Fix issues identified during PR review by RyanjDick and brandonrising

- ModelMetadataStoreService is now injected into ModelRecordStoreService
  (these two services are really joined at the hip, and should someday be merged)
- ModelRecordStoreService is now injected into ModelManagerService
- Reduced timeout value for the various installer and download wait*() methods
- Introduced a Mock modelmanager for testing
- Removed bare print() statement with _logger in the install helper backend.
- Removed unused code from model loader init file
- Made `locker` a private variable in the `LoadedModel` object.
- Fixed up model merge frontend (will be deprecated anyway!)
This commit is contained in:
Lincoln Stein 2024-02-15 22:41:29 -05:00
parent 9758082dc5
commit 09e7d35b55
22 changed files with 449 additions and 131 deletions

View File

@ -25,6 +25,8 @@ from ..services.invoker import Invoker
from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage
from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage
from ..services.model_manager.model_manager_default import ModelManagerService from ..services.model_manager.model_manager_default import ModelManagerService
from ..services.model_metadata import ModelMetadataStoreSQL
from ..services.model_records import ModelRecordServiceSQL
from ..services.names.names_default import SimpleNameService from ..services.names.names_default import SimpleNameService
from ..services.session_processor.session_processor_default import DefaultSessionProcessor from ..services.session_processor.session_processor_default import DefaultSessionProcessor
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
@ -83,8 +85,12 @@ class ApiDependencies:
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size) invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")) latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
download_queue_service = DownloadQueueService(event_bus=events) download_queue_service = DownloadQueueService(event_bus=events)
model_metadata_service = ModelMetadataStoreSQL(db=db)
model_manager = ModelManagerService.build_model_manager( model_manager = ModelManagerService.build_model_manager(
app_config=configuration, db=db, download_queue=download_queue_service, events=events app_config=configuration,
model_record_service=ModelRecordServiceSQL(db=db, metadata_store=model_metadata_service),
download_queue=download_queue_service,
events=events,
) )
names = SimpleNameService() names = SimpleNameService()
performance_statistics = InvocationStatsService() performance_statistics = InvocationStatsService()

View File

@ -194,7 +194,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
"""Block until the indicated job has reached terminal state, or when timeout limit reached.""" """Block until the indicated job has reached terminal state, or when timeout limit reached."""
start = time.time() start = time.time()
while not job.in_terminal_state: while not job.in_terminal_state:
if self._job_completed_event.wait(timeout=5): # in case we miss an event if self._job_completed_event.wait(timeout=0.25): # in case we miss an event
self._job_completed_event.clear() self._job_completed_event.clear()
if timeout > 0 and time.time() - start > timeout: if timeout > 0 and time.time() - start > timeout:
raise TimeoutError("Timeout exceeded") raise TimeoutError("Timeout exceeded")

View File

@ -46,8 +46,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
# This is to handle case of the model manager not being initialized, which happens # This is to handle case of the model manager not being initialized, which happens
# during some tests. # during some tests.
services = self._invoker.services services = self._invoker.services
if services.model_manager is None or services.model_manager.load is None:
yield None
if not self._stats.get(graph_execution_state_id): if not self._stats.get(graph_execution_state_id):
# First time we're seeing this graph_execution_state_id. # First time we're seeing this graph_execution_state_id.
self._stats[graph_execution_state_id] = GraphExecutionStats() self._stats[graph_execution_state_id] = GraphExecutionStats()

View File

@ -18,7 +18,9 @@ from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_records import ModelRecordServiceBase from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from ..model_metadata import ModelMetadataStoreBase
class InstallStatus(str, Enum): class InstallStatus(str, Enum):
@ -243,7 +245,7 @@ class ModelInstallServiceBase(ABC):
app_config: InvokeAIAppConfig, app_config: InvokeAIAppConfig,
record_store: ModelRecordServiceBase, record_store: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase, download_queue: DownloadQueueServiceBase,
metadata_store: ModelMetadataStore, metadata_store: ModelMetadataStoreBase,
event_bus: Optional["EventServiceBase"] = None, event_bus: Optional["EventServiceBase"] = None,
): ):
""" """

View File

