mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor(mm): get metadata working
This commit is contained in:
parent
7cb0da1f66
commit
c3aa985c93
@ -3,8 +3,6 @@
|
|||||||
|
|
||||||
import pathlib
|
import pathlib
|
||||||
import shutil
|
import shutil
|
||||||
from hashlib import sha1
|
|
||||||
from random import randbytes
|
|
||||||
from typing import Any, Dict, List, Optional, Set
|
from typing import Any, Dict, List, Optional, Set
|
||||||
|
|
||||||
from fastapi import Body, Path, Query, Response
|
from fastapi import Body, Path, Query, Response
|
||||||
@ -461,11 +459,8 @@ async def add_model_record(
|
|||||||
"""Add a model using the configuration information appropriate for its type."""
|
"""Add a model using the configuration information appropriate for its type."""
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||||
if config.key == "<NOKEY>":
|
|
||||||
config.key = sha1(randbytes(100)).hexdigest()
|
|
||||||
logger.info(f"Created model {config.key} for {config.name}")
|
|
||||||
try:
|
try:
|
||||||
record_store.add_model(config.key, config)
|
record_store.add_model(config)
|
||||||
except DuplicateModelException as e:
|
except DuplicateModelException as e:
|
||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
raise HTTPException(status_code=409, detail=str(e))
|
raise HTTPException(status_code=409, detail=str(e))
|
||||||
|
@ -556,7 +556,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
# make config relative to our root
|
# make config relative to our root
|
||||||
legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config_path).resolve()
|
legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config_path).resolve()
|
||||||
info.config_path = legacy_conf.relative_to(self.app_config.root_dir).as_posix()
|
info.config_path = legacy_conf.relative_to(self.app_config.root_dir).as_posix()
|
||||||
self.record_store.add_model(info.key, info)
|
self.record_store.add_model(info)
|
||||||
return info.key
|
return info.key
|
||||||
|
|
||||||
def _next_id(self) -> int:
|
def _next_id(self) -> int:
|
||||||
@ -583,7 +583,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
def _import_from_civitai(self, source: CivitaiModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
def _import_from_civitai(self, source: CivitaiModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||||
if not source.access_token:
|
if not source.access_token:
|
||||||
self._logger.info("No Civitai access token provided; some models may not be downloadable.")
|
self._logger.info("No Civitai access token provided; some models may not be downloadable.")
|
||||||
metadata = CivitaiMetadataFetch(self._session).from_id(str(source.version_id))
|
metadata = CivitaiMetadataFetch(self._session, self.app_config.get_config().civitai_api_key).from_id(
|
||||||
|
str(source.version_id)
|
||||||
|
)
|
||||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||||
remote_files = metadata.download_urls(session=self._session)
|
remote_files = metadata.download_urls(session=self._session)
|
||||||
return self._import_remote_model(source=source, config=config, metadata=metadata, remote_files=remote_files)
|
return self._import_remote_model(source=source, config=config, metadata=metadata, remote_files=remote_files)
|
||||||
@ -611,15 +613,17 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
|
|
||||||
def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||||
# URLs from Civitai or HuggingFace will be handled specially
|
# URLs from Civitai or HuggingFace will be handled specially
|
||||||
url_patterns = {
|
|
||||||
r"^https?://civitai.com/": CivitaiMetadataFetch,
|
|
||||||
r"^https?://huggingface.co/[^/]+/[^/]+$": HuggingFaceMetadataFetch,
|
|
||||||
}
|
|
||||||
metadata = None
|
metadata = None
|
||||||
for pattern, fetcher in url_patterns.items():
|
fetcher = None
|
||||||
if re.match(pattern, str(source.url), re.IGNORECASE):
|
try:
|
||||||
metadata = fetcher(self._session).from_url(source.url)
|
fetcher = self.get_fetcher_from_url(str(source.url))
|
||||||
break
|
except ValueError:
|
||||||
|
pass
|
||||||
|
kwargs: dict[str, Any] = {"session": self._session}
|
||||||
|
if fetcher is CivitaiMetadataFetch:
|
||||||
|
kwargs["api_key"] = self._app_config.get_config().civitai_api_key
|
||||||
|
if fetcher is not None:
|
||||||
|
metadata = fetcher(**kwargs).from_url(source.url)
|
||||||
self._logger.debug(f"metadata={metadata}")
|
self._logger.debug(f"metadata={metadata}")
|
||||||
if metadata and isinstance(metadata, ModelMetadataWithFiles):
|
if metadata and isinstance(metadata, ModelMetadataWithFiles):
|
||||||
remote_files = metadata.download_urls(session=self._session)
|
remote_files = metadata.download_urls(session=self._session)
|
||||||
@ -849,3 +853,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._logger.info(f"{job.source}: model installation was cancelled")
|
self._logger.info(f"{job.source}: model installation was cancelled")
|
||||||
if self._event_bus:
|
if self._event_bus:
|
||||||
self._event_bus.emit_model_install_cancelled(str(job.source))
|
self._event_bus.emit_model_install_cancelled(str(job.source))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_fetcher_from_url(url: str):
|
||||||
|
if re.match(r"^https?://civitai.com/", url.lower()):
|
||||||
|
return CivitaiMetadataFetch
|
||||||
|
elif re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
|
||||||
|
return HuggingFaceMetadataFetch
|
||||||
|
raise ValueError(f"Unsupported model source: '{url}'")
|
||||||
|
@ -64,7 +64,7 @@ class ModelRecordServiceBase(ABC):
|
|||||||
"""Abstract base class for storage and retrieval of model configs."""
|
"""Abstract base class for storage and retrieval of model configs."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def add_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
|
def add_model(self, config: AnyModelConfig) -> AnyModelConfig:
|
||||||
"""
|
"""
|
||||||
Add a model to the database.
|
Add a model to the database.
|
||||||
|
|
||||||
|
@ -85,7 +85,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
"""Return the underlying database."""
|
"""Return the underlying database."""
|
||||||
return self._db
|
return self._db
|
||||||
|
|
||||||
def add_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
|
def add_model(self, config: AnyModelConfig) -> AnyModelConfig:
|
||||||
"""
|
"""
|
||||||
Add a model to the database.
|
Add a model to the database.
|
||||||
|
|
||||||
@ -95,8 +95,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
|
|
||||||
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
|
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
|
||||||
"""
|
"""
|
||||||
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
|
|
||||||
json_serialized = record.model_dump_json() # and turn it into a json string.
|
|
||||||
with self._db.lock:
|
with self._db.lock:
|
||||||
try:
|
try:
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
@ -108,8 +106,8 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
VALUES (?,?);
|
VALUES (?,?);
|
||||||
""",
|
""",
|
||||||
(
|
(
|
||||||
key,
|
config.key,
|
||||||
json_serialized,
|
config.model_dump_json(),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self._db.conn.commit()
|
self._db.conn.commit()
|
||||||
@ -118,11 +116,11 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
self._db.conn.rollback()
|
self._db.conn.rollback()
|
||||||
if "UNIQUE constraint failed" in str(e):
|
if "UNIQUE constraint failed" in str(e):
|
||||||
if "models.path" in str(e):
|
if "models.path" in str(e):
|
||||||
msg = f"A model with path '{record.path}' is already installed"
|
msg = f"A model with path '{config.path}' is already installed"
|
||||||
elif "models.name" in str(e):
|
elif "models.name" in str(e):
|
||||||
msg = f"A model with name='{record.name}', type='{record.type}', base='{record.base}' is already installed"
|
msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
|
||||||
else:
|
else:
|
||||||
msg = f"A model with key '{key}' is already installed"
|
msg = f"A model with key '{config.key}' is already installed"
|
||||||
raise DuplicateModelException(msg) from e
|
raise DuplicateModelException(msg) from e
|
||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
@ -130,7 +128,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
self._db.conn.rollback()
|
self._db.conn.rollback()
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
return self.get_model(key)
|
return self.get_model(config.key)
|
||||||
|
|
||||||
def del_model(self, key: str) -> None:
|
def del_model(self, key: str) -> None:
|
||||||
"""
|
"""
|
||||||
@ -263,14 +261,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
with self._db.lock:
|
with self._db.lock:
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
f"""--sql
|
f"""--sql
|
||||||
select config, strftime('%s',updated_at) FROM models
|
SELECT config, strftime('%s',updated_at) FROM models
|
||||||
{where};
|
{where};
|
||||||
""",
|
""",
|
||||||
tuple(bindings),
|
tuple(bindings),
|
||||||
)
|
)
|
||||||
results = [
|
result = self._cursor.fetchall()
|
||||||
ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()
|
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in result]
|
||||||
]
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
|
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
|
||||||
@ -347,6 +344,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
|
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
|
||||||
) -> PaginatedResults[ModelSummary]:
|
) -> PaginatedResults[ModelSummary]:
|
||||||
"""Return a paginated summary listing of each model in the database."""
|
"""Return a paginated summary listing of each model in the database."""
|
||||||
|
assert isinstance(order_by, ModelRecordOrderBy)
|
||||||
ordering = {
|
ordering = {
|
||||||
ModelRecordOrderBy.Default: "a.type, a.base, a.format, a.name",
|
ModelRecordOrderBy.Default: "a.type, a.base, a.format, a.name",
|
||||||
ModelRecordOrderBy.Type: "a.type",
|
ModelRecordOrderBy.Type: "a.type",
|
||||||
@ -355,14 +353,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
ModelRecordOrderBy.Format: "a.format",
|
ModelRecordOrderBy.Format: "a.format",
|
||||||
}
|
}
|
||||||
|
|
||||||
def _fixup(summary: Dict[str, str]) -> Dict[str, Union[str, int, Set[str]]]:
|
|
||||||
"""Fix up results so that there are no null values."""
|
|
||||||
result: Dict[str, Union[str, int, Set[str]]] = {}
|
|
||||||
for key, item in summary.items():
|
|
||||||
result[key] = item or ""
|
|
||||||
result["tags"] = set(json.loads(summary["tags"] or "[]"))
|
|
||||||
return result
|
|
||||||
|
|
||||||
# Lock so that the database isn't updated while we're doing the two queries.
|
# Lock so that the database isn't updated while we're doing the two queries.
|
||||||
with self._db.lock:
|
with self._db.lock:
|
||||||
# query1: get the total number of model configs
|
# query1: get the total number of model configs
|
||||||
@ -377,11 +367,8 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
# query2: fetch key fields from the join of models and model_metadata
|
# query2: fetch key fields from the join of models and model_metadata
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
f"""--sql
|
f"""--sql
|
||||||
SELECT a.id as key, a.type, a.base, a.format, a.name,
|
SELECT config
|
||||||
json_extract(a.config, '$.description') as description,
|
FROM models
|
||||||
json_extract(b.metadata, '$.tags') as tags
|
|
||||||
FROM models AS a
|
|
||||||
LEFT JOIN model_metadata AS b on a.id=b.id
|
|
||||||
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
|
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
|
||||||
LIMIT ?
|
LIMIT ?
|
||||||
OFFSET ?;
|
OFFSET ?;
|
||||||
@ -392,7 +379,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
rows = self._cursor.fetchall()
|
rows = self._cursor.fetchall()
|
||||||
items = [ModelSummary.model_validate(_fixup(dict(x))) for x in rows]
|
items = [ModelSummary.model_validate(dict(x)) for x in rows]
|
||||||
return PaginatedResults(
|
return PaginatedResults(
|
||||||
page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items
|
page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items
|
||||||
)
|
)
|
||||||
|
@ -26,7 +26,7 @@ from typing import Literal, Optional, Type, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers.models.modeling_utils import ModelMixin
|
from diffusers.models.modeling_utils import ModelMixin
|
||||||
from pydantic import BaseModel, ConfigDict, Discriminator, Field, JsonValue, Tag, TypeAdapter
|
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
|
||||||
from typing_extensions import Annotated, Any, Dict
|
from typing_extensions import Annotated, Any, Dict
|
||||||
|
|
||||||
from ..raw_model import RawModel
|
from ..raw_model import RawModel
|
||||||
@ -142,9 +142,7 @@ class ModelConfigBase(BaseModel):
|
|||||||
description: Optional[str] = Field(description="Model description", default=None)
|
description: Optional[str] = Field(description="Model description", default=None)
|
||||||
source: str = Field(description="The original source of the model (path, URL or repo_id).")
|
source: str = Field(description="The original source of the model (path, URL or repo_id).")
|
||||||
source_type: ModelSourceType = Field(description="The type of source")
|
source_type: ModelSourceType = Field(description="The type of source")
|
||||||
source_api_response: Optional[JsonValue] = Field(
|
source_api_response: Optional[str] = Field(description="The original API response from the source, as stringified JSON.", default=None)
|
||||||
description="The original API response from the source", default=None
|
|
||||||
)
|
|
||||||
trigger_words: Optional[set[str]] = Field(description="Set of trigger words for this model", default=None)
|
trigger_words: Optional[set[str]] = Field(description="Set of trigger words for this model", default=None)
|
||||||
|
|
||||||
model_config = ConfigDict(use_enum_values=False, validate_assignment=True)
|
model_config = ConfigDict(use_enum_values=False, validate_assignment=True)
|
||||||
|
@ -23,12 +23,13 @@ metadata = fetcher.from_url("https://civitai.com/models/206883/split")
|
|||||||
print(metadata.trained_words)
|
print(metadata.trained_words)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from pydantic import TypeAdapter
|
from pydantic import TypeAdapter, ValidationError
|
||||||
from pydantic.networks import AnyHttpUrl
|
from pydantic.networks import AnyHttpUrl
|
||||||
from requests.sessions import Session
|
from requests.sessions import Session
|
||||||
|
|
||||||
@ -56,7 +57,7 @@ StringSetAdapter = TypeAdapter(set[str])
|
|||||||
class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||||
"""Fetch model metadata from Civitai."""
|
"""Fetch model metadata from Civitai."""
|
||||||
|
|
||||||
def __init__(self, session: Optional[Session] = None):
|
def __init__(self, session: Optional[Session] = None, api_key: Optional[str] = None):
|
||||||
"""
|
"""
|
||||||
Initialize the fetcher with an optional requests.sessions.Session object.
|
Initialize the fetcher with an optional requests.sessions.Session object.
|
||||||
|
|
||||||
@ -64,6 +65,7 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
|||||||
this module without an internet connection.
|
this module without an internet connection.
|
||||||
"""
|
"""
|
||||||
self._requests = session or requests.Session()
|
self._requests = session or requests.Session()
|
||||||
|
self._api_key = api_key
|
||||||
|
|
||||||
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
||||||
"""
|
"""
|
||||||
@ -103,7 +105,7 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
|||||||
May raise an `UnknownMetadataException`.
|
May raise an `UnknownMetadataException`.
|
||||||
"""
|
"""
|
||||||
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
|
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
|
||||||
model_json = self._requests.get(model_url).json()
|
model_json = self._requests.get(self._get_url_with_api_key(model_url)).json()
|
||||||
return self._from_api_response(model_json)
|
return self._from_api_response(model_json)
|
||||||
|
|
||||||
def _from_api_response(self, api_response: dict[str, Any], version_id: Optional[int] = None) -> CivitaiMetadata:
|
def _from_api_response(self, api_response: dict[str, Any], version_id: Optional[int] = None) -> CivitaiMetadata:
|
||||||
@ -134,7 +136,7 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
|||||||
url = url + f"?type={primary_file['type']}{metadata_string}"
|
url = url + f"?type={primary_file['type']}{metadata_string}"
|
||||||
model_files = [
|
model_files = [
|
||||||
RemoteModelFile(
|
RemoteModelFile(
|
||||||
url=url,
|
url=self._get_url_with_api_key(url),
|
||||||
path=Path(primary_file["name"]),
|
path=Path(primary_file["name"]),
|
||||||
size=int(primary_file["sizeKB"] * 1024),
|
size=int(primary_file["sizeKB"] * 1024),
|
||||||
sha256=primary_file["hashes"]["SHA256"],
|
sha256=primary_file["hashes"]["SHA256"],
|
||||||
@ -142,11 +144,16 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
|||||||
]
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
trigger_words = StringSetAdapter.validate_python(api_response["triggerWords"])
|
trigger_words = StringSetAdapter.validate_python(version_json.get("trainedWords"))
|
||||||
except TypeError:
|
except ValidationError:
|
||||||
trigger_words: set[str] = set()
|
trigger_words: set[str] = set()
|
||||||
|
|
||||||
return CivitaiMetadata(name=version_json["name"], files=model_files, trigger_words=trigger_words)
|
return CivitaiMetadata(
|
||||||
|
name=version_json["name"],
|
||||||
|
files=model_files,
|
||||||
|
trigger_words=trigger_words,
|
||||||
|
api_response=json.dumps(version_json),
|
||||||
|
)
|
||||||
|
|
||||||
def from_civitai_versionid(self, version_id: int, model_id: Optional[int] = None) -> CivitaiMetadata:
|
def from_civitai_versionid(self, version_id: int, model_id: Optional[int] = None) -> CivitaiMetadata:
|
||||||
"""
|
"""
|
||||||
@ -156,13 +163,13 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
|||||||
"""
|
"""
|
||||||
if model_id is None:
|
if model_id is None:
|
||||||
version_url = CIVITAI_VERSION_ENDPOINT + str(version_id)
|
version_url = CIVITAI_VERSION_ENDPOINT + str(version_id)
|
||||||
version = self._requests.get(version_url).json()
|
version = self._requests.get(self._get_url_with_api_key(version_url)).json()
|
||||||
if error := version.get("error"):
|
if error := version.get("error"):
|
||||||
raise UnknownMetadataException(error)
|
raise UnknownMetadataException(error)
|
||||||
model_id = version["modelId"]
|
model_id = version["modelId"]
|
||||||
|
|
||||||
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
|
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
|
||||||
model_json = self._requests.get(model_url).json()
|
model_json = self._requests.get(self._get_url_with_api_key(model_url)).json()
|
||||||
return self._from_api_response(model_json, version_id)
|
return self._from_api_response(model_json, version_id)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -170,3 +177,12 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
|||||||
"""Given the JSON representation of the metadata, return the corresponding Pydantic object."""
|
"""Given the JSON representation of the metadata, return the corresponding Pydantic object."""
|
||||||
metadata = CivitaiMetadata.model_validate_json(json)
|
metadata = CivitaiMetadata.model_validate_json(json)
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
|
def _get_url_with_api_key(self, url: str) -> str:
|
||||||
|
if not self._api_key:
|
||||||
|
return url
|
||||||
|
|
||||||
|
if "?" in url:
|
||||||
|
return f"{url}&token={self._api_key}"
|
||||||
|
|
||||||
|
return f"{url}?token={self._api_key}"
|
||||||
|
@ -13,6 +13,7 @@ metadata = fetcher.from_url("https://huggingface.co/stabilityai/sdxl-turbo")
|
|||||||
print(metadata.tags)
|
print(metadata.tags)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@ -89,7 +90,9 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return HuggingFaceMetadata(id=model_info.id, name=name, files=files)
|
return HuggingFaceMetadata(
|
||||||
|
id=model_info.id, name=name, files=files, api_response=json.dumps(model_info.__dict__)
|
||||||
|
)
|
||||||
|
|
||||||
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
||||||
"""
|
"""
|
||||||
|
@ -18,7 +18,7 @@ from pathlib import Path
|
|||||||
from typing import List, Literal, Optional, Union
|
from typing import List, Literal, Optional, Union
|
||||||
|
|
||||||
from huggingface_hub import configure_http_backend, hf_hub_url
|
from huggingface_hub import configure_http_backend, hf_hub_url
|
||||||
from pydantic import BaseModel, Field, JsonValue, TypeAdapter
|
from pydantic import BaseModel, Field, TypeAdapter
|
||||||
from pydantic.networks import AnyHttpUrl
|
from pydantic.networks import AnyHttpUrl
|
||||||
from requests.sessions import Session
|
from requests.sessions import Session
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
@ -93,7 +93,7 @@ class CivitaiMetadata(ModelMetadataWithFiles):
|
|||||||
|
|
||||||
type: Literal["civitai"] = "civitai"
|
type: Literal["civitai"] = "civitai"
|
||||||
trigger_words: set[str] = Field(description="Trigger words extracted from the API response")
|
trigger_words: set[str] = Field(description="Trigger words extracted from the API response")
|
||||||
api_response: Optional[JsonValue] = Field(description="Response from the Civitai API", default=None)
|
api_response: Optional[str] = Field(description="Response from the Civitai API as stringified JSON", default=None)
|
||||||
|
|
||||||
|
|
||||||
class HuggingFaceMetadata(ModelMetadataWithFiles):
|
class HuggingFaceMetadata(ModelMetadataWithFiles):
|
||||||
@ -101,7 +101,7 @@ class HuggingFaceMetadata(ModelMetadataWithFiles):
|
|||||||
|
|
||||||
type: Literal["huggingface"] = "huggingface"
|
type: Literal["huggingface"] = "huggingface"
|
||||||
id: str = Field(description="The HF model id")
|
id: str = Field(description="The HF model id")
|
||||||
api_response: Optional[JsonValue] = Field(description="Response from the HF API", default=None)
|
api_response: Optional[str] = Field(description="Response from the HF API as stringified JSON", default=None)
|
||||||
|
|
||||||
def download_urls(
|
def download_urls(
|
||||||
self,
|
self,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user