diff --git a/invokeai/app/services/shared/sqlite/sqlite_util.py b/invokeai/app/services/shared/sqlite/sqlite_util.py index 83a42917a7..f5718a78e1 100644 --- a/invokeai/app/services/shared/sqlite/sqlite_util.py +++ b/invokeai/app/services/shared/sqlite/sqlite_util.py @@ -5,6 +5,7 @@ from invokeai.app.services.image_files.image_files_base import ImageFileStorageB from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import build_migration_1 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import build_migration_2 +from invokeai.app.services.shared.sqlite_migrator.migrations.migration_4 import build_migration_4 from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator @@ -27,6 +28,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto migrator = SqliteMigrator(db=db) migrator.register_migration(build_migration_1()) migrator.register_migration(build_migration_2(image_files=image_files, logger=logger)) + migrator.register_migration(build_migration_4()) migrator.run_migrations() return db diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_4.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_4.py new file mode 100644 index 0000000000..702262912e --- /dev/null +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_4.py @@ -0,0 +1,94 @@ +import sqlite3 + +from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration + + +class Migration4Callback: + """Callback to do step 4 of migration.""" + + def __call__(self, cursor: sqlite3.Cursor) -> None: # noqa D102 + self._create_model_metadata(cursor) + self._create_model_tags(cursor) + self._create_tags(cursor) + self._create_triggers(cursor) + + def _create_model_metadata(self, cursor: sqlite3.Cursor) -> None: + """Create the table used to store model metadata downloaded from remote sources.""" + cursor.execute( + """--sql + CREATE TABLE IF NOT EXISTS model_metadata ( + id TEXT NOT NULL, + name TEXT GENERATED ALWAYS AS (json_extract(metadata, '$.name')) VIRTUAL NOT NULL, + author TEXT GENERATED ALWAYS AS (json_extract(metadata, '$.author')) VIRTUAL NOT NULL, + -- Serialized JSON representation of the whole metadata object, + -- which will contain additional fields from subclasses + metadata TEXT NOT NULL, + created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + -- Updated via trigger + updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + FOREIGN KEY(id) REFERENCES model_config(id) + ); + """ + ) + + def _create_model_tags(self, cursor: sqlite3.Cursor) -> None: + cursor.execute( + """--sql + CREATE TABLE IF NOT EXISTS model_tags ( + id TEXT NOT NULL, + tag_id INTEGER NOT NULL, + FOREIGN KEY(id) REFERENCES model_config(id), + FOREIGN KEY(tag_id) REFERENCES tags(tag_id), + UNIQUE(id,tag_id) + ); + """ + ) + + def _create_tags(self, cursor: sqlite3.Cursor) -> None: + cursor.execute( + """--sql + CREATE TABLE IF NOT EXISTS tags ( + tag_id INTEGER NOT NULL PRIMARY KEY, + tag_text TEXT NOT NULL UNIQUE + ); + """ + ) + + def _create_triggers(self, cursor: sqlite3.Cursor) -> None: + cursor.execute( + """--sql + CREATE TRIGGER IF NOT EXISTS model_metadata_updated_at + AFTER UPDATE + ON model_metadata FOR EACH ROW + BEGIN + UPDATE model_metadata SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW') + WHERE id = old.id; + END; + """ + ) + cursor.execute( + """--sql + CREATE TRIGGER IF NOT EXISTS model_config_deleted + AFTER DELETE + ON model_config + BEGIN + DELETE from model_metadata WHERE id=old.id; + DELETE from model_tags WHERE id=old.id; + END; + """ + ) + + +def build_migration_4() -> Migration: + """ + Build the migration from database version 3 to 4. + + Adds the tables needed to store model metadata and tags. + """ + migration_4 = Migration( + from_version=2, # until migration_3 is merged, pretend we are doing 2-3 + to_version=3, + callback=Migration4Callback(), + ) + + return migration_4 diff --git a/invokeai/backend/model_manager/metadata/__init__.py b/invokeai/backend/model_manager/metadata/__init__.py index e69de29bb2..ee9e087435 100644 --- a/invokeai/backend/model_manager/metadata/__init__.py +++ b/invokeai/backend/model_manager/metadata/__init__.py @@ -0,0 +1,38 @@ +""" +Initialization file for invokeai.backend.model_manager.metadata + +Usage: + +from invokeai.backend.model_manager.metadata import( + AnyModelRepoMetadata, + CommercialUsage, + LicenseRestrictions, + HuggingFaceMetadata, + CivitaiMetadata, +) + +from invokeai.backend.model_manager.metadata.fetch import CivitaiMetadataFetch + +data = CivitaiMetadataFetch().from_url("https://civitai.com/models/206883/split") +assert isinstance(data, CivitaiMetadata) +if data.allow_commercial_use: + print("Commercial use of this model is allowed") +""" + +from .metadata_base import ( + AnyModelRepoMetadata, + AnyModelRepoMetadataValidator, + CivitaiMetadata, + CommercialUsage, + HuggingFaceMetadata, + LicenseRestrictions, +) + +__all__ = [ + "AnyModelRepoMetadata", + "AnyModelRepoMetadataValidator", + "CommercialUsage", + "LicenseRestrictions", + "HuggingFaceMetadata", + "CivitaiMetadata", +] diff --git a/invokeai/backend/model_manager/metadata/fetch.py b/invokeai/backend/model_manager/metadata/fetch.py deleted file mode 100644 index 346ecafa1e..0000000000 --- a/invokeai/backend/model_manager/metadata/fetch.py +++ /dev/null @@ -1,139 +0,0 @@ -# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team - -""" -This module fetches metadata objects from HuggingFace and Civitai. - -Usage: - -from invokeai.backend.model_manager.metadata.fetch import MetadataFetch - -metadata = MetadataFetch.from_civitai_url("https://civitai.com/models/58390/detail-tweaker-lora-lora") -print(metadata.description) -""" - -import re -from pathlib import Path -from typing import Optional, Dict, Optional, Any -from datetime import datetime - -import requests -from huggingface_hub import HfApi, configure_http_backend -from huggingface_hub.utils._errors import RepositoryNotFoundError -from pydantic.networks import AnyHttpUrl -from requests.sessions import Session - -from invokeai.app.services.model_records import UnknownModelException - -from .base import ( - CivitaiMetadata, - HuggingFaceMetadata, - LicenseRestrictions, - CommercialUsage, -) - -HF_MODEL_RE = r"https?://huggingface.co/([\w\-.]+/[\w\-.]+)" -CIVITAI_MODEL_PAGE_RE = r"https?://civitai.com/models/(\d+)" -CIVITAI_VERSION_PAGE_RE = r"https?://civitai.com/models/(\d+)\?modelVersionId=(\d+)" -CIVITAI_DOWNLOAD_RE = r"https?://civitai.com/api/download/models/(\d+)" - -CIVITAI_VERSION_ENDPOINT = "https://civitai.com/api/v1/model-versions/" -CIVITAI_MODEL_ENDPOINT = "https://civitai.com/api/v1/models/" - -class MetadataFetch: - """Fetch metadata from HuggingFace and Civitai URLs.""" - - _requests: Session - - def __init__(self, session: Optional[Session]=None): - """ - Initialize the fetcher with an optional requests.sessions.Session object. - - By providing a configurable Session object, we can support unit tests on - this module without an internet connection. - """ - self._requests = session or requests.Session() - configure_http_backend(backend_factory = lambda: self._requests) - - def from_huggingface_repoid(self, repo_id: str) -> HuggingFaceMetadata: - """Return a HuggingFaceMetadata object given the model's repo_id.""" - try: - model_info = HfApi().model_info(repo_id=repo_id, files_metadata=True) - except RepositoryNotFoundError as excp: - raise UnknownModelException(f"'{repo_id}' not found. See trace for details.") from excp - - _, name = repo_id.split("/") - return HuggingFaceMetadata( - id = model_info.modelId, - author = model_info.author, - name = name, - last_modified = model_info.lastModified, - tags = model_info.tags, - tag_dict = model_info.card_data.to_dict(), - files = [Path(x.rfilename) for x in model_info.siblings] - ) - - def from_huggingface_url(self, url: AnyHttpUrl) -> HuggingFaceMetadata: - """ - Return a HuggingFaceMetadata object given the model's web page URL. - - In the case of an invalid or missing URL, raises a ModelNotFound exception. - """ - if match := re.match(HF_MODEL_RE, str(url)): - repo_id = match.group(1) - return self.from_huggingface_repoid(repo_id) - else: - raise UnknownModelException(f"'{url}' does not look like a HuggingFace model page") - - def from_civitai_versionid(self, version_id: int, model_metadata: Optional[Dict[str,Any]]=None) -> CivitaiMetadata: - """Return Civitai metadata using a model's version id.""" - version_url = CIVITAI_VERSION_ENDPOINT + str(version_id) - version = self._requests.get(version_url).json() - - model_url = CIVITAI_MODEL_ENDPOINT + str(version['modelId']) - model = model_metadata or self._requests.get(model_url).json() - safe_thumbnails = [x['url'] for x in version['images'] if x['nsfw']=='None'] - - return CivitaiMetadata( - id=version['modelId'], - name=version['model']['name'], - version_id=version['id'], - version_name=version['name'], - created=datetime.fromisoformat(re.sub(r"Z$", "+00:00", version['createdAt'])), - base_model_trained_on=version['baseModel'], # note - need a dictionary to turn into a BaseModelType - download_url=version['downloadUrl'], - thumbnail_url=safe_thumbnails[0] if safe_thumbnails else None, - author=model['creator']['username'], - description=model['description'], - version_description=version['description'] or "", - tags=model['tags'], - trained_words=version['trainedWords'], - nsfw=version['model']['nsfw'], - restrictions=LicenseRestrictions( - AllowNoCredit=model['allowNoCredit'], - AllowCommercialUse=CommercialUsage(model['allowCommercialUse']), - AllowDerivatives=model['allowDerivatives'], - AllowDifferentLicense=model['allowDifferentLicense'] - ), - ) - - def from_civitai_modelid(self, model_id: int) -> CivitaiMetadata: - """Return metadata from the default version of the indicated model.""" - model_url = CIVITAI_MODEL_ENDPOINT + str(model_id) - model = self._requests.get(model_url).json() - default_version = model['modelVersions'][0]['id'] - return self.from_civitai_versionid(default_version, model) - - def from_civitai_url(self, url: AnyHttpUrl) -> CivitaiMetadata: - """Parse a Civitai URL that user is likely to pass and return its metadata.""" - if match := re.match(CIVITAI_MODEL_PAGE_RE, str(url)): - model_id = match.group(1) - return self.from_civitai_modelid(int(model_id)) - elif match := re.match(CIVITAI_VERSION_PAGE_RE, str(url)): - version_id = match.group(1) - return self.from_civitai_versionid(int(version_id)) - elif match := re.match(CIVITAI_DOWNLOAD_RE, str(url)): - version_id = match.group(1) - return self.from_civitai_versionid(int(version_id)) - raise UnknownModelException("The url '{url}' does not match any known Civitai URL patterns") - - diff --git a/invokeai/backend/model_manager/metadata/fetch/__init__.py b/invokeai/backend/model_manager/metadata/fetch/__init__.py new file mode 100644 index 0000000000..d09f68eb08 --- /dev/null +++ b/invokeai/backend/model_manager/metadata/fetch/__init__.py @@ -0,0 +1,21 @@ +""" +Initialization file for invokeai.backend.model_manager.metadata.fetch + +Usage: +from invokeai.backend.model_manager.metadata.fetch import ( + CivitaiMetadataFetch, + HuggingFaceMetadataFetch, +) +from invokeai.backend.model_manager.metadata import CivitaiMetadata + +data = CivitaiMetadataFetch().from_url("https://civitai.com/models/206883/split") +assert isinstance(data, CivitaiMetadata) +if data.allow_commercial_use: + print("Commercial use of this model is allowed") +""" + +from .fetch_base import ModelMetadataFetchBase +from .fetch_civitai import CivitaiMetadataFetch +from .fetch_huggingface import HuggingFaceMetadataFetch + +__all__ = ["ModelMetadataFetchBase", "CivitaiMetadataFetch", "HuggingFaceMetadataFetch"] diff --git a/invokeai/backend/model_manager/metadata/fetch/fetch_base.py b/invokeai/backend/model_manager/metadata/fetch/fetch_base.py new file mode 100644 index 0000000000..c063685f9a --- /dev/null +++ b/invokeai/backend/model_manager/metadata/fetch/fetch_base.py @@ -0,0 +1,63 @@ +# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team + +""" +This module is the base class for subclasses that fetch metadata from model repositories + +Usage: + +from invokeai.backend.model_manager.metadata.fetch import CivitAIMetadataFetch + +fetcher = CivitaiMetadataFetch() +metadata = fetcher.from_url("https://civitai.com/models/206883/split") +print(metadata.trained_words) +""" + +from abc import ABC, abstractmethod +from typing import Optional + +from pydantic.networks import AnyHttpUrl +from requests.sessions import Session + +from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator + + +class ModelMetadataFetchBase(ABC): + """Fetch metadata from remote generative model repositories.""" + + @abstractmethod + def __init__(self, session: Optional[Session] = None): + """ + Initialize the fetcher with an optional requests.sessions.Session object. + + By providing a configurable Session object, we can support unit tests on + this module without an internet connection. + """ + pass + + @abstractmethod + def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata: + """ + Given a URL to a model repository, return a ModelMetadata object. + + This method will raise an `invokeai.app.services.model_records.UnknownModelException` + in the event that the requested model metadata is not found at the provided location. + """ + pass + + @abstractmethod + def from_id(self, id: str) -> AnyModelRepoMetadata: + """ + Given an ID for a model, return a ModelMetadata object. + + This method will raise an `invokeai.app.services.model_records.UnknownModelException` + in the event that the requested model's metadata is not found at the provided id. + """ + pass + + @classmethod + def from_json(self, json: str) -> AnyModelRepoMetadata: + """Given the JSON representation of the metadata, return the corresponding Pydantic object.""" + metadata = AnyModelRepoMetadataValidator.validate_json(json) + return ( + metadata # mypy complains that metadata is a and issues a type checking error. Why? + ) diff --git a/invokeai/backend/model_manager/metadata/fetch/fetch_civitai.py b/invokeai/backend/model_manager/metadata/fetch/fetch_civitai.py new file mode 100644 index 0000000000..6928765f99 --- /dev/null +++ b/invokeai/backend/model_manager/metadata/fetch/fetch_civitai.py @@ -0,0 +1,150 @@ +# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team + +""" +This module fetches model metadata objects from the Civitai model repository. +In addition to the `from_url()` and `from_id()` methods inherited from the +`ModelMetadataFetchBase` base class. + +Civitai has two separate ID spaces: a model ID and a version ID. The +version ID corresponds to a specific model, and is the ID accepted by +`from_id()`. The model ID corresponds to a family of related models, +such as different training checkpoints or 16 vs 32-bit versions. The +`from_civitai_modelid()` method will accept a model ID and return the +metadata from the default version within this model set. The default +version is the same as what the user sees when they click on a model's +thumbnail. + +Usage: + +from invokeai.backend.model_manager.metadata.fetch import CivitaiMetadataFetch + +fetcher = CivitaiMetadataFetch() +metadata = fetcher.from_url("https://civitai.com/models/206883/split") +print(metadata.trained_words) +""" + +import re +from datetime import datetime +from typing import Any, Dict, Optional + +import requests +from pydantic.networks import AnyHttpUrl +from requests.sessions import Session + +from invokeai.app.services.model_records import UnknownModelException + +from ..metadata_base import ( + AnyModelRepoMetadata, + AnyModelRepoMetadataValidator, + CivitaiMetadata, + CommercialUsage, + LicenseRestrictions, +) +from .fetch_base import ModelMetadataFetchBase + +CIVITAI_MODEL_PAGE_RE = r"https?://civitai.com/models/(\d+)" +CIVITAI_VERSION_PAGE_RE = r"https?://civitai.com/models/(\d+)\?modelVersionId=(\d+)" +CIVITAI_DOWNLOAD_RE = r"https?://civitai.com/api/download/models/(\d+)" + +CIVITAI_VERSION_ENDPOINT = "https://civitai.com/api/v1/model-versions/" +CIVITAI_MODEL_ENDPOINT = "https://civitai.com/api/v1/models/" + + +class CivitaiMetadataFetch(ModelMetadataFetchBase): + """Fetch model metadata from Civitai.""" + + _requests: Session + + def __init__(self, session: Optional[Session] = None): + """ + Initialize the fetcher with an optional requests.sessions.Session object. + + By providing a configurable Session object, we can support unit tests on + this module without an internet connection. + """ + self._requests = session or requests.Session() + + def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata: + """ + Given a URL to a CivitAI model or version page, return a ModelMetadata object. + + In the event that the URL points to a model page without the particular version + indicated, the default model version is returned. Otherwise, the requested version + is returned. + """ + if match := re.match(CIVITAI_MODEL_PAGE_RE, str(url)): + model_id = match.group(1) + return self.from_civitai_modelid(int(model_id)) + elif match := re.match(CIVITAI_VERSION_PAGE_RE, str(url)): + version_id = match.group(1) + return self._from_civitai_versionid(int(version_id)) + elif match := re.match(CIVITAI_DOWNLOAD_RE, str(url)): + version_id = match.group(1) + return self._from_civitai_versionid(int(version_id)) + raise UnknownModelException("The url '{url}' does not match any known Civitai URL patterns") + + def from_id(self, id: str) -> AnyModelRepoMetadata: + """ + Given a Civitai model version ID, return a ModelRepoMetadata object. + + May raise an `UnknownModelException`. + """ + return self._from_civitai_versionid(int(id)) + + def from_civitai_modelid(self, model_id: int) -> CivitaiMetadata: + """ + Return metadata from the default version of the indicated model. + + May raise an `UnknownModelException`. + """ + model_url = CIVITAI_MODEL_ENDPOINT + str(model_id) + model = self._requests.get(model_url).json() + default_version = model["modelVersions"][0]["id"] + return self._from_civitai_versionid(default_version, model) + + def _from_civitai_versionid( + self, version_id: int, model_metadata: Optional[Dict[str, Any]] = None + ) -> CivitaiMetadata: + version_url = CIVITAI_VERSION_ENDPOINT + str(version_id) + version = self._requests.get(version_url).json() + + model_url = CIVITAI_MODEL_ENDPOINT + str(version["modelId"]) + model = model_metadata or self._requests.get(model_url).json() + safe_thumbnails = [x["url"] for x in version["images"] if x["nsfw"] == "None"] + + # It would be more elegant to define a Pydantic BaseModel that matches the Civitai metadata JSON. + # However the contents of the JSON does not exactly match the documentation at + # https://github.com/civitai/civitai/wiki/REST-API-Reference, and it feels safer to cherry pick + # a subset of the fields. + # + # In addition, there are some fields that I want to pick up from the model JSON, such as `tags`, + # that are not present in the version JSON. + return CivitaiMetadata( + id=version["modelId"], + name=version["model"]["name"], + version_id=version["id"], + version_name=version["name"], + created=datetime.fromisoformat(re.sub(r"Z$", "+00:00", version["createdAt"])), + base_model_trained_on=version["baseModel"], # note - need a dictionary to turn into a BaseModelType + download_url=version["downloadUrl"], + thumbnail_url=safe_thumbnails[0] if safe_thumbnails else None, + author=model["creator"]["username"], + description=model["description"], + version_description=version["description"] or "", + tags=model["tags"], + trained_words=version["trainedWords"], + nsfw=version["model"]["nsfw"], + restrictions=LicenseRestrictions( + AllowNoCredit=model["allowNoCredit"], + AllowCommercialUse=CommercialUsage(model["allowCommercialUse"]), + AllowDerivatives=model["allowDerivatives"], + AllowDifferentLicense=model["allowDifferentLicense"], + ), + ) + + @classmethod + def from_json(cls, json: str) -> CivitaiMetadata: + """Given the JSON representation of the metadata, return the corresponding Pydantic object.""" + metadata = AnyModelRepoMetadataValidator.validate_json(json) + assert isinstance(metadata, CivitaiMetadata) + return metadata diff --git a/invokeai/backend/model_manager/metadata/fetch/fetch_huggingface.py b/invokeai/backend/model_manager/metadata/fetch/fetch_huggingface.py new file mode 100644 index 0000000000..3eb4cf37a5 --- /dev/null +++ b/invokeai/backend/model_manager/metadata/fetch/fetch_huggingface.py @@ -0,0 +1,83 @@ +# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team + +""" +This module fetches model metadata objects from the HuggingFace model repository, +using either a `repo_id` or the model page URL. + +Usage: + +from invokeai.backend.model_manager.metadata.fetch import HuggingFaceMetadataFetch + +fetcher = HuggingFaceMetadataFetch() +metadata = fetcher.from_url("https://huggingface.co/stabilityai/sdxl-turbo") +print(metadata.tags) +""" + +import re +from pathlib import Path +from typing import Optional + +import requests +from huggingface_hub import HfApi, configure_http_backend +from huggingface_hub.utils._errors import RepositoryNotFoundError +from pydantic.networks import AnyHttpUrl +from requests.sessions import Session + +from invokeai.app.services.model_records import UnknownModelException + +from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator, HuggingFaceMetadata +from .fetch_base import ModelMetadataFetchBase + + +class HuggingFaceMetadataFetch(ModelMetadataFetchBase): + """Fetch model metadata from HuggingFace.""" + + _requests: Session + + def __init__(self, session: Optional[Session] = None): + """ + Initialize the fetcher with an optional requests.sessions.Session object. + + By providing a configurable Session object, we can support unit tests on + this module without an internet connection. + """ + self._requests = session or requests.Session() + configure_http_backend(backend_factory=lambda: self._requests) + + def from_id(self, id: str) -> AnyModelRepoMetadata: + """Return a HuggingFaceMetadata object given the model's repo_id.""" + try: + model_info = HfApi().model_info(repo_id=id, files_metadata=True) + except RepositoryNotFoundError as excp: + raise UnknownModelException(f"'{id}' not found. See trace for details.") from excp + + _, name = id.split("/") + return HuggingFaceMetadata( + id=model_info.modelId, + author=model_info.author, + name=name, + last_modified=model_info.lastModified, + tag_dict=model_info.card_data.to_dict(), + tags=model_info.tags, + files=[Path(x.rfilename) for x in model_info.siblings], + ) + + def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata: + """ + Return a HuggingFaceMetadata object given the model's web page URL. + + In the case of an invalid or missing URL, raises a ModelNotFound exception. + """ + HF_MODEL_RE = r"https?://huggingface.co/([\w\-.]+/[\w\-.]+)" + if match := re.match(HF_MODEL_RE, str(url)): + repo_id = match.group(1) + return self.from_id(repo_id) + else: + raise UnknownModelException(f"'{url}' does not look like a HuggingFace model page") + + @classmethod + def from_json(cls, json: str) -> HuggingFaceMetadata: + """Given the JSON representation of the metadata, return the corresponding Pydantic object.""" + metadata = AnyModelRepoMetadataValidator.validate_json(json) + assert isinstance(metadata, HuggingFaceMetadata) + return metadata diff --git a/invokeai/backend/model_manager/metadata/base.py b/invokeai/backend/model_manager/metadata/metadata_base.py similarity index 71% rename from invokeai/backend/model_manager/metadata/base.py rename to invokeai/backend/model_manager/metadata/metadata_base.py index e9ef8cd97f..0ef289c43e 100644 --- a/invokeai/backend/model_manager/metadata/base.py +++ b/invokeai/backend/model_manager/metadata/metadata_base.py @@ -15,10 +15,12 @@ This may need reworking. from datetime import datetime from enum import Enum from pathlib import Path -from typing import Any, Dict, List, Set, Optional +from typing import Any, Dict, List, Literal, Optional, Set, Union -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, TypeAdapter from pydantic.networks import AnyHttpUrl +from typing_extensions import Annotated + class CommercialUsage(str, Enum): """Type of commercial usage allowed.""" @@ -29,13 +31,21 @@ class CommercialUsage(str, Enum): RentCivit = "RentCivit" Sell = "Sell" + class LicenseRestrictions(BaseModel): """Broad categories of licensing restrictions.""" - AllowNoCredit: bool = Field(description="if true, model can be redistributed without crediting author", default=False) + AllowNoCredit: bool = Field( + description="if true, model can be redistributed without crediting author", default=False + ) AllowDerivatives: bool = Field(description="if true, derivatives of this model can be redistributed", default=False) - AllowDifferentLicense: bool = Field(description="if true, derivatives of this model be redistributed under a different license", default=False) - AllowCommercialUse: CommercialUsage = Field(description="Type of commercial use allowed or 'No' if no commercial use is allowed.", default_factory=set) + AllowDifferentLicense: bool = Field( + description="if true, derivatives of this model be redistributed under a different license", default=False + ) + AllowCommercialUse: CommercialUsage = Field( + description="Type of commercial use allowed or 'No' if no commercial use is allowed.", default_factory=set + ) + class ModelMetadataBase(BaseModel): """Base class for model metadata information.""" @@ -44,32 +54,41 @@ class ModelMetadataBase(BaseModel): author: str = Field(description="model's author") tags: Set[str] = Field(description="tags provided by model source") + class HuggingFaceMetadata(ModelMetadataBase): """Extended metadata fields provided by HuggingFace.""" + type: Literal["huggingface"] = "huggingface" id: str = Field(description="huggingface model id") tag_dict: Dict[str, Any] last_modified: datetime = Field(description="date of last commit to repo") files: List[Path] = Field(description="sibling files that belong to this model", default_factory=list) + class CivitaiMetadata(ModelMetadataBase): """Extended metadata fields provided by Civitai.""" + type: Literal["civitai"] = "civitai" id: int = Field(description="Civitai model identifier") version_name: str = Field(description="Version identifier, such as 'V2-alpha'") version_id: int = Field(description="Civitai model version identifier") created: datetime = Field(description="date the model was posted to CivitAI") description: str = Field(description="text description of model; may contain HTML") - version_description: str = Field(description="text description of the model's reversion; usually change history; may contain HTML") + version_description: str = Field( + description="text description of the model's reversion; usually change history; may contain HTML" + ) nsfw: bool = Field(description="whether the model tends to generate NSFW content", default=False) restrictions: LicenseRestrictions = Field(description="license terms", default=LicenseRestrictions) trained_words: Set[str] = Field(description="words to trigger the model", default_factory=set) download_url: AnyHttpUrl = Field(description="download URL for this model") base_model_trained_on: str = Field(description="base model on which this model was trained (currently not an enum)") thumbnail_url: Optional[AnyHttpUrl] = Field(description="a thumbnail image for this model", default=None) - weight_min: float = Field(description="minimum suggested value for a LoRA or other secondary model", default=-1.0) # note: For future use; not currently easily - weight_max: float = Field(description="maximum suggested value for a LoRA or other secondary model", default=+2.0) # recoverable from metadata - + weight_min: float = Field( + description="minimum suggested value for a LoRA or other secondary model", default=-1.0 + ) # note: For future use; not currently easily + weight_max: float = Field( + description="maximum suggested value for a LoRA or other secondary model", default=+2.0 + ) # recoverable from metadata @property def credit_required(self) -> bool: @@ -79,7 +98,7 @@ class CivitaiMetadata(ModelMetadataBase): @property def allow_commercial_use(self) -> bool: """Return True if commercial use is allowed.""" - return self.restrictions.AllowCommercialUse == CommercialUsage('None') + return self.restrictions.AllowCommercialUse == CommercialUsage("None") @property def allow_derivatives(self) -> bool: @@ -90,3 +109,7 @@ class CivitaiMetadata(ModelMetadataBase): def allow_different_license(self) -> bool: """Return true if derivatives of this model can use a different license.""" return self.restrictions.AllowDifferentLicense + + +AnyModelRepoMetadata = Annotated[Union[HuggingFaceMetadata, CivitaiMetadata], Field(discriminator="type")] +AnyModelRepoMetadataValidator = TypeAdapter(AnyModelRepoMetadata) diff --git a/invokeai/backend/model_manager/metadata/metadata_store.py b/invokeai/backend/model_manager/metadata/metadata_store.py new file mode 100644 index 0000000000..f42058802a --- /dev/null +++ b/invokeai/backend/model_manager/metadata/metadata_store.py @@ -0,0 +1,185 @@ +# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team +""" +SQL Storage for Model Metadata +""" + +import sqlite3 +from typing import Set + +from invokeai.app.services.model_records import UnknownModelException +from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase + +from .fetch import ModelMetadataFetchBase +from .metadata_base import AnyModelRepoMetadata + + +class ModelMetadataStore: + """Store, search and fetch model metadata retrieved from remote repositories.""" + + _db: SqliteDatabase + _cursor: sqlite3.Cursor + + def __init__(self, db: SqliteDatabase): + """ + Initialize a new object from preexisting sqlite3 connection and threading lock objects. + + :param conn: sqlite3 connection object + :param lock: threading Lock object + """ + super().__init__() + self._db = db + self._cursor = self._db.conn.cursor() + self._enable_foreign_key_constraints() + + def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None: + """ + Add a block of repo metadata to a model record. + + The model record config must already exist in the database with the + same key. Otherwise a FOREIGN KEY constraint exception will be raised. + + :param model_key: Existing model key in the `model_config` table + :param metadata: ModelRepoMetadata object to store + """ + json_serialized = metadata.model_dump_json() + with self._db.lock: + try: + self._cursor.execute( + """--sql + INSERT INTO model_metadata( + id, + metadata + ) + VALUES (?,?); + """, + ( + model_key, + json_serialized, + ), + ) + self._update_tags(model_key, metadata.tags) + self._db.conn.commit() + except sqlite3.Error as e: + self._db.conn.rollback() + raise e + + def get_metadata(self, model_key: str) -> AnyModelRepoMetadata: + """Retrieve the ModelRepoMetadata corresponding to model key.""" + with self._db.lock: + self._cursor.execute( + """--sql + SELECT metadata FROM model_metadata + WHERE id=?; + """, + (model_key,), + ) + rows = self._cursor.fetchone() + if not rows: + raise UnknownModelException("model metadata not found") + return ModelMetadataFetchBase.from_json(rows[0]) + + def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata: + """ + Update metadata corresponding to the model with the indicated key. + + :param model_key: Existing model key in the `model_config` table + :param metadata: ModelRepoMetadata object to update + """ + json_serialized = metadata.model_dump_json() # turn it into a json string. + with self._db.lock: + try: + self._cursor.execute( + """--sql + UPDATE model_metadata + SET + metadata=? + WHERE id=?; + """, + (json_serialized, model_key), + ) + if self._cursor.rowcount == 0: + raise UnknownModelException("model not found") + self._update_tags(model_key, metadata.tags) + self._db.conn.commit() + except sqlite3.Error as e: + self._db.conn.rollback() + raise e + + return self.get_metadata(model_key) + + def search_by_tag(self, tags: Set[str]) -> Set[str]: + """Return the keys of models containing all of the listed tags.""" + with self._db.lock: + try: + matches: Set[str] = set() + for tag in tags: + self._cursor.execute( + """--sql + SELECT a.id FROM model_tags AS a, + tags AS b + WHERE a.tag_id=b.tag_id + AND b.tag_text=?; + """, + (tag,), + ) + model_keys = {x[0] for x in self._cursor.fetchall()} + matches = matches.intersection(model_keys) if len(matches) > 0 else model_keys + except sqlite3.Error as e: + raise e + return matches + + def search_by_author(self, author: str) -> Set[str]: + """Return the keys of models authored by the indicated author.""" + self._cursor.execute( + """--sql + SELECT id FROM model_metadata + WHERE author=?; + """, + (author,), + ) + return {x[0] for x in self._cursor.fetchall()} + + def _update_tags(self, model_key: str, tags: Set[str]) -> None: + """Update tags for the model referenced by model_key.""" + # remove previous tags from this model + self._cursor.execute( + """--sql + DELETE FROM model_tags + WHERE id=?; + """, + (model_key,), + ) + + for tag in tags: + self._cursor.execute( + """--sql + INSERT OR IGNORE INTO tags ( + tag_text + ) + VALUES (?); + """, + (tag,), + ) + self._cursor.execute( + """--sql + SELECT tag_id + FROM tags + WHERE tag_text = ? + LIMIT 1; + """, + (tag,), + ) + tag_id = self._cursor.fetchone()[0] + self._cursor.execute( + """--sql + INSERT OR IGNORE INTO model_tags ( + id, + tag_id + ) + VALUES (?,?); + """, + (model_key, tag_id), + ) + + def _enable_foreign_key_constraints(self) -> None: + self._cursor.execute("PRAGMA foreign_keys = ON;")