From 7cb0da1f66245cd6bf51b3e764ccca1c14db5eb7 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 4 Mar 2024 19:16:25 +1100 Subject: [PATCH] refactor(mm): wip schema changes --- .../model_install/model_install_default.py | 21 ++++++--- .../sqlite_migrator/migrations/migration_7.py | 1 + invokeai/backend/model_manager/config.py | 6 ++- .../metadata/fetch/huggingface.py | 47 +++++++------------ invokeai/backend/model_manager/probe.py | 6 ++- 5 files changed, 42 insertions(+), 39 deletions(-) diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index d676204854..5a75b80de5 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -37,6 +37,7 @@ from invokeai.backend.model_manager.metadata import ( ModelMetadataWithFiles, RemoteModelFile, ) +from invokeai.backend.model_manager.metadata.metadata_base import CivitaiMetadata, HuggingFaceMetadata from invokeai.backend.model_manager.probe import ModelProbe from invokeai.backend.model_manager.search import ModelSearch from invokeai.backend.util import Chdir, InvokeAILogger @@ -152,9 +153,6 @@ class ModelInstallService(ModelInstallServiceBase): ) -> str: # noqa D102 model_path = Path(model_path) config = config or {} - if not config.get("source"): - config["source"] = model_path.resolve().as_posix() - config["key"] = config.get("key", uuid_string()) info: AnyModelConfig = ModelProbe.probe(Path(model_path), config) @@ -379,15 +377,17 @@ class ModelInstallService(ModelInstallServiceBase): self._signal_job_running(job) job.config_in["source"] = str(job.source) job.config_in["source_type"] = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__] + # enter the metadata, if there is any + if isinstance(job.source_metadata, (CivitaiMetadata, HuggingFaceMetadata)): + job.config_in["source_api_response"] = job.source_metadata.api_response + if isinstance(job.source_metadata, CivitaiMetadata) and job.source_metadata.trigger_words: + job.config_in["trigger_words"] = job.source_metadata.trigger_words + if job.inplace: key = self.register_path(job.local_path, job.config_in) else: key = self.install_path(job.local_path, job.config_in) job.config_out = self.record_store.get_model(key) - - # enter the metadata, if there is any - if job.source_metadata: - self._metadata_store.add_metadata(key, job.source_metadata) self._signal_job_completed(job) except InvalidModelConfigException as excp: @@ -525,6 +525,13 @@ class ModelInstallService(ModelInstallServiceBase): move(old_path, new_path) return new_path + # def _probe_model(self, model_path: Path, config: Optional[Dict[str, Any]] = None) -> AnyModelConfig: + # info: AnyModelConfig = ModelProbe.probe(Path(model_path)) + # if config: # used to override probe fields + # for key, value in config.items(): + # setattr(info, key, value) + # return info + def _register( self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None ) -> str: diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_7.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_7.py index c78fa3ce8c..60a58c1a38 100644 --- a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_7.py +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_7.py @@ -24,6 +24,7 @@ class Migration7Callback: source TEXT GENERATED ALWAYS as (json_extract(config, '$.source')) VIRTUAL NOT NULL, source_type TEXT GENERATED ALWAYS as (json_extract(config, '$.source_type')) VIRTUAL NOT NULL, source_api_response TEXT GENERATED ALWAYS as (json_extract(config, '$.source_api_response')) VIRTUAL, + trigger_words TEXT GENERATED ALWAYS as (json_extract(config, '$.trigger_words')) VIRTUAL, -- Serialized JSON representation of the whole config object, which will contain additional fields from subclasses config TEXT NOT NULL, created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index e28071c346..45e0d5524e 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -26,7 +26,7 @@ from typing import Literal, Optional, Type, Union import torch from diffusers.models.modeling_utils import ModelMixin -from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter +from pydantic import BaseModel, ConfigDict, Discriminator, Field, JsonValue, Tag, TypeAdapter from typing_extensions import Annotated, Any, Dict from ..raw_model import RawModel @@ -142,6 +142,10 @@ class ModelConfigBase(BaseModel): description: Optional[str] = Field(description="Model description", default=None) source: str = Field(description="The original source of the model (path, URL or repo_id).") source_type: ModelSourceType = Field(description="The type of source") + source_api_response: Optional[JsonValue] = Field( + description="The original API response from the source", default=None + ) + trigger_words: Optional[set[str]] = Field(description="Set of trigger words for this model", default=None) model_config = ConfigDict(use_enum_values=False, validate_assignment=True) diff --git a/invokeai/backend/model_manager/metadata/fetch/huggingface.py b/invokeai/backend/model_manager/metadata/fetch/huggingface.py index e148965a66..a42907c658 100644 --- a/invokeai/backend/model_manager/metadata/fetch/huggingface.py +++ b/invokeai/backend/model_manager/metadata/fetch/huggingface.py @@ -19,7 +19,6 @@ from typing import Optional import requests from huggingface_hub import HfApi, configure_http_backend, hf_hub_url -from huggingface_hub.hf_api import RepoSibling from huggingface_hub.utils._errors import RepositoryNotFoundError, RevisionNotFoundError from pydantic.networks import AnyHttpUrl from requests.sessions import Session @@ -61,6 +60,7 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase): # Little loop which tries fetching a revision corresponding to the selected variant. # If not available, then set variant to None and get the default. # If this too fails, raise exception. + model_info = None while not model_info: try: @@ -73,12 +73,23 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase): else: variant = None + files: list[RemoteModelFile] = [] + _, name = id.split("/") - return HuggingFaceMetadata( - id=model_info.id, - name=name, - files=parse_siblings(id, model_info.siblings, variant), - ) + + for s in model_info.siblings or []: + assert s.rfilename is not None + assert s.size is not None + files.append( + RemoteModelFile( + url=hf_hub_url(id, s.rfilename, revision=variant), + path=Path(name, s.rfilename), + size=s.size, + sha256=s.lfs.get("sha256") if s.lfs else None, + ) + ) + + return HuggingFaceMetadata(id=model_info.id, name=name, files=files) def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata: """ @@ -91,27 +102,3 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase): return self.from_id(repo_id) else: raise UnknownMetadataException(f"'{url}' does not look like a HuggingFace model page") - - -def parse_siblings( - repo_id: str, siblings: Optional[list[RepoSibling]] = None, variant: Optional[ModelRepoVariant] = None -) -> list[RemoteModelFile]: - """Parse the siblings list from the HuggingFace API into a list of RemoteModelFile objects.""" - if not siblings: - return [] - - files: list[RemoteModelFile] = [] - - for s in siblings: - assert s.rfilename is not None - assert s.size is not None - files.append( - RemoteModelFile( - url=hf_hub_url(repo_id, s.rfilename, revision=variant.value if variant else None), - path=Path(s.rfilename), - size=s.size, - sha256=s.lfs.get("sha256") if s.lfs else None, - ) - ) - - return files diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index ca04b47331..c837993888 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -8,6 +8,7 @@ import torch from picklescan.scanner import scan_file_path import invokeai.backend.util.logging as logger +from invokeai.app.util.misc import uuid_string from invokeai.backend.util.util import SilenceWarnings from .config import ( @@ -149,6 +150,9 @@ class ModelProbe(object): probe = probe_class(model_path) + fields["source_type"] = fields.get("source_type") + fields["source"] = fields.get("source") or model_path.as_posix() + fields["key"] = fields.get("key", uuid_string()) fields["path"] = model_path.as_posix() fields["type"] = fields.get("type") or model_type fields["base"] = fields.get("base") or probe.get_base_type() @@ -162,7 +166,7 @@ class ModelProbe(object): fields["format"] = fields.get("format") or probe.get_format() fields["hash"] = fields.get("hash") or ModelHash().hash(model_path) - if format_type == ModelFormat.Diffusers and hasattr(probe, "get_repo_variant"): + if format_type == ModelFormat.Diffusers and isinstance(probe, FolderProbeBase): fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant() # additional fields needed for main and controlnet models