Rename format to model_format(still named format when work with config)

This commit is contained in:
Sergey Borisov 2023-06-20 03:25:08 +03:00 committed by psychedelicious
parent aceadacad4
commit e4dc9c5a04
7 changed files with 27 additions and 23 deletions

View File

@ -266,6 +266,8 @@ class ModelManager(object):
for model_key, model_config in config.items():
model_name, base_model, model_type = self.parse_key(model_key)
model_class = MODEL_CLASSES[base_model][model_type]
# alias for config file
model_config["model_format"] = model_config.pop("format")
self.models[model_key] = model_class.create_config(**model_config)
# check config version number and update on disk/RAM if necessary
@ -617,6 +619,8 @@ class ModelManager(object):
if model_class.save_to_config:
# TODO: or exclude_unset better fits here?
data_to_save[model_key] = model_config.dict(exclude_defaults=True, exclude={"error"})
# alias for config file
data_to_save[model_key]["format"] = data_to_save[model_key].pop("model_format")
yaml_str = OmegaConf.to_yaml(data_to_save)
config_file_path = conf_file or self.config_path

View File

@ -49,7 +49,7 @@ class ModelError(str, Enum):
class ModelConfigBase(BaseModel):
path: str # or Path
description: Optional[str] = Field(None)
format: Optional[str] = Field(None)
model_format: Optional[str] = Field(None)
# do not save to config
error: Optional[ModelError] = Field(None)
@ -125,20 +125,20 @@ class ModelBase(metaclass=ABCMeta):
continue
fields = inspect.get_annotations(value)
if "format" not in fields:
raise Exception("Invalid config definition - format field not found")
if "model_format" not in fields:
raise Exception("Invalid config definition - model_format field not found")
format_type = typing.get_origin(fields["format"])
format_type = typing.get_origin(fields["model_format"])
if format_type not in {None, Literal, Union}:
raise Exception(f"Invalid config definition - unknown format type: {fields['format']}")
raise Exception(f"Invalid config definition - unknown format type: {fields['model_format']}")
if format_type is Union and not all(typing.get_origin(v) in {None, Literal} for v in fields["format"].__args__):
raise Exception(f"Invalid config definition - unknown format type: {fields['format']}")
if format_type is Union and not all(typing.get_origin(v) in {None, Literal} for v in fields["model_format"].__args__):
raise Exception(f"Invalid config definition - unknown format type: {fields['model_format']}")
if format_type == Union:
f_fields = fields["format"].__args__
f_fields = fields["model_format"].__args__
else:
f_fields = (fields["format"],)
f_fields = (fields["model_format"],)
for field in f_fields:
@ -155,17 +155,17 @@ class ModelBase(metaclass=ABCMeta):
@classmethod
def create_config(cls, **kwargs) -> ModelConfigBase:
if "format" not in kwargs:
raise Exception("Field 'format' not found in model config")
if "model_format" not in kwargs:
raise Exception("Field 'model_format' not found in model config")
configs = cls._get_configs()
return configs[kwargs["format"]](**kwargs)
return configs[kwargs["model_format"]](**kwargs)
@classmethod
def probe_config(cls, path: str, **kwargs) -> ModelConfigBase:
return cls.create_config(
path=path,
format=cls.detect_format(path),
model_format=cls.detect_format(path),
)
@classmethod

View File

@ -19,7 +19,7 @@ class ControlNetModel(ModelBase):
#model_size: int
class Config(ModelConfigBase):
format: Union[Literal["checkpoint"], Literal["diffusers"]]
model_format: Union[Literal["checkpoint"], Literal["diffusers"]]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.ControlNet

View File

@ -16,7 +16,7 @@ class LoRAModel(ModelBase):
#model_size: int
class Config(ModelConfigBase):
format: Union[Literal["lycoris"], Literal["diffusers"]]
model_format: Union[Literal["lycoris"], Literal["diffusers"]]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.Lora

View File

@ -23,12 +23,12 @@ from omegaconf import OmegaConf
class StableDiffusion1Model(DiffusersModel):
class DiffusersConfig(ModelConfigBase):
format: Literal["diffusers"]
model_format: Literal["diffusers"]
vae: Optional[str] = Field(None)
variant: ModelVariantType
class CheckpointConfig(ModelConfigBase):
format: Literal["checkpoint"]
model_format: Literal["checkpoint"]
vae: Optional[str] = Field(None)
config: Optional[str] = Field(None)
variant: ModelVariantType
@ -80,7 +80,7 @@ class StableDiffusion1Model(DiffusersModel):
return cls.create_config(
path=path,
format=model_format,
model_format=model_format,
config=ckpt_config_path,
variant=variant,
@ -121,14 +121,14 @@ class StableDiffusion2Model(DiffusersModel):
# TODO: check that configs overwriten properly
class DiffusersConfig(ModelConfigBase):
format: Literal["diffusers"]
model_format: Literal["diffusers"]
vae: Optional[str] = Field(None)
variant: ModelVariantType
prediction_type: SchedulerPredictionType
upcast_attention: bool
class CheckpointConfig(ModelConfigBase):
format: Literal["checkpoint"]
model_format: Literal["checkpoint"]
vae: Optional[str] = Field(None)
config: Optional[str] = Field(None)
variant: ModelVariantType
@ -191,7 +191,7 @@ class StableDiffusion2Model(DiffusersModel):
return cls.create_config(
path=path,
format=model_format,
model_format=model_format,
config=ckpt_config_path,
variant=variant,

View File

@ -16,7 +16,7 @@ class TextualInversionModel(ModelBase):
#model_size: int
class Config(ModelConfigBase):
format: None
model_format: None
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.TextualInversion

View File

@ -24,7 +24,7 @@ class VaeModel(ModelBase):
#model_size: int
class Config(ModelConfigBase):
format: Union[Literal["checkpoint"], Literal["diffusers"]]
model_format: Union[Literal["checkpoint"], Literal["diffusers"]]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.Vae