add basic functionality for model metadata fetching from hf and civitai

This commit is contained in:
Lincoln Stein
2023-12-17 22:19:29 -05:00
parent 77b74264a8
commit c610283158
3 changed files with 231 additions and 0 deletions

View File

@ -0,0 +1,92 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
This module defines core text-to-image model metadata fields.
Metadata comprises any descriptive information that is not essential
for getting the model to run. For example "author" is metadata, while
"type", "base" and "format" are not. The latter fields are part of the
model's config, as defined in invokeai.backend.model_manager.config.
Note that the "name" and "description" are also present in `config`.
This may need reworking.
"""
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Set, Optional
from pydantic import BaseModel, Field
from pydantic.networks import AnyHttpUrl
class CommercialUsage(str, Enum):
"""Type of commercial usage allowed."""
No = "None"
Image = "Image"
Rent = "Rent"
RentCivit = "RentCivit"
Sell = "Sell"
class LicenseRestrictions(BaseModel):
"""Broad categories of licensing restrictions."""
AllowNoCredit: bool = Field(description="if true, model can be redistributed without crediting author", default=False)
AllowDerivatives: bool = Field(description="if true, derivatives of this model can be redistributed", default=False)
AllowDifferentLicense: bool = Field(description="if true, derivatives of this model be redistributed under a different license", default=False)
AllowCommercialUse: CommercialUsage = Field(description="Type of commercial use allowed or 'No' if no commercial use is allowed.", default_factory=set)
class ModelMetadataBase(BaseModel):
"""Base class for model metadata information."""
name: str = Field(description="model's name")
author: str = Field(description="model's author")
tags: Set[str] = Field(description="tags provided by model source")
class HuggingFaceMetadata(ModelMetadataBase):
"""Extended metadata fields provided by HuggingFace."""
id: str = Field(description="huggingface model id")
tag_dict: Dict[str, Any]
last_modified: datetime = Field(description="date of last commit to repo")
files: List[Path] = Field(description="sibling files that belong to this model", default_factory=list)
class CivitaiMetadata(ModelMetadataBase):
"""Extended metadata fields provided by Civitai."""
id: int = Field(description="Civitai model identifier")
version_name: str = Field(description="Version identifier, such as 'V2-alpha'")
version_id: int = Field(description="Civitai model version identifier")
created: datetime = Field(description="date the model was posted to CivitAI")
description: str = Field(description="text description of model; may contain HTML")
version_description: str = Field(description="text description of the model's reversion; usually change history; may contain HTML")
nsfw: bool = Field(description="whether the model tends to generate NSFW content", default=False)
restrictions: LicenseRestrictions = Field(description="license terms", default=LicenseRestrictions)
trained_words: Set[str] = Field(description="words to trigger the model", default_factory=set)
download_url: AnyHttpUrl = Field(description="download URL for this model")
base_model_trained_on: str = Field(description="base model on which this model was trained (currently not an enum)")
thumbnail_url: Optional[AnyHttpUrl] = Field(description="a thumbnail image for this model", default=None)
weight_min: float = Field(description="minimum suggested value for a LoRA or other secondary model", default=-1.0) # note: For future use; not currently easily
weight_max: float = Field(description="maximum suggested value for a LoRA or other secondary model", default=+2.0) # recoverable from metadata
@property
def credit_required(self) -> bool:
"""Return True if you must give credit for derivatives of this model and images generated from it."""
return not self.restrictions.AllowNoCredit
@property
def allow_commercial_use(self) -> bool:
"""Return True if commercial use is allowed."""
return self.restrictions.AllowCommercialUse == CommercialUsage('None')
@property
def allow_derivatives(self) -> bool:
"""Return True if derivatives of this model can be redistributed."""
return self.restrictions.AllowDerivatives
@property
def allow_different_license(self) -> bool:
"""Return true if derivatives of this model can use a different license."""
return self.restrictions.AllowDifferentLicense

View File

@ -0,0 +1,139 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
This module fetches metadata objects from HuggingFace and Civitai.
Usage:
from invokeai.backend.model_manager.metadata.fetch import MetadataFetch
metadata = MetadataFetch.from_civitai_url("https://civitai.com/models/58390/detail-tweaker-lora-lora")
print(metadata.description)
"""
import re
from pathlib import Path
from typing import Optional, Dict, Optional, Any
from datetime import datetime
import requests
from huggingface_hub import HfApi, configure_http_backend
from huggingface_hub.utils._errors import RepositoryNotFoundError
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
from invokeai.app.services.model_records import UnknownModelException
from .base import (
CivitaiMetadata,
HuggingFaceMetadata,
LicenseRestrictions,
CommercialUsage,
)
HF_MODEL_RE = r"https?://huggingface.co/([\w\-.]+/[\w\-.]+)"
CIVITAI_MODEL_PAGE_RE = r"https?://civitai.com/models/(\d+)"
CIVITAI_VERSION_PAGE_RE = r"https?://civitai.com/models/(\d+)\?modelVersionId=(\d+)"
CIVITAI_DOWNLOAD_RE = r"https?://civitai.com/api/download/models/(\d+)"
CIVITAI_VERSION_ENDPOINT = "https://civitai.com/api/v1/model-versions/"
CIVITAI_MODEL_ENDPOINT = "https://civitai.com/api/v1/models/"
class MetadataFetch:
"""Fetch metadata from HuggingFace and Civitai URLs."""
_requests: Session
def __init__(self, session: Optional[Session]=None):
"""
Initialize the fetcher with an optional requests.sessions.Session object.
By providing a configurable Session object, we can support unit tests on
this module without an internet connection.
"""
self._requests = session or requests.Session()
configure_http_backend(backend_factory = lambda: self._requests)
def from_huggingface_repoid(self, repo_id: str) -> HuggingFaceMetadata:
"""Return a HuggingFaceMetadata object given the model's repo_id."""
try:
model_info = HfApi().model_info(repo_id=repo_id, files_metadata=True)
except RepositoryNotFoundError as excp:
raise UnknownModelException(f"'{repo_id}' not found. See trace for details.") from excp
_, name = repo_id.split("/")
return HuggingFaceMetadata(
id = model_info.modelId,
author = model_info.author,
name = name,
last_modified = model_info.lastModified,
tags = model_info.tags,
tag_dict = model_info.card_data.to_dict(),
files = [Path(x.rfilename) for x in model_info.siblings]
)
def from_huggingface_url(self, url: AnyHttpUrl) -> HuggingFaceMetadata:
"""
Return a HuggingFaceMetadata object given the model's web page URL.
In the case of an invalid or missing URL, raises a ModelNotFound exception.
"""
if match := re.match(HF_MODEL_RE, str(url)):
repo_id = match.group(1)
return self.from_huggingface_repoid(repo_id)
else:
raise UnknownModelException(f"'{url}' does not look like a HuggingFace model page")
def from_civitai_versionid(self, version_id: int, model_metadata: Optional[Dict[str,Any]]=None) -> CivitaiMetadata:
"""Return Civitai metadata using a model's version id."""
version_url = CIVITAI_VERSION_ENDPOINT + str(version_id)
version = self._requests.get(version_url).json()
model_url = CIVITAI_MODEL_ENDPOINT + str(version['modelId'])
model = model_metadata or self._requests.get(model_url).json()
safe_thumbnails = [x['url'] for x in version['images'] if x['nsfw']=='None']
return CivitaiMetadata(
id=version['modelId'],
name=version['model']['name'],
version_id=version['id'],
version_name=version['name'],
created=datetime.fromisoformat(re.sub(r"Z$", "+00:00", version['createdAt'])),
base_model_trained_on=version['baseModel'], # note - need a dictionary to turn into a BaseModelType
download_url=version['downloadUrl'],
thumbnail_url=safe_thumbnails[0] if safe_thumbnails else None,
author=model['creator']['username'],
description=model['description'],
version_description=version['description'] or "",
tags=model['tags'],
trained_words=version['trainedWords'],
nsfw=version['model']['nsfw'],
restrictions=LicenseRestrictions(
AllowNoCredit=model['allowNoCredit'],
AllowCommercialUse=CommercialUsage(model['allowCommercialUse']),
AllowDerivatives=model['allowDerivatives'],
AllowDifferentLicense=model['allowDifferentLicense']
),
)
def from_civitai_modelid(self, model_id: int) -> CivitaiMetadata:
"""Return metadata from the default version of the indicated model."""
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
model = self._requests.get(model_url).json()
default_version = model['modelVersions'][0]['id']
return self.from_civitai_versionid(default_version, model)
def from_civitai_url(self, url: AnyHttpUrl) -> CivitaiMetadata:
"""Parse a Civitai URL that user is likely to pass and return its metadata."""
if match := re.match(CIVITAI_MODEL_PAGE_RE, str(url)):
model_id = match.group(1)
return self.from_civitai_modelid(int(model_id))
elif match := re.match(CIVITAI_VERSION_PAGE_RE, str(url)):
version_id = match.group(1)
return self.from_civitai_versionid(int(version_id))
elif match := re.match(CIVITAI_DOWNLOAD_RE, str(url)):
version_id = match.group(1)
return self.from_civitai_versionid(int(version_id))
raise UnknownModelException("The url '{url}' does not match any known Civitai URL patterns")