feat(mm): add source_type to model configs

This commit is contained in:
psychedelicious 2024-03-01 22:12:48 +11:00
parent 4471ea8ad1
commit 9378e47a06
4 changed files with 35 additions and 17 deletions

View File

@ -18,6 +18,7 @@ from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
from invokeai.backend.model_manager.config import ModelSourceType
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from ..model_metadata import ModelMetadataStoreBase
@ -151,6 +152,13 @@ ModelSource = Annotated[
Union[LocalModelSource, HFModelSource, CivitaiModelSource, URLModelSource], Field(discriminator="type")
]
MODEL_SOURCE_TO_TYPE_MAP = {
URLModelSource: ModelSourceType.Url,
HFModelSource: ModelSourceType.HFRepoID,
CivitaiModelSource: ModelSourceType.CivitAI,
LocalModelSource: ModelSourceType.Path,
}
class ModelInstallJob(BaseModel):
"""Object that tracks the current status of an install request."""

View File

@ -27,6 +27,7 @@ from invokeai.backend.model_manager.config import (
CheckpointConfigBase,
InvalidModelConfigException,
ModelRepoVariant,
ModelSourceType,
ModelType,
)
from invokeai.backend.model_manager.metadata import (
@ -42,6 +43,7 @@ from invokeai.backend.util import Chdir, InvokeAILogger
from invokeai.backend.util.devices import choose_precision, choose_torch_device
from .model_install_base import (
MODEL_SOURCE_TO_TYPE_MAP,
CivitaiModelSource,
HFModelSource,
InstallStatus,
@ -140,6 +142,7 @@ class ModelInstallService(ModelInstallServiceBase):
config = config or {}
if not config.get("source"):
config["source"] = model_path.resolve().as_posix()
config["source_type"] = ModelSourceType.Path
return self._register(model_path, config)
def install_path(
@ -153,7 +156,7 @@ class ModelInstallService(ModelInstallServiceBase):
config["source"] = model_path.resolve().as_posix()
config["key"] = config.get("key", uuid_string())
info: AnyModelConfig = self._probe_model(Path(model_path), config)
info: AnyModelConfig = ModelProbe.probe(Path(model_path), config)
if preferred_name := config.get("name"):
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
@ -375,6 +378,7 @@ class ModelInstallService(ModelInstallServiceBase):
job.bytes = job.total_bytes
self._signal_job_running(job)
job.config_in["source"] = str(job.source)
job.config_in["source_type"] = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
if job.inplace:
key = self.register_path(job.local_path, job.config_in)
else:
@ -521,13 +525,6 @@ class ModelInstallService(ModelInstallServiceBase):
move(old_path, new_path)
return new_path
def _probe_model(self, model_path: Path, config: Optional[Dict[str, Any]] = None) -> AnyModelConfig:
info: AnyModelConfig = ModelProbe.probe(Path(model_path))
if config: # used to override probe fields
for key, value in config.items():
setattr(info, key, value)
return info
def _register(
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
) -> str:
@ -538,8 +535,8 @@ class ModelInstallService(ModelInstallServiceBase):
info = info or ModelProbe.probe(model_path, config)
override_key: Optional[str] = config.get("key") if config else None
assert info.original_hash # always assigned by probe()
info.key = override_key or info.original_hash
assert info.hash # always assigned by probe()
info.key = override_key or info.hash
model_path = model_path.absolute()
if model_path.is_relative_to(self.app_config.models_path):
@ -573,7 +570,7 @@ class ModelInstallService(ModelInstallServiceBase):
source=source,
config_in=config or {},
local_path=Path(source.path),
inplace=source.inplace,
inplace=source.inplace or False,
)
def _import_from_civitai(self, source: CivitaiModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
@ -630,7 +627,7 @@ class ModelInstallService(ModelInstallServiceBase):
def _import_remote_model(
self,
source: ModelSource,
source: HFModelSource | CivitaiModelSource | URLModelSource,
remote_files: List[RemoteModelFile],
metadata: Optional[AnyModelRepoMetadata],
config: Optional[Dict[str, Any]],
@ -658,7 +655,7 @@ class ModelInstallService(ModelInstallServiceBase):
# In the event that there is a subfolder specified in the source,
# we need to remove it from the destination path in order to avoid
# creating unwanted subfolders
if hasattr(source, "subfolder") and source.subfolder:
if isinstance(source, HFModelSource) and source.subfolder:
root = Path(remote_files[0].path.parts[0])
subfolder = root / source.subfolder
else:

View File

@ -120,6 +120,15 @@ class ModelRepoVariant(str, Enum):
FLAX = "flax"
class ModelSourceType(str, Enum):
"""Model source type."""
Path = "path"
Url = "url"
HFRepoID = "hf_repo_id"
CivitAI = "civitai"
class ModelConfigBase(BaseModel):
"""Base class for model configuration information."""
@ -128,7 +137,9 @@ class ModelConfigBase(BaseModel):
base: BaseModelType = Field(description="base model")
key: str = Field(description="unique key for model", default="<NOKEY>")
hash: Optional[str] = Field(description="original fasthash of model contents", default=None)
description: Optional[str] = Field(description="human readable description of the model", default=None)
description: Optional[str] = Field(description="Model description", default=None)
source: str = Field(description="The source of the model (e.g. path, URL, HF Repo ID)")
source_type: ModelSourceType = Field(description="The type of source")
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:

View File

@ -147,7 +147,6 @@ class ModelProbe(object):
if not probe_class:
raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}")
hash = ModelHash().hash(model_path)
probe = probe_class(model_path)
fields["path"] = model_path.as_posix()
@ -161,13 +160,16 @@ class ModelProbe(object):
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
)
fields["format"] = fields.get("format") or probe.get_format()
fields["hash"] = fields.get("hash") or hash
fields["hash"] = fields.get("hash") or ModelHash().hash(model_path)
if format_type == ModelFormat.Diffusers and hasattr(probe, "get_repo_variant"):
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
# additional fields needed for main and controlnet models
if fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.Vae] and fields["format"] == ModelFormat.Checkpoint:
if (
fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.Vae]
and fields["format"] is ModelFormat.Checkpoint
):
fields["config_path"] = cls._get_checkpoint_config_path(
model_path,
model_type=fields["type"],