feat(db): invert backup/restore logic

Do the migration on a temp copy of the db, then back up the original and move the temp into its file.
This commit is contained in:
psychedelicious
2023-12-11 00:47:53 +11:00
parent abeb1bd3b3
commit e461f9925e
6 changed files with 167 additions and 126 deletions

View File

@ -21,7 +21,7 @@ from invokeai.app.services.shared.sqlite.sqlite_migrator import (
def migrator() -> SQLiteMigrator:
conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
return SQLiteMigrator(
conn=conn, database=sqlite_memory, lock=threading.RLock(), logger=Logger("test_sqlite_migrator")
conn=conn, db_path=sqlite_memory, lock=threading.RLock(), logger=Logger("test_sqlite_migrator")
)
@ -50,19 +50,19 @@ def test_register_invalid_migration_version(migrator: SQLiteMigrator):
def test_create_version_table(migrator: SQLiteMigrator):
migrator._create_version_table()
migrator._create_migrations_table()
migrator._cursor.execute("SELECT * FROM sqlite_master WHERE type='table' AND name='version';")
assert migrator._cursor.fetchone() is not None
def test_get_current_version(migrator: SQLiteMigrator):
migrator._create_version_table()
migrator._create_migrations_table()
migrator._conn.commit()
assert migrator._get_current_version() == 0 # initial version
def test_set_version(migrator: SQLiteMigrator):
migrator._create_version_table()
migrator._create_migrations_table()
migrator._set_version(db_version=1, app_version="1.0.0")
migrator._cursor.execute("SELECT MAX(db_version) FROM version;")
assert migrator._cursor.fetchone()[0] == 1
@ -71,7 +71,7 @@ def test_set_version(migrator: SQLiteMigrator):
def test_run_migration(migrator: SQLiteMigrator):
migrator._create_version_table()
migrator._create_migrations_table()
def migration_callback(cursor: sqlite3.Cursor) -> None:
cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
@ -84,7 +84,7 @@ def test_run_migration(migrator: SQLiteMigrator):
def test_run_migrations(migrator: SQLiteMigrator):
migrator._create_version_table()
migrator._create_migrations_table()
def create_migrate(i: int) -> Callable[[sqlite3.Cursor], None]:
def migrate(cursor: sqlite3.Cursor) -> None:
@ -109,8 +109,8 @@ def test_backup_and_restore_db():
cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
conn.commit()
migrator = SQLiteMigrator(conn=conn, database=database, lock=threading.RLock(), logger=Logger("test"))
backup_path = migrator._backup_db(migrator._database)
migrator = SQLiteMigrator(conn=conn, db_path=database, lock=threading.RLock(), logger=Logger("test"))
backup_path = migrator._backup_db(migrator._db_path)
# mangle the db
migrator._cursor.execute("DROP TABLE test;")
@ -135,21 +135,21 @@ def test_no_backup_and_restore_for_memory_db(migrator: SQLiteMigrator):
def test_failed_migration(migrator: SQLiteMigrator, failing_migration: Migration):
migrator._create_version_table()
migrator._create_migrations_table()
with pytest.raises(MigrationError, match="Error migrating database from 0 to 1"):
migrator._run_migration(failing_migration)
assert migrator._get_current_version() == 0
def test_duplicate_migration_versions(migrator: SQLiteMigrator, good_migration: Migration):
migrator._create_version_table()
migrator._create_migrations_table()
migrator.register_migration(good_migration)
with pytest.raises(MigrationVersionError, match="already registered"):
migrator.register_migration(deepcopy(good_migration))
def test_non_sequential_migration_registration(migrator: SQLiteMigrator):
migrator._create_version_table()
migrator._create_migrations_table()
def create_migrate(i: int) -> Callable[[sqlite3.Cursor], None]:
def migrate(cursor: sqlite3.Cursor) -> None:
@ -167,7 +167,7 @@ def test_non_sequential_migration_registration(migrator: SQLiteMigrator):
def test_db_version_gt_last_migration(migrator: SQLiteMigrator, good_migration: Migration):
migrator._create_version_table()
migrator._create_migrations_table()
migrator.register_migration(good_migration)
migrator._set_version(db_version=2, app_version="2.0.0")
with pytest.raises(MigrationError, match="greater than the latest migration version"):
@ -176,7 +176,7 @@ def test_db_version_gt_last_migration(migrator: SQLiteMigrator, good_migration:
def test_idempotent_migrations(migrator: SQLiteMigrator):
migrator._create_version_table()
migrator._create_migrations_table()
def create_test_table(cursor: sqlite3.Cursor) -> None:
# This SQL throws if run twice