refactor(mm): get metadata working

This commit is contained in:
psychedelicious
2024-03-04 19:17:01 +11:00
parent 7cb0da1f66
commit c3aa985c93
8 changed files with 72 additions and 61 deletions

View File

@ -26,7 +26,7 @@ from typing import Literal, Optional, Type, Union
import torch
from diffusers.models.modeling_utils import ModelMixin
from pydantic import BaseModel, ConfigDict, Discriminator, Field, JsonValue, Tag, TypeAdapter
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
from typing_extensions import Annotated, Any, Dict
from ..raw_model import RawModel
@ -142,9 +142,7 @@ class ModelConfigBase(BaseModel):
description: Optional[str] = Field(description="Model description", default=None)
source: str = Field(description="The original source of the model (path, URL or repo_id).")
source_type: ModelSourceType = Field(description="The type of source")
source_api_response: Optional[JsonValue] = Field(
description="The original API response from the source", default=None
)
source_api_response: Optional[str] = Field(description="The original API response from the source, as stringified JSON.", default=None)
trigger_words: Optional[set[str]] = Field(description="Set of trigger words for this model", default=None)
model_config = ConfigDict(use_enum_values=False, validate_assignment=True)

View File

@ -23,12 +23,13 @@ metadata = fetcher.from_url("https://civitai.com/models/206883/split")
print(metadata.trained_words)
"""
import json
import re
from pathlib import Path
from typing import Any, Optional
import requests
from pydantic import TypeAdapter
from pydantic import TypeAdapter, ValidationError
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
@ -56,7 +57,7 @@ StringSetAdapter = TypeAdapter(set[str])
class CivitaiMetadataFetch(ModelMetadataFetchBase):
"""Fetch model metadata from Civitai."""
def __init__(self, session: Optional[Session] = None):
def __init__(self, session: Optional[Session] = None, api_key: Optional[str] = None):
"""
Initialize the fetcher with an optional requests.sessions.Session object.
@ -64,6 +65,7 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
this module without an internet connection.
"""
self._requests = session or requests.Session()
self._api_key = api_key
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
"""
@ -103,7 +105,7 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
May raise an `UnknownMetadataException`.
"""
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
model_json = self._requests.get(model_url).json()
model_json = self._requests.get(self._get_url_with_api_key(model_url)).json()
return self._from_api_response(model_json)
def _from_api_response(self, api_response: dict[str, Any], version_id: Optional[int] = None) -> CivitaiMetadata:
@ -134,7 +136,7 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
url = url + f"?type={primary_file['type']}{metadata_string}"
model_files = [
RemoteModelFile(
url=url,
url=self._get_url_with_api_key(url),
path=Path(primary_file["name"]),
size=int(primary_file["sizeKB"] * 1024),
sha256=primary_file["hashes"]["SHA256"],
@ -142,11 +144,16 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
]
try:
trigger_words = StringSetAdapter.validate_python(api_response["triggerWords"])
except TypeError:
trigger_words = StringSetAdapter.validate_python(version_json.get("trainedWords"))
except ValidationError:
trigger_words: set[str] = set()
return CivitaiMetadata(name=version_json["name"], files=model_files, trigger_words=trigger_words)
return CivitaiMetadata(
name=version_json["name"],
files=model_files,
trigger_words=trigger_words,
api_response=json.dumps(version_json),
)
def from_civitai_versionid(self, version_id: int, model_id: Optional[int] = None) -> CivitaiMetadata:
"""
@ -156,13 +163,13 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
"""
if model_id is None:
version_url = CIVITAI_VERSION_ENDPOINT + str(version_id)
version = self._requests.get(version_url).json()
version = self._requests.get(self._get_url_with_api_key(version_url)).json()
if error := version.get("error"):
raise UnknownMetadataException(error)
model_id = version["modelId"]
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
model_json = self._requests.get(model_url).json()
model_json = self._requests.get(self._get_url_with_api_key(model_url)).json()
return self._from_api_response(model_json, version_id)
@classmethod
@ -170,3 +177,12 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
"""Given the JSON representation of the metadata, return the corresponding Pydantic object."""
metadata = CivitaiMetadata.model_validate_json(json)
return metadata
def _get_url_with_api_key(self, url: str) -> str:
if not self._api_key:
return url
if "?" in url:
return f"{url}&token={self._api_key}"
return f"{url}?token={self._api_key}"

View File

@ -13,6 +13,7 @@ metadata = fetcher.from_url("https://huggingface.co/stabilityai/sdxl-turbo")
print(metadata.tags)
"""
import json
import re
from pathlib import Path
from typing import Optional
@ -89,7 +90,9 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
)
)
return HuggingFaceMetadata(id=model_info.id, name=name, files=files)
return HuggingFaceMetadata(
id=model_info.id, name=name, files=files, api_response=json.dumps(model_info.__dict__)
)
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
"""

View File

@ -18,7 +18,7 @@ from pathlib import Path
from typing import List, Literal, Optional, Union
from huggingface_hub import configure_http_backend, hf_hub_url
from pydantic import BaseModel, Field, JsonValue, TypeAdapter
from pydantic import BaseModel, Field, TypeAdapter
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
from typing_extensions import Annotated
@ -93,7 +93,7 @@ class CivitaiMetadata(ModelMetadataWithFiles):
type: Literal["civitai"] = "civitai"
trigger_words: set[str] = Field(description="Trigger words extracted from the API response")
api_response: Optional[JsonValue] = Field(description="Response from the Civitai API", default=None)
api_response: Optional[str] = Field(description="Response from the Civitai API as stringified JSON", default=None)
class HuggingFaceMetadata(ModelMetadataWithFiles):
@ -101,7 +101,7 @@ class HuggingFaceMetadata(ModelMetadataWithFiles):
type: Literal["huggingface"] = "huggingface"
id: str = Field(description="The HF model id")
api_response: Optional[JsonValue] = Field(description="Response from the HF API", default=None)
api_response: Optional[str] = Field(description="Response from the HF API as stringified JSON", default=None)
def download_urls(
self,