feat(mm): add source_type to model configs

This commit is contained in:
psychedelicious 2024-03-01 22:12:48 +11:00
parent 4471ea8ad1
commit 9378e47a06
4 changed files with 35 additions and 17 deletions

View File

@ -18,6 +18,7 @@ from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_records import ModelRecordServiceBase from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
from invokeai.backend.model_manager.config import ModelSourceType
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from ..model_metadata import ModelMetadataStoreBase from ..model_metadata import ModelMetadataStoreBase
@ -151,6 +152,13 @@ ModelSource = Annotated[
Union[LocalModelSource, HFModelSource, CivitaiModelSource, URLModelSource], Field(discriminator="type") Union[LocalModelSource, HFModelSource, CivitaiModelSource, URLModelSource], Field(discriminator="type")
] ]
MODEL_SOURCE_TO_TYPE_MAP = {
URLModelSource: ModelSourceType.Url,
HFModelSource: ModelSourceType.HFRepoID,
CivitaiModelSource: ModelSourceType.CivitAI,
LocalModelSource: ModelSourceType.Path,
}
class ModelInstallJob(BaseModel): class ModelInstallJob(BaseModel):
"""Object that tracks the current status of an install request.""" """Object that tracks the current status of an install request."""

View File

@ -27,6 +27,7 @@ from invokeai.backend.model_manager.config import (
CheckpointConfigBase, CheckpointConfigBase,
InvalidModelConfigException, InvalidModelConfigException,
ModelRepoVariant, ModelRepoVariant,
ModelSourceType,
ModelType, ModelType,
) )
from invokeai.backend.model_manager.metadata import ( from invokeai.backend.model_manager.metadata import (
@ -42,6 +43,7 @@ from invokeai.backend.util import Chdir, InvokeAILogger
from invokeai.backend.util.devices import choose_precision, choose_torch_device from invokeai.backend.util.devices import choose_precision, choose_torch_device
from .model_install_base import ( from .model_install_base import (
MODEL_SOURCE_TO_TYPE_MAP,
CivitaiModelSource, CivitaiModelSource,
HFModelSource, HFModelSource,
InstallStatus, InstallStatus,
@ -140,6 +142,7 @@ class ModelInstallService(ModelInstallServiceBase):
config = config or {} config = config or {}
if not config.get("source"): if not config.get("source"):
config["source"] = model_path.resolve().as_posix() config["source"] = model_path.resolve().as_posix()
config["source_type"] = ModelSourceType.Path
return self._register(model_path, config) return self._register(model_path, config)
def install_path( def install_path(
@ -153,7 +156,7 @@ class ModelInstallService(ModelInstallServiceBase):
config["source"] = model_path.resolve().as_posix() config["source"] = model_path.resolve().as_posix()
config["key"] = config.get("key", uuid_string()) config["key"] = config.get("key", uuid_string())
info: AnyModelConfig = self._probe_model(Path(model_path), config) info: AnyModelConfig = ModelProbe.probe(Path(model_path), config)
if preferred_name := config.get("name"): if preferred_name := config.get("name"):
preferred_name = Path(preferred_name).with_suffix(model_path.suffix) preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
@ -375,6 +378,7 @@ class ModelInstallService(ModelInstallServiceBase):
job.bytes = job.total_bytes job.bytes = job.total_bytes
self._signal_job_running(job) self._signal_job_running(job)
job.config_in["source"] = str(job.source) job.config_in["source"] = str(job.source)
job.config_in["source_type"] = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
if job.inplace: if job.inplace:
key = self.register_path(job.local_path, job.config_in) key = self.register_path(job.local_path, job.config_in)
else: else:
@ -521,13 +525,6 @@ class ModelInstallService(ModelInstallServiceBase):
move(old_path, new_path) move(old_path, new_path)
return new_path return new_path
def _probe_model(self, model_path: Path, config: Optional[Dict[str, Any]] = None) -> AnyModelConfig:
info: AnyModelConfig = ModelProbe.probe(Path(model_path))
if config: # used to override probe fields
for key, value in config.items():
setattr(info, key, value)
return info
def _register( def _register(
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
) -> str: ) -> str:
@ -538,8 +535,8 @@ class ModelInstallService(ModelInstallServiceBase):
info = info or ModelProbe.probe(model_path, config) info = info or ModelProbe.probe(model_path, config)
override_key: Optional[str] = config.get("key") if config else None override_key: Optional[str] = config.get("key") if config else None
assert info.original_hash # always assigned by probe() assert info.hash # always assigned by probe()
info.key = override_key or info.original_hash info.key = override_key or info.hash
model_path = model_path.absolute() model_path = model_path.absolute()
if model_path.is_relative_to(self.app_config.models_path): if model_path.is_relative_to(self.app_config.models_path):
@ -573,7 +570,7 @@ class ModelInstallService(ModelInstallServiceBase):
source=source, source=source,
config_in=config or {}, config_in=config or {},
local_path=Path(source.path), local_path=Path(source.path),
inplace=source.inplace, inplace=source.inplace or False,
) )
def _import_from_civitai(self, source: CivitaiModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob: def _import_from_civitai(self, source: CivitaiModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
@ -630,7 +627,7 @@ class ModelInstallService(ModelInstallServiceBase):
def _import_remote_model( def _import_remote_model(
self, self,
source: ModelSource, source: HFModelSource | CivitaiModelSource | URLModelSource,
remote_files: List[RemoteModelFile], remote_files: List[RemoteModelFile],
metadata: Optional[AnyModelRepoMetadata], metadata: Optional[AnyModelRepoMetadata],
config: Optional[Dict[str, Any]], config: Optional[Dict[str, Any]],
@ -658,7 +655,7 @@ class ModelInstallService(ModelInstallServiceBase):
# In the event that there is a subfolder specified in the source, # In the event that there is a subfolder specified in the source,
# we need to remove it from the destination path in order to avoid # we need to remove it from the destination path in order to avoid
# creating unwanted subfolders # creating unwanted subfolders
if hasattr(source, "subfolder") and source.subfolder: if isinstance(source, HFModelSource) and source.subfolder:
root = Path(remote_files[0].path.parts[0]) root = Path(remote_files[0].path.parts[0])
subfolder = root / source.subfolder subfolder = root / source.subfolder
else: else:

View File

@ -120,6 +120,15 @@ class ModelRepoVariant(str, Enum):
FLAX = "flax" FLAX = "flax"
class ModelSourceType(str, Enum):
"""Model source type."""
Path = "path"
Url = "url"
HFRepoID = "hf_repo_id"
CivitAI = "civitai"
class ModelConfigBase(BaseModel): class ModelConfigBase(BaseModel):
"""Base class for model configuration information.""" """Base class for model configuration information."""
@ -128,7 +137,9 @@ class ModelConfigBase(BaseModel):
base: BaseModelType = Field(description="base model") base: BaseModelType = Field(description="base model")
key: str = Field(description="unique key for model", default="<NOKEY>") key: str = Field(description="unique key for model", default="<NOKEY>")
hash: Optional[str] = Field(description="original fasthash of model contents", default=None) hash: Optional[str] = Field(description="original fasthash of model contents", default=None)
description: Optional[str] = Field(description="human readable description of the model", default=None) description: Optional[str] = Field(description="Model description", default=None)
source: str = Field(description="The source of the model (e.g. path, URL, HF Repo ID)")
source_type: ModelSourceType = Field(description="The type of source")
@staticmethod @staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None: def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:

View File

@ -147,7 +147,6 @@ class ModelProbe(object):
if not probe_class: if not probe_class:
raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}") raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}")
hash = ModelHash().hash(model_path)
probe = probe_class(model_path) probe = probe_class(model_path)
fields["path"] = model_path.as_posix() fields["path"] = model_path.as_posix()
@ -161,13 +160,16 @@ class ModelProbe(object):
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}" fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
) )
fields["format"] = fields.get("format") or probe.get_format() fields["format"] = fields.get("format") or probe.get_format()
fields["hash"] = fields.get("hash") or hash fields["hash"] = fields.get("hash") or ModelHash().hash(model_path)
if format_type == ModelFormat.Diffusers and hasattr(probe, "get_repo_variant"): if format_type == ModelFormat.Diffusers and hasattr(probe, "get_repo_variant"):
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant() fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
# additional fields needed for main and controlnet models # additional fields needed for main and controlnet models
if fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.Vae] and fields["format"] == ModelFormat.Checkpoint: if (
fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.Vae]
and fields["format"] is ModelFormat.Checkpoint
):
fields["config_path"] = cls._get_checkpoint_config_path( fields["config_path"] = cls._get_checkpoint_config_path(
model_path, model_path,
model_type=fields["type"], model_type=fields["type"],