# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team """ This module fetches model metadata objects from the HuggingFace model repository, using either a `repo_id` or the model page URL. Usage: from invokeai.backend.model_manager.metadata.fetch import HuggingFaceMetadataFetch fetcher = HuggingFaceMetadataFetch() metadata = fetcher.from_url("https://huggingface.co/stabilityai/sdxl-turbo") print(metadata.tags) """ import json import re from pathlib import Path from typing import Optional import requests from huggingface_hub import HfApi, configure_http_backend, hf_hub_url from huggingface_hub.utils._errors import RepositoryNotFoundError, RevisionNotFoundError from pydantic.networks import AnyHttpUrl from requests.sessions import Session from invokeai.backend.model_manager.config import ModelRepoVariant from ..metadata_base import ( AnyModelRepoMetadata, HuggingFaceMetadata, RemoteModelFile, UnknownMetadataException, ) from .fetch_base import ModelMetadataFetchBase HF_MODEL_RE = r"https?://huggingface.co/([\w\-.]+/[\w\-.]+)" class HuggingFaceMetadataFetch(ModelMetadataFetchBase): """Fetch model metadata from HuggingFace.""" def __init__(self, session: Optional[Session] = None): """ Initialize the fetcher with an optional requests.sessions.Session object. By providing a configurable Session object, we can support unit tests on this module without an internet connection. """ self._requests = session or requests.Session() configure_http_backend(backend_factory=lambda: self._requests) @classmethod def from_json(cls, json: str) -> HuggingFaceMetadata: """Given the JSON representation of the metadata, return the corresponding Pydantic object.""" metadata = HuggingFaceMetadata.model_validate_json(json) return metadata def from_id(self, id: str, variant: Optional[ModelRepoVariant] = None) -> AnyModelRepoMetadata: """Return a HuggingFaceMetadata object given the model's repo_id.""" # Little loop which tries fetching a revision corresponding to the selected variant. # If not available, then set variant to None and get the default. # If this too fails, raise exception. model_info = None while not model_info: try: model_info = HfApi().model_info(repo_id=id, files_metadata=True, revision=variant) except RepositoryNotFoundError as excp: raise UnknownMetadataException(f"'{id}' not found. See trace for details.") from excp except RevisionNotFoundError: if variant is None: raise else: variant = None files: list[RemoteModelFile] = [] _, name = id.split("/") for s in model_info.siblings or []: assert s.rfilename is not None assert s.size is not None files.append( RemoteModelFile( url=hf_hub_url(id, s.rfilename, revision=variant or "main"), path=Path(name, s.rfilename), size=s.size, sha256=s.lfs.get("sha256") if s.lfs else None, ) ) # diffusers models have a `model_index.json` or `config.json` file is_diffusers = any(str(f.url).endswith(("model_index.json", "config.json")) for f in files) # These URLs will be exposed to the user - I think these are the only file types we fully support ckpt_urls = ( None if is_diffusers else [ f.url for f in files if str(f.url).endswith( ( ".safetensors", ".bin", ".pth", ".pt", ".ckpt", ) ) ] ) return HuggingFaceMetadata( id=model_info.id, name=name, files=files, api_response=json.dumps(model_info.__dict__, default=str), is_diffusers=is_diffusers, ckpt_urls=ckpt_urls, ) def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata: """ Return a HuggingFaceMetadata object given the model's web page URL. In the case of an invalid or missing URL, raises a ModelNotFound exception. """ if match := re.match(HF_MODEL_RE, str(url), re.IGNORECASE): repo_id = match.group(1) return self.from_id(repo_id) else: raise UnknownMetadataException(f"'{url}' does not look like a HuggingFace model page")