Multiple refinements on loaders:

- Cache stat collection enabled.
- Implemented ONNX loading.
- Add ability to specify the repo version variant in installer CLI.
- If caller asks for a repo version that doesn't exist, will fall back
  to empty version rather than raising an error.
This commit is contained in:
Lincoln Stein
2024-02-05 21:55:11 -05:00
committed by Brandon Rising
parent fdbd288956
commit 92843d55eb
18 changed files with 215 additions and 49 deletions

View File

@ -32,6 +32,8 @@ import requests
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
from invokeai.backend.model_manager import ModelRepoVariant
from ..metadata_base import (
AnyModelRepoMetadata,
CivitaiMetadata,
@ -82,10 +84,13 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
return self.from_civitai_versionid(int(version_id))
raise UnknownMetadataException("The url '{url}' does not match any known Civitai URL patterns")
def from_id(self, id: str) -> AnyModelRepoMetadata:
def from_id(self, id: str, variant: Optional[ModelRepoVariant] = None) -> AnyModelRepoMetadata:
"""
Given a Civitai model version ID, return a ModelRepoMetadata object.
:param id: An ID.
:param variant: A model variant from the ModelRepoVariant enum (currently ignored)
May raise an `UnknownMetadataException`.
"""
return self.from_civitai_versionid(int(id))

View File

@ -18,6 +18,8 @@ from typing import Optional
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
from invokeai.backend.model_manager import ModelRepoVariant
from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator
@ -45,10 +47,13 @@ class ModelMetadataFetchBase(ABC):
pass
@abstractmethod
def from_id(self, id: str) -> AnyModelRepoMetadata:
def from_id(self, id: str, variant: Optional[ModelRepoVariant] = None) -> AnyModelRepoMetadata:
"""
Given an ID for a model, return a ModelMetadata object.
:param id: An ID.
:param variant: A model variant from the ModelRepoVariant enum.
This method will raise a `UnknownMetadataException`
in the event that the requested model's metadata is not found at the provided id.
"""

View File

@ -19,10 +19,12 @@ from typing import Optional
import requests
from huggingface_hub import HfApi, configure_http_backend, hf_hub_url
from huggingface_hub.utils._errors import RepositoryNotFoundError
from huggingface_hub.utils._errors import RepositoryNotFoundError, RevisionNotFoundError
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
from invokeai.backend.model_manager import ModelRepoVariant
from ..metadata_base import (
AnyModelRepoMetadata,
HuggingFaceMetadata,
@ -53,12 +55,22 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
metadata = HuggingFaceMetadata.model_validate_json(json)
return metadata
def from_id(self, id: str) -> AnyModelRepoMetadata:
def from_id(self, id: str, variant: Optional[ModelRepoVariant] = None) -> AnyModelRepoMetadata:
"""Return a HuggingFaceMetadata object given the model's repo_id."""
try:
model_info = HfApi().model_info(repo_id=id, files_metadata=True)
except RepositoryNotFoundError as excp:
raise UnknownMetadataException(f"'{id}' not found. See trace for details.") from excp
# Little loop which tries fetching a revision corresponding to the selected variant.
# If not available, then set variant to None and get the default.
# If this too fails, raise exception.
model_info = None
while not model_info:
try:
model_info = HfApi().model_info(repo_id=id, files_metadata=True, revision=variant)
except RepositoryNotFoundError as excp:
raise UnknownMetadataException(f"'{id}' not found. See trace for details.") from excp
except RevisionNotFoundError:
if variant is None:
raise
else:
variant = None
_, name = id.split("/")
return HuggingFaceMetadata(
@ -70,7 +82,7 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
tags=model_info.tags,
files=[
RemoteModelFile(
url=hf_hub_url(id, x.rfilename),
url=hf_hub_url(id, x.rfilename, revision=variant),
path=Path(name, x.rfilename),
size=x.size,
sha256=x.lfs.get("sha256") if x.lfs else None,