refactor(mm): wip schema changes

This commit is contained in:
psychedelicious 2024-03-04 19:16:25 +11:00
parent 3534366146
commit 7cb0da1f66
5 changed files with 42 additions and 39 deletions

View File

@ -37,6 +37,7 @@ from invokeai.backend.model_manager.metadata import (
ModelMetadataWithFiles,
RemoteModelFile,
)
from invokeai.backend.model_manager.metadata.metadata_base import CivitaiMetadata, HuggingFaceMetadata
from invokeai.backend.model_manager.probe import ModelProbe
from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.util import Chdir, InvokeAILogger
@ -152,9 +153,6 @@ class ModelInstallService(ModelInstallServiceBase):
) -> str: # noqa D102
model_path = Path(model_path)
config = config or {}
if not config.get("source"):
config["source"] = model_path.resolve().as_posix()
config["key"] = config.get("key", uuid_string())
info: AnyModelConfig = ModelProbe.probe(Path(model_path), config)
@ -379,15 +377,17 @@ class ModelInstallService(ModelInstallServiceBase):
self._signal_job_running(job)
job.config_in["source"] = str(job.source)
job.config_in["source_type"] = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
# enter the metadata, if there is any
if isinstance(job.source_metadata, (CivitaiMetadata, HuggingFaceMetadata)):
job.config_in["source_api_response"] = job.source_metadata.api_response
if isinstance(job.source_metadata, CivitaiMetadata) and job.source_metadata.trigger_words:
job.config_in["trigger_words"] = job.source_metadata.trigger_words
if job.inplace:
key = self.register_path(job.local_path, job.config_in)
else:
key = self.install_path(job.local_path, job.config_in)
job.config_out = self.record_store.get_model(key)
# enter the metadata, if there is any
if job.source_metadata:
self._metadata_store.add_metadata(key, job.source_metadata)
self._signal_job_completed(job)
except InvalidModelConfigException as excp:
@ -525,6 +525,13 @@ class ModelInstallService(ModelInstallServiceBase):
move(old_path, new_path)
return new_path
# def _probe_model(self, model_path: Path, config: Optional[Dict[str, Any]] = None) -> AnyModelConfig:
# info: AnyModelConfig = ModelProbe.probe(Path(model_path))
# if config: # used to override probe fields
# for key, value in config.items():
# setattr(info, key, value)
# return info
def _register(
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
) -> str:

View File

@ -24,6 +24,7 @@ class Migration7Callback:
source TEXT GENERATED ALWAYS as (json_extract(config, '$.source')) VIRTUAL NOT NULL,
source_type TEXT GENERATED ALWAYS as (json_extract(config, '$.source_type')) VIRTUAL NOT NULL,
source_api_response TEXT GENERATED ALWAYS as (json_extract(config, '$.source_api_response')) VIRTUAL,
trigger_words TEXT GENERATED ALWAYS as (json_extract(config, '$.trigger_words')) VIRTUAL,
-- Serialized JSON representation of the whole config object, which will contain additional fields from subclasses
config TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),

View File

@ -26,7 +26,7 @@ from typing import Literal, Optional, Type, Union
import torch
from diffusers.models.modeling_utils import ModelMixin
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
from pydantic import BaseModel, ConfigDict, Discriminator, Field, JsonValue, Tag, TypeAdapter
from typing_extensions import Annotated, Any, Dict
from ..raw_model import RawModel
@ -142,6 +142,10 @@ class ModelConfigBase(BaseModel):
description: Optional[str] = Field(description="Model description", default=None)
source: str = Field(description="The original source of the model (path, URL or repo_id).")
source_type: ModelSourceType = Field(description="The type of source")
source_api_response: Optional[JsonValue] = Field(
description="The original API response from the source", default=None
)
trigger_words: Optional[set[str]] = Field(description="Set of trigger words for this model", default=None)
model_config = ConfigDict(use_enum_values=False, validate_assignment=True)

View File

@ -19,7 +19,6 @@ from typing import Optional
import requests
from huggingface_hub import HfApi, configure_http_backend, hf_hub_url
from huggingface_hub.hf_api import RepoSibling
from huggingface_hub.utils._errors import RepositoryNotFoundError, RevisionNotFoundError
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
@ -61,6 +60,7 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
# Little loop which tries fetching a revision corresponding to the selected variant.
# If not available, then set variant to None and get the default.
# If this too fails, raise exception.
model_info = None
while not model_info:
try:
@ -73,12 +73,23 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
else:
variant = None
files: list[RemoteModelFile] = []
_, name = id.split("/")
return HuggingFaceMetadata(
id=model_info.id,
name=name,
files=parse_siblings(id, model_info.siblings, variant),
)
for s in model_info.siblings or []:
assert s.rfilename is not None
assert s.size is not None
files.append(
RemoteModelFile(
url=hf_hub_url(id, s.rfilename, revision=variant),
path=Path(name, s.rfilename),
size=s.size,
sha256=s.lfs.get("sha256") if s.lfs else None,
)
)
return HuggingFaceMetadata(id=model_info.id, name=name, files=files)
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
"""
@ -91,27 +102,3 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
return self.from_id(repo_id)
else:
raise UnknownMetadataException(f"'{url}' does not look like a HuggingFace model page")
def parse_siblings(
repo_id: str, siblings: Optional[list[RepoSibling]] = None, variant: Optional[ModelRepoVariant] = None
) -> list[RemoteModelFile]:
"""Parse the siblings list from the HuggingFace API into a list of RemoteModelFile objects."""
if not siblings:
return []
files: list[RemoteModelFile] = []
for s in siblings:
assert s.rfilename is not None
assert s.size is not None
files.append(
RemoteModelFile(
url=hf_hub_url(repo_id, s.rfilename, revision=variant.value if variant else None),
path=Path(s.rfilename),
size=s.size,
sha256=s.lfs.get("sha256") if s.lfs else None,
)
)
return files

View File

@ -8,6 +8,7 @@ import torch
from picklescan.scanner import scan_file_path
import invokeai.backend.util.logging as logger
from invokeai.app.util.misc import uuid_string
from invokeai.backend.util.util import SilenceWarnings
from .config import (
@ -149,6 +150,9 @@ class ModelProbe(object):
probe = probe_class(model_path)
fields["source_type"] = fields.get("source_type")
fields["source"] = fields.get("source") or model_path.as_posix()
fields["key"] = fields.get("key", uuid_string())
fields["path"] = model_path.as_posix()
fields["type"] = fields.get("type") or model_type
fields["base"] = fields.get("base") or probe.get_base_type()
@ -162,7 +166,7 @@ class ModelProbe(object):
fields["format"] = fields.get("format") or probe.get_format()
fields["hash"] = fields.get("hash") or ModelHash().hash(model_path)
if format_type == ModelFormat.Diffusers and hasattr(probe, "get_repo_variant"):
if format_type == ModelFormat.Diffusers and isinstance(probe, FolderProbeBase):
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
# additional fields needed for main and controlnet models