refactor(mm): add CheckpointConfigBase for all ckpt models

This commit is contained in:
psychedelicious 2024-03-01 15:21:35 +11:00
parent 0a614943f6
commit 76cbc745e1
4 changed files with 17 additions and 18 deletions

View File

@ -24,6 +24,7 @@ from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
CheckpointConfigBase,
InvalidModelConfigException,
ModelRepoVariant,
ModelType,
@ -532,7 +533,8 @@ class ModelInstallService(ModelInstallServiceBase):
) -> str:
# Note that we may be passed a pre-populated AnyModelConfig object,
# in which case the key field should have been populated by the caller (e.g. in `install_path`).
config["key"] = config.get("key", uuid_string())
if config is not None:
config["key"] = config.get("key", uuid_string())
info = info or ModelProbe.probe(model_path, config)
override_key: Optional[str] = config.get("key") if config else None
@ -546,7 +548,7 @@ class ModelInstallService(ModelInstallServiceBase):
info.path = model_path.as_posix()
# add 'main' specific fields
if hasattr(info, "config"):
if isinstance(info, CheckpointConfigBase):
# make config relative to our root
legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config).resolve()
info.config = legacy_conf.relative_to(self.app_config.root_dir).as_posix()

View File

@ -135,7 +135,6 @@ class ModelConfigBase(BaseModel):
) # if model is converted or otherwise modified, this will hold updated hash
description: Optional[str] = Field(description="human readable description of the model", default=None)
source: Optional[str] = Field(description="model original source (path, URL or repo_id)", default=None)
last_modified: Optional[float] = Field(description="timestamp for modification time", default_factory=time.time)
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
@ -160,6 +159,9 @@ class CheckpointConfigBase(ModelConfigBase):
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
config: str = Field(description="path to the checkpoint model config file")
last_modified: Optional[float] = Field(
description="When this model was last converted to diffusers", default_factory=time.time
)
class DiffusersConfigBase(ModelConfigBase):
@ -191,7 +193,7 @@ class LoRADiffusersConfig(ModelConfigBase):
return Tag(f"{ModelType.Lora}.{ModelFormat.Diffusers}")
class VaeCheckpointConfig(ModelConfigBase):
class VaeCheckpointConfig(CheckpointConfigBase):
"""Model config for standalone VAE models."""
type: Literal[ModelType.Vae] = ModelType.Vae
@ -257,25 +259,20 @@ class TextualInversionFolderConfig(ModelConfigBase):
return Tag(f"{ModelType.TextualInversion}.{ModelFormat.EmbeddingFolder}")
class _MainConfig(ModelConfigBase):
"""Model config for main models."""
variant: ModelVariantType = ModelVariantType.Normal
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
class MainCheckpointConfig(CheckpointConfigBase, _MainConfig):
class MainCheckpointConfig(CheckpointConfigBase):
"""Model config for main checkpoint models."""
type: Literal[ModelType.Main] = ModelType.Main
variant: ModelVariantType = ModelVariantType.Normal
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Main}.{ModelFormat.Checkpoint}")
class MainDiffusersConfig(DiffusersConfigBase, _MainConfig):
class MainDiffusersConfig(DiffusersConfigBase):
"""Model config for main diffusers models."""
type: Literal[ModelType.Main] = ModelType.Main
@ -382,6 +379,6 @@ class ModelConfigFactory(object):
assert model is not None
if key:
model.key = key
if timestamp:
if isinstance(model, CheckpointConfigBase) and timestamp is not None:
model.last_modified = timestamp
return model # type: ignore

View File

@ -25,7 +25,7 @@ class ControlNetLoader(GenericDiffusersLoader):
"""Class to load ControlNet models."""
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
if config.format != ModelFormat.Checkpoint:
if not isinstance(config, CheckpointConfigBase):
return False
elif (
dest_path.exists()

View File

@ -17,7 +17,7 @@ from invokeai.backend.model_manager import (
ModelVariantType,
SubModelType,
)
from invokeai.backend.model_manager.config import MainCheckpointConfig
from invokeai.backend.model_manager.config import CheckpointConfigBase, MainCheckpointConfig
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
from .. import ModelLoaderRegistry
@ -55,7 +55,7 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
return result
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
if config.format != ModelFormat.Checkpoint:
if not isinstance(config, CheckpointConfigBase):
return False
elif (
dest_path.exists()