From c61028315884011117747a9150e0e8d8f8ea56d1 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sun, 17 Dec 2023 22:19:29 -0500 Subject: [PATCH] add basic functionality for model metadata fetching from hf and civitai --- .../model_manager/metadata/__init__.py | 0 .../backend/model_manager/metadata/base.py | 92 ++++++++++++ .../backend/model_manager/metadata/fetch.py | 139 ++++++++++++++++++ 3 files changed, 231 insertions(+) create mode 100644 invokeai/backend/model_manager/metadata/__init__.py create mode 100644 invokeai/backend/model_manager/metadata/base.py create mode 100644 invokeai/backend/model_manager/metadata/fetch.py diff --git a/invokeai/backend/model_manager/metadata/__init__.py b/invokeai/backend/model_manager/metadata/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/invokeai/backend/model_manager/metadata/base.py b/invokeai/backend/model_manager/metadata/base.py new file mode 100644 index 0000000000..e9ef8cd97f --- /dev/null +++ b/invokeai/backend/model_manager/metadata/base.py @@ -0,0 +1,92 @@ +# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team + +""" +This module defines core text-to-image model metadata fields. + +Metadata comprises any descriptive information that is not essential +for getting the model to run. For example "author" is metadata, while +"type", "base" and "format" are not. The latter fields are part of the +model's config, as defined in invokeai.backend.model_manager.config. + +Note that the "name" and "description" are also present in `config`. +This may need reworking. +""" + +from datetime import datetime +from enum import Enum +from pathlib import Path +from typing import Any, Dict, List, Set, Optional + +from pydantic import BaseModel, Field +from pydantic.networks import AnyHttpUrl + +class CommercialUsage(str, Enum): + """Type of commercial usage allowed.""" + + No = "None" + Image = "Image" + Rent = "Rent" + RentCivit = "RentCivit" + Sell = "Sell" + +class LicenseRestrictions(BaseModel): + """Broad categories of licensing restrictions.""" + + AllowNoCredit: bool = Field(description="if true, model can be redistributed without crediting author", default=False) + AllowDerivatives: bool = Field(description="if true, derivatives of this model can be redistributed", default=False) + AllowDifferentLicense: bool = Field(description="if true, derivatives of this model be redistributed under a different license", default=False) + AllowCommercialUse: CommercialUsage = Field(description="Type of commercial use allowed or 'No' if no commercial use is allowed.", default_factory=set) + +class ModelMetadataBase(BaseModel): + """Base class for model metadata information.""" + + name: str = Field(description="model's name") + author: str = Field(description="model's author") + tags: Set[str] = Field(description="tags provided by model source") + +class HuggingFaceMetadata(ModelMetadataBase): + """Extended metadata fields provided by HuggingFace.""" + + id: str = Field(description="huggingface model id") + tag_dict: Dict[str, Any] + last_modified: datetime = Field(description="date of last commit to repo") + files: List[Path] = Field(description="sibling files that belong to this model", default_factory=list) + +class CivitaiMetadata(ModelMetadataBase): + """Extended metadata fields provided by Civitai.""" + + id: int = Field(description="Civitai model identifier") + version_name: str = Field(description="Version identifier, such as 'V2-alpha'") + version_id: int = Field(description="Civitai model version identifier") + created: datetime = Field(description="date the model was posted to CivitAI") + description: str = Field(description="text description of model; may contain HTML") + version_description: str = Field(description="text description of the model's reversion; usually change history; may contain HTML") + nsfw: bool = Field(description="whether the model tends to generate NSFW content", default=False) + restrictions: LicenseRestrictions = Field(description="license terms", default=LicenseRestrictions) + trained_words: Set[str] = Field(description="words to trigger the model", default_factory=set) + download_url: AnyHttpUrl = Field(description="download URL for this model") + base_model_trained_on: str = Field(description="base model on which this model was trained (currently not an enum)") + thumbnail_url: Optional[AnyHttpUrl] = Field(description="a thumbnail image for this model", default=None) + weight_min: float = Field(description="minimum suggested value for a LoRA or other secondary model", default=-1.0) # note: For future use; not currently easily + weight_max: float = Field(description="maximum suggested value for a LoRA or other secondary model", default=+2.0) # recoverable from metadata + + + @property + def credit_required(self) -> bool: + """Return True if you must give credit for derivatives of this model and images generated from it.""" + return not self.restrictions.AllowNoCredit + + @property + def allow_commercial_use(self) -> bool: + """Return True if commercial use is allowed.""" + return self.restrictions.AllowCommercialUse == CommercialUsage('None') + + @property + def allow_derivatives(self) -> bool: + """Return True if derivatives of this model can be redistributed.""" + return self.restrictions.AllowDerivatives + + @property + def allow_different_license(self) -> bool: + """Return true if derivatives of this model can use a different license.""" + return self.restrictions.AllowDifferentLicense diff --git a/invokeai/backend/model_manager/metadata/fetch.py b/invokeai/backend/model_manager/metadata/fetch.py new file mode 100644 index 0000000000..346ecafa1e --- /dev/null +++ b/invokeai/backend/model_manager/metadata/fetch.py @@ -0,0 +1,139 @@ +# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team + +""" +This module fetches metadata objects from HuggingFace and Civitai. + +Usage: + +from invokeai.backend.model_manager.metadata.fetch import MetadataFetch + +metadata = MetadataFetch.from_civitai_url("https://civitai.com/models/58390/detail-tweaker-lora-lora") +print(metadata.description) +""" + +import re +from pathlib import Path +from typing import Optional, Dict, Optional, Any +from datetime import datetime + +import requests +from huggingface_hub import HfApi, configure_http_backend +from huggingface_hub.utils._errors import RepositoryNotFoundError +from pydantic.networks import AnyHttpUrl +from requests.sessions import Session + +from invokeai.app.services.model_records import UnknownModelException + +from .base import ( + CivitaiMetadata, + HuggingFaceMetadata, + LicenseRestrictions, + CommercialUsage, +) + +HF_MODEL_RE = r"https?://huggingface.co/([\w\-.]+/[\w\-.]+)" +CIVITAI_MODEL_PAGE_RE = r"https?://civitai.com/models/(\d+)" +CIVITAI_VERSION_PAGE_RE = r"https?://civitai.com/models/(\d+)\?modelVersionId=(\d+)" +CIVITAI_DOWNLOAD_RE = r"https?://civitai.com/api/download/models/(\d+)" + +CIVITAI_VERSION_ENDPOINT = "https://civitai.com/api/v1/model-versions/" +CIVITAI_MODEL_ENDPOINT = "https://civitai.com/api/v1/models/" + +class MetadataFetch: + """Fetch metadata from HuggingFace and Civitai URLs.""" + + _requests: Session + + def __init__(self, session: Optional[Session]=None): + """ + Initialize the fetcher with an optional requests.sessions.Session object. + + By providing a configurable Session object, we can support unit tests on + this module without an internet connection. + """ + self._requests = session or requests.Session() + configure_http_backend(backend_factory = lambda: self._requests) + + def from_huggingface_repoid(self, repo_id: str) -> HuggingFaceMetadata: + """Return a HuggingFaceMetadata object given the model's repo_id.""" + try: + model_info = HfApi().model_info(repo_id=repo_id, files_metadata=True) + except RepositoryNotFoundError as excp: + raise UnknownModelException(f"'{repo_id}' not found. See trace for details.") from excp + + _, name = repo_id.split("/") + return HuggingFaceMetadata( + id = model_info.modelId, + author = model_info.author, + name = name, + last_modified = model_info.lastModified, + tags = model_info.tags, + tag_dict = model_info.card_data.to_dict(), + files = [Path(x.rfilename) for x in model_info.siblings] + ) + + def from_huggingface_url(self, url: AnyHttpUrl) -> HuggingFaceMetadata: + """ + Return a HuggingFaceMetadata object given the model's web page URL. + + In the case of an invalid or missing URL, raises a ModelNotFound exception. + """ + if match := re.match(HF_MODEL_RE, str(url)): + repo_id = match.group(1) + return self.from_huggingface_repoid(repo_id) + else: + raise UnknownModelException(f"'{url}' does not look like a HuggingFace model page") + + def from_civitai_versionid(self, version_id: int, model_metadata: Optional[Dict[str,Any]]=None) -> CivitaiMetadata: + """Return Civitai metadata using a model's version id.""" + version_url = CIVITAI_VERSION_ENDPOINT + str(version_id) + version = self._requests.get(version_url).json() + + model_url = CIVITAI_MODEL_ENDPOINT + str(version['modelId']) + model = model_metadata or self._requests.get(model_url).json() + safe_thumbnails = [x['url'] for x in version['images'] if x['nsfw']=='None'] + + return CivitaiMetadata( + id=version['modelId'], + name=version['model']['name'], + version_id=version['id'], + version_name=version['name'], + created=datetime.fromisoformat(re.sub(r"Z$", "+00:00", version['createdAt'])), + base_model_trained_on=version['baseModel'], # note - need a dictionary to turn into a BaseModelType + download_url=version['downloadUrl'], + thumbnail_url=safe_thumbnails[0] if safe_thumbnails else None, + author=model['creator']['username'], + description=model['description'], + version_description=version['description'] or "", + tags=model['tags'], + trained_words=version['trainedWords'], + nsfw=version['model']['nsfw'], + restrictions=LicenseRestrictions( + AllowNoCredit=model['allowNoCredit'], + AllowCommercialUse=CommercialUsage(model['allowCommercialUse']), + AllowDerivatives=model['allowDerivatives'], + AllowDifferentLicense=model['allowDifferentLicense'] + ), + ) + + def from_civitai_modelid(self, model_id: int) -> CivitaiMetadata: + """Return metadata from the default version of the indicated model.""" + model_url = CIVITAI_MODEL_ENDPOINT + str(model_id) + model = self._requests.get(model_url).json() + default_version = model['modelVersions'][0]['id'] + return self.from_civitai_versionid(default_version, model) + + def from_civitai_url(self, url: AnyHttpUrl) -> CivitaiMetadata: + """Parse a Civitai URL that user is likely to pass and return its metadata.""" + if match := re.match(CIVITAI_MODEL_PAGE_RE, str(url)): + model_id = match.group(1) + return self.from_civitai_modelid(int(model_id)) + elif match := re.match(CIVITAI_VERSION_PAGE_RE, str(url)): + version_id = match.group(1) + return self.from_civitai_versionid(int(version_id)) + elif match := re.match(CIVITAI_DOWNLOAD_RE, str(url)): + version_id = match.group(1) + return self.from_civitai_versionid(int(version_id)) + raise UnknownModelException("The url '{url}' does not match any known Civitai URL patterns") + +