mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Model Manager Refactor: Install remote models and store their tags and other metadata (#5361)
* add basic functionality for model metadata fetching from hf and civitai * add storage * start unit tests * add unit tests and documentation * add missing dependency for pytests * remove redundant fetch; add modified/published dates; updated docs * add code to select diffusers files based on the variant type * implement Civitai installs * make huggingface parallel downloading work * add unit tests for model installation manager - Fixed race condition on selection of download destination path - Add fixtures common to several model_manager_2 unit tests - Added dummy model files for testing diffusers and safetensors downloading/probing - Refactored code for selecting proper variant from list of huggingface repo files - Regrouped ordering of methods in model_install_default.py * improve Civitai model downloading - Provide a better error message when Civitai requires an access token (doesn't give a 403 forbidden, but redirects to the HTML of an authorization page -- arrgh) - Handle case of Civitai providing a primary download link plus additional links for VAEs, config files, etc * add routes for retrieving metadata and tags * code tidying and documentation * fix ruff errors * add file needed to maintain test root diretory in repo for unit tests * fix self->cls in classmethod * add pydantic plugin for mypy * use TestSession instead of requests.Session to prevent any internet activity improve logging fix error message formatting fix logging again fix forward vs reverse slash issue in Windows install tests * Several fixes of problems detected during PR review: - Implement cancel_model_install_job and get_model_install_job routes to allow for better control of model download and install. - Fix thread deadlock that occurred after cancelling an install. - Remove unneeded pytest_plugins section from tests/conftest.py - Remove unused _in_terminal_state() from model_install_default. - Remove outdated documentation from several spots. - Add workaround for Civitai API results which don't return correct URL for the default model. * fix docs and tests to match get_job_by_source() rather than get_job() * Update invokeai/backend/model_manager/metadata/fetch/huggingface.py Co-authored-by: Ryan Dick <ryanjdick3@gmail.com> * Call CivitaiMetadata.model_validate_json() directly Co-authored-by: Ryan Dick <ryanjdick3@gmail.com> * Second round of revisions suggested by @ryanjdick: - Fix type mismatch in `list_all_metadata()` route. - Do not have a default value for the model install job id - Remove static class variable declarations from non Pydantic classes - Change `id` field to `model_id` for the sqlite3 `model_tags` table. - Changed AFTER DELETE triggers to ON DELETE CASCADE for the metadata and tags tables. - Made the `id` field of the `model_metadata` table into a primary key to achieve uniqueness. * Code cleanup suggested in PR review: - Narrowed the declaration of the `parts` attribute of the download progress event - Removed auto-conversion of str to Url in Url-containing sources - Fixed handling of `InvalidModelConfigException` - Made unknown sources raise `NotImplementedError` rather than `Exception` - Improved status reporting on cached HuggingFace access tokens * Multiple fixes: - `job.total_size` returns a valid size for locally installed models - new route `list_models` returns a paged summary of model, name, description, tags and other essential info - fix a few type errors * consolidated all invokeai root pytest fixtures into a single location * Update invokeai/backend/model_manager/metadata/metadata_store.py Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com> * Small tweaks in response to review comments: - Remove flake8 configuration from pyproject.toml - Use `id` rather than `modelId` for huggingface `ModelInfo` object - Use `last_modified` rather than `LastModified` for huggingface `ModelInfo` object - Add `sha256` field to file metadata downloaded from huggingface - Add `Invoker` argument to the model installer `start()` and `stop()` routines (but made it optional in order to facilitate use of the service outside the API) - Removed redundant `PRAGMA foreign_keys` from metadata store initialization code. * Additional tweaks and minor bug fixes - Fix calculation of aggregate diffusers model size to only count the size of files, not files + directories (which gives different unit test results on different filesystems). - Refactor _get_metadata() and _get_download_urls() to have distinct code paths for Civitai, HuggingFace and URL sources. - Forward the `inplace` flag from the source to the job and added unit test for this. - Attach cached model metadata to the job rather than to the model install service. * fix unit test that was breaking on windows due to CR/LF changing size of test json files * fix ruff formatting * a few last minor fixes before merging: - Turn job `error` and `error_type` into properties derived from the exception. - Add TODO comment about the reason for handling temporary directory destruction manually rather than using tempfile.tmpdir(). * add unit tests for reporting HTTP download errors --------- Co-authored-by: Lincoln Stein <lstein@gmail.com> Co-authored-by: Ryan Dick <ryanjdick3@gmail.com> Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
This commit is contained in:
@ -6,6 +6,7 @@ from .config import (
|
||||
InvalidModelConfigException,
|
||||
ModelConfigFactory,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
@ -15,15 +16,16 @@ from .probe import ModelProbe
|
||||
from .search import ModelSearch
|
||||
|
||||
__all__ = [
|
||||
"ModelProbe",
|
||||
"ModelSearch",
|
||||
"AnyModelConfig",
|
||||
"BaseModelType",
|
||||
"ModelRepoVariant",
|
||||
"InvalidModelConfigException",
|
||||
"ModelConfigFactory",
|
||||
"BaseModelType",
|
||||
"ModelType",
|
||||
"SubModelType",
|
||||
"ModelVariantType",
|
||||
"ModelFormat",
|
||||
"ModelProbe",
|
||||
"ModelSearch",
|
||||
"ModelType",
|
||||
"ModelVariantType",
|
||||
"SchedulerPredictionType",
|
||||
"AnyModelConfig",
|
||||
"SubModelType",
|
||||
]
|
||||
|
@ -99,6 +99,17 @@ class SchedulerPredictionType(str, Enum):
|
||||
Sample = "sample"
|
||||
|
||||
|
||||
class ModelRepoVariant(str, Enum):
|
||||
"""Various hugging face variants on the diffusers format."""
|
||||
|
||||
DEFAULT = "default" # model files without "fp16" or other qualifier
|
||||
FP16 = "fp16"
|
||||
FP32 = "fp32"
|
||||
ONNX = "onnx"
|
||||
OPENVINO = "openvino"
|
||||
FLAX = "flax"
|
||||
|
||||
|
||||
class ModelConfigBase(BaseModel):
|
||||
"""Base class for model configuration information."""
|
||||
|
||||
|
50
invokeai/backend/model_manager/metadata/__init__.py
Normal file
50
invokeai/backend/model_manager/metadata/__init__.py
Normal file
@ -0,0 +1,50 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.model_manager.metadata
|
||||
|
||||
Usage:
|
||||
|
||||
from invokeai.backend.model_manager.metadata import(
|
||||
AnyModelRepoMetadata,
|
||||
CommercialUsage,
|
||||
LicenseRestrictions,
|
||||
HuggingFaceMetadata,
|
||||
CivitaiMetadata,
|
||||
)
|
||||
|
||||
from invokeai.backend.model_manager.metadata.fetch import CivitaiMetadataFetch
|
||||
|
||||
data = CivitaiMetadataFetch().from_url("https://civitai.com/models/206883/split")
|
||||
assert isinstance(data, CivitaiMetadata)
|
||||
if data.allow_commercial_use:
|
||||
print("Commercial use of this model is allowed")
|
||||
"""
|
||||
from .fetch import CivitaiMetadataFetch, HuggingFaceMetadataFetch
|
||||
from .metadata_base import (
|
||||
AnyModelRepoMetadata,
|
||||
AnyModelRepoMetadataValidator,
|
||||
BaseMetadata,
|
||||
CivitaiMetadata,
|
||||
CommercialUsage,
|
||||
HuggingFaceMetadata,
|
||||
LicenseRestrictions,
|
||||
ModelMetadataWithFiles,
|
||||
RemoteModelFile,
|
||||
UnknownMetadataException,
|
||||
)
|
||||
from .metadata_store import ModelMetadataStore
|
||||
|
||||
__all__ = [
|
||||
"AnyModelRepoMetadata",
|
||||
"AnyModelRepoMetadataValidator",
|
||||
"CivitaiMetadata",
|
||||
"CivitaiMetadataFetch",
|
||||
"CommercialUsage",
|
||||
"HuggingFaceMetadata",
|
||||
"HuggingFaceMetadataFetch",
|
||||
"LicenseRestrictions",
|
||||
"ModelMetadataStore",
|
||||
"BaseMetadata",
|
||||
"ModelMetadataWithFiles",
|
||||
"RemoteModelFile",
|
||||
"UnknownMetadataException",
|
||||
]
|
21
invokeai/backend/model_manager/metadata/fetch/__init__.py
Normal file
21
invokeai/backend/model_manager/metadata/fetch/__init__.py
Normal file
@ -0,0 +1,21 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.model_manager.metadata.fetch
|
||||
|
||||
Usage:
|
||||
from invokeai.backend.model_manager.metadata.fetch import (
|
||||
CivitaiMetadataFetch,
|
||||
HuggingFaceMetadataFetch,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import CivitaiMetadata
|
||||
|
||||
data = CivitaiMetadataFetch().from_url("https://civitai.com/models/206883/split")
|
||||
assert isinstance(data, CivitaiMetadata)
|
||||
if data.allow_commercial_use:
|
||||
print("Commercial use of this model is allowed")
|
||||
"""
|
||||
|
||||
from .civitai import CivitaiMetadataFetch
|
||||
from .fetch_base import ModelMetadataFetchBase
|
||||
from .huggingface import HuggingFaceMetadataFetch
|
||||
|
||||
__all__ = ["ModelMetadataFetchBase", "CivitaiMetadataFetch", "HuggingFaceMetadataFetch"]
|
187
invokeai/backend/model_manager/metadata/fetch/civitai.py
Normal file
187
invokeai/backend/model_manager/metadata/fetch/civitai.py
Normal file
@ -0,0 +1,187 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
|
||||
"""
|
||||
This module fetches model metadata objects from the Civitai model repository.
|
||||
In addition to the `from_url()` and `from_id()` methods inherited from the
|
||||
`ModelMetadataFetchBase` base class.
|
||||
|
||||
Civitai has two separate ID spaces: a model ID and a version ID. The
|
||||
version ID corresponds to a specific model, and is the ID accepted by
|
||||
`from_id()`. The model ID corresponds to a family of related models,
|
||||
such as different training checkpoints or 16 vs 32-bit versions. The
|
||||
`from_civitai_modelid()` method will accept a model ID and return the
|
||||
metadata from the default version within this model set. The default
|
||||
version is the same as what the user sees when they click on a model's
|
||||
thumbnail.
|
||||
|
||||
Usage:
|
||||
|
||||
from invokeai.backend.model_manager.metadata.fetch import CivitaiMetadataFetch
|
||||
|
||||
fetcher = CivitaiMetadataFetch()
|
||||
metadata = fetcher.from_url("https://civitai.com/models/206883/split")
|
||||
print(metadata.trained_words)
|
||||
"""
|
||||
|
||||
import re
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import requests
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
|
||||
from ..metadata_base import (
|
||||
AnyModelRepoMetadata,
|
||||
CivitaiMetadata,
|
||||
CommercialUsage,
|
||||
LicenseRestrictions,
|
||||
RemoteModelFile,
|
||||
UnknownMetadataException,
|
||||
)
|
||||
from .fetch_base import ModelMetadataFetchBase
|
||||
|
||||
CIVITAI_MODEL_PAGE_RE = r"https?://civitai.com/models/(\d+)"
|
||||
CIVITAI_VERSION_PAGE_RE = r"https?://civitai.com/models/(\d+)\?modelVersionId=(\d+)"
|
||||
CIVITAI_DOWNLOAD_RE = r"https?://civitai.com/api/download/models/(\d+)"
|
||||
|
||||
CIVITAI_VERSION_ENDPOINT = "https://civitai.com/api/v1/model-versions/"
|
||||
CIVITAI_MODEL_ENDPOINT = "https://civitai.com/api/v1/models/"
|
||||
|
||||
|
||||
class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
"""Fetch model metadata from Civitai."""
|
||||
|
||||
def __init__(self, session: Optional[Session] = None):
|
||||
"""
|
||||
Initialize the fetcher with an optional requests.sessions.Session object.
|
||||
|
||||
By providing a configurable Session object, we can support unit tests on
|
||||
this module without an internet connection.
|
||||
"""
|
||||
self._requests = session or requests.Session()
|
||||
|
||||
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Given a URL to a CivitAI model or version page, return a ModelMetadata object.
|
||||
|
||||
In the event that the URL points to a model page without the particular version
|
||||
indicated, the default model version is returned. Otherwise, the requested version
|
||||
is returned.
|
||||
"""
|
||||
if match := re.match(CIVITAI_VERSION_PAGE_RE, str(url), re.IGNORECASE):
|
||||
model_id = match.group(1)
|
||||
version_id = match.group(2)
|
||||
return self.from_civitai_versionid(int(version_id), int(model_id))
|
||||
elif match := re.match(CIVITAI_MODEL_PAGE_RE, str(url), re.IGNORECASE):
|
||||
model_id = match.group(1)
|
||||
return self.from_civitai_modelid(int(model_id))
|
||||
elif match := re.match(CIVITAI_DOWNLOAD_RE, str(url), re.IGNORECASE):
|
||||
version_id = match.group(1)
|
||||
return self.from_civitai_versionid(int(version_id))
|
||||
raise UnknownMetadataException("The url '{url}' does not match any known Civitai URL patterns")
|
||||
|
||||
def from_id(self, id: str) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Given a Civitai model version ID, return a ModelRepoMetadata object.
|
||||
|
||||
May raise an `UnknownMetadataException`.
|
||||
"""
|
||||
return self.from_civitai_versionid(int(id))
|
||||
|
||||
def from_civitai_modelid(self, model_id: int) -> CivitaiMetadata:
|
||||
"""
|
||||
Return metadata from the default version of the indicated model.
|
||||
|
||||
May raise an `UnknownMetadataException`.
|
||||
"""
|
||||
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
|
||||
model_json = self._requests.get(model_url).json()
|
||||
return self._from_model_json(model_json)
|
||||
|
||||
def _from_model_json(self, model_json: Dict[str, Any], version_id: Optional[int] = None) -> CivitaiMetadata:
|
||||
try:
|
||||
version_id = version_id or model_json["modelVersions"][0]["id"]
|
||||
except TypeError as excp:
|
||||
raise UnknownMetadataException from excp
|
||||
|
||||
# loop till we find the section containing the version requested
|
||||
version_sections = [x for x in model_json["modelVersions"] if x["id"] == version_id]
|
||||
if not version_sections:
|
||||
raise UnknownMetadataException(f"Version {version_id} not found in model metadata")
|
||||
|
||||
version_json = version_sections[0]
|
||||
safe_thumbnails = [x["url"] for x in version_json["images"] if x["nsfw"] == "None"]
|
||||
|
||||
# Civitai has one "primary" file plus others such as VAEs. We only fetch the primary.
|
||||
primary = [x for x in version_json["files"] if x.get("primary")]
|
||||
assert len(primary) == 1
|
||||
primary_file = primary[0]
|
||||
|
||||
url = primary_file["downloadUrl"]
|
||||
if "?" not in url: # work around apparent bug in civitai api
|
||||
metadata_string = ""
|
||||
for key, value in primary_file["metadata"].items():
|
||||
if not value:
|
||||
continue
|
||||
metadata_string += f"&{key}={value}"
|
||||
url = url + f"?type={primary_file['type']}{metadata_string}"
|
||||
model_files = [
|
||||
RemoteModelFile(
|
||||
url=url,
|
||||
path=Path(primary_file["name"]),
|
||||
size=int(primary_file["sizeKB"] * 1024),
|
||||
sha256=primary_file["hashes"]["SHA256"],
|
||||
)
|
||||
]
|
||||
return CivitaiMetadata(
|
||||
id=model_json["id"],
|
||||
name=version_json["name"],
|
||||
version_id=version_json["id"],
|
||||
version_name=version_json["name"],
|
||||
created=datetime.fromisoformat(_fix_timezone(version_json["createdAt"])),
|
||||
updated=datetime.fromisoformat(_fix_timezone(version_json["updatedAt"])),
|
||||
published=datetime.fromisoformat(_fix_timezone(version_json["publishedAt"])),
|
||||
base_model_trained_on=version_json["baseModel"], # note - need a dictionary to turn into a BaseModelType
|
||||
files=model_files,
|
||||
download_url=version_json["downloadUrl"],
|
||||
thumbnail_url=safe_thumbnails[0] if safe_thumbnails else None,
|
||||
author=model_json["creator"]["username"],
|
||||
description=model_json["description"],
|
||||
version_description=version_json["description"] or "",
|
||||
tags=model_json["tags"],
|
||||
trained_words=version_json["trainedWords"],
|
||||
nsfw=model_json["nsfw"],
|
||||
restrictions=LicenseRestrictions(
|
||||
AllowNoCredit=model_json["allowNoCredit"],
|
||||
AllowCommercialUse=CommercialUsage(model_json["allowCommercialUse"]),
|
||||
AllowDerivatives=model_json["allowDerivatives"],
|
||||
AllowDifferentLicense=model_json["allowDifferentLicense"],
|
||||
),
|
||||
)
|
||||
|
||||
def from_civitai_versionid(self, version_id: int, model_id: Optional[int] = None) -> CivitaiMetadata:
|
||||
"""
|
||||
Return a CivitaiMetadata object given a model version id.
|
||||
|
||||
May raise an `UnknownMetadataException`.
|
||||
"""
|
||||
if model_id is None:
|
||||
version_url = CIVITAI_VERSION_ENDPOINT + str(version_id)
|
||||
version = self._requests.get(version_url).json()
|
||||
model_id = version["modelId"]
|
||||
|
||||
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
|
||||
model_json = self._requests.get(model_url).json()
|
||||
return self._from_model_json(model_json, version_id)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json: str) -> CivitaiMetadata:
|
||||
"""Given the JSON representation of the metadata, return the corresponding Pydantic object."""
|
||||
metadata = CivitaiMetadata.model_validate_json(json)
|
||||
return metadata
|
||||
|
||||
|
||||
def _fix_timezone(date: str) -> str:
|
||||
return re.sub(r"Z$", "+00:00", date)
|
61
invokeai/backend/model_manager/metadata/fetch/fetch_base.py
Normal file
61
invokeai/backend/model_manager/metadata/fetch/fetch_base.py
Normal file
@ -0,0 +1,61 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
|
||||
"""
|
||||
This module is the base class for subclasses that fetch metadata from model repositories
|
||||
|
||||
Usage:
|
||||
|
||||
from invokeai.backend.model_manager.metadata.fetch import CivitAIMetadataFetch
|
||||
|
||||
fetcher = CivitaiMetadataFetch()
|
||||
metadata = fetcher.from_url("https://civitai.com/models/206883/split")
|
||||
print(metadata.trained_words)
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
|
||||
from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator
|
||||
|
||||
|
||||
class ModelMetadataFetchBase(ABC):
|
||||
"""Fetch metadata from remote generative model repositories."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, session: Optional[Session] = None):
|
||||
"""
|
||||
Initialize the fetcher with an optional requests.sessions.Session object.
|
||||
|
||||
By providing a configurable Session object, we can support unit tests on
|
||||
this module without an internet connection.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Given a URL to a model repository, return a ModelMetadata object.
|
||||
|
||||
This method will raise a `UnknownMetadataException`
|
||||
in the event that the requested model metadata is not found at the provided location.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def from_id(self, id: str) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Given an ID for a model, return a ModelMetadata object.
|
||||
|
||||
This method will raise a `UnknownMetadataException`
|
||||
in the event that the requested model's metadata is not found at the provided id.
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json: str) -> AnyModelRepoMetadata:
|
||||
"""Given the JSON representation of the metadata, return the corresponding Pydantic object."""
|
||||
metadata = AnyModelRepoMetadataValidator.validate_json(json)
|
||||
return metadata
|
92
invokeai/backend/model_manager/metadata/fetch/huggingface.py
Normal file
92
invokeai/backend/model_manager/metadata/fetch/huggingface.py
Normal file
@ -0,0 +1,92 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
|
||||
"""
|
||||
This module fetches model metadata objects from the HuggingFace model repository,
|
||||
using either a `repo_id` or the model page URL.
|
||||
|
||||
Usage:
|
||||
|
||||
from invokeai.backend.model_manager.metadata.fetch import HuggingFaceMetadataFetch
|
||||
|
||||
fetcher = HuggingFaceMetadataFetch()
|
||||
metadata = fetcher.from_url("https://huggingface.co/stabilityai/sdxl-turbo")
|
||||
print(metadata.tags)
|
||||
"""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from huggingface_hub import HfApi, configure_http_backend, hf_hub_url
|
||||
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
|
||||
from ..metadata_base import (
|
||||
AnyModelRepoMetadata,
|
||||
HuggingFaceMetadata,
|
||||
RemoteModelFile,
|
||||
UnknownMetadataException,
|
||||
)
|
||||
from .fetch_base import ModelMetadataFetchBase
|
||||
|
||||
HF_MODEL_RE = r"https?://huggingface.co/([\w\-.]+/[\w\-.]+)"
|
||||
|
||||
|
||||
class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
|
||||
"""Fetch model metadata from HuggingFace."""
|
||||
|
||||
def __init__(self, session: Optional[Session] = None):
|
||||
"""
|
||||
Initialize the fetcher with an optional requests.sessions.Session object.
|
||||
|
||||
By providing a configurable Session object, we can support unit tests on
|
||||
this module without an internet connection.
|
||||
"""
|
||||
self._requests = session or requests.Session()
|
||||
configure_http_backend(backend_factory=lambda: self._requests)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json: str) -> HuggingFaceMetadata:
|
||||
"""Given the JSON representation of the metadata, return the corresponding Pydantic object."""
|
||||
metadata = HuggingFaceMetadata.model_validate_json(json)
|
||||
return metadata
|
||||
|
||||
def from_id(self, id: str) -> AnyModelRepoMetadata:
|
||||
"""Return a HuggingFaceMetadata object given the model's repo_id."""
|
||||
try:
|
||||
model_info = HfApi().model_info(repo_id=id, files_metadata=True)
|
||||
except RepositoryNotFoundError as excp:
|
||||
raise UnknownMetadataException(f"'{id}' not found. See trace for details.") from excp
|
||||
|
||||
_, name = id.split("/")
|
||||
return HuggingFaceMetadata(
|
||||
id=model_info.id,
|
||||
author=model_info.author,
|
||||
name=name,
|
||||
last_modified=model_info.last_modified,
|
||||
tag_dict=model_info.card_data.to_dict() if model_info.card_data else {},
|
||||
tags=model_info.tags,
|
||||
files=[
|
||||
RemoteModelFile(
|
||||
url=hf_hub_url(id, x.rfilename),
|
||||
path=Path(name, x.rfilename),
|
||||
size=x.size,
|
||||
sha256=x.lfs.get("sha256") if x.lfs else None,
|
||||
)
|
||||
for x in model_info.siblings
|
||||
],
|
||||
)
|
||||
|
||||
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Return a HuggingFaceMetadata object given the model's web page URL.
|
||||
|
||||
In the case of an invalid or missing URL, raises a ModelNotFound exception.
|
||||
"""
|
||||
if match := re.match(HF_MODEL_RE, str(url), re.IGNORECASE):
|
||||
repo_id = match.group(1)
|
||||
return self.from_id(repo_id)
|
||||
else:
|
||||
raise UnknownMetadataException(f"'{url}' does not look like a HuggingFace model page")
|
202
invokeai/backend/model_manager/metadata/metadata_base.py
Normal file
202
invokeai/backend/model_manager/metadata/metadata_base.py
Normal file
@ -0,0 +1,202 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
|
||||
"""This module defines core text-to-image model metadata fields.
|
||||
|
||||
Metadata comprises any descriptive information that is not essential
|
||||
for getting the model to run. For example "author" is metadata, while
|
||||
"type", "base" and "format" are not. The latter fields are part of the
|
||||
model's config, as defined in invokeai.backend.model_manager.config.
|
||||
|
||||
Note that the "name" and "description" are also present in `config`
|
||||
records. This is intentional. The config record fields are intended to
|
||||
be editable by the user as a form of customization. The metadata
|
||||
versions of these fields are intended to be kept in sync with the
|
||||
remote repo.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union
|
||||
|
||||
from huggingface_hub import configure_http_backend, hf_hub_url
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.backend.model_manager import ModelRepoVariant
|
||||
|
||||
from ..util import select_hf_files
|
||||
|
||||
|
||||
class UnknownMetadataException(Exception):
|
||||
"""Raised when no metadata is available for a model."""
|
||||
|
||||
|
||||
class CommercialUsage(str, Enum):
|
||||
"""Type of commercial usage allowed."""
|
||||
|
||||
No = "None"
|
||||
Image = "Image"
|
||||
Rent = "Rent"
|
||||
RentCivit = "RentCivit"
|
||||
Sell = "Sell"
|
||||
|
||||
|
||||
class LicenseRestrictions(BaseModel):
|
||||
"""Broad categories of licensing restrictions."""
|
||||
|
||||
AllowNoCredit: bool = Field(
|
||||
description="if true, model can be redistributed without crediting author", default=False
|
||||
)
|
||||
AllowDerivatives: bool = Field(description="if true, derivatives of this model can be redistributed", default=False)
|
||||
AllowDifferentLicense: bool = Field(
|
||||
description="if true, derivatives of this model be redistributed under a different license", default=False
|
||||
)
|
||||
AllowCommercialUse: CommercialUsage = Field(
|
||||
description="Type of commercial use allowed or 'No' if no commercial use is allowed.", default_factory=set
|
||||
)
|
||||
|
||||
|
||||
class RemoteModelFile(BaseModel):
|
||||
"""Information about a downloadable file that forms part of a model."""
|
||||
|
||||
url: AnyHttpUrl = Field(description="The url to download this model file")
|
||||
path: Path = Field(description="The path to the file, relative to the model root")
|
||||
size: int = Field(description="The size of this file, in bytes")
|
||||
sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None)
|
||||
|
||||
|
||||
class ModelMetadataBase(BaseModel):
|
||||
"""Base class for model metadata information."""
|
||||
|
||||
name: str = Field(description="model's name")
|
||||
author: str = Field(description="model's author")
|
||||
tags: Set[str] = Field(description="tags provided by model source")
|
||||
|
||||
|
||||
class BaseMetadata(ModelMetadataBase):
|
||||
"""Adds typing data for discriminated union."""
|
||||
|
||||
type: Literal["basemetadata"] = "basemetadata"
|
||||
|
||||
|
||||
class ModelMetadataWithFiles(ModelMetadataBase):
|
||||
"""Base class for metadata that contains a list of downloadable model file(s)."""
|
||||
|
||||
files: List[RemoteModelFile] = Field(description="model files and their sizes", default_factory=list)
|
||||
|
||||
def download_urls(
|
||||
self,
|
||||
variant: Optional[ModelRepoVariant] = None,
|
||||
subfolder: Optional[Path] = None,
|
||||
session: Optional[Session] = None,
|
||||
) -> List[RemoteModelFile]:
|
||||
"""
|
||||
Return a list of URLs needed to download the model.
|
||||
|
||||
:param variant: Return files needed to reconstruct the indicated variant (e.g. ModelRepoVariant('fp16'))
|
||||
:param subfolder: Return files in the designated subfolder only
|
||||
:param session: A request.Session object for offline testing
|
||||
|
||||
Note that the "variant" and "subfolder" concepts currently only apply to HuggingFace.
|
||||
However Civitai does have fields for the precision and format of its models, and may
|
||||
provide variant selection criteria in the future.
|
||||
"""
|
||||
return self.files
|
||||
|
||||
|
||||
class CivitaiMetadata(ModelMetadataWithFiles):
|
||||
"""Extended metadata fields provided by Civitai."""
|
||||
|
||||
type: Literal["civitai"] = "civitai"
|
||||
id: int = Field(description="Civitai version identifier")
|
||||
version_name: str = Field(description="Version identifier, such as 'V2-alpha'")
|
||||
version_id: int = Field(description="Civitai model version identifier")
|
||||
created: datetime = Field(description="date the model was created")
|
||||
updated: datetime = Field(description="date the model was last modified")
|
||||
published: datetime = Field(description="date the model was published to Civitai")
|
||||
description: str = Field(description="text description of model; may contain HTML")
|
||||
version_description: str = Field(
|
||||
description="text description of the model's reversion; usually change history; may contain HTML"
|
||||
)
|
||||
nsfw: bool = Field(description="whether the model tends to generate NSFW content", default=False)
|
||||
restrictions: LicenseRestrictions = Field(description="license terms", default_factory=LicenseRestrictions)
|
||||
trained_words: Set[str] = Field(description="words to trigger the model", default_factory=set)
|
||||
download_url: AnyHttpUrl = Field(description="download URL for this model")
|
||||
base_model_trained_on: str = Field(description="base model on which this model was trained (currently not an enum)")
|
||||
thumbnail_url: Optional[AnyHttpUrl] = Field(description="a thumbnail image for this model", default=None)
|
||||
weight_minmax: Tuple[float, float] = Field(
|
||||
description="minimum and maximum slider values for a LoRA or other secondary model", default=(-1.0, +2.0)
|
||||
) # note: For future use
|
||||
|
||||
@property
|
||||
def credit_required(self) -> bool:
|
||||
"""Return True if you must give credit for derivatives of this model and images generated from it."""
|
||||
return not self.restrictions.AllowNoCredit
|
||||
|
||||
@property
|
||||
def allow_commercial_use(self) -> bool:
|
||||
"""Return True if commercial use is allowed."""
|
||||
return self.restrictions.AllowCommercialUse != CommercialUsage("None")
|
||||
|
||||
@property
|
||||
def allow_derivatives(self) -> bool:
|
||||
"""Return True if derivatives of this model can be redistributed."""
|
||||
return self.restrictions.AllowDerivatives
|
||||
|
||||
@property
|
||||
def allow_different_license(self) -> bool:
|
||||
"""Return true if derivatives of this model can use a different license."""
|
||||
return self.restrictions.AllowDifferentLicense
|
||||
|
||||
|
||||
class HuggingFaceMetadata(ModelMetadataWithFiles):
|
||||
"""Extended metadata fields provided by HuggingFace."""
|
||||
|
||||
type: Literal["huggingface"] = "huggingface"
|
||||
id: str = Field(description="huggingface model id")
|
||||
tag_dict: Dict[str, Any]
|
||||
last_modified: datetime = Field(description="date of last commit to repo")
|
||||
|
||||
def download_urls(
|
||||
self,
|
||||
variant: Optional[ModelRepoVariant] = None,
|
||||
subfolder: Optional[Path] = None,
|
||||
session: Optional[Session] = None,
|
||||
) -> List[RemoteModelFile]:
|
||||
"""
|
||||
Return list of downloadable files, filtering by variant and subfolder, if any.
|
||||
|
||||
:param variant: Return model files needed to reconstruct the indicated variant
|
||||
:param subfolder: Return model files from the designated subfolder only
|
||||
:param session: A request.Session object used for internet-free testing
|
||||
|
||||
Note that there is special variant-filtering behavior here:
|
||||
When the fp16 variant is requested and not available, the
|
||||
full-precision model is returned.
|
||||
"""
|
||||
session = session or Session()
|
||||
configure_http_backend(backend_factory=lambda: session) # used in testing
|
||||
|
||||
paths = select_hf_files.filter_files(
|
||||
[x.path for x in self.files], variant, subfolder
|
||||
) # all files in the model
|
||||
prefix = f"{subfolder}/" if subfolder else ""
|
||||
|
||||
# the next step reads model_index.json to determine which subdirectories belong
|
||||
# to the model
|
||||
if Path(f"{prefix}model_index.json") in paths:
|
||||
url = hf_hub_url(self.id, filename="model_index.json", subfolder=subfolder)
|
||||
resp = session.get(url)
|
||||
resp.raise_for_status()
|
||||
submodels = resp.json()
|
||||
paths = [Path(subfolder or "", x) for x in paths if Path(x).parent.as_posix() in submodels]
|
||||
paths.insert(0, Path(f"{prefix}model_index.json"))
|
||||
|
||||
return [x for x in self.files if x.path in paths]
|
||||
|
||||
|
||||
AnyModelRepoMetadata = Annotated[Union[BaseMetadata, HuggingFaceMetadata, CivitaiMetadata], Field(discriminator="type")]
|
||||
AnyModelRepoMetadataValidator = TypeAdapter(AnyModelRepoMetadata)
|
221
invokeai/backend/model_manager/metadata/metadata_store.py
Normal file
221
invokeai/backend/model_manager/metadata/metadata_store.py
Normal file
@ -0,0 +1,221 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
SQL Storage for Model Metadata
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
from .fetch import ModelMetadataFetchBase
|
||||
from .metadata_base import AnyModelRepoMetadata, UnknownMetadataException
|
||||
|
||||
|
||||
class ModelMetadataStore:
|
||||
"""Store, search and fetch model metadata retrieved from remote repositories."""
|
||||
|
||||
def __init__(self, db: SqliteDatabase):
|
||||
"""
|
||||
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
||||
|
||||
:param conn: sqlite3 connection object
|
||||
:param lock: threading Lock object
|
||||
"""
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._cursor = self._db.conn.cursor()
|
||||
|
||||
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
|
||||
"""
|
||||
Add a block of repo metadata to a model record.
|
||||
|
||||
The model record config must already exist in the database with the
|
||||
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to store
|
||||
"""
|
||||
json_serialized = metadata.model_dump_json()
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO model_metadata(
|
||||
id,
|
||||
metadata
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(
|
||||
model_key,
|
||||
json_serialized,
|
||||
),
|
||||
)
|
||||
self._update_tags(model_key, metadata.tags)
|
||||
self._db.conn.commit()
|
||||
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
|
||||
self._db.conn.rollback()
|
||||
raise UnknownMetadataException from excp
|
||||
except sqlite3.Error as excp:
|
||||
self._db.conn.rollback()
|
||||
raise excp
|
||||
|
||||
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
|
||||
"""Retrieve the ModelRepoMetadata corresponding to model key."""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT metadata FROM model_metadata
|
||||
WHERE id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
rows = self._cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownMetadataException("model metadata not found")
|
||||
return ModelMetadataFetchBase.from_json(rows[0])
|
||||
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
|
||||
"""Dump out all the metadata."""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id,metadata FROM model_metadata;
|
||||
""",
|
||||
(),
|
||||
)
|
||||
rows = self._cursor.fetchall()
|
||||
return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows]
|
||||
|
||||
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Update metadata corresponding to the model with the indicated key.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to update
|
||||
"""
|
||||
json_serialized = metadata.model_dump_json() # turn it into a json string.
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE model_metadata
|
||||
SET
|
||||
metadata=?
|
||||
WHERE id=?;
|
||||
""",
|
||||
(json_serialized, model_key),
|
||||
)
|
||||
if self._cursor.rowcount == 0:
|
||||
raise UnknownMetadataException("model metadata not found")
|
||||
self._update_tags(model_key, metadata.tags)
|
||||
self._db.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
return self.get_metadata(model_key)
|
||||
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return all tags in the tags table."""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
select tag_text from tags;
|
||||
"""
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def search_by_tag(self, tags: Set[str]) -> Set[str]:
|
||||
"""Return the keys of models containing all of the listed tags."""
|
||||
with self._db.lock:
|
||||
try:
|
||||
matches: Optional[Set[str]] = None
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT a.model_id FROM model_tags AS a,
|
||||
tags AS b
|
||||
WHERE a.tag_id=b.tag_id
|
||||
AND b.tag_text=?;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
model_keys = {x[0] for x in self._cursor.fetchall()}
|
||||
if matches is None:
|
||||
matches = model_keys
|
||||
matches = matches.intersection(model_keys)
|
||||
except sqlite3.Error as e:
|
||||
raise e
|
||||
return matches if matches else set()
|
||||
|
||||
def search_by_author(self, author: str) -> Set[str]:
|
||||
"""Return the keys of models authored by the indicated author."""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id FROM model_metadata
|
||||
WHERE author=?;
|
||||
""",
|
||||
(author,),
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def search_by_name(self, name: str) -> Set[str]:
|
||||
"""
|
||||
Return the keys of models with the indicated name.
|
||||
|
||||
Note that this is the name of the model given to it by
|
||||
the remote source. The user may have changed the local
|
||||
name. The local name will be located in the model config
|
||||
record object.
|
||||
"""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id FROM model_metadata
|
||||
WHERE name=?;
|
||||
""",
|
||||
(name,),
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
|
||||
"""Update tags for the model referenced by model_key."""
|
||||
# remove previous tags from this model
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM model_tags
|
||||
WHERE model_id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO tags (
|
||||
tag_text
|
||||
)
|
||||
VALUES (?);
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT tag_id
|
||||
FROM tags
|
||||
WHERE tag_text = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
tag_id = self._cursor.fetchone()[0]
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO model_tags (
|
||||
model_id,
|
||||
tag_id
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(model_key, tag_id),
|
||||
)
|
@ -496,9 +496,9 @@ class PipelineFolderProbe(FolderProbeBase):
|
||||
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
|
||||
with open(self.model_path / "scheduler" / "scheduler_config.json", "r") as file:
|
||||
scheduler_conf = json.load(file)
|
||||
if scheduler_conf["prediction_type"] == "v_prediction":
|
||||
if scheduler_conf.get("prediction_type", "epsilon") == "v_prediction":
|
||||
return SchedulerPredictionType.VPrediction
|
||||
elif scheduler_conf["prediction_type"] == "epsilon":
|
||||
elif scheduler_conf.get("prediction_type", "epsilon") == "epsilon":
|
||||
return SchedulerPredictionType.Epsilon
|
||||
else:
|
||||
raise InvalidModelConfigException("Unknown scheduler prediction type: {scheduler_conf['prediction_type']}")
|
||||
|
132
invokeai/backend/model_manager/util/select_hf_files.py
Normal file
132
invokeai/backend/model_manager/util/select_hf_files.py
Normal file
@ -0,0 +1,132 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
Select the files from a HuggingFace repository needed for a particular model variant.
|
||||
|
||||
Usage:
|
||||
```
|
||||
from invokeai.backend.model_manager.util.select_hf_files import select_hf_model_files
|
||||
from invokeai.backend.model_manager.metadata.fetch import HuggingFaceMetadataFetch
|
||||
|
||||
metadata = HuggingFaceMetadataFetch().from_url("https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0")
|
||||
files_to_download = select_hf_model_files(metadata.files, variant='onnx')
|
||||
```
|
||||
"""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
from ..config import ModelRepoVariant
|
||||
|
||||
|
||||
def filter_files(
|
||||
files: List[Path],
|
||||
variant: Optional[ModelRepoVariant] = None,
|
||||
subfolder: Optional[Path] = None,
|
||||
) -> List[Path]:
|
||||
"""
|
||||
Take a list of files in a HuggingFace repo root and return paths to files needed to load the model.
|
||||
|
||||
:param files: List of files relative to the repo root.
|
||||
:param subfolder: Filter by the indicated subfolder.
|
||||
:param variant: Filter by files belonging to a particular variant, such as fp16.
|
||||
|
||||
The file list can be obtained from the `files` field of HuggingFaceMetadata,
|
||||
as defined in `invokeai.backend.model_manager.metadata.metadata_base`.
|
||||
"""
|
||||
variant = variant or ModelRepoVariant.DEFAULT
|
||||
paths: List[Path] = []
|
||||
|
||||
# Start by filtering on model file extensions, discarding images, docs, etc
|
||||
for file in files:
|
||||
if file.name.endswith((".json", ".txt")):
|
||||
paths.append(file)
|
||||
elif file.name.endswith(("learned_embeds.bin", "ip_adapter.bin", "lora_weights.safetensors")):
|
||||
paths.append(file)
|
||||
# BRITTLENESS WARNING!!
|
||||
# Diffusers models always seem to have "model" in their name, and the regex filter below is applied to avoid
|
||||
# downloading random checkpoints that might also be in the repo. However there is no guarantee
|
||||
# that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models
|
||||
# will adhere to this naming convention, so this is an area of brittleness.
|
||||
elif re.search(r"model(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name):
|
||||
paths.append(file)
|
||||
|
||||
# limit search to subfolder if requested
|
||||
if subfolder:
|
||||
paths = [x for x in paths if x.parent == Path(subfolder)]
|
||||
|
||||
# _filter_by_variant uniquifies the paths and returns a set
|
||||
return sorted(_filter_by_variant(paths, variant))
|
||||
|
||||
|
||||
def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path]:
|
||||
"""Select the proper variant files from a list of HuggingFace repo_id paths."""
|
||||
result = set()
|
||||
basenames: Dict[Path, Path] = {}
|
||||
for path in files:
|
||||
if path.suffix == ".onnx":
|
||||
if variant == ModelRepoVariant.ONNX:
|
||||
result.add(path)
|
||||
|
||||
elif "openvino_model" in path.name:
|
||||
if variant == ModelRepoVariant.OPENVINO:
|
||||
result.add(path)
|
||||
|
||||
elif "flax_model" in path.name:
|
||||
if variant == ModelRepoVariant.FLAX:
|
||||
result.add(path)
|
||||
|
||||
elif path.suffix in [".json", ".txt"]:
|
||||
result.add(path)
|
||||
|
||||
elif path.suffix in [".bin", ".safetensors", ".pt", ".ckpt"] and variant in [
|
||||
ModelRepoVariant.FP16,
|
||||
ModelRepoVariant.FP32,
|
||||
ModelRepoVariant.DEFAULT,
|
||||
]:
|
||||
parent = path.parent
|
||||
suffixes = path.suffixes
|
||||
if len(suffixes) == 2:
|
||||
variant_label, suffix = suffixes
|
||||
basename = parent / Path(path.stem).stem
|
||||
else:
|
||||
variant_label = ""
|
||||
suffix = suffixes[0]
|
||||
basename = parent / path.stem
|
||||
|
||||
if previous := basenames.get(basename):
|
||||
if (
|
||||
previous.suffix != ".safetensors" and suffix == ".safetensors"
|
||||
): # replace non-safetensors with safetensors when available
|
||||
basenames[basename] = path
|
||||
if variant_label == f".{variant}":
|
||||
basenames[basename] = path
|
||||
elif not variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.DEFAULT]:
|
||||
basenames[basename] = path
|
||||
else:
|
||||
basenames[basename] = path
|
||||
|
||||
else:
|
||||
continue
|
||||
|
||||
for v in basenames.values():
|
||||
result.add(v)
|
||||
|
||||
# If one of the architecture-related variants was specified and no files matched other than
|
||||
# config and text files then we return an empty list
|
||||
if (
|
||||
variant
|
||||
and variant in [ModelRepoVariant.ONNX, ModelRepoVariant.OPENVINO, ModelRepoVariant.FLAX]
|
||||
and not any(variant.value in x.name for x in result)
|
||||
):
|
||||
return set()
|
||||
|
||||
# Prune folders that contain just a `config.json`. This happens when
|
||||
# the requested variant (e.g. "onnx") is missing
|
||||
directories: Dict[Path, int] = {}
|
||||
for x in result:
|
||||
if not x.parent:
|
||||
continue
|
||||
directories[x.parent] = directories.get(x.parent, 0) + 1
|
||||
|
||||
return {x for x in result if directories[x.parent] > 1 or x.name != "config.json"}
|
Reference in New Issue
Block a user