diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 378961a055..8e79b26e2d 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -28,6 +28,8 @@ from ..services.invocation_services import InvocationServices from ..services.invocation_stats.invocation_stats_default import InvocationStatsService from ..services.invoker import Invoker from ..services.model_manager.model_manager_default import ModelManagerService +from ..services.model_metadata import ModelMetadataStoreSQL +from ..services.model_records import ModelRecordServiceSQL from ..services.names.names_default import SimpleNameService from ..services.session_processor.session_processor_default import DefaultSessionProcessor from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue @@ -94,8 +96,12 @@ class ApiDependencies: ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True) ) download_queue_service = DownloadQueueService(event_bus=events) + model_metadata_service = ModelMetadataStoreSQL(db=db) model_manager = ModelManagerService.build_model_manager( - app_config=configuration, db=db, download_queue=download_queue_service, events=events + app_config=configuration, + model_record_service=ModelRecordServiceSQL(db=db, metadata_store=model_metadata_service), + download_queue=download_queue_service, + events=events, ) names = SimpleNameService() performance_statistics = InvocationStatsService() diff --git a/invokeai/app/services/download/download_default.py b/invokeai/app/services/download/download_default.py index 6d5cedbcad..50cac80d09 100644 --- a/invokeai/app/services/download/download_default.py +++ b/invokeai/app/services/download/download_default.py @@ -194,7 +194,7 @@ class DownloadQueueService(DownloadQueueServiceBase): """Block until the indicated job has reached terminal state, or when timeout limit reached.""" start = time.time() while not job.in_terminal_state: - if self._job_completed_event.wait(timeout=5): # in case we miss an event + if self._job_completed_event.wait(timeout=0.25): # in case we miss an event self._job_completed_event.clear() if timeout > 0 and time.time() - start > timeout: raise TimeoutError("Timeout exceeded") diff --git a/invokeai/app/services/invocation_stats/invocation_stats_default.py b/invokeai/app/services/invocation_stats/invocation_stats_default.py index 6c893021de..486a1ca5b3 100644 --- a/invokeai/app/services/invocation_stats/invocation_stats_default.py +++ b/invokeai/app/services/invocation_stats/invocation_stats_default.py @@ -46,8 +46,6 @@ class InvocationStatsService(InvocationStatsServiceBase): # This is to handle case of the model manager not being initialized, which happens # during some tests. services = self._invoker.services - if services.model_manager is None or services.model_manager.load is None: - yield None if not self._stats.get(graph_execution_state_id): # First time we're seeing this graph_execution_state_id. self._stats[graph_execution_state_id] = GraphExecutionStats() diff --git a/invokeai/app/services/model_install/model_install_base.py b/invokeai/app/services/model_install/model_install_base.py index 39ea8c4a0d..2f03db0af7 100644 --- a/invokeai/app/services/model_install/model_install_base.py +++ b/invokeai/app/services/model_install/model_install_base.py @@ -18,7 +18,9 @@ from invokeai.app.services.events import EventServiceBase from invokeai.app.services.invoker import Invoker from invokeai.app.services.model_records import ModelRecordServiceBase from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant -from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore +from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata + +from ..model_metadata import ModelMetadataStoreBase class InstallStatus(str, Enum): @@ -243,7 +245,7 @@ class ModelInstallServiceBase(ABC): app_config: InvokeAIAppConfig, record_store: ModelRecordServiceBase, download_queue: DownloadQueueServiceBase, - metadata_store: ModelMetadataStore, + metadata_store: ModelMetadataStoreBase, event_bus: Optional["EventServiceBase"] = None, ): """ diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 20a85a82a1..7dee8bfd8c 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -20,7 +20,7 @@ from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invoker import Invoker -from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase, ModelRecordServiceSQL +from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase from invokeai.backend.model_manager.config import ( AnyModelConfig, BaseModelType, @@ -33,7 +33,6 @@ from invokeai.backend.model_manager.metadata import ( AnyModelRepoMetadata, CivitaiMetadataFetch, HuggingFaceMetadataFetch, - ModelMetadataStore, ModelMetadataWithFiles, RemoteModelFile, ) @@ -65,7 +64,6 @@ class ModelInstallService(ModelInstallServiceBase): app_config: InvokeAIAppConfig, record_store: ModelRecordServiceBase, download_queue: DownloadQueueServiceBase, - metadata_store: Optional[ModelMetadataStore] = None, event_bus: Optional[EventServiceBase] = None, session: Optional[Session] = None, ): @@ -93,14 +91,7 @@ class ModelInstallService(ModelInstallServiceBase): self._running = False self._session = session self._next_job_id = 0 - # There may not necessarily be a metadata store initialized - # so we create one and initialize it with the same sql database - # used by the record store service. - if metadata_store: - self._metadata_store = metadata_store - else: - assert isinstance(record_store, ModelRecordServiceSQL) - self._metadata_store = ModelMetadataStore(record_store.db) + self._metadata_store = record_store.metadata_store # for convenience @property def app_config(self) -> InvokeAIAppConfig: # noqa D102 @@ -259,7 +250,7 @@ class ModelInstallService(ModelInstallServiceBase): """Block until all installation jobs are done.""" start = time.time() while len(self._download_cache) > 0: - if self._downloads_changed_event.wait(timeout=5): # in case we miss an event + if self._downloads_changed_event.wait(timeout=0.25): # in case we miss an event self._downloads_changed_event.clear() if timeout > 0 and time.time() - start > timeout: raise TimeoutError("Timeout exceeded") diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index 028d4af615..b96341be69 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -5,7 +5,6 @@ from typing_extensions import Self from invokeai.app.services.invoker import Invoker from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache -from invokeai.backend.model_manager.metadata import ModelMetadataStore from invokeai.backend.util.logging import InvokeAILogger from ..config import InvokeAIAppConfig @@ -13,8 +12,7 @@ from ..download import DownloadQueueServiceBase from ..events.events_base import EventServiceBase from ..model_install import ModelInstallService, ModelInstallServiceBase from ..model_load import ModelLoadService, ModelLoadServiceBase -from ..model_records import ModelRecordServiceBase, ModelRecordServiceSQL -from ..shared.sqlite.sqlite_database import SqliteDatabase +from ..model_records import ModelRecordServiceBase from .model_manager_base import ModelManagerServiceBase @@ -64,7 +62,7 @@ class ModelManagerService(ModelManagerServiceBase): def build_model_manager( cls, app_config: InvokeAIAppConfig, - db: SqliteDatabase, + model_record_service: ModelRecordServiceBase, download_queue: DownloadQueueServiceBase, events: EventServiceBase, ) -> Self: @@ -82,19 +80,16 @@ class ModelManagerService(ModelManagerServiceBase): convert_cache = ModelConvertCache( cache_path=app_config.models_convert_cache_path, max_size=app_config.convert_cache_size ) - record_store = ModelRecordServiceSQL(db=db) loader = ModelLoadService( app_config=app_config, - record_store=record_store, + record_store=model_record_service, ram_cache=ram_cache, convert_cache=convert_cache, ) - record_store._loader = loader # yeah, there is a circular reference here installer = ModelInstallService( app_config=app_config, - record_store=record_store, + record_store=model_record_service, download_queue=download_queue, - metadata_store=ModelMetadataStore(db=db), event_bus=events, ) - return cls(store=record_store, install=installer, load=loader) + return cls(store=model_record_service, install=installer, load=loader) diff --git a/invokeai/app/services/model_metadata/__init__.py b/invokeai/app/services/model_metadata/__init__.py new file mode 100644 index 0000000000..981c96b709 --- /dev/null +++ b/invokeai/app/services/model_metadata/__init__.py @@ -0,0 +1,9 @@ +"""Init file for ModelMetadataStoreService module.""" + +from .metadata_store_base import ModelMetadataStoreBase +from .metadata_store_sql import ModelMetadataStoreSQL + +__all__ = [ + "ModelMetadataStoreBase", + "ModelMetadataStoreSQL", +] diff --git a/invokeai/app/services/model_metadata/metadata_store_base.py b/invokeai/app/services/model_metadata/metadata_store_base.py new file mode 100644 index 0000000000..e0e4381b09 --- /dev/null +++ b/invokeai/app/services/model_metadata/metadata_store_base.py @@ -0,0 +1,65 @@ +# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team +""" +Storage for Model Metadata +""" + +from abc import ABC, abstractmethod +from typing import List, Set, Tuple + +from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata + + +class ModelMetadataStoreBase(ABC): + """Store, search and fetch model metadata retrieved from remote repositories.""" + + @abstractmethod + def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None: + """ + Add a block of repo metadata to a model record. + + The model record config must already exist in the database with the + same key. Otherwise a FOREIGN KEY constraint exception will be raised. + + :param model_key: Existing model key in the `model_config` table + :param metadata: ModelRepoMetadata object to store + """ + + @abstractmethod + def get_metadata(self, model_key: str) -> AnyModelRepoMetadata: + """Retrieve the ModelRepoMetadata corresponding to model key.""" + + @abstractmethod + def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata + """Dump out all the metadata.""" + + @abstractmethod + def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata: + """ + Update metadata corresponding to the model with the indicated key. + + :param model_key: Existing model key in the `model_config` table + :param metadata: ModelRepoMetadata object to update + """ + + @abstractmethod + def list_tags(self) -> Set[str]: + """Return all tags in the tags table.""" + + @abstractmethod + def search_by_tag(self, tags: Set[str]) -> Set[str]: + """Return the keys of models containing all of the listed tags.""" + + @abstractmethod + def search_by_author(self, author: str) -> Set[str]: + """Return the keys of models authored by the indicated author.""" + + @abstractmethod + def search_by_name(self, name: str) -> Set[str]: + """ + Return the keys of models with the indicated name. + + Note that this is the name of the model given to it by + the remote source. The user may have changed the local + name. The local name will be located in the model config + record object. + """ diff --git a/invokeai/app/services/model_metadata/metadata_store_sql.py b/invokeai/app/services/model_metadata/metadata_store_sql.py new file mode 100644 index 0000000000..afe9d2c8c6 --- /dev/null +++ b/invokeai/app/services/model_metadata/metadata_store_sql.py @@ -0,0 +1,222 @@ +# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team +""" +SQL Storage for Model Metadata +""" + +import sqlite3 +from typing import List, Optional, Set, Tuple + +from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase +from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException +from invokeai.backend.model_manager.metadata.fetch import ModelMetadataFetchBase + +from .metadata_store_base import ModelMetadataStoreBase + + +class ModelMetadataStoreSQL(ModelMetadataStoreBase): + """Store, search and fetch model metadata retrieved from remote repositories.""" + + def __init__(self, db: SqliteDatabase): + """ + Initialize a new object from preexisting sqlite3 connection and threading lock objects. + + :param conn: sqlite3 connection object + :param lock: threading Lock object + """ + super().__init__() + self._db = db + self._cursor = self._db.conn.cursor() + + def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None: + """ + Add a block of repo metadata to a model record. + + The model record config must already exist in the database with the + same key. Otherwise a FOREIGN KEY constraint exception will be raised. + + :param model_key: Existing model key in the `model_config` table + :param metadata: ModelRepoMetadata object to store + """ + json_serialized = metadata.model_dump_json() + with self._db.lock: + try: + self._cursor.execute( + """--sql + INSERT INTO model_metadata( + id, + metadata + ) + VALUES (?,?); + """, + ( + model_key, + json_serialized, + ), + ) + self._update_tags(model_key, metadata.tags) + self._db.conn.commit() + except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table + self._db.conn.rollback() + raise UnknownMetadataException from excp + except sqlite3.Error as excp: + self._db.conn.rollback() + raise excp + + def get_metadata(self, model_key: str) -> AnyModelRepoMetadata: + """Retrieve the ModelRepoMetadata corresponding to model key.""" + with self._db.lock: + self._cursor.execute( + """--sql + SELECT metadata FROM model_metadata + WHERE id=?; + """, + (model_key,), + ) + rows = self._cursor.fetchone() + if not rows: + raise UnknownMetadataException("model metadata not found") + return ModelMetadataFetchBase.from_json(rows[0]) + + def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata + """Dump out all the metadata.""" + with self._db.lock: + self._cursor.execute( + """--sql + SELECT id,metadata FROM model_metadata; + """, + (), + ) + rows = self._cursor.fetchall() + return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows] + + def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata: + """ + Update metadata corresponding to the model with the indicated key. + + :param model_key: Existing model key in the `model_config` table + :param metadata: ModelRepoMetadata object to update + """ + json_serialized = metadata.model_dump_json() # turn it into a json string. + with self._db.lock: + try: + self._cursor.execute( + """--sql + UPDATE model_metadata + SET + metadata=? + WHERE id=?; + """, + (json_serialized, model_key), + ) + if self._cursor.rowcount == 0: + raise UnknownMetadataException("model metadata not found") + self._update_tags(model_key, metadata.tags) + self._db.conn.commit() + except sqlite3.Error as e: + self._db.conn.rollback() + raise e + + return self.get_metadata(model_key) + + def list_tags(self) -> Set[str]: + """Return all tags in the tags table.""" + self._cursor.execute( + """--sql + select tag_text from tags; + """ + ) + return {x[0] for x in self._cursor.fetchall()} + + def search_by_tag(self, tags: Set[str]) -> Set[str]: + """Return the keys of models containing all of the listed tags.""" + with self._db.lock: + try: + matches: Optional[Set[str]] = None + for tag in tags: + self._cursor.execute( + """--sql + SELECT a.model_id FROM model_tags AS a, + tags AS b + WHERE a.tag_id=b.tag_id + AND b.tag_text=?; + """, + (tag,), + ) + model_keys = {x[0] for x in self._cursor.fetchall()} + if matches is None: + matches = model_keys + matches = matches.intersection(model_keys) + except sqlite3.Error as e: + raise e + return matches if matches else set() + + def search_by_author(self, author: str) -> Set[str]: + """Return the keys of models authored by the indicated author.""" + self._cursor.execute( + """--sql + SELECT id FROM model_metadata + WHERE author=?; + """, + (author,), + ) + return {x[0] for x in self._cursor.fetchall()} + + def search_by_name(self, name: str) -> Set[str]: + """ + Return the keys of models with the indicated name. + + Note that this is the name of the model given to it by + the remote source. The user may have changed the local + name. The local name will be located in the model config + record object. + """ + self._cursor.execute( + """--sql + SELECT id FROM model_metadata + WHERE name=?; + """, + (name,), + ) + return {x[0] for x in self._cursor.fetchall()} + + def _update_tags(self, model_key: str, tags: Set[str]) -> None: + """Update tags for the model referenced by model_key.""" + # remove previous tags from this model + self._cursor.execute( + """--sql + DELETE FROM model_tags + WHERE model_id=?; + """, + (model_key,), + ) + + for tag in tags: + self._cursor.execute( + """--sql + INSERT OR IGNORE INTO tags ( + tag_text + ) + VALUES (?); + """, + (tag,), + ) + self._cursor.execute( + """--sql + SELECT tag_id + FROM tags + WHERE tag_text = ? + LIMIT 1; + """, + (tag,), + ) + tag_id = self._cursor.fetchone()[0] + self._cursor.execute( + """--sql + INSERT OR IGNORE INTO model_tags ( + model_id, + tag_id + ) + VALUES (?,?); + """, + (model_key, tag_id), + ) diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py index b2eacc524b..d6014db448 100644 --- a/invokeai/app/services/model_records/model_records_base.py +++ b/invokeai/app/services/model_records/model_records_base.py @@ -17,7 +17,9 @@ from invokeai.backend.model_manager import ( ModelFormat, ModelType, ) -from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore +from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata + +from ..model_metadata import ModelMetadataStoreBase class DuplicateModelException(Exception): @@ -109,7 +111,7 @@ class ModelRecordServiceBase(ABC): @property @abstractmethod - def metadata_store(self) -> ModelMetadataStore: + def metadata_store(self) -> ModelMetadataStoreBase: """Return a ModelMetadataStore initialized on the same database.""" pass diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py index 84a1412383..dcd1114655 100644 --- a/invokeai/app/services/model_records/model_records_sql.py +++ b/invokeai/app/services/model_records/model_records_sql.py @@ -54,8 +54,9 @@ from invokeai.backend.model_manager.config import ( ModelFormat, ModelType, ) -from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException +from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException +from ..model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL from ..shared.sqlite.sqlite_database import SqliteDatabase from .model_records_base import ( DuplicateModelException, @@ -69,7 +70,7 @@ from .model_records_base import ( class ModelRecordServiceSQL(ModelRecordServiceBase): """Implementation of the ModelConfigStore ABC using a SQL database.""" - def __init__(self, db: SqliteDatabase): + def __init__(self, db: SqliteDatabase, metadata_store: ModelMetadataStoreBase): """ Initialize a new object from preexisting sqlite3 connection and threading lock objects. @@ -78,6 +79,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): super().__init__() self._db = db self._cursor = db.conn.cursor() + self._metadata_store = metadata_store @property def db(self) -> SqliteDatabase: @@ -157,7 +159,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): self._db.conn.rollback() raise e - def update_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig: + def update_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig: """ Update the model, returning the updated version. @@ -307,9 +309,9 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): return results @property - def metadata_store(self) -> ModelMetadataStore: + def metadata_store(self) -> ModelMetadataStoreBase: """Return a ModelMetadataStore initialized on the same database.""" - return ModelMetadataStore(self._db) + return self._metadata_store def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]: """ @@ -330,18 +332,18 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): :param tags: Set of tags to search for. All tags must be present. """ - store = ModelMetadataStore(self._db) + store = ModelMetadataStoreSQL(self._db) keys = store.search_by_tag(tags) return [self.get_model(x) for x in keys] def list_tags(self) -> Set[str]: """Return a unique set of all the model tags in the metadata database.""" - store = ModelMetadataStore(self._db) + store = ModelMetadataStoreSQL(self._db) return store.list_tags() def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: """List metadata for all models that have it.""" - store = ModelMetadataStore(self._db) + store = ModelMetadataStoreSQL(self._db) return store.list_all_metadata() def list_models( diff --git a/invokeai/backend/install/install_helper.py b/invokeai/backend/install/install_helper.py index 9c386c209c..3623b623a9 100644 --- a/invokeai/backend/install/install_helper.py +++ b/invokeai/backend/install/install_helper.py @@ -25,6 +25,7 @@ from invokeai.app.services.model_install import ( ModelSource, URLModelSource, ) +from invokeai.app.services.model_metadata import ModelMetadataStoreSQL from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL from invokeai.app.services.shared.sqlite.sqlite_util import init_db from invokeai.backend.model_manager import ( @@ -45,7 +46,7 @@ def initialize_record_store(app_config: InvokeAIAppConfig) -> ModelRecordService logger = InvokeAILogger.get_logger(config=app_config) image_files = DiskImageFileStorage(f"{app_config.output_path}/images") db = init_db(config=app_config, logger=logger, image_files=image_files) - obj: ModelRecordServiceBase = ModelRecordServiceSQL(db) + obj: ModelRecordServiceBase = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db)) return obj @@ -54,12 +55,10 @@ def initialize_installer( ) -> ModelInstallServiceBase: """Return an initialized ModelInstallService object.""" record_store = initialize_record_store(app_config) - metadata_store = record_store.metadata_store download_queue = DownloadQueueService() installer = ModelInstallService( app_config=app_config, record_store=record_store, - metadata_store=metadata_store, download_queue=download_queue, event_bus=event_bus, ) @@ -287,14 +286,14 @@ class InstallHelper(object): model_name=model_name, ) if len(matches) > 1: - print( - f"{model_to_remove} is ambiguous. Please use model_base/model_type/model_name (e.g. sd-1/main/my_model) to disambiguate." + self._logger.error( + "{model_to_remove} is ambiguous. Please use model_base/model_type/model_name (e.g. sd-1/main/my_model) to disambiguate" ) elif not matches: - print(f"{model_to_remove}: unknown model") + self._logger.error(f"{model_to_remove}: unknown model") else: for m in matches: - print(f"Deleting {m.type}:{m.name}") + self._logger.info(f"Deleting {m.type}:{m.name}") installer.delete(m.key) installer.wait_for_installs() diff --git a/invokeai/backend/model_manager/load/__init__.py b/invokeai/backend/model_manager/load/__init__.py index 966a739237..a3a840b625 100644 --- a/invokeai/backend/model_manager/load/__init__.py +++ b/invokeai/backend/model_manager/load/__init__.py @@ -4,10 +4,6 @@ Init file for the model loader. """ from importlib import import_module from pathlib import Path -from typing import Optional - -from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.util.logging import InvokeAILogger from .convert_cache.convert_cache_default import ModelConvertCache from .load_base import AnyModelLoader, LoadedModel @@ -19,16 +15,3 @@ for module in loaders: import_module(f"{__package__}.model_loaders.{module}") __all__ = ["AnyModelLoader", "LoadedModel", "ModelCache", "ModelConvertCache"] - - -def get_standalone_loader(app_config: Optional[InvokeAIAppConfig]) -> AnyModelLoader: - app_config = app_config or InvokeAIAppConfig.get_config() - logger = InvokeAILogger.get_logger(config=app_config) - return AnyModelLoader( - app_config=app_config, - logger=logger, - ram_cache=ModelCache( - logger=logger, max_cache_size=app_config.ram_cache_size, max_vram_cache_size=app_config.vram_cache_size - ), - convert_cache=ModelConvertCache(app_config.models_convert_cache_path), - ) diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py index 7649dee762..4c5e899aa3 100644 --- a/invokeai/backend/model_manager/load/load_base.py +++ b/invokeai/backend/model_manager/load/load_base.py @@ -39,21 +39,21 @@ class LoadedModel: """Context manager object that mediates transfer from RAM<->VRAM.""" config: AnyModelConfig - locker: ModelLockerBase + _locker: ModelLockerBase def __enter__(self) -> AnyModel: """Context entry.""" - self.locker.lock() + self._locker.lock() return self.model def __exit__(self, *args: Any, **kwargs: Any) -> None: """Context exit.""" - self.locker.unlock() + self._locker.unlock() @property def model(self) -> AnyModel: """Return the model without locking it.""" - return self.locker.model + return self._locker.model class ModelLoaderBase(ABC): diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index 1dac121a30..79c9311de1 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -75,7 +75,7 @@ class ModelLoader(ModelLoaderBase): model_path = self._convert_if_needed(model_config, model_path, submodel_type) locker = self._load_if_needed(model_config, model_path, submodel_type) - return LoadedModel(config=model_config, locker=locker) + return LoadedModel(config=model_config, _locker=locker) def _get_model_path( self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None diff --git a/invokeai/backend/model_manager/merge.py b/invokeai/backend/model_manager/merge.py index 2c94af4af3..108f1f0e6f 100644 --- a/invokeai/backend/model_manager/merge.py +++ b/invokeai/backend/model_manager/merge.py @@ -39,10 +39,7 @@ class ModelMerger(object): def __init__(self, installer: ModelInstallServiceBase): """ - Initialize a ModelMerger object. - - :param store: Underlying storage manager for the running process. - :param config: InvokeAIAppConfig object (if not provided, default will be selected). + Initialize a ModelMerger object with the model installer. """ self._installer = installer diff --git a/invokeai/backend/model_manager/metadata/__init__.py b/invokeai/backend/model_manager/metadata/__init__.py index 672e378c7f..a35e55f3d2 100644 --- a/invokeai/backend/model_manager/metadata/__init__.py +++ b/invokeai/backend/model_manager/metadata/__init__.py @@ -18,7 +18,7 @@ assert isinstance(data, CivitaiMetadata) if data.allow_commercial_use: print("Commercial use of this model is allowed") """ -from .fetch import CivitaiMetadataFetch, HuggingFaceMetadataFetch +from .fetch import CivitaiMetadataFetch, HuggingFaceMetadataFetch, ModelMetadataFetchBase from .metadata_base import ( AnyModelRepoMetadata, AnyModelRepoMetadataValidator, @@ -31,7 +31,6 @@ from .metadata_base import ( RemoteModelFile, UnknownMetadataException, ) -from .metadata_store import ModelMetadataStore __all__ = [ "AnyModelRepoMetadata", @@ -42,7 +41,7 @@ __all__ = [ "HuggingFaceMetadata", "HuggingFaceMetadataFetch", "LicenseRestrictions", - "ModelMetadataStore", + "ModelMetadataFetchBase", "BaseMetadata", "ModelMetadataWithFiles", "RemoteModelFile", diff --git a/invokeai/frontend/merge/merge_diffusers.py b/invokeai/frontend/merge/merge_diffusers.py index 92b98b52f9..5484040674 100644 --- a/invokeai/frontend/merge/merge_diffusers.py +++ b/invokeai/frontend/merge/merge_diffusers.py @@ -6,20 +6,40 @@ Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team """ import argparse import curses +import re import sys from argparse import Namespace from pathlib import Path -from typing import List +from typing import List, Optional, Tuple import npyscreen from npyscreen import widget -import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.model_management import BaseModelType, ModelManager, ModelMerger, ModelType +from invokeai.app.services.download import DownloadQueueService +from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage +from invokeai.app.services.model_install import ModelInstallService +from invokeai.app.services.model_metadata import ModelMetadataStoreSQL +from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL +from invokeai.app.services.shared.sqlite.sqlite_util import init_db +from invokeai.backend.model_manager import ( + BaseModelType, + ModelFormat, + ModelType, + ModelVariantType, +) +from invokeai.backend.model_manager.merge import ModelMerger +from invokeai.backend.util.logging import InvokeAILogger from invokeai.frontend.install.widgets import FloatTitleSlider, SingleSelectColumns, TextBox config = InvokeAIAppConfig.get_config() +logger = InvokeAILogger.get_logger() + +BASE_TYPES = [ + (BaseModelType.StableDiffusion1, "Models Built on SD-1.x"), + (BaseModelType.StableDiffusion2, "Models Built on SD-2.x"), + (BaseModelType.StableDiffusionXL, "Models Built on SDXL"), +] def _parse_args() -> Namespace: @@ -48,7 +68,7 @@ def _parse_args() -> Namespace: parser.add_argument( "--base_model", type=str, - choices=[x.value for x in BaseModelType], + choices=[x[0].value for x in BASE_TYPES], help="The base model shared by the models to be merged", ) parser.add_argument( @@ -98,17 +118,17 @@ class mergeModelsForm(npyscreen.FormMultiPageAction): super().__init__(parentApp, name) @property - def model_manager(self): - return self.parentApp.model_manager + def record_store(self): + return self.parentApp.record_store def afterEditing(self): self.parentApp.setNextForm(None) def create(self): window_height, window_width = curses.initscr().getmaxyx() - - self.model_names = self.get_model_names() self.current_base = 0 + self.models = self.get_models(BASE_TYPES[self.current_base][0]) + self.model_names = [x[1] for x in self.models] max_width = max([len(x) for x in self.model_names]) max_width += 6 horizontal_layout = max_width * 3 < window_width @@ -128,11 +148,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction): self.nextrely += 1 self.base_select = self.add_widget_intelligent( SingleSelectColumns, - values=[ - "Models Built on SD-1.x", - "Models Built on SD-2.x", - "Models Built on SDXL", - ], + values=[x[1] for x in BASE_TYPES], value=[self.current_base], columns=4, max_height=2, @@ -263,21 +279,20 @@ class mergeModelsForm(npyscreen.FormMultiPageAction): sys.exit(0) def marshall_arguments(self) -> dict: - model_names = self.model_names + model_keys = [x[0] for x in self.models] models = [ - model_names[self.model1.value[0]], - model_names[self.model2.value[0]], + model_keys[self.model1.value[0]], + model_keys[self.model2.value[0]], ] if self.model3.value[0] > 0: - models.append(model_names[self.model3.value[0] - 1]) + models.append(model_keys[self.model3.value[0] - 1]) interp = "add_difference" else: interp = self.interpolations[self.merge_method.value[0]] - bases = ["sd-1", "sd-2", "sdxl"] args = { - "model_names": models, - "base_model": BaseModelType(bases[self.base_select.value[0]]), + "model_keys": models, + "base_model": tuple(BaseModelType)[self.base_select.value[0]], "alpha": self.alpha.value, "interp": interp, "force": self.force.value, @@ -311,18 +326,18 @@ class mergeModelsForm(npyscreen.FormMultiPageAction): else: return True - def get_model_names(self, base_model: BaseModelType = BaseModelType.StableDiffusion1) -> List[str]: - model_names = [ - info["model_name"] - for info in self.model_manager.list_models(model_type=ModelType.Main, base_model=base_model) - if info["model_format"] == "diffusers" + def get_models(self, base_model: Optional[BaseModelType] = None) -> List[Tuple[str, str]]: # key to name + models = [ + (x.key, x.name) + for x in self.record_store.search_by_attr(model_type=ModelType.Main, base_model=base_model) + if x.format == ModelFormat("diffusers") and x.variant == ModelVariantType("normal") ] - return sorted(model_names) + return sorted(models, key=lambda x: x[1]) - def _populate_models(self, value=None): - bases = ["sd-1", "sd-2", "sdxl"] - base_model = BaseModelType(bases[value[0]]) - self.model_names = self.get_model_names(base_model) + def _populate_models(self, value: List[int]): + base_model = BASE_TYPES[value[0]][0] + self.models = self.get_models(base_model) + self.model_names = [x[1] for x in self.models] models_plus_none = self.model_names.copy() models_plus_none.insert(0, "None") @@ -334,24 +349,24 @@ class mergeModelsForm(npyscreen.FormMultiPageAction): class Mergeapp(npyscreen.NPSAppManaged): - def __init__(self, model_manager: ModelManager): + def __init__(self, record_store: ModelRecordServiceBase): super().__init__() - self.model_manager = model_manager + self.record_store = record_store def onStart(self): npyscreen.setTheme(npyscreen.Themes.ElegantTheme) self.main = self.addForm("MAIN", mergeModelsForm, name="Merge Models Settings") -def run_gui(args: Namespace): - model_manager = ModelManager(config.model_conf_path) - mergeapp = Mergeapp(model_manager) +def run_gui(args: Namespace) -> None: + record_store: ModelRecordServiceBase = get_config_store() + mergeapp = Mergeapp(record_store) mergeapp.run() - args = mergeapp.merge_arguments - merger = ModelMerger(model_manager) + merger = get_model_merger(record_store) merger.merge_diffusion_models_and_save(**args) - logger.info(f'Models merged into new model: "{args["merged_model_name"]}".') + merged_model_name = args["merged_model_name"] + logger.info(f'Models merged into new model: "{merged_model_name}".') def run_cli(args: Namespace): @@ -364,20 +379,54 @@ def run_cli(args: Namespace): args.merged_model_name = "+".join(args.model_names) logger.info(f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"') - model_manager = ModelManager(config.model_conf_path) + record_store: ModelRecordServiceBase = get_config_store() assert ( - not model_manager.model_exists(args.merged_model_name, args.base_model, ModelType.Main) or args.clobber + len(record_store.search_by_attr(args.merged_model_name, args.base_model, ModelType.Main)) == 0 or args.clobber ), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.' - merger = ModelMerger(model_manager) - merger.merge_diffusion_models_and_save(**vars(args)) + merger = get_model_merger(record_store) + model_keys = [] + for name in args.model_names: + if len(name) == 32 and re.match(r"^[0-9a-f]$", name): + model_keys.append(name) + else: + models = record_store.search_by_attr( + model_name=name, model_type=ModelType.Main, base_model=BaseModelType(args.base_model) + ) + assert len(models) > 0, f"{name}: Unknown model" + assert len(models) < 2, f"{name}: More than one model by this name. Please specify the model key instead." + model_keys.append(models[0].key) + + merger.merge_diffusion_models_and_save( + alpha=args.alpha, + model_keys=model_keys, + merged_model_name=args.merged_model_name, + interp=args.interp, + force=args.force, + ) logger.info(f'Models merged into new model: "{args.merged_model_name}".') +def get_config_store() -> ModelRecordServiceSQL: + output_path = config.output_path + assert output_path is not None + image_files = DiskImageFileStorage(output_path / "images") + db = init_db(config=config, logger=InvokeAILogger.get_logger(), image_files=image_files) + return ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db)) + + +def get_model_merger(record_store: ModelRecordServiceBase) -> ModelMerger: + installer = ModelInstallService(app_config=config, record_store=record_store, download_queue=DownloadQueueService()) + installer.start() + return ModelMerger(installer) + + def main(): args = _parse_args() if args.root_dir: config.parse_args(["--root", str(args.root_dir)]) + else: + config.parse_args([]) try: if args.front_end: diff --git a/tests/aa_nodes/test_invoker.py b/tests/aa_nodes/test_invoker.py index 774f7501dc..f67b5a2ac5 100644 --- a/tests/aa_nodes/test_invoker.py +++ b/tests/aa_nodes/test_invoker.py @@ -1,4 +1,5 @@ import logging +from unittest.mock import Mock import pytest @@ -64,7 +65,7 @@ def mock_services() -> InvocationServices: images=None, # type: ignore invocation_cache=MemoryInvocationCache(max_cache_size=0), logger=logging, # type: ignore - model_manager=None, # type: ignore + model_manager=Mock(), # type: ignore download_queue=None, # type: ignore names=None, # type: ignore performance_statistics=InvocationStatsService(), diff --git a/tests/app/services/model_records/test_model_records_sql.py b/tests/app/services/model_records/test_model_records_sql.py index 46afe0105b..852e1da979 100644 --- a/tests/app/services/model_records/test_model_records_sql.py +++ b/tests/app/services/model_records/test_model_records_sql.py @@ -8,6 +8,7 @@ from typing import Any import pytest from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.app.services.model_metadata import ModelMetadataStoreSQL from invokeai.app.services.model_records import ( DuplicateModelException, ModelRecordOrderBy, @@ -36,7 +37,7 @@ def store( config = InvokeAIAppConfig(root=datadir) logger = InvokeAILogger.get_logger(config=config) db = create_mock_sqlite_database(config, logger) - return ModelRecordServiceSQL(db) + return ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db)) def example_config() -> TextualInversionConfig: diff --git a/tests/backend/model_manager_2/model_manager_2_fixtures.py b/tests/backend/model_manager_2/model_manager_2_fixtures.py index d85eab67dd..ebdc9cb5cd 100644 --- a/tests/backend/model_manager_2/model_manager_2_fixtures.py +++ b/tests/backend/model_manager_2/model_manager_2_fixtures.py @@ -14,6 +14,7 @@ from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.download import DownloadQueueService from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase +from invokeai.app.services.model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL from invokeai.app.services.model_records import ModelRecordServiceSQL from invokeai.backend.model_manager.config import ( BaseModelType, @@ -21,7 +22,6 @@ from invokeai.backend.model_manager.config import ( ModelType, ) from invokeai.backend.model_manager.load import AnyModelLoader, ModelCache, ModelConvertCache -from invokeai.backend.model_manager.metadata import ModelMetadataStore from invokeai.backend.util.logging import InvokeAILogger from tests.backend.model_manager_2.model_metadata.metadata_examples import ( RepoCivitaiModelMetadata1, @@ -104,7 +104,7 @@ def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordS def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL: logger = InvokeAILogger.get_logger(config=mm2_app_config) db = create_mock_sqlite_database(mm2_app_config, logger) - store = ModelRecordServiceSQL(db) + store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db)) # add five simple config records to the database raw1 = { "path": "/tmp/foo1", @@ -163,15 +163,14 @@ def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL @pytest.fixture -def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStore: - db = mm2_record_store._db # to ensure we are sharing the same database - return ModelMetadataStore(db) +def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStoreBase: + return mm2_record_store.metadata_store @pytest.fixture def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session: """This fixtures defines a series of mock URLs for testing download and installation.""" - sess = TestSession() + sess: Session = TestSession() sess.mount( "https://test.com/missing_model.safetensors", TestAdapter( @@ -258,8 +257,7 @@ def mm2_installer(mm2_app_config: InvokeAIAppConfig, mm2_session: Session) -> Mo logger = InvokeAILogger.get_logger() db = create_mock_sqlite_database(mm2_app_config, logger) events = DummyEventService() - store = ModelRecordServiceSQL(db) - metadata_store = ModelMetadataStore(db) + store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db)) download_queue = DownloadQueueService(requests_session=mm2_session) download_queue.start() @@ -268,7 +266,6 @@ def mm2_installer(mm2_app_config: InvokeAIAppConfig, mm2_session: Session) -> Mo app_config=mm2_app_config, record_store=store, download_queue=download_queue, - metadata_store=metadata_store, event_bus=events, session=mm2_session, ) diff --git a/tests/backend/model_manager_2/model_metadata/test_model_metadata.py b/tests/backend/model_manager_2/model_metadata/test_model_metadata.py index 5a2ec93767..f61eab1b5d 100644 --- a/tests/backend/model_manager_2/model_metadata/test_model_metadata.py +++ b/tests/backend/model_manager_2/model_metadata/test_model_metadata.py @@ -8,6 +8,7 @@ import pytest from pydantic.networks import HttpUrl from requests.sessions import Session +from invokeai.app.services.model_metadata import ModelMetadataStoreBase from invokeai.backend.model_manager.config import ModelRepoVariant from invokeai.backend.model_manager.metadata import ( CivitaiMetadata, @@ -15,14 +16,13 @@ from invokeai.backend.model_manager.metadata import ( CommercialUsage, HuggingFaceMetadata, HuggingFaceMetadataFetch, - ModelMetadataStore, UnknownMetadataException, ) from invokeai.backend.model_manager.util import select_hf_files from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403 -def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStore) -> None: +def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStoreBase) -> None: tags = {"text-to-image", "diffusers"} input_metadata = HuggingFaceMetadata( name="sdxl-vae", @@ -40,7 +40,7 @@ def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStore) -> None: assert mm2_metadata_store.list_tags() == tags -def test_metadata_store_update(mm2_metadata_store: ModelMetadataStore) -> None: +def test_metadata_store_update(mm2_metadata_store: ModelMetadataStoreBase) -> None: input_metadata = HuggingFaceMetadata( name="sdxl-vae", author="stabilityai", @@ -57,7 +57,7 @@ def test_metadata_store_update(mm2_metadata_store: ModelMetadataStore) -> None: assert input_metadata == output_metadata -def test_metadata_search(mm2_metadata_store: ModelMetadataStore) -> None: +def test_metadata_search(mm2_metadata_store: ModelMetadataStoreBase) -> None: metadata1 = HuggingFaceMetadata( name="sdxl-vae", author="stabilityai",