refactor(mm): get metadata working

This commit is contained in:
psychedelicious 2024-03-04 19:17:01 +11:00
parent 7cb0da1f66
commit c3aa985c93
8 changed files with 72 additions and 61 deletions

View File

@ -3,8 +3,6 @@
import pathlib
import shutil
from hashlib import sha1
from random import randbytes
from typing import Any, Dict, List, Optional, Set
from fastapi import Body, Path, Query, Response
@ -461,11 +459,8 @@ async def add_model_record(
"""Add a model using the configuration information appropriate for its type."""
logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_manager.store
if config.key == "<NOKEY>":
config.key = sha1(randbytes(100)).hexdigest()
logger.info(f"Created model {config.key} for {config.name}")
try:
record_store.add_model(config.key, config)
record_store.add_model(config)
except DuplicateModelException as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))

View File

@ -556,7 +556,7 @@ class ModelInstallService(ModelInstallServiceBase):
# make config relative to our root
legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config_path).resolve()
info.config_path = legacy_conf.relative_to(self.app_config.root_dir).as_posix()
self.record_store.add_model(info.key, info)
self.record_store.add_model(info)
return info.key
def _next_id(self) -> int:
@ -583,7 +583,9 @@ class ModelInstallService(ModelInstallServiceBase):
def _import_from_civitai(self, source: CivitaiModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
if not source.access_token:
self._logger.info("No Civitai access token provided; some models may not be downloadable.")
metadata = CivitaiMetadataFetch(self._session).from_id(str(source.version_id))
metadata = CivitaiMetadataFetch(self._session, self.app_config.get_config().civitai_api_key).from_id(
str(source.version_id)
)
assert isinstance(metadata, ModelMetadataWithFiles)
remote_files = metadata.download_urls(session=self._session)
return self._import_remote_model(source=source, config=config, metadata=metadata, remote_files=remote_files)
@ -611,15 +613,17 @@ class ModelInstallService(ModelInstallServiceBase):
def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
# URLs from Civitai or HuggingFace will be handled specially
url_patterns = {
r"^https?://civitai.com/": CivitaiMetadataFetch,
r"^https?://huggingface.co/[^/]+/[^/]+$": HuggingFaceMetadataFetch,
}
metadata = None
for pattern, fetcher in url_patterns.items():
if re.match(pattern, str(source.url), re.IGNORECASE):
metadata = fetcher(self._session).from_url(source.url)
break
fetcher = None
try:
fetcher = self.get_fetcher_from_url(str(source.url))
except ValueError:
pass
kwargs: dict[str, Any] = {"session": self._session}
if fetcher is CivitaiMetadataFetch:
kwargs["api_key"] = self._app_config.get_config().civitai_api_key
if fetcher is not None:
metadata = fetcher(**kwargs).from_url(source.url)
self._logger.debug(f"metadata={metadata}")
if metadata and isinstance(metadata, ModelMetadataWithFiles):
remote_files = metadata.download_urls(session=self._session)
@ -849,3 +853,11 @@ class ModelInstallService(ModelInstallServiceBase):
self._logger.info(f"{job.source}: model installation was cancelled")
if self._event_bus:
self._event_bus.emit_model_install_cancelled(str(job.source))
@staticmethod
def get_fetcher_from_url(url: str):
if re.match(r"^https?://civitai.com/", url.lower()):
return CivitaiMetadataFetch
elif re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
return HuggingFaceMetadataFetch
raise ValueError(f"Unsupported model source: '{url}'")

View File

@ -64,7 +64,7 @@ class ModelRecordServiceBase(ABC):
"""Abstract base class for storage and retrieval of model configs."""
@abstractmethod
def add_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
def add_model(self, config: AnyModelConfig) -> AnyModelConfig:
"""
Add a model to the database.

View File

@ -85,7 +85,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
"""Return the underlying database."""
return self._db
def add_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
def add_model(self, config: AnyModelConfig) -> AnyModelConfig:
"""
Add a model to the database.
@ -95,8 +95,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
"""
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
json_serialized = record.model_dump_json() # and turn it into a json string.
with self._db.lock:
try:
self._cursor.execute(
@ -108,8 +106,8 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
VALUES (?,?);
""",
(
key,
json_serialized,
config.key,
config.model_dump_json(),
),
)
self._db.conn.commit()
@ -118,11 +116,11 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
self._db.conn.rollback()
if "UNIQUE constraint failed" in str(e):
if "models.path" in str(e):
msg = f"A model with path '{record.path}' is already installed"
msg = f"A model with path '{config.path}' is already installed"
elif "models.name" in str(e):
msg = f"A model with name='{record.name}', type='{record.type}', base='{record.base}' is already installed"
msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
else:
msg = f"A model with key '{key}' is already installed"
msg = f"A model with key '{config.key}' is already installed"
raise DuplicateModelException(msg) from e
else:
raise e
@ -130,7 +128,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
self._db.conn.rollback()
raise e
return self.get_model(key)
return self.get_model(config.key)
def del_model(self, key: str) -> None:
"""
@ -263,14 +261,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
with self._db.lock:
self._cursor.execute(
f"""--sql
select config, strftime('%s',updated_at) FROM models
SELECT config, strftime('%s',updated_at) FROM models
{where};
""",
tuple(bindings),
)
results = [
ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()
]
result = self._cursor.fetchall()
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in result]
return results
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
@ -347,6 +344,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
) -> PaginatedResults[ModelSummary]:
"""Return a paginated summary listing of each model in the database."""
assert isinstance(order_by, ModelRecordOrderBy)
ordering = {
ModelRecordOrderBy.Default: "a.type, a.base, a.format, a.name",
ModelRecordOrderBy.Type: "a.type",
@ -355,14 +353,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
ModelRecordOrderBy.Format: "a.format",
}
def _fixup(summary: Dict[str, str]) -> Dict[str, Union[str, int, Set[str]]]:
"""Fix up results so that there are no null values."""
result: Dict[str, Union[str, int, Set[str]]] = {}
for key, item in summary.items():
result[key] = item or ""
result["tags"] = set(json.loads(summary["tags"] or "[]"))
return result
# Lock so that the database isn't updated while we're doing the two queries.
with self._db.lock:
# query1: get the total number of model configs
@ -377,11 +367,8 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
# query2: fetch key fields from the join of models and model_metadata
self._cursor.execute(
f"""--sql
SELECT a.id as key, a.type, a.base, a.format, a.name,
json_extract(a.config, '$.description') as description,
json_extract(b.metadata, '$.tags') as tags
FROM models AS a
LEFT JOIN model_metadata AS b on a.id=b.id
SELECT config
FROM models
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
LIMIT ?
OFFSET ?;
@ -392,7 +379,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
),
)
rows = self._cursor.fetchall()
items = [ModelSummary.model_validate(_fixup(dict(x))) for x in rows]
items = [ModelSummary.model_validate(dict(x)) for x in rows]
return PaginatedResults(
page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items
)

