diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 95407291ec..c26122cc77 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -26,7 +26,6 @@ from ..services.invocation_services import InvocationServices from ..services.invocation_stats.invocation_stats_default import InvocationStatsService from ..services.invoker import Invoker from ..services.model_manager.model_manager_default import ModelManagerService -from ..services.model_metadata import ModelMetadataStoreSQL from ..services.model_records import ModelRecordServiceSQL from ..services.names.names_default import SimpleNameService from ..services.session_processor.session_processor_default import DefaultSessionProcessor @@ -93,10 +92,9 @@ class ApiDependencies: ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True) ) download_queue_service = DownloadQueueService(event_bus=events) - model_metadata_service = ModelMetadataStoreSQL(db=db) model_manager = ModelManagerService.build_model_manager( app_config=configuration, - model_record_service=ModelRecordServiceSQL(db=db, metadata_store=model_metadata_service), + model_record_service=ModelRecordServiceSQL(db=db), download_queue=download_queue_service, events=events, ) diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index 04a7a7ad05..5e12a9923f 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -12,7 +12,6 @@ from starlette.exceptions import HTTPException from typing_extensions import Annotated from invokeai.app.services.model_install import ModelInstallJob -from invokeai.app.services.model_metadata.metadata_store_base import ModelMetadataChanges from invokeai.app.services.model_records import ( DuplicateModelException, InvalidModelException, @@ -30,8 +29,6 @@ from invokeai.backend.model_manager.config import ( SubModelType, ) from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger -from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata -from invokeai.backend.model_manager.metadata.metadata_base import BaseMetadata from invokeai.backend.model_manager.search import ModelSearch from ..dependencies import ApiDependencies @@ -90,44 +87,6 @@ example_model_input = { "variant": "normal", } -example_model_metadata = { - "name": "ip_adapter_sd_image_encoder", - "author": "InvokeAI", - "tags": [ - "transformers", - "safetensors", - "clip_vision_model", - "endpoints_compatible", - "region:us", - "has_space", - "license:apache-2.0", - ], - "files": [ - { - "url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/README.md", - "path": "ip_adapter_sd_image_encoder/README.md", - "size": 628, - "sha256": None, - }, - { - "url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/config.json", - "path": "ip_adapter_sd_image_encoder/config.json", - "size": 560, - "sha256": None, - }, - { - "url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/model.safetensors", - "path": "ip_adapter_sd_image_encoder/model.safetensors", - "size": 2528373448, - "sha256": "6ca9667da1ca9e0b0f75e46bb030f7e011f44f86cbfb8d5a36590fcd7507b030", - }, - ], - "type": "huggingface", - "id": "InvokeAI/ip_adapter_sd_image_encoder", - "tag_dict": {"license": "apache-2.0"}, - "last_modified": "2023-09-23T17:33:25Z", -} - ############################################################################## # ROUTES ############################################################################## @@ -219,79 +178,6 @@ async def list_model_summary( return results -@model_manager_router.get( - "/i/{key}/metadata", - operation_id="get_model_metadata", - responses={ - 200: { - "description": "The model metadata was retrieved successfully", - "content": {"application/json": {"example": example_model_metadata}}, - }, - 400: {"description": "Bad request"}, - }, -) -async def get_model_metadata( - key: str = Path(description="Key of the model repo metadata to fetch."), -) -> Optional[AnyModelRepoMetadata]: - """Get a model metadata object.""" - record_store = ApiDependencies.invoker.services.model_manager.store - result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key) - - return result - - -@model_manager_router.patch( - "/i/{key}/metadata", - operation_id="update_model_metadata", - responses={ - 201: { - "description": "The model metadata was updated successfully", - "content": {"application/json": {"example": example_model_metadata}}, - }, - 400: {"description": "Bad request"}, - }, -) -async def update_model_metadata( - key: str = Path(description="Key of the model repo metadata to fetch."), - changes: ModelMetadataChanges = Body(description="The changes"), -) -> Optional[AnyModelRepoMetadata]: - """Updates or creates a model metadata object.""" - record_store = ApiDependencies.invoker.services.model_manager.store - metadata_store = ApiDependencies.invoker.services.model_manager.store.metadata_store - - try: - original_metadata = record_store.get_metadata(key) - if original_metadata: - if changes.default_settings: - original_metadata.default_settings = changes.default_settings - - metadata_store.update_metadata(key, original_metadata) - else: - metadata_store.add_metadata( - key, BaseMetadata(name="", author="", default_settings=changes.default_settings) - ) - except Exception as e: - raise HTTPException( - status_code=500, - detail=f"An error occurred while updating the model metadata: {e}", - ) - - result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key) - - return result - - -@model_manager_router.get( - "/tags", - operation_id="list_tags", -) -async def list_tags() -> Set[str]: - """Get a unique set of all the model tags.""" - record_store = ApiDependencies.invoker.services.model_manager.store - result: Set[str] = record_store.list_tags() - return result - - class FoundModel(BaseModel): path: str = Field(description="Path to the model") is_installed: bool = Field(description="Whether or not the model is already installed") @@ -361,19 +247,6 @@ async def scan_for_models( return scan_results -@model_manager_router.get( - "/tags/search", - operation_id="search_by_metadata_tags", -) -async def search_by_metadata_tags( - tags: Set[str] = Query(default=None, description="Tags to search for"), -) -> ModelsList: - """Get a list of models.""" - record_store = ApiDependencies.invoker.services.model_manager.store - results = record_store.search_by_metadata_tag(tags) - return ModelsList(models=results) - - @model_manager_router.patch( "/i/{key}", operation_id="update_model_record", @@ -562,9 +435,8 @@ async def list_model_install_jobs() -> List[ModelInstallJob]: * "cancelled" -- Job was cancelled before completion. Once completed, information about the model such as its size, base - model, type, and metadata can be retrieved from the `config_out` - field. For multi-file models such as diffusers, information on individual files - can be retrieved from `download_parts`. + model and type can be retrieved from the `config_out` field. For multi-file models such as diffusers, + information on individual files can be retrieved from `download_parts`. See the example and schema below for more information. """ @@ -708,10 +580,6 @@ async def convert_model( logger.error(str(e)) raise HTTPException(status_code=409, detail=str(e)) - # get the original metadata - if orig_metadata := store.get_metadata(key): - store.metadata_store.add_metadata(new_key, orig_metadata) - # delete the original safetensors file installer.delete(key) diff --git a/invokeai/app/services/model_install/model_install_base.py b/invokeai/app/services/model_install/model_install_base.py index f85d4ae9ea..81860f5ad7 100644 --- a/invokeai/app/services/model_install/model_install_base.py +++ b/invokeai/app/services/model_install/model_install_base.py @@ -21,8 +21,6 @@ from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant from invokeai.backend.model_manager.config import ModelSourceType from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata -from ..model_metadata import ModelMetadataStoreBase - class InstallStatus(str, Enum): """State of an install job running in the background.""" @@ -268,7 +266,6 @@ class ModelInstallServiceBase(ABC): app_config: InvokeAIAppConfig, record_store: ModelRecordServiceBase, download_queue: DownloadQueueServiceBase, - metadata_store: ModelMetadataStoreBase, event_bus: Optional["EventServiceBase"] = None, ): """ diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index eaea5e5ff4..e12a499648 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -94,7 +94,6 @@ class ModelInstallService(ModelInstallServiceBase): self._running = False self._session = session self._next_job_id = 0 - self._metadata_store = record_store.metadata_store # for convenience @property def app_config(self) -> InvokeAIAppConfig: # noqa D102 diff --git a/invokeai/app/services/model_metadata/__init__.py b/invokeai/app/services/model_metadata/__init__.py deleted file mode 100644 index 981c96b709..0000000000 --- a/invokeai/app/services/model_metadata/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -"""Init file for ModelMetadataStoreService module.""" - -from .metadata_store_base import ModelMetadataStoreBase -from .metadata_store_sql import ModelMetadataStoreSQL - -__all__ = [ - "ModelMetadataStoreBase", - "ModelMetadataStoreSQL", -] diff --git a/invokeai/app/services/model_metadata/metadata_store_base.py b/invokeai/app/services/model_metadata/metadata_store_base.py deleted file mode 100644 index 882575a4bf..0000000000 --- a/invokeai/app/services/model_metadata/metadata_store_base.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team -""" -Storage for Model Metadata -""" - -from abc import ABC, abstractmethod -from typing import List, Optional, Set, Tuple - -from pydantic import Field - -from invokeai.app.util.model_exclude_null import BaseModelExcludeNull -from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata -from invokeai.backend.model_manager.metadata.metadata_base import ModelDefaultSettings - - -class ModelMetadataChanges(BaseModelExcludeNull, extra="allow"): - """A set of changes to apply to model metadata. - Only limited changes are valid: - - `default_settings`: the user-configured default settings for this model - """ - - default_settings: Optional[ModelDefaultSettings] = Field( - default=None, description="The user-configured default settings for this model" - ) - """The user-configured default settings for this model""" - - -class ModelMetadataStoreBase(ABC): - """Store, search and fetch model metadata retrieved from remote repositories.""" - - @abstractmethod - def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None: - """ - Add a block of repo metadata to a model record. - - The model record config must already exist in the database with the - same key. Otherwise a FOREIGN KEY constraint exception will be raised. - - :param model_key: Existing model key in the `model_config` table - :param metadata: ModelRepoMetadata object to store - """ - - @abstractmethod - def get_metadata(self, model_key: str) -> AnyModelRepoMetadata: - """Retrieve the ModelRepoMetadata corresponding to model key.""" - - @abstractmethod - def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata - """Dump out all the metadata.""" - - @abstractmethod - def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata: - """ - Update metadata corresponding to the model with the indicated key. - - :param model_key: Existing model key in the `model_config` table - :param metadata: ModelRepoMetadata object to update - """ - - @abstractmethod - def list_tags(self) -> Set[str]: - """Return all tags in the tags table.""" - - @abstractmethod - def search_by_tag(self, tags: Set[str]) -> Set[str]: - """Return the keys of models containing all of the listed tags.""" - - @abstractmethod - def search_by_author(self, author: str) -> Set[str]: - """Return the keys of models authored by the indicated author.""" - - @abstractmethod - def search_by_name(self, name: str) -> Set[str]: - """ - Return the keys of models with the indicated name. - - Note that this is the name of the model given to it by - the remote source. The user may have changed the local - name. The local name will be located in the model config - record object. - """ diff --git a/invokeai/app/services/model_metadata/metadata_store_sql.py b/invokeai/app/services/model_metadata/metadata_store_sql.py deleted file mode 100644 index 4f8170448f..0000000000 --- a/invokeai/app/services/model_metadata/metadata_store_sql.py +++ /dev/null @@ -1,223 +0,0 @@ -# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team -""" -SQL Storage for Model Metadata -""" - -import sqlite3 -from typing import List, Optional, Set, Tuple - -from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase -from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException -from invokeai.backend.model_manager.metadata.fetch import ModelMetadataFetchBase - -from .metadata_store_base import ModelMetadataStoreBase - - -class ModelMetadataStoreSQL(ModelMetadataStoreBase): - """Store, search and fetch model metadata retrieved from remote repositories.""" - - def __init__(self, db: SqliteDatabase): - """ - Initialize a new object from preexisting sqlite3 connection and threading lock objects. - - :param conn: sqlite3 connection object - :param lock: threading Lock object - """ - super().__init__() - self._db = db - self._cursor = self._db.conn.cursor() - - def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None: - """ - Add a block of repo metadata to a model record. - - The model record config must already exist in the database with the - same key. Otherwise a FOREIGN KEY constraint exception will be raised. - - :param model_key: Existing model key in the `model_config` table - :param metadata: ModelRepoMetadata object to store - """ - json_serialized = metadata.model_dump_json() - with self._db.lock: - try: - self._cursor.execute( - """--sql - INSERT INTO model_metadata( - id, - metadata - ) - VALUES (?,?); - """, - ( - model_key, - json_serialized, - ), - ) - self._update_tags(model_key, metadata.tags) - self._db.conn.commit() - except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table - self._db.conn.rollback() - raise UnknownMetadataException from excp - except sqlite3.Error as excp: - self._db.conn.rollback() - raise excp - - def get_metadata(self, model_key: str) -> AnyModelRepoMetadata: - """Retrieve the ModelRepoMetadata corresponding to model key.""" - with self._db.lock: - self._cursor.execute( - """--sql - SELECT metadata FROM model_metadata - WHERE id=?; - """, - (model_key,), - ) - rows = self._cursor.fetchone() - if not rows: - raise UnknownMetadataException("model metadata not found") - return ModelMetadataFetchBase.from_json(rows[0]) - - def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata - """Dump out all the metadata.""" - with self._db.lock: - self._cursor.execute( - """--sql - SELECT id,metadata FROM model_metadata; - """, - (), - ) - rows = self._cursor.fetchall() - return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows] - - def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata: - """ - Update metadata corresponding to the model with the indicated key. - - :param model_key: Existing model key in the `model_config` table - :param metadata: ModelRepoMetadata object to update - """ - json_serialized = metadata.model_dump_json() # turn it into a json string. - with self._db.lock: - try: - self._cursor.execute( - """--sql - UPDATE model_metadata - SET - metadata=? - WHERE id=?; - """, - (json_serialized, model_key), - ) - if self._cursor.rowcount == 0: - raise UnknownMetadataException("model metadata not found") - self._update_tags(model_key, metadata.tags) - self._db.conn.commit() - except sqlite3.Error as e: - self._db.conn.rollback() - raise e - - return self.get_metadata(model_key) - - def list_tags(self) -> Set[str]: - """Return all tags in the tags table.""" - self._cursor.execute( - """--sql - select tag_text from tags; - """ - ) - return {x[0] for x in self._cursor.fetchall()} - - def search_by_tag(self, tags: Set[str]) -> Set[str]: - """Return the keys of models containing all of the listed tags.""" - with self._db.lock: - try: - matches: Optional[Set[str]] = None - for tag in tags: - self._cursor.execute( - """--sql - SELECT a.model_id FROM model_tags AS a, - tags AS b - WHERE a.tag_id=b.tag_id - AND b.tag_text=?; - """, - (tag,), - ) - model_keys = {x[0] for x in self._cursor.fetchall()} - if matches is None: - matches = model_keys - matches = matches.intersection(model_keys) - except sqlite3.Error as e: - raise e - return matches if matches else set() - - def search_by_author(self, author: str) -> Set[str]: - """Return the keys of models authored by the indicated author.""" - self._cursor.execute( - """--sql - SELECT id FROM model_metadata - WHERE author=?; - """, - (author,), - ) - return {x[0] for x in self._cursor.fetchall()} - - def search_by_name(self, name: str) -> Set[str]: - """ - Return the keys of models with the indicated name. - - Note that this is the name of the model given to it by - the remote source. The user may have changed the local - name. The local name will be located in the model config - record object. - """ - self._cursor.execute( - """--sql - SELECT id FROM model_metadata - WHERE name=?; - """, - (name,), - ) - return {x[0] for x in self._cursor.fetchall()} - - def _update_tags(self, model_key: str, tags: Optional[Set[str]]) -> None: - """Update tags for the model referenced by model_key.""" - if tags: - # remove previous tags from this model - self._cursor.execute( - """--sql - DELETE FROM model_tags - WHERE model_id=?; - """, - (model_key,), - ) - - for tag in tags: - self._cursor.execute( - """--sql - INSERT OR IGNORE INTO tags ( - tag_text - ) - VALUES (?); - """, - (tag,), - ) - self._cursor.execute( - """--sql - SELECT tag_id - FROM tags - WHERE tag_text = ? - LIMIT 1; - """, - (tag,), - ) - tag_id = self._cursor.fetchone()[0] - self._cursor.execute( - """--sql - INSERT OR IGNORE INTO model_tags ( - model_id, - tag_id - ) - VALUES (?,?); - """, - (model_key, tag_id), - ) diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py index 9c463ebb45..f376ae3878 100644 --- a/invokeai/app/services/model_records/model_records_base.py +++ b/invokeai/app/services/model_records/model_records_base.py @@ -6,7 +6,7 @@ Abstract base class for storing and retrieving model configuration records. from abc import ABC, abstractmethod from enum import Enum from pathlib import Path -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Dict, List, Optional, Set, Union from pydantic import BaseModel, Field @@ -17,9 +17,6 @@ from invokeai.backend.model_manager import ( ModelFormat, ModelType, ) -from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata - -from ..model_metadata import ModelMetadataStoreBase class DuplicateModelException(Exception): @@ -109,40 +106,6 @@ class ModelRecordServiceBase(ABC): """ pass - @property - @abstractmethod - def metadata_store(self) -> ModelMetadataStoreBase: - """Return a ModelMetadataStore initialized on the same database.""" - pass - - @abstractmethod - def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]: - """ - Retrieve metadata (if any) from when model was downloaded from a repo. - - :param key: Model key - """ - pass - - @abstractmethod - def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: - """List metadata for all models that have it.""" - pass - - @abstractmethod - def search_by_metadata_tag(self, tags: Set[str]) -> List[AnyModelConfig]: - """ - Search model metadata for ones with all listed tags and return their corresponding configs. - - :param tags: Set of tags to search for. All tags must be present. - """ - pass - - @abstractmethod - def list_tags(self) -> Set[str]: - """Return a unique set of all the model tags in the metadata database.""" - pass - @abstractmethod def list_models( self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py index 35ddc75567..ce4467d833 100644 --- a/invokeai/app/services/model_records/model_records_sql.py +++ b/invokeai/app/services/model_records/model_records_sql.py @@ -43,7 +43,7 @@ import json import sqlite3 from math import ceil from pathlib import Path -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Dict, List, Optional, Union from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.backend.model_manager.config import ( @@ -53,9 +53,7 @@ from invokeai.backend.model_manager.config import ( ModelFormat, ModelType, ) -from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException -from ..model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL from ..shared.sqlite.sqlite_database import SqliteDatabase from .model_records_base import ( DuplicateModelException, @@ -69,7 +67,7 @@ from .model_records_base import ( class ModelRecordServiceSQL(ModelRecordServiceBase): """Implementation of the ModelConfigStore ABC using a SQL database.""" - def __init__(self, db: SqliteDatabase, metadata_store: ModelMetadataStoreBase): + def __init__(self, db: SqliteDatabase): """ Initialize a new object from preexisting sqlite3 connection and threading lock objects. @@ -78,7 +76,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): super().__init__() self._db = db self._cursor = db.conn.cursor() - self._metadata_store = metadata_store @property def db(self) -> SqliteDatabase: @@ -242,9 +239,8 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): If none of the optional filters are passed, will return all models in the database. """ - results = [] - where_clause = [] - bindings = [] + where_clause: list[str] = [] + bindings: list[str] = [] if model_name: where_clause.append("name=?") bindings.append(model_name) @@ -302,55 +298,17 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): ] return results - @property - def metadata_store(self) -> ModelMetadataStoreBase: - """Return a ModelMetadataStore initialized on the same database.""" - return self._metadata_store - - def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]: - """ - Retrieve metadata (if any) from when model was downloaded from a repo. - - :param key: Model key - """ - store = self.metadata_store - try: - metadata = store.get_metadata(key) - return metadata - except UnknownMetadataException: - return None - - def search_by_metadata_tag(self, tags: Set[str]) -> List[AnyModelConfig]: - """ - Search model metadata for ones with all listed tags and return their corresponding configs. - - :param tags: Set of tags to search for. All tags must be present. - """ - store = ModelMetadataStoreSQL(self._db) - keys = store.search_by_tag(tags) - return [self.get_model(x) for x in keys] - - def list_tags(self) -> Set[str]: - """Return a unique set of all the model tags in the metadata database.""" - store = ModelMetadataStoreSQL(self._db) - return store.list_tags() - - def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: - """List metadata for all models that have it.""" - store = ModelMetadataStoreSQL(self._db) - return store.list_all_metadata() - def list_models( self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default ) -> PaginatedResults[ModelSummary]: """Return a paginated summary listing of each model in the database.""" assert isinstance(order_by, ModelRecordOrderBy) ordering = { - ModelRecordOrderBy.Default: "a.type, a.base, a.format, a.name", - ModelRecordOrderBy.Type: "a.type", - ModelRecordOrderBy.Base: "a.base", - ModelRecordOrderBy.Name: "a.name", - ModelRecordOrderBy.Format: "a.format", + ModelRecordOrderBy.Default: "type, base, format, name", + ModelRecordOrderBy.Type: "type", + ModelRecordOrderBy.Base: "base", + ModelRecordOrderBy.Name: "name", + ModelRecordOrderBy.Format: "format", } # Lock so that the database isn't updated while we're doing the two queries. @@ -364,7 +322,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): ) total = int(self._cursor.fetchone()[0]) - # query2: fetch key fields from the join of models and model_metadata + # query2: fetch key fields self._cursor.execute( f"""--sql SELECT config diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_7.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_7.py index 60a58c1a38..bb33609c27 100644 --- a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_7.py +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_7.py @@ -6,6 +6,15 @@ from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import class Migration7Callback: def __call__(self, cursor: sqlite3.Cursor) -> None: self._create_models_table(cursor) + self._drop_old_models_tables(cursor) + + def _drop_old_models_tables(self, cursor: sqlite3.Cursor) -> None: + """Drops the old model_records, model_metadata, model_tags and tags tables.""" + + tables = ["model_records", "model_metadata", "model_tags", "tags"] + + for table in tables: + cursor.execute(f"DROP TABLE IF EXISTS {table};") def _create_models_table(self, cursor: sqlite3.Cursor) -> None: """Creates the v4.0.0 models table.""" @@ -67,7 +76,7 @@ def build_migration_7() -> Migration: This migration does the following: - Adds the new models table - - TODO(MM2): Drops the old model_records, model_metadata, model_tags and tags tables. + - Drops the old model_records, model_metadata, model_tags and tags tables. - TODO(MM2): Migrates model names and descriptions from `models.yaml` to the new table (?). """ migration_7 = Migration( diff --git a/invokeai/app/util/metadata.py b/invokeai/app/util/metadata.py deleted file mode 100644 index 52f9750e4f..0000000000 --- a/invokeai/app/util/metadata.py +++ /dev/null @@ -1,55 +0,0 @@ -import json -from typing import Optional - -from pydantic import ValidationError - -from invokeai.app.services.shared.graph import Edge - - -def get_metadata_graph_from_raw_session(session_raw: str) -> Optional[dict]: - """ - Parses raw session string, returning a dict of the graph. - - Only the general graph shape is validated; none of the fields are validated. - - Any `metadata_accumulator` nodes and edges are removed. - - Any validation failure will return None. - """ - - graph = json.loads(session_raw).get("graph", None) - - # sanity check make sure the graph is at least reasonably shaped - if ( - not isinstance(graph, dict) - or "nodes" not in graph - or not isinstance(graph["nodes"], dict) - or "edges" not in graph - or not isinstance(graph["edges"], list) - ): - # something has gone terribly awry, return an empty dict - return None - - try: - # delete the `metadata_accumulator` node - del graph["nodes"]["metadata_accumulator"] - except KeyError: - # no accumulator node, all good - pass - - # delete any edges to or from it - for i, edge in enumerate(graph["edges"]): - try: - # try to parse the edge - Edge(**edge) - except ValidationError: - # something has gone terribly awry, return an empty dict - return None - - if ( - edge["source"]["node_id"] == "metadata_accumulator" - or edge["destination"]["node_id"] == "metadata_accumulator" - ): - del graph["edges"][i] - - return graph diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 676a2a0250..0386eab8ca 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -29,6 +29,8 @@ from diffusers.models.modeling_utils import ModelMixin from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter from typing_extensions import Annotated, Any, Dict +from invokeai.app.util.misc import uuid_string + from ..raw_model import RawModel # ModelMixin is the base class for all diffusers and transformers models @@ -132,7 +134,7 @@ class ModelSourceType(str, Enum): class ModelConfigBase(BaseModel): """Base class for model configuration information.""" - key: str = Field(description="A unique key for this model.") + key: str = Field(description="A unique key for this model.", default_factory=uuid_string) hash: str = Field(description="The hash of the model file(s).") path: str = Field( description="Path to the model on the filesystem. Relative paths are relative to the Invoke root directory." @@ -142,7 +144,9 @@ class ModelConfigBase(BaseModel): description: Optional[str] = Field(description="Model description", default=None) source: str = Field(description="The original source of the model (path, URL or repo_id).") source_type: ModelSourceType = Field(description="The type of source") - source_api_response: Optional[str] = Field(description="The original API response from the source, as stringified JSON.", default=None) + source_api_response: Optional[str] = Field( + description="The original API response from the source, as stringified JSON.", default=None + ) trigger_words: Optional[set[str]] = Field(description="Set of trigger words for this model", default=None) model_config = ConfigDict(use_enum_values=False, validate_assignment=True) diff --git a/invokeai/backend/model_manager/metadata/fetch/huggingface.py b/invokeai/backend/model_manager/metadata/fetch/huggingface.py index 9d2a52603d..5090f29148 100644 --- a/invokeai/backend/model_manager/metadata/fetch/huggingface.py +++ b/invokeai/backend/model_manager/metadata/fetch/huggingface.py @@ -91,7 +91,7 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase): ) return HuggingFaceMetadata( - id=model_info.id, name=name, files=files, api_response=json.dumps(model_info.__dict__) + id=model_info.id, name=name, files=files, api_response=json.dumps(model_info.__dict__, default=str) ) def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata: diff --git a/invokeai/backend/model_manager/metadata/metadata_store.py b/invokeai/backend/model_manager/metadata/metadata_store.py deleted file mode 100644 index 684409fc3b..0000000000 --- a/invokeai/backend/model_manager/metadata/metadata_store.py +++ /dev/null @@ -1,221 +0,0 @@ -# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team -""" -SQL Storage for Model Metadata -""" - -import sqlite3 -from typing import List, Optional, Set, Tuple - -from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase - -from .fetch import ModelMetadataFetchBase -from .metadata_base import AnyModelRepoMetadata, UnknownMetadataException - - -class ModelMetadataStore: - """Store, search and fetch model metadata retrieved from remote repositories.""" - - def __init__(self, db: SqliteDatabase): - """ - Initialize a new object from preexisting sqlite3 connection and threading lock objects. - - :param conn: sqlite3 connection object - :param lock: threading Lock object - """ - super().__init__() - self._db = db - self._cursor = self._db.conn.cursor() - - def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None: - """ - Add a block of repo metadata to a model record. - - The model record config must already exist in the database with the - same key. Otherwise a FOREIGN KEY constraint exception will be raised. - - :param model_key: Existing model key in the `model_config` table - :param metadata: ModelRepoMetadata object to store - """ - json_serialized = metadata.model_dump_json() - with self._db.lock: - try: - self._cursor.execute( - """--sql - INSERT INTO model_metadata( - id, - metadata - ) - VALUES (?,?); - """, - ( - model_key, - json_serialized, - ), - ) - self._update_tags(model_key, metadata.tags) - self._db.conn.commit() - except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table - self._db.conn.rollback() - raise UnknownMetadataException from excp - except sqlite3.Error as excp: - self._db.conn.rollback() - raise excp - - def get_metadata(self, model_key: str) -> AnyModelRepoMetadata: - """Retrieve the ModelRepoMetadata corresponding to model key.""" - with self._db.lock: - self._cursor.execute( - """--sql - SELECT metadata FROM model_metadata - WHERE id=?; - """, - (model_key,), - ) - rows = self._cursor.fetchone() - if not rows: - raise UnknownMetadataException("model metadata not found") - return ModelMetadataFetchBase.from_json(rows[0]) - - def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata - """Dump out all the metadata.""" - with self._db.lock: - self._cursor.execute( - """--sql - SELECT id,metadata FROM model_metadata; - """, - (), - ) - rows = self._cursor.fetchall() - return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows] - - def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata: - """ - Update metadata corresponding to the model with the indicated key. - - :param model_key: Existing model key in the `model_config` table - :param metadata: ModelRepoMetadata object to update - """ - json_serialized = metadata.model_dump_json() # turn it into a json string. - with self._db.lock: - try: - self._cursor.execute( - """--sql - UPDATE model_metadata - SET - metadata=? - WHERE id=?; - """, - (json_serialized, model_key), - ) - if self._cursor.rowcount == 0: - raise UnknownMetadataException("model metadata not found") - self._update_tags(model_key, metadata.tags) - self._db.conn.commit() - except sqlite3.Error as e: - self._db.conn.rollback() - raise e - - return self.get_metadata(model_key) - - def list_tags(self) -> Set[str]: - """Return all tags in the tags table.""" - self._cursor.execute( - """--sql - select tag_text from tags; - """ - ) - return {x[0] for x in self._cursor.fetchall()} - - def search_by_tag(self, tags: Set[str]) -> Set[str]: - """Return the keys of models containing all of the listed tags.""" - with self._db.lock: - try: - matches: Optional[Set[str]] = None - for tag in tags: - self._cursor.execute( - """--sql - SELECT a.model_id FROM model_tags AS a, - tags AS b - WHERE a.tag_id=b.tag_id - AND b.tag_text=?; - """, - (tag,), - ) - model_keys = {x[0] for x in self._cursor.fetchall()} - if matches is None: - matches = model_keys - matches = matches.intersection(model_keys) - except sqlite3.Error as e: - raise e - return matches if matches else set() - - def search_by_author(self, author: str) -> Set[str]: - """Return the keys of models authored by the indicated author.""" - self._cursor.execute( - """--sql - SELECT id FROM model_metadata - WHERE author=?; - """, - (author,), - ) - return {x[0] for x in self._cursor.fetchall()} - - def search_by_name(self, name: str) -> Set[str]: - """ - Return the keys of models with the indicated name. - - Note that this is the name of the model given to it by - the remote source. The user may have changed the local - name. The local name will be located in the model config - record object. - """ - self._cursor.execute( - """--sql - SELECT id FROM model_metadata - WHERE name=?; - """, - (name,), - ) - return {x[0] for x in self._cursor.fetchall()} - - def _update_tags(self, model_key: str, tags: Set[str]) -> None: - """Update tags for the model referenced by model_key.""" - # remove previous tags from this model - self._cursor.execute( - """--sql - DELETE FROM model_tags - WHERE model_id=?; - """, - (model_key,), - ) - - for tag in tags: - self._cursor.execute( - """--sql - INSERT OR IGNORE INTO tags ( - tag_text - ) - VALUES (?); - """, - (tag,), - ) - self._cursor.execute( - """--sql - SELECT tag_id - FROM tags - WHERE tag_text = ? - LIMIT 1; - """, - (tag,), - ) - tag_id = self._cursor.fetchone()[0] - self._cursor.execute( - """--sql - INSERT OR IGNORE INTO model_tags ( - model_id, - tag_id - ) - VALUES (?,?); - """, - (model_key, tag_id), - ) diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index c837993888..774959f7ef 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -18,6 +18,7 @@ from .config import ( ModelConfigFactory, ModelFormat, ModelRepoVariant, + ModelSourceType, ModelType, ModelVariantType, SchedulerPredictionType, @@ -150,7 +151,7 @@ class ModelProbe(object): probe = probe_class(model_path) - fields["source_type"] = fields.get("source_type") + fields["source_type"] = fields.get("source_type") or ModelSourceType.Path fields["source"] = fields.get("source") or model_path.as_posix() fields["key"] = fields.get("key", uuid_string()) fields["path"] = model_path.as_posix() diff --git a/tests/app/services/model_records/test_model_records_sql.py b/tests/app/services/model_records/test_model_records_sql.py index 8263d77d79..1552d4c07d 100644 --- a/tests/app/services/model_records/test_model_records_sql.py +++ b/tests/app/services/model_records/test_model_records_sql.py @@ -3,29 +3,26 @@ Test the refactored model config classes. """ from hashlib import sha256 -from typing import Any +from typing import Any, Optional import pytest from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.model_metadata import ModelMetadataStoreSQL from invokeai.app.services.model_records import ( DuplicateModelException, - ModelRecordOrderBy, ModelRecordServiceBase, ModelRecordServiceSQL, UnknownModelException, ) from invokeai.backend.model_manager.config import ( BaseModelType, - MainCheckpointConfig, MainDiffusersConfig, ModelFormat, + ModelSourceType, ModelType, TextualInversionFileConfig, VaeDiffusersConfig, ) -from invokeai.backend.model_manager.metadata import BaseMetadata from invokeai.backend.util.logging import InvokeAILogger from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 from tests.fixtures.sqlite_database import create_mock_sqlite_database @@ -38,11 +35,13 @@ def store( config = InvokeAIAppConfig(root=datadir) logger = InvokeAILogger.get_logger(config=config) db = create_mock_sqlite_database(config, logger) - return ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db)) + return ModelRecordServiceSQL(db) -def example_config() -> TextualInversionFileConfig: - return TextualInversionFileConfig( +def example_config(key: Optional[str] = None) -> TextualInversionFileConfig: + config = TextualInversionFileConfig( + source="test/source/", + source_type=ModelSourceType.Path, path="/tmp/pokemon.bin", name="old name", base=BaseModelType.StableDiffusion1, @@ -50,59 +49,45 @@ def example_config() -> TextualInversionFileConfig: format=ModelFormat.EmbeddingFile, hash="ABC123", ) + if key is not None: + config.key = key + return config def test_type(store: ModelRecordServiceBase): - config = example_config() - store.add_model("key1", config) + config = example_config("key1") + store.add_model(config) config1 = store.get_model("key1") - assert type(config1) == TextualInversionFileConfig + assert isinstance(config1, TextualInversionFileConfig) -def test_add(store: ModelRecordServiceBase): - raw = { - "path": "/tmp/foo.ckpt", - "name": "model1", - "base": BaseModelType.StableDiffusion1, - "type": "main", - "config_path": "/tmp/foo.yaml", - "variant": "normal", - "format": "checkpoint", - "original_hash": "111222333444", - } - store.add_model("key1", raw) - config1 = store.get_model("key1") - assert config1 is not None - assert type(config1) == MainCheckpointConfig - assert config1.base == BaseModelType.StableDiffusion1 - assert config1.name == "model1" - assert config1.hash == "111222333444" - - -def test_dup(store: ModelRecordServiceBase): - config = example_config() - store.add_model("key1", example_config()) +def test_raises_on_violating_uniqueness(store: ModelRecordServiceBase): + # Models have a uniqueness constraint by their name, base and type + config1 = example_config("key1") + config2 = config1.model_copy(deep=True) + config2.key = "key2" + store.add_model(config1) with pytest.raises(DuplicateModelException): - store.add_model("key1", config) + store.add_model(config1) with pytest.raises(DuplicateModelException): - store.add_model("key2", config) + store.add_model(config2) def test_update(store: ModelRecordServiceBase): - config = example_config() - store.add_model("key1", config) + config = example_config("key1") + store.add_model(config) config = store.get_model("key1") assert config.name == "old name" config.name = "new name" - store.update_model("key1", config) + store.update_model(config.key, config) new_config = store.get_model("key1") assert new_config.name == "new name" def test_rename(store: ModelRecordServiceBase): - config = example_config() - store.add_model("key1", config) + config = example_config("key1") + store.add_model(config) config = store.get_model("key1") assert config.name == "old name" @@ -112,15 +97,15 @@ def test_rename(store: ModelRecordServiceBase): def test_unknown_key(store: ModelRecordServiceBase): - config = example_config() - store.add_model("key1", config) + config = example_config("key1") + store.add_model(config) with pytest.raises(UnknownModelException): store.update_model("unknown_key", config) def test_delete(store: ModelRecordServiceBase): - config = example_config() - store.add_model("key1", config) + config = example_config("key1") + store.add_model(config) config = store.get_model("key1") store.del_model("key1") with pytest.raises(UnknownModelException): @@ -128,36 +113,45 @@ def test_delete(store: ModelRecordServiceBase): def test_exists(store: ModelRecordServiceBase): - config = example_config() - store.add_model("key1", config) + config = example_config("key1") + store.add_model(config) assert store.exists("key1") assert not store.exists("key2") def test_filter(store: ModelRecordServiceBase): config1 = MainDiffusersConfig( + key="config1", path="/tmp/config1", name="config1", base=BaseModelType.StableDiffusion1, type=ModelType.Main, hash="CONFIG1HASH", + source="test/source", + source_type=ModelSourceType.Path, ) config2 = MainDiffusersConfig( + key="config2", path="/tmp/config2", name="config2", base=BaseModelType.StableDiffusion1, type=ModelType.Main, hash="CONFIG2HASH", + source="test/source", + source_type=ModelSourceType.Path, ) config3 = VaeDiffusersConfig( + key="config3", path="/tmp/config3", name="config3", base=BaseModelType("sd-2"), type=ModelType.Vae, hash="CONFIG3HASH", + source="test/source", + source_type=ModelSourceType.Path, ) for c in config1, config2, config3: - store.add_model(sha256(c.name.encode("utf-8")).hexdigest(), c) + store.add_model(c) matches = store.search_by_attr(model_type=ModelType.Main) assert len(matches) == 2 assert matches[0].name in {"config1", "config2"} @@ -165,7 +159,7 @@ def test_filter(store: ModelRecordServiceBase): matches = store.search_by_attr(model_type=ModelType.Vae) assert len(matches) == 1 assert matches[0].name == "config3" - assert matches[0].key == sha256("config3".encode("utf-8")).hexdigest() + assert matches[0].key == "config3" assert isinstance(matches[0].type, ModelType) # This tests that we get proper enums back matches = store.search_by_hash("CONFIG1HASH") @@ -183,6 +177,8 @@ def test_unique(store: ModelRecordServiceBase): type=ModelType.Main, name="nonuniquename", hash="CONFIG1HASH", + source="test/source/", + source_type=ModelSourceType.Path, ) config2 = MainDiffusersConfig( path="/tmp/config2", @@ -190,6 +186,8 @@ def test_unique(store: ModelRecordServiceBase): type=ModelType.Main, name="nonuniquename", hash="CONFIG1HASH", + source="test/source/", + source_type=ModelSourceType.Path, ) config3 = VaeDiffusersConfig( path="/tmp/config3", @@ -197,6 +195,8 @@ def test_unique(store: ModelRecordServiceBase): type=ModelType.Vae, name="nonuniquename", hash="CONFIG1HASH", + source="test/source/", + source_type=ModelSourceType.Path, ) config4 = MainDiffusersConfig( path="/tmp/config4", @@ -204,15 +204,19 @@ def test_unique(store: ModelRecordServiceBase): type=ModelType.Main, name="nonuniquename", hash="CONFIG1HASH", + source="test/source/", + source_type=ModelSourceType.Path, ) # config1, config2 and config3 are compatible because they have unique combos # of name, type and base for c in config1, config2, config3: - store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), c) + c.key = sha256(c.path.encode("utf-8")).hexdigest() + store.add_model(c) # config4 clashes with config1 and should raise an integrity error with pytest.raises(DuplicateModelException): - store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), config4) + config4.key = sha256(config4.path.encode("utf-8")).hexdigest() + store.add_model(config4) def test_filter_2(store: ModelRecordServiceBase): @@ -222,6 +226,8 @@ def test_filter_2(store: ModelRecordServiceBase): base=BaseModelType.StableDiffusion1, type=ModelType.Main, hash="CONFIG1HASH", + source="test/source/", + source_type=ModelSourceType.Path, ) config2 = MainDiffusersConfig( path="/tmp/config2", @@ -229,6 +235,8 @@ def test_filter_2(store: ModelRecordServiceBase): base=BaseModelType.StableDiffusion1, type=ModelType.Main, hash="CONFIG2HASH", + source="test/source/", + source_type=ModelSourceType.Path, ) config3 = MainDiffusersConfig( path="/tmp/config3", @@ -236,6 +244,8 @@ def test_filter_2(store: ModelRecordServiceBase): base=BaseModelType("sd-2"), type=ModelType.Main, hash="CONFIG3HASH", + source="test/source/", + source_type=ModelSourceType.Path, ) config4 = MainDiffusersConfig( path="/tmp/config4", @@ -243,6 +253,8 @@ def test_filter_2(store: ModelRecordServiceBase): base=BaseModelType("sdxl"), type=ModelType.Main, hash="CONFIG3HASH", + source="test/source/", + source_type=ModelSourceType.Path, ) config5 = VaeDiffusersConfig( path="/tmp/config5", @@ -250,9 +262,11 @@ def test_filter_2(store: ModelRecordServiceBase): base=BaseModelType.StableDiffusion1, type=ModelType.Vae, hash="CONFIG3HASH", + source="test/source/", + source_type=ModelSourceType.Path, ) for c in config1, config2, config3, config4, config5: - store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), c) + store.add_model(c) matches = store.search_by_attr( model_type=ModelType.Main, @@ -272,50 +286,3 @@ def test_filter_2(store: ModelRecordServiceBase): model_name="dup_name1", ) assert len(matches) == 1 - - -def test_summary(mm2_record_store: ModelRecordServiceSQL) -> None: - # The fixture provides us with five configs. - for x in range(1, 5): - key = f"test_config_{x}" - name = f"name_{x}" - author = f"author_{x}" - tags = {f"tag{y}" for y in range(1, x)} - mm2_record_store.metadata_store.add_metadata( - model_key=key, metadata=BaseMetadata(name=name, author=author, tags=tags) - ) - # sanity check that the tags sent in all right - assert mm2_record_store.get_metadata("test_config_3").tags == {"tag1", "tag2"} - assert mm2_record_store.get_metadata("test_config_4").tags == {"tag1", "tag2", "tag3"} - - # get summary - summary1 = mm2_record_store.list_models(page=0, per_page=100) - assert summary1.page == 0 - assert summary1.pages == 1 - assert summary1.per_page == 100 - assert summary1.total == 5 - assert len(summary1.items) == 5 - assert summary1.items[0].name == "test5" # lora / sd-1 / diffusers / test5 - - # find test_config_3 - config3 = [x for x in summary1.items if x.key == "test_config_3"][0] - assert config3.description == "This is test 3" - assert config3.tags == {"tag1", "tag2"} - - # find test_config_5 - config5 = [x for x in summary1.items if x.key == "test_config_5"][0] - assert config5.tags == set() - assert config5.description == "" - - # test paging - summary2 = mm2_record_store.list_models(page=1, per_page=2) - assert summary2.page == 1 - assert summary2.per_page == 2 - assert summary2.pages == 3 - assert summary1.items[2].name == summary2.items[0].name - - # test sorting - summary = mm2_record_store.list_models(page=0, per_page=100, order_by=ModelRecordOrderBy.Name) - print(summary.items) - assert summary.items[0].name == "model1" - assert summary.items[-1].name == "test5" diff --git a/tests/backend/model_manager/model_manager_fixtures.py b/tests/backend/model_manager/model_manager_fixtures.py index 4c0226f0cb..112b3765ff 100644 --- a/tests/backend/model_manager/model_manager_fixtures.py +++ b/tests/backend/model_manager/model_manager_fixtures.py @@ -18,12 +18,17 @@ from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase from invokeai.app.services.model_load import ModelLoadService, ModelLoadServiceBase from invokeai.app.services.model_manager import ModelManagerService, ModelManagerServiceBase -from invokeai.app.services.model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL from invokeai.backend.model_manager.config import ( BaseModelType, + LoRADiffusersConfig, + MainCheckpointConfig, + MainDiffusersConfig, ModelFormat, + ModelSourceType, ModelType, + ModelVariantType, + VaeDiffusersConfig, ) from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache from invokeai.backend.util.logging import InvokeAILogger @@ -107,11 +112,6 @@ def mm2_download_queue(mm2_session: Session, request: FixtureRequest) -> Downloa return download_queue -@pytest.fixture -def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStoreBase: - return mm2_record_store.metadata_store - - @pytest.fixture def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase: ram_cache = ModelCache( @@ -137,7 +137,7 @@ def mm2_installer( logger = InvokeAILogger.get_logger() db = create_mock_sqlite_database(mm2_app_config, logger) events = DummyEventService() - store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db)) + store = ModelRecordServiceSQL(db) installer = ModelInstallService( app_config=mm2_app_config, @@ -160,61 +160,71 @@ def mm2_installer( def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBase: logger = InvokeAILogger.get_logger(config=mm2_app_config) db = create_mock_sqlite_database(mm2_app_config, logger) - store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db)) + store = ModelRecordServiceSQL(db) # add five simple config records to the database - raw1 = { - "path": "/tmp/foo1", - "format": ModelFormat("diffusers"), - "name": "test2", - "base": BaseModelType("sd-2"), - "type": ModelType("vae"), - "original_hash": "111222333444", - "source": "stabilityai/sdxl-vae", - } - raw2 = { - "path": "/tmp/foo2.ckpt", - "name": "model1", - "format": ModelFormat("checkpoint"), - "base": BaseModelType("sd-1"), - "type": "main", - "config_path": "/tmp/foo.yaml", - "variant": "normal", - "original_hash": "111222333444", - "source": "https://civitai.com/models/206883/split", - } - raw3 = { - "path": "/tmp/foo3", - "format": ModelFormat("diffusers"), - "name": "test3", - "base": BaseModelType("sdxl"), - "type": ModelType("main"), - "original_hash": "111222333444", - "source": "author3/model3", - "description": "This is test 3", - } - raw4 = { - "path": "/tmp/foo4", - "format": ModelFormat("diffusers"), - "name": "test4", - "base": BaseModelType("sdxl"), - "type": ModelType("lora"), - "original_hash": "111222333444", - "source": "author4/model4", - } - raw5 = { - "path": "/tmp/foo5", - "format": ModelFormat("diffusers"), - "name": "test5", - "base": BaseModelType("sd-1"), - "type": ModelType("lora"), - "original_hash": "111222333444", - "source": "author4/model5", - } - store.add_model("test_config_1", raw1) - store.add_model("test_config_2", raw2) - store.add_model("test_config_3", raw3) - store.add_model("test_config_4", raw4) - store.add_model("test_config_5", raw5) + config1 = VaeDiffusersConfig( + key="test_config_1", + path="/tmp/foo1", + format=ModelFormat.Diffusers, + name="test2", + base=BaseModelType.StableDiffusion2, + type=ModelType.Vae, + hash="111222333444", + source="stabilityai/sdxl-vae", + source_type=ModelSourceType.HFRepoID, + ) + config2 = MainCheckpointConfig( + key="test_config_2", + path="/tmp/foo2.ckpt", + name="model1", + format=ModelFormat.Checkpoint, + base=BaseModelType.StableDiffusion1, + type=ModelType.Main, + config_path="/tmp/foo.yaml", + variant=ModelVariantType.Normal, + hash="111222333444", + source="https://civitai.com/models/206883/split", + source_type=ModelSourceType.CivitAI, + ) + config3 = MainDiffusersConfig( + key="test_config_3", + path="/tmp/foo3", + format=ModelFormat.Diffusers, + name="test3", + base=BaseModelType.StableDiffusionXL, + type=ModelType.Main, + hash="111222333444", + source="author3/model3", + description="This is test 3", + source_type=ModelSourceType.HFRepoID, + ) + config4 = LoRADiffusersConfig( + key="test_config_4", + path="/tmp/foo4", + format=ModelFormat.Diffusers, + name="test4", + base=BaseModelType.StableDiffusionXL, + type=ModelType.Lora, + hash="111222333444", + source="author4/model4", + source_type=ModelSourceType.HFRepoID, + ) + config5 = LoRADiffusersConfig( + key="test_config_5", + path="/tmp/foo5", + format=ModelFormat.Diffusers, + name="test5", + base=BaseModelType.StableDiffusion1, + type=ModelType.Lora, + hash="111222333444", + source="author4/model5", + source_type=ModelSourceType.HFRepoID, + ) + store.add_model(config1) + store.add_model(config2) + store.add_model(config3) + store.add_model(config4) + store.add_model(config5) return store diff --git a/tests/backend/model_manager/model_metadata/test_model_metadata.py b/tests/backend/model_manager/model_metadata/test_model_metadata.py deleted file mode 100644 index 0c6b1a6e93..0000000000 --- a/tests/backend/model_manager/model_metadata/test_model_metadata.py +++ /dev/null @@ -1,202 +0,0 @@ -""" -Test model metadata fetching and storage. -""" - -import datetime -from pathlib import Path - -import pytest -from pydantic.networks import HttpUrl -from requests.sessions import Session - -from invokeai.app.services.model_metadata import ModelMetadataStoreBase -from invokeai.backend.model_manager.config import ModelRepoVariant -from invokeai.backend.model_manager.metadata import ( - CivitaiMetadata, - CivitaiMetadataFetch, - CommercialUsage, - HuggingFaceMetadata, - HuggingFaceMetadataFetch, - UnknownMetadataException, -) -from invokeai.backend.model_manager.util import select_hf_files -from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 - - -def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStoreBase) -> None: - tags = {"text-to-image", "diffusers"} - input_metadata = HuggingFaceMetadata( - name="sdxl-vae", - author="stabilityai", - tags=tags, - id="stabilityai/sdxl-vae", - tag_dict={"license": "other"}, - last_modified=datetime.datetime.now(), - ) - mm2_metadata_store.add_metadata("test_config_1", input_metadata) - output_metadata = mm2_metadata_store.get_metadata("test_config_1") - assert input_metadata == output_metadata - with pytest.raises(UnknownMetadataException): - mm2_metadata_store.add_metadata("unknown_key", input_metadata) - assert mm2_metadata_store.list_tags() == tags - - -def test_metadata_store_update(mm2_metadata_store: ModelMetadataStoreBase) -> None: - input_metadata = HuggingFaceMetadata( - name="sdxl-vae", - author="stabilityai", - tags={"text-to-image", "diffusers"}, - id="stabilityai/sdxl-vae", - tag_dict={"license": "other"}, - last_modified=datetime.datetime.now(), - ) - mm2_metadata_store.add_metadata("test_config_1", input_metadata) - input_metadata.name = "new-name" - mm2_metadata_store.update_metadata("test_config_1", input_metadata) - output_metadata = mm2_metadata_store.get_metadata("test_config_1") - assert output_metadata.name == "new-name" - assert input_metadata == output_metadata - - -def test_metadata_search(mm2_metadata_store: ModelMetadataStoreBase) -> None: - metadata1 = HuggingFaceMetadata( - name="sdxl-vae", - author="stabilityai", - tags={"text-to-image", "diffusers"}, - id="stabilityai/sdxl-vae", - tag_dict={"license": "other"}, - last_modified=datetime.datetime.now(), - ) - metadata2 = HuggingFaceMetadata( - name="model2", - author="stabilityai", - tags={"text-to-image", "diffusers", "community-contributed"}, - id="author2/model2", - tag_dict={"license": "other"}, - last_modified=datetime.datetime.now(), - ) - metadata3 = HuggingFaceMetadata( - name="model3", - author="author3", - tags={"text-to-image", "checkpoint", "community-contributed"}, - id="author3/model3", - tag_dict={"license": "other"}, - last_modified=datetime.datetime.now(), - ) - mm2_metadata_store.add_metadata("test_config_1", metadata1) - mm2_metadata_store.add_metadata("test_config_2", metadata2) - mm2_metadata_store.add_metadata("test_config_3", metadata3) - - matches = mm2_metadata_store.search_by_author("stabilityai") - assert len(matches) == 2 - assert "test_config_1" in matches - assert "test_config_2" in matches - matches = mm2_metadata_store.search_by_author("Sherlock Holmes") - assert not matches - - matches = mm2_metadata_store.search_by_name("model3") - assert len(matches) == 1 - assert "test_config_3" in matches - - matches = mm2_metadata_store.search_by_tag({"text-to-image"}) - assert len(matches) == 3 - - matches = mm2_metadata_store.search_by_tag({"text-to-image", "diffusers"}) - assert len(matches) == 2 - assert "test_config_1" in matches - assert "test_config_2" in matches - - matches = mm2_metadata_store.search_by_tag({"checkpoint", "community-contributed"}) - assert len(matches) == 1 - assert "test_config_3" in matches - - # does the tag table update correctly? - matches = mm2_metadata_store.search_by_tag({"checkpoint", "licensed-for-commercial-use"}) - assert not matches - assert mm2_metadata_store.list_tags() == {"text-to-image", "diffusers", "community-contributed", "checkpoint"} - metadata3.tags.add("licensed-for-commercial-use") - mm2_metadata_store.update_metadata("test_config_3", metadata3) - assert mm2_metadata_store.list_tags() == { - "text-to-image", - "diffusers", - "community-contributed", - "checkpoint", - "licensed-for-commercial-use", - } - matches = mm2_metadata_store.search_by_tag({"checkpoint", "licensed-for-commercial-use"}) - assert len(matches) == 1 - - -def test_metadata_civitai_fetch(mm2_session: Session) -> None: - fetcher = CivitaiMetadataFetch(mm2_session) - metadata = fetcher.from_url(HttpUrl("https://civitai.com/models/215485/SDXL-turbo")) - assert isinstance(metadata, CivitaiMetadata) - assert metadata.id == 215485 - assert metadata.author == "test_author" # note that this is not the same as the original from Civitai - assert metadata.allow_commercial_use # changed to make sure we are reading locally not remotely - assert CommercialUsage("RentCivit") in metadata.restrictions.AllowCommercialUse - assert metadata.version_id == 242807 - assert metadata.tags == {"tool", "turbo", "sdxl turbo"} - - -def test_metadata_hf_fetch(mm2_session: Session) -> None: - fetcher = HuggingFaceMetadataFetch(mm2_session) - metadata = fetcher.from_url(HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo")) - assert isinstance(metadata, HuggingFaceMetadata) - assert metadata.author == "test_author" # this is not the same as the original - assert metadata.files - assert metadata.tags == { - "diffusers", - "onnx", - "safetensors", - "text-to-image", - "license:other", - "has_space", - "diffusers:StableDiffusionXLPipeline", - "region:us", - } - - -def test_metadata_hf_filter(mm2_session: Session) -> None: - metadata = HuggingFaceMetadataFetch(mm2_session).from_url(HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo")) - assert isinstance(metadata, HuggingFaceMetadata) - files = [x.path for x in metadata.files] - fp16_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("fp16")) - assert Path("sdxl-turbo/text_encoder/model.fp16.safetensors") in fp16_files - assert Path("sdxl-turbo/text_encoder/model.safetensors") not in fp16_files - - fp32_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("fp32")) - assert Path("sdxl-turbo/text_encoder/model.safetensors") in fp32_files - assert Path("sdxl-turbo/text_encoder/model.16.safetensors") not in fp32_files - - onnx_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("onnx")) - assert Path("sdxl-turbo/text_encoder/model.onnx") in onnx_files - assert Path("sdxl-turbo/text_encoder/model.safetensors") not in onnx_files - - default_files = select_hf_files.filter_files(files) - assert Path("sdxl-turbo/text_encoder/model.safetensors") in default_files - assert Path("sdxl-turbo/text_encoder/model.16.safetensors") not in default_files - - openvino_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("openvino")) - print(openvino_files) - assert len(openvino_files) == 0 - - flax_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("flax")) - print(flax_files) - assert not flax_files - - metadata = HuggingFaceMetadataFetch(mm2_session).from_url( - HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo-nofp16") - ) - assert isinstance(metadata, HuggingFaceMetadata) - files = [x.path for x in metadata.files] - filtered_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("fp16")) - assert ( - Path("sdxl-turbo-nofp16/text_encoder/model.safetensors") in filtered_files - ) # confirm that default is returned - assert Path("sdxl-turbo-nofp16/text_encoder/model.16.safetensors") not in filtered_files - - -def test_metadata_hf_urls(mm2_session: Session) -> None: - metadata = HuggingFaceMetadataFetch(mm2_session).from_url(HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo")) - assert isinstance(metadata, HuggingFaceMetadata)