diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 151c8edf7f..fe42872bcd 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -81,13 +81,7 @@ class ApiDependencies: migration_2.register_post_callback(partial(migrate_embedded_workflows, logger=logger, image_files=image_files)) migrator.register_migration(migration_1) migrator.register_migration(migration_2) - did_migrate = migrator.run_migrations() - - # We need to reinitialize the database if we migrated, but only if we are using a file database. - # This closes the SqliteDatabase's connection and re-runs its `__init__` logic. - # If we do this with a memory database, we wipe the db. - if not db.db_path and did_migrate: - db.reinitialize() + migrator.run_migrations() configuration = config logger = logger diff --git a/invokeai/app/services/shared/sqlite/sqlite_database.py b/invokeai/app/services/shared/sqlite/sqlite_database.py index be8b6284bf..e860160044 100644 --- a/invokeai/app/services/shared/sqlite/sqlite_database.py +++ b/invokeai/app/services/shared/sqlite/sqlite_database.py @@ -10,25 +10,22 @@ class SqliteDatabase: """ Manages a connection to an SQLite database. + :param db_path: Path to the database file. If None, an in-memory database is used. + :param logger: Logger to use for logging. + :param verbose: Whether to log SQL statements. Provides `logger.debug` as the SQLite trace callback. + This is a light wrapper around the `sqlite3` module, providing a few conveniences: - The database file is written to disk if it does not exist. - Foreign key constraints are enabled by default. - The connection is configured to use the `sqlite3.Row` row factory. - - A `conn` attribute is provided to access the connection. - - A `lock` attribute is provided to lock the database connection. - - A `clean` method to run the VACUUM command and report on the freed space. - - A `reinitialize` method to close the connection and re-run the init. - - A `close` method to close the connection. - :param db_path: Path to the database file. If None, an in-memory database is used. - :param logger: Logger to use for logging. - :param verbose: Whether to log SQL statements. Provides `logger.debug` as the SQLite trace callback. + In addition to the constructor args, the instance provides the following attributes and methods: + - `conn`: A `sqlite3.Connection` object. Note that the connection must never be closed if the database is in-memory. + - `lock`: A shared re-entrant lock, used to approximate thread safety. + - `clean()`: Runs the SQL `VACUUM;` command and reports on the freed space. """ def __init__(self, db_path: Path | None, logger: Logger, verbose: bool = False) -> None: - self.initialize(db_path=db_path, logger=logger, verbose=verbose) - - def initialize(self, db_path: Path | None, logger: Logger, verbose: bool = False) -> None: """Initializes the database. This is used internally by the class constructor.""" self.logger = logger self.db_path = db_path @@ -49,21 +46,6 @@ class SqliteDatabase: self.conn.execute("PRAGMA foreign_keys = ON;") - def reinitialize(self) -> None: - """ - Re-initializes the database by closing the connection and re-running the init. - Warning: This will wipe the database if it is an in-memory database. - """ - self.close() - self.initialize(db_path=self.db_path, logger=self.logger, verbose=self.verbose) - - def close(self) -> None: - """ - Closes the connection to the database. - Warning: This will wipe the database if it is an in-memory database. - """ - self.conn.close() - def clean(self) -> None: """ Cleans the database by running the VACUUM command, reporting on the freed space. diff --git a/invokeai/app/services/shared/sqlite_migrator/sqlite_migrator_impl.py b/invokeai/app/services/shared/sqlite_migrator/sqlite_migrator_impl.py index 33035b58c2..b6bc8eed58 100644 --- a/invokeai/app/services/shared/sqlite_migrator/sqlite_migrator_impl.py +++ b/invokeai/app/services/shared/sqlite_migrator/sqlite_migrator_impl.py @@ -1,6 +1,4 @@ -import shutil import sqlite3 -from datetime import datetime from pathlib import Path from typing import Optional @@ -12,19 +10,11 @@ class SQLiteMigrator: """ Manages migrations for a SQLite database. - :param db_path: The path to the database to migrate, or None if using an in-memory database. - :param conn: The connection to the database. - :param lock: A lock to use when running migrations. - :param logger: A logger to use for logging. - :param log_sql: Whether to log SQL statements. Only used when the log level is set to debug. + :param db: The instanceof :class:`SqliteDatabase` to migrate. Migrations should be registered with :meth:`register_migration`. - During migration, a copy of the current database is made and the migrations are run on the copy. If the migration - is successful, the original database is backed up and the migrated database is moved to the original database's - path. If the migration fails, the original database is left untouched and the migrated database is deleted. - - If the database is in-memory, no backup is made; the migration is run in-place. + Each migration is run in a transaction. If a migration fails, the transaction is rolled back. """ backup_path: Optional[Path] = None @@ -34,11 +24,6 @@ class SQLiteMigrator: self._logger = db.logger self._migration_set = MigrationSet() - # The presence of an temp database file indicates a catastrophic failure of a previous migration. - if self._db.db_path and self._get_temp_db_path(self._db.db_path).is_file(): - self._logger.warning("Previous migration failed! Trying again...") - self._get_temp_db_path(self._db.db_path).unlink() - def register_migration(self, migration: Migration) -> None: """Registers a migration.""" self._migration_set.register(migration) @@ -61,44 +46,24 @@ class SQLiteMigrator: return False self._logger.info("Database update needed") - - if self._db.db_path: - # We are using a file database. Create a copy of the database to run the migrations on. - temp_db_path = self._create_temp_db(self._db.db_path) - self._logger.info(f"Copied database to {temp_db_path} for migration") - temp_db = SqliteDatabase(db_path=temp_db_path, logger=self._logger, verbose=self._db.verbose) - temp_db_cursor = temp_db.conn.cursor() - self._run_migrations(temp_db_cursor) - # Close the connections, copy the original database as a backup, and move the temp database to the - # original database's path. - temp_db.close() - self._db.close() - backup_db_path = self._finalize_migration( - temp_db_path=temp_db_path, - original_db_path=self._db.db_path, - ) - self._logger.info(f"Migration successful. Original DB backed up to {backup_db_path}") - else: - # We are using a memory database. No special backup or special handling needed. - self._run_migrations(cursor) - + next_migration = self._migration_set.get(from_version=self._get_current_version(cursor)) + while next_migration is not None: + self._run_migration(next_migration) + next_migration = self._migration_set.get(self._get_current_version(cursor)) self._logger.info("Database updated successfully") return True - def _run_migrations(self, temp_db_cursor: sqlite3.Cursor) -> None: - """Runs all migrations in a loop.""" - next_migration = self._migration_set.get(from_version=self._get_current_version(temp_db_cursor)) - while next_migration is not None: - self._run_migration(next_migration, temp_db_cursor) - next_migration = self._migration_set.get(self._get_current_version(temp_db_cursor)) - - def _run_migration(self, migration: Migration, temp_db_cursor: sqlite3.Cursor) -> None: + def _run_migration(self, migration: Migration) -> None: """Runs a single migration.""" - with self._db.lock: - try: - if self._get_current_version(temp_db_cursor) != migration.from_version: + # Using sqlite3.Connection as a context manager commits a the transaction on exit, or rolls it back if an + # exception is raised. We want to commit the transaction if the migration is successful, or roll it back if + # there is an error. + try: + with self._db.lock, self._db.conn as conn: + cursor = conn.cursor() + if self._get_current_version(cursor) != migration.from_version: raise MigrationError( - f"Database is at version {self._get_current_version(temp_db_cursor)}, expected {migration.from_version}" + f"Database is at version {self._get_current_version(cursor)}, expected {migration.from_version}" ) self._logger.debug(f"Running migration from {migration.from_version} to {migration.to_version}") @@ -106,28 +71,26 @@ class SQLiteMigrator: if migration.pre_migrate: self._logger.debug(f"Running {len(migration.pre_migrate)} pre-migration callbacks") for callback in migration.pre_migrate: - callback(temp_db_cursor) + callback(cursor) # Run the actual migration - migration.migrate(temp_db_cursor) - temp_db_cursor.execute("INSERT INTO migrations (version) VALUES (?);", (migration.to_version,)) + migration.migrate(cursor) + cursor.execute("INSERT INTO migrations (version) VALUES (?);", (migration.to_version,)) # Run post-migration callbacks if migration.post_migrate: self._logger.debug(f"Running {len(migration.post_migrate)} post-migration callbacks") for callback in migration.post_migrate: - callback(temp_db_cursor) - - # Migration callbacks only get a cursor. Commit this migration. - temp_db_cursor.connection.commit() + callback(cursor) self._logger.debug( f"Successfully migrated database from {migration.from_version} to {migration.to_version}" ) - except Exception as e: - msg = f"Error migrating database from {migration.from_version} to {migration.to_version}: {e}" - temp_db_cursor.connection.rollback() - self._logger.error(msg) - raise MigrationError(msg) from e + # We want to catch *any* error, mirroring the behaviour of the sqlite3 module. + except Exception as e: + # The connection context manager has already rolled back the migration, so we don't need to do anything. + msg = f"Error migrating database from {migration.from_version} to {migration.to_version}: {e}" + self._logger.error(msg) + raise MigrationError(msg) from e def _create_migrations_table(self, cursor: sqlite3.Cursor) -> None: """Creates the migrations table for the database, if one does not already exist.""" @@ -166,33 +129,3 @@ class SQLiteMigrator: if "no such table" in str(e): return 0 raise - - @classmethod - def _create_temp_db(cls, original_db_path: Path) -> Path: - """Copies the current database to a new file for migration.""" - temp_db_path = cls._get_temp_db_path(original_db_path) - shutil.copy2(original_db_path, temp_db_path) - return temp_db_path - - @classmethod - def _finalize_migration( - cls, - temp_db_path: Path, - original_db_path: Path, - ) -> Path: - """Renames the original database as a backup and renames the migrated database to the original name.""" - backup_db_path = cls._get_backup_db_path(original_db_path) - original_db_path.rename(backup_db_path) - temp_db_path.rename(original_db_path) - return backup_db_path - - @classmethod - def _get_temp_db_path(cls, original_db_path: Path) -> Path: - """Gets the path to the temp database.""" - return original_db_path.parent / original_db_path.name.replace(".db", ".db.temp") - - @classmethod - def _get_backup_db_path(cls, original_db_path: Path) -> Path: - """Gets the path to the final backup database.""" - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - return original_db_path.parent / f"{original_db_path.stem}_backup_{timestamp}.db" diff --git a/tests/app/services/model_install/test_model_install.py b/tests/app/services/model_install/test_model_install.py index 58515cc273..2b245cce6d 100644 --- a/tests/app/services/model_install/test_model_install.py +++ b/tests/app/services/model_install/test_model_install.py @@ -48,8 +48,6 @@ def store(app_config: InvokeAIAppConfig) -> ModelRecordServiceBase: migrator.register_migration(migration_1) migrator.register_migration(migration_2) migrator.run_migrations() - # this test uses a file database, so we need to reinitialize it after migrations - db.reinitialize() store: ModelRecordServiceBase = ModelRecordServiceSQL(db) return store diff --git a/tests/app/services/model_records/test_model_records_sql.py b/tests/app/services/model_records/test_model_records_sql.py index 235c8f3cff..e3589d6ec0 100644 --- a/tests/app/services/model_records/test_model_records_sql.py +++ b/tests/app/services/model_records/test_model_records_sql.py @@ -39,8 +39,6 @@ def store(datadir: Any) -> ModelRecordServiceBase: migrator.register_migration(migration_1) migrator.register_migration(migration_2) migrator.run_migrations() - # this test uses a file database, so we need to reinitialize it after migrations - db.reinitialize() return ModelRecordServiceSQL(db) diff --git a/tests/test_sqlite_migrator.py b/tests/test_sqlite_migrator.py index 109cc88472..3759859b4b 100644 --- a/tests/test_sqlite_migrator.py +++ b/tests/test_sqlite_migrator.py @@ -198,7 +198,7 @@ def test_migrator_gets_current_version(migrator: SQLiteMigrator, migration_no_op def test_migrator_runs_single_migration(migrator: SQLiteMigrator, migration_create_test_table: Migration) -> None: cursor = migrator._db.conn.cursor() migrator._create_migrations_table(cursor) - migrator._run_migration(migration_create_test_table, cursor) + migrator._run_migration(migration_create_test_table) assert migrator._get_current_version(cursor) == 1 cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test';") assert cursor.fetchone() is not None @@ -228,46 +228,6 @@ def test_migrator_runs_all_migrations_file(logger: Logger) -> None: assert SQLiteMigrator._get_current_version(original_db_cursor) == 3 -def test_migrator_creates_temp_db() -> None: - with TemporaryDirectory() as tempdir: - original_db_path = Path(tempdir) / "invokeai.db" - with closing(sqlite3.connect(original_db_path)): - # create the db file so _create_temp_db has something to copy - pass - temp_db_path = SQLiteMigrator._create_temp_db(original_db_path) - assert temp_db_path.is_file() - assert temp_db_path == SQLiteMigrator._get_temp_db_path(original_db_path) - - -def test_migrator_finalizes() -> None: - with TemporaryDirectory() as tempdir: - original_db_path = Path(tempdir) / "invokeai.db" - temp_db_path = SQLiteMigrator._get_temp_db_path(original_db_path) - backup_db_path = SQLiteMigrator._get_backup_db_path(original_db_path) - with closing(sqlite3.connect(original_db_path)) as original_db_conn, closing( - sqlite3.connect(temp_db_path) - ) as temp_db_conn: - original_db_cursor = original_db_conn.cursor() - original_db_cursor.execute("CREATE TABLE original_db_test (id INTEGER PRIMARY KEY);") - original_db_conn.commit() - temp_db_cursor = temp_db_conn.cursor() - temp_db_cursor.execute("CREATE TABLE temp_db_test (id INTEGER PRIMARY KEY);") - temp_db_conn.commit() - SQLiteMigrator._finalize_migration( - original_db_path=original_db_path, - temp_db_path=temp_db_path, - ) - with closing(sqlite3.connect(backup_db_path)) as backup_db_conn, closing( - sqlite3.connect(temp_db_path) - ) as temp_db_conn: - backup_db_cursor = backup_db_conn.cursor() - backup_db_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='original_db_test';") - assert backup_db_cursor.fetchone() is not None - temp_db_cursor = temp_db_conn.cursor() - temp_db_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='temp_db_test';") - assert temp_db_cursor.fetchone() is None - - def test_migrator_makes_no_changes_on_failed_migration( migrator: SQLiteMigrator, migration_no_op: Migration, failing_migrate_callback: MigrateCallback ) -> None: