mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor(mm): wip schema changes
This commit is contained in:
@ -37,6 +37,7 @@ from invokeai.backend.model_manager.metadata import (
|
|||||||
ModelMetadataWithFiles,
|
ModelMetadataWithFiles,
|
||||||
RemoteModelFile,
|
RemoteModelFile,
|
||||||
)
|
)
|
||||||
|
from invokeai.backend.model_manager.metadata.metadata_base import CivitaiMetadata, HuggingFaceMetadata
|
||||||
from invokeai.backend.model_manager.probe import ModelProbe
|
from invokeai.backend.model_manager.probe import ModelProbe
|
||||||
from invokeai.backend.model_manager.search import ModelSearch
|
from invokeai.backend.model_manager.search import ModelSearch
|
||||||
from invokeai.backend.util import Chdir, InvokeAILogger
|
from invokeai.backend.util import Chdir, InvokeAILogger
|
||||||
@ -152,9 +153,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
) -> str: # noqa D102
|
) -> str: # noqa D102
|
||||||
model_path = Path(model_path)
|
model_path = Path(model_path)
|
||||||
config = config or {}
|
config = config or {}
|
||||||
if not config.get("source"):
|
|
||||||
config["source"] = model_path.resolve().as_posix()
|
|
||||||
config["key"] = config.get("key", uuid_string())
|
|
||||||
|
|
||||||
info: AnyModelConfig = ModelProbe.probe(Path(model_path), config)
|
info: AnyModelConfig = ModelProbe.probe(Path(model_path), config)
|
||||||
|
|
||||||
@ -379,15 +377,17 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._signal_job_running(job)
|
self._signal_job_running(job)
|
||||||
job.config_in["source"] = str(job.source)
|
job.config_in["source"] = str(job.source)
|
||||||
job.config_in["source_type"] = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
|
job.config_in["source_type"] = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
|
||||||
|
# enter the metadata, if there is any
|
||||||
|
if isinstance(job.source_metadata, (CivitaiMetadata, HuggingFaceMetadata)):
|
||||||
|
job.config_in["source_api_response"] = job.source_metadata.api_response
|
||||||
|
if isinstance(job.source_metadata, CivitaiMetadata) and job.source_metadata.trigger_words:
|
||||||
|
job.config_in["trigger_words"] = job.source_metadata.trigger_words
|
||||||
|
|
||||||
if job.inplace:
|
if job.inplace:
|
||||||
key = self.register_path(job.local_path, job.config_in)
|
key = self.register_path(job.local_path, job.config_in)
|
||||||
else:
|
else:
|
||||||
key = self.install_path(job.local_path, job.config_in)
|
key = self.install_path(job.local_path, job.config_in)
|
||||||
job.config_out = self.record_store.get_model(key)
|
job.config_out = self.record_store.get_model(key)
|
||||||
|
|
||||||
# enter the metadata, if there is any
|
|
||||||
if job.source_metadata:
|
|
||||||
self._metadata_store.add_metadata(key, job.source_metadata)
|
|
||||||
self._signal_job_completed(job)
|
self._signal_job_completed(job)
|
||||||
|
|
||||||
except InvalidModelConfigException as excp:
|
except InvalidModelConfigException as excp:
|
||||||
@ -525,6 +525,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
move(old_path, new_path)
|
move(old_path, new_path)
|
||||||
return new_path
|
return new_path
|
||||||
|
|
||||||
|
# def _probe_model(self, model_path: Path, config: Optional[Dict[str, Any]] = None) -> AnyModelConfig:
|
||||||
|
# info: AnyModelConfig = ModelProbe.probe(Path(model_path))
|
||||||
|
# if config: # used to override probe fields
|
||||||
|
# for key, value in config.items():
|
||||||
|
# setattr(info, key, value)
|
||||||
|
# return info
|
||||||
|
|
||||||
def _register(
|
def _register(
|
||||||
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
|
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
|
||||||
) -> str:
|
) -> str:
|
||||||
|
@ -24,6 +24,7 @@ class Migration7Callback:
|
|||||||
source TEXT GENERATED ALWAYS as (json_extract(config, '$.source')) VIRTUAL NOT NULL,
|
source TEXT GENERATED ALWAYS as (json_extract(config, '$.source')) VIRTUAL NOT NULL,
|
||||||
source_type TEXT GENERATED ALWAYS as (json_extract(config, '$.source_type')) VIRTUAL NOT NULL,
|
source_type TEXT GENERATED ALWAYS as (json_extract(config, '$.source_type')) VIRTUAL NOT NULL,
|
||||||
source_api_response TEXT GENERATED ALWAYS as (json_extract(config, '$.source_api_response')) VIRTUAL,
|
source_api_response TEXT GENERATED ALWAYS as (json_extract(config, '$.source_api_response')) VIRTUAL,
|
||||||
|
trigger_words TEXT GENERATED ALWAYS as (json_extract(config, '$.trigger_words')) VIRTUAL,
|
||||||
-- Serialized JSON representation of the whole config object, which will contain additional fields from subclasses
|
-- Serialized JSON representation of the whole config object, which will contain additional fields from subclasses
|
||||||
config TEXT NOT NULL,
|
config TEXT NOT NULL,
|
||||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
|
@ -26,7 +26,7 @@ from typing import Literal, Optional, Type, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers.models.modeling_utils import ModelMixin
|
from diffusers.models.modeling_utils import ModelMixin
|
||||||
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
|
from pydantic import BaseModel, ConfigDict, Discriminator, Field, JsonValue, Tag, TypeAdapter
|
||||||
from typing_extensions import Annotated, Any, Dict
|
from typing_extensions import Annotated, Any, Dict
|
||||||
|
|
||||||
from ..raw_model import RawModel
|
from ..raw_model import RawModel
|
||||||
@ -142,6 +142,10 @@ class ModelConfigBase(BaseModel):
|
|||||||
description: Optional[str] = Field(description="Model description", default=None)
|
description: Optional[str] = Field(description="Model description", default=None)
|
||||||
source: str = Field(description="The original source of the model (path, URL or repo_id).")
|
source: str = Field(description="The original source of the model (path, URL or repo_id).")
|
||||||
source_type: ModelSourceType = Field(description="The type of source")
|
source_type: ModelSourceType = Field(description="The type of source")
|
||||||
|
source_api_response: Optional[JsonValue] = Field(
|
||||||
|
description="The original API response from the source", default=None
|
||||||
|
)
|
||||||
|
trigger_words: Optional[set[str]] = Field(description="Set of trigger words for this model", default=None)
|
||||||
|
|
||||||
model_config = ConfigDict(use_enum_values=False, validate_assignment=True)
|
model_config = ConfigDict(use_enum_values=False, validate_assignment=True)
|
||||||
|
|
||||||
|
@ -19,7 +19,6 @@ from typing import Optional
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
from huggingface_hub import HfApi, configure_http_backend, hf_hub_url
|
from huggingface_hub import HfApi, configure_http_backend, hf_hub_url
|
||||||
from huggingface_hub.hf_api import RepoSibling
|
|
||||||
from huggingface_hub.utils._errors import RepositoryNotFoundError, RevisionNotFoundError
|
from huggingface_hub.utils._errors import RepositoryNotFoundError, RevisionNotFoundError
|
||||||
from pydantic.networks import AnyHttpUrl
|
from pydantic.networks import AnyHttpUrl
|
||||||
from requests.sessions import Session
|
from requests.sessions import Session
|
||||||
@ -61,6 +60,7 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
|
|||||||
# Little loop which tries fetching a revision corresponding to the selected variant.
|
# Little loop which tries fetching a revision corresponding to the selected variant.
|
||||||
# If not available, then set variant to None and get the default.
|
# If not available, then set variant to None and get the default.
|
||||||
# If this too fails, raise exception.
|
# If this too fails, raise exception.
|
||||||
|
|
||||||
model_info = None
|
model_info = None
|
||||||
while not model_info:
|
while not model_info:
|
||||||
try:
|
try:
|
||||||
@ -73,12 +73,23 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
|
|||||||
else:
|
else:
|
||||||
variant = None
|
variant = None
|
||||||
|
|
||||||
|
files: list[RemoteModelFile] = []
|
||||||
|
|
||||||
_, name = id.split("/")
|
_, name = id.split("/")
|
||||||
return HuggingFaceMetadata(
|
|
||||||
id=model_info.id,
|
for s in model_info.siblings or []:
|
||||||
name=name,
|
assert s.rfilename is not None
|
||||||
files=parse_siblings(id, model_info.siblings, variant),
|
assert s.size is not None
|
||||||
|
files.append(
|
||||||
|
RemoteModelFile(
|
||||||
|
url=hf_hub_url(id, s.rfilename, revision=variant),
|
||||||
|
path=Path(name, s.rfilename),
|
||||||
|
size=s.size,
|
||||||
|
sha256=s.lfs.get("sha256") if s.lfs else None,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return HuggingFaceMetadata(id=model_info.id, name=name, files=files)
|
||||||
|
|
||||||
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
||||||
"""
|
"""
|
||||||
@ -91,27 +102,3 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
|
|||||||
return self.from_id(repo_id)
|
return self.from_id(repo_id)
|
||||||
else:
|
else:
|
||||||
raise UnknownMetadataException(f"'{url}' does not look like a HuggingFace model page")
|
raise UnknownMetadataException(f"'{url}' does not look like a HuggingFace model page")
|
||||||
|
|
||||||
|
|
||||||
def parse_siblings(
|
|
||||||
repo_id: str, siblings: Optional[list[RepoSibling]] = None, variant: Optional[ModelRepoVariant] = None
|
|
||||||
) -> list[RemoteModelFile]:
|
|
||||||
"""Parse the siblings list from the HuggingFace API into a list of RemoteModelFile objects."""
|
|
||||||
if not siblings:
|
|
||||||
return []
|
|
||||||
|
|
||||||
files: list[RemoteModelFile] = []
|
|
||||||
|
|
||||||
for s in siblings:
|
|
||||||
assert s.rfilename is not None
|
|
||||||
assert s.size is not None
|
|
||||||
files.append(
|
|
||||||
RemoteModelFile(
|
|
||||||
url=hf_hub_url(repo_id, s.rfilename, revision=variant.value if variant else None),
|
|
||||||
path=Path(s.rfilename),
|
|
||||||
size=s.size,
|
|
||||||
sha256=s.lfs.get("sha256") if s.lfs else None,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return files
|
|
||||||
|
@ -8,6 +8,7 @@ import torch
|
|||||||
from picklescan.scanner import scan_file_path
|
from picklescan.scanner import scan_file_path
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
|
from invokeai.app.util.misc import uuid_string
|
||||||
from invokeai.backend.util.util import SilenceWarnings
|
from invokeai.backend.util.util import SilenceWarnings
|
||||||
|
|
||||||
from .config import (
|
from .config import (
|
||||||
@ -149,6 +150,9 @@ class ModelProbe(object):
|
|||||||
|
|
||||||
probe = probe_class(model_path)
|
probe = probe_class(model_path)
|
||||||
|
|
||||||
|
fields["source_type"] = fields.get("source_type")
|
||||||
|
fields["source"] = fields.get("source") or model_path.as_posix()
|
||||||
|
fields["key"] = fields.get("key", uuid_string())
|
||||||
fields["path"] = model_path.as_posix()
|
fields["path"] = model_path.as_posix()
|
||||||
fields["type"] = fields.get("type") or model_type
|
fields["type"] = fields.get("type") or model_type
|
||||||
fields["base"] = fields.get("base") or probe.get_base_type()
|
fields["base"] = fields.get("base") or probe.get_base_type()
|
||||||
@ -162,7 +166,7 @@ class ModelProbe(object):
|
|||||||
fields["format"] = fields.get("format") or probe.get_format()
|
fields["format"] = fields.get("format") or probe.get_format()
|
||||||
fields["hash"] = fields.get("hash") or ModelHash().hash(model_path)
|
fields["hash"] = fields.get("hash") or ModelHash().hash(model_path)
|
||||||
|
|
||||||
if format_type == ModelFormat.Diffusers and hasattr(probe, "get_repo_variant"):
|
if format_type == ModelFormat.Diffusers and isinstance(probe, FolderProbeBase):
|
||||||
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
|
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
|
||||||
|
|
||||||
# additional fields needed for main and controlnet models
|
# additional fields needed for main and controlnet models
|
||||||
|
Reference in New Issue
Block a user