add storage

This commit is contained in:
Lincoln Stein 2023-12-19 17:00:49 -05:00
parent c610283158
commit e86f3fe29e
10 changed files with 669 additions and 149 deletions

View File

@ -5,6 +5,7 @@ from invokeai.app.services.image_files.image_files_base import ImageFileStorageB
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import build_migration_1
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import build_migration_2
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_4 import build_migration_4
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@ -27,6 +28,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator = SqliteMigrator(db=db)
migrator.register_migration(build_migration_1())
migrator.register_migration(build_migration_2(image_files=image_files, logger=logger))
migrator.register_migration(build_migration_4())
migrator.run_migrations()
return db

View File

@ -0,0 +1,94 @@
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration4Callback:
"""Callback to do step 4 of migration."""
def __call__(self, cursor: sqlite3.Cursor) -> None: # noqa D102
self._create_model_metadata(cursor)
self._create_model_tags(cursor)
self._create_tags(cursor)
self._create_triggers(cursor)
def _create_model_metadata(self, cursor: sqlite3.Cursor) -> None:
"""Create the table used to store model metadata downloaded from remote sources."""
cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS model_metadata (
id TEXT NOT NULL,
name TEXT GENERATED ALWAYS AS (json_extract(metadata, '$.name')) VIRTUAL NOT NULL,
author TEXT GENERATED ALWAYS AS (json_extract(metadata, '$.author')) VIRTUAL NOT NULL,
-- Serialized JSON representation of the whole metadata object,
-- which will contain additional fields from subclasses
metadata TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
FOREIGN KEY(id) REFERENCES model_config(id)
);
"""
)
def _create_model_tags(self, cursor: sqlite3.Cursor) -> None:
cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS model_tags (
id TEXT NOT NULL,
tag_id INTEGER NOT NULL,
FOREIGN KEY(id) REFERENCES model_config(id),
FOREIGN KEY(tag_id) REFERENCES tags(tag_id),
UNIQUE(id,tag_id)
);
"""
)
def _create_tags(self, cursor: sqlite3.Cursor) -> None:
cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS tags (
tag_id INTEGER NOT NULL PRIMARY KEY,
tag_text TEXT NOT NULL UNIQUE
);
"""
)
def _create_triggers(self, cursor: sqlite3.Cursor) -> None:
cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS model_metadata_updated_at
AFTER UPDATE
ON model_metadata FOR EACH ROW
BEGIN
UPDATE model_metadata SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE id = old.id;
END;
"""
)
cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS model_config_deleted
AFTER DELETE
ON model_config
BEGIN
DELETE from model_metadata WHERE id=old.id;
DELETE from model_tags WHERE id=old.id;
END;
"""
)
def build_migration_4() -> Migration:
"""
Build the migration from database version 3 to 4.
Adds the tables needed to store model metadata and tags.
"""
migration_4 = Migration(
from_version=2, # until migration_3 is merged, pretend we are doing 2-3
to_version=3,
callback=Migration4Callback(),
)
return migration_4

View File

@ -0,0 +1,38 @@
"""
Initialization file for invokeai.backend.model_manager.metadata
Usage:
from invokeai.backend.model_manager.metadata import(
AnyModelRepoMetadata,
CommercialUsage,
LicenseRestrictions,
HuggingFaceMetadata,
CivitaiMetadata,
)
from invokeai.backend.model_manager.metadata.fetch import CivitaiMetadataFetch
data = CivitaiMetadataFetch().from_url("https://civitai.com/models/206883/split")
assert isinstance(data, CivitaiMetadata)
if data.allow_commercial_use:
print("Commercial use of this model is allowed")
"""
from .metadata_base import (
AnyModelRepoMetadata,
AnyModelRepoMetadataValidator,
CivitaiMetadata,
CommercialUsage,
HuggingFaceMetadata,
LicenseRestrictions,
)
__all__ = [
"AnyModelRepoMetadata",
"AnyModelRepoMetadataValidator",
"CommercialUsage",
"LicenseRestrictions",
"HuggingFaceMetadata",
"CivitaiMetadata",
]

View File

@ -1,139 +0,0 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
This module fetches metadata objects from HuggingFace and Civitai.
Usage:
from invokeai.backend.model_manager.metadata.fetch import MetadataFetch
metadata = MetadataFetch.from_civitai_url("https://civitai.com/models/58390/detail-tweaker-lora-lora")
print(metadata.description)
"""
import re
from pathlib import Path
from typing import Optional, Dict, Optional, Any
from datetime import datetime
import requests
from huggingface_hub import HfApi, configure_http_backend
from huggingface_hub.utils._errors import RepositoryNotFoundError
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
from invokeai.app.services.model_records import UnknownModelException
from .base import (
CivitaiMetadata,
HuggingFaceMetadata,
LicenseRestrictions,
CommercialUsage,
)
HF_MODEL_RE = r"https?://huggingface.co/([\w\-.]+/[\w\-.]+)"
CIVITAI_MODEL_PAGE_RE = r"https?://civitai.com/models/(\d+)"
CIVITAI_VERSION_PAGE_RE = r"https?://civitai.com/models/(\d+)\?modelVersionId=(\d+)"
CIVITAI_DOWNLOAD_RE = r"https?://civitai.com/api/download/models/(\d+)"
CIVITAI_VERSION_ENDPOINT = "https://civitai.com/api/v1/model-versions/"
CIVITAI_MODEL_ENDPOINT = "https://civitai.com/api/v1/models/"
class MetadataFetch:
"""Fetch metadata from HuggingFace and Civitai URLs."""
_requests: Session
def __init__(self, session: Optional[Session]=None):
"""
Initialize the fetcher with an optional requests.sessions.Session object.
By providing a configurable Session object, we can support unit tests on
this module without an internet connection.
"""
self._requests = session or requests.Session()
configure_http_backend(backend_factory = lambda: self._requests)
def from_huggingface_repoid(self, repo_id: str) -> HuggingFaceMetadata:
"""Return a HuggingFaceMetadata object given the model's repo_id."""
try:
model_info = HfApi().model_info(repo_id=repo_id, files_metadata=True)
except RepositoryNotFoundError as excp:
raise UnknownModelException(f"'{repo_id}' not found. See trace for details.") from excp
_, name = repo_id.split("/")
return HuggingFaceMetadata(
id = model_info.modelId,
author = model_info.author,
name = name,
last_modified = model_info.lastModified,
tags = model_info.tags,
tag_dict = model_info.card_data.to_dict(),
files = [Path(x.rfilename) for x in model_info.siblings]
)
def from_huggingface_url(self, url: AnyHttpUrl) -> HuggingFaceMetadata:
"""
Return a HuggingFaceMetadata object given the model's web page URL.
In the case of an invalid or missing URL, raises a ModelNotFound exception.
"""
if match := re.match(HF_MODEL_RE, str(url)):
repo_id = match.group(1)
return self.from_huggingface_repoid(repo_id)
else:
raise UnknownModelException(f"'{url}' does not look like a HuggingFace model page")
def from_civitai_versionid(self, version_id: int, model_metadata: Optional[Dict[str,Any]]=None) -> CivitaiMetadata:
"""Return Civitai metadata using a model's version id."""
version_url = CIVITAI_VERSION_ENDPOINT + str(version_id)
version = self._requests.get(version_url).json()
model_url = CIVITAI_MODEL_ENDPOINT + str(version['modelId'])
model = model_metadata or self._requests.get(model_url).json()
safe_thumbnails = [x['url'] for x in version['images'] if x['nsfw']=='None']
return CivitaiMetadata(
id=version['modelId'],
name=version['model']['name'],
version_id=version['id'],
version_name=version['name'],
created=datetime.fromisoformat(re.sub(r"Z$", "+00:00", version['createdAt'])),
base_model_trained_on=version['baseModel'], # note - need a dictionary to turn into a BaseModelType
download_url=version['downloadUrl'],
thumbnail_url=safe_thumbnails[0] if safe_thumbnails else None,
author=model['creator']['username'],
description=model['description'],
version_description=version['description'] or "",
tags=model['tags'],
trained_words=version['trainedWords'],
nsfw=version['model']['nsfw'],
restrictions=LicenseRestrictions(
AllowNoCredit=model['allowNoCredit'],
AllowCommercialUse=CommercialUsage(model['allowCommercialUse']),
AllowDerivatives=model['allowDerivatives'],
AllowDifferentLicense=model['allowDifferentLicense']
),
)
def from_civitai_modelid(self, model_id: int) -> CivitaiMetadata:
"""Return metadata from the default version of the indicated model."""
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
model = self._requests.get(model_url).json()
default_version = model['modelVersions'][0]['id']
return self.from_civitai_versionid(default_version, model)
def from_civitai_url(self, url: AnyHttpUrl) -> CivitaiMetadata:
"""Parse a Civitai URL that user is likely to pass and return its metadata."""
if match := re.match(CIVITAI_MODEL_PAGE_RE, str(url)):
model_id = match.group(1)
return self.from_civitai_modelid(int(model_id))
elif match := re.match(CIVITAI_VERSION_PAGE_RE, str(url)):
version_id = match.group(1)
return self.from_civitai_versionid(int(version_id))
elif match := re.match(CIVITAI_DOWNLOAD_RE, str(url)):
version_id = match.group(1)
return self.from_civitai_versionid(int(version_id))
raise UnknownModelException("The url '{url}' does not match any known Civitai URL patterns")

