refactor(mm): simplify model metadata schemas

This commit is contained in:
psychedelicious 2024-03-01 22:12:13 +11:00
parent 2c835fd550
commit 4471ea8ad1
4 changed files with 52 additions and 141 deletions

View File

@ -25,9 +25,7 @@ from .metadata_base import (
AnyModelRepoMetadataValidator,
BaseMetadata,
CivitaiMetadata,
CommercialUsage,
HuggingFaceMetadata,
LicenseRestrictions,
ModelMetadataWithFiles,
RemoteModelFile,
UnknownMetadataException,
@ -38,10 +36,8 @@ __all__ = [
"AnyModelRepoMetadataValidator",
"CivitaiMetadata",
"CivitaiMetadataFetch",
"CommercialUsage",
"HuggingFaceMetadata",
"HuggingFaceMetadataFetch",
"LicenseRestrictions",
"ModelMetadataFetchBase",
"BaseMetadata",
"ModelMetadataWithFiles",

View File

@ -24,21 +24,19 @@ print(metadata.trained_words)
"""
import re
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Optional
from typing import Any, Optional
import requests
from pydantic import TypeAdapter
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
from invokeai.backend.model_manager import ModelRepoVariant
from invokeai.backend.model_manager.config import ModelRepoVariant
from ..metadata_base import (
AnyModelRepoMetadata,
CivitaiMetadata,
CommercialUsage,
LicenseRestrictions,
RemoteModelFile,
UnknownMetadataException,
)
@ -52,6 +50,9 @@ CIVITAI_VERSION_ENDPOINT = "https://civitai.com/api/v1/model-versions/"
CIVITAI_MODEL_ENDPOINT = "https://civitai.com/api/v1/models/"
StringSetAdapter = TypeAdapter(set[str])
class CivitaiMetadataFetch(ModelMetadataFetchBase):
"""Fetch model metadata from Civitai."""
@ -103,21 +104,20 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
"""
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
model_json = self._requests.get(model_url).json()
return self._from_model_json(model_json)
return self._from_api_response(model_json)
def _from_model_json(self, model_json: Dict[str, Any], version_id: Optional[int] = None) -> CivitaiMetadata:
def _from_api_response(self, api_response: dict[str, Any], version_id: Optional[int] = None) -> CivitaiMetadata:
try:
version_id = version_id or model_json["modelVersions"][0]["id"]
version_id = version_id or api_response["modelVersions"][0]["id"]
except TypeError as excp:
raise UnknownMetadataException from excp
# loop till we find the section containing the version requested
version_sections = [x for x in model_json["modelVersions"] if x["id"] == version_id]
version_sections = [x for x in api_response["modelVersions"] if x["id"] == version_id]
if not version_sections:
raise UnknownMetadataException(f"Version {version_id} not found in model metadata")
version_json = version_sections[0]
safe_thumbnails = [x["url"] for x in version_json["images"] if x["nsfw"] == "None"]
# Civitai has one "primary" file plus others such as VAEs. We only fetch the primary.
primary = [x for x in version_json["files"] if x.get("primary")]
@ -140,31 +140,13 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
sha256=primary_file["hashes"]["SHA256"],
)
]
return CivitaiMetadata(
id=model_json["id"],
name=version_json["name"],
version_id=version_json["id"],
version_name=version_json["name"],
created=datetime.fromisoformat(_fix_timezone(version_json["createdAt"])),
updated=datetime.fromisoformat(_fix_timezone(version_json["updatedAt"])),
published=datetime.fromisoformat(_fix_timezone(version_json["publishedAt"])),
base_model_trained_on=version_json["baseModel"], # note - need a dictionary to turn into a BaseModelType
files=model_files,
download_url=version_json["downloadUrl"],
thumbnail_url=safe_thumbnails[0] if safe_thumbnails else None,
author=model_json["creator"]["username"],
description=model_json["description"],
version_description=version_json["description"] or "",
tags=model_json["tags"],
trained_words=version_json["trainedWords"],
nsfw=model_json["nsfw"],
restrictions=LicenseRestrictions(
AllowNoCredit=model_json["allowNoCredit"],
AllowCommercialUse={CommercialUsage(x) for x in model_json["allowCommercialUse"]},
AllowDerivatives=model_json["allowDerivatives"],
AllowDifferentLicense=model_json["allowDifferentLicense"],
),
)
try:
trigger_words = StringSetAdapter.validate_python(api_response["triggerWords"])
except TypeError:
trigger_words: set[str] = set()
return CivitaiMetadata(name=version_json["name"], files=model_files, trigger_words=trigger_words)
def from_civitai_versionid(self, version_id: int, model_id: Optional[int] = None) -> CivitaiMetadata:
"""
@ -181,14 +163,10 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
model_json = self._requests.get(model_url).json()
return self._from_model_json(model_json, version_id)
return self._from_api_response(model_json, version_id)
@classmethod
def from_json(cls, json: str) -> CivitaiMetadata:
"""Given the JSON representation of the metadata, return the corresponding Pydantic object."""
metadata = CivitaiMetadata.model_validate_json(json)
return metadata
def _fix_timezone(date: str) -> str:
return re.sub(r"Z$", "+00:00", date)

View File

@ -19,11 +19,12 @@ from typing import Optional
import requests
from huggingface_hub import HfApi, configure_http_backend, hf_hub_url
from huggingface_hub.hf_api import RepoSibling
from huggingface_hub.utils._errors import RepositoryNotFoundError, RevisionNotFoundError
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
from invokeai.backend.model_manager import ModelRepoVariant
from invokeai.backend.model_manager.config import ModelRepoVariant
from ..metadata_base import (
AnyModelRepoMetadata,
@ -75,20 +76,8 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
_, name = id.split("/")
return HuggingFaceMetadata(
id=model_info.id,
author=model_info.author,
name=name,
last_modified=model_info.last_modified,
tag_dict=model_info.card_data.to_dict() if model_info.card_data else {},
tags=model_info.tags,
files=[
RemoteModelFile(
url=hf_hub_url(id, x.rfilename, revision=variant),
path=Path(name, x.rfilename),
size=x.size,
sha256=x.lfs.get("sha256") if x.lfs else None,
)
for x in model_info.siblings
],
files=parse_siblings(id, model_info.siblings, variant),
)
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
@ -102,3 +91,27 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
return self.from_id(repo_id)
else:
raise UnknownMetadataException(f"'{url}' does not look like a HuggingFace model page")
def parse_siblings(
repo_id: str, siblings: Optional[list[RepoSibling]] = None, variant: Optional[ModelRepoVariant] = None
) -> list[RemoteModelFile]:
"""Parse the siblings list from the HuggingFace API into a list of RemoteModelFile objects."""
if not siblings:
return []
files: list[RemoteModelFile] = []
for s in siblings:
assert s.rfilename is not None
assert s.size is not None
files.append(
RemoteModelFile(
url=hf_hub_url(repo_id, s.rfilename, revision=variant.value if variant else None),
path=Path(s.rfilename),
size=s.size,
sha256=s.lfs.get("sha256") if s.lfs else None,
)
)
return files

View File

@ -14,13 +14,11 @@ versions of these fields are intended to be kept in sync with the
remote repo.
"""
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union
from typing import List, Literal, Optional, Union
from huggingface_hub import configure_http_backend, hf_hub_url
from pydantic import BaseModel, Field, TypeAdapter
from pydantic import BaseModel, Field, JsonValue, TypeAdapter
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
from typing_extensions import Annotated
@ -35,31 +33,6 @@ class UnknownMetadataException(Exception):
"""Raised when no metadata is available for a model."""
class CommercialUsage(str, Enum):
"""Type of commercial usage allowed."""
No = "None"
Image = "Image"
Rent = "Rent"
RentCivit = "RentCivit"
Sell = "Sell"
class LicenseRestrictions(BaseModel):
"""Broad categories of licensing restrictions."""
AllowNoCredit: bool = Field(
description="if true, model can be redistributed without crediting author", default=False
)
AllowDerivatives: bool = Field(description="if true, derivatives of this model can be redistributed", default=False)
AllowDifferentLicense: bool = Field(
description="if true, derivatives of this model be redistributed under a different license", default=False
)
AllowCommercialUse: Optional[Set[CommercialUsage] | CommercialUsage] = Field(
description="Type of commercial use allowed if no commercial use is allowed.", default=None
)
class RemoteModelFile(BaseModel):
"""Information about a downloadable file that forms part of a model."""
@ -82,11 +55,6 @@ class ModelMetadataBase(BaseModel):
"""Base class for model metadata information."""
name: str = Field(description="model's name")
author: str = Field(description="model's author")
tags: Optional[Set[str]] = Field(description="tags provided by model source", default=None)
default_settings: Optional[ModelDefaultSettings] = Field(
description="default settings for this model", default=None
)
class BaseMetadata(ModelMetadataBase):
@ -124,60 +92,16 @@ class CivitaiMetadata(ModelMetadataWithFiles):
"""Extended metadata fields provided by Civitai."""
type: Literal["civitai"] = "civitai"
id: int = Field(description="Civitai version identifier")
version_name: str = Field(description="Version identifier, such as 'V2-alpha'")
version_id: int = Field(description="Civitai model version identifier")
created: datetime = Field(description="date the model was created")
updated: datetime = Field(description="date the model was last modified")
published: datetime = Field(description="date the model was published to Civitai")
description: str = Field(description="text description of model; may contain HTML")
version_description: str = Field(
description="text description of the model's reversion; usually change history; may contain HTML"
)
nsfw: bool = Field(description="whether the model tends to generate NSFW content", default=False)
restrictions: LicenseRestrictions = Field(description="license terms", default_factory=LicenseRestrictions)
trained_words: Set[str] = Field(description="words to trigger the model", default_factory=set)
download_url: AnyHttpUrl = Field(description="download URL for this model")
base_model_trained_on: str = Field(description="base model on which this model was trained (currently not an enum)")
thumbnail_url: Optional[AnyHttpUrl] = Field(description="a thumbnail image for this model", default=None)
weight_minmax: Tuple[float, float] = Field(
description="minimum and maximum slider values for a LoRA or other secondary model", default=(-1.0, +2.0)
) # note: For future use
@property
def credit_required(self) -> bool:
"""Return True if you must give credit for derivatives of this model and images generated from it."""
return not self.restrictions.AllowNoCredit
@property
def allow_commercial_use(self) -> bool:
"""Return True if commercial use is allowed."""
if self.restrictions.AllowCommercialUse is None:
return False
else:
# accommodate schema change
acu = self.restrictions.AllowCommercialUse
commercial_usage = acu if isinstance(acu, set) else {acu}
return CommercialUsage.No not in commercial_usage
@property
def allow_derivatives(self) -> bool:
"""Return True if derivatives of this model can be redistributed."""
return self.restrictions.AllowDerivatives
@property
def allow_different_license(self) -> bool:
"""Return true if derivatives of this model can use a different license."""
return self.restrictions.AllowDifferentLicense
trigger_words: set[str] = Field(description="Trigger words extracted from the API response")
api_response: Optional[JsonValue] = Field(description="Response from the Civitai API", default=None)
class HuggingFaceMetadata(ModelMetadataWithFiles):
"""Extended metadata fields provided by HuggingFace."""
type: Literal["huggingface"] = "huggingface"
id: str = Field(description="huggingface model id")
tag_dict: Dict[str, Any]
last_modified: datetime = Field(description="date of last commit to repo")
id: str = Field(description="The HF model id")
api_response: Optional[JsonValue] = Field(description="Response from the HF API", default=None)
def download_urls(
self,
@ -206,7 +130,7 @@ class HuggingFaceMetadata(ModelMetadataWithFiles):
# the next step reads model_index.json to determine which subdirectories belong
# to the model
if Path(f"{prefix}model_index.json") in paths:
url = hf_hub_url(self.id, filename="model_index.json", subfolder=subfolder)
url = hf_hub_url(self.id, filename="model_index.json", subfolder=str(subfolder) if subfolder else None)
resp = session.get(url)
resp.raise_for_status()
submodels = resp.json()