feat(db): back up database before running migrations

Just in case.
This commit is contained in:
psychedelicious 2024-04-02 08:53:28 +11:00
parent 59b4a23479
commit 4049217728
2 changed files with 41 additions and 0 deletions

View File

@ -1,4 +1,6 @@
import sqlite3
from contextlib import closing
from datetime import datetime
from pathlib import Path
from typing import Optional
@ -32,6 +34,7 @@ class SqliteMigrator:
self._db = db
self._logger = db.logger
self._migration_set = MigrationSet()
self._backup_path: Optional[Path] = None
def register_migration(self, migration: Migration) -> None:
"""Registers a migration."""
@ -55,6 +58,18 @@ class SqliteMigrator:
return False
self._logger.info("Database update needed")
# Make a backup of the db if it needs to be updated and is a file db
if self._db.db_path is not None:
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
self._backup_path = self._db.db_path.parent / f"{self._db.db_path.stem}_backup_{timestamp}.db"
self._logger.info(f"Backing up database to {str(self._backup_path)}")
# Use SQLite to do the backup
with closing(sqlite3.connect(self._backup_path)) as backup_conn:
self._db.conn.backup(backup_conn)
else:
self._logger.info("Using in-memory database, no backup needed")
next_migration = self._migration_set.get(from_version=self._get_current_version(cursor))
while next_migration is not None:
self._run_migration(next_migration)

View File

@ -250,6 +250,32 @@ def test_migrator_runs_all_migrations_file(logger: Logger) -> None:
db.conn.close()
def test_migrator_backs_up_db(logger: Logger) -> None:
with TemporaryDirectory() as tempdir:
original_db_path = Path(tempdir) / "invokeai.db"
db = SqliteDatabase(db_path=original_db_path, logger=logger, verbose=False)
# Write some data to the db to test for successful backup
temp_cursor = db.conn.cursor()
temp_cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
db.conn.commit()
# Set up the migrator
migrator = SqliteMigrator(db=db)
migrations = [Migration(from_version=i, to_version=i + 1, callback=create_migrate(i)) for i in range(0, 3)]
for migration in migrations:
migrator.register_migration(migration)
migrator.run_migrations()
# Must manually close else we get an error on Windows
db.conn.close()
assert original_db_path.exists()
# We should have a backup file when we migrated a file db
assert migrator._backup_path
# Check that the test table exists as a proxy for successful backup
with closing(sqlite3.connect(migrator._backup_path)) as backup_db_conn:
backup_db_cursor = backup_db_conn.cursor()
backup_db_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test';")
assert backup_db_cursor.fetchone() is not None
def test_migrator_makes_no_changes_on_failed_migration(
migrator: SqliteMigrator, migration_no_op: Migration, failing_migrate_callback: MigrateCallback
) -> None: