refactor(mm): wip schema changes

This commit is contained in:
psychedelicious
2024-03-04 19:16:25 +11:00
parent 3534366146
commit 7cb0da1f66
5 changed files with 42 additions and 39 deletions

View File

@ -19,7 +19,6 @@ from typing import Optional
import requests
from huggingface_hub import HfApi, configure_http_backend, hf_hub_url
from huggingface_hub.hf_api import RepoSibling
from huggingface_hub.utils._errors import RepositoryNotFoundError, RevisionNotFoundError
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
@ -61,6 +60,7 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
# Little loop which tries fetching a revision corresponding to the selected variant.
# If not available, then set variant to None and get the default.
# If this too fails, raise exception.
model_info = None
while not model_info:
try:
@ -73,12 +73,23 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
else:
variant = None
files: list[RemoteModelFile] = []
_, name = id.split("/")
return HuggingFaceMetadata(
id=model_info.id,
name=name,
files=parse_siblings(id, model_info.siblings, variant),
)
for s in model_info.siblings or []:
assert s.rfilename is not None
assert s.size is not None
files.append(
RemoteModelFile(
url=hf_hub_url(id, s.rfilename, revision=variant),
path=Path(name, s.rfilename),
size=s.size,
sha256=s.lfs.get("sha256") if s.lfs else None,
)
)
return HuggingFaceMetadata(id=model_info.id, name=name, files=files)
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
"""
@ -91,27 +102,3 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
return self.from_id(repo_id)
else:
raise UnknownMetadataException(f"'{url}' does not look like a HuggingFace model page")
def parse_siblings(
repo_id: str, siblings: Optional[list[RepoSibling]] = None, variant: Optional[ModelRepoVariant] = None
) -> list[RemoteModelFile]:
"""Parse the siblings list from the HuggingFace API into a list of RemoteModelFile objects."""
if not siblings:
return []
files: list[RemoteModelFile] = []
for s in siblings:
assert s.rfilename is not None
assert s.size is not None
files.append(
RemoteModelFile(
url=hf_hub_url(repo_id, s.rfilename, revision=variant.value if variant else None),
path=Path(s.rfilename),
size=s.size,
sha256=s.lfs.get("sha256") if s.lfs else None,
)
)
return files