refactor(mm): get metadata working

This commit is contained in:
psychedelicious
2024-03-04 19:17:01 +11:00
parent 7cb0da1f66
commit c3aa985c93
8 changed files with 72 additions and 61 deletions

View File

@ -23,12 +23,13 @@ metadata = fetcher.from_url("https://civitai.com/models/206883/split")
print(metadata.trained_words)
"""
import json
import re
from pathlib import Path
from typing import Any, Optional
import requests
from pydantic import TypeAdapter
from pydantic import TypeAdapter, ValidationError
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
@ -56,7 +57,7 @@ StringSetAdapter = TypeAdapter(set[str])
class CivitaiMetadataFetch(ModelMetadataFetchBase):
"""Fetch model metadata from Civitai."""
def __init__(self, session: Optional[Session] = None):
def __init__(self, session: Optional[Session] = None, api_key: Optional[str] = None):
"""
Initialize the fetcher with an optional requests.sessions.Session object.
@ -64,6 +65,7 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
this module without an internet connection.
"""
self._requests = session or requests.Session()
self._api_key = api_key
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
"""
@ -103,7 +105,7 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
May raise an `UnknownMetadataException`.
"""
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
model_json = self._requests.get(model_url).json()
model_json = self._requests.get(self._get_url_with_api_key(model_url)).json()
return self._from_api_response(model_json)
def _from_api_response(self, api_response: dict[str, Any], version_id: Optional[int] = None) -> CivitaiMetadata:
@ -134,7 +136,7 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
url = url + f"?type={primary_file['type']}{metadata_string}"
model_files = [
RemoteModelFile(
url=url,
url=self._get_url_with_api_key(url),
path=Path(primary_file["name"]),
size=int(primary_file["sizeKB"] * 1024),
sha256=primary_file["hashes"]["SHA256"],
@ -142,11 +144,16 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
]
try:
trigger_words = StringSetAdapter.validate_python(api_response["triggerWords"])
except TypeError:
trigger_words = StringSetAdapter.validate_python(version_json.get("trainedWords"))
except ValidationError:
trigger_words: set[str] = set()
return CivitaiMetadata(name=version_json["name"], files=model_files, trigger_words=trigger_words)
return CivitaiMetadata(
name=version_json["name"],
files=model_files,
trigger_words=trigger_words,
api_response=json.dumps(version_json),
)
def from_civitai_versionid(self, version_id: int, model_id: Optional[int] = None) -> CivitaiMetadata:
"""
@ -156,13 +163,13 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
"""
if model_id is None:
version_url = CIVITAI_VERSION_ENDPOINT + str(version_id)
version = self._requests.get(version_url).json()
version = self._requests.get(self._get_url_with_api_key(version_url)).json()
if error := version.get("error"):
raise UnknownMetadataException(error)
model_id = version["modelId"]
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
model_json = self._requests.get(model_url).json()
model_json = self._requests.get(self._get_url_with_api_key(model_url)).json()
return self._from_api_response(model_json, version_id)
@classmethod
@ -170,3 +177,12 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
"""Given the JSON representation of the metadata, return the corresponding Pydantic object."""
metadata = CivitaiMetadata.model_validate_json(json)
return metadata
def _get_url_with_api_key(self, url: str) -> str:
if not self._api_key:
return url
if "?" in url:
return f"{url}&token={self._api_key}"
return f"{url}?token={self._api_key}"

View File

@ -13,6 +13,7 @@ metadata = fetcher.from_url("https://huggingface.co/stabilityai/sdxl-turbo")
print(metadata.tags)
"""
import json
import re
from pathlib import Path
from typing import Optional
@ -89,7 +90,9 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
)
)
return HuggingFaceMetadata(id=model_info.id, name=name, files=files)
return HuggingFaceMetadata(
id=model_info.id, name=name, files=files, api_response=json.dumps(model_info.__dict__)
)
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
"""