From 3414437eea0f29848ec521cbe83f80e922471052 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 12 Dec 2023 10:46:08 +1100 Subject: [PATCH] feat(db): instantiate SqliteMigrator with a SqliteDatabase Simplifies a couple things: - Init is more straightforward - It's clear in the migrator that the connection we are working with is related to the SqliteDatabase --- invokeai/app/api/dependencies.py | 8 +-- .../sqlite_migrator/sqlite_migrator_impl.py | 54 ++++++--------- tests/aa_nodes/test_graph_execution_state.py | 8 +-- tests/aa_nodes/test_invoker.py | 8 +-- .../model_install/test_model_install.py | 8 +-- .../model_records/test_model_records_sql.py | 8 +-- tests/test_sqlite_migrator.py | 65 ++++++++++--------- 7 files changed, 58 insertions(+), 101 deletions(-) diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 32b3e40715..151c8edf7f 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -77,13 +77,7 @@ class ApiDependencies: db_path = None if config.use_memory_db else config.db_path db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql) - migrator = SQLiteMigrator( - db_path=db.db_path, - conn=db.conn, - lock=db.lock, - logger=logger, - log_sql=config.log_sql, - ) + migrator = SQLiteMigrator(db=db) migration_2.register_post_callback(partial(migrate_embedded_workflows, logger=logger, image_files=image_files)) migrator.register_migration(migration_1) migrator.register_migration(migration_2) diff --git a/invokeai/app/services/shared/sqlite_migrator/sqlite_migrator_impl.py b/invokeai/app/services/shared/sqlite_migrator/sqlite_migrator_impl.py index 129d68451c..33035b58c2 100644 --- a/invokeai/app/services/shared/sqlite_migrator/sqlite_migrator_impl.py +++ b/invokeai/app/services/shared/sqlite_migrator/sqlite_migrator_impl.py @@ -1,11 +1,10 @@ import shutil import sqlite3 -import threading from datetime import datetime -from logging import Logger from pathlib import Path from typing import Optional +from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration, MigrationError, MigrationSet @@ -30,26 +29,15 @@ class SQLiteMigrator: backup_path: Optional[Path] = None - def __init__( - self, - db_path: Path | None, - conn: sqlite3.Connection, - lock: threading.RLock, - logger: Logger, - log_sql: bool = False, - ) -> None: - self._lock = lock - self._db_path = db_path - self._logger = logger - self._conn = conn - self._log_sql = log_sql - self._cursor = self._conn.cursor() + def __init__(self, db: SqliteDatabase) -> None: + self._db = db + self._logger = db.logger self._migration_set = MigrationSet() # The presence of an temp database file indicates a catastrophic failure of a previous migration. - if self._db_path and self._get_temp_db_path(self._db_path).is_file(): + if self._db.db_path and self._get_temp_db_path(self._db.db_path).is_file(): self._logger.warning("Previous migration failed! Trying again...") - self._get_temp_db_path(self._db_path).unlink() + self._get_temp_db_path(self._db.db_path).unlink() def register_migration(self, migration: Migration) -> None: """Registers a migration.""" @@ -58,43 +46,41 @@ class SQLiteMigrator: def run_migrations(self) -> bool: """Migrates the database to the latest version.""" - with self._lock: + with self._db.lock: # This throws if there is a problem. self._migration_set.validate_migration_chain() - self._create_migrations_table(cursor=self._cursor) + cursor = self._db.conn.cursor() + self._create_migrations_table(cursor=cursor) if self._migration_set.count == 0: self._logger.debug("No migrations registered") return False - if self._get_current_version(self._cursor) == self._migration_set.latest_version: + if self._get_current_version(cursor=cursor) == self._migration_set.latest_version: self._logger.debug("Database is up to date, no migrations to run") return False self._logger.info("Database update needed") - if self._db_path: + if self._db.db_path: # We are using a file database. Create a copy of the database to run the migrations on. - temp_db_path = self._create_temp_db(self._db_path) + temp_db_path = self._create_temp_db(self._db.db_path) self._logger.info(f"Copied database to {temp_db_path} for migration") - temp_db_conn = sqlite3.connect(temp_db_path) - # We have to re-set this because we just created a new connection. - if self._log_sql: - temp_db_conn.set_trace_callback(self._logger.debug) - temp_db_cursor = temp_db_conn.cursor() + temp_db = SqliteDatabase(db_path=temp_db_path, logger=self._logger, verbose=self._db.verbose) + temp_db_cursor = temp_db.conn.cursor() self._run_migrations(temp_db_cursor) # Close the connections, copy the original database as a backup, and move the temp database to the # original database's path. - temp_db_conn.close() - self._conn.close() + temp_db.close() + self._db.close() backup_db_path = self._finalize_migration( temp_db_path=temp_db_path, - original_db_path=self._db_path, + original_db_path=self._db.db_path, ) self._logger.info(f"Migration successful. Original DB backed up to {backup_db_path}") else: # We are using a memory database. No special backup or special handling needed. - self._run_migrations(self._cursor) + self._run_migrations(cursor) self._logger.info("Database updated successfully") return True @@ -108,7 +94,7 @@ class SQLiteMigrator: def _run_migration(self, migration: Migration, temp_db_cursor: sqlite3.Cursor) -> None: """Runs a single migration.""" - with self._lock: + with self._db.lock: try: if self._get_current_version(temp_db_cursor) != migration.from_version: raise MigrationError( @@ -145,7 +131,7 @@ class SQLiteMigrator: def _create_migrations_table(self, cursor: sqlite3.Cursor) -> None: """Creates the migrations table for the database, if one does not already exist.""" - with self._lock: + with self._db.lock: try: cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='migrations';") if cursor.fetchone() is not None: diff --git a/tests/aa_nodes/test_graph_execution_state.py b/tests/aa_nodes/test_graph_execution_state.py index b2596ea181..609b0c3736 100644 --- a/tests/aa_nodes/test_graph_execution_state.py +++ b/tests/aa_nodes/test_graph_execution_state.py @@ -55,13 +55,7 @@ def mock_services() -> InvocationServices: logger = InvokeAILogger.get_logger() db_path = None if configuration.use_memory_db else configuration.db_path db = SqliteDatabase(db_path=db_path, logger=logger, verbose=configuration.log_sql) - migrator = SQLiteMigrator( - db_path=db.db_path, - conn=db.conn, - lock=db.lock, - logger=logger, - log_sql=configuration.log_sql, - ) + migrator = SQLiteMigrator(db=db) migrator.register_migration(migration_1) migrator.register_migration(migration_2) migrator.run_migrations() diff --git a/tests/aa_nodes/test_invoker.py b/tests/aa_nodes/test_invoker.py index e509703f78..866287c461 100644 --- a/tests/aa_nodes/test_invoker.py +++ b/tests/aa_nodes/test_invoker.py @@ -59,13 +59,7 @@ def mock_services() -> InvocationServices: logger = InvokeAILogger.get_logger() db_path = None if configuration.use_memory_db else configuration.db_path db = SqliteDatabase(db_path=db_path, logger=logger, verbose=configuration.log_sql) - migrator = SQLiteMigrator( - db_path=db.db_path, - conn=db.conn, - lock=db.lock, - logger=logger, - log_sql=configuration.log_sql, - ) + migrator = SQLiteMigrator(db=db) migrator.register_migration(migration_1) migrator.register_migration(migration_2) migrator.run_migrations() diff --git a/tests/app/services/model_install/test_model_install.py b/tests/app/services/model_install/test_model_install.py index 9cb1f67f0a..58515cc273 100644 --- a/tests/app/services/model_install/test_model_install.py +++ b/tests/app/services/model_install/test_model_install.py @@ -44,13 +44,7 @@ def store(app_config: InvokeAIAppConfig) -> ModelRecordServiceBase: logger = InvokeAILogger.get_logger(config=app_config) db_path = None if app_config.use_memory_db else app_config.db_path db = SqliteDatabase(db_path=db_path, logger=logger, verbose=app_config.log_sql) - migrator = SQLiteMigrator( - db_path=db.db_path, - conn=db.conn, - lock=db.lock, - logger=logger, - log_sql=app_config.log_sql, - ) + migrator = SQLiteMigrator(db=db) migrator.register_migration(migration_1) migrator.register_migration(migration_2) migrator.run_migrations() diff --git a/tests/app/services/model_records/test_model_records_sql.py b/tests/app/services/model_records/test_model_records_sql.py index 1be0498836..235c8f3cff 100644 --- a/tests/app/services/model_records/test_model_records_sql.py +++ b/tests/app/services/model_records/test_model_records_sql.py @@ -35,13 +35,7 @@ def store(datadir: Any) -> ModelRecordServiceBase: logger = InvokeAILogger.get_logger(config=config) db_path = None if config.use_memory_db else config.db_path db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql) - migrator = SQLiteMigrator( - db_path=db.db_path, - conn=db.conn, - lock=db.lock, - logger=logger, - log_sql=config.log_sql, - ) + migrator = SQLiteMigrator(db=db) migrator.register_migration(migration_1) migrator.register_migration(migration_2) migrator.run_migrations() diff --git a/tests/test_sqlite_migrator.py b/tests/test_sqlite_migrator.py index 630fb5dd3b..109cc88472 100644 --- a/tests/test_sqlite_migrator.py +++ b/tests/test_sqlite_migrator.py @@ -1,5 +1,4 @@ import sqlite3 -import threading from contextlib import closing from logging import Logger from pathlib import Path @@ -8,7 +7,7 @@ from tempfile import TemporaryDirectory import pytest from pydantic import ValidationError -from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory +from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import ( MigrateCallback, Migration, @@ -27,14 +26,9 @@ def logger() -> Logger: @pytest.fixture -def lock() -> threading.RLock: - return threading.RLock() - - -@pytest.fixture -def migrator(logger: Logger, lock: threading.RLock) -> SQLiteMigrator: - conn = sqlite3.connect(sqlite_memory, check_same_thread=False) - return SQLiteMigrator(conn=conn, db_path=None, lock=lock, logger=logger) +def migrator(logger: Logger) -> SQLiteMigrator: + db = SqliteDatabase(db_path=None, logger=logger, verbose=False) + return SQLiteMigrator(db=db) @pytest.fixture @@ -176,50 +170,55 @@ def test_migrator_registers_migration(migrator: SQLiteMigrator, migration_no_op: def test_migrator_creates_migrations_table(migrator: SQLiteMigrator) -> None: - migrator._create_migrations_table(migrator._cursor) - migrator._cursor.execute("SELECT * FROM sqlite_master WHERE type='table' AND name='migrations';") - assert migrator._cursor.fetchone() is not None + cursor = migrator._db.conn.cursor() + migrator._create_migrations_table(cursor) + cursor.execute("SELECT * FROM sqlite_master WHERE type='table' AND name='migrations';") + assert cursor.fetchone() is not None def test_migrator_migration_sets_version(migrator: SQLiteMigrator, migration_no_op: Migration) -> None: - migrator._create_migrations_table(migrator._cursor) + cursor = migrator._db.conn.cursor() + migrator._create_migrations_table(cursor) migrator.register_migration(migration_no_op) migrator.run_migrations() - migrator._cursor.execute("SELECT MAX(version) FROM migrations;") - assert migrator._cursor.fetchone()[0] == 1 + cursor.execute("SELECT MAX(version) FROM migrations;") + assert cursor.fetchone()[0] == 1 def test_migrator_gets_current_version(migrator: SQLiteMigrator, migration_no_op: Migration) -> None: - assert migrator._get_current_version(migrator._cursor) == 0 - migrator._create_migrations_table(migrator._cursor) - assert migrator._get_current_version(migrator._cursor) == 0 + cursor = migrator._db.conn.cursor() + assert migrator._get_current_version(cursor) == 0 + migrator._create_migrations_table(cursor) + assert migrator._get_current_version(cursor) == 0 migrator.register_migration(migration_no_op) migrator.run_migrations() - assert migrator._get_current_version(migrator._cursor) == 1 + assert migrator._get_current_version(cursor) == 1 def test_migrator_runs_single_migration(migrator: SQLiteMigrator, migration_create_test_table: Migration) -> None: - migrator._create_migrations_table(migrator._cursor) - migrator._run_migration(migration_create_test_table, migrator._cursor) - assert migrator._get_current_version(migrator._cursor) == 1 - migrator._cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test';") - assert migrator._cursor.fetchone() is not None + cursor = migrator._db.conn.cursor() + migrator._create_migrations_table(cursor) + migrator._run_migration(migration_create_test_table, cursor) + assert migrator._get_current_version(cursor) == 1 + cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test';") + assert cursor.fetchone() is not None def test_migrator_runs_all_migrations_in_memory(migrator: SQLiteMigrator) -> None: + cursor = migrator._db.conn.cursor() migrations = [Migration(from_version=i, to_version=i + 1, migrate=create_migrate(i)) for i in range(0, 3)] for migration in migrations: migrator.register_migration(migration) migrator.run_migrations() - assert migrator._get_current_version(migrator._cursor) == 3 + assert migrator._get_current_version(cursor) == 3 -def test_migrator_runs_all_migrations_file(logger: Logger, lock: threading.RLock) -> None: +def test_migrator_runs_all_migrations_file(logger: Logger) -> None: with TemporaryDirectory() as tempdir: original_db_path = Path(tempdir) / "invokeai.db" # The Migrator closes the database when it finishes; we cannot use a context manager. - original_db_conn = sqlite3.connect(original_db_path) - migrator = SQLiteMigrator(conn=original_db_conn, db_path=original_db_path, lock=lock, logger=logger) + db = SqliteDatabase(db_path=original_db_path, logger=logger, verbose=False) + migrator = SQLiteMigrator(db=db) migrations = [Migration(from_version=i, to_version=i + 1, migrate=create_migrate(i)) for i in range(0, 3)] for migration in migrations: migrator.register_migration(migration) @@ -272,18 +271,20 @@ def test_migrator_finalizes() -> None: def test_migrator_makes_no_changes_on_failed_migration( migrator: SQLiteMigrator, migration_no_op: Migration, failing_migrate_callback: MigrateCallback ) -> None: + cursor = migrator._db.conn.cursor() migrator.register_migration(migration_no_op) migrator.run_migrations() - assert migrator._get_current_version(migrator._cursor) == 1 + assert migrator._get_current_version(cursor) == 1 migrator.register_migration(Migration(from_version=1, to_version=2, migrate=failing_migrate_callback)) with pytest.raises(MigrationError, match="Bad migration"): migrator.run_migrations() - assert migrator._get_current_version(migrator._cursor) == 1 + assert migrator._get_current_version(cursor) == 1 def test_idempotent_migrations(migrator: SQLiteMigrator, migration_create_test_table: Migration) -> None: + cursor = migrator._db.conn.cursor() migrator.register_migration(migration_create_test_table) migrator.run_migrations() # not throwing is sufficient migrator.run_migrations() - assert migrator._get_current_version(migrator._cursor) == 1 + assert migrator._get_current_version(cursor) == 1