Merge branch 'main' into lstein/migrate-fix

This commit is contained in:
Kent Keirsey 2023-07-15 10:37:56 -04:00 committed by GitHub
commit 77b0129b4c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 613 additions and 6 deletions

View File

@ -1,8 +1,7 @@
import os import os
import torch import torch
from enum import Enum from enum import Enum
from pathlib import Path from typing import Optional
from typing import Optional, Union, Literal
from .base import ( from .base import (
ModelBase, ModelBase,
ModelConfigBase, ModelConfigBase,
@ -14,6 +13,7 @@ from .base import (
calc_model_size_by_data, calc_model_size_by_data,
classproperty, classproperty,
InvalidModelException, InvalidModelException,
ModelNotFoundException,
) )
class ControlNetModelFormat(str, Enum): class ControlNetModelFormat(str, Enum):
@ -60,10 +60,20 @@ class ControlNetModel(ModelBase):
if child_type is not None: if child_type is not None:
raise Exception("There is no child models in controlnet model") raise Exception("There is no child models in controlnet model")
model = self.model_class.from_pretrained( model = None
self.model_path, for variant in ['fp16',None]:
torch_dtype=torch_dtype, try:
) model = self.model_class.from_pretrained(
self.model_path,
torch_dtype=torch_dtype,
variant=variant,
)
break
except:
pass
if not model:
raise ModelNotFoundException()
# calc more accurate size # calc more accurate size
self.model_size = calc_model_size_by_data(model) self.model_size = calc_model_size_by_data(model)
return model return model

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long