@ -20,7 +20,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress
from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase, ModelRecordServiceSQL from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.config import (
AnyModelConfig, AnyModelConfig,
BaseModelType, BaseModelType,
@ -33,7 +33,6 @@ from invokeai.backend.model_manager.metadata import (
AnyModelRepoMetadata, AnyModelRepoMetadata,
CivitaiMetadataFetch, CivitaiMetadataFetch,
HuggingFaceMetadataFetch, HuggingFaceMetadataFetch,
ModelMetadataStore,
ModelMetadataWithFiles, ModelMetadataWithFiles,
RemoteModelFile, RemoteModelFile,
) )
@ -65,7 +64,6 @@ class ModelInstallService(ModelInstallServiceBase):
app_config: InvokeAIAppConfig, app_config: InvokeAIAppConfig,
record_store: ModelRecordServiceBase, record_store: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase, download_queue: DownloadQueueServiceBase,
metadata_store: Optional[ModelMetadataStore] = None,
event_bus: Optional[EventServiceBase] = None, event_bus: Optional[EventServiceBase] = None,
session: Optional[Session] = None, session: Optional[Session] = None,
): ):
@ -93,14 +91,7 @@ class ModelInstallService(ModelInstallServiceBase):
self._running = False self._running = False
self._session = session self._session = session
self._next_job_id = 0 self._next_job_id = 0
# There may not necessarily be a metadata store initialized self._metadata_store = record_store.metadata_store # for convenience
# so we create one and initialize it with the same sql database
# used by the record store service.
if metadata_store:
self._metadata_store = metadata_store
else:
assert isinstance(record_store, ModelRecordServiceSQL)
self._metadata_store = ModelMetadataStore(record_store.db)
@property @property
def app_config(self) -> InvokeAIAppConfig: # noqa D102 def app_config(self) -> InvokeAIAppConfig: # noqa D102
@ -259,7 +250,7 @@ class ModelInstallService(ModelInstallServiceBase):
"""Block until all installation jobs are done.""" """Block until all installation jobs are done."""
start = time.time() start = time.time()
while len(self._download_cache) > 0: while len(self._download_cache) > 0:
if self._downloads_changed_event.wait(timeout=5): # in case we miss an event if self._downloads_changed_event.wait(timeout=0.25): # in case we miss an event
self._downloads_changed_event.clear() self._downloads_changed_event.clear()
if timeout > 0 and time.time() - start > timeout: if timeout > 0 and time.time() - start > timeout:
raise TimeoutError("Timeout exceeded") raise TimeoutError("Timeout exceeded")

View File

@ -5,7 +5,6 @@ from typing_extensions import Self
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache
from invokeai.backend.model_manager.metadata import ModelMetadataStore
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from ..config import InvokeAIAppConfig from ..config import InvokeAIAppConfig
@ -13,8 +12,7 @@ from ..download import DownloadQueueServiceBase
from ..events.events_base import EventServiceBase from ..events.events_base import EventServiceBase
from ..model_install import ModelInstallService, ModelInstallServiceBase from ..model_install import ModelInstallService, ModelInstallServiceBase
from ..model_load import ModelLoadService, ModelLoadServiceBase from ..model_load import ModelLoadService, ModelLoadServiceBase
from ..model_records import ModelRecordServiceBase, ModelRecordServiceSQL from ..model_records import ModelRecordServiceBase
from ..shared.sqlite.sqlite_database import SqliteDatabase
from .model_manager_base import ModelManagerServiceBase from .model_manager_base import ModelManagerServiceBase
@ -64,7 +62,7 @@ class ModelManagerService(ModelManagerServiceBase):
def build_model_manager( def build_model_manager(
cls, cls,
app_config: InvokeAIAppConfig, app_config: InvokeAIAppConfig,
db: SqliteDatabase, model_record_service: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase, download_queue: DownloadQueueServiceBase,
events: EventServiceBase, events: EventServiceBase,
) -> Self: ) -> Self:
@ -82,19 +80,16 @@ class ModelManagerService(ModelManagerServiceBase):
convert_cache = ModelConvertCache( convert_cache = ModelConvertCache(
cache_path=app_config.models_convert_cache_path, max_size=app_config.convert_cache_size cache_path=app_config.models_convert_cache_path, max_size=app_config.convert_cache_size
) )
record_store = ModelRecordServiceSQL(db=db)
loader = ModelLoadService( loader = ModelLoadService(
app_config=app_config, app_config=app_config,
record_store=record_store, record_store=model_record_service,
ram_cache=ram_cache, ram_cache=ram_cache,
convert_cache=convert_cache, convert_cache=convert_cache,
) )
record_store._loader = loader # yeah, there is a circular reference here
installer = ModelInstallService( installer = ModelInstallService(
app_config=app_config, app_config=app_config,
record_store=record_store, record_store=model_record_service,
download_queue=download_queue, download_queue=download_queue,
metadata_store=ModelMetadataStore(db=db),
event_bus=events, event_bus=events,
) )
return cls(store=record_store, install=installer, load=loader) return cls(store=model_record_service, install=installer, load=loader)

View File

@ -0,0 +1,9 @@
"""Init file for ModelMetadataStoreService module."""
from .metadata_store_base import ModelMetadataStoreBase
from .metadata_store_sql import ModelMetadataStoreSQL
__all__ = [
"ModelMetadataStoreBase",
"ModelMetadataStoreSQL",
]

View File

