mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Remove civit AI model install resources
This commit is contained in:
parent
5118160282
commit
d1f859a446
@ -583,31 +583,6 @@ there is special-case code in the installer that looks for HuggingFace
|
|||||||
and Civitai URLs and fetches the corresponding model metadata from
|
and Civitai URLs and fetches the corresponding model metadata from
|
||||||
the corresponding repo.
|
the corresponding repo.
|
||||||
|
|
||||||
#### CivitaiModelSource
|
|
||||||
|
|
||||||
This is used for a model that is hosted by the Civitai web site.
|
|
||||||
|
|
||||||
| **Argument** | **Type** | **Default** | **Description** |
|
|
||||||
|------------------|------------------------------|-------------|-------------------------------------------|
|
|
||||||
| `version_id` | int | None | The ID of the particular version of the desired model. |
|
|
||||||
| `access_token` | str | None | An access token needed to gain access to a subscriber's-only model. |
|
|
||||||
|
|
||||||
Civitai has two model IDs, both of which are integers. The `model_id`
|
|
||||||
corresponds to a collection of model versions that may different in
|
|
||||||
arbitrary ways, such as derivation from different checkpoint training
|
|
||||||
steps, SFW vs NSFW generation, pruned vs non-pruned, etc. The
|
|
||||||
`version_id` points to a specific version. Please use the latter.
|
|
||||||
|
|
||||||
Some Civitai models require an access token to download. These can be
|
|
||||||
generated from the Civitai profile page of a logged-in
|
|
||||||
account. Somewhat annoyingly, if you fail to provide the access token
|
|
||||||
when downloading a model that needs it, Civitai generates a redirect
|
|
||||||
to a login page rather than a 403 Forbidden error. The installer
|
|
||||||
attempts to catch this event and issue an informative error
|
|
||||||
message. Otherwise you will get an "unrecognized model suffix" error
|
|
||||||
when the model prober tries to identify the type of the HTML login
|
|
||||||
page.
|
|
||||||
|
|
||||||
#### HFModelSource
|
#### HFModelSource
|
||||||
|
|
||||||
HuggingFace has the most complicated `ModelSource` structure:
|
HuggingFace has the most complicated `ModelSource` structure:
|
||||||
|
@ -287,9 +287,6 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
deny_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to deny. Omit to deny none.", json_schema_extra=Categories.Nodes)
|
deny_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to deny. Omit to deny none.", json_schema_extra=Categories.Nodes)
|
||||||
node_cache_size : int = Field(default=512, description="How many cached nodes to keep in memory", json_schema_extra=Categories.Nodes)
|
node_cache_size : int = Field(default=512, description="How many cached nodes to keep in memory", json_schema_extra=Categories.Nodes)
|
||||||
|
|
||||||
# MODEL IMPORT
|
|
||||||
civitai_api_key : Optional[str] = Field(default=os.environ.get("CIVITAI_API_KEY"), description="API key for CivitAI", json_schema_extra=Categories.Other)
|
|
||||||
|
|
||||||
# DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES
|
# DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES
|
||||||
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", json_schema_extra=Categories.MemoryPerformance)
|
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", json_schema_extra=Categories.MemoryPerformance)
|
||||||
max_cache_size : Optional[float] = Field(default=None, gt=0, description="Maximum memory amount used by model cache for rapid switching", json_schema_extra=Categories.MemoryPerformance)
|
max_cache_size : Optional[float] = Field(default=None, gt=0, description="Maximum memory amount used by model cache for rapid switching", json_schema_extra=Categories.MemoryPerformance)
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
"""Initialization file for model install service package."""
|
"""Initialization file for model install service package."""
|
||||||
|
|
||||||
from .model_install_base import (
|
from .model_install_base import (
|
||||||
CivitaiModelSource,
|
|
||||||
HFModelSource,
|
HFModelSource,
|
||||||
InstallStatus,
|
InstallStatus,
|
||||||
LocalModelSource,
|
LocalModelSource,
|
||||||
@ -23,5 +22,4 @@ __all__ = [
|
|||||||
"LocalModelSource",
|
"LocalModelSource",
|
||||||
"HFModelSource",
|
"HFModelSource",
|
||||||
"URLModelSource",
|
"URLModelSource",
|
||||||
"CivitaiModelSource",
|
|
||||||
]
|
]
|
||||||
|
@ -91,21 +91,6 @@ class LocalModelSource(StringLikeSource):
|
|||||||
return Path(self.path).as_posix()
|
return Path(self.path).as_posix()
|
||||||
|
|
||||||
|
|
||||||
class CivitaiModelSource(StringLikeSource):
|
|
||||||
"""A Civitai version id, with optional variant and access token."""
|
|
||||||
|
|
||||||
version_id: int
|
|
||||||
variant: Optional[ModelRepoVariant] = None
|
|
||||||
access_token: Optional[str] = None
|
|
||||||
type: Literal["civitai"] = "civitai"
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
"""Return string version of repoid when string rep needed."""
|
|
||||||
base: str = str(self.version_id)
|
|
||||||
base += f" ({self.variant})" if self.variant else ""
|
|
||||||
return base
|
|
||||||
|
|
||||||
|
|
||||||
class HFModelSource(StringLikeSource):
|
class HFModelSource(StringLikeSource):
|
||||||
"""
|
"""
|
||||||
A HuggingFace repo_id with optional variant, sub-folder and access token.
|
A HuggingFace repo_id with optional variant, sub-folder and access token.
|
||||||
@ -147,13 +132,12 @@ class URLModelSource(StringLikeSource):
|
|||||||
|
|
||||||
|
|
||||||
ModelSource = Annotated[
|
ModelSource = Annotated[
|
||||||
Union[LocalModelSource, HFModelSource, CivitaiModelSource, URLModelSource], Field(discriminator="type")
|
Union[LocalModelSource, HFModelSource, URLModelSource], Field(discriminator="type")
|
||||||
]
|
]
|
||||||
|
|
||||||
MODEL_SOURCE_TO_TYPE_MAP = {
|
MODEL_SOURCE_TO_TYPE_MAP = {
|
||||||
URLModelSource: ModelSourceType.Url,
|
URLModelSource: ModelSourceType.Url,
|
||||||
HFModelSource: ModelSourceType.HFRepoID,
|
HFModelSource: ModelSourceType.HFRepoID,
|
||||||
CivitaiModelSource: ModelSourceType.CivitAI,
|
|
||||||
LocalModelSource: ModelSourceType.Path,
|
LocalModelSource: ModelSourceType.Path,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -33,12 +33,11 @@ from invokeai.backend.model_manager.config import (
|
|||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.metadata import (
|
from invokeai.backend.model_manager.metadata import (
|
||||||
AnyModelRepoMetadata,
|
AnyModelRepoMetadata,
|
||||||
CivitaiMetadataFetch,
|
|
||||||
HuggingFaceMetadataFetch,
|
HuggingFaceMetadataFetch,
|
||||||
ModelMetadataWithFiles,
|
ModelMetadataWithFiles,
|
||||||
RemoteModelFile,
|
RemoteModelFile,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.metadata.metadata_base import CivitaiMetadata, HuggingFaceMetadata
|
from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMetadata
|
||||||
from invokeai.backend.model_manager.probe import ModelProbe
|
from invokeai.backend.model_manager.probe import ModelProbe
|
||||||
from invokeai.backend.model_manager.search import ModelSearch
|
from invokeai.backend.model_manager.search import ModelSearch
|
||||||
from invokeai.backend.util import Chdir, InvokeAILogger
|
from invokeai.backend.util import Chdir, InvokeAILogger
|
||||||
@ -46,7 +45,6 @@ from invokeai.backend.util.devices import choose_precision, choose_torch_device
|
|||||||
|
|
||||||
from .model_install_base import (
|
from .model_install_base import (
|
||||||
MODEL_SOURCE_TO_TYPE_MAP,
|
MODEL_SOURCE_TO_TYPE_MAP,
|
||||||
CivitaiModelSource,
|
|
||||||
HFModelSource,
|
HFModelSource,
|
||||||
InstallStatus,
|
InstallStatus,
|
||||||
LocalModelSource,
|
LocalModelSource,
|
||||||
@ -216,8 +214,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
if isinstance(source, LocalModelSource):
|
if isinstance(source, LocalModelSource):
|
||||||
install_job = self._import_local_model(source, config)
|
install_job = self._import_local_model(source, config)
|
||||||
self._install_queue.put(install_job) # synchronously install
|
self._install_queue.put(install_job) # synchronously install
|
||||||
elif isinstance(source, CivitaiModelSource):
|
|
||||||
install_job = self._import_from_civitai(source, config)
|
|
||||||
elif isinstance(source, HFModelSource):
|
elif isinstance(source, HFModelSource):
|
||||||
install_job = self._import_from_hf(source, config)
|
install_job = self._import_from_hf(source, config)
|
||||||
elif isinstance(source, URLModelSource):
|
elif isinstance(source, URLModelSource):
|
||||||
@ -381,10 +377,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
job.config_in["source"] = str(job.source)
|
job.config_in["source"] = str(job.source)
|
||||||
job.config_in["source_type"] = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
|
job.config_in["source_type"] = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
|
||||||
# enter the metadata, if there is any
|
# enter the metadata, if there is any
|
||||||
if isinstance(job.source_metadata, (CivitaiMetadata, HuggingFaceMetadata)):
|
if isinstance(job.source_metadata, (HuggingFaceMetadata)):
|
||||||
job.config_in["source_api_response"] = job.source_metadata.api_response
|
job.config_in["source_api_response"] = job.source_metadata.api_response
|
||||||
if isinstance(job.source_metadata, CivitaiMetadata) and job.source_metadata.trigger_phrases:
|
|
||||||
job.config_in["trigger_phrases"] = job.source_metadata.trigger_phrases
|
|
||||||
|
|
||||||
if job.inplace:
|
if job.inplace:
|
||||||
key = self.register_path(job.local_path, job.config_in)
|
key = self.register_path(job.local_path, job.config_in)
|
||||||
@ -573,16 +567,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
inplace=source.inplace or False,
|
inplace=source.inplace or False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _import_from_civitai(self, source: CivitaiModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
|
||||||
if not source.access_token:
|
|
||||||
self._logger.info("No Civitai access token provided; some models may not be downloadable.")
|
|
||||||
metadata = CivitaiMetadataFetch(self._session, self.app_config.get_config().civitai_api_key).from_id(
|
|
||||||
str(source.version_id)
|
|
||||||
)
|
|
||||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
|
||||||
remote_files = metadata.download_urls(session=self._session)
|
|
||||||
return self._import_remote_model(source=source, config=config, metadata=metadata, remote_files=remote_files)
|
|
||||||
|
|
||||||
def _import_from_hf(self, source: HFModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
def _import_from_hf(self, source: HFModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||||
# Add user's cached access token to HuggingFace requests
|
# Add user's cached access token to HuggingFace requests
|
||||||
source.access_token = source.access_token or HfFolder.get_token()
|
source.access_token = source.access_token or HfFolder.get_token()
|
||||||
@ -613,8 +597,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
kwargs: dict[str, Any] = {"session": self._session}
|
kwargs: dict[str, Any] = {"session": self._session}
|
||||||
if fetcher is CivitaiMetadataFetch:
|
|
||||||
kwargs["api_key"] = self._app_config.get_config().civitai_api_key
|
|
||||||
if fetcher is not None:
|
if fetcher is not None:
|
||||||
metadata = fetcher(**kwargs).from_url(source.url)
|
metadata = fetcher(**kwargs).from_url(source.url)
|
||||||
self._logger.debug(f"metadata={metadata}")
|
self._logger.debug(f"metadata={metadata}")
|
||||||
@ -631,7 +613,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
|
|
||||||
def _import_remote_model(
|
def _import_remote_model(
|
||||||
self,
|
self,
|
||||||
source: HFModelSource | CivitaiModelSource | URLModelSource,
|
source: HFModelSource | URLModelSource,
|
||||||
remote_files: List[RemoteModelFile],
|
remote_files: List[RemoteModelFile],
|
||||||
metadata: Optional[AnyModelRepoMetadata],
|
metadata: Optional[AnyModelRepoMetadata],
|
||||||
config: Optional[Dict[str, Any]],
|
config: Optional[Dict[str, Any]],
|
||||||
@ -849,8 +831,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_fetcher_from_url(url: str):
|
def get_fetcher_from_url(url: str):
|
||||||
if re.match(r"^https?://civitai.com/", url.lower()):
|
if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
|
||||||
return CivitaiMetadataFetch
|
|
||||||
elif re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
|
|
||||||
return HuggingFaceMetadataFetch
|
return HuggingFaceMetadataFetch
|
||||||
raise ValueError(f"Unsupported model source: '{url}'")
|
raise ValueError(f"Unsupported model source: '{url}'")
|
||||||
|
@ -129,7 +129,6 @@ class ModelSourceType(str, Enum):
|
|||||||
Path = "path"
|
Path = "path"
|
||||||
Url = "url"
|
Url = "url"
|
||||||
HFRepoID = "hf_repo_id"
|
HFRepoID = "hf_repo_id"
|
||||||
CivitAI = "civitai"
|
|
||||||
|
|
||||||
|
|
||||||
class ModelDefaultSettings(BaseModel):
|
class ModelDefaultSettings(BaseModel):
|
||||||
|
@ -8,7 +8,6 @@ from invokeai.backend.model_manager.metadata import(
|
|||||||
CommercialUsage,
|
CommercialUsage,
|
||||||
LicenseRestrictions,
|
LicenseRestrictions,
|
||||||
HuggingFaceMetadata,
|
HuggingFaceMetadata,
|
||||||
CivitaiMetadata,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from invokeai.backend.model_manager.metadata.fetch import CivitaiMetadataFetch
|
from invokeai.backend.model_manager.metadata.fetch import CivitaiMetadataFetch
|
||||||
@ -19,12 +18,11 @@ if data.allow_commercial_use:
|
|||||||
print("Commercial use of this model is allowed")
|
print("Commercial use of this model is allowed")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .fetch import CivitaiMetadataFetch, HuggingFaceMetadataFetch, ModelMetadataFetchBase
|
from .fetch import HuggingFaceMetadataFetch, ModelMetadataFetchBase
|
||||||
from .metadata_base import (
|
from .metadata_base import (
|
||||||
AnyModelRepoMetadata,
|
AnyModelRepoMetadata,
|
||||||
AnyModelRepoMetadataValidator,
|
AnyModelRepoMetadataValidator,
|
||||||
BaseMetadata,
|
BaseMetadata,
|
||||||
CivitaiMetadata,
|
|
||||||
HuggingFaceMetadata,
|
HuggingFaceMetadata,
|
||||||
ModelMetadataWithFiles,
|
ModelMetadataWithFiles,
|
||||||
RemoteModelFile,
|
RemoteModelFile,
|
||||||
@ -34,8 +32,6 @@ from .metadata_base import (
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"AnyModelRepoMetadata",
|
"AnyModelRepoMetadata",
|
||||||
"AnyModelRepoMetadataValidator",
|
"AnyModelRepoMetadataValidator",
|
||||||
"CivitaiMetadata",
|
|
||||||
"CivitaiMetadataFetch",
|
|
||||||
"HuggingFaceMetadata",
|
"HuggingFaceMetadata",
|
||||||
"HuggingFaceMetadataFetch",
|
"HuggingFaceMetadataFetch",
|
||||||
"ModelMetadataFetchBase",
|
"ModelMetadataFetchBase",
|
||||||
|
@ -14,8 +14,7 @@ if data.allow_commercial_use:
|
|||||||
print("Commercial use of this model is allowed")
|
print("Commercial use of this model is allowed")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .civitai import CivitaiMetadataFetch
|
|
||||||
from .fetch_base import ModelMetadataFetchBase
|
from .fetch_base import ModelMetadataFetchBase
|
||||||
from .huggingface import HuggingFaceMetadataFetch
|
from .huggingface import HuggingFaceMetadataFetch
|
||||||
|
|
||||||
__all__ = ["ModelMetadataFetchBase", "CivitaiMetadataFetch", "HuggingFaceMetadataFetch"]
|
__all__ = ["ModelMetadataFetchBase", "HuggingFaceMetadataFetch"]
|
||||||
|
@ -1,188 +0,0 @@
|
|||||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
|
||||||
|
|
||||||
"""
|
|
||||||
This module fetches model metadata objects from the Civitai model repository.
|
|
||||||
In addition to the `from_url()` and `from_id()` methods inherited from the
|
|
||||||
`ModelMetadataFetchBase` base class.
|
|
||||||
|
|
||||||
Civitai has two separate ID spaces: a model ID and a version ID. The
|
|
||||||
version ID corresponds to a specific model, and is the ID accepted by
|
|
||||||
`from_id()`. The model ID corresponds to a family of related models,
|
|
||||||
such as different training checkpoints or 16 vs 32-bit versions. The
|
|
||||||
`from_civitai_modelid()` method will accept a model ID and return the
|
|
||||||
metadata from the default version within this model set. The default
|
|
||||||
version is the same as what the user sees when they click on a model's
|
|
||||||
thumbnail.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
|
|
||||||
from invokeai.backend.model_manager.metadata.fetch import CivitaiMetadataFetch
|
|
||||||
|
|
||||||
fetcher = CivitaiMetadataFetch()
|
|
||||||
metadata = fetcher.from_url("https://civitai.com/models/206883/split")
|
|
||||||
print(metadata.trained_words)
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import re
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
import requests
|
|
||||||
from pydantic import TypeAdapter, ValidationError
|
|
||||||
from pydantic.networks import AnyHttpUrl
|
|
||||||
from requests.sessions import Session
|
|
||||||
|
|
||||||
from invokeai.backend.model_manager.config import ModelRepoVariant
|
|
||||||
|
|
||||||
from ..metadata_base import (
|
|
||||||
AnyModelRepoMetadata,
|
|
||||||
CivitaiMetadata,
|
|
||||||
RemoteModelFile,
|
|
||||||
UnknownMetadataException,
|
|
||||||
)
|
|
||||||
from .fetch_base import ModelMetadataFetchBase
|
|
||||||
|
|
||||||
CIVITAI_MODEL_PAGE_RE = r"https?://civitai.com/models/(\d+)"
|
|
||||||
CIVITAI_VERSION_PAGE_RE = r"https?://civitai.com/models/(\d+)\?modelVersionId=(\d+)"
|
|
||||||
CIVITAI_DOWNLOAD_RE = r"https?://civitai.com/api/download/models/(\d+)"
|
|
||||||
|
|
||||||
CIVITAI_VERSION_ENDPOINT = "https://civitai.com/api/v1/model-versions/"
|
|
||||||
CIVITAI_MODEL_ENDPOINT = "https://civitai.com/api/v1/models/"
|
|
||||||
|
|
||||||
|
|
||||||
StringSetAdapter = TypeAdapter(set[str])
|
|
||||||
|
|
||||||
|
|
||||||
class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
|
||||||
"""Fetch model metadata from Civitai."""
|
|
||||||
|
|
||||||
def __init__(self, session: Optional[Session] = None, api_key: Optional[str] = None):
|
|
||||||
"""
|
|
||||||
Initialize the fetcher with an optional requests.sessions.Session object.
|
|
||||||
|
|
||||||
By providing a configurable Session object, we can support unit tests on
|
|
||||||
this module without an internet connection.
|
|
||||||
"""
|
|
||||||
self._requests = session or requests.Session()
|
|
||||||
self._api_key = api_key
|
|
||||||
|
|
||||||
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
|
||||||
"""
|
|
||||||
Given a URL to a CivitAI model or version page, return a ModelMetadata object.
|
|
||||||
|
|
||||||
In the event that the URL points to a model page without the particular version
|
|
||||||
indicated, the default model version is returned. Otherwise, the requested version
|
|
||||||
is returned.
|
|
||||||
"""
|
|
||||||
if match := re.match(CIVITAI_VERSION_PAGE_RE, str(url), re.IGNORECASE):
|
|
||||||
model_id = match.group(1)
|
|
||||||
version_id = match.group(2)
|
|
||||||
return self.from_civitai_versionid(int(version_id), int(model_id))
|
|
||||||
elif match := re.match(CIVITAI_MODEL_PAGE_RE, str(url), re.IGNORECASE):
|
|
||||||
model_id = match.group(1)
|
|
||||||
return self.from_civitai_modelid(int(model_id))
|
|
||||||
elif match := re.match(CIVITAI_DOWNLOAD_RE, str(url), re.IGNORECASE):
|
|
||||||
version_id = match.group(1)
|
|
||||||
return self.from_civitai_versionid(int(version_id))
|
|
||||||
raise UnknownMetadataException("The url '{url}' does not match any known Civitai URL patterns")
|
|
||||||
|
|
||||||
def from_id(self, id: str, variant: Optional[ModelRepoVariant] = None) -> AnyModelRepoMetadata:
|
|
||||||
"""
|
|
||||||
Given a Civitai model version ID, return a ModelRepoMetadata object.
|
|
||||||
|
|
||||||
:param id: An ID.
|
|
||||||
:param variant: A model variant from the ModelRepoVariant enum (currently ignored)
|
|
||||||
|
|
||||||
May raise an `UnknownMetadataException`.
|
|
||||||
"""
|
|
||||||
return self.from_civitai_versionid(int(id))
|
|
||||||
|
|
||||||
def from_civitai_modelid(self, model_id: int) -> CivitaiMetadata:
|
|
||||||
"""
|
|
||||||
Return metadata from the default version of the indicated model.
|
|
||||||
|
|
||||||
May raise an `UnknownMetadataException`.
|
|
||||||
"""
|
|
||||||
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
|
|
||||||
model_json = self._requests.get(self._get_url_with_api_key(model_url)).json()
|
|
||||||
return self._from_api_response(model_json)
|
|
||||||
|
|
||||||
def _from_api_response(self, api_response: dict[str, Any], version_id: Optional[int] = None) -> CivitaiMetadata:
|
|
||||||
try:
|
|
||||||
version_id = version_id or api_response["modelVersions"][0]["id"]
|
|
||||||
except TypeError as excp:
|
|
||||||
raise UnknownMetadataException from excp
|
|
||||||
|
|
||||||
# loop till we find the section containing the version requested
|
|
||||||
version_sections = [x for x in api_response["modelVersions"] if x["id"] == version_id]
|
|
||||||
if not version_sections:
|
|
||||||
raise UnknownMetadataException(f"Version {version_id} not found in model metadata")
|
|
||||||
|
|
||||||
version_json = version_sections[0]
|
|
||||||
|
|
||||||
# Civitai has one "primary" file plus others such as VAEs. We only fetch the primary.
|
|
||||||
primary = [x for x in version_json["files"] if x.get("primary")]
|
|
||||||
assert len(primary) == 1
|
|
||||||
primary_file = primary[0]
|
|
||||||
|
|
||||||
url = primary_file["downloadUrl"]
|
|
||||||
if "?" not in url: # work around apparent bug in civitai api
|
|
||||||
metadata_string = ""
|
|
||||||
for key, value in primary_file["metadata"].items():
|
|
||||||
if not value:
|
|
||||||
continue
|
|
||||||
metadata_string += f"&{key}={value}"
|
|
||||||
url = url + f"?type={primary_file['type']}{metadata_string}"
|
|
||||||
model_files = [
|
|
||||||
RemoteModelFile(
|
|
||||||
url=self._get_url_with_api_key(url),
|
|
||||||
path=Path(primary_file["name"]),
|
|
||||||
size=int(primary_file["sizeKB"] * 1024),
|
|
||||||
sha256=primary_file["hashes"]["SHA256"],
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
try:
|
|
||||||
trigger_phrases = StringSetAdapter.validate_python(version_json.get("trainedWords"))
|
|
||||||
except ValidationError:
|
|
||||||
trigger_phrases: set[str] = set()
|
|
||||||
|
|
||||||
return CivitaiMetadata(
|
|
||||||
name=version_json["name"],
|
|
||||||
files=model_files,
|
|
||||||
trigger_phrases=trigger_phrases,
|
|
||||||
api_response=json.dumps(version_json),
|
|
||||||
)
|
|
||||||
|
|
||||||
def from_civitai_versionid(self, version_id: int, model_id: Optional[int] = None) -> CivitaiMetadata:
|
|
||||||
"""
|
|
||||||
Return a CivitaiMetadata object given a model version id.
|
|
||||||
|
|
||||||
May raise an `UnknownMetadataException`.
|
|
||||||
"""
|
|
||||||
if model_id is None:
|
|
||||||
version_url = CIVITAI_VERSION_ENDPOINT + str(version_id)
|
|
||||||
version = self._requests.get(self._get_url_with_api_key(version_url)).json()
|
|
||||||
if error := version.get("error"):
|
|
||||||
raise UnknownMetadataException(error)
|
|
||||||
model_id = version["modelId"]
|
|
||||||
|
|
||||||
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
|
|
||||||
model_json = self._requests.get(self._get_url_with_api_key(model_url)).json()
|
|
||||||
return self._from_api_response(model_json, version_id)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_json(cls, json: str) -> CivitaiMetadata:
|
|
||||||
"""Given the JSON representation of the metadata, return the corresponding Pydantic object."""
|
|
||||||
metadata = CivitaiMetadata.model_validate_json(json)
|
|
||||||
return metadata
|
|
||||||
|
|
||||||
def _get_url_with_api_key(self, url: str) -> str:
|
|
||||||
if not self._api_key:
|
|
||||||
return url
|
|
||||||
|
|
||||||
if "?" in url:
|
|
||||||
return f"{url}&token={self._api_key}"
|
|
||||||
|
|
||||||
return f"{url}?token={self._api_key}"
|
|
@ -78,14 +78,6 @@ class ModelMetadataWithFiles(ModelMetadataBase):
|
|||||||
return self.files
|
return self.files
|
||||||
|
|
||||||
|
|
||||||
class CivitaiMetadata(ModelMetadataWithFiles):
|
|
||||||
"""Extended metadata fields provided by Civitai."""
|
|
||||||
|
|
||||||
type: Literal["civitai"] = "civitai"
|
|
||||||
trigger_phrases: set[str] = Field(description="Trigger phrases extracted from the API response")
|
|
||||||
api_response: Optional[str] = Field(description="Response from the Civitai API as stringified JSON", default=None)
|
|
||||||
|
|
||||||
|
|
||||||
class HuggingFaceMetadata(ModelMetadataWithFiles):
|
class HuggingFaceMetadata(ModelMetadataWithFiles):
|
||||||
"""Extended metadata fields provided by HuggingFace."""
|
"""Extended metadata fields provided by HuggingFace."""
|
||||||
|
|
||||||
@ -130,5 +122,5 @@ class HuggingFaceMetadata(ModelMetadataWithFiles):
|
|||||||
return [x for x in self.files if x.path in paths]
|
return [x for x in self.files if x.path in paths]
|
||||||
|
|
||||||
|
|
||||||
AnyModelRepoMetadata = Annotated[Union[BaseMetadata, HuggingFaceMetadata, CivitaiMetadata], Field(discriminator="type")]
|
AnyModelRepoMetadata = Annotated[Union[BaseMetadata, HuggingFaceMetadata], Field(discriminator="type")]
|
||||||
AnyModelRepoMetadataValidator = TypeAdapter(AnyModelRepoMetadata)
|
AnyModelRepoMetadataValidator = TypeAdapter(AnyModelRepoMetadata)
|
||||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user