diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py index 9c85f52776..03fbd5e004 100644 --- a/invokeai/app/services/model_records/model_records_base.py +++ b/invokeai/app/services/model_records/model_records_base.py @@ -10,7 +10,7 @@ from typing import List, Optional, Union from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelConfigBase, ModelType # should match the InvokeAI version when this is first released. -CONFIG_FILE_VERSION = "3.2" +CONFIG_FILE_VERSION = "3.2.0" class DuplicateModelException(Exception): diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py index 2bf3e3acf4..2d2b87c5a4 100644 --- a/invokeai/app/services/model_records/model_records_sql.py +++ b/invokeai/app/services/model_records/model_records_sql.py @@ -66,9 +66,8 @@ from .model_records_base import ( class ModelRecordServiceSQL(ModelRecordServiceBase): """Implementation of the ModelConfigStore ABC using a SQL database.""" - _conn: sqlite3.Connection + _db: SqliteDatabase _cursor: sqlite3.Cursor - _lock: threading.Lock def __init__(self, db: SqliteDatabase): """ @@ -78,16 +77,15 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): :param lock: threading Lock object """ super().__init__() - self._conn = db.conn - self._lock = db.lock - self._conn.row_factory = sqlite3.Row - self._cursor = self._conn.cursor() + self._db = db + self._db.conn.row_factory = sqlite3.Row + self._cursor = self._db.conn.cursor() - with self._lock: + with self._db.lock: # Enable foreign keys - self._conn.execute("PRAGMA foreign_keys = ON;") + self._db.conn.execute("PRAGMA foreign_keys = ON;") self._create_tables() - self._conn.commit() + self._db.conn.commit() assert ( str(self.version) == CONFIG_FILE_VERSION ), f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}" @@ -172,7 +170,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): """ record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect. json_serialized = json.dumps(record.model_dump()) # and turn it into a json string. - with self._lock: + with self._db.lock: try: self._cursor.execute( """--sql @@ -197,10 +195,10 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): json_serialized, ), ) - self._conn.commit() + self._db.conn.commit() except sqlite3.IntegrityError as e: - self._conn.rollback() + self._db.conn.rollback() if "UNIQUE constraint failed" in str(e): if "model_config.path" in str(e): msg = f"A model with path '{record.path}' is already installed" @@ -210,7 +208,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): else: raise e except sqlite3.Error as e: - self._conn.rollback() + self._db.conn.rollback() raise e return self.get_model(key) @@ -218,7 +216,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): @property def version(self) -> str: """Return the version of the database schema.""" - with self._lock: + with self._db.lock: self._cursor.execute( """--sql SELECT metadata_value FROM model_manager_metadata @@ -239,7 +237,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): Can raise an UnknownModelException """ - with self._lock: + with self._db.lock: try: self._cursor.execute( """--sql @@ -250,9 +248,9 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): ) if self._cursor.rowcount == 0: raise UnknownModelException("model not found") - self._conn.commit() + self._db.conn.commit() except sqlite3.Error as e: - self._conn.rollback() + self._db.conn.rollback() raise e def update_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase: @@ -265,7 +263,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): """ record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect json_serialized = json.dumps(record.model_dump()) # and turn it into a json string. - with self._lock: + with self._db.lock: try: self._cursor.execute( """--sql @@ -281,9 +279,9 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): ) if self._cursor.rowcount == 0: raise UnknownModelException("model not found") - self._conn.commit() + self._db.conn.commit() except sqlite3.Error as e: - self._conn.rollback() + self._db.conn.rollback() raise e return self.get_model(key) @@ -296,7 +294,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): Exceptions: UnknownModelException """ - with self._lock: + with self._db.lock: self._cursor.execute( """--sql SELECT config FROM model_config @@ -317,7 +315,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): :param key: Unique key for the model to be deleted """ count = 0 - with self._lock: + with self._db.lock: try: self._cursor.execute( """--sql @@ -360,7 +358,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): where_clause.append("type=?") bindings.append(model_type) where = f"WHERE {' AND '.join(where_clause)}" if where_clause else "" - with self._lock: + with self._db.lock: try: self._cursor.execute( f"""--sql @@ -377,7 +375,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): def search_by_path(self, path: Union[str, Path]) -> List[ModelConfigBase]: """Return models with the indicated path.""" results = [] - with self._lock: + with self._db.lock: try: self._cursor.execute( """--sql @@ -394,7 +392,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): def search_by_hash(self, hash: str) -> List[ModelConfigBase]: """Return models with the indicated original_hash.""" results = [] - with self._lock: + with self._db.lock: try: self._cursor.execute( """--sql diff --git a/invokeai/backend/model_manager/migrate_to_db.py b/invokeai/backend/model_manager/migrate_to_db.py index 6cb7f478d1..7d4f460cd0 100644 --- a/invokeai/backend/model_manager/migrate_to_db.py +++ b/invokeai/backend/model_manager/migrate_to_db.py @@ -16,8 +16,19 @@ from invokeai.backend.util.logging import InvokeAILogger ModelsValidator = TypeAdapter(AnyModelConfig) -class Migrate: - """Migration class.""" +class MigrateModelYamlToDb: + """ + Migrate the InvokeAI models.yaml format (VERSION 3.0.0) to SQL3 database format (VERSION 3.2.0) + + The class has one externally useful method, migrate(), which scans the + currently models.yaml file and imports all its entries into invokeai.db. + + Use this way: + + from invokeai.backend.model_manager/migrate_to_db import MigrateModelYamlToDb + MigrateModelYamlToDb().migrate() + + """ config: InvokeAIAppConfig logger: InvokeAILogger @@ -28,14 +39,17 @@ class Migrate: self.logger = InvokeAILogger.get_logger() def get_db(self) -> ModelRecordServiceSQL: + """Fetch the sqlite3 database for this installation.""" db = SqliteDatabase(self.config, self.logger) return ModelRecordServiceSQL(db) def get_yaml(self) -> DictConfig: + """Fetch the models.yaml DictConfig for this installation.""" yaml_path = self.config.model_conf_path return OmegaConf.load(yaml_path) def migrate(self): + """Do the migration from models.yaml to invokeai.db.""" db = self.get_db() yaml = self.get_yaml() @@ -65,7 +79,7 @@ class Migrate: def main(): - Migrate().migrate() + MigrateModelYamlToDb().migrate() if __name__ == "__main__": diff --git a/invokeai/backend/util/util.py b/invokeai/backend/util/util.py index 5d64a51c3a..d10d5a0a27 100644 --- a/invokeai/backend/util/util.py +++ b/invokeai/backend/util/util.py @@ -386,13 +386,10 @@ class Chdir(object): class SilenceWarnings(object): """Context manager to temporarily lower verbosity of diffusers & transformers warning messages.""" - def __init__(self): - """Set up context, save current transformers and diffusers verbosity settings.""" - self.transformers_verbosity = transformers_logging.get_verbosity() - self.diffusers_verbosity = diffusers_logging.get_verbosity() - def __enter__(self): """Set verbosity to error.""" + self.transformers_verbosity = transformers_logging.get_verbosity() + self.diffusers_verbosity = diffusers_logging.get_verbosity() transformers_logging.set_verbosity_error() diffusers_logging.set_verbosity_error() warnings.simplefilter("ignore") diff --git a/tests/backend/model_manager_2/test_model_storage_sql.py b/tests/app/services/model_records/test_model_records_sql.py similarity index 68% rename from tests/backend/model_manager_2/test_model_storage_sql.py rename to tests/app/services/model_records/test_model_records_sql.py index 308ec7a66f..5dd41485e1 100644 --- a/tests/backend/model_manager_2/test_model_storage_sql.py +++ b/tests/app/services/model_records/test_model_records_sql.py @@ -67,16 +67,10 @@ def test_add(store: ModelRecordServiceBase): def test_dup(store: ModelRecordServiceBase): config = example_config() store.add_model("key1", example_config()) - try: + with pytest.raises(DuplicateModelException): store.add_model("key1", config) - assert False, "Duplicate model key should have been caught" - except DuplicateModelException: - assert True - try: + with pytest.raises(DuplicateModelException): store.add_model("key2", config) - assert False, "Duplicate model path should have been caught" - except DuplicateModelException: - assert True def test_update(store: ModelRecordServiceBase): @@ -90,11 +84,12 @@ def test_update(store: ModelRecordServiceBase): new_config = store.get_model("key1") assert new_config.name == "new name" - try: + +def test_unknown_key(store: ModelRecordServiceBase): + config = example_config() + store.add_model("key1", config) + with pytest.raises(UnknownModelException): store.update_model("unknown_key", config) - assert False, "expected UnknownModelException" - except UnknownModelException: - assert True def test_delete(store: ModelRecordServiceBase): @@ -102,20 +97,8 @@ def test_delete(store: ModelRecordServiceBase): store.add_model("key1", config) config = store.get_model("key1") store.del_model("key1") - try: + with pytest.raises(UnknownModelException): config = store.get_model("key1") - assert False, "expected fetch of deleted model to raise exception" - except UnknownModelException: - assert True - - # a bug in sqlite3 in python 3.9 prevents DEL from returning number of - # deleted rows! - if sys.version_info.major == 3 and sys.version_info.minor > 9: - try: - store.del_model("unknown") - assert False, "expected delete of unknown model to raise exception" - except UnknownModelException: - assert True def test_exists(store: ModelRecordServiceBase): @@ -167,3 +150,62 @@ def test_filter(store: ModelRecordServiceBase): matches = store.all_models() assert len(matches) == 3 + + +def test_filter_2(store: ModelRecordServiceBase): + config1 = DiffusersConfig( + path="/tmp/config1", + name="config1", + base=BaseModelType("sd-1"), + type=ModelType("main"), + original_hash="CONFIG1HASH", + ) + config2 = DiffusersConfig( + path="/tmp/config2", + name="config2", + base=BaseModelType("sd-1"), + type=ModelType("main"), + original_hash="CONFIG2HASH", + ) + config3 = DiffusersConfig( + path="/tmp/config3", + name="dup_name1", + base=BaseModelType("sd-2"), + type=ModelType("main"), + original_hash="CONFIG3HASH", + ) + config4 = DiffusersConfig( + path="/tmp/config4", + name="dup_name1", + base=BaseModelType("sd-2"), + type=ModelType("main"), + original_hash="CONFIG3HASH", + ) + config5 = VaeDiffusersConfig( + path="/tmp/config5", + name="dup_name1", + base=BaseModelType("sd-1"), + type=ModelType("vae"), + original_hash="CONFIG3HASH", + ) + for c in config1, config2, config3, config4, config5: + store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), c) + + matches = store.search_by_name( + model_type=ModelType("main"), + model_name="dup_name1", + ) + assert len(matches) == 2 + + matches = store.search_by_name( + base_model=BaseModelType("sd-1"), + model_type=ModelType("main"), + ) + assert len(matches) == 2 + + matches = store.search_by_name( + base_model=BaseModelType("sd-1"), + model_type=ModelType("vae"), + model_name="dup_name1", + ) + assert len(matches) == 1