mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Rename format to model_format(still named format when work with config)
This commit is contained in:
parent
aceadacad4
commit
e4dc9c5a04
@ -266,6 +266,8 @@ class ModelManager(object):
|
||||
for model_key, model_config in config.items():
|
||||
model_name, base_model, model_type = self.parse_key(model_key)
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
# alias for config file
|
||||
model_config["model_format"] = model_config.pop("format")
|
||||
self.models[model_key] = model_class.create_config(**model_config)
|
||||
|
||||
# check config version number and update on disk/RAM if necessary
|
||||
@ -617,6 +619,8 @@ class ModelManager(object):
|
||||
if model_class.save_to_config:
|
||||
# TODO: or exclude_unset better fits here?
|
||||
data_to_save[model_key] = model_config.dict(exclude_defaults=True, exclude={"error"})
|
||||
# alias for config file
|
||||
data_to_save[model_key]["format"] = data_to_save[model_key].pop("model_format")
|
||||
|
||||
yaml_str = OmegaConf.to_yaml(data_to_save)
|
||||
config_file_path = conf_file or self.config_path
|
||||
|
@ -49,7 +49,7 @@ class ModelError(str, Enum):
|
||||
class ModelConfigBase(BaseModel):
|
||||
path: str # or Path
|
||||
description: Optional[str] = Field(None)
|
||||
format: Optional[str] = Field(None)
|
||||
model_format: Optional[str] = Field(None)
|
||||
# do not save to config
|
||||
error: Optional[ModelError] = Field(None)
|
||||
|
||||
@ -125,20 +125,20 @@ class ModelBase(metaclass=ABCMeta):
|
||||
continue
|
||||
|
||||
fields = inspect.get_annotations(value)
|
||||
if "format" not in fields:
|
||||
raise Exception("Invalid config definition - format field not found")
|
||||
if "model_format" not in fields:
|
||||
raise Exception("Invalid config definition - model_format field not found")
|
||||
|
||||
format_type = typing.get_origin(fields["format"])
|
||||
format_type = typing.get_origin(fields["model_format"])
|
||||
if format_type not in {None, Literal, Union}:
|
||||
raise Exception(f"Invalid config definition - unknown format type: {fields['format']}")
|
||||
raise Exception(f"Invalid config definition - unknown format type: {fields['model_format']}")
|
||||
|
||||
if format_type is Union and not all(typing.get_origin(v) in {None, Literal} for v in fields["format"].__args__):
|
||||
raise Exception(f"Invalid config definition - unknown format type: {fields['format']}")
|
||||
if format_type is Union and not all(typing.get_origin(v) in {None, Literal} for v in fields["model_format"].__args__):
|
||||
raise Exception(f"Invalid config definition - unknown format type: {fields['model_format']}")
|
||||
|
||||
if format_type == Union:
|
||||
f_fields = fields["format"].__args__
|
||||
f_fields = fields["model_format"].__args__
|
||||
else:
|
||||
f_fields = (fields["format"],)
|
||||
f_fields = (fields["model_format"],)
|
||||
|
||||
|
||||
for field in f_fields:
|
||||
@ -155,17 +155,17 @@ class ModelBase(metaclass=ABCMeta):
|
||||
|
||||
@classmethod
|
||||
def create_config(cls, **kwargs) -> ModelConfigBase:
|
||||
if "format" not in kwargs:
|
||||
raise Exception("Field 'format' not found in model config")
|
||||
if "model_format" not in kwargs:
|
||||
raise Exception("Field 'model_format' not found in model config")
|
||||
|
||||
configs = cls._get_configs()
|
||||
return configs[kwargs["format"]](**kwargs)
|
||||
return configs[kwargs["model_format"]](**kwargs)
|
||||
|
||||
@classmethod
|
||||
def probe_config(cls, path: str, **kwargs) -> ModelConfigBase:
|
||||
return cls.create_config(
|
||||
path=path,
|
||||
format=cls.detect_format(path),
|
||||
model_format=cls.detect_format(path),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -19,7 +19,7 @@ class ControlNetModel(ModelBase):
|
||||
#model_size: int
|
||||
|
||||
class Config(ModelConfigBase):
|
||||
format: Union[Literal["checkpoint"], Literal["diffusers"]]
|
||||
model_format: Union[Literal["checkpoint"], Literal["diffusers"]]
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert model_type == ModelType.ControlNet
|
||||
|
@ -16,7 +16,7 @@ class LoRAModel(ModelBase):
|
||||
#model_size: int
|
||||
|
||||
class Config(ModelConfigBase):
|
||||
format: Union[Literal["lycoris"], Literal["diffusers"]]
|
||||
model_format: Union[Literal["lycoris"], Literal["diffusers"]]
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert model_type == ModelType.Lora
|
||||
|
@ -23,12 +23,12 @@ from omegaconf import OmegaConf
|
||||
class StableDiffusion1Model(DiffusersModel):
|
||||
|
||||
class DiffusersConfig(ModelConfigBase):
|
||||
format: Literal["diffusers"]
|
||||
model_format: Literal["diffusers"]
|
||||
vae: Optional[str] = Field(None)
|
||||
variant: ModelVariantType
|
||||
|
||||
class CheckpointConfig(ModelConfigBase):
|
||||
format: Literal["checkpoint"]
|
||||
model_format: Literal["checkpoint"]
|
||||
vae: Optional[str] = Field(None)
|
||||
config: Optional[str] = Field(None)
|
||||
variant: ModelVariantType
|
||||
@ -80,7 +80,7 @@ class StableDiffusion1Model(DiffusersModel):
|
||||
|
||||
return cls.create_config(
|
||||
path=path,
|
||||
format=model_format,
|
||||
model_format=model_format,
|
||||
|
||||
config=ckpt_config_path,
|
||||
variant=variant,
|
||||
@ -121,14 +121,14 @@ class StableDiffusion2Model(DiffusersModel):
|
||||
|
||||
# TODO: check that configs overwriten properly
|
||||
class DiffusersConfig(ModelConfigBase):
|
||||
format: Literal["diffusers"]
|
||||
model_format: Literal["diffusers"]
|
||||
vae: Optional[str] = Field(None)
|
||||
variant: ModelVariantType
|
||||
prediction_type: SchedulerPredictionType
|
||||
upcast_attention: bool
|
||||
|
||||
class CheckpointConfig(ModelConfigBase):
|
||||
format: Literal["checkpoint"]
|
||||
model_format: Literal["checkpoint"]
|
||||
vae: Optional[str] = Field(None)
|
||||
config: Optional[str] = Field(None)
|
||||
variant: ModelVariantType
|
||||
@ -191,7 +191,7 @@ class StableDiffusion2Model(DiffusersModel):
|
||||
|
||||
return cls.create_config(
|
||||
path=path,
|
||||
format=model_format,
|
||||
model_format=model_format,
|
||||
|
||||
config=ckpt_config_path,
|
||||
variant=variant,
|
||||
|
@ -16,7 +16,7 @@ class TextualInversionModel(ModelBase):
|
||||
#model_size: int
|
||||
|
||||
class Config(ModelConfigBase):
|
||||
format: None
|
||||
model_format: None
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert model_type == ModelType.TextualInversion
|
||||
|
@ -24,7 +24,7 @@ class VaeModel(ModelBase):
|
||||
#model_size: int
|
||||
|
||||
class Config(ModelConfigBase):
|
||||
format: Union[Literal["checkpoint"], Literal["diffusers"]]
|
||||
model_format: Union[Literal["checkpoint"], Literal["diffusers"]]
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert model_type == ModelType.Vae
|
||||
|
Loading…
Reference in New Issue
Block a user