refactor(mm): remove unused metadata logic, fix tests

- Metadata is merged with the config. We can simplify the MM substantially and remove the handling for metadata.
- Per discussion, we don't have an ETA for frontend implementation of tags, and with the realization that the tags from CivitAI are largely useless, there's no reason to keep tags in the MM right now. When we are ready to implement tags on the frontend, we can refer back to the implementation here and use it if it supports the design.
- Fix all tests.
This commit is contained in:
psychedelicious 2024-03-04 21:38:21 +11:00
parent 0b9a212363
commit 44c40d7d1a
18 changed files with 170 additions and 1187 deletions

View File

@ -26,7 +26,6 @@ from ..services.invocation_services import InvocationServices
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
from ..services.invoker import Invoker from ..services.invoker import Invoker
from ..services.model_manager.model_manager_default import ModelManagerService from ..services.model_manager.model_manager_default import ModelManagerService
from ..services.model_metadata import ModelMetadataStoreSQL
from ..services.model_records import ModelRecordServiceSQL from ..services.model_records import ModelRecordServiceSQL
from ..services.names.names_default import SimpleNameService from ..services.names.names_default import SimpleNameService
from ..services.session_processor.session_processor_default import DefaultSessionProcessor from ..services.session_processor.session_processor_default import DefaultSessionProcessor
@ -93,10 +92,9 @@ class ApiDependencies:
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True) ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
) )
download_queue_service = DownloadQueueService(event_bus=events) download_queue_service = DownloadQueueService(event_bus=events)
model_metadata_service = ModelMetadataStoreSQL(db=db)
model_manager = ModelManagerService.build_model_manager( model_manager = ModelManagerService.build_model_manager(
app_config=configuration, app_config=configuration,
model_record_service=ModelRecordServiceSQL(db=db, metadata_store=model_metadata_service), model_record_service=ModelRecordServiceSQL(db=db),
download_queue=download_queue_service, download_queue=download_queue_service,
events=events, events=events,
) )

View File