View File

@ -0,0 +1,21 @@
"""
Initialization file for invokeai.backend.model_manager.metadata.fetch
Usage:
from invokeai.backend.model_manager.metadata.fetch import (
CivitaiMetadataFetch,
HuggingFaceMetadataFetch,
)
from invokeai.backend.model_manager.metadata import CivitaiMetadata
data = CivitaiMetadataFetch().from_url("https://civitai.com/models/206883/split")
assert isinstance(data, CivitaiMetadata)
if data.allow_commercial_use:
print("Commercial use of this model is allowed")
"""
from .fetch_base import ModelMetadataFetchBase
from .fetch_civitai import CivitaiMetadataFetch
from .fetch_huggingface import HuggingFaceMetadataFetch
__all__ = ["ModelMetadataFetchBase", "CivitaiMetadataFetch", "HuggingFaceMetadataFetch"]

View File

@ -0,0 +1,63 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
This module is the base class for subclasses that fetch metadata from model repositories
Usage:
from invokeai.backend.model_manager.metadata.fetch import CivitAIMetadataFetch
fetcher = CivitaiMetadataFetch()
metadata = fetcher.from_url("https://civitai.com/models/206883/split")
print(metadata.trained_words)
"""
from abc import ABC, abstractmethod
from typing import Optional
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator
class ModelMetadataFetchBase(ABC):
"""Fetch metadata from remote generative model repositories."""
@abstractmethod
def __init__(self, session: Optional[Session] = None):
"""
Initialize the fetcher with an optional requests.sessions.Session object.
By providing a configurable Session object, we can support unit tests on
this module without an internet connection.
"""
pass
@abstractmethod
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
"""
Given a URL to a model repository, return a ModelMetadata object.
This method will raise an `invokeai.app.services.model_records.UnknownModelException`
in the event that the requested model metadata is not found at the provided location.
"""
pass
@abstractmethod
def from_id(self, id: str) -> AnyModelRepoMetadata:
"""
Given an ID for a model, return a ModelMetadata object.
This method will raise an `invokeai.app.services.model_records.UnknownModelException`
in the event that the requested model's metadata is not found at the provided id.
"""
pass
@classmethod
def from_json(self, json: str) -> AnyModelRepoMetadata:
"""Given the JSON representation of the metadata, return the corresponding Pydantic object."""
metadata = AnyModelRepoMetadataValidator.validate_json(json)
return (
metadata # mypy complains that metadata is a <typing special form> and issues a type checking error. Why?
)

View File

@ -0,0 +1,150 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
This module fetches model metadata objects from the Civitai model repository.
In addition to the `from_url()` and `from_id()` methods inherited from the
`ModelMetadataFetchBase` base class.
Civitai has two separate ID spaces: a model ID and a version ID. The
version ID corresponds to a specific model, and is the ID accepted by
`from_id()`. The model ID corresponds to a family of related models,
such as different training checkpoints or 16 vs 32-bit versions. The
`from_civitai_modelid()` method will accept a model ID and return the
metadata from the default version within this model set. The default
version is the same as what the user sees when they click on a model's
thumbnail.
Usage:
from invokeai.backend.model_manager.metadata.fetch import CivitaiMetadataFetch
fetcher = CivitaiMetadataFetch()
metadata = fetcher.from_url("https://civitai.com/models/206883/split")
print(metadata.trained_words)
"""
import re
from datetime import datetime
from typing import Any, Dict, Optional
import requests
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
from invokeai.app.services.model_records import UnknownModelException
from ..metadata_base import (
AnyModelRepoMetadata,
AnyModelRepoMetadataValidator,
CivitaiMetadata,
CommercialUsage,
LicenseRestrictions,
)
from .fetch_base import ModelMetadataFetchBase
CIVITAI_MODEL_PAGE_RE = r"https?://civitai.com/models/(\d+)"
CIVITAI_VERSION_PAGE_RE = r"https?://civitai.com/models/(\d+)\?modelVersionId=(\d+)"
CIVITAI_DOWNLOAD_RE = r"https?://civitai.com/api/download/models/(\d+)"
CIVITAI_VERSION_ENDPOINT = "https://civitai.com/api/v1/model-versions/"
CIVITAI_MODEL_ENDPOINT = "https://civitai.com/api/v1/models/"
class CivitaiMetadataFetch(ModelMetadataFetchBase):
"""Fetch model metadata from Civitai."""
_requests: Session
def __init__(self, session: Optional[Session] = None):
"""
Initialize the fetcher with an optional requests.sessions.Session object.
By providing a configurable Session object, we can support unit tests on
this module without an internet connection.
"""
self._requests = session or requests.Session()
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
"""
Given a URL to a CivitAI model or version page, return a ModelMetadata object.
In the event that the URL points to a model page without the particular version
indicated, the default model version is returned. Otherwise, the requested version
is returned.
"""
if match := re.match(CIVITAI_MODEL_PAGE_RE, str(url)):
model_id = match.group(1)
return self.from_civitai_modelid(int(model_id))
elif match := re.match(CIVITAI_VERSION_PAGE_RE, str(url)):
version_id = match.group(1)
return self._from_civitai_versionid(int(version_id))
elif match := re.match(CIVITAI_DOWNLOAD_RE, str(url)):
version_id = match.group(1)
return self._from_civitai_versionid(int(version_id))
raise UnknownModelException("The url '{url}' does not match any known Civitai URL patterns")
def from_id(self, id: str) -> AnyModelRepoMetadata:
"""
Given a Civitai model version ID, return a ModelRepoMetadata object.
May raise an `UnknownModelException`.
"""
return self._from_civitai_versionid(int(id))
def from_civitai_modelid(self, model_id: int) -> CivitaiMetadata:
"""
Return metadata from the default version of the indicated model.
May raise an `UnknownModelException`.
"""
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
model = self._requests.get(model_url).json()
default_version = model["modelVersions"][0]["id"]
return self._from_civitai_versionid(default_version, model)
def _from_civitai_versionid(
self, version_id: int, model_metadata: Optional[Dict[str, Any]] = None
) -> CivitaiMetadata:
version_url = CIVITAI_VERSION_ENDPOINT + str(version_id)
version = self._requests.get(version_url).json()
model_url = CIVITAI_MODEL_ENDPOINT + str(version["modelId"])
model = model_metadata or self._requests.get(model_url).json()
safe_thumbnails = [x["url"] for x in version["images"] if x["nsfw"] == "None"]
# It would be more elegant to define a Pydantic BaseModel that matches the Civitai metadata JSON.
# However the contents of the JSON does not exactly match the documentation at
# https://github.com/civitai/civitai/wiki/REST-API-Reference, and it feels safer to cherry pick
# a subset of the fields.
#
# In addition, there are some fields that I want to pick up from the model JSON, such as `tags`,
# that are not present in the version JSON.
return CivitaiMetadata(
id=version["modelId"],
name=version["model"]["name"],
version_id=version["id"],
version_name=version["name"],
created=datetime.fromisoformat(re.sub(r"Z$", "+00:00", version["createdAt"])),
base_model_trained_on=version["baseModel"], # note - need a dictionary to turn into a BaseModelType
download_url=version["downloadUrl"],
thumbnail_url=safe_thumbnails[0] if safe_thumbnails else None,
author=model["creator"]["username"],
description=model["description"],
version_description=version["description"] or "",
tags=model["tags"],
trained_words=version["trainedWords"],
nsfw=version["model"]["nsfw"],
restrictions=LicenseRestrictions(
AllowNoCredit=model["allowNoCredit"],
AllowCommercialUse=CommercialUsage(model["allowCommercialUse"]),
AllowDerivatives=model["allowDerivatives"],
AllowDifferentLicense=model["allowDifferentLicense"],
),
)
@classmethod
def from_json(cls, json: str) -> CivitaiMetadata:
"""Given the JSON representation of the metadata, return the corresponding Pydantic object."""
metadata = AnyModelRepoMetadataValidator.validate_json(json)
assert isinstance(metadata, CivitaiMetadata)
return metadata