@ -0,0 +1,65 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Storage for Model Metadata
"""
from abc import ABC, abstractmethod
from typing import List, Set, Tuple
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
class ModelMetadataStoreBase(ABC):
"""Store, search and fetch model metadata retrieved from remote repositories."""
@abstractmethod
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
"""
Add a block of repo metadata to a model record.
The model record config must already exist in the database with the
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to store
"""
@abstractmethod
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
"""Retrieve the ModelRepoMetadata corresponding to model key."""
@abstractmethod
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
"""Dump out all the metadata."""
@abstractmethod
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
"""
Update metadata corresponding to the model with the indicated key.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to update
"""
@abstractmethod
def list_tags(self) -> Set[str]:
"""Return all tags in the tags table."""
@abstractmethod
def search_by_tag(self, tags: Set[str]) -> Set[str]:
"""Return the keys of models containing all of the listed tags."""
@abstractmethod
def search_by_author(self, author: str) -> Set[str]:
"""Return the keys of models authored by the indicated author."""
@abstractmethod
def search_by_name(self, name: str) -> Set[str]:
"""
Return the keys of models with the indicated name.
Note that this is the name of the model given to it by
the remote source. The user may have changed the local
name. The local name will be located in the model config
record object.
"""

View File

@ -0,0 +1,222 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
SQL Storage for Model Metadata
"""
import sqlite3
from typing import List, Optional, Set, Tuple
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException
from invokeai.backend.model_manager.metadata.fetch import ModelMetadataFetchBase
from .metadata_store_base import ModelMetadataStoreBase
class ModelMetadataStoreSQL(ModelMetadataStoreBase):
"""Store, search and fetch model metadata retrieved from remote repositories."""
def __init__(self, db: SqliteDatabase):
"""
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
:param conn: sqlite3 connection object
:param lock: threading Lock object
"""
super().__init__()
self._db = db
self._cursor = self._db.conn.cursor()
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
"""
Add a block of repo metadata to a model record.
The model record config must already exist in the database with the
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to store
"""
json_serialized = metadata.model_dump_json()
with self._db.lock:
try:
self._cursor.execute(
"""--sql
INSERT INTO model_metadata(
id,
metadata
)
VALUES (?,?);
""",
(
model_key,
json_serialized,
),
)
self._update_tags(model_key, metadata.tags)
self._db.conn.commit()
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
self._db.conn.rollback()
raise UnknownMetadataException from excp
except sqlite3.Error as excp:
self._db.conn.rollback()
raise excp
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
"""Retrieve the ModelRepoMetadata corresponding to model key."""
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT metadata FROM model_metadata
WHERE id=?;
""",
(model_key,),
)
rows = self._cursor.fetchone()
if not rows:
raise UnknownMetadataException("model metadata not found")
return ModelMetadataFetchBase.from_json(rows[0])
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
"""Dump out all the metadata."""
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT id,metadata FROM model_metadata;
""",
(),
)
rows = self._cursor.fetchall()
return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows]
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
"""
Update metadata corresponding to the model with the indicated key.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to update
"""
json_serialized = metadata.model_dump_json() # turn it into a json string.
with self._db.lock:
try:
self._cursor.execute(
"""--sql
UPDATE model_metadata
SET
metadata=?
WHERE id=?;
""",
(json_serialized, model_key),
)
if self._cursor.rowcount == 0:
raise UnknownMetadataException("model metadata not found")
self._update_tags(model_key, metadata.tags)
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
return self.get_metadata(model_key)
def list_tags(self) -> Set[str]:
"""Return all tags in the tags table."""
self._cursor.execute(
"""--sql
select tag_text from tags;
"""
)
return {x[0] for x in self._cursor.fetchall()}
def search_by_tag(self, tags: Set[str]) -> Set[str]:
"""Return the keys of models containing all of the listed tags."""
with self._db.lock:
try:
matches: Optional[Set[str]] = None
for tag in tags:
self._cursor.execute(
"""--sql
SELECT a.model_id FROM model_tags AS a,
tags AS b
WHERE a.tag_id=b.tag_id
AND b.tag_text=?;
""",
(tag,),
)
model_keys = {x[0] for x in self._cursor.fetchall()}
if matches is None:
matches = model_keys
matches = matches.intersection(model_keys)
except sqlite3.Error as e:
raise e
return matches if matches else set()
def search_by_author(self, author: str) -> Set[str]:
"""Return the keys of models authored by the indicated author."""
self._cursor.execute(
"""--sql
SELECT id FROM model_metadata
WHERE author=?;
""",
(author,),
)
return {x[0] for x in self._cursor.fetchall()}
def search_by_name(self, name: str) -> Set[str]:
"""
Return the keys of models with the indicated name.
Note that this is the name of the model given to it by
the remote source. The user may have changed the local
name. The local name will be located in the model config
record object.
"""
self._cursor.execute(
"""--sql
SELECT id FROM model_metadata
WHERE name=?;
""",
(name,),
)
return {x[0] for x in self._cursor.fetchall()}
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
"""Update tags for the model referenced by model_key."""
# remove previous tags from this model
self._cursor.execute(
"""--sql
DELETE FROM model_tags
WHERE model_id=?;
""",
(model_key,),
)
for tag in tags:
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO tags (
tag_text
)
VALUES (?);
""",
(tag,),
)
self._cursor.execute(
"""--sql
SELECT tag_id
FROM tags
WHERE tag_text = ?
LIMIT 1;
""",
(tag,),
)
tag_id = self._cursor.fetchone()[0]
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO model_tags (
model_id,
tag_id
)
VALUES (?,?);
""",
(model_key, tag_id),
)

View File

@ -17,7 +17,9 @@ from invokeai.backend.model_manager import (
ModelFormat, ModelFormat,
ModelType, ModelType,
) )
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from ..model_metadata import ModelMetadataStoreBase
class DuplicateModelException(Exception): class DuplicateModelException(Exception):
@ -109,7 +111,7 @@ class ModelRecordServiceBase(ABC):
@property @property
@abstractmethod @abstractmethod
def metadata_store(self) -> ModelMetadataStore: def metadata_store(self) -> ModelMetadataStoreBase:
"""Return a ModelMetadataStore initialized on the same database.""" """Return a ModelMetadataStore initialized on the same database."""
pass pass

View File

