From 44c40d7d1a2dd2695016bcdd613c128032a6ca01 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 4 Mar 2024 21:38:21 +1100 Subject: [PATCH] refactor(mm): remove unused metadata logic, fix tests - Metadata is merged with the config. We can simplify the MM substantially and remove the handling for metadata. - Per discussion, we don't have an ETA for frontend implementation of tags, and with the realization that the tags from CivitAI are largely useless, there's no reason to keep tags in the MM right now. When we are ready to implement tags on the frontend, we can refer back to the implementation here and use it if it supports the design. - Fix all tests. --- invokeai/app/api/dependencies.py | 4 +- invokeai/app/api/routers/model_manager.py | 136 +---------- .../model_install/model_install_base.py | 3 - .../model_install/model_install_default.py | 1 - .../app/services/model_metadata/__init__.py | 9 - .../model_metadata/metadata_store_base.py | 81 ------- .../model_metadata/metadata_store_sql.py | 223 ------------------ .../model_records/model_records_base.py | 39 +-- .../model_records/model_records_sql.py | 62 +---- .../sqlite_migrator/migrations/migration_7.py | 11 +- invokeai/app/util/metadata.py | 55 ----- invokeai/backend/model_manager/config.py | 8 +- .../metadata/fetch/huggingface.py | 2 +- .../model_manager/metadata/metadata_store.py | 221 ----------------- invokeai/backend/model_manager/probe.py | 3 +- .../model_records/test_model_records_sql.py | 165 ++++++------- .../model_manager/model_manager_fixtures.py | 132 ++++++----- .../model_metadata/test_model_metadata.py | 202 ---------------- 18 files changed, 170 insertions(+), 1187 deletions(-) delete mode 100644 invokeai/app/services/model_metadata/__init__.py delete mode 100644 invokeai/app/services/model_metadata/metadata_store_base.py delete mode 100644 invokeai/app/services/model_metadata/metadata_store_sql.py delete mode 100644 invokeai/app/util/metadata.py delete mode 100644 invokeai/backend/model_manager/metadata/metadata_store.py delete mode 100644 tests/backend/model_manager/model_metadata/test_model_metadata.py diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 95407291ec..c26122cc77 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -26,7 +26,6 @@ from ..services.invocation_services import InvocationServices from ..services.invocation_stats.invocation_stats_default import InvocationStatsService from ..services.invoker import Invoker from ..services.model_manager.model_manager_default import ModelManagerService -from ..services.model_metadata import ModelMetadataStoreSQL from ..services.model_records import ModelRecordServiceSQL from ..services.names.names_default import SimpleNameService from ..services.session_processor.session_processor_default import DefaultSessionProcessor @@ -93,10 +92,9 @@ class ApiDependencies: ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True) ) download_queue_service = DownloadQueueService(event_bus=events) - model_metadata_service = ModelMetadataStoreSQL(db=db) model_manager = ModelManagerService.build_model_manager( app_config=configuration, - model_record_service=ModelRecordServiceSQL(db=db, metadata_store=model_metadata_service), + model_record_service=ModelRecordServiceSQL(db=db), download_queue=download_queue_service, events=events, ) diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index 04a7a7ad05..5e12a9923f 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -12,7 +12,6 @@ from starlette.exceptions import HTTPException from typing_extensions import Annotated from invokeai.app.services.model_install import ModelInstallJob -from invokeai.app.services.model_metadata.metadata_store_base import ModelMetadataChanges from invokeai.app.services.model_records import ( DuplicateModelException, InvalidModelException, @@ -30,8 +29,6 @@ from invokeai.backend.model_manager.config import ( SubModelType, ) from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger -from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata -from invokeai.backend.model_manager.metadata.metadata_base import BaseMetadata from invokeai.backend.model_manager.search import ModelSearch from ..dependencies import ApiDependencies @@ -90,44 +87,6 @@ example_model_input = { "variant": "normal", } -example_model_metadata = { - "name": "ip_adapter_sd_image_encoder", - "author": "InvokeAI", - "tags": [ - "transformers", - "safetensors", - "clip_vision_model", - "endpoints_compatible", - "region:us", - "has_space", - "license:apache-2.0", - ], - "files": [ - { - "url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/README.md", - "path": "ip_adapter_sd_image_encoder/README.md", - "size": 628, - "sha256": None, - }, - { - "url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/config.json", - "path": "ip_adapter_sd_image_encoder/config.json", - "size": 560, - "sha256": None, - }, - { - "url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/model.safetensors", - "path": "ip_adapter_sd_image_encoder/model.safetensors", - "size": 2528373448, - "sha256": "6ca9667da1ca9e0b0f75e46bb030f7e011f44f86cbfb8d5a36590fcd7507b030", - }, - ], - "type": "huggingface", - "id": "InvokeAI/ip_adapter_sd_image_encoder", - "tag_dict": {"license": "apache-2.0"}, - "last_modified": "2023-09-23T17:33:25Z", -} - ############################################################################## # ROUTES ############################################################################## @@ -219,79 +178,6 @@ async def list_model_summary( return results -@model_manager_router.get( - "/i/{key}/metadata", - operation_id="get_model_metadata", - responses={ - 200: { - "description": "The model metadata was retrieved successfully", - "content": {"application/json": {"example": example_model_metadata}}, - }, - 400: {"description": "Bad request"}, - }, -) -async def get_model_metadata( - key: str = Path(description="Key of the model repo metadata to fetch."), -) -> Optional[AnyModelRepoMetadata]: - """Get a model metadata object.""" - record_store = ApiDependencies.invoker.services.model_manager.store - result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key) - - return result - - -@model_manager_router.patch( - "/i/{key}/metadata", - operation_id="update_model_metadata", - responses={ - 201: { - "description": "The model metadata was updated successfully", - "content": {"application/json": {"example": example_model_metadata}}, - }, - 400: {"description": "Bad request"}, - }, -) -async def update_model_metadata( - key: str = Path(description="Key of the model repo metadata to fetch."), - changes: ModelMetadataChanges = Body(description="The changes"), -) -> Optional[AnyModelRepoMetadata]: - """Updates or creates a model metadata object.""" - record_store = ApiDependencies.invoker.services.model_manager.store - metadata_store = ApiDependencies.invoker.services.model_manager.store.metadata_store - - try: - original_metadata = record_store.get_metadata(key) - if original_metadata: - if changes.default_settings: - original_metadata.default_settings = changes.default_settings - - metadata_store.update_metadata(key, original_metadata) - else: - metadata_store.add_metadata( - key, BaseMetadata(name="", author="", default_settings=changes.default_settings) - ) - except Exception as e: - raise HTTPException( - status_code=500, - detail=f"An error occurred while updating the model metadata: {e}", - ) - - result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key) - - return result - - -@model_manager_router.get( - "/tags", - operation_id="list_tags", -) -async def list_tags() -> Set[str]: - """Get a unique set of all the model tags.""" - record_store = ApiDependencies.invoker.services.model_manager.store - result: Set[str] = record_store.list_tags() - return result - - class FoundModel(BaseModel): path: str = Field(description="Path to the model") is_installed: bool = Field(description="Whether or not the model is already installed") @@ -361,19 +247,6 @@ async def scan_for_models( return scan_results -@model_manager_router.get( - "/tags/search", - operation_id="search_by_metadata_tags", -) -async def search_by_metadata_tags( - tags: Set[str] = Query(default=None, description="Tags to search for"), -) -> ModelsList: - """Get a list of models.""" - record_store = ApiDependencies.invoker.services.model_manager.store - results = record_store.search_by_metadata_tag(tags) - return ModelsList(models=results) - - @model_manager_router.patch( "/i/{key}", operation_id="update_model_record", @@ -562,9 +435,8 @@ async def list_model_install_jobs() -> List[ModelInstallJob]: * "cancelled" -- Job was cancelled before completion. Once completed, information about the model such as its size, base - model, type, and metadata can be retrieved from the `config_out` - field. For multi-file models such as diffusers, information on individual files - can be retrieved from `download_parts`. + model and type can be retrieved from the `config_out` field. For multi-file models such as diffusers, + information on individual files can be retrieved from `download_parts`. See the example and schema below for more information. """ @@ -708,10 +580,6 @@ async def convert_model( logger.error(str(e)) raise HTTPException(status_code=409, detail=str(e)) - # get the original metadata - if orig_metadata := store.get_metadata(key): - store.metadata_store.add_metadata(new_key, orig_metadata) - # delete the original safetensors file installer.delete(key) diff --git a/invokeai/app/services/model_install/model_install_base.py b/invokeai/app/services/model_install/model_install_base.py index f85d4ae9ea..81860f5ad7 100644 --- a/invokeai/app/services/model_install/model_install_base.py +++ b/invokeai/app/services/model_install/model_install_base.py @@ -21,8 +21,6 @@ from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant from invokeai.backend.model_manager.config import ModelSourceType from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata -from ..model_metadata import ModelMetadataStoreBase - class InstallStatus(str, Enum): """State of an install job running in the background.""" @@ -268,7 +266,6 @@ class ModelInstallServiceBase(ABC): app_config: InvokeAIAppConfig, record_store: ModelRecordServiceBase, download_queue: DownloadQueueServiceBase, - metadata_store: ModelMetadataStoreBase, event_bus: Optional["EventServiceBase"] = None, ): """ diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index eaea5e5ff4..e12a499648 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -94,7 +94,6 @@ class ModelInstallService(ModelInstallServiceBase): self._running = False self._session = session self._next_job_id = 0 - self._metadata_store = record_store.metadata_store # for convenience @property def app_config(self) -> InvokeAIAppConfig: # noqa D102 diff --git a/invokeai/app/services/model_metadata/__init__.py b/invokeai/app/services/model_metadata/__init__.py deleted file mode 100644 index 981c96b709..0000000000 --- a/invokeai/app/services/model_metadata/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -"""Init file for ModelMetadataStoreService module.""" - -from .metadata_store_base import ModelMetadataStoreBase -from .metadata_store_sql import ModelMetadataStoreSQL - -__all__ = [ - "ModelMetadataStoreBase", - "ModelMetadataStoreSQL", -] diff --git a/invokeai/app/services/model_metadata/metadata_store_base.py b/invokeai/app/services/model_metadata/metadata_store_base.py deleted file mode 100644 index 882575a4bf..0000000000 --- a/invokeai/app/services/model_metadata/metadata_store_base.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team -""" -Storage for Model Metadata -""" - -from abc import ABC, abstractmethod -from typing import List, Optional, Set, Tuple - -from pydantic import Field - -from invokeai.app.util.model_exclude_null import BaseModelExcludeNull -from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata -from invokeai.backend.model_manager.metadata.metadata_base import ModelDefaultSettings - - -class ModelMetadataChanges(BaseModelExcludeNull, extra="allow"): - """A set of changes to apply to model metadata. - Only limited changes are valid: - - `default_settings`: the user-configured default settings for this model - """ - - default_settings: Optional[ModelDefaultSettings] = Field( - default=None, description="The user-configured default settings for this model" - ) - """The user-configured default settings for this model""" - - -class ModelMetadataStoreBase(ABC): - """Store, search and fetch model metadata retrieved from remote repositories.""" - - @abstractmethod - def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None: - """ - Add a block of repo metadata to a model record. - - The model record config must already exist in the database with the - same key. Otherwise a FOREIGN KEY constraint exception will be raised. - - :param model_key: Existing model key in the `model_config` table - :param metadata: ModelRepoMetadata object to store - """ - - @abstractmethod - def get_metadata(self, model_key: str) -> AnyModelRepoMetadata: - """Retrieve the ModelRepoMetadata corresponding to model key.""" - - @abstractmethod - def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata - """Dump out all the metadata.""" - - @abstractmethod - def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata: - """ - Update metadata corresponding to the model with the indicated key. - - :param model_key: Existing model key in the `model_config` table - :param metadata: ModelRepoMetadata object to update - """ - - @abstractmethod - def list_tags(self) -> Set[str]: - """Return all tags in the tags table.""" - - @abstractmethod - def search_by_tag(self, tags: Set[str]) -> Set[str]: - """Return the keys of models containing all of the listed tags.""" - - @abstractmethod - def search_by_author(self, author: str) -> Set[str]: - """Return the keys of models authored by the indicated author.""" - - @abstractmethod - def search_by_name(self, name: str) -> Set[str]: - """ - Return the keys of models with the indicated name. - - Note that this is the name of the model given to it by - the remote source. The user may have changed the local - name. The local name will be located in the model config - record object. - """ diff --git a/invokeai/app/services/model_metadata/metadata_store_sql.py b/invokeai/app/services/model_metadata/metadata_store_sql.py deleted file mode 100644 index 4f8170448f..0000000000 --- a/invokeai/app/services/model_metadata/metadata_store_sql.py +++ /dev/null @@ -1,223 +0,0 @@ -# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team -""" -SQL Storage for Model Metadata -""" - -import sqlite3 -from typing import List, Optional, Set, Tuple - -from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase -from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException -from invokeai.backend.model_manager.metadata.fetch import ModelMetadataFetchBase - -from .metadata_store_base import ModelMetadataStoreBase - - -class ModelMetadataStoreSQL(ModelMetadataStoreBase): - """Store, search and fetch model metadata retrieved from remote repositories.""" - - def __init__(self, db: SqliteDatabase): - """ - Initialize a new object from preexisting sqlite3 connection and threading lock objects. - - :param conn: sqlite3 connection object - :param lock: threading Lock object - """ - super().__init__() - self._db = db - self._cursor = self._db.conn.cursor() - - def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None: - """ - Add a block of repo metadata to a model record. - - The model record config must already exist in the database with the - same key. Otherwise a FOREIGN KEY constraint exception will be raised. - - :param model_key: Existing model key in the `model_config` table - :param metadata: ModelRepoMetadata object to store - """ - json_serialized = metadata.model_dump_json() - with self._db.lock: - try: - self._cursor.execute( - """--sql - INSERT INTO model_metadata( - id, - metadata - ) - VALUES (?,?); - """, - ( - model_key, - json_serialized, - ), - ) - self._update_tags(model_key, metadata.tags) - self._db.conn.commit() - except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table - self._db.conn.rollback() - raise UnknownMetadataException from excp - except sqlite3.Error as excp: - self._db.conn.rollback() - raise excp - - def get_metadata(self, model_key: str) -> AnyModelRepoMetadata: - """Retrieve the ModelRepoMetadata corresponding to model key.""" - with self._db.lock: - self._cursor.execute( - """--sql - SELECT metadata FROM model_metadata - WHERE id=?; - """, - (model_key,), - ) - rows = self._cursor.fetchone() - if not rows: - raise UnknownMetadataException("model metadata not found") - return ModelMetadataFetchBase.from_json(rows[0]) - - def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata - """Dump out all the metadata.""" - with self._db.lock: - self._cursor.execute( - """--sql - SELECT id,metadata FROM model_metadata; - """, - (), - ) - rows = self._cursor.fetchall() - return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows] - - def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata: - """ - Update metadata corresponding to the model with the indicated key. - - :param model_key: Existing model key in the `model_config` table - :param metadata: ModelRepoMetadata object to update - """ - json_serialized = metadata.model_dump_json() # turn it into a json string. - with self._db.lock: - try: - self._cursor.execute( - """--sql - UPDATE model_metadata - SET - metadata=? - WHERE id=?; - """, - (json_serialized, model_key), - ) - if self._cursor.rowcount == 0: - raise UnknownMetadataException("model metadata not found") - self._update_tags(model_key, metadata.tags) - self._db.conn.commit() - except sqlite3.Error as e: - self._db.conn.rollback() - raise e - - return self.get_metadata(model_key) - - def list_tags(self) -> Set[str]: - """Return all tags in the tags table.""" - self._cursor.execute( - """--sql - select tag_text from tags; - """ - ) - return {x[0] for x in self._cursor.fetchall()} - - def search_by_tag(self, tags: Set[str]) -> Set[str]: - """Return the keys of models containing all of the listed tags.""" - with self._db.lock: - try: - matches: Optional[Set[str]] = None - for tag in tags: - self._cursor.execute( - """--sql - SELECT a.model_id FROM model_tags AS a, - tags AS b - WHERE a.tag_id=b.tag_id - AND b.tag_text=?; - """, - (tag,), - ) - model_keys = {x[0] for x in self._cursor.fetchall()} - if matches is None: - matches = model_keys - matches = matches.intersection(model_keys) - except sqlite3.Error as e: - raise e - return matches if matches else set() - - def search_by_author(self, author: str) -> Set[str]: - """Return the keys of models authored by the indicated author.""" - self._cursor.execute( - """--sql - SELECT id FROM model_metadata - WHERE author=?; - """, - (author,), - ) - return {x[0] for x in self._cursor.fetchall()} - - def search_by_name(self, name: str) -> Set[str]: - """ - Return the keys of models with the indicated name. - - Note that this is the name of the model given to it by - the remote source. The user may have changed the local - name. The local name will be located in the model config - record object. - """ - self._cursor.execute( - """--sql - SELECT id FROM model_metadata - WHERE name=?; - """, - (name,), - ) - return {x[0] for x in self._cursor.fetchall()} - - def _update_tags(self, model_key: str, tags: Optional[Set[str]]) -> None: - """Update tags for the model referenced by model_key.""" - if tags: - # remove previous tags from this model - self._cursor.execute( - """--sql - DELETE FROM model_tags - WHERE model_id=?; - """, - (model_key,), - ) - - for tag in tags: - self._cursor.execute( - """--sql - INSERT OR IGNORE INTO tags ( - tag_text - ) - VALUES (?); - """, - (tag,), - ) - self._cursor.execute( - """--sql - SELECT tag_id - FROM tags - WHERE tag_text = ? - LIMIT 1; - """, - (tag,), - ) - tag_id = self._cursor.fetchone()[0] - self._cursor.execute( - """--sql - INSERT OR IGNORE INTO model_tags ( - model_id, - tag_id - ) - VALUES (?,?); - """, - (model_key, tag_id), - ) diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py index 9c463ebb45..f376ae3878 100644 --- a/invokeai/app/services/model_records/model_records_base.py +++ b/invokeai/app/services/model_records/model_records_base.py @@ -6,7 +6,7 @@ Abstract base class for storing and retrieving model configuration records. from abc import ABC, abstractmethod from enum import Enum from pathlib import Path -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Dict, List, Optional, Set, Union from pydantic import BaseModel, Field @@ -17,9 +17,6 @@ from invokeai.backend.model_manager import ( ModelFormat, ModelType, ) -from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata - -from ..model_metadata import ModelMetadataStoreBase class DuplicateModelException(Exception): @@ -109,40 +106,6 @@ class ModelRecordServiceBase(ABC): """ pass - @property - @abstractmethod - def metadata_store(self) -> ModelMetadataStoreBase: - """Return a ModelMetadataStore initialized on the same database.""" - pass - - @abstractmethod - def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]: - """ - Retrieve metadata (if any) from when model was downloaded from a repo. - - :param key: Model key - """ - pass - - @abstractmethod - def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: - """List metadata for all models that have it.""" - pass - - @abstractmethod - def search_by_metadata_tag(self, tags: Set[str]) -> List[AnyModelConfig]: - """ - Search model metadata for ones with all listed tags and return their corresponding configs. - - :param tags: Set of tags to search for. All tags must be present. - """ - pass - - @abstractmethod - def list_tags(self) -> Set[str]: - """Return a unique set of all the model tags in the metadata database.""" - pass - @abstractmethod def list_models( self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py index 35ddc75567..ce4467d833 100644 --- a/invokeai/app/services/model_records/model_records_sql.py +++ b/invokeai/app/services/model_records/model_records_sql.py @@ -43,7 +43,7 @@ import json import sqlite3 from math import ceil from pathlib import Path -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Dict, List, Optional, Union from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.backend.model_manager.config import ( @@ -53,9 +53,7 @@ from invokeai.backend.model_manager.config import ( ModelFormat, ModelType, ) -from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException -from ..model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL from ..shared.sqlite.sqlite_database import SqliteDatabase from .model_records_base import ( DuplicateModelException, @@ -69,7 +67,7 @@ from .model_records_base import ( class ModelRecordServiceSQL(ModelRecordServiceBase): """Implementation of the ModelConfigStore ABC using a SQL database.""" - def __init__(self, db: SqliteDatabase, metadata_store: ModelMetadataStoreBase): + def __init__(self, db: SqliteDatabase): """ Initialize a new object from preexisting sqlite3 connection and threading lock objects. @@ -78,7 +76,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): super().__init__() self._db = db self._cursor = db.conn.cursor() - self._metadata_store = metadata_store @property def db(self) -> SqliteDatabase: @@ -242,9 +239,8 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): If none of the optional filters are passed, will return all models in the database. """ - results = [] - where_clause = [] - bindings = [] + where_clause: list[str] = [] + bindings: list[str] = [] if model_name: where_clause.append("name=?") bindings.append(model_name) @@ -302,55 +298,17 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): ] return results - @property - def metadata_store(self) -> ModelMetadataStoreBase: - """Return a ModelMetadataStore initialized on the same database.""" - return self._metadata_store - - def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]: - """ - Retrieve metadata (if any) from when model was downloaded from a repo. - - :param key: Model key - """ - store = self.metadata_store - try: - metadata = store.get_metadata(key) - return metadata - except UnknownMetadataException: - return None - - def search_by_metadata_tag(self, tags: Set[str]) -> List[AnyModelConfig]: - """ - Search model metadata for ones with all listed tags and return their corresponding configs. - - :param tags: Set of tags to search for. All tags must be present. - """ - store = ModelMetadataStoreSQL(self._db) - keys = store.search_by_tag(tags) - return [self.get_model(x) for x in keys] - - def list_tags(self) -> Set[str]: - """Return a unique set of all the model tags in the metadata database.""" - store = ModelMetadataStoreSQL(self._db) - return store.list_tags() - - def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: - """List metadata for all models that have it.""" - store = ModelMetadataStoreSQL(self._db) - return store.list_all_metadata() - def list_models( self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default ) -> PaginatedResults[ModelSummary]: """Return a paginated summary listing of each model in the database.""" assert isinstance(order_by, ModelRecordOrderBy) ordering = { - ModelRecordOrderBy.Default: "a.type, a.base, a.format, a.name", - ModelRecordOrderBy.Type: "a.type", - ModelRecordOrderBy.Base: "a.base", - ModelRecordOrderBy.Name: "a.name", - ModelRecordOrderBy.Format: "a.format", + ModelRecordOrderBy.Default: "type, base, format, name", + ModelRecordOrderBy.Type: "type", + ModelRecordOrderBy.Base: "base", + ModelRecordOrderBy.Name: "name", + ModelRecordOrderBy.Format: "format", } # Lock so that the database isn't updated while we're doing the two queries. @@ -364,7 +322,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): ) total = int(self._cursor.fetchone()[0]) - # query2: fetch key fields from the join of models and model_metadata + # query2: fetch key fields self._cursor.execute( f"""--sql SELECT config diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_7.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_7.py index 60a58c1a38..bb33609c27 100644 --- a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_7.py +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_7.py @@ -6,6 +6,15 @@ from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import class Migration7Callback: def __call__(self, cursor: sqlite3.Cursor) -> None: self._create_models_table(cursor) + self._drop_old_models_tables(cursor) + + def _drop_old_models_tables(self, cursor: sqlite3.Cursor) -> None: + """Drops the old model_records, model_metadata, model_tags and tags tables.""" + + tables = ["model_records", "model_metadata", "model_tags", "tags"] + + for table in tables: + cursor.execute(f"DROP TABLE IF EXISTS {table};") def _create_models_table(self, cursor: sqlite3.Cursor) -> None: """Creates the v4.0.0 models table.""" @@ -67,7 +76,7 @@ def build_migration_7() -> Migration: This migration does the following: - Adds the new models table - - TODO(MM2): Drops the old model_records, model_metadata, model_tags and tags tables. + - Drops the old model_records, model_metadata, model_tags and tags tables. - TODO(MM2): Migrates model names and descriptions from `models.yaml` to the new table (?). """ migration_7 = Migration( diff --git a/invokeai/app/util/metadata.py b/invokeai/app/util/metadata.py deleted file mode 100644 index 52f9750e4f..0000000000 --- a/invokeai/app/util/metadata.py +++ /dev/null @@ -1,55 +0,0 @@ -import json -from typing import Optional - -from pydantic import ValidationError - -from invokeai.app.services.shared.graph import Edge - - -def get_metadata_graph_from_raw_session(session_raw: str) -> Optional[dict]: - """ - Parses raw session string, returning a dict of the graph. - - Only the general graph shape is validated; none of the fields are validated. - - Any `metadata_accumulator` nodes and edges are removed. - - Any validation failure will return None. - """ - - graph = json.loads(session_raw).get("graph", None) - - # sanity check make sure the graph is at least reasonably shaped - if ( - not isinstance(graph, dict) - or "nodes" not in graph - or not isinstance(graph["nodes"], dict) - or "edges" not in graph - or not isinstance(graph["edges"], list) - ): - # something has gone terribly awry, return an empty dict - return None - - try: - # delete the `metadata_accumulator` node - del graph["nodes"]["metadata_accumulator"] - except KeyError: - # no accumulator node, all good - pass - - # delete any edges to or from it - for i, edge in enumerate(graph["edges"]): - try: - # try to parse the edge - Edge(**edge) - except ValidationError: - # something has gone terribly awry, return an empty dict - return None - - if ( - edge["source"]["node_id"] == "metadata_accumulator" - or edge["destination"]["node_id"] == "metadata_accumulator" - ): - del graph["edges"][i] - - return graph diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 676a2a0250..0386eab8ca 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -29,6 +29,8 @@ from diffusers.models.modeling_utils import ModelMixin from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter from typing_extensions import Annotated, Any, Dict +from invokeai.app.util.misc import uuid_string + from ..raw_model import RawModel # ModelMixin is the base class for all diffusers and transformers models @@ -132,7 +134,7 @@ class ModelSourceType(str, Enum): class ModelConfigBase(BaseModel): """Base class for model configuration information.""" - key: str = Field(description="A unique key for this model.") + key: str = Field(description="A unique key for this model.", default_factory=uuid_string) hash: str = Field(description="The hash of the model file(s).") path: str = Field( description="Path to the model on the filesystem. Relative paths are relative to the Invoke root directory." @@ -142,7 +144,9 @@ class ModelConfigBase(BaseModel): description: Optional[str] = Field(description="Model description", default=None) source: str = Field(description="The original source of the model (path, URL or repo_id).") source_type: ModelSourceType = Field(description="The type of source") - source_api_response: Optional[str] = Field(description="The original API response from the source, as stringified JSON.", default=None) + source_api_response: Optional[str] = Field( + description="The original API response from the source, as stringified JSON.", default=None + ) trigger_words: Optional[set[str]] = Field(description="Set of trigger words for this model", default=None) model_config = ConfigDict(use_enum_values=False, validate_assignment=True) diff --git a/invokeai/backend/model_manager/metadata/fetch/huggingface.py b/invokeai/backend/model_manager/metadata/fetch/huggingface.py index 9d2a52603d..5090f29148 100644 --- a/invokeai/backend/model_manager/metadata/fetch/huggingface.py +++ b/invokeai/backend/model_manager/metadata/fetch/huggingface.py @@ -91,7 +91,7 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase): ) return HuggingFaceMetadata( - id=model_info.id, name=name, files=files, api_response=json.dumps(model_info.__dict__) + id=model_info.id, name=name, files=files, api_response=json.dumps(model_info.__dict__, default=str) ) def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata: diff --git a/invokeai/backend/model_manager/metadata/metadata_store.py b/invokeai/backend/model_manager/metadata/metadata_store.py deleted file mode 100644 index 684409fc3b..0000000000 --- a/invokeai/backend/model_manager/metadata/metadata_store.py +++ /dev/null @@ -1,221 +0,0 @@ -# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team -""" -SQL Storage for Model Metadata -""" - -import sqlite3 -from typing import List, Optional, Set, Tuple - -from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase - -from .fetch import ModelMetadataFetchBase -from .metadata_base import AnyModelRepoMetadata, UnknownMetadataException - - -class ModelMetadataStore: - """Store, search and fetch model metadata retrieved from remote repositories.""" - - def __init__(self, db: SqliteDatabase): - """ - Initialize a new object from preexisting sqlite3 connection and threading lock objects. - - :param conn: sqlite3 connection object - :param lock: threading Lock object - """ - super().__init__() - self._db = db - self._cursor = self._db.conn.cursor() - - def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None: - """ - Add a block of repo metadata to a model record. - - The model record config must already exist in the database with the - same key. Otherwise a FOREIGN KEY constraint exception will be raised. - - :param model_key: Existing model key in the `model_config` table - :param metadata: ModelRepoMetadata object to store - """ - json_serialized = metadata.model_dump_json() - with self._db.lock: - try: - self._cursor.execute( - """--sql - INSERT INTO model_metadata( - id, - metadata - ) - VALUES (?,?); - """, - ( - model_key, - json_serialized, - ), - ) - self._update_tags(model_key, metadata.tags) - self._db.conn.commit() - except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table - self._db.conn.rollback() - raise UnknownMetadataException from excp - except sqlite3.Error as excp: - self._db.conn.rollback() - raise excp - - def get_metadata(self, model_key: str) -> AnyModelRepoMetadata: - """Retrieve the ModelRepoMetadata corresponding to model key.""" - with self._db.lock: - self._cursor.execute( - """--sql - SELECT metadata FROM model_metadata - WHERE id=?; - """, - (model_key,), - ) - rows = self._cursor.fetchone() - if not rows: - raise UnknownMetadataException("model metadata not found") - return ModelMetadataFetchBase.from_json(rows[0]) - - def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata - """Dump out all the metadata.""" - with self._db.lock: - self._cursor.execute( - """--sql - SELECT id,metadata FROM model_metadata; - """, - (), - ) - rows = self._cursor.fetchall() - return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows] - - def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata: - """ - Update metadata corresponding to the model with the indicated key. - - :param model_key: Existing model key in the `model_config` table - :param metadata: ModelRepoMetadata object to update - """ - json_serialized = metadata.model_dump_json() # turn it into a json string. - with self._db.lock: - try: - self._cursor.execute( - """--sql - UPDATE model_metadata - SET - metadata=? - WHERE id=?; - """, - (json_serialized, model_key), - ) - if self._cursor.rowcount == 0: - raise UnknownMetadataException("model metadata not found") - self._update_tags(model_key, metadata.tags) - self._db.conn.commit() - except sqlite3.Error as e: - self._db.conn.rollback() - raise e - - return self.get_metadata(model_key) - - def list_tags(self) -> Set[str]: - """Return all tags in the tags table.""" - self._cursor.execute( - """--sql - select tag_text from tags; - """ - ) - return {x[0] for x in self._cursor.fetchall()} - - def search_by_tag(self, tags: Set[str]) -> Set[str]: - """Return the keys of models containing all of the listed tags.""" - with self._db.lock: - try: - matches: Optional[Set[str]] = None - for tag in tags: - self._cursor.execute( - """--sql - SELECT a.model_id FROM model_tags AS a, - tags AS b - WHERE a.tag_id=b.tag_id - AND b.tag_text=?; - """, - (tag,), - ) - model_keys = {x[0] for x in self._cursor.fetchall()} - if matches is None: - matches = model_keys - matches = matches.intersection(model_keys) - except sqlite3.Error as e: - raise e - return matches if matches else set() - - def search_by_author(self, author: str) -> Set[str]: - """Return the keys of models authored by the indicated author.""" - self._cursor.execute( - """--sql - SELECT id FROM model_metadata - WHERE author=?; - """, - (author,), - ) - return {x[0] for x in self._cursor.fetchall()} - - def search_by_name(self, name: str) -> Set[str]: - """ - Return the keys of models with the indicated name. - - Note that this is the name of the model given to it by - the remote source. The user may have changed the local - name. The local name will be located in the model config - record object. - """ - self._cursor.execute( - """--sql - SELECT id FROM model_metadata - WHERE name=?; - """, - (name,), - ) - return {x[0] for x in self._cursor.fetchall()} - - def _update_tags(self, model_key: str, tags: Set[str]) -> None: - """Update tags for the model referenced by model_key.""" - # remove previous tags from this model - self._cursor.execute( - """--sql - DELETE FROM model_tags - WHERE model_id=?; - """, - (model_key,), - ) - - for tag in tags: - self._cursor.execute( - """--sql - INSERT OR IGNORE INTO tags ( - tag_text - ) - VALUES (?); - """, - (tag,), - ) - self._cursor.execute( - """--sql - SELECT tag_id - FROM tags - WHERE tag_text = ? - LIMIT 1; - """, - (tag,), - ) - tag_id = self._cursor.fetchone()[0] - self._cursor.execute( - """--sql - INSERT OR IGNORE INTO model_tags ( - model_id, - tag_id - ) - VALUES (?,?); - """, - (model_key, tag_id), - ) diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index c837993888..774959f7ef 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -18,6 +18,7 @@ from .config import ( ModelConfigFactory, ModelFormat, ModelRepoVariant, + ModelSourceType, ModelType, ModelVariantType, SchedulerPredictionType, @@ -150,7 +151,7 @@ class ModelProbe(object): probe = probe_class(model_path) - fields["source_type"] = fields.get("source_type") + fields["source_type"] = fields.get("source_type") or ModelSourceType.Path fields["source"] = fields.get("source") or model_path.as_posix() fields["key"] = fields.get("key", uuid_string()) fields["path"] = model_path.as_posix() diff --git a/tests/app/services/model_records/test_model_records_sql.py b/tests/app/services/model_records/test_model_records_sql.py index 8263d77d79..1552d4c07d 100644 --- a/tests/app/services/model_records/test_model_records_sql.py +++ b/tests/app/services/model_records/test_model_records_sql.py @@ -3,29 +3,26 @@ Test the refactored model config classes. """ from hashlib import sha256 -from typing import Any +from typing import Any, Optional import pytest from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.model_metadata import ModelMetadataStoreSQL from invokeai.app.services.model_records import ( DuplicateModelException, - ModelRecordOrderBy, ModelRecordServiceBase, ModelRecordServiceSQL, UnknownModelException, ) from invokeai.backend.model_manager.config import ( BaseModelType, - MainCheckpointConfig, MainDiffusersConfig, ModelFormat, + ModelSourceType, ModelType, TextualInversionFileConfig, VaeDiffusersConfig, ) -from invokeai.backend.model_manager.metadata import BaseMetadata from invokeai.backend.util.logging import InvokeAILogger from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 from tests.fixtures.sqlite_database import create_mock_sqlite_database @@ -38,11 +35,13 @@ def store( config = InvokeAIAppConfig(root=datadir) logger = InvokeAILogger.get_logger(config=config) db = create_mock_sqlite_database(config, logger) - return ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db)) + return ModelRecordServiceSQL(db) -def example_config() -> TextualInversionFileConfig: - return TextualInversionFileConfig( +def example_config(key: Optional[str] = None) -> TextualInversionFileConfig: + config = TextualInversionFileConfig( + source="test/source/", + source_type=ModelSourceType.Path, path="/tmp/pokemon.bin", name="old name", base=BaseModelType.StableDiffusion1, @@ -50,59 +49,45 @@ def example_config() -> TextualInversionFileConfig: format=ModelFormat.EmbeddingFile, hash="ABC123", ) + if key is not None: + config.key = key + return config def test_type(store: ModelRecordServiceBase): - config = example_config() - store.add_model("key1", config) + config = example_config("key1") + store.add_model(config) config1 = store.get_model("key1") - assert type(config1) == TextualInversionFileConfig + assert isinstance(config1, TextualInversionFileConfig) -def test_add(store: ModelRecordServiceBase): - raw = { - "path": "/tmp/foo.ckpt", - "name": "model1", - "base": BaseModelType.StableDiffusion1, - "type": "main", - "config_path": "/tmp/foo.yaml", - "variant": "normal", - "format": "checkpoint", - "original_hash": "111222333444", - } - store.add_model("key1", raw) - config1 = store.get_model("key1") - assert config1 is not None - assert type(config1) == MainCheckpointConfig - assert config1.base == BaseModelType.StableDiffusion1 - assert config1.name == "model1" - assert config1.hash == "111222333444" - - -def test_dup(store: ModelRecordServiceBase): - config = example_config() - store.add_model("key1", example_config()) +def test_raises_on_violating_uniqueness(store: ModelRecordServiceBase): + # Models have a uniqueness constraint by their name, base and type + config1 = example_config("key1") + config2 = config1.model_copy(deep=True) + config2.key = "key2" + store.add_model(config1) with pytest.raises(DuplicateModelException): - store.add_model("key1", config) + store.add_model(config1) with pytest.raises(DuplicateModelException): - store.add_model("key2", config) + store.add_model(config2) def test_update(store: ModelRecordServiceBase): - config = example_config() - store.add_model("key1", config) + config = example_config("key1") + store.add_model(config) config = store.get_model("key1") assert config.name == "old name" config.name = "new name" - store.update_model("key1", config) + store.update_model(config.key, config) new_config = store.get_model("key1") assert new_config.name == "new name" def test_rename(store: ModelRecordServiceBase): - config = example_config() - store.add_model("key1", config) + config = example_config("key1") + store.add_model(config) config = store.get_model("key1") assert config.name == "old name" @@ -112,15 +97,15 @@ def test_rename(store: ModelRecordServiceBase): def test_unknown_key(store: ModelRecordServiceBase): - config = example_config() - store.add_model("key1", config) + config = example_config("key1") + store.add_model(config) with pytest.raises(UnknownModelException): store.update_model("unknown_key", config) def test_delete(store: ModelRecordServiceBase): - config = example_config() - store.add_model("key1", config) + config = example_config("key1") + store.add_model(config) config = store.get_model("key1") store.del_model("key1") with pytest.raises(UnknownModelException): @@ -128,36 +113,45 @@ def test_delete(store: ModelRecordServiceBase): def test_exists(store: ModelRecordServiceBase): - config = example_config() - store.add_model("key1", config) + config = example_config("key1") + store.add_model(config) assert store.exists("key1") assert not store.exists("key2") def test_filter(store: ModelRecordServiceBase): config1 = MainDiffusersConfig( + key="config1", path="/tmp/config1", name="config1", base=BaseModelType.StableDiffusion1, type=ModelType.Main, hash="CONFIG1HASH", + source="test/source", + source_type=ModelSourceType.Path, ) config2 = MainDiffusersConfig( + key="config2", path="/tmp/config2", name="config2", base=BaseModelType.StableDiffusion1, type=ModelType.Main, hash="CONFIG2HASH", + source="test/source", + source_type=ModelSourceType.Path, ) config3 = VaeDiffusersConfig( + key="config3", path="/tmp/config3", name="config3", base=BaseModelType("sd-2"), type=ModelType.Vae, hash="CONFIG3HASH", + source="test/source", + source_type=ModelSourceType.Path, ) for c in config1, config2, config3: - store.add_model(sha256(c.name.encode("utf-8")).hexdigest(), c) + store.add_model(c) matches = store.search_by_attr(model_type=ModelType.Main) assert len(matches) == 2 assert matches[0].name in {"config1", "config2"} @@ -165,7 +159,7 @@ def test_filter(store: ModelRecordServiceBase): matches = store.search_by_attr(model_type=ModelType.Vae) assert len(matches) == 1 assert matches[0].name == "config3" - assert matches[0].key == sha256("config3".encode("utf-8")).hexdigest() + assert matches[0].key == "config3" assert isinstance(matches[0].type, ModelType) # This tests that we get proper enums back matches = store.search_by_hash("CONFIG1HASH") @@ -183,6 +177,8 @@ def test_unique(store: ModelRecordServiceBase): type=ModelType.Main, name="nonuniquename", hash="CONFIG1HASH", + source="test/source/", + source_type=ModelSourceType.Path, ) config2 = MainDiffusersConfig( path="/tmp/config2", @@ -190,6 +186,8 @@ def test_unique(store: ModelRecordServiceBase): type=ModelType.Main, name="nonuniquename", hash="CONFIG1HASH", + source="test/source/", + source_type=ModelSourceType.Path, ) config3 = VaeDiffusersConfig( path="/tmp/config3", @@ -197,6 +195,8 @@ def test_unique(store: ModelRecordServiceBase): type=ModelType.Vae, name="nonuniquename", hash="CONFIG1HASH", + source="test/source/", + source_type=ModelSourceType.Path, ) config4 = MainDiffusersConfig( path="/tmp/config4", @@ -204,15 +204,19 @@ def test_unique(store: ModelRecordServiceBase): type=ModelType.Main, name="nonuniquename", hash="CONFIG1HASH", + source="test/source/", + source_type=ModelSourceType.Path, ) # config1, config2 and config3 are compatible because they have unique combos # of name, type and base for c in config1, config2, config3: - store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), c) + c.key = sha256(c.path.encode("utf-8")).hexdigest() + store.add_model(c) # config4 clashes with config1 and should raise an integrity error with pytest.raises(DuplicateModelException): - store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), config4) + config4.key = sha256(config4.path.encode("utf-8")).hexdigest() + store.add_model(config4) def test_filter_2(store: ModelRecordServiceBase): @@ -222,6 +226,8 @@ def test_filter_2(store: ModelRecordServiceBase): base=BaseModelType.StableDiffusion1, type=ModelType.Main, hash="CONFIG1HASH", + source="test/source/", + source_type=ModelSourceType.Path, ) config2 = MainDiffusersConfig( path="/tmp/config2", @@ -229,6 +235,8 @@ def test_filter_2(store: ModelRecordServiceBase): base=BaseModelType.StableDiffusion1, type=ModelType.Main, hash="CONFIG2HASH", + source="test/source/", + source_type=ModelSourceType.Path, ) config3 = MainDiffusersConfig( path="/tmp/config3", @@ -236,6 +244,8 @@ def test_filter_2(store: ModelRecordServiceBase): base=BaseModelType("sd-2"), type=ModelType.Main, hash="CONFIG3HASH", + source="test/source/", + source_type=ModelSourceType.Path, ) config4 = MainDiffusersConfig( path="/tmp/config4", @@ -243,6 +253,8 @@ def test_filter_2(store: ModelRecordServiceBase): base=BaseModelType("sdxl"), type=ModelType.Main, hash="CONFIG3HASH", + source="test/source/", + source_type=ModelSourceType.Path, ) config5 = VaeDiffusersConfig( path="/tmp/config5", @@ -250,9 +262,11 @@ def test_filter_2(store: ModelRecordServiceBase): base=BaseModelType.StableDiffusion1, type=ModelType.Vae, hash="CONFIG3HASH", + source="test/source/", + source_type=ModelSourceType.Path, ) for c in config1, config2, config3, config4, config5: - store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), c) + store.add_model(c) matches = store.search_by_attr( model_type=ModelType.Main, @@ -272,50 +286,3 @@ def test_filter_2(store: ModelRecordServiceBase): model_name="dup_name1", ) assert len(matches) == 1 - - -def test_summary(mm2_record_store: ModelRecordServiceSQL) -> None: - # The fixture provides us with five configs. - for x in range(1, 5): - key = f"test_config_{x}" - name = f"name_{x}" - author = f"author_{x}" - tags = {f"tag{y}" for y in range(1, x)} - mm2_record_store.metadata_store.add_metadata( - model_key=key, metadata=BaseMetadata(name=name, author=author, tags=tags) - ) - # sanity check that the tags sent in all right - assert mm2_record_store.get_metadata("test_config_3").tags == {"tag1", "tag2"} - assert mm2_record_store.get_metadata("test_config_4").tags == {"tag1", "tag2", "tag3"} - - # get summary - summary1 = mm2_record_store.list_models(page=0, per_page=100) - assert summary1.page == 0 - assert summary1.pages == 1 - assert summary1.per_page == 100 - assert summary1.total == 5 - assert len(summary1.items) == 5 - assert summary1.items[0].name == "test5" # lora / sd-1 / diffusers / test5 - - # find test_config_3 - config3 = [x for x in summary1.items if x.key == "test_config_3"][0] - assert config3.description == "This is test 3" - assert config3.tags == {"tag1", "tag2"} - - # find test_config_5 - config5 = [x for x in summary1.items if x.key == "test_config_5"][0] - assert config5.tags == set() - assert config5.description == "" - - # test paging - summary2 = mm2_record_store.list_models(page=1, per_page=2) - assert summary2.page == 1 - assert summary2.per_page == 2 - assert summary2.pages == 3 - assert summary1.items[2].name == summary2.items[0].name - - # test sorting - summary = mm2_record_store.list_models(page=0, per_page=100, order_by=ModelRecordOrderBy.Name) - print(summary.items) - assert summary.items[0].name == "model1" - assert summary.items[-1].name == "test5" diff --git a/tests/backend/model_manager/model_manager_fixtures.py b/tests/backend/model_manager/model_manager_fixtures.py index 4c0226f0cb..112b3765ff 100644 --- a/tests/backend/model_manager/model_manager_fixtures.py +++ b/tests/backend/model_manager/model_manager_fixtures.py @@ -18,12 +18,17 @@ from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase from invokeai.app.services.model_load import ModelLoadService, ModelLoadServiceBase from invokeai.app.services.model_manager import ModelManagerService, ModelManagerServiceBase -from invokeai.app.services.model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL from invokeai.backend.model_manager.config import ( BaseModelType, + LoRADiffusersConfig, + MainCheckpointConfig, + MainDiffusersConfig, ModelFormat, + ModelSourceType, ModelType, + ModelVariantType, + VaeDiffusersConfig, ) from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache from invokeai.backend.util.logging import InvokeAILogger @@ -107,11 +112,6 @@ def mm2_download_queue(mm2_session: Session, request: FixtureRequest) -> Downloa return download_queue -@pytest.fixture -def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStoreBase: - return mm2_record_store.metadata_store - - @pytest.fixture def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase: ram_cache = ModelCache( @@ -137,7 +137,7 @@ def mm2_installer( logger = InvokeAILogger.get_logger() db = create_mock_sqlite_database(mm2_app_config, logger) events = DummyEventService() - store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db)) + store = ModelRecordServiceSQL(db) installer = ModelInstallService( app_config=mm2_app_config, @@ -160,61 +160,71 @@ def mm2_installer( def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBase: logger = InvokeAILogger.get_logger(config=mm2_app_config) db = create_mock_sqlite_database(mm2_app_config, logger) - store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db)) + store = ModelRecordServiceSQL(db) # add five simple config records to the database - raw1 = { - "path": "/tmp/foo1", - "format": ModelFormat("diffusers"), - "name": "test2", - "base": BaseModelType("sd-2"), - "type": ModelType("vae"), - "original_hash": "111222333444", - "source": "stabilityai/sdxl-vae", - } - raw2 = { - "path": "/tmp/foo2.ckpt", - "name": "model1", - "format": ModelFormat("checkpoint"), - "base": BaseModelType("sd-1"), - "type": "main", - "config_path": "/tmp/foo.yaml", - "variant": "normal", - "original_hash": "111222333444", - "source": "https://civitai.com/models/206883/split", - } - raw3 = { - "path": "/tmp/foo3", - "format": ModelFormat("diffusers"), - "name": "test3", - "base": BaseModelType("sdxl"), - "type": ModelType("main"), - "original_hash": "111222333444", - "source": "author3/model3", - "description": "This is test 3", - } - raw4 = { - "path": "/tmp/foo4", - "format": ModelFormat("diffusers"), - "name": "test4", - "base": BaseModelType("sdxl"), - "type": ModelType("lora"), - "original_hash": "111222333444", - "source": "author4/model4", - } - raw5 = { - "path": "/tmp/foo5", - "format": ModelFormat("diffusers"), - "name": "test5", - "base": BaseModelType("sd-1"), - "type": ModelType("lora"), - "original_hash": "111222333444", - "source": "author4/model5", - } - store.add_model("test_config_1", raw1) - store.add_model("test_config_2", raw2) - store.add_model("test_config_3", raw3) - store.add_model("test_config_4", raw4) - store.add_model("test_config_5", raw5) + config1 = VaeDiffusersConfig( + key="test_config_1", + path="/tmp/foo1", + format=ModelFormat.Diffusers, + name="test2", + base=BaseModelType.StableDiffusion2, + type=ModelType.Vae, + hash="111222333444", + source="stabilityai/sdxl-vae", + source_type=ModelSourceType.HFRepoID, + ) + config2 = MainCheckpointConfig( + key="test_config_2", + path="/tmp/foo2.ckpt", + name="model1", + format=ModelFormat.Checkpoint, + base=BaseModelType.StableDiffusion1, + type=ModelType.Main, + config_path="/tmp/foo.yaml", + variant=ModelVariantType.Normal, + hash="111222333444", + source="https://civitai.com/models/206883/split", + source_type=ModelSourceType.CivitAI, + ) + config3 = MainDiffusersConfig( + key="test_config_3", + path="/tmp/foo3", + format=ModelFormat.Diffusers, + name="test3", + base=BaseModelType.StableDiffusionXL, + type=ModelType.Main, + hash="111222333444", + source="author3/model3", + description="This is test 3", + source_type=ModelSourceType.HFRepoID, + ) + config4 = LoRADiffusersConfig( + key="test_config_4", + path="/tmp/foo4", + format=ModelFormat.Diffusers, + name="test4", + base=BaseModelType.StableDiffusionXL, + type=ModelType.Lora, + hash="111222333444", + source="author4/model4", + source_type=ModelSourceType.HFRepoID, + ) + config5 = LoRADiffusersConfig( + key="test_config_5", + path="/tmp/foo5", + format=ModelFormat.Diffusers, + name="test5", + base=BaseModelType.StableDiffusion1, + type=ModelType.Lora, + hash="111222333444", + source="author4/model5", + source_type=ModelSourceType.HFRepoID, + ) + store.add_model(config1) + store.add_model(config2) + store.add_model(config3) + store.add_model(config4) + store.add_model(config5) return store diff --git a/tests/backend/model_manager/model_metadata/test_model_metadata.py b/tests/backend/model_manager/model_metadata/test_model_metadata.py deleted file mode 100644 index 0c6b1a6e93..0000000000 --- a/tests/backend/model_manager/model_metadata/test_model_metadata.py +++ /dev/null @@ -1,202 +0,0 @@ -""" -Test model metadata fetching and storage. -""" - -import datetime -from pathlib import Path - -import pytest -from pydantic.networks import HttpUrl -from requests.sessions import Session - -from invokeai.app.services.model_metadata import ModelMetadataStoreBase -from invokeai.backend.model_manager.config import ModelRepoVariant -from invokeai.backend.model_manager.metadata import ( - CivitaiMetadata, - CivitaiMetadataFetch, - CommercialUsage, - HuggingFaceMetadata, - HuggingFaceMetadataFetch, - UnknownMetadataException, -) -from invokeai.backend.model_manager.util import select_hf_files -from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 - - -def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStoreBase) -> None: - tags = {"text-to-image", "diffusers"} - input_metadata = HuggingFaceMetadata( - name="sdxl-vae", - author="stabilityai", - tags=tags, - id="stabilityai/sdxl-vae", - tag_dict={"license": "other"}, - last_modified=datetime.datetime.now(), - ) - mm2_metadata_store.add_metadata("test_config_1", input_metadata) - output_metadata = mm2_metadata_store.get_metadata("test_config_1") - assert input_metadata == output_metadata - with pytest.raises(UnknownMetadataException): - mm2_metadata_store.add_metadata("unknown_key", input_metadata) - assert mm2_metadata_store.list_tags() == tags - - -def test_metadata_store_update(mm2_metadata_store: ModelMetadataStoreBase) -> None: - input_metadata = HuggingFaceMetadata( - name="sdxl-vae", - author="stabilityai", - tags={"text-to-image", "diffusers"}, - id="stabilityai/sdxl-vae", - tag_dict={"license": "other"}, - last_modified=datetime.datetime.now(), - ) - mm2_metadata_store.add_metadata("test_config_1", input_metadata) - input_metadata.name = "new-name" - mm2_metadata_store.update_metadata("test_config_1", input_metadata) - output_metadata = mm2_metadata_store.get_metadata("test_config_1") - assert output_metadata.name == "new-name" - assert input_metadata == output_metadata - - -def test_metadata_search(mm2_metadata_store: ModelMetadataStoreBase) -> None: - metadata1 = HuggingFaceMetadata( - name="sdxl-vae", - author="stabilityai", - tags={"text-to-image", "diffusers"}, - id="stabilityai/sdxl-vae", - tag_dict={"license": "other"}, - last_modified=datetime.datetime.now(), - ) - metadata2 = HuggingFaceMetadata( - name="model2", - author="stabilityai", - tags={"text-to-image", "diffusers", "community-contributed"}, - id="author2/model2", - tag_dict={"license": "other"}, - last_modified=datetime.datetime.now(), - ) - metadata3 = HuggingFaceMetadata( - name="model3", - author="author3", - tags={"text-to-image", "checkpoint", "community-contributed"}, - id="author3/model3", - tag_dict={"license": "other"}, - last_modified=datetime.datetime.now(), - ) - mm2_metadata_store.add_metadata("test_config_1", metadata1) - mm2_metadata_store.add_metadata("test_config_2", metadata2) - mm2_metadata_store.add_metadata("test_config_3", metadata3) - - matches = mm2_metadata_store.search_by_author("stabilityai") - assert len(matches) == 2 - assert "test_config_1" in matches - assert "test_config_2" in matches - matches = mm2_metadata_store.search_by_author("Sherlock Holmes") - assert not matches - - matches = mm2_metadata_store.search_by_name("model3") - assert len(matches) == 1 - assert "test_config_3" in matches - - matches = mm2_metadata_store.search_by_tag({"text-to-image"}) - assert len(matches) == 3 - - matches = mm2_metadata_store.search_by_tag({"text-to-image", "diffusers"}) - assert len(matches) == 2 - assert "test_config_1" in matches - assert "test_config_2" in matches - - matches = mm2_metadata_store.search_by_tag({"checkpoint", "community-contributed"}) - assert len(matches) == 1 - assert "test_config_3" in matches - - # does the tag table update correctly? - matches = mm2_metadata_store.search_by_tag({"checkpoint", "licensed-for-commercial-use"}) - assert not matches - assert mm2_metadata_store.list_tags() == {"text-to-image", "diffusers", "community-contributed", "checkpoint"} - metadata3.tags.add("licensed-for-commercial-use") - mm2_metadata_store.update_metadata("test_config_3", metadata3) - assert mm2_metadata_store.list_tags() == { - "text-to-image", - "diffusers", - "community-contributed", - "checkpoint", - "licensed-for-commercial-use", - } - matches = mm2_metadata_store.search_by_tag({"checkpoint", "licensed-for-commercial-use"}) - assert len(matches) == 1 - - -def test_metadata_civitai_fetch(mm2_session: Session) -> None: - fetcher = CivitaiMetadataFetch(mm2_session) - metadata = fetcher.from_url(HttpUrl("https://civitai.com/models/215485/SDXL-turbo")) - assert isinstance(metadata, CivitaiMetadata) - assert metadata.id == 215485 - assert metadata.author == "test_author" # note that this is not the same as the original from Civitai - assert metadata.allow_commercial_use # changed to make sure we are reading locally not remotely - assert CommercialUsage("RentCivit") in metadata.restrictions.AllowCommercialUse - assert metadata.version_id == 242807 - assert metadata.tags == {"tool", "turbo", "sdxl turbo"} - - -def test_metadata_hf_fetch(mm2_session: Session) -> None: - fetcher = HuggingFaceMetadataFetch(mm2_session) - metadata = fetcher.from_url(HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo")) - assert isinstance(metadata, HuggingFaceMetadata) - assert metadata.author == "test_author" # this is not the same as the original - assert metadata.files - assert metadata.tags == { - "diffusers", - "onnx", - "safetensors", - "text-to-image", - "license:other", - "has_space", - "diffusers:StableDiffusionXLPipeline", - "region:us", - } - - -def test_metadata_hf_filter(mm2_session: Session) -> None: - metadata = HuggingFaceMetadataFetch(mm2_session).from_url(HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo")) - assert isinstance(metadata, HuggingFaceMetadata) - files = [x.path for x in metadata.files] - fp16_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("fp16")) - assert Path("sdxl-turbo/text_encoder/model.fp16.safetensors") in fp16_files - assert Path("sdxl-turbo/text_encoder/model.safetensors") not in fp16_files - - fp32_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("fp32")) - assert Path("sdxl-turbo/text_encoder/model.safetensors") in fp32_files - assert Path("sdxl-turbo/text_encoder/model.16.safetensors") not in fp32_files - - onnx_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("onnx")) - assert Path("sdxl-turbo/text_encoder/model.onnx") in onnx_files - assert Path("sdxl-turbo/text_encoder/model.safetensors") not in onnx_files - - default_files = select_hf_files.filter_files(files) - assert Path("sdxl-turbo/text_encoder/model.safetensors") in default_files - assert Path("sdxl-turbo/text_encoder/model.16.safetensors") not in default_files - - openvino_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("openvino")) - print(openvino_files) - assert len(openvino_files) == 0 - - flax_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("flax")) - print(flax_files) - assert not flax_files - - metadata = HuggingFaceMetadataFetch(mm2_session).from_url( - HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo-nofp16") - ) - assert isinstance(metadata, HuggingFaceMetadata) - files = [x.path for x in metadata.files] - filtered_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("fp16")) - assert ( - Path("sdxl-turbo-nofp16/text_encoder/model.safetensors") in filtered_files - ) # confirm that default is returned - assert Path("sdxl-turbo-nofp16/text_encoder/model.16.safetensors") not in filtered_files - - -def test_metadata_hf_urls(mm2_session: Session) -> None: - metadata = HuggingFaceMetadataFetch(mm2_session).from_url(HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo")) - assert isinstance(metadata, HuggingFaceMetadata)