View File

@ -0,0 +1,83 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
This module fetches model metadata objects from the HuggingFace model repository,
using either a `repo_id` or the model page URL.
Usage:
from invokeai.backend.model_manager.metadata.fetch import HuggingFaceMetadataFetch
fetcher = HuggingFaceMetadataFetch()
metadata = fetcher.from_url("https://huggingface.co/stabilityai/sdxl-turbo")
print(metadata.tags)
"""
import re
from pathlib import Path
from typing import Optional
import requests
from huggingface_hub import HfApi, configure_http_backend
from huggingface_hub.utils._errors import RepositoryNotFoundError
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
from invokeai.app.services.model_records import UnknownModelException
from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator, HuggingFaceMetadata
from .fetch_base import ModelMetadataFetchBase
class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
"""Fetch model metadata from HuggingFace."""
_requests: Session
def __init__(self, session: Optional[Session] = None):
"""
Initialize the fetcher with an optional requests.sessions.Session object.
By providing a configurable Session object, we can support unit tests on
this module without an internet connection.
"""
self._requests = session or requests.Session()
configure_http_backend(backend_factory=lambda: self._requests)
def from_id(self, id: str) -> AnyModelRepoMetadata:
"""Return a HuggingFaceMetadata object given the model's repo_id."""
try:
model_info = HfApi().model_info(repo_id=id, files_metadata=True)
except RepositoryNotFoundError as excp:
raise UnknownModelException(f"'{id}' not found. See trace for details.") from excp
_, name = id.split("/")
return HuggingFaceMetadata(
id=model_info.modelId,
author=model_info.author,
name=name,
last_modified=model_info.lastModified,
tag_dict=model_info.card_data.to_dict(),
tags=model_info.tags,
files=[Path(x.rfilename) for x in model_info.siblings],
)
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
"""
Return a HuggingFaceMetadata object given the model's web page URL.
In the case of an invalid or missing URL, raises a ModelNotFound exception.
"""
HF_MODEL_RE = r"https?://huggingface.co/([\w\-.]+/[\w\-.]+)"
if match := re.match(HF_MODEL_RE, str(url)):
repo_id = match.group(1)
return self.from_id(repo_id)
else:
raise UnknownModelException(f"'{url}' does not look like a HuggingFace model page")
@classmethod
def from_json(cls, json: str) -> HuggingFaceMetadata:
"""Given the JSON representation of the metadata, return the corresponding Pydantic object."""
metadata = AnyModelRepoMetadataValidator.validate_json(json)
assert isinstance(metadata, HuggingFaceMetadata)
return metadata