@ -54,8 +54,9 @@ from invokeai.backend.model_manager.config import (
ModelFormat, ModelFormat,
ModelType, ModelType,
) )
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException
from ..model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL
from ..shared.sqlite.sqlite_database import SqliteDatabase from ..shared.sqlite.sqlite_database import SqliteDatabase
from .model_records_base import ( from .model_records_base import (
DuplicateModelException, DuplicateModelException,
@ -69,7 +70,7 @@ from .model_records_base import (
class ModelRecordServiceSQL(ModelRecordServiceBase): class ModelRecordServiceSQL(ModelRecordServiceBase):
"""Implementation of the ModelConfigStore ABC using a SQL database.""" """Implementation of the ModelConfigStore ABC using a SQL database."""
def __init__(self, db: SqliteDatabase): def __init__(self, db: SqliteDatabase, metadata_store: ModelMetadataStoreBase):
""" """
Initialize a new object from preexisting sqlite3 connection and threading lock objects. Initialize a new object from preexisting sqlite3 connection and threading lock objects.
@ -78,6 +79,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
super().__init__() super().__init__()
self._db = db self._db = db
self._cursor = db.conn.cursor() self._cursor = db.conn.cursor()
self._metadata_store = metadata_store
@property @property
def db(self) -> SqliteDatabase: def db(self) -> SqliteDatabase:
@ -157,7 +159,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
self._db.conn.rollback() self._db.conn.rollback()
raise e raise e
def update_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig: def update_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
""" """
Update the model, returning the updated version. Update the model, returning the updated version.
@ -307,9 +309,9 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
return results return results
@property @property
def metadata_store(self) -> ModelMetadataStore: def metadata_store(self) -> ModelMetadataStoreBase:
"""Return a ModelMetadataStore initialized on the same database.""" """Return a ModelMetadataStore initialized on the same database."""
return ModelMetadataStore(self._db) return self._metadata_store
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]: def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
""" """
@ -330,18 +332,18 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
:param tags: Set of tags to search for. All tags must be present. :param tags: Set of tags to search for. All tags must be present.
""" """
store = ModelMetadataStore(self._db) store = ModelMetadataStoreSQL(self._db)
keys = store.search_by_tag(tags) keys = store.search_by_tag(tags)
return [self.get_model(x) for x in keys] return [self.get_model(x) for x in keys]
def list_tags(self) -> Set[str]: def list_tags(self) -> Set[str]:
"""Return a unique set of all the model tags in the metadata database.""" """Return a unique set of all the model tags in the metadata database."""
store = ModelMetadataStore(self._db) store = ModelMetadataStoreSQL(self._db)
return store.list_tags() return store.list_tags()
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]:
"""List metadata for all models that have it.""" """List metadata for all models that have it."""
store = ModelMetadataStore(self._db) store = ModelMetadataStoreSQL(self._db)
return store.list_all_metadata() return store.list_all_metadata()
def list_models( def list_models(

View File

@ -25,6 +25,7 @@ from invokeai.app.services.model_install import (
ModelSource, ModelSource,
URLModelSource, URLModelSource,
) )
from invokeai.app.services.model_metadata import ModelMetadataStoreSQL
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
from invokeai.app.services.shared.sqlite.sqlite_util import init_db from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.backend.model_manager import ( from invokeai.backend.model_manager import (
@ -45,7 +46,7 @@ def initialize_record_store(app_config: InvokeAIAppConfig) -> ModelRecordService
logger = InvokeAILogger.get_logger(config=app_config) logger = InvokeAILogger.get_logger(config=app_config)
image_files = DiskImageFileStorage(f"{app_config.output_path}/images") image_files = DiskImageFileStorage(f"{app_config.output_path}/images")
db = init_db(config=app_config, logger=logger, image_files=image_files) db = init_db(config=app_config, logger=logger, image_files=image_files)
obj: ModelRecordServiceBase = ModelRecordServiceSQL(db) obj: ModelRecordServiceBase = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
return obj return obj
@ -54,12 +55,10 @@ def initialize_installer(
) -> ModelInstallServiceBase: ) -> ModelInstallServiceBase:
"""Return an initialized ModelInstallService object.""" """Return an initialized ModelInstallService object."""
record_store = initialize_record_store(app_config) record_store = initialize_record_store(app_config)
metadata_store = record_store.metadata_store
download_queue = DownloadQueueService() download_queue = DownloadQueueService()
installer = ModelInstallService( installer = ModelInstallService(
app_config=app_config, app_config=app_config,
record_store=record_store, record_store=record_store,
metadata_store=metadata_store,
download_queue=download_queue, download_queue=download_queue,
event_bus=event_bus, event_bus=event_bus,
) )
@ -287,14 +286,14 @@ class InstallHelper(object):
model_name=model_name, model_name=model_name,
) )
if len(matches) > 1: if len(matches) > 1:
print( self._logger.error(
f"{model_to_remove} is ambiguous. Please use model_base/model_type/model_name (e.g. sd-1/main/my_model) to disambiguate." "{model_to_remove} is ambiguous. Please use model_base/model_type/model_name (e.g. sd-1/main/my_model) to disambiguate"
) )
elif not matches: elif not matches:
print(f"{model_to_remove}: unknown model") self._logger.error(f"{model_to_remove}: unknown model")
else: else:
for m in matches: for m in matches:
print(f"Deleting {m.type}:{m.name}") self._logger.info(f"Deleting {m.type}:{m.name}")
installer.delete(m.key) installer.delete(m.key)
installer.wait_for_installs() installer.wait_for_installs()

View File

@ -4,10 +4,6 @@ Init file for the model loader.
""" """
from importlib import import_module from importlib import import_module
from pathlib import Path from pathlib import Path
from typing import Optional
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util.logging import InvokeAILogger
from .convert_cache.convert_cache_default import ModelConvertCache from .convert_cache.convert_cache_default import ModelConvertCache
from .load_base import AnyModelLoader, LoadedModel from .load_base import AnyModelLoader, LoadedModel
@ -19,16 +15,3 @@ for module in loaders:
import_module(f"{__package__}.model_loaders.{module}") import_module(f"{__package__}.model_loaders.{module}")
__all__ = ["AnyModelLoader", "LoadedModel", "ModelCache", "ModelConvertCache"] __all__ = ["AnyModelLoader", "LoadedModel", "ModelCache", "ModelConvertCache"]
def get_standalone_loader(app_config: Optional[InvokeAIAppConfig]) -> AnyModelLoader:
app_config = app_config or InvokeAIAppConfig.get_config()
logger = InvokeAILogger.get_logger(config=app_config)
return AnyModelLoader(
app_config=app_config,
logger=logger,
ram_cache=ModelCache(
logger=logger, max_cache_size=app_config.ram_cache_size, max_vram_cache_size=app_config.vram_cache_size
),
convert_cache=ModelConvertCache(app_config.models_convert_cache_path),
)

View File

@ -39,21 +39,21 @@ class LoadedModel:
"""Context manager object that mediates transfer from RAM<->VRAM.""" """Context manager object that mediates transfer from RAM<->VRAM."""
config: AnyModelConfig config: AnyModelConfig
locker: ModelLockerBase _locker: ModelLockerBase
def __enter__(self) -> AnyModel: def __enter__(self) -> AnyModel:
"""Context entry.""" """Context entry."""
self.locker.lock() self._locker.lock()
return self.model return self.model
def __exit__(self, *args: Any, **kwargs: Any) -> None: def __exit__(self, *args: Any, **kwargs: Any) -> None:
"""Context exit.""" """Context exit."""
self.locker.unlock() self._locker.unlock()
@property @property
def model(self) -> AnyModel: def model(self) -> AnyModel:
"""Return the model without locking it.""" """Return the model without locking it."""
return self.locker.model return self._locker.model
class ModelLoaderBase(ABC): class ModelLoaderBase(ABC):

View File

@ -75,7 +75,7 @@ class ModelLoader(ModelLoaderBase):
model_path = self._convert_if_needed(model_config, model_path, submodel_type) model_path = self._convert_if_needed(model_config, model_path, submodel_type)
locker = self._load_if_needed(model_config, model_path, submodel_type) locker = self._load_if_needed(model_config, model_path, submodel_type)
return LoadedModel(config=model_config, locker=locker) return LoadedModel(config=model_config, _locker=locker)
def _get_model_path( def _get_model_path(
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None

View File

@ -39,10 +39,7 @@ class ModelMerger(object):
def __init__(self, installer: ModelInstallServiceBase): def __init__(self, installer: ModelInstallServiceBase):
""" """
Initialize a ModelMerger object. Initialize a ModelMerger object with the model installer.
:param store: Underlying storage manager for the running process.
:param config: InvokeAIAppConfig object (if not provided, default will be selected).
""" """
self._installer = installer self._installer = installer

View File

@ -18,7 +18,7 @@ assert isinstance(data, CivitaiMetadata)
if data.allow_commercial_use: if data.allow_commercial_use:
print("Commercial use of this model is allowed") print("Commercial use of this model is allowed")
""" """
from .fetch import CivitaiMetadataFetch, HuggingFaceMetadataFetch from .fetch import CivitaiMetadataFetch, HuggingFaceMetadataFetch, ModelMetadataFetchBase
from .metadata_base import ( from .metadata_base import (
AnyModelRepoMetadata, AnyModelRepoMetadata,
AnyModelRepoMetadataValidator, AnyModelRepoMetadataValidator,
@ -31,7 +31,6 @@ from .metadata_base import (
RemoteModelFile, RemoteModelFile,
UnknownMetadataException, UnknownMetadataException,
) )
from .metadata_store import ModelMetadataStore
__all__ = [ __all__ = [
"AnyModelRepoMetadata", "AnyModelRepoMetadata",
@ -42,7 +41,7 @@ __all__ = [
"HuggingFaceMetadata", "HuggingFaceMetadata",
"HuggingFaceMetadataFetch", "HuggingFaceMetadataFetch",
"LicenseRestrictions", "LicenseRestrictions",
"ModelMetadataStore", "ModelMetadataFetchBase",
"BaseMetadata", "BaseMetadata",
"ModelMetadataWithFiles", "ModelMetadataWithFiles",
"RemoteModelFile", "RemoteModelFile",

View File

@ -6,20 +6,40 @@ Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
""" """
import argparse import argparse
import curses import curses
import re
import sys import sys
from argparse import Namespace from argparse import Namespace
from pathlib import Path from pathlib import Path
from typing import List from typing import List, Optional, Tuple
import npyscreen import npyscreen
from npyscreen import widget from npyscreen import widget
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_management import BaseModelType, ModelManager, ModelMerger, ModelType from invokeai.app.services.download import DownloadQueueService
from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage
from invokeai.app.services.model_install import ModelInstallService
from invokeai.app.services.model_metadata import ModelMetadataStoreSQL
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.backend.model_manager import (
BaseModelType,
ModelFormat,
ModelType,
ModelVariantType,
)
from invokeai.backend.model_manager.merge import ModelMerger
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.frontend.install.widgets import FloatTitleSlider, SingleSelectColumns, TextBox from invokeai.frontend.install.widgets import FloatTitleSlider, SingleSelectColumns, TextBox
config = InvokeAIAppConfig.get_config() config = InvokeAIAppConfig.get_config()
logger = InvokeAILogger.get_logger()
BASE_TYPES = [
(BaseModelType.StableDiffusion1, "Models Built on SD-1.x"),
(BaseModelType.StableDiffusion2, "Models Built on SD-2.x"),
(BaseModelType.StableDiffusionXL, "Models Built on SDXL"),
]
def _parse_args() -> Namespace: def _parse_args() -> Namespace:
@ -48,7 +68,7 @@ def _parse_args() -> Namespace:
parser.add_argument( parser.add_argument(
"--base_model", "--base_model",
type=str, type=str,
choices=[x.value for x in BaseModelType], choices=[x[0].value for x in BASE_TYPES],
help="The base model shared by the models to be merged", help="The base model shared by the models to be merged",
) )
parser.add_argument( parser.add_argument(
@ -98,17 +118,17 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
super().__init__(parentApp, name) super().__init__(parentApp, name)
@property @property
def model_manager(self): def record_store(self):
return self.parentApp.model_manager return self.parentApp.record_store
def afterEditing(self): def afterEditing(self):
self.parentApp.setNextForm(None) self.parentApp.setNextForm(None)
def create(self): def create(self):
window_height, window_width = curses.initscr().getmaxyx() window_height, window_width = curses.initscr().getmaxyx()
self.model_names = self.get_model_names()
self.current_base = 0 self.current_base = 0
self.models = self.get_models(BASE_TYPES[self.current_base][0])
self.model_names = [x[1] for x in self.models]
max_width = max([len(x) for x in self.model_names]) max_width = max([len(x) for x in self.model_names])
max_width += 6 max_width += 6
horizontal_layout = max_width * 3 < window_width horizontal_layout = max_width * 3 < window_width
@ -128,11 +148,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
self.nextrely += 1 self.nextrely += 1
self.base_select = self.add_widget_intelligent( self.base_select = self.add_widget_intelligent(
SingleSelectColumns, SingleSelectColumns,
values=[ values=[x[1] for x in BASE_TYPES],
"Models Built on SD-1.x",
"Models Built on SD-2.x",
"Models Built on SDXL",
],
value=[self.current_base], value=[self.current_base],
columns=4, columns=4,
max_height=2, max_height=2,
@ -263,21 +279,20 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
sys.exit(0) sys.exit(0)
def marshall_arguments(self) -> dict: def marshall_arguments(self) -> dict:
model_names = self.model_names model_keys = [x[0] for x in self.models]
models = [ models = [
model_names[self.model1.value[0]], model_keys[self.model1.value[0]],
model_names[self.model2.value[0]], model_keys[self.model2.value[0]],
] ]
if self.model3.value[0] > 0: if self.model3.value[0] > 0:
models.append(model_names[self.model3.value[0] - 1]) models.append(model_keys[self.model3.value[0] - 1])
interp = "add_difference" interp = "add_difference"
else: else:
interp = self.interpolations[self.merge_method.value[0]] interp = self.interpolations[self.merge_method.value[0]]
bases = ["sd-1", "sd-2", "sdxl"]
args = { args = {
"model_names": models, "model_keys": models,
"base_model": BaseModelType(bases[self.base_select.value[0]]), "base_model": tuple(BaseModelType)[self.base_select.value[0]],
"alpha": self.alpha.value, "alpha": self.alpha.value,
"interp": interp, "interp": interp,
"force": self.force.value, "force": self.force.value,
@ -311,18 +326,18 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
else: else:
return True return True
def get_model_names(self, base_model: BaseModelType = BaseModelType.StableDiffusion1) -> List[str]: def get_models(self, base_model: Optional[BaseModelType] = None) -> List[Tuple[str, str]]: # key to name
model_names = [ models = [
info["model_name"] (x.key, x.name)
for info in self.model_manager.list_models(model_type=ModelType.Main, base_model=base_model) for x in self.record_store.search_by_attr(model_type=ModelType.Main, base_model=base_model)
if info["model_format"] == "diffusers" if x.format == ModelFormat("diffusers") and x.variant == ModelVariantType("normal")
] ]
return sorted(model_names) return sorted(models, key=lambda x: x[1])
def _populate_models(self, value=None): def _populate_models(self, value: List[int]):
bases = ["sd-1", "sd-2", "sdxl"] base_model = BASE_TYPES[value[0]][0]
base_model = BaseModelType(bases[value[0]]) self.models = self.get_models(base_model)
self.model_names = self.get_model_names(base_model) self.model_names = [x[1] for x in self.models]
models_plus_none = self.model_names.copy() models_plus_none = self.model_names.copy()
models_plus_none.insert(0, "None") models_plus_none.insert(0, "None")
@ -334,24 +349,24 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
class Mergeapp(npyscreen.NPSAppManaged): class Mergeapp(npyscreen.NPSAppManaged):
def __init__(self, model_manager: ModelManager): def __init__(self, record_store: ModelRecordServiceBase):
super().__init__() super().__init__()
self.model_manager = model_manager self.record_store = record_store
def onStart(self): def onStart(self):
npyscreen.setTheme(npyscreen.Themes.ElegantTheme) npyscreen.setTheme(npyscreen.Themes.ElegantTheme)
self.main = self.addForm("MAIN", mergeModelsForm, name="Merge Models Settings") self.main = self.addForm("MAIN", mergeModelsForm, name="Merge Models Settings")
def run_gui(args: Namespace): def run_gui(args: Namespace) -> None:
model_manager = ModelManager(config.model_conf_path) record_store: ModelRecordServiceBase = get_config_store()
mergeapp = Mergeapp(model_manager) mergeapp = Mergeapp(record_store)
mergeapp.run() mergeapp.run()
args = mergeapp.merge_arguments args = mergeapp.merge_arguments
merger = ModelMerger(model_manager) merger = get_model_merger(record_store)
merger.merge_diffusion_models_and_save(**args) merger.merge_diffusion_models_and_save(**args)
logger.info(f'Models merged into new model: "{args["merged_model_name"]}".') merged_model_name = args["merged_model_name"]
logger.info(f'Models merged into new model: "{merged_model_name}".')
def run_cli(args: Namespace): def run_cli(args: Namespace):
@ -364,20 +379,54 @@ def run_cli(args: Namespace):
args.merged_model_name = "+".join(args.model_names) args.merged_model_name = "+".join(args.model_names)
logger.info(f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"') logger.info(f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"')
model_manager = ModelManager(config.model_conf_path) record_store: ModelRecordServiceBase = get_config_store()
assert ( assert (
not model_manager.model_exists(args.merged_model_name, args.base_model, ModelType.Main) or args.clobber len(record_store.search_by_attr(args.merged_model_name, args.base_model, ModelType.Main)) == 0 or args.clobber
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.' ), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
merger = ModelMerger(model_manager) merger = get_model_merger(record_store)
merger.merge_diffusion_models_and_save(**vars(args)) model_keys = []
for name in args.model_names:
if len(name) == 32 and re.match(r"^[0-9a-f]$", name):
model_keys.append(name)
else:
models = record_store.search_by_attr(
model_name=name, model_type=ModelType.Main, base_model=BaseModelType(args.base_model)
)
assert len(models) > 0, f"{name}: Unknown model"
assert len(models) < 2, f"{name}: More than one model by this name. Please specify the model key instead."
model_keys.append(models[0].key)
merger.merge_diffusion_models_and_save(
alpha=args.alpha,
model_keys=model_keys,
merged_model_name=args.merged_model_name,
interp=args.interp,
force=args.force,
)
logger.info(f'Models merged into new model: "{args.merged_model_name}".') logger.info(f'Models merged into new model: "{args.merged_model_name}".')
def get_config_store() -> ModelRecordServiceSQL:
output_path = config.output_path
assert output_path is not None
image_files = DiskImageFileStorage(output_path / "images")
db = init_db(config=config, logger=InvokeAILogger.get_logger(), image_files=image_files)
return ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
def get_model_merger(record_store: ModelRecordServiceBase) -> ModelMerger:
installer = ModelInstallService(app_config=config, record_store=record_store, download_queue=DownloadQueueService())
installer.start()
return ModelMerger(installer)
def main(): def main():
args = _parse_args() args = _parse_args()
if args.root_dir: if args.root_dir:
config.parse_args(["--root", str(args.root_dir)]) config.parse_args(["--root", str(args.root_dir)])
else:
config.parse_args([])
try: try:
if args.front_end: if args.front_end:

View File

@ -1,4 +1,5 @@
import logging import logging
from unittest.mock import Mock
import pytest import pytest
@ -65,7 +66,7 @@ def mock_services() -> InvocationServices:
invocation_cache=MemoryInvocationCache(max_cache_size=0), invocation_cache=MemoryInvocationCache(max_cache_size=0),
latents=None, # type: ignore latents=None, # type: ignore
logger=logging, # type: ignore logger=logging, # type: ignore
model_manager=None, # type: ignore model_manager=Mock(), # type: ignore
download_queue=None, # type: ignore download_queue=None, # type: ignore
names=None, # type: ignore names=None, # type: ignore
performance_statistics=InvocationStatsService(), performance_statistics=InvocationStatsService(),

View File

@ -8,6 +8,7 @@ from typing import Any
import pytest import pytest
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_metadata import ModelMetadataStoreSQL
from invokeai.app.services.model_records import ( from invokeai.app.services.model_records import (
DuplicateModelException, DuplicateModelException,
ModelRecordOrderBy, ModelRecordOrderBy,
@ -36,7 +37,7 @@ def store(
config = InvokeAIAppConfig(root=datadir) config = InvokeAIAppConfig(root=datadir)
logger = InvokeAILogger.get_logger(config=config) logger = InvokeAILogger.get_logger(config=config)
db = create_mock_sqlite_database(config, logger) db = create_mock_sqlite_database(config, logger)
return ModelRecordServiceSQL(db) return ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
def example_config() -> TextualInversionConfig: def example_config() -> TextualInversionConfig:

View File

@ -14,6 +14,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadQueueService from invokeai.app.services.download import DownloadQueueService
from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase
from invokeai.app.services.model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL
from invokeai.app.services.model_records import ModelRecordServiceSQL from invokeai.app.services.model_records import ModelRecordServiceSQL
from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.config import (
BaseModelType, BaseModelType,
@ -21,7 +22,6 @@ from invokeai.backend.model_manager.config import (
ModelType, ModelType,
) )
from invokeai.backend.model_manager.load import AnyModelLoader, ModelCache, ModelConvertCache from invokeai.backend.model_manager.load import AnyModelLoader, ModelCache, ModelConvertCache
from invokeai.backend.model_manager.metadata import ModelMetadataStore
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from tests.backend.model_manager_2.model_metadata.metadata_examples import ( from tests.backend.model_manager_2.model_metadata.metadata_examples import (
RepoCivitaiModelMetadata1, RepoCivitaiModelMetadata1,
@ -104,7 +104,7 @@ def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordS
def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL: def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL:
logger = InvokeAILogger.get_logger(config=mm2_app_config) logger = InvokeAILogger.get_logger(config=mm2_app_config)
db = create_mock_sqlite_database(mm2_app_config, logger) db = create_mock_sqlite_database(mm2_app_config, logger)
store = ModelRecordServiceSQL(db) store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
# add five simple config records to the database # add five simple config records to the database
raw1 = { raw1 = {
"path": "/tmp/foo1", "path": "/tmp/foo1",
@ -163,15 +163,14 @@ def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL
@pytest.fixture @pytest.fixture
def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStore: def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStoreBase:
db = mm2_record_store._db # to ensure we are sharing the same database return mm2_record_store.metadata_store
return ModelMetadataStore(db)
@pytest.fixture @pytest.fixture
def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session: def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
"""This fixtures defines a series of mock URLs for testing download and installation.""" """This fixtures defines a series of mock URLs for testing download and installation."""
sess = TestSession() sess: Session = TestSession()
sess.mount( sess.mount(
"https://test.com/missing_model.safetensors", "https://test.com/missing_model.safetensors",
TestAdapter( TestAdapter(
@ -258,8 +257,7 @@ def mm2_installer(mm2_app_config: InvokeAIAppConfig, mm2_session: Session) -> Mo
logger = InvokeAILogger.get_logger() logger = InvokeAILogger.get_logger()
db = create_mock_sqlite_database(mm2_app_config, logger) db = create_mock_sqlite_database(mm2_app_config, logger)
events = DummyEventService() events = DummyEventService()
store = ModelRecordServiceSQL(db) store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
metadata_store = ModelMetadataStore(db)
download_queue = DownloadQueueService(requests_session=mm2_session) download_queue = DownloadQueueService(requests_session=mm2_session)
download_queue.start() download_queue.start()
@ -268,7 +266,6 @@ def mm2_installer(mm2_app_config: InvokeAIAppConfig, mm2_session: Session) -> Mo
app_config=mm2_app_config, app_config=mm2_app_config,
record_store=store, record_store=store,
download_queue=download_queue, download_queue=download_queue,
metadata_store=metadata_store,
event_bus=events, event_bus=events,
session=mm2_session, session=mm2_session,
) )

View File

@ -8,6 +8,7 @@ import pytest
from pydantic.networks import HttpUrl from pydantic.networks import HttpUrl
from requests.sessions import Session from requests.sessions import Session
from invokeai.app.services.model_metadata import ModelMetadataStoreBase
from invokeai.backend.model_manager.config import ModelRepoVariant from invokeai.backend.model_manager.config import ModelRepoVariant
from invokeai.backend.model_manager.metadata import ( from invokeai.backend.model_manager.metadata import (
CivitaiMetadata, CivitaiMetadata,
@ -15,14 +16,13 @@ from invokeai.backend.model_manager.metadata import (
CommercialUsage, CommercialUsage,
HuggingFaceMetadata, HuggingFaceMetadata,
HuggingFaceMetadataFetch, HuggingFaceMetadataFetch,
ModelMetadataStore,
UnknownMetadataException, UnknownMetadataException,
) )
from invokeai.backend.model_manager.util import select_hf_files from invokeai.backend.model_manager.util import select_hf_files
from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403 from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403
def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStore) -> None: def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStoreBase) -> None:
tags = {"text-to-image", "diffusers"} tags = {"text-to-image", "diffusers"}
input_metadata = HuggingFaceMetadata( input_metadata = HuggingFaceMetadata(
name="sdxl-vae", name="sdxl-vae",
@ -40,7 +40,7 @@ def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStore) -> None:
assert mm2_metadata_store.list_tags() == tags assert mm2_metadata_store.list_tags() == tags
def test_metadata_store_update(mm2_metadata_store: ModelMetadataStore) -> None: def test_metadata_store_update(mm2_metadata_store: ModelMetadataStoreBase) -> None:
input_metadata = HuggingFaceMetadata( input_metadata = HuggingFaceMetadata(
name="sdxl-vae", name="sdxl-vae",
author="stabilityai", author="stabilityai",
@ -57,7 +57,7 @@ def test_metadata_store_update(mm2_metadata_store: ModelMetadataStore) -> None:
assert input_metadata == output_metadata assert input_metadata == output_metadata
def test_metadata_search(mm2_metadata_store: ModelMetadataStore) -> None: def test_metadata_search(mm2_metadata_store: ModelMetadataStoreBase) -> None:
metadata1 = HuggingFaceMetadata( metadata1 = HuggingFaceMetadata(
name="sdxl-vae", name="sdxl-vae",
author="stabilityai", author="stabilityai",