mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor(mm): add CheckpointConfigBase
for all ckpt models
This commit is contained in:
parent
0a614943f6
commit
76cbc745e1
@ -24,6 +24,7 @@ from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
CheckpointConfigBase,
|
||||
InvalidModelConfigException,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
@ -532,7 +533,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
) -> str:
|
||||
# Note that we may be passed a pre-populated AnyModelConfig object,
|
||||
# in which case the key field should have been populated by the caller (e.g. in `install_path`).
|
||||
config["key"] = config.get("key", uuid_string())
|
||||
if config is not None:
|
||||
config["key"] = config.get("key", uuid_string())
|
||||
info = info or ModelProbe.probe(model_path, config)
|
||||
override_key: Optional[str] = config.get("key") if config else None
|
||||
|
||||
@ -546,7 +548,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
info.path = model_path.as_posix()
|
||||
|
||||
# add 'main' specific fields
|
||||
if hasattr(info, "config"):
|
||||
if isinstance(info, CheckpointConfigBase):
|
||||
# make config relative to our root
|
||||
legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config).resolve()
|
||||
info.config = legacy_conf.relative_to(self.app_config.root_dir).as_posix()
|
||||
|
@ -135,7 +135,6 @@ class ModelConfigBase(BaseModel):
|
||||
) # if model is converted or otherwise modified, this will hold updated hash
|
||||
description: Optional[str] = Field(description="human readable description of the model", default=None)
|
||||
source: Optional[str] = Field(description="model original source (path, URL or repo_id)", default=None)
|
||||
last_modified: Optional[float] = Field(description="timestamp for modification time", default_factory=time.time)
|
||||
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
@ -160,6 +159,9 @@ class CheckpointConfigBase(ModelConfigBase):
|
||||
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
config: str = Field(description="path to the checkpoint model config file")
|
||||
last_modified: Optional[float] = Field(
|
||||
description="When this model was last converted to diffusers", default_factory=time.time
|
||||
)
|
||||
|
||||
|
||||
class DiffusersConfigBase(ModelConfigBase):
|
||||
@ -191,7 +193,7 @@ class LoRADiffusersConfig(ModelConfigBase):
|
||||
return Tag(f"{ModelType.Lora}.{ModelFormat.Diffusers}")
|
||||
|
||||
|
||||
class VaeCheckpointConfig(ModelConfigBase):
|
||||
class VaeCheckpointConfig(CheckpointConfigBase):
|
||||
"""Model config for standalone VAE models."""
|
||||
|
||||
type: Literal[ModelType.Vae] = ModelType.Vae
|
||||
@ -257,25 +259,20 @@ class TextualInversionFolderConfig(ModelConfigBase):
|
||||
return Tag(f"{ModelType.TextualInversion}.{ModelFormat.EmbeddingFolder}")
|
||||
|
||||
|
||||
class _MainConfig(ModelConfigBase):
|
||||
"""Model config for main models."""
|
||||
|
||||
variant: ModelVariantType = ModelVariantType.Normal
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||
upcast_attention: bool = False
|
||||
|
||||
|
||||
class MainCheckpointConfig(CheckpointConfigBase, _MainConfig):
|
||||
class MainCheckpointConfig(CheckpointConfigBase):
|
||||
"""Model config for main checkpoint models."""
|
||||
|
||||
type: Literal[ModelType.Main] = ModelType.Main
|
||||
variant: ModelVariantType = ModelVariantType.Normal
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||
upcast_attention: bool = False
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.Main}.{ModelFormat.Checkpoint}")
|
||||
|
||||
|
||||
class MainDiffusersConfig(DiffusersConfigBase, _MainConfig):
|
||||
class MainDiffusersConfig(DiffusersConfigBase):
|
||||
"""Model config for main diffusers models."""
|
||||
|
||||
type: Literal[ModelType.Main] = ModelType.Main
|
||||
@ -382,6 +379,6 @@ class ModelConfigFactory(object):
|
||||
assert model is not None
|
||||
if key:
|
||||
model.key = key
|
||||
if timestamp:
|
||||
if isinstance(model, CheckpointConfigBase) and timestamp is not None:
|
||||
model.last_modified = timestamp
|
||||
return model # type: ignore
|
||||
|
@ -25,7 +25,7 @@ class ControlNetLoader(GenericDiffusersLoader):
|
||||
"""Class to load ControlNet models."""
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
if config.format != ModelFormat.Checkpoint:
|
||||
if not isinstance(config, CheckpointConfigBase):
|
||||
return False
|
||||
elif (
|
||||
dest_path.exists()
|
||||
|
@ -17,7 +17,7 @@ from invokeai.backend.model_manager import (
|
||||
ModelVariantType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import MainCheckpointConfig
|
||||
from invokeai.backend.model_manager.config import CheckpointConfigBase, MainCheckpointConfig
|
||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
||||
|
||||
from .. import ModelLoaderRegistry
|
||||
@ -55,7 +55,7 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
return result
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
if config.format != ModelFormat.Checkpoint:
|
||||
if not isinstance(config, CheckpointConfigBase):
|
||||
return False
|
||||
elif (
|
||||
dest_path.exists()
|
||||
|
Loading…
Reference in New Issue
Block a user