mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor(mm): add CheckpointConfigBase
for all ckpt models
This commit is contained in:
parent
0a614943f6
commit
76cbc745e1
@ -24,6 +24,7 @@ from invokeai.app.util.misc import uuid_string
|
|||||||
from invokeai.backend.model_manager.config import (
|
from invokeai.backend.model_manager.config import (
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
|
CheckpointConfigBase,
|
||||||
InvalidModelConfigException,
|
InvalidModelConfigException,
|
||||||
ModelRepoVariant,
|
ModelRepoVariant,
|
||||||
ModelType,
|
ModelType,
|
||||||
@ -532,7 +533,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
) -> str:
|
) -> str:
|
||||||
# Note that we may be passed a pre-populated AnyModelConfig object,
|
# Note that we may be passed a pre-populated AnyModelConfig object,
|
||||||
# in which case the key field should have been populated by the caller (e.g. in `install_path`).
|
# in which case the key field should have been populated by the caller (e.g. in `install_path`).
|
||||||
config["key"] = config.get("key", uuid_string())
|
if config is not None:
|
||||||
|
config["key"] = config.get("key", uuid_string())
|
||||||
info = info or ModelProbe.probe(model_path, config)
|
info = info or ModelProbe.probe(model_path, config)
|
||||||
override_key: Optional[str] = config.get("key") if config else None
|
override_key: Optional[str] = config.get("key") if config else None
|
||||||
|
|
||||||
@ -546,7 +548,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
info.path = model_path.as_posix()
|
info.path = model_path.as_posix()
|
||||||
|
|
||||||
# add 'main' specific fields
|
# add 'main' specific fields
|
||||||
if hasattr(info, "config"):
|
if isinstance(info, CheckpointConfigBase):
|
||||||
# make config relative to our root
|
# make config relative to our root
|
||||||
legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config).resolve()
|
legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config).resolve()
|
||||||
info.config = legacy_conf.relative_to(self.app_config.root_dir).as_posix()
|
info.config = legacy_conf.relative_to(self.app_config.root_dir).as_posix()
|
||||||
|
@ -135,7 +135,6 @@ class ModelConfigBase(BaseModel):
|
|||||||
) # if model is converted or otherwise modified, this will hold updated hash
|
) # if model is converted or otherwise modified, this will hold updated hash
|
||||||
description: Optional[str] = Field(description="human readable description of the model", default=None)
|
description: Optional[str] = Field(description="human readable description of the model", default=None)
|
||||||
source: Optional[str] = Field(description="model original source (path, URL or repo_id)", default=None)
|
source: Optional[str] = Field(description="model original source (path, URL or repo_id)", default=None)
|
||||||
last_modified: Optional[float] = Field(description="timestamp for modification time", default_factory=time.time)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||||
@ -160,6 +159,9 @@ class CheckpointConfigBase(ModelConfigBase):
|
|||||||
|
|
||||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||||
config: str = Field(description="path to the checkpoint model config file")
|
config: str = Field(description="path to the checkpoint model config file")
|
||||||
|
last_modified: Optional[float] = Field(
|
||||||
|
description="When this model was last converted to diffusers", default_factory=time.time
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DiffusersConfigBase(ModelConfigBase):
|
class DiffusersConfigBase(ModelConfigBase):
|
||||||
@ -191,7 +193,7 @@ class LoRADiffusersConfig(ModelConfigBase):
|
|||||||
return Tag(f"{ModelType.Lora}.{ModelFormat.Diffusers}")
|
return Tag(f"{ModelType.Lora}.{ModelFormat.Diffusers}")
|
||||||
|
|
||||||
|
|
||||||
class VaeCheckpointConfig(ModelConfigBase):
|
class VaeCheckpointConfig(CheckpointConfigBase):
|
||||||
"""Model config for standalone VAE models."""
|
"""Model config for standalone VAE models."""
|
||||||
|
|
||||||
type: Literal[ModelType.Vae] = ModelType.Vae
|
type: Literal[ModelType.Vae] = ModelType.Vae
|
||||||
@ -257,25 +259,20 @@ class TextualInversionFolderConfig(ModelConfigBase):
|
|||||||
return Tag(f"{ModelType.TextualInversion}.{ModelFormat.EmbeddingFolder}")
|
return Tag(f"{ModelType.TextualInversion}.{ModelFormat.EmbeddingFolder}")
|
||||||
|
|
||||||
|
|
||||||
class _MainConfig(ModelConfigBase):
|
class MainCheckpointConfig(CheckpointConfigBase):
|
||||||
"""Model config for main models."""
|
|
||||||
|
|
||||||
variant: ModelVariantType = ModelVariantType.Normal
|
|
||||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
|
||||||
upcast_attention: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class MainCheckpointConfig(CheckpointConfigBase, _MainConfig):
|
|
||||||
"""Model config for main checkpoint models."""
|
"""Model config for main checkpoint models."""
|
||||||
|
|
||||||
type: Literal[ModelType.Main] = ModelType.Main
|
type: Literal[ModelType.Main] = ModelType.Main
|
||||||
|
variant: ModelVariantType = ModelVariantType.Normal
|
||||||
|
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||||
|
upcast_attention: bool = False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_tag() -> Tag:
|
def get_tag() -> Tag:
|
||||||
return Tag(f"{ModelType.Main}.{ModelFormat.Checkpoint}")
|
return Tag(f"{ModelType.Main}.{ModelFormat.Checkpoint}")
|
||||||
|
|
||||||
|
|
||||||
class MainDiffusersConfig(DiffusersConfigBase, _MainConfig):
|
class MainDiffusersConfig(DiffusersConfigBase):
|
||||||
"""Model config for main diffusers models."""
|
"""Model config for main diffusers models."""
|
||||||
|
|
||||||
type: Literal[ModelType.Main] = ModelType.Main
|
type: Literal[ModelType.Main] = ModelType.Main
|
||||||
@ -382,6 +379,6 @@ class ModelConfigFactory(object):
|
|||||||
assert model is not None
|
assert model is not None
|
||||||
if key:
|
if key:
|
||||||
model.key = key
|
model.key = key
|
||||||
if timestamp:
|
if isinstance(model, CheckpointConfigBase) and timestamp is not None:
|
||||||
model.last_modified = timestamp
|
model.last_modified = timestamp
|
||||||
return model # type: ignore
|
return model # type: ignore
|
||||||
|
@ -25,7 +25,7 @@ class ControlNetLoader(GenericDiffusersLoader):
|
|||||||
"""Class to load ControlNet models."""
|
"""Class to load ControlNet models."""
|
||||||
|
|
||||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||||
if config.format != ModelFormat.Checkpoint:
|
if not isinstance(config, CheckpointConfigBase):
|
||||||
return False
|
return False
|
||||||
elif (
|
elif (
|
||||||
dest_path.exists()
|
dest_path.exists()
|
||||||
|
@ -17,7 +17,7 @@ from invokeai.backend.model_manager import (
|
|||||||
ModelVariantType,
|
ModelVariantType,
|
||||||
SubModelType,
|
SubModelType,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.config import MainCheckpointConfig
|
from invokeai.backend.model_manager.config import CheckpointConfigBase, MainCheckpointConfig
|
||||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
||||||
|
|
||||||
from .. import ModelLoaderRegistry
|
from .. import ModelLoaderRegistry
|
||||||
@ -55,7 +55,7 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||||
if config.format != ModelFormat.Checkpoint:
|
if not isinstance(config, CheckpointConfigBase):
|
||||||
return False
|
return False
|
||||||
elif (
|
elif (
|
||||||
dest_path.exists()
|
dest_path.exists()
|
||||||
|
Loading…
Reference in New Issue
Block a user