@ -12,7 +12,6 @@ from starlette.exceptions import HTTPException
from typing_extensions import Annotated from typing_extensions import Annotated
from invokeai.app.services.model_install import ModelInstallJob from invokeai.app.services.model_install import ModelInstallJob
from invokeai.app.services.model_metadata.metadata_store_base import ModelMetadataChanges
from invokeai.app.services.model_records import ( from invokeai.app.services.model_records import (
DuplicateModelException, DuplicateModelException,
InvalidModelException, InvalidModelException,
@ -30,8 +29,6 @@ from invokeai.backend.model_manager.config import (
SubModelType, SubModelType,
) )
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from invokeai.backend.model_manager.metadata.metadata_base import BaseMetadata
from invokeai.backend.model_manager.search import ModelSearch from invokeai.backend.model_manager.search import ModelSearch
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
@ -90,44 +87,6 @@ example_model_input = {
"variant": "normal", "variant": "normal",
} }
example_model_metadata = {
"name": "ip_adapter_sd_image_encoder",
"author": "InvokeAI",
"tags": [
"transformers",
"safetensors",
"clip_vision_model",
"endpoints_compatible",
"region:us",
"has_space",
"license:apache-2.0",
],
"files": [
{
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/README.md",
"path": "ip_adapter_sd_image_encoder/README.md",
"size": 628,
"sha256": None,
},
{
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/config.json",
"path": "ip_adapter_sd_image_encoder/config.json",
"size": 560,
"sha256": None,
},
{
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/model.safetensors",
"path": "ip_adapter_sd_image_encoder/model.safetensors",
"size": 2528373448,
"sha256": "6ca9667da1ca9e0b0f75e46bb030f7e011f44f86cbfb8d5a36590fcd7507b030",
},
],
"type": "huggingface",
"id": "InvokeAI/ip_adapter_sd_image_encoder",
"tag_dict": {"license": "apache-2.0"},
"last_modified": "2023-09-23T17:33:25Z",
}
############################################################################## ##############################################################################
# ROUTES # ROUTES
############################################################################## ##############################################################################
@ -219,79 +178,6 @@ async def list_model_summary(
return results return results
@model_manager_router.get(
"/i/{key}/metadata",
operation_id="get_model_metadata",
responses={
200: {
"description": "The model metadata was retrieved successfully",
"content": {"application/json": {"example": example_model_metadata}},
},
400: {"description": "Bad request"},
},
)
async def get_model_metadata(
key: str = Path(description="Key of the model repo metadata to fetch."),
) -> Optional[AnyModelRepoMetadata]:
"""Get a model metadata object."""
record_store = ApiDependencies.invoker.services.model_manager.store
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
return result
@model_manager_router.patch(
"/i/{key}/metadata",
operation_id="update_model_metadata",
responses={
201: {
"description": "The model metadata was updated successfully",
"content": {"application/json": {"example": example_model_metadata}},
},
400: {"description": "Bad request"},
},
)
async def update_model_metadata(
key: str = Path(description="Key of the model repo metadata to fetch."),
changes: ModelMetadataChanges = Body(description="The changes"),
) -> Optional[AnyModelRepoMetadata]:
"""Updates or creates a model metadata object."""
record_store = ApiDependencies.invoker.services.model_manager.store
metadata_store = ApiDependencies.invoker.services.model_manager.store.metadata_store
try:
original_metadata = record_store.get_metadata(key)
if original_metadata:
if changes.default_settings:
original_metadata.default_settings = changes.default_settings
metadata_store.update_metadata(key, original_metadata)
else:
metadata_store.add_metadata(
key, BaseMetadata(name="", author="", default_settings=changes.default_settings)
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"An error occurred while updating the model metadata: {e}",
)
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
return result
@model_manager_router.get(
"/tags",
operation_id="list_tags",
)
async def list_tags() -> Set[str]:
"""Get a unique set of all the model tags."""
record_store = ApiDependencies.invoker.services.model_manager.store
result: Set[str] = record_store.list_tags()
return result
class FoundModel(BaseModel): class FoundModel(BaseModel):
path: str = Field(description="Path to the model") path: str = Field(description="Path to the model")
is_installed: bool = Field(description="Whether or not the model is already installed") is_installed: bool = Field(description="Whether or not the model is already installed")
@ -361,19 +247,6 @@ async def scan_for_models(
return scan_results return scan_results
@model_manager_router.get(
"/tags/search",
operation_id="search_by_metadata_tags",
)
async def search_by_metadata_tags(
tags: Set[str] = Query(default=None, description="Tags to search for"),
) -> ModelsList:
"""Get a list of models."""
record_store = ApiDependencies.invoker.services.model_manager.store
results = record_store.search_by_metadata_tag(tags)
return ModelsList(models=results)
@model_manager_router.patch( @model_manager_router.patch(
"/i/{key}", "/i/{key}",
operation_id="update_model_record", operation_id="update_model_record",
@ -562,9 +435,8 @@ async def list_model_install_jobs() -> List[ModelInstallJob]:
* "cancelled" -- Job was cancelled before completion. * "cancelled" -- Job was cancelled before completion.
Once completed, information about the model such as its size, base Once completed, information about the model such as its size, base
model, type, and metadata can be retrieved from the `config_out` model and type can be retrieved from the `config_out` field. For multi-file models such as diffusers,
field. For multi-file models such as diffusers, information on individual files information on individual files can be retrieved from `download_parts`.
can be retrieved from `download_parts`.
See the example and schema below for more information. See the example and schema below for more information.
""" """
@ -708,10 +580,6 @@ async def convert_model(
logger.error(str(e)) logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e)) raise HTTPException(status_code=409, detail=str(e))
# get the original metadata
if orig_metadata := store.get_metadata(key):
store.metadata_store.add_metadata(new_key, orig_metadata)
# delete the original safetensors file # delete the original safetensors file
installer.delete(key) installer.delete(key)

View File

@ -21,8 +21,6 @@ from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
from invokeai.backend.model_manager.config import ModelSourceType from invokeai.backend.model_manager.config import ModelSourceType
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from ..model_metadata import ModelMetadataStoreBase
class InstallStatus(str, Enum): class InstallStatus(str, Enum):
"""State of an install job running in the background.""" """State of an install job running in the background."""
@ -268,7 +266,6 @@ class ModelInstallServiceBase(ABC):
app_config: InvokeAIAppConfig, app_config: InvokeAIAppConfig,
record_store: ModelRecordServiceBase, record_store: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase, download_queue: DownloadQueueServiceBase,
metadata_store: ModelMetadataStoreBase,
event_bus: Optional["EventServiceBase"] = None, event_bus: Optional["EventServiceBase"] = None,
): ):
""" """

View File

@ -94,7 +94,6 @@ class ModelInstallService(ModelInstallServiceBase):
self._running = False self._running = False
self._session = session self._session = session
self._next_job_id = 0 self._next_job_id = 0
self._metadata_store = record_store.metadata_store # for convenience
@property @property
def app_config(self) -> InvokeAIAppConfig: # noqa D102 def app_config(self) -> InvokeAIAppConfig: # noqa D102

View File

@ -1,9 +0,0 @@
"""Init file for ModelMetadataStoreService module."""
from .metadata_store_base import ModelMetadataStoreBase
from .metadata_store_sql import ModelMetadataStoreSQL
__all__ = [
"ModelMetadataStoreBase",
"ModelMetadataStoreSQL",
]

View File

@ -1,81 +0,0 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Storage for Model Metadata
"""
from abc import ABC, abstractmethod
from typing import List, Optional, Set, Tuple
from pydantic import Field
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from invokeai.backend.model_manager.metadata.metadata_base import ModelDefaultSettings
class ModelMetadataChanges(BaseModelExcludeNull, extra="allow"):
"""A set of changes to apply to model metadata.
Only limited changes are valid:
- `default_settings`: the user-configured default settings for this model
"""
default_settings: Optional[ModelDefaultSettings] = Field(
default=None, description="The user-configured default settings for this model"
)
"""The user-configured default settings for this model"""
class ModelMetadataStoreBase(ABC):
"""Store, search and fetch model metadata retrieved from remote repositories."""
@abstractmethod
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
"""
Add a block of repo metadata to a model record.
The model record config must already exist in the database with the
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to store
"""
@abstractmethod
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
"""Retrieve the ModelRepoMetadata corresponding to model key."""
@abstractmethod
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
"""Dump out all the metadata."""
@abstractmethod
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
"""
Update metadata corresponding to the model with the indicated key.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to update
"""
@abstractmethod
def list_tags(self) -> Set[str]:
"""Return all tags in the tags table."""
@abstractmethod
def search_by_tag(self, tags: Set[str]) -> Set[str]:
"""Return the keys of models containing all of the listed tags."""
@abstractmethod
def search_by_author(self, author: str) -> Set[str]:
"""Return the keys of models authored by the indicated author."""
@abstractmethod
def search_by_name(self, name: str) -> Set[str]:
"""
Return the keys of models with the indicated name.
Note that this is the name of the model given to it by
the remote source. The user may have changed the local
name. The local name will be located in the model config
record object.
"""

View File

@ -1,223 +0,0 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
SQL Storage for Model Metadata
"""
import sqlite3
from typing import List, Optional, Set, Tuple
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException
from invokeai.backend.model_manager.metadata.fetch import ModelMetadataFetchBase
from .metadata_store_base import ModelMetadataStoreBase
class ModelMetadataStoreSQL(ModelMetadataStoreBase):
"""Store, search and fetch model metadata retrieved from remote repositories."""
def __init__(self, db: SqliteDatabase):
"""
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
:param conn: sqlite3 connection object
:param lock: threading Lock object
"""
super().__init__()
self._db = db
self._cursor = self._db.conn.cursor()
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
"""
Add a block of repo metadata to a model record.
The model record config must already exist in the database with the
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to store
"""
json_serialized = metadata.model_dump_json()
with self._db.lock:
try:
self._cursor.execute(
"""--sql
INSERT INTO model_metadata(
id,
metadata
)
VALUES (?,?);
""",
(
model_key,
json_serialized,
),
)
self._update_tags(model_key, metadata.tags)
self._db.conn.commit()
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
self._db.conn.rollback()
raise UnknownMetadataException from excp
except sqlite3.Error as excp:
self._db.conn.rollback()
raise excp
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
"""Retrieve the ModelRepoMetadata corresponding to model key."""
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT metadata FROM model_metadata
WHERE id=?;
""",
(model_key,),
)
rows = self._cursor.fetchone()
if not rows:
raise UnknownMetadataException("model metadata not found")
return ModelMetadataFetchBase.from_json(rows[0])
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
"""Dump out all the metadata."""
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT id,metadata FROM model_metadata;
""",
(),
)
rows = self._cursor.fetchall()
return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows]
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
"""
Update metadata corresponding to the model with the indicated key.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to update
"""
json_serialized = metadata.model_dump_json() # turn it into a json string.
with self._db.lock:
try:
self._cursor.execute(
"""--sql
UPDATE model_metadata
SET
metadata=?
WHERE id=?;
""",
(json_serialized, model_key),
)
if self._cursor.rowcount == 0:
raise UnknownMetadataException("model metadata not found")
self._update_tags(model_key, metadata.tags)
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
return self.get_metadata(model_key)
def list_tags(self) -> Set[str]:
"""Return all tags in the tags table."""
self._cursor.execute(
"""--sql
select tag_text from tags;
"""
)
return {x[0] for x in self._cursor.fetchall()}
def search_by_tag(self, tags: Set[str]) -> Set[str]:
"""Return the keys of models containing all of the listed tags."""
with self._db.lock:
try:
matches: Optional[Set[str]] = None
for tag in tags:
self._cursor.execute(
"""--sql
SELECT a.model_id FROM model_tags AS a,
tags AS b
WHERE a.tag_id=b.tag_id
AND b.tag_text=?;
""",
(tag,),
)
model_keys = {x[0] for x in self._cursor.fetchall()}
if matches is None:
matches = model_keys
matches = matches.intersection(model_keys)
except sqlite3.Error as e:
raise e
return matches if matches else set()
def search_by_author(self, author: str) -> Set[str]:
"""Return the keys of models authored by the indicated author."""
self._cursor.execute(
"""--sql
SELECT id FROM model_metadata
WHERE author=?;
""",
(author,),
)
return {x[0] for x in self._cursor.fetchall()}
def search_by_name(self, name: str) -> Set[str]:
"""
Return the keys of models with the indicated name.
Note that this is the name of the model given to it by
the remote source. The user may have changed the local
name. The local name will be located in the model config
record object.
"""
self._cursor.execute(
"""--sql
SELECT id FROM model_metadata
WHERE name=?;
""",
(name,),
)
return {x[0] for x in self._cursor.fetchall()}
def _update_tags(self, model_key: str, tags: Optional[Set[str]]) -> None:
"""Update tags for the model referenced by model_key."""
if tags:
# remove previous tags from this model
self._cursor.execute(
"""--sql
DELETE FROM model_tags
WHERE model_id=?;
""",
(model_key,),
)
for tag in tags:
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO tags (
tag_text
)
VALUES (?);
""",
(tag,),
)
self._cursor.execute(
"""--sql
SELECT tag_id
FROM tags
WHERE tag_text = ?
LIMIT 1;
""",
(tag,),
)
tag_id = self._cursor.fetchone()[0]
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO model_tags (
model_id,
tag_id
)
VALUES (?,?);
""",
(model_key, tag_id),
)