View File

@ -26,7 +26,7 @@ from typing import Literal, Optional, Type, Union
import torch
from diffusers.models.modeling_utils import ModelMixin
from pydantic import BaseModel, ConfigDict, Discriminator, Field, JsonValue, Tag, TypeAdapter
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
from typing_extensions import Annotated, Any, Dict
from ..raw_model import RawModel
@ -142,9 +142,7 @@ class ModelConfigBase(BaseModel):
description: Optional[str] = Field(description="Model description", default=None)
source: str = Field(description="The original source of the model (path, URL or repo_id).")
source_type: ModelSourceType = Field(description="The type of source")
source_api_response: Optional[JsonValue] = Field(
description="The original API response from the source", default=None
)
source_api_response: Optional[str] = Field(description="The original API response from the source, as stringified JSON.", default=None)
trigger_words: Optional[set[str]] = Field(description="Set of trigger words for this model", default=None)
model_config = ConfigDict(use_enum_values=False, validate_assignment=True)

View File

@ -23,12 +23,13 @@ metadata = fetcher.from_url("https://civitai.com/models/206883/split")
print(metadata.trained_words)
"""
import json
import re
from pathlib import Path
from typing import Any, Optional
import requests
from pydantic import TypeAdapter
from pydantic import TypeAdapter, ValidationError
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
@ -56,7 +57,7 @@ StringSetAdapter = TypeAdapter(set[str])
class CivitaiMetadataFetch(ModelMetadataFetchBase):
"""Fetch model metadata from Civitai."""
def __init__(self, session: Optional[Session] = None):
def __init__(self, session: Optional[Session] = None, api_key: Optional[str] = None):
"""
Initialize the fetcher with an optional requests.sessions.Session object.
@ -64,6 +65,7 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
this module without an internet connection.
"""
self._requests = session or requests.Session()
self._api_key = api_key
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
"""
@ -103,7 +105,7 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
May raise an `UnknownMetadataException`.
"""
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
model_json = self._requests.get(model_url).json()
model_json = self._requests.get(self._get_url_with_api_key(model_url)).json()
return self._from_api_response(model_json)
def _from_api_response(self, api_response: dict[str, Any], version_id: Optional[int] = None) -> CivitaiMetadata:
@ -134,7 +136,7 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
url = url + f"?type={primary_file['type']}{metadata_string}"
model_files = [
RemoteModelFile(
url=url,
url=self._get_url_with_api_key(url),
path=Path(primary_file["name"]),
size=int(primary_file["sizeKB"] * 1024),
sha256=primary_file["hashes"]["SHA256"],
@ -142,11 +144,16 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
]
try:
trigger_words = StringSetAdapter.validate_python(api_response["triggerWords"])
except TypeError:
trigger_words = StringSetAdapter.validate_python(version_json.get("trainedWords"))
except ValidationError:
trigger_words: set[str] = set()
return CivitaiMetadata(name=version_json["name"], files=model_files, trigger_words=trigger_words)
return CivitaiMetadata(
name=version_json["name"],
files=model_files,
trigger_words=trigger_words,
api_response=json.dumps(version_json),
)
def from_civitai_versionid(self, version_id: int, model_id: Optional[int] = None) -> CivitaiMetadata:
"""
@ -156,13 +163,13 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
"""
if model_id is None:
version_url = CIVITAI_VERSION_ENDPOINT + str(version_id)
version = self._requests.get(version_url).json()
version = self._requests.get(self._get_url_with_api_key(version_url)).json()
if error := version.get("error"):
raise UnknownMetadataException(error)
model_id = version["modelId"]
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
model_json = self._requests.get(model_url).json()
model_json = self._requests.get(self._get_url_with_api_key(model_url)).json()
return self._from_api_response(model_json, version_id)
@classmethod
@ -170,3 +177,12 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
"""Given the JSON representation of the metadata, return the corresponding Pydantic object."""
metadata = CivitaiMetadata.model_validate_json(json)
return metadata
def _get_url_with_api_key(self, url: str) -> str:
if not self._api_key:
return url
if "?" in url:
return f"{url}&token={self._api_key}"
return f"{url}?token={self._api_key}"

View File

@ -13,6 +13,7 @@ metadata = fetcher.from_url("https://huggingface.co/stabilityai/sdxl-turbo")
print(metadata.tags)
"""
import json
import re
from pathlib import Path
from typing import Optional
@ -89,7 +90,9 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
)
)
return HuggingFaceMetadata(id=model_info.id, name=name, files=files)
return HuggingFaceMetadata(
id=model_info.id, name=name, files=files, api_response=json.dumps(model_info.__dict__)
)
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
"""

View File

@ -18,7 +18,7 @@ from pathlib import Path
from typing import List, Literal, Optional, Union
from huggingface_hub import configure_http_backend, hf_hub_url
from pydantic import BaseModel, Field, JsonValue, TypeAdapter
from pydantic import BaseModel, Field, TypeAdapter
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
from typing_extensions import Annotated
@ -93,7 +93,7 @@ class CivitaiMetadata(ModelMetadataWithFiles):
type: Literal["civitai"] = "civitai"
trigger_words: set[str] = Field(description="Trigger words extracted from the API response")
api_response: Optional[JsonValue] = Field(description="Response from the Civitai API", default=None)
api_response: Optional[str] = Field(description="Response from the Civitai API as stringified JSON", default=None)
class HuggingFaceMetadata(ModelMetadataWithFiles):
@ -101,7 +101,7 @@ class HuggingFaceMetadata(ModelMetadataWithFiles):
type: Literal["huggingface"] = "huggingface"
id: str = Field(description="The HF model id")
api_response: Optional[JsonValue] = Field(description="Response from the HF API", default=None)
api_response: Optional[str] = Field(description="Response from the HF API as stringified JSON", default=None)
def download_urls(
self,