diff --git a/invokeai/app/services/shared/sqlite_migrator/sqlite_migrator_impl.py b/invokeai/app/services/shared/sqlite_migrator/sqlite_migrator_impl.py index eef9c07ed4..5d78d55818 100644 --- a/invokeai/app/services/shared/sqlite_migrator/sqlite_migrator_impl.py +++ b/invokeai/app/services/shared/sqlite_migrator/sqlite_migrator_impl.py @@ -1,4 +1,6 @@ import sqlite3 +from contextlib import closing +from datetime import datetime from pathlib import Path from typing import Optional @@ -32,6 +34,7 @@ class SqliteMigrator: self._db = db self._logger = db.logger self._migration_set = MigrationSet() + self._backup_path: Optional[Path] = None def register_migration(self, migration: Migration) -> None: """Registers a migration.""" @@ -55,6 +58,18 @@ class SqliteMigrator: return False self._logger.info("Database update needed") + + # Make a backup of the db if it needs to be updated and is a file db + if self._db.db_path is not None: + timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") + self._backup_path = self._db.db_path.parent / f"{self._db.db_path.stem}_backup_{timestamp}.db" + self._logger.info(f"Backing up database to {str(self._backup_path)}") + # Use SQLite to do the backup + with closing(sqlite3.connect(self._backup_path)) as backup_conn: + self._db.conn.backup(backup_conn) + else: + self._logger.info("Using in-memory database, no backup needed") + next_migration = self._migration_set.get(from_version=self._get_current_version(cursor)) while next_migration is not None: self._run_migration(next_migration) diff --git a/tests/test_sqlite_migrator.py b/tests/test_sqlite_migrator.py index 816b8b6a10..7f72d0bd13 100644 --- a/tests/test_sqlite_migrator.py +++ b/tests/test_sqlite_migrator.py @@ -250,6 +250,32 @@ def test_migrator_runs_all_migrations_file(logger: Logger) -> None: db.conn.close() +def test_migrator_backs_up_db(logger: Logger) -> None: + with TemporaryDirectory() as tempdir: + original_db_path = Path(tempdir) / "invokeai.db" + db = SqliteDatabase(db_path=original_db_path, logger=logger, verbose=False) + # Write some data to the db to test for successful backup + temp_cursor = db.conn.cursor() + temp_cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);") + db.conn.commit() + # Set up the migrator + migrator = SqliteMigrator(db=db) + migrations = [Migration(from_version=i, to_version=i + 1, callback=create_migrate(i)) for i in range(0, 3)] + for migration in migrations: + migrator.register_migration(migration) + migrator.run_migrations() + # Must manually close else we get an error on Windows + db.conn.close() + assert original_db_path.exists() + # We should have a backup file when we migrated a file db + assert migrator._backup_path + # Check that the test table exists as a proxy for successful backup + with closing(sqlite3.connect(migrator._backup_path)) as backup_db_conn: + backup_db_cursor = backup_db_conn.cursor() + backup_db_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test';") + assert backup_db_cursor.fetchone() is not None + + def test_migrator_makes_no_changes_on_failed_migration( migrator: SqliteMigrator, migration_no_op: Migration, failing_migrate_callback: MigrateCallback ) -> None: