mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Rename format to model_format(still named format when work with config)
This commit is contained in:
parent
82b73c50a0
commit
4cefe37723
@ -266,6 +266,8 @@ class ModelManager(object):
|
|||||||
for model_key, model_config in config.items():
|
for model_key, model_config in config.items():
|
||||||
model_name, base_model, model_type = self.parse_key(model_key)
|
model_name, base_model, model_type = self.parse_key(model_key)
|
||||||
model_class = MODEL_CLASSES[base_model][model_type]
|
model_class = MODEL_CLASSES[base_model][model_type]
|
||||||
|
# alias for config file
|
||||||
|
model_config["model_format"] = model_config.pop("format")
|
||||||
self.models[model_key] = model_class.create_config(**model_config)
|
self.models[model_key] = model_class.create_config(**model_config)
|
||||||
|
|
||||||
# check config version number and update on disk/RAM if necessary
|
# check config version number and update on disk/RAM if necessary
|
||||||
@ -617,6 +619,8 @@ class ModelManager(object):
|
|||||||
if model_class.save_to_config:
|
if model_class.save_to_config:
|
||||||
# TODO: or exclude_unset better fits here?
|
# TODO: or exclude_unset better fits here?
|
||||||
data_to_save[model_key] = model_config.dict(exclude_defaults=True, exclude={"error"})
|
data_to_save[model_key] = model_config.dict(exclude_defaults=True, exclude={"error"})
|
||||||
|
# alias for config file
|
||||||
|
data_to_save[model_key]["format"] = data_to_save[model_key].pop("model_format")
|
||||||
|
|
||||||
yaml_str = OmegaConf.to_yaml(data_to_save)
|
yaml_str = OmegaConf.to_yaml(data_to_save)
|
||||||
config_file_path = conf_file or self.config_path
|
config_file_path = conf_file or self.config_path
|
||||||
|
@ -49,7 +49,7 @@ class ModelError(str, Enum):
|
|||||||
class ModelConfigBase(BaseModel):
|
class ModelConfigBase(BaseModel):
|
||||||
path: str # or Path
|
path: str # or Path
|
||||||
description: Optional[str] = Field(None)
|
description: Optional[str] = Field(None)
|
||||||
format: Optional[str] = Field(None)
|
model_format: Optional[str] = Field(None)
|
||||||
# do not save to config
|
# do not save to config
|
||||||
error: Optional[ModelError] = Field(None)
|
error: Optional[ModelError] = Field(None)
|
||||||
|
|
||||||
@ -125,20 +125,20 @@ class ModelBase(metaclass=ABCMeta):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
fields = inspect.get_annotations(value)
|
fields = inspect.get_annotations(value)
|
||||||
if "format" not in fields:
|
if "model_format" not in fields:
|
||||||
raise Exception("Invalid config definition - format field not found")
|
raise Exception("Invalid config definition - model_format field not found")
|
||||||
|
|
||||||
format_type = typing.get_origin(fields["format"])
|
format_type = typing.get_origin(fields["model_format"])
|
||||||
if format_type not in {None, Literal, Union}:
|
if format_type not in {None, Literal, Union}:
|
||||||
raise Exception(f"Invalid config definition - unknown format type: {fields['format']}")
|
raise Exception(f"Invalid config definition - unknown format type: {fields['model_format']}")
|
||||||
|
|
||||||
if format_type is Union and not all(typing.get_origin(v) in {None, Literal} for v in fields["format"].__args__):
|
if format_type is Union and not all(typing.get_origin(v) in {None, Literal} for v in fields["model_format"].__args__):
|
||||||
raise Exception(f"Invalid config definition - unknown format type: {fields['format']}")
|
raise Exception(f"Invalid config definition - unknown format type: {fields['model_format']}")
|
||||||
|
|
||||||
if format_type == Union:
|
if format_type == Union:
|
||||||
f_fields = fields["format"].__args__
|
f_fields = fields["model_format"].__args__
|
||||||
else:
|
else:
|
||||||
f_fields = (fields["format"],)
|
f_fields = (fields["model_format"],)
|
||||||
|
|
||||||
|
|
||||||
for field in f_fields:
|
for field in f_fields:
|
||||||
@ -155,17 +155,17 @@ class ModelBase(metaclass=ABCMeta):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_config(cls, **kwargs) -> ModelConfigBase:
|
def create_config(cls, **kwargs) -> ModelConfigBase:
|
||||||
if "format" not in kwargs:
|
if "model_format" not in kwargs:
|
||||||
raise Exception("Field 'format' not found in model config")
|
raise Exception("Field 'model_format' not found in model config")
|
||||||
|
|
||||||
configs = cls._get_configs()
|
configs = cls._get_configs()
|
||||||
return configs[kwargs["format"]](**kwargs)
|
return configs[kwargs["model_format"]](**kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def probe_config(cls, path: str, **kwargs) -> ModelConfigBase:
|
def probe_config(cls, path: str, **kwargs) -> ModelConfigBase:
|
||||||
return cls.create_config(
|
return cls.create_config(
|
||||||
path=path,
|
path=path,
|
||||||
format=cls.detect_format(path),
|
model_format=cls.detect_format(path),
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -19,7 +19,7 @@ class ControlNetModel(ModelBase):
|
|||||||
#model_size: int
|
#model_size: int
|
||||||
|
|
||||||
class Config(ModelConfigBase):
|
class Config(ModelConfigBase):
|
||||||
format: Union[Literal["checkpoint"], Literal["diffusers"]]
|
model_format: Union[Literal["checkpoint"], Literal["diffusers"]]
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
assert model_type == ModelType.ControlNet
|
assert model_type == ModelType.ControlNet
|
||||||
|
@ -16,7 +16,7 @@ class LoRAModel(ModelBase):
|
|||||||
#model_size: int
|
#model_size: int
|
||||||
|
|
||||||
class Config(ModelConfigBase):
|
class Config(ModelConfigBase):
|
||||||
format: Union[Literal["lycoris"], Literal["diffusers"]]
|
model_format: Union[Literal["lycoris"], Literal["diffusers"]]
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
assert model_type == ModelType.Lora
|
assert model_type == ModelType.Lora
|
||||||
|
@ -23,12 +23,12 @@ from omegaconf import OmegaConf
|
|||||||
class StableDiffusion1Model(DiffusersModel):
|
class StableDiffusion1Model(DiffusersModel):
|
||||||
|
|
||||||
class DiffusersConfig(ModelConfigBase):
|
class DiffusersConfig(ModelConfigBase):
|
||||||
format: Literal["diffusers"]
|
model_format: Literal["diffusers"]
|
||||||
vae: Optional[str] = Field(None)
|
vae: Optional[str] = Field(None)
|
||||||
variant: ModelVariantType
|
variant: ModelVariantType
|
||||||
|
|
||||||
class CheckpointConfig(ModelConfigBase):
|
class CheckpointConfig(ModelConfigBase):
|
||||||
format: Literal["checkpoint"]
|
model_format: Literal["checkpoint"]
|
||||||
vae: Optional[str] = Field(None)
|
vae: Optional[str] = Field(None)
|
||||||
config: Optional[str] = Field(None)
|
config: Optional[str] = Field(None)
|
||||||
variant: ModelVariantType
|
variant: ModelVariantType
|
||||||
@ -80,7 +80,7 @@ class StableDiffusion1Model(DiffusersModel):
|
|||||||
|
|
||||||
return cls.create_config(
|
return cls.create_config(
|
||||||
path=path,
|
path=path,
|
||||||
format=model_format,
|
model_format=model_format,
|
||||||
|
|
||||||
config=ckpt_config_path,
|
config=ckpt_config_path,
|
||||||
variant=variant,
|
variant=variant,
|
||||||
@ -121,14 +121,14 @@ class StableDiffusion2Model(DiffusersModel):
|
|||||||
|
|
||||||
# TODO: check that configs overwriten properly
|
# TODO: check that configs overwriten properly
|
||||||
class DiffusersConfig(ModelConfigBase):
|
class DiffusersConfig(ModelConfigBase):
|
||||||
format: Literal["diffusers"]
|
model_format: Literal["diffusers"]
|
||||||
vae: Optional[str] = Field(None)
|
vae: Optional[str] = Field(None)
|
||||||
variant: ModelVariantType
|
variant: ModelVariantType
|
||||||
prediction_type: SchedulerPredictionType
|
prediction_type: SchedulerPredictionType
|
||||||
upcast_attention: bool
|
upcast_attention: bool
|
||||||
|
|
||||||
class CheckpointConfig(ModelConfigBase):
|
class CheckpointConfig(ModelConfigBase):
|
||||||
format: Literal["checkpoint"]
|
model_format: Literal["checkpoint"]
|
||||||
vae: Optional[str] = Field(None)
|
vae: Optional[str] = Field(None)
|
||||||
config: Optional[str] = Field(None)
|
config: Optional[str] = Field(None)
|
||||||
variant: ModelVariantType
|
variant: ModelVariantType
|
||||||
@ -191,7 +191,7 @@ class StableDiffusion2Model(DiffusersModel):
|
|||||||
|
|
||||||
return cls.create_config(
|
return cls.create_config(
|
||||||
path=path,
|
path=path,
|
||||||
format=model_format,
|
model_format=model_format,
|
||||||
|
|
||||||
config=ckpt_config_path,
|
config=ckpt_config_path,
|
||||||
variant=variant,
|
variant=variant,
|
||||||
|
@ -16,7 +16,7 @@ class TextualInversionModel(ModelBase):
|
|||||||
#model_size: int
|
#model_size: int
|
||||||
|
|
||||||
class Config(ModelConfigBase):
|
class Config(ModelConfigBase):
|
||||||
format: None
|
model_format: None
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
assert model_type == ModelType.TextualInversion
|
assert model_type == ModelType.TextualInversion
|
||||||
|
@ -24,7 +24,7 @@ class VaeModel(ModelBase):
|
|||||||
#model_size: int
|
#model_size: int
|
||||||
|
|
||||||
class Config(ModelConfigBase):
|
class Config(ModelConfigBase):
|
||||||
format: Union[Literal["checkpoint"], Literal["diffusers"]]
|
model_format: Union[Literal["checkpoint"], Literal["diffusers"]]
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
assert model_type == ModelType.Vae
|
assert model_type == ModelType.Vae
|
||||||
|
Loading…
x
Reference in New Issue
Block a user