View File

@ -15,10 +15,12 @@ This may need reworking.
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Set, Optional
from typing import Any, Dict, List, Literal, Optional, Set, Union
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, TypeAdapter
from pydantic.networks import AnyHttpUrl
from typing_extensions import Annotated
class CommercialUsage(str, Enum):
"""Type of commercial usage allowed."""
@ -29,13 +31,21 @@ class CommercialUsage(str, Enum):
RentCivit = "RentCivit"
Sell = "Sell"
class LicenseRestrictions(BaseModel):
"""Broad categories of licensing restrictions."""
AllowNoCredit: bool = Field(description="if true, model can be redistributed without crediting author", default=False)
AllowNoCredit: bool = Field(
description="if true, model can be redistributed without crediting author", default=False
)
AllowDerivatives: bool = Field(description="if true, derivatives of this model can be redistributed", default=False)
AllowDifferentLicense: bool = Field(description="if true, derivatives of this model be redistributed under a different license", default=False)
AllowCommercialUse: CommercialUsage = Field(description="Type of commercial use allowed or 'No' if no commercial use is allowed.", default_factory=set)
AllowDifferentLicense: bool = Field(
description="if true, derivatives of this model be redistributed under a different license", default=False
)
AllowCommercialUse: CommercialUsage = Field(
description="Type of commercial use allowed or 'No' if no commercial use is allowed.", default_factory=set
)
class ModelMetadataBase(BaseModel):
"""Base class for model metadata information."""
@ -44,32 +54,41 @@ class ModelMetadataBase(BaseModel):
author: str = Field(description="model's author")
tags: Set[str] = Field(description="tags provided by model source")
class HuggingFaceMetadata(ModelMetadataBase):
"""Extended metadata fields provided by HuggingFace."""
type: Literal["huggingface"] = "huggingface"
id: str = Field(description="huggingface model id")
tag_dict: Dict[str, Any]
last_modified: datetime = Field(description="date of last commit to repo")
files: List[Path] = Field(description="sibling files that belong to this model", default_factory=list)
class CivitaiMetadata(ModelMetadataBase):
"""Extended metadata fields provided by Civitai."""
type: Literal["civitai"] = "civitai"
id: int = Field(description="Civitai model identifier")
version_name: str = Field(description="Version identifier, such as 'V2-alpha'")
version_id: int = Field(description="Civitai model version identifier")
created: datetime = Field(description="date the model was posted to CivitAI")
description: str = Field(description="text description of model; may contain HTML")
version_description: str = Field(description="text description of the model's reversion; usually change history; may contain HTML")
version_description: str = Field(
description="text description of the model's reversion; usually change history; may contain HTML"
)
nsfw: bool = Field(description="whether the model tends to generate NSFW content", default=False)
restrictions: LicenseRestrictions = Field(description="license terms", default=LicenseRestrictions)
trained_words: Set[str] = Field(description="words to trigger the model", default_factory=set)
download_url: AnyHttpUrl = Field(description="download URL for this model")
base_model_trained_on: str = Field(description="base model on which this model was trained (currently not an enum)")
thumbnail_url: Optional[AnyHttpUrl] = Field(description="a thumbnail image for this model", default=None)
weight_min: float = Field(description="minimum suggested value for a LoRA or other secondary model", default=-1.0) # note: For future use; not currently easily
weight_max: float = Field(description="maximum suggested value for a LoRA or other secondary model", default=+2.0) # recoverable from metadata
weight_min: float = Field(
description="minimum suggested value for a LoRA or other secondary model", default=-1.0
) # note: For future use; not currently easily
weight_max: float = Field(
description="maximum suggested value for a LoRA or other secondary model", default=+2.0
) # recoverable from metadata
@property
def credit_required(self) -> bool:
@ -79,7 +98,7 @@ class CivitaiMetadata(ModelMetadataBase):
@property
def allow_commercial_use(self) -> bool:
"""Return True if commercial use is allowed."""
return self.restrictions.AllowCommercialUse == CommercialUsage('None')
return self.restrictions.AllowCommercialUse == CommercialUsage("None")
@property
def allow_derivatives(self) -> bool:
@ -90,3 +109,7 @@ class CivitaiMetadata(ModelMetadataBase):
def allow_different_license(self) -> bool:
"""Return true if derivatives of this model can use a different license."""
return self.restrictions.AllowDifferentLicense
AnyModelRepoMetadata = Annotated[Union[HuggingFaceMetadata, CivitaiMetadata], Field(discriminator="type")]
AnyModelRepoMetadataValidator = TypeAdapter(AnyModelRepoMetadata)

