From 76cbc745e1373952411ea1328ffc5d3ea6833126 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 1 Mar 2024 15:21:35 +1100 Subject: [PATCH] refactor(mm): add `CheckpointConfigBase` for all ckpt models --- .../model_install/model_install_default.py | 6 +++-- invokeai/backend/model_manager/config.py | 23 ++++++++----------- .../load/model_loaders/controlnet.py | 2 +- .../load/model_loaders/stable_diffusion.py | 4 ++-- 4 files changed, 17 insertions(+), 18 deletions(-) diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index b91f961099..5a4c765d82 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -24,6 +24,7 @@ from invokeai.app.util.misc import uuid_string from invokeai.backend.model_manager.config import ( AnyModelConfig, BaseModelType, + CheckpointConfigBase, InvalidModelConfigException, ModelRepoVariant, ModelType, @@ -532,7 +533,8 @@ class ModelInstallService(ModelInstallServiceBase): ) -> str: # Note that we may be passed a pre-populated AnyModelConfig object, # in which case the key field should have been populated by the caller (e.g. in `install_path`). - config["key"] = config.get("key", uuid_string()) + if config is not None: + config["key"] = config.get("key", uuid_string()) info = info or ModelProbe.probe(model_path, config) override_key: Optional[str] = config.get("key") if config else None @@ -546,7 +548,7 @@ class ModelInstallService(ModelInstallServiceBase): info.path = model_path.as_posix() # add 'main' specific fields - if hasattr(info, "config"): + if isinstance(info, CheckpointConfigBase): # make config relative to our root legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config).resolve() info.config = legacy_conf.relative_to(self.app_config.root_dir).as_posix() diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index e01379cc1f..a37fc7de7f 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -135,7 +135,6 @@ class ModelConfigBase(BaseModel): ) # if model is converted or otherwise modified, this will hold updated hash description: Optional[str] = Field(description="human readable description of the model", default=None) source: Optional[str] = Field(description="model original source (path, URL or repo_id)", default=None) - last_modified: Optional[float] = Field(description="timestamp for modification time", default_factory=time.time) @staticmethod def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None: @@ -160,6 +159,9 @@ class CheckpointConfigBase(ModelConfigBase): format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint config: str = Field(description="path to the checkpoint model config file") + last_modified: Optional[float] = Field( + description="When this model was last converted to diffusers", default_factory=time.time + ) class DiffusersConfigBase(ModelConfigBase): @@ -191,7 +193,7 @@ class LoRADiffusersConfig(ModelConfigBase): return Tag(f"{ModelType.Lora}.{ModelFormat.Diffusers}") -class VaeCheckpointConfig(ModelConfigBase): +class VaeCheckpointConfig(CheckpointConfigBase): """Model config for standalone VAE models.""" type: Literal[ModelType.Vae] = ModelType.Vae @@ -257,25 +259,20 @@ class TextualInversionFolderConfig(ModelConfigBase): return Tag(f"{ModelType.TextualInversion}.{ModelFormat.EmbeddingFolder}") -class _MainConfig(ModelConfigBase): - """Model config for main models.""" - - variant: ModelVariantType = ModelVariantType.Normal - prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon - upcast_attention: bool = False - - -class MainCheckpointConfig(CheckpointConfigBase, _MainConfig): +class MainCheckpointConfig(CheckpointConfigBase): """Model config for main checkpoint models.""" type: Literal[ModelType.Main] = ModelType.Main + variant: ModelVariantType = ModelVariantType.Normal + prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon + upcast_attention: bool = False @staticmethod def get_tag() -> Tag: return Tag(f"{ModelType.Main}.{ModelFormat.Checkpoint}") -class MainDiffusersConfig(DiffusersConfigBase, _MainConfig): +class MainDiffusersConfig(DiffusersConfigBase): """Model config for main diffusers models.""" type: Literal[ModelType.Main] = ModelType.Main @@ -382,6 +379,6 @@ class ModelConfigFactory(object): assert model is not None if key: model.key = key - if timestamp: + if isinstance(model, CheckpointConfigBase) and timestamp is not None: model.last_modified = timestamp return model # type: ignore diff --git a/invokeai/backend/model_manager/load/model_loaders/controlnet.py b/invokeai/backend/model_manager/load/model_loaders/controlnet.py index 66bba755ae..4285d3b2da 100644 --- a/invokeai/backend/model_manager/load/model_loaders/controlnet.py +++ b/invokeai/backend/model_manager/load/model_loaders/controlnet.py @@ -25,7 +25,7 @@ class ControlNetLoader(GenericDiffusersLoader): """Class to load ControlNet models.""" def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool: - if config.format != ModelFormat.Checkpoint: + if not isinstance(config, CheckpointConfigBase): return False elif ( dest_path.exists() diff --git a/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py index cfe0e6f83a..0deb723ca0 100644 --- a/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +++ b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py @@ -17,7 +17,7 @@ from invokeai.backend.model_manager import ( ModelVariantType, SubModelType, ) -from invokeai.backend.model_manager.config import MainCheckpointConfig +from invokeai.backend.model_manager.config import CheckpointConfigBase, MainCheckpointConfig from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers from .. import ModelLoaderRegistry @@ -55,7 +55,7 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader): return result def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool: - if config.format != ModelFormat.Checkpoint: + if not isinstance(config, CheckpointConfigBase): return False elif ( dest_path.exists()