View File

@ -6,7 +6,7 @@ Abstract base class for storing and retrieving model configuration records.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple, Union from typing import Any, Dict, List, Optional, Set, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -17,9 +17,6 @@ from invokeai.backend.model_manager import (
ModelFormat, ModelFormat,
ModelType, ModelType,
) )
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from ..model_metadata import ModelMetadataStoreBase
class DuplicateModelException(Exception): class DuplicateModelException(Exception):
@ -109,40 +106,6 @@ class ModelRecordServiceBase(ABC):
""" """
pass pass
@property
@abstractmethod
def metadata_store(self) -> ModelMetadataStoreBase:
"""Return a ModelMetadataStore initialized on the same database."""
pass
@abstractmethod
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
"""
Retrieve metadata (if any) from when model was downloaded from a repo.
:param key: Model key
"""
pass
@abstractmethod
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]:
"""List metadata for all models that have it."""
pass
@abstractmethod
def search_by_metadata_tag(self, tags: Set[str]) -> List[AnyModelConfig]:
"""
Search model metadata for ones with all listed tags and return their corresponding configs.
:param tags: Set of tags to search for. All tags must be present.
"""
pass
@abstractmethod
def list_tags(self) -> Set[str]:
"""Return a unique set of all the model tags in the metadata database."""
pass
@abstractmethod @abstractmethod
def list_models( def list_models(
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default

View File

@ -43,7 +43,7 @@ import json
import sqlite3 import sqlite3
from math import ceil from math import ceil
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple, Union from typing import Any, Dict, List, Optional, Union
from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.config import (
@ -53,9 +53,7 @@ from invokeai.backend.model_manager.config import (
ModelFormat, ModelFormat,
ModelType, ModelType,
) )
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException
from ..model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL
from ..shared.sqlite.sqlite_database import SqliteDatabase from ..shared.sqlite.sqlite_database import SqliteDatabase
from .model_records_base import ( from .model_records_base import (
DuplicateModelException, DuplicateModelException,
@ -69,7 +67,7 @@ from .model_records_base import (
class ModelRecordServiceSQL(ModelRecordServiceBase): class ModelRecordServiceSQL(ModelRecordServiceBase):
"""Implementation of the ModelConfigStore ABC using a SQL database.""" """Implementation of the ModelConfigStore ABC using a SQL database."""
def __init__(self, db: SqliteDatabase, metadata_store: ModelMetadataStoreBase): def __init__(self, db: SqliteDatabase):
""" """
Initialize a new object from preexisting sqlite3 connection and threading lock objects. Initialize a new object from preexisting sqlite3 connection and threading lock objects.
@ -78,7 +76,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
super().__init__() super().__init__()
self._db = db self._db = db
self._cursor = db.conn.cursor() self._cursor = db.conn.cursor()
self._metadata_store = metadata_store
@property @property
def db(self) -> SqliteDatabase: def db(self) -> SqliteDatabase:
@ -242,9 +239,8 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
If none of the optional filters are passed, will return all If none of the optional filters are passed, will return all
models in the database. models in the database.
""" """
results = [] where_clause: list[str] = []
where_clause = [] bindings: list[str] = []
bindings = []
if model_name: if model_name:
where_clause.append("name=?") where_clause.append("name=?")
bindings.append(model_name) bindings.append(model_name)
@ -302,55 +298,17 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
] ]
return results return results
@property
def metadata_store(self) -> ModelMetadataStoreBase:
"""Return a ModelMetadataStore initialized on the same database."""
return self._metadata_store
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
"""
Retrieve metadata (if any) from when model was downloaded from a repo.
:param key: Model key
"""
store = self.metadata_store
try:
metadata = store.get_metadata(key)
return metadata
except UnknownMetadataException:
return None
def search_by_metadata_tag(self, tags: Set[str]) -> List[AnyModelConfig]:
"""
Search model metadata for ones with all listed tags and return their corresponding configs.
:param tags: Set of tags to search for. All tags must be present.
"""
store = ModelMetadataStoreSQL(self._db)
keys = store.search_by_tag(tags)
return [self.get_model(x) for x in keys]
def list_tags(self) -> Set[str]:
"""Return a unique set of all the model tags in the metadata database."""
store = ModelMetadataStoreSQL(self._db)
return store.list_tags()
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]:
"""List metadata for all models that have it."""
store = ModelMetadataStoreSQL(self._db)
return store.list_all_metadata()
def list_models( def list_models(
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
) -> PaginatedResults[ModelSummary]: ) -> PaginatedResults[ModelSummary]:
"""Return a paginated summary listing of each model in the database.""" """Return a paginated summary listing of each model in the database."""
assert isinstance(order_by, ModelRecordOrderBy) assert isinstance(order_by, ModelRecordOrderBy)
ordering = { ordering = {
ModelRecordOrderBy.Default: "a.type, a.base, a.format, a.name", ModelRecordOrderBy.Default: "type, base, format, name",
ModelRecordOrderBy.Type: "a.type", ModelRecordOrderBy.Type: "type",
ModelRecordOrderBy.Base: "a.base", ModelRecordOrderBy.Base: "base",
ModelRecordOrderBy.Name: "a.name", ModelRecordOrderBy.Name: "name",
ModelRecordOrderBy.Format: "a.format", ModelRecordOrderBy.Format: "format",
} }
# Lock so that the database isn't updated while we're doing the two queries. # Lock so that the database isn't updated while we're doing the two queries.
@ -364,7 +322,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
) )
total = int(self._cursor.fetchone()[0]) total = int(self._cursor.fetchone()[0])
# query2: fetch key fields from the join of models and model_metadata # query2: fetch key fields
self._cursor.execute( self._cursor.execute(
f"""--sql f"""--sql
SELECT config SELECT config

View File

@ -6,6 +6,15 @@ from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import
class Migration7Callback: class Migration7Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None: def __call__(self, cursor: sqlite3.Cursor) -> None:
self._create_models_table(cursor) self._create_models_table(cursor)
self._drop_old_models_tables(cursor)
def _drop_old_models_tables(self, cursor: sqlite3.Cursor) -> None:
"""Drops the old model_records, model_metadata, model_tags and tags tables."""
tables = ["model_records", "model_metadata", "model_tags", "tags"]
for table in tables:
cursor.execute(f"DROP TABLE IF EXISTS {table};")
def _create_models_table(self, cursor: sqlite3.Cursor) -> None: def _create_models_table(self, cursor: sqlite3.Cursor) -> None:
"""Creates the v4.0.0 models table.""" """Creates the v4.0.0 models table."""
@ -67,7 +76,7 @@ def build_migration_7() -> Migration:
This migration does the following: This migration does the following:
- Adds the new models table - Adds the new models table
- TODO(MM2): Drops the old model_records, model_metadata, model_tags and tags tables. - Drops the old model_records, model_metadata, model_tags and tags tables.
- TODO(MM2): Migrates model names and descriptions from `models.yaml` to the new table (?). - TODO(MM2): Migrates model names and descriptions from `models.yaml` to the new table (?).
""" """
migration_7 = Migration( migration_7 = Migration(

View File

@ -1,55 +0,0 @@
import json
from typing import Optional
from pydantic import ValidationError
from invokeai.app.services.shared.graph import Edge
def get_metadata_graph_from_raw_session(session_raw: str) -> Optional[dict]:
"""
Parses raw session string, returning a dict of the graph.
Only the general graph shape is validated; none of the fields are validated.
Any `metadata_accumulator` nodes and edges are removed.
Any validation failure will return None.
"""
graph = json.loads(session_raw).get("graph", None)
# sanity check make sure the graph is at least reasonably shaped
if (
not isinstance(graph, dict)
or "nodes" not in graph
or not isinstance(graph["nodes"], dict)
or "edges" not in graph
or not isinstance(graph["edges"], list)
):
# something has gone terribly awry, return an empty dict
return None
try:
# delete the `metadata_accumulator` node
del graph["nodes"]["metadata_accumulator"]
except KeyError:
# no accumulator node, all good
pass
# delete any edges to or from it
for i, edge in enumerate(graph["edges"]):
try:
# try to parse the edge
Edge(**edge)
except ValidationError:
# something has gone terribly awry, return an empty dict
return None
if (
edge["source"]["node_id"] == "metadata_accumulator"
or edge["destination"]["node_id"] == "metadata_accumulator"
):
del graph["edges"][i]
return graph

View File

@ -29,6 +29,8 @@ from diffusers.models.modeling_utils import ModelMixin
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
from typing_extensions import Annotated, Any, Dict from typing_extensions import Annotated, Any, Dict
from invokeai.app.util.misc import uuid_string
from ..raw_model import RawModel from ..raw_model import RawModel
# ModelMixin is the base class for all diffusers and transformers models # ModelMixin is the base class for all diffusers and transformers models
@ -132,7 +134,7 @@ class ModelSourceType(str, Enum):
class ModelConfigBase(BaseModel): class ModelConfigBase(BaseModel):
"""Base class for model configuration information.""" """Base class for model configuration information."""
key: str = Field(description="A unique key for this model.") key: str = Field(description="A unique key for this model.", default_factory=uuid_string)
hash: str = Field(description="The hash of the model file(s).") hash: str = Field(description="The hash of the model file(s).")
path: str = Field( path: str = Field(
description="Path to the model on the filesystem. Relative paths are relative to the Invoke root directory." description="Path to the model on the filesystem. Relative paths are relative to the Invoke root directory."
@ -142,7 +144,9 @@ class ModelConfigBase(BaseModel):
description: Optional[str] = Field(description="Model description", default=None) description: Optional[str] = Field(description="Model description", default=None)
source: str = Field(description="The original source of the model (path, URL or repo_id).") source: str = Field(description="The original source of the model (path, URL or repo_id).")
source_type: ModelSourceType = Field(description="The type of source") source_type: ModelSourceType = Field(description="The type of source")
source_api_response: Optional[str] = Field(description="The original API response from the source, as stringified JSON.", default=None) source_api_response: Optional[str] = Field(
description="The original API response from the source, as stringified JSON.", default=None
)
trigger_words: Optional[set[str]] = Field(description="Set of trigger words for this model", default=None) trigger_words: Optional[set[str]] = Field(description="Set of trigger words for this model", default=None)
model_config = ConfigDict(use_enum_values=False, validate_assignment=True) model_config = ConfigDict(use_enum_values=False, validate_assignment=True)

View File

@ -91,7 +91,7 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
) )
return HuggingFaceMetadata( return HuggingFaceMetadata(
id=model_info.id, name=name, files=files, api_response=json.dumps(model_info.__dict__) id=model_info.id, name=name, files=files, api_response=json.dumps(model_info.__dict__, default=str)
) )
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata: def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:

View File

@ -1,221 +0,0 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
SQL Storage for Model Metadata
"""
import sqlite3
from typing import List, Optional, Set, Tuple
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from .fetch import ModelMetadataFetchBase
from .metadata_base import AnyModelRepoMetadata, UnknownMetadataException
class ModelMetadataStore:
"""Store, search and fetch model metadata retrieved from remote repositories."""
def __init__(self, db: SqliteDatabase):
"""
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
:param conn: sqlite3 connection object
:param lock: threading Lock object
"""
super().__init__()
self._db = db
self._cursor = self._db.conn.cursor()
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
"""
Add a block of repo metadata to a model record.
The model record config must already exist in the database with the
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to store
"""
json_serialized = metadata.model_dump_json()
with self._db.lock:
try:
self._cursor.execute(
"""--sql
INSERT INTO model_metadata(
id,
metadata
)
VALUES (?,?);
""",
(
model_key,
json_serialized,
),
)
self._update_tags(model_key, metadata.tags)
self._db.conn.commit()
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
self._db.conn.rollback()
raise UnknownMetadataException from excp
except sqlite3.Error as excp:
self._db.conn.rollback()
raise excp
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
"""Retrieve the ModelRepoMetadata corresponding to model key."""
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT metadata FROM model_metadata
WHERE id=?;
""",
(model_key,),
)
rows = self._cursor.fetchone()
if not rows:
raise UnknownMetadataException("model metadata not found")
return ModelMetadataFetchBase.from_json(rows[0])
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
"""Dump out all the metadata."""
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT id,metadata FROM model_metadata;
""",
(),
)
rows = self._cursor.fetchall()
return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows]
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
"""
Update metadata corresponding to the model with the indicated key.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to update
"""
json_serialized = metadata.model_dump_json() # turn it into a json string.
with self._db.lock:
try:
self._cursor.execute(
"""--sql
UPDATE model_metadata
SET
metadata=?
WHERE id=?;
""",
(json_serialized, model_key),
)
if self._cursor.rowcount == 0:
raise UnknownMetadataException("model metadata not found")
self._update_tags(model_key, metadata.tags)
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
return self.get_metadata(model_key)
def list_tags(self) -> Set[str]:
"""Return all tags in the tags table."""
self._cursor.execute(
"""--sql
select tag_text from tags;
"""
)
return {x[0] for x in self._cursor.fetchall()}
def search_by_tag(self, tags: Set[str]) -> Set[str]:
"""Return the keys of models containing all of the listed tags."""
with self._db.lock:
try:
matches: Optional[Set[str]] = None
for tag in tags:
self._cursor.execute(
"""--sql
SELECT a.model_id FROM model_tags AS a,
tags AS b
WHERE a.tag_id=b.tag_id
AND b.tag_text=?;
""",
(tag,),
)
model_keys = {x[0] for x in self._cursor.fetchall()}
if matches is None:
matches = model_keys
matches = matches.intersection(model_keys)
except sqlite3.Error as e:
raise e
return matches if matches else set()
def search_by_author(self, author: str) -> Set[str]:
"""Return the keys of models authored by the indicated author."""
self._cursor.execute(
"""--sql
SELECT id FROM model_metadata
WHERE author=?;
""",
(author,),
)
return {x[0] for x in self._cursor.fetchall()}
def search_by_name(self, name: str) -> Set[str]:
"""
Return the keys of models with the indicated name.
Note that this is the name of the model given to it by
the remote source. The user may have changed the local
name. The local name will be located in the model config
record object.
"""
self._cursor.execute(
"""--sql
SELECT id FROM model_metadata
WHERE name=?;
""",
(name,),
)
return {x[0] for x in self._cursor.fetchall()}
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
"""Update tags for the model referenced by model_key."""
# remove previous tags from this model
self._cursor.execute(
"""--sql
DELETE FROM model_tags
WHERE model_id=?;
""",
(model_key,),
)
for tag in tags:
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO tags (
tag_text
)
VALUES (?);
""",
(tag,),
)
self._cursor.execute(
"""--sql
SELECT tag_id
FROM tags
WHERE tag_text = ?
LIMIT 1;
""",
(tag,),
)
tag_id = self._cursor.fetchone()[0]
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO model_tags (
model_id,
tag_id
)
VALUES (?,?);
""",
(model_key, tag_id),
)

View File

@ -18,6 +18,7 @@ from .config import (
ModelConfigFactory, ModelConfigFactory,
ModelFormat, ModelFormat,
ModelRepoVariant, ModelRepoVariant,
ModelSourceType,
ModelType, ModelType,
ModelVariantType, ModelVariantType,
SchedulerPredictionType, SchedulerPredictionType,
@ -150,7 +151,7 @@ class ModelProbe(object):
probe = probe_class(model_path) probe = probe_class(model_path)
fields["source_type"] = fields.get("source_type") fields["source_type"] = fields.get("source_type") or ModelSourceType.Path
fields["source"] = fields.get("source") or model_path.as_posix() fields["source"] = fields.get("source") or model_path.as_posix()
fields["key"] = fields.get("key", uuid_string()) fields["key"] = fields.get("key", uuid_string())
fields["path"] = model_path.as_posix() fields["path"] = model_path.as_posix()

View File

@ -3,29 +3,26 @@ Test the refactored model config classes.
""" """
from hashlib import sha256 from hashlib import sha256
from typing import Any from typing import Any, Optional
import pytest import pytest
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_metadata import ModelMetadataStoreSQL
from invokeai.app.services.model_records import ( from invokeai.app.services.model_records import (
DuplicateModelException, DuplicateModelException,
ModelRecordOrderBy,
ModelRecordServiceBase, ModelRecordServiceBase,
ModelRecordServiceSQL, ModelRecordServiceSQL,
UnknownModelException, UnknownModelException,
) )
from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.config import (
BaseModelType, BaseModelType,
MainCheckpointConfig,
MainDiffusersConfig, MainDiffusersConfig,
ModelFormat, ModelFormat,
ModelSourceType,
ModelType, ModelType,
TextualInversionFileConfig, TextualInversionFileConfig,
VaeDiffusersConfig, VaeDiffusersConfig,
) )
from invokeai.backend.model_manager.metadata import BaseMetadata
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
from tests.fixtures.sqlite_database import create_mock_sqlite_database from tests.fixtures.sqlite_database import create_mock_sqlite_database
@ -38,11 +35,13 @@ def store(
config = InvokeAIAppConfig(root=datadir) config = InvokeAIAppConfig(root=datadir)
logger = InvokeAILogger.get_logger(config=config) logger = InvokeAILogger.get_logger(config=config)
db = create_mock_sqlite_database(config, logger) db = create_mock_sqlite_database(config, logger)
return ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db)) return ModelRecordServiceSQL(db)
def example_config() -> TextualInversionFileConfig: def example_config(key: Optional[str] = None) -> TextualInversionFileConfig:
return TextualInversionFileConfig( config = TextualInversionFileConfig(
source="test/source/",
source_type=ModelSourceType.Path,
path="/tmp/pokemon.bin", path="/tmp/pokemon.bin",
name="old name", name="old name",
base=BaseModelType.StableDiffusion1, base=BaseModelType.StableDiffusion1,
@ -50,59 +49,45 @@ def example_config() -> TextualInversionFileConfig:
format=ModelFormat.EmbeddingFile, format=ModelFormat.EmbeddingFile,
hash="ABC123", hash="ABC123",
) )
if key is not None:
config.key = key
return config
def test_type(store: ModelRecordServiceBase): def test_type(store: ModelRecordServiceBase):
config = example_config() config = example_config("key1")
store.add_model("key1", config) store.add_model(config)
config1 = store.get_model("key1") config1 = store.get_model("key1")
assert type(config1) == TextualInversionFileConfig assert isinstance(config1, TextualInversionFileConfig)
def test_add(store: ModelRecordServiceBase): def test_raises_on_violating_uniqueness(store: ModelRecordServiceBase):
raw = { # Models have a uniqueness constraint by their name, base and type
"path": "/tmp/foo.ckpt", config1 = example_config("key1")
"name": "model1", config2 = config1.model_copy(deep=True)
"base": BaseModelType.StableDiffusion1, config2.key = "key2"
"type": "main", store.add_model(config1)
"config_path": "/tmp/foo.yaml",
"variant": "normal",
"format": "checkpoint",
"original_hash": "111222333444",
}
store.add_model("key1", raw)
config1 = store.get_model("key1")
assert config1 is not None
assert type(config1) == MainCheckpointConfig
assert config1.base == BaseModelType.StableDiffusion1
assert config1.name == "model1"
assert config1.hash == "111222333444"
def test_dup(store: ModelRecordServiceBase):
config = example_config()
store.add_model("key1", example_config())
with pytest.raises(DuplicateModelException): with pytest.raises(DuplicateModelException):
store.add_model("key1", config) store.add_model(config1)
with pytest.raises(DuplicateModelException): with pytest.raises(DuplicateModelException):
store.add_model("key2", config) store.add_model(config2)
def test_update(store: ModelRecordServiceBase): def test_update(store: ModelRecordServiceBase):
config = example_config() config = example_config("key1")
store.add_model("key1", config) store.add_model(config)
config = store.get_model("key1") config = store.get_model("key1")
assert config.name == "old name" assert config.name == "old name"
config.name = "new name" config.name = "new name"
store.update_model("key1", config) store.update_model(config.key, config)
new_config = store.get_model("key1") new_config = store.get_model("key1")
assert new_config.name == "new name" assert new_config.name == "new name"
def test_rename(store: ModelRecordServiceBase): def test_rename(store: ModelRecordServiceBase):
config = example_config() config = example_config("key1")
store.add_model("key1", config) store.add_model(config)
config = store.get_model("key1") config = store.get_model("key1")
assert config.name == "old name" assert config.name == "old name"
@ -112,15 +97,15 @@ def test_rename(store: ModelRecordServiceBase):
def test_unknown_key(store: ModelRecordServiceBase): def test_unknown_key(store: ModelRecordServiceBase):
config = example_config() config = example_config("key1")
store.add_model("key1", config) store.add_model(config)
with pytest.raises(UnknownModelException): with pytest.raises(UnknownModelException):
store.update_model("unknown_key", config) store.update_model("unknown_key", config)
def test_delete(store: ModelRecordServiceBase): def test_delete(store: ModelRecordServiceBase):
config = example_config() config = example_config("key1")
store.add_model("key1", config) store.add_model(config)
config = store.get_model("key1") config = store.get_model("key1")
store.del_model("key1") store.del_model("key1")
with pytest.raises(UnknownModelException): with pytest.raises(UnknownModelException):
@ -128,36 +113,45 @@ def test_delete(store: ModelRecordServiceBase):
def test_exists(store: ModelRecordServiceBase): def test_exists(store: ModelRecordServiceBase):
config = example_config() config = example_config("key1")
store.add_model("key1", config) store.add_model(config)
assert store.exists("key1") assert store.exists("key1")
assert not store.exists("key2") assert not store.exists("key2")
def test_filter(store: ModelRecordServiceBase): def test_filter(store: ModelRecordServiceBase):
config1 = MainDiffusersConfig( config1 = MainDiffusersConfig(
key="config1",
path="/tmp/config1", path="/tmp/config1",
name="config1", name="config1",
base=BaseModelType.StableDiffusion1, base=BaseModelType.StableDiffusion1,
type=ModelType.Main, type=ModelType.Main,
hash="CONFIG1HASH", hash="CONFIG1HASH",
source="test/source",
source_type=ModelSourceType.Path,
) )
config2 = MainDiffusersConfig( config2 = MainDiffusersConfig(
key="config2",
path="/tmp/config2", path="/tmp/config2",
name="config2", name="config2",
base=BaseModelType.StableDiffusion1, base=BaseModelType.StableDiffusion1,
type=ModelType.Main, type=ModelType.Main,
hash="CONFIG2HASH", hash="CONFIG2HASH",
source="test/source",
source_type=ModelSourceType.Path,
) )
config3 = VaeDiffusersConfig( config3 = VaeDiffusersConfig(
key="config3",
path="/tmp/config3", path="/tmp/config3",
name="config3", name="config3",
base=BaseModelType("sd-2"), base=BaseModelType("sd-2"),
type=ModelType.Vae, type=ModelType.Vae,
hash="CONFIG3HASH", hash="CONFIG3HASH",
source="test/source",
source_type=ModelSourceType.Path,
) )
for c in config1, config2, config3: for c in config1, config2, config3:
store.add_model(sha256(c.name.encode("utf-8")).hexdigest(), c) store.add_model(c)
matches = store.search_by_attr(model_type=ModelType.Main) matches = store.search_by_attr(model_type=ModelType.Main)
assert len(matches) == 2 assert len(matches) == 2
assert matches[0].name in {"config1", "config2"} assert matches[0].name in {"config1", "config2"}
@ -165,7 +159,7 @@ def test_filter(store: ModelRecordServiceBase):
matches = store.search_by_attr(model_type=ModelType.Vae) matches = store.search_by_attr(model_type=ModelType.Vae)
assert len(matches) == 1 assert len(matches) == 1
assert matches[0].name == "config3" assert matches[0].name == "config3"
assert matches[0].key == sha256("config3".encode("utf-8")).hexdigest() assert matches[0].key == "config3"
assert isinstance(matches[0].type, ModelType) # This tests that we get proper enums back assert isinstance(matches[0].type, ModelType) # This tests that we get proper enums back
matches = store.search_by_hash("CONFIG1HASH") matches = store.search_by_hash("CONFIG1HASH")
@ -183,6 +177,8 @@ def test_unique(store: ModelRecordServiceBase):
type=ModelType.Main, type=ModelType.Main,
name="nonuniquename", name="nonuniquename",
hash="CONFIG1HASH", hash="CONFIG1HASH",
source="test/source/",
source_type=ModelSourceType.Path,
) )
config2 = MainDiffusersConfig( config2 = MainDiffusersConfig(
path="/tmp/config2", path="/tmp/config2",
@ -190,6 +186,8 @@ def test_unique(store: ModelRecordServiceBase):
type=ModelType.Main, type=ModelType.Main,
name="nonuniquename", name="nonuniquename",
hash="CONFIG1HASH", hash="CONFIG1HASH",
source="test/source/",
source_type=ModelSourceType.Path,
) )
config3 = VaeDiffusersConfig( config3 = VaeDiffusersConfig(
path="/tmp/config3", path="/tmp/config3",
@ -197,6 +195,8 @@ def test_unique(store: ModelRecordServiceBase):
type=ModelType.Vae, type=ModelType.Vae,
name="nonuniquename", name="nonuniquename",
hash="CONFIG1HASH", hash="CONFIG1HASH",
source="test/source/",
source_type=ModelSourceType.Path,
) )
config4 = MainDiffusersConfig( config4 = MainDiffusersConfig(
path="/tmp/config4", path="/tmp/config4",
@ -204,15 +204,19 @@ def test_unique(store: ModelRecordServiceBase):
type=ModelType.Main, type=ModelType.Main,
name="nonuniquename", name="nonuniquename",
hash="CONFIG1HASH", hash="CONFIG1HASH",
source="test/source/",
source_type=ModelSourceType.Path,
) )
# config1, config2 and config3 are compatible because they have unique combos # config1, config2 and config3 are compatible because they have unique combos
# of name, type and base # of name, type and base
for c in config1, config2, config3: for c in config1, config2, config3:
store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), c) c.key = sha256(c.path.encode("utf-8")).hexdigest()
store.add_model(c)
# config4 clashes with config1 and should raise an integrity error # config4 clashes with config1 and should raise an integrity error
with pytest.raises(DuplicateModelException): with pytest.raises(DuplicateModelException):
store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), config4) config4.key = sha256(config4.path.encode("utf-8")).hexdigest()
store.add_model(config4)
def test_filter_2(store: ModelRecordServiceBase): def test_filter_2(store: ModelRecordServiceBase):
@ -222,6 +226,8 @@ def test_filter_2(store: ModelRecordServiceBase):
base=BaseModelType.StableDiffusion1, base=BaseModelType.StableDiffusion1,
type=ModelType.Main, type=ModelType.Main,
hash="CONFIG1HASH", hash="CONFIG1HASH",
source="test/source/",
source_type=ModelSourceType.Path,
) )
config2 = MainDiffusersConfig( config2 = MainDiffusersConfig(
path="/tmp/config2", path="/tmp/config2",
@ -229,6 +235,8 @@ def test_filter_2(store: ModelRecordServiceBase):
base=BaseModelType.StableDiffusion1, base=BaseModelType.StableDiffusion1,
type=ModelType.Main, type=ModelType.Main,
hash="CONFIG2HASH", hash="CONFIG2HASH",
source="test/source/",
source_type=ModelSourceType.Path,
) )
config3 = MainDiffusersConfig( config3 = MainDiffusersConfig(
path="/tmp/config3", path="/tmp/config3",
@ -236,6 +244,8 @@ def test_filter_2(store: ModelRecordServiceBase):
base=BaseModelType("sd-2"), base=BaseModelType("sd-2"),
type=ModelType.Main, type=ModelType.Main,
hash="CONFIG3HASH", hash="CONFIG3HASH",
source="test/source/",
source_type=ModelSourceType.Path,
) )
config4 = MainDiffusersConfig( config4 = MainDiffusersConfig(
path="/tmp/config4", path="/tmp/config4",
@ -243,6 +253,8 @@ def test_filter_2(store: ModelRecordServiceBase):
base=BaseModelType("sdxl"), base=BaseModelType("sdxl"),
type=ModelType.Main, type=ModelType.Main,
hash="CONFIG3HASH", hash="CONFIG3HASH",
source="test/source/",
source_type=ModelSourceType.Path,
) )
config5 = VaeDiffusersConfig( config5 = VaeDiffusersConfig(
path="/tmp/config5", path="/tmp/config5",
@ -250,9 +262,11 @@ def test_filter_2(store: ModelRecordServiceBase):
base=BaseModelType.StableDiffusion1, base=BaseModelType.StableDiffusion1,
type=ModelType.Vae, type=ModelType.Vae,
hash="CONFIG3HASH", hash="CONFIG3HASH",
source="test/source/",
source_type=ModelSourceType.Path,
) )
for c in config1, config2, config3, config4, config5: for c in config1, config2, config3, config4, config5:
store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), c) store.add_model(c)
matches = store.search_by_attr( matches = store.search_by_attr(
model_type=ModelType.Main, model_type=ModelType.Main,
@ -272,50 +286,3 @@ def test_filter_2(store: ModelRecordServiceBase):
model_name="dup_name1", model_name="dup_name1",
) )
assert len(matches) == 1 assert len(matches) == 1
def test_summary(mm2_record_store: ModelRecordServiceSQL) -> None:
# The fixture provides us with five configs.
for x in range(1, 5):
key = f"test_config_{x}"
name = f"name_{x}"
author = f"author_{x}"
tags = {f"tag{y}" for y in range(1, x)}
mm2_record_store.metadata_store.add_metadata(
model_key=key, metadata=BaseMetadata(name=name, author=author, tags=tags)
)
# sanity check that the tags sent in all right
assert mm2_record_store.get_metadata("test_config_3").tags == {"tag1", "tag2"}
assert mm2_record_store.get_metadata("test_config_4").tags == {"tag1", "tag2", "tag3"}
# get summary
summary1 = mm2_record_store.list_models(page=0, per_page=100)
assert summary1.page == 0
assert summary1.pages == 1
assert summary1.per_page == 100
assert summary1.total == 5
assert len(summary1.items) == 5
assert summary1.items[0].name == "test5" # lora / sd-1 / diffusers / test5
# find test_config_3
config3 = [x for x in summary1.items if x.key == "test_config_3"][0]
assert config3.description == "This is test 3"
assert config3.tags == {"tag1", "tag2"}
# find test_config_5
config5 = [x for x in summary1.items if x.key == "test_config_5"][0]
assert config5.tags == set()
assert config5.description == ""
# test paging
summary2 = mm2_record_store.list_models(page=1, per_page=2)
assert summary2.page == 1
assert summary2.per_page == 2
assert summary2.pages == 3
assert summary1.items[2].name == summary2.items[0].name
# test sorting
summary = mm2_record_store.list_models(page=0, per_page=100, order_by=ModelRecordOrderBy.Name)
print(summary.items)
assert summary.items[0].name == "model1"
assert summary.items[-1].name == "test5"

View File

@ -18,12 +18,17 @@ from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase
from invokeai.app.services.model_load import ModelLoadService, ModelLoadServiceBase from invokeai.app.services.model_load import ModelLoadService, ModelLoadServiceBase
from invokeai.app.services.model_manager import ModelManagerService, ModelManagerServiceBase from invokeai.app.services.model_manager import ModelManagerService, ModelManagerServiceBase
from invokeai.app.services.model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.config import (
BaseModelType, BaseModelType,
LoRADiffusersConfig,
MainCheckpointConfig,
MainDiffusersConfig,
ModelFormat, ModelFormat,
ModelSourceType,
ModelType, ModelType,
ModelVariantType,
VaeDiffusersConfig,
) )
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
@ -107,11 +112,6 @@ def mm2_download_queue(mm2_session: Session, request: FixtureRequest) -> Downloa
return download_queue return download_queue
@pytest.fixture
def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStoreBase:
return mm2_record_store.metadata_store
@pytest.fixture @pytest.fixture
def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase: def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase:
ram_cache = ModelCache( ram_cache = ModelCache(
@ -137,7 +137,7 @@ def mm2_installer(
logger = InvokeAILogger.get_logger() logger = InvokeAILogger.get_logger()
db = create_mock_sqlite_database(mm2_app_config, logger) db = create_mock_sqlite_database(mm2_app_config, logger)
events = DummyEventService() events = DummyEventService()
store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db)) store = ModelRecordServiceSQL(db)
installer = ModelInstallService( installer = ModelInstallService(
app_config=mm2_app_config, app_config=mm2_app_config,
@ -160,61 +160,71 @@ def mm2_installer(
def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBase: def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
logger = InvokeAILogger.get_logger(config=mm2_app_config) logger = InvokeAILogger.get_logger(config=mm2_app_config)
db = create_mock_sqlite_database(mm2_app_config, logger) db = create_mock_sqlite_database(mm2_app_config, logger)
store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db)) store = ModelRecordServiceSQL(db)
# add five simple config records to the database # add five simple config records to the database
raw1 = { config1 = VaeDiffusersConfig(
"path": "/tmp/foo1", key="test_config_1",
"format": ModelFormat("diffusers"), path="/tmp/foo1",
"name": "test2", format=ModelFormat.Diffusers,
"base": BaseModelType("sd-2"), name="test2",
"type": ModelType("vae"), base=BaseModelType.StableDiffusion2,
"original_hash": "111222333444", type=ModelType.Vae,
"source": "stabilityai/sdxl-vae", hash="111222333444",
} source="stabilityai/sdxl-vae",
raw2 = { source_type=ModelSourceType.HFRepoID,
"path": "/tmp/foo2.ckpt", )
"name": "model1", config2 = MainCheckpointConfig(
"format": ModelFormat("checkpoint"), key="test_config_2",
"base": BaseModelType("sd-1"), path="/tmp/foo2.ckpt",
"type": "main", name="model1",
"config_path": "/tmp/foo.yaml", format=ModelFormat.Checkpoint,
"variant": "normal", base=BaseModelType.StableDiffusion1,
"original_hash": "111222333444", type=ModelType.Main,
"source": "https://civitai.com/models/206883/split", config_path="/tmp/foo.yaml",
} variant=ModelVariantType.Normal,
raw3 = { hash="111222333444",
"path": "/tmp/foo3", source="https://civitai.com/models/206883/split",
"format": ModelFormat("diffusers"), source_type=ModelSourceType.CivitAI,
"name": "test3", )
"base": BaseModelType("sdxl"), config3 = MainDiffusersConfig(
"type": ModelType("main"), key="test_config_3",
"original_hash": "111222333444", path="/tmp/foo3",
"source": "author3/model3", format=ModelFormat.Diffusers,
"description": "This is test 3", name="test3",
} base=BaseModelType.StableDiffusionXL,
raw4 = { type=ModelType.Main,
"path": "/tmp/foo4", hash="111222333444",
"format": ModelFormat("diffusers"), source="author3/model3",
"name": "test4", description="This is test 3",
"base": BaseModelType("sdxl"), source_type=ModelSourceType.HFRepoID,
"type": ModelType("lora"), )
"original_hash": "111222333444", config4 = LoRADiffusersConfig(
"source": "author4/model4", key="test_config_4",
} path="/tmp/foo4",
raw5 = { format=ModelFormat.Diffusers,
"path": "/tmp/foo5", name="test4",
"format": ModelFormat("diffusers"), base=BaseModelType.StableDiffusionXL,
"name": "test5", type=ModelType.Lora,
"base": BaseModelType("sd-1"), hash="111222333444",
"type": ModelType("lora"), source="author4/model4",
"original_hash": "111222333444", source_type=ModelSourceType.HFRepoID,
"source": "author4/model5", )
} config5 = LoRADiffusersConfig(
store.add_model("test_config_1", raw1) key="test_config_5",
store.add_model("test_config_2", raw2) path="/tmp/foo5",
store.add_model("test_config_3", raw3) format=ModelFormat.Diffusers,
store.add_model("test_config_4", raw4) name="test5",
store.add_model("test_config_5", raw5) base=BaseModelType.StableDiffusion1,
type=ModelType.Lora,
hash="111222333444",
source="author4/model5",
source_type=ModelSourceType.HFRepoID,
)
store.add_model(config1)
store.add_model(config2)
store.add_model(config3)
store.add_model(config4)
store.add_model(config5)
return store return store

View File

@ -1,202 +0,0 @@
"""
Test model metadata fetching and storage.
"""
import datetime
from pathlib import Path
import pytest
from pydantic.networks import HttpUrl
from requests.sessions import Session
from invokeai.app.services.model_metadata import ModelMetadataStoreBase
from invokeai.backend.model_manager.config import ModelRepoVariant
from invokeai.backend.model_manager.metadata import (
CivitaiMetadata,
CivitaiMetadataFetch,
CommercialUsage,
HuggingFaceMetadata,
HuggingFaceMetadataFetch,
UnknownMetadataException,
)
from invokeai.backend.model_manager.util import select_hf_files
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStoreBase) -> None:
tags = {"text-to-image", "diffusers"}
input_metadata = HuggingFaceMetadata(
name="sdxl-vae",
author="stabilityai",
tags=tags,
id="stabilityai/sdxl-vae",
tag_dict={"license": "other"},
last_modified=datetime.datetime.now(),
)
mm2_metadata_store.add_metadata("test_config_1", input_metadata)
output_metadata = mm2_metadata_store.get_metadata("test_config_1")
assert input_metadata == output_metadata
with pytest.raises(UnknownMetadataException):
mm2_metadata_store.add_metadata("unknown_key", input_metadata)
assert mm2_metadata_store.list_tags() == tags
def test_metadata_store_update(mm2_metadata_store: ModelMetadataStoreBase) -> None:
input_metadata = HuggingFaceMetadata(
name="sdxl-vae",
author="stabilityai",
tags={"text-to-image", "diffusers"},
id="stabilityai/sdxl-vae",
tag_dict={"license": "other"},
last_modified=datetime.datetime.now(),
)
mm2_metadata_store.add_metadata("test_config_1", input_metadata)
input_metadata.name = "new-name"
mm2_metadata_store.update_metadata("test_config_1", input_metadata)
output_metadata = mm2_metadata_store.get_metadata("test_config_1")
assert output_metadata.name == "new-name"
assert input_metadata == output_metadata
def test_metadata_search(mm2_metadata_store: ModelMetadataStoreBase) -> None:
metadata1 = HuggingFaceMetadata(
name="sdxl-vae",
author="stabilityai",
tags={"text-to-image", "diffusers"},
id="stabilityai/sdxl-vae",
tag_dict={"license": "other"},
last_modified=datetime.datetime.now(),
)
metadata2 = HuggingFaceMetadata(
name="model2",
author="stabilityai",
tags={"text-to-image", "diffusers", "community-contributed"},
id="author2/model2",
tag_dict={"license": "other"},
last_modified=datetime.datetime.now(),
)
metadata3 = HuggingFaceMetadata(
name="model3",
author="author3",
tags={"text-to-image", "checkpoint", "community-contributed"},
id="author3/model3",
tag_dict={"license": "other"},
last_modified=datetime.datetime.now(),
)
mm2_metadata_store.add_metadata("test_config_1", metadata1)
mm2_metadata_store.add_metadata("test_config_2", metadata2)
mm2_metadata_store.add_metadata("test_config_3", metadata3)
matches = mm2_metadata_store.search_by_author("stabilityai")
assert len(matches) == 2
assert "test_config_1" in matches
assert "test_config_2" in matches
matches = mm2_metadata_store.search_by_author("Sherlock Holmes")
assert not matches
matches = mm2_metadata_store.search_by_name("model3")
assert len(matches) == 1
assert "test_config_3" in matches
matches = mm2_metadata_store.search_by_tag({"text-to-image"})
assert len(matches) == 3
matches = mm2_metadata_store.search_by_tag({"text-to-image", "diffusers"})
assert len(matches) == 2
assert "test_config_1" in matches
assert "test_config_2" in matches
matches = mm2_metadata_store.search_by_tag({"checkpoint", "community-contributed"})
assert len(matches) == 1
assert "test_config_3" in matches
# does the tag table update correctly?
matches = mm2_metadata_store.search_by_tag({"checkpoint", "licensed-for-commercial-use"})
assert not matches
assert mm2_metadata_store.list_tags() == {"text-to-image", "diffusers", "community-contributed", "checkpoint"}
metadata3.tags.add("licensed-for-commercial-use")
mm2_metadata_store.update_metadata("test_config_3", metadata3)
assert mm2_metadata_store.list_tags() == {
"text-to-image",
"diffusers",
"community-contributed",
"checkpoint",
"licensed-for-commercial-use",
}
matches = mm2_metadata_store.search_by_tag({"checkpoint", "licensed-for-commercial-use"})
assert len(matches) == 1
def test_metadata_civitai_fetch(mm2_session: Session) -> None:
fetcher = CivitaiMetadataFetch(mm2_session)
metadata = fetcher.from_url(HttpUrl("https://civitai.com/models/215485/SDXL-turbo"))
assert isinstance(metadata, CivitaiMetadata)
assert metadata.id == 215485
assert metadata.author == "test_author" # note that this is not the same as the original from Civitai
assert metadata.allow_commercial_use # changed to make sure we are reading locally not remotely
assert CommercialUsage("RentCivit") in metadata.restrictions.AllowCommercialUse
assert metadata.version_id == 242807
assert metadata.tags == {"tool", "turbo", "sdxl turbo"}
def test_metadata_hf_fetch(mm2_session: Session) -> None:
fetcher = HuggingFaceMetadataFetch(mm2_session)
metadata = fetcher.from_url(HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo"))
assert isinstance(metadata, HuggingFaceMetadata)
assert metadata.author == "test_author" # this is not the same as the original
assert metadata.files
assert metadata.tags == {
"diffusers",
"onnx",
"safetensors",
"text-to-image",
"license:other",
"has_space",
"diffusers:StableDiffusionXLPipeline",
"region:us",
}
def test_metadata_hf_filter(mm2_session: Session) -> None:
metadata = HuggingFaceMetadataFetch(mm2_session).from_url(HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo"))
assert isinstance(metadata, HuggingFaceMetadata)
files = [x.path for x in metadata.files]
fp16_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("fp16"))
assert Path("sdxl-turbo/text_encoder/model.fp16.safetensors") in fp16_files
assert Path("sdxl-turbo/text_encoder/model.safetensors") not in fp16_files
fp32_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("fp32"))
assert Path("sdxl-turbo/text_encoder/model.safetensors") in fp32_files
assert Path("sdxl-turbo/text_encoder/model.16.safetensors") not in fp32_files
onnx_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("onnx"))
assert Path("sdxl-turbo/text_encoder/model.onnx") in onnx_files
assert Path("sdxl-turbo/text_encoder/model.safetensors") not in onnx_files
default_files = select_hf_files.filter_files(files)
assert Path("sdxl-turbo/text_encoder/model.safetensors") in default_files
assert Path("sdxl-turbo/text_encoder/model.16.safetensors") not in default_files
openvino_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("openvino"))
print(openvino_files)
assert len(openvino_files) == 0
flax_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("flax"))
print(flax_files)
assert not flax_files
metadata = HuggingFaceMetadataFetch(mm2_session).from_url(
HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo-nofp16")
)
assert isinstance(metadata, HuggingFaceMetadata)
files = [x.path for x in metadata.files]
filtered_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("fp16"))
assert (
Path("sdxl-turbo-nofp16/text_encoder/model.safetensors") in filtered_files
) # confirm that default is returned
assert Path("sdxl-turbo-nofp16/text_encoder/model.16.safetensors") not in filtered_files
def test_metadata_hf_urls(mm2_session: Session) -> None:
metadata = HuggingFaceMetadataFetch(mm2_session).from_url(HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo"))
assert isinstance(metadata, HuggingFaceMetadata)