mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor(mm): get metadata working
This commit is contained in:
@ -23,12 +23,13 @@ metadata = fetcher.from_url("https://civitai.com/models/206883/split")
|
||||
print(metadata.trained_words)
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import requests
|
||||
from pydantic import TypeAdapter
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
|
||||
@ -56,7 +57,7 @@ StringSetAdapter = TypeAdapter(set[str])
|
||||
class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
"""Fetch model metadata from Civitai."""
|
||||
|
||||
def __init__(self, session: Optional[Session] = None):
|
||||
def __init__(self, session: Optional[Session] = None, api_key: Optional[str] = None):
|
||||
"""
|
||||
Initialize the fetcher with an optional requests.sessions.Session object.
|
||||
|
||||
@ -64,6 +65,7 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
this module without an internet connection.
|
||||
"""
|
||||
self._requests = session or requests.Session()
|
||||
self._api_key = api_key
|
||||
|
||||
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
@ -103,7 +105,7 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
May raise an `UnknownMetadataException`.
|
||||
"""
|
||||
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
|
||||
model_json = self._requests.get(model_url).json()
|
||||
model_json = self._requests.get(self._get_url_with_api_key(model_url)).json()
|
||||
return self._from_api_response(model_json)
|
||||
|
||||
def _from_api_response(self, api_response: dict[str, Any], version_id: Optional[int] = None) -> CivitaiMetadata:
|
||||
@ -134,7 +136,7 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
url = url + f"?type={primary_file['type']}{metadata_string}"
|
||||
model_files = [
|
||||
RemoteModelFile(
|
||||
url=url,
|
||||
url=self._get_url_with_api_key(url),
|
||||
path=Path(primary_file["name"]),
|
||||
size=int(primary_file["sizeKB"] * 1024),
|
||||
sha256=primary_file["hashes"]["SHA256"],
|
||||
@ -142,11 +144,16 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
]
|
||||
|
||||
try:
|
||||
trigger_words = StringSetAdapter.validate_python(api_response["triggerWords"])
|
||||
except TypeError:
|
||||
trigger_words = StringSetAdapter.validate_python(version_json.get("trainedWords"))
|
||||
except ValidationError:
|
||||
trigger_words: set[str] = set()
|
||||
|
||||
return CivitaiMetadata(name=version_json["name"], files=model_files, trigger_words=trigger_words)
|
||||
return CivitaiMetadata(
|
||||
name=version_json["name"],
|
||||
files=model_files,
|
||||
trigger_words=trigger_words,
|
||||
api_response=json.dumps(version_json),
|
||||
)
|
||||
|
||||
def from_civitai_versionid(self, version_id: int, model_id: Optional[int] = None) -> CivitaiMetadata:
|
||||
"""
|
||||
@ -156,13 +163,13 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
"""
|
||||
if model_id is None:
|
||||
version_url = CIVITAI_VERSION_ENDPOINT + str(version_id)
|
||||
version = self._requests.get(version_url).json()
|
||||
version = self._requests.get(self._get_url_with_api_key(version_url)).json()
|
||||
if error := version.get("error"):
|
||||
raise UnknownMetadataException(error)
|
||||
model_id = version["modelId"]
|
||||
|
||||
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
|
||||
model_json = self._requests.get(model_url).json()
|
||||
model_json = self._requests.get(self._get_url_with_api_key(model_url)).json()
|
||||
return self._from_api_response(model_json, version_id)
|
||||
|
||||
@classmethod
|
||||
@ -170,3 +177,12 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
"""Given the JSON representation of the metadata, return the corresponding Pydantic object."""
|
||||
metadata = CivitaiMetadata.model_validate_json(json)
|
||||
return metadata
|
||||
|
||||
def _get_url_with_api_key(self, url: str) -> str:
|
||||
if not self._api_key:
|
||||
return url
|
||||
|
||||
if "?" in url:
|
||||
return f"{url}&token={self._api_key}"
|
||||
|
||||
return f"{url}?token={self._api_key}"
|
||||
|
@ -13,6 +13,7 @@ metadata = fetcher.from_url("https://huggingface.co/stabilityai/sdxl-turbo")
|
||||
print(metadata.tags)
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
@ -89,7 +90,9 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
|
||||
)
|
||||
)
|
||||
|
||||
return HuggingFaceMetadata(id=model_info.id, name=name, files=files)
|
||||
return HuggingFaceMetadata(
|
||||
id=model_info.id, name=name, files=files, api_response=json.dumps(model_info.__dict__)
|
||||
)
|
||||
|
||||
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
|
Reference in New Issue
Block a user