View File

@ -0,0 +1,185 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
SQL Storage for Model Metadata
"""
import sqlite3
from typing import Set
from invokeai.app.services.model_records import UnknownModelException
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from .fetch import ModelMetadataFetchBase
from .metadata_base import AnyModelRepoMetadata
class ModelMetadataStore:
"""Store, search and fetch model metadata retrieved from remote repositories."""
_db: SqliteDatabase
_cursor: sqlite3.Cursor
def __init__(self, db: SqliteDatabase):
"""
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
:param conn: sqlite3 connection object
:param lock: threading Lock object
"""
super().__init__()
self._db = db
self._cursor = self._db.conn.cursor()
self._enable_foreign_key_constraints()
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
"""
Add a block of repo metadata to a model record.
The model record config must already exist in the database with the
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to store
"""
json_serialized = metadata.model_dump_json()
with self._db.lock:
try:
self._cursor.execute(
"""--sql
INSERT INTO model_metadata(
id,
metadata
)
VALUES (?,?);
""",
(
model_key,
json_serialized,
),
)
self._update_tags(model_key, metadata.tags)
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
"""Retrieve the ModelRepoMetadata corresponding to model key."""
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT metadata FROM model_metadata
WHERE id=?;
""",
(model_key,),
)
rows = self._cursor.fetchone()
if not rows:
raise UnknownModelException("model metadata not found")
return ModelMetadataFetchBase.from_json(rows[0])
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
"""
Update metadata corresponding to the model with the indicated key.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to update
"""
json_serialized = metadata.model_dump_json() # turn it into a json string.
with self._db.lock:
try:
self._cursor.execute(
"""--sql
UPDATE model_metadata
SET
metadata=?
WHERE id=?;
""",
(json_serialized, model_key),
)
if self._cursor.rowcount == 0:
raise UnknownModelException("model not found")
self._update_tags(model_key, metadata.tags)
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
return self.get_metadata(model_key)
def search_by_tag(self, tags: Set[str]) -> Set[str]:
"""Return the keys of models containing all of the listed tags."""
with self._db.lock:
try:
matches: Set[str] = set()
for tag in tags:
self._cursor.execute(
"""--sql
SELECT a.id FROM model_tags AS a,
tags AS b
WHERE a.tag_id=b.tag_id
AND b.tag_text=?;
""",
(tag,),
)
model_keys = {x[0] for x in self._cursor.fetchall()}
matches = matches.intersection(model_keys) if len(matches) > 0 else model_keys
except sqlite3.Error as e:
raise e
return matches
def search_by_author(self, author: str) -> Set[str]:
"""Return the keys of models authored by the indicated author."""
self._cursor.execute(
"""--sql
SELECT id FROM model_metadata
WHERE author=?;
""",
(author,),
)
return {x[0] for x in self._cursor.fetchall()}
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
"""Update tags for the model referenced by model_key."""
# remove previous tags from this model
self._cursor.execute(
"""--sql
DELETE FROM model_tags
WHERE id=?;
""",
(model_key,),
)
for tag in tags:
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO tags (
tag_text
)
VALUES (?);
""",
(tag,),
)
self._cursor.execute(
"""--sql
SELECT tag_id
FROM tags
WHERE tag_text = ?
LIMIT 1;
""",
(tag,),
)
tag_id = self._cursor.fetchone()[0]
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO model_tags (
id,
tag_id
)
VALUES (?,?);
""",
(model_key, tag_id),
)
def _enable_foreign_key_constraints(self) -> None:
self._cursor.execute("PRAGMA foreign_keys = ON;")