mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor(mm): wip schema changes
This commit is contained in:
parent
3534366146
commit
7cb0da1f66
@ -37,6 +37,7 @@ from invokeai.backend.model_manager.metadata import (
|
||||
ModelMetadataWithFiles,
|
||||
RemoteModelFile,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import CivitaiMetadata, HuggingFaceMetadata
|
||||
from invokeai.backend.model_manager.probe import ModelProbe
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
from invokeai.backend.util import Chdir, InvokeAILogger
|
||||
@ -152,9 +153,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
config = config or {}
|
||||
if not config.get("source"):
|
||||
config["source"] = model_path.resolve().as_posix()
|
||||
config["key"] = config.get("key", uuid_string())
|
||||
|
||||
info: AnyModelConfig = ModelProbe.probe(Path(model_path), config)
|
||||
|
||||
@ -379,15 +377,17 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._signal_job_running(job)
|
||||
job.config_in["source"] = str(job.source)
|
||||
job.config_in["source_type"] = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
|
||||
# enter the metadata, if there is any
|
||||
if isinstance(job.source_metadata, (CivitaiMetadata, HuggingFaceMetadata)):
|
||||
job.config_in["source_api_response"] = job.source_metadata.api_response
|
||||
if isinstance(job.source_metadata, CivitaiMetadata) and job.source_metadata.trigger_words:
|
||||
job.config_in["trigger_words"] = job.source_metadata.trigger_words
|
||||
|
||||
if job.inplace:
|
||||
key = self.register_path(job.local_path, job.config_in)
|
||||
else:
|
||||
key = self.install_path(job.local_path, job.config_in)
|
||||
job.config_out = self.record_store.get_model(key)
|
||||
|
||||
# enter the metadata, if there is any
|
||||
if job.source_metadata:
|
||||
self._metadata_store.add_metadata(key, job.source_metadata)
|
||||
self._signal_job_completed(job)
|
||||
|
||||
except InvalidModelConfigException as excp:
|
||||
@ -525,6 +525,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
move(old_path, new_path)
|
||||
return new_path
|
||||
|
||||
# def _probe_model(self, model_path: Path, config: Optional[Dict[str, Any]] = None) -> AnyModelConfig:
|
||||
# info: AnyModelConfig = ModelProbe.probe(Path(model_path))
|
||||
# if config: # used to override probe fields
|
||||
# for key, value in config.items():
|
||||
# setattr(info, key, value)
|
||||
# return info
|
||||
|
||||
def _register(
|
||||
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
|
||||
) -> str:
|
||||
|
@ -24,6 +24,7 @@ class Migration7Callback:
|
||||
source TEXT GENERATED ALWAYS as (json_extract(config, '$.source')) VIRTUAL NOT NULL,
|
||||
source_type TEXT GENERATED ALWAYS as (json_extract(config, '$.source_type')) VIRTUAL NOT NULL,
|
||||
source_api_response TEXT GENERATED ALWAYS as (json_extract(config, '$.source_api_response')) VIRTUAL,
|
||||
trigger_words TEXT GENERATED ALWAYS as (json_extract(config, '$.trigger_words')) VIRTUAL,
|
||||
-- Serialized JSON representation of the whole config object, which will contain additional fields from subclasses
|
||||
config TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
|
@ -26,7 +26,7 @@ from typing import Literal, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
|
||||
from pydantic import BaseModel, ConfigDict, Discriminator, Field, JsonValue, Tag, TypeAdapter
|
||||
from typing_extensions import Annotated, Any, Dict
|
||||
|
||||
from ..raw_model import RawModel
|
||||
@ -142,6 +142,10 @@ class ModelConfigBase(BaseModel):
|
||||
description: Optional[str] = Field(description="Model description", default=None)
|
||||
source: str = Field(description="The original source of the model (path, URL or repo_id).")
|
||||
source_type: ModelSourceType = Field(description="The type of source")
|
||||
source_api_response: Optional[JsonValue] = Field(
|
||||
description="The original API response from the source", default=None
|
||||
)
|
||||
trigger_words: Optional[set[str]] = Field(description="Set of trigger words for this model", default=None)
|
||||
|
||||
model_config = ConfigDict(use_enum_values=False, validate_assignment=True)
|
||||
|
||||
|
@ -19,7 +19,6 @@ from typing import Optional
|
||||
|
||||
import requests
|
||||
from huggingface_hub import HfApi, configure_http_backend, hf_hub_url
|
||||
from huggingface_hub.hf_api import RepoSibling
|
||||
from huggingface_hub.utils._errors import RepositoryNotFoundError, RevisionNotFoundError
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
@ -61,6 +60,7 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
|
||||
# Little loop which tries fetching a revision corresponding to the selected variant.
|
||||
# If not available, then set variant to None and get the default.
|
||||
# If this too fails, raise exception.
|
||||
|
||||
model_info = None
|
||||
while not model_info:
|
||||
try:
|
||||
@ -73,12 +73,23 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
|
||||
else:
|
||||
variant = None
|
||||
|
||||
files: list[RemoteModelFile] = []
|
||||
|
||||
_, name = id.split("/")
|
||||
return HuggingFaceMetadata(
|
||||
id=model_info.id,
|
||||
name=name,
|
||||
files=parse_siblings(id, model_info.siblings, variant),
|
||||
)
|
||||
|
||||
for s in model_info.siblings or []:
|
||||
assert s.rfilename is not None
|
||||
assert s.size is not None
|
||||
files.append(
|
||||
RemoteModelFile(
|
||||
url=hf_hub_url(id, s.rfilename, revision=variant),
|
||||
path=Path(name, s.rfilename),
|
||||
size=s.size,
|
||||
sha256=s.lfs.get("sha256") if s.lfs else None,
|
||||
)
|
||||
)
|
||||
|
||||
return HuggingFaceMetadata(id=model_info.id, name=name, files=files)
|
||||
|
||||
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
@ -91,27 +102,3 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
|
||||
return self.from_id(repo_id)
|
||||
else:
|
||||
raise UnknownMetadataException(f"'{url}' does not look like a HuggingFace model page")
|
||||
|
||||
|
||||
def parse_siblings(
|
||||
repo_id: str, siblings: Optional[list[RepoSibling]] = None, variant: Optional[ModelRepoVariant] = None
|
||||
) -> list[RemoteModelFile]:
|
||||
"""Parse the siblings list from the HuggingFace API into a list of RemoteModelFile objects."""
|
||||
if not siblings:
|
||||
return []
|
||||
|
||||
files: list[RemoteModelFile] = []
|
||||
|
||||
for s in siblings:
|
||||
assert s.rfilename is not None
|
||||
assert s.size is not None
|
||||
files.append(
|
||||
RemoteModelFile(
|
||||
url=hf_hub_url(repo_id, s.rfilename, revision=variant.value if variant else None),
|
||||
path=Path(s.rfilename),
|
||||
size=s.size,
|
||||
sha256=s.lfs.get("sha256") if s.lfs else None,
|
||||
)
|
||||
)
|
||||
|
||||
return files
|
||||
|
@ -8,6 +8,7 @@ import torch
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.util.util import SilenceWarnings
|
||||
|
||||
from .config import (
|
||||
@ -149,6 +150,9 @@ class ModelProbe(object):
|
||||
|
||||
probe = probe_class(model_path)
|
||||
|
||||
fields["source_type"] = fields.get("source_type")
|
||||
fields["source"] = fields.get("source") or model_path.as_posix()
|
||||
fields["key"] = fields.get("key", uuid_string())
|
||||
fields["path"] = model_path.as_posix()
|
||||
fields["type"] = fields.get("type") or model_type
|
||||
fields["base"] = fields.get("base") or probe.get_base_type()
|
||||
@ -162,7 +166,7 @@ class ModelProbe(object):
|
||||
fields["format"] = fields.get("format") or probe.get_format()
|
||||
fields["hash"] = fields.get("hash") or ModelHash().hash(model_path)
|
||||
|
||||
if format_type == ModelFormat.Diffusers and hasattr(probe, "get_repo_variant"):
|
||||
if format_type == ModelFormat.Diffusers and isinstance(probe, FolderProbeBase):
|
||||
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
|
||||
|
||||
# additional fields needed for main and controlnet models
|
||||
|
Loading…
Reference in New Issue
Block a user