mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add basic functionality for model metadata fetching from hf and civitai
This commit is contained in:
0
invokeai/backend/model_manager/metadata/__init__.py
Normal file
0
invokeai/backend/model_manager/metadata/__init__.py
Normal file
92
invokeai/backend/model_manager/metadata/base.py
Normal file
92
invokeai/backend/model_manager/metadata/base.py
Normal file
@ -0,0 +1,92 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
|
||||
"""
|
||||
This module defines core text-to-image model metadata fields.
|
||||
|
||||
Metadata comprises any descriptive information that is not essential
|
||||
for getting the model to run. For example "author" is metadata, while
|
||||
"type", "base" and "format" are not. The latter fields are part of the
|
||||
model's config, as defined in invokeai.backend.model_manager.config.
|
||||
|
||||
Note that the "name" and "description" are also present in `config`.
|
||||
This may need reworking.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Set, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
|
||||
class CommercialUsage(str, Enum):
|
||||
"""Type of commercial usage allowed."""
|
||||
|
||||
No = "None"
|
||||
Image = "Image"
|
||||
Rent = "Rent"
|
||||
RentCivit = "RentCivit"
|
||||
Sell = "Sell"
|
||||
|
||||
class LicenseRestrictions(BaseModel):
|
||||
"""Broad categories of licensing restrictions."""
|
||||
|
||||
AllowNoCredit: bool = Field(description="if true, model can be redistributed without crediting author", default=False)
|
||||
AllowDerivatives: bool = Field(description="if true, derivatives of this model can be redistributed", default=False)
|
||||
AllowDifferentLicense: bool = Field(description="if true, derivatives of this model be redistributed under a different license", default=False)
|
||||
AllowCommercialUse: CommercialUsage = Field(description="Type of commercial use allowed or 'No' if no commercial use is allowed.", default_factory=set)
|
||||
|
||||
class ModelMetadataBase(BaseModel):
|
||||
"""Base class for model metadata information."""
|
||||
|
||||
name: str = Field(description="model's name")
|
||||
author: str = Field(description="model's author")
|
||||
tags: Set[str] = Field(description="tags provided by model source")
|
||||
|
||||
class HuggingFaceMetadata(ModelMetadataBase):
|
||||
"""Extended metadata fields provided by HuggingFace."""
|
||||
|
||||
id: str = Field(description="huggingface model id")
|
||||
tag_dict: Dict[str, Any]
|
||||
last_modified: datetime = Field(description="date of last commit to repo")
|
||||
files: List[Path] = Field(description="sibling files that belong to this model", default_factory=list)
|
||||
|
||||
class CivitaiMetadata(ModelMetadataBase):
|
||||
"""Extended metadata fields provided by Civitai."""
|
||||
|
||||
id: int = Field(description="Civitai model identifier")
|
||||
version_name: str = Field(description="Version identifier, such as 'V2-alpha'")
|
||||
version_id: int = Field(description="Civitai model version identifier")
|
||||
created: datetime = Field(description="date the model was posted to CivitAI")
|
||||
description: str = Field(description="text description of model; may contain HTML")
|
||||
version_description: str = Field(description="text description of the model's reversion; usually change history; may contain HTML")
|
||||
nsfw: bool = Field(description="whether the model tends to generate NSFW content", default=False)
|
||||
restrictions: LicenseRestrictions = Field(description="license terms", default=LicenseRestrictions)
|
||||
trained_words: Set[str] = Field(description="words to trigger the model", default_factory=set)
|
||||
download_url: AnyHttpUrl = Field(description="download URL for this model")
|
||||
base_model_trained_on: str = Field(description="base model on which this model was trained (currently not an enum)")
|
||||
thumbnail_url: Optional[AnyHttpUrl] = Field(description="a thumbnail image for this model", default=None)
|
||||
weight_min: float = Field(description="minimum suggested value for a LoRA or other secondary model", default=-1.0) # note: For future use; not currently easily
|
||||
weight_max: float = Field(description="maximum suggested value for a LoRA or other secondary model", default=+2.0) # recoverable from metadata
|
||||
|
||||
|
||||
@property
|
||||
def credit_required(self) -> bool:
|
||||
"""Return True if you must give credit for derivatives of this model and images generated from it."""
|
||||
return not self.restrictions.AllowNoCredit
|
||||
|
||||
@property
|
||||
def allow_commercial_use(self) -> bool:
|
||||
"""Return True if commercial use is allowed."""
|
||||
return self.restrictions.AllowCommercialUse == CommercialUsage('None')
|
||||
|
||||
@property
|
||||
def allow_derivatives(self) -> bool:
|
||||
"""Return True if derivatives of this model can be redistributed."""
|
||||
return self.restrictions.AllowDerivatives
|
||||
|
||||
@property
|
||||
def allow_different_license(self) -> bool:
|
||||
"""Return true if derivatives of this model can use a different license."""
|
||||
return self.restrictions.AllowDifferentLicense
|
139
invokeai/backend/model_manager/metadata/fetch.py
Normal file
139
invokeai/backend/model_manager/metadata/fetch.py
Normal file
@ -0,0 +1,139 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
|
||||
"""
|
||||
This module fetches metadata objects from HuggingFace and Civitai.
|
||||
|
||||
Usage:
|
||||
|
||||
from invokeai.backend.model_manager.metadata.fetch import MetadataFetch
|
||||
|
||||
metadata = MetadataFetch.from_civitai_url("https://civitai.com/models/58390/detail-tweaker-lora-lora")
|
||||
print(metadata.description)
|
||||
"""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Optional, Any
|
||||
from datetime import datetime
|
||||
|
||||
import requests
|
||||
from huggingface_hub import HfApi, configure_http_backend
|
||||
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
|
||||
from invokeai.app.services.model_records import UnknownModelException
|
||||
|
||||
from .base import (
|
||||
CivitaiMetadata,
|
||||
HuggingFaceMetadata,
|
||||
LicenseRestrictions,
|
||||
CommercialUsage,
|
||||
)
|
||||
|
||||
HF_MODEL_RE = r"https?://huggingface.co/([\w\-.]+/[\w\-.]+)"
|
||||
CIVITAI_MODEL_PAGE_RE = r"https?://civitai.com/models/(\d+)"
|
||||
CIVITAI_VERSION_PAGE_RE = r"https?://civitai.com/models/(\d+)\?modelVersionId=(\d+)"
|
||||
CIVITAI_DOWNLOAD_RE = r"https?://civitai.com/api/download/models/(\d+)"
|
||||
|
||||
CIVITAI_VERSION_ENDPOINT = "https://civitai.com/api/v1/model-versions/"
|
||||
CIVITAI_MODEL_ENDPOINT = "https://civitai.com/api/v1/models/"
|
||||
|
||||
class MetadataFetch:
|
||||
"""Fetch metadata from HuggingFace and Civitai URLs."""
|
||||
|
||||
_requests: Session
|
||||
|
||||
def __init__(self, session: Optional[Session]=None):
|
||||
"""
|
||||
Initialize the fetcher with an optional requests.sessions.Session object.
|
||||
|
||||
By providing a configurable Session object, we can support unit tests on
|
||||
this module without an internet connection.
|
||||
"""
|
||||
self._requests = session or requests.Session()
|
||||
configure_http_backend(backend_factory = lambda: self._requests)
|
||||
|
||||
def from_huggingface_repoid(self, repo_id: str) -> HuggingFaceMetadata:
|
||||
"""Return a HuggingFaceMetadata object given the model's repo_id."""
|
||||
try:
|
||||
model_info = HfApi().model_info(repo_id=repo_id, files_metadata=True)
|
||||
except RepositoryNotFoundError as excp:
|
||||
raise UnknownModelException(f"'{repo_id}' not found. See trace for details.") from excp
|
||||
|
||||
_, name = repo_id.split("/")
|
||||
return HuggingFaceMetadata(
|
||||
id = model_info.modelId,
|
||||
author = model_info.author,
|
||||
name = name,
|
||||
last_modified = model_info.lastModified,
|
||||
tags = model_info.tags,
|
||||
tag_dict = model_info.card_data.to_dict(),
|
||||
files = [Path(x.rfilename) for x in model_info.siblings]
|
||||
)
|
||||
|
||||
def from_huggingface_url(self, url: AnyHttpUrl) -> HuggingFaceMetadata:
|
||||
"""
|
||||
Return a HuggingFaceMetadata object given the model's web page URL.
|
||||
|
||||
In the case of an invalid or missing URL, raises a ModelNotFound exception.
|
||||
"""
|
||||
if match := re.match(HF_MODEL_RE, str(url)):
|
||||
repo_id = match.group(1)
|
||||
return self.from_huggingface_repoid(repo_id)
|
||||
else:
|
||||
raise UnknownModelException(f"'{url}' does not look like a HuggingFace model page")
|
||||
|
||||
def from_civitai_versionid(self, version_id: int, model_metadata: Optional[Dict[str,Any]]=None) -> CivitaiMetadata:
|
||||
"""Return Civitai metadata using a model's version id."""
|
||||
version_url = CIVITAI_VERSION_ENDPOINT + str(version_id)
|
||||
version = self._requests.get(version_url).json()
|
||||
|
||||
model_url = CIVITAI_MODEL_ENDPOINT + str(version['modelId'])
|
||||
model = model_metadata or self._requests.get(model_url).json()
|
||||
safe_thumbnails = [x['url'] for x in version['images'] if x['nsfw']=='None']
|
||||
|
||||
return CivitaiMetadata(
|
||||
id=version['modelId'],
|
||||
name=version['model']['name'],
|
||||
version_id=version['id'],
|
||||
version_name=version['name'],
|
||||
created=datetime.fromisoformat(re.sub(r"Z$", "+00:00", version['createdAt'])),
|
||||
base_model_trained_on=version['baseModel'], # note - need a dictionary to turn into a BaseModelType
|
||||
download_url=version['downloadUrl'],
|
||||
thumbnail_url=safe_thumbnails[0] if safe_thumbnails else None,
|
||||
author=model['creator']['username'],
|
||||
description=model['description'],
|
||||
version_description=version['description'] or "",
|
||||
tags=model['tags'],
|
||||
trained_words=version['trainedWords'],
|
||||
nsfw=version['model']['nsfw'],
|
||||
restrictions=LicenseRestrictions(
|
||||
AllowNoCredit=model['allowNoCredit'],
|
||||
AllowCommercialUse=CommercialUsage(model['allowCommercialUse']),
|
||||
AllowDerivatives=model['allowDerivatives'],
|
||||
AllowDifferentLicense=model['allowDifferentLicense']
|
||||
),
|
||||
)
|
||||
|
||||
def from_civitai_modelid(self, model_id: int) -> CivitaiMetadata:
|
||||
"""Return metadata from the default version of the indicated model."""
|
||||
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
|
||||
model = self._requests.get(model_url).json()
|
||||
default_version = model['modelVersions'][0]['id']
|
||||
return self.from_civitai_versionid(default_version, model)
|
||||
|
||||
def from_civitai_url(self, url: AnyHttpUrl) -> CivitaiMetadata:
|
||||
"""Parse a Civitai URL that user is likely to pass and return its metadata."""
|
||||
if match := re.match(CIVITAI_MODEL_PAGE_RE, str(url)):
|
||||
model_id = match.group(1)
|
||||
return self.from_civitai_modelid(int(model_id))
|
||||
elif match := re.match(CIVITAI_VERSION_PAGE_RE, str(url)):
|
||||
version_id = match.group(1)
|
||||
return self.from_civitai_versionid(int(version_id))
|
||||
elif match := re.match(CIVITAI_DOWNLOAD_RE, str(url)):
|
||||
version_id = match.group(1)
|
||||
return self.from_civitai_versionid(int(version_id))
|
||||
raise UnknownModelException("The url '{url}' does not match any known Civitai URL patterns")
|
||||
|
||||
|
Reference in New Issue
Block a user