feat(db): update sqlite migrator tests

This commit is contained in:
psychedelicious 2023-12-11 12:43:40 +11:00
parent 3227b30430
commit c823f5667b

View File

@ -1,33 +1,50 @@
import sqlite3 import sqlite3
import threading import threading
from copy import deepcopy
from logging import Logger from logging import Logger
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Callable
import pytest import pytest
from pydantic import ValidationError
from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
from invokeai.app.services.shared.sqlite.sqlite_migrator import ( from invokeai.app.services.shared.sqlite.sqlite_migrator import (
MigrateCallback,
Migration, Migration,
MigrationError, MigrationError,
MigrationSet,
MigrationVersionError, MigrationVersionError,
SQLiteMigrator, SQLiteMigrator,
) )
@pytest.fixture @pytest.fixture
def migrator() -> SQLiteMigrator: def logger() -> Logger:
conn = sqlite3.connect(sqlite_memory, check_same_thread=False) return Logger("test_sqlite_migrator")
return SQLiteMigrator(
conn=conn, db_path=sqlite_memory, lock=threading.RLock(), logger=Logger("test_sqlite_migrator")
)
@pytest.fixture @pytest.fixture
def good_migration() -> Migration: def lock() -> threading.RLock:
return Migration(db_version=1, app_version="1.0.0", migrate=lambda cursor: None) return threading.RLock()
@pytest.fixture
def migrator(logger: Logger, lock: threading.RLock) -> SQLiteMigrator:
conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
return SQLiteMigrator(conn=conn, db_path=None, lock=lock, logger=logger)
@pytest.fixture
def migration_no_op() -> Migration:
return Migration(from_version=0, to_version=1, migrate=lambda cursor: None)
@pytest.fixture
def migration_create_test_table() -> Migration:
def migrate(cursor: sqlite3.Cursor) -> None:
cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
return Migration(from_version=0, to_version=1, migrate=migrate)
@pytest.fixture @pytest.fixture
@ -35,156 +52,224 @@ def failing_migration() -> Migration:
def failing_migration(cursor: sqlite3.Cursor) -> None: def failing_migration(cursor: sqlite3.Cursor) -> None:
raise Exception("Bad migration") raise Exception("Bad migration")
return Migration(db_version=1, app_version="1.0.0", migrate=failing_migration) return Migration(from_version=0, to_version=1, migrate=failing_migration)
def test_register_migration(migrator: SQLiteMigrator, good_migration: Migration): @pytest.fixture
migration = good_migration def no_op_migrate_callback() -> MigrateCallback:
def no_op_migrate(cursor: sqlite3.Cursor) -> None:
pass
return no_op_migrate
@pytest.fixture
def failing_migrate_callback() -> MigrateCallback:
def failing_migrate(cursor: sqlite3.Cursor) -> None:
raise Exception("Bad migration")
return failing_migrate
def create_migrate(i: int) -> MigrateCallback:
def migrate(cursor: sqlite3.Cursor) -> None:
cursor.execute(f"CREATE TABLE test{i} (id INTEGER PRIMARY KEY);")
return migrate
def test_migration_to_version_gt_from_version(no_op_migrate_callback: MigrateCallback):
with pytest.raises(ValidationError, match="greater_than_equal"):
Migration(from_version=1, to_version=0, migrate=no_op_migrate_callback)
def test_migration_hash(no_op_migrate_callback: MigrateCallback):
migration = Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback)
assert hash(migration) == hash((0, 1))
def test_migration_registers_pre_and_post_callbacks(no_op_migrate_callback: MigrateCallback):
def pre_callback(cursor: sqlite3.Cursor) -> None:
pass
def post_callback(cursor: sqlite3.Cursor) -> None:
pass
migration = Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback)
migration.register_pre_callback(pre_callback)
migration.register_post_callback(post_callback)
assert pre_callback in migration.pre_migrate
assert post_callback in migration.post_migrate
def test_migration_set_add_migration(migrator: SQLiteMigrator, migration_no_op: Migration):
migration = migration_no_op
migrator._migration_set.register(migration)
assert migration in migrator._migration_set._migrations
def test_migration_set_may_not_register_dupes(migrator: SQLiteMigrator, no_op_migrate_callback: MigrateCallback):
migrate_1_to_2 = Migration(from_version=1, to_version=2, migrate=no_op_migrate_callback)
migrate_0_to_2 = Migration(from_version=0, to_version=2, migrate=no_op_migrate_callback)
migrate_1_to_3 = Migration(from_version=1, to_version=3, migrate=no_op_migrate_callback)
migrator._migration_set.register(migrate_1_to_2)
with pytest.raises(MigrationVersionError, match=r"Migration to 2 already registered"):
migrator._migration_set.register(migrate_0_to_2)
with pytest.raises(MigrationVersionError, match=r"Migration from 1 already registered"):
migrator._migration_set.register(migrate_1_to_3)
def test_migration_set_gets_migration(migration_no_op: Migration):
migration_set = MigrationSet()
migration_set.register(migration_no_op)
assert migration_set.get(0) == migration_no_op
assert migration_set.get(1) is None
def test_migration_set_validates_migration_path(no_op_migrate_callback: MigrateCallback):
migration_set = MigrationSet()
migration_set.validate_migration_chain()
migration_set.register(Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback))
migration_set.validate_migration_chain()
migration_set.register(Migration(from_version=1, to_version=2, migrate=no_op_migrate_callback))
migration_set.register(Migration(from_version=2, to_version=3, migrate=no_op_migrate_callback))
migration_set.validate_migration_chain()
migration_set.register(Migration(from_version=4, to_version=5, migrate=no_op_migrate_callback))
with pytest.raises(MigrationError, match="Migration chain is fragmented"):
migration_set.validate_migration_chain()
def test_migration_set_counts_migrations(no_op_migrate_callback: MigrateCallback):
migration_set = MigrationSet()
assert migration_set.count == 0
migration_set.register(Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback))
assert migration_set.count == 1
migration_set.register(Migration(from_version=1, to_version=2, migrate=no_op_migrate_callback))
assert migration_set.count == 2
def test_migration_set_gets_latest_version(no_op_migrate_callback: MigrateCallback):
migration_set = MigrationSet()
assert migration_set.latest_version == 0
migration_set.register(Migration(from_version=1, to_version=2, migrate=no_op_migrate_callback))
assert migration_set.latest_version == 2
migration_set.register(Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback))
assert migration_set.latest_version == 2
def test_migrator_registers_migration(migrator: SQLiteMigrator, migration_no_op: Migration):
migration = migration_no_op
migrator.register_migration(migration) migrator.register_migration(migration)
assert migration in migrator._migrations assert migration in migrator._migration_set._migrations
def test_register_invalid_migration_version(migrator: SQLiteMigrator): def test_migrator_creates_migrations_table(migrator: SQLiteMigrator):
with pytest.raises(MigrationError, match="Invalid migration version"): migrator._create_migrations_table(migrator._cursor)
migrator.register_migration(Migration(db_version=0, app_version="0.0.0", migrate=lambda cursor: None)) migrator._cursor.execute("SELECT * FROM sqlite_master WHERE type='table' AND name='migrations';")
def test_create_version_table(migrator: SQLiteMigrator):
migrator._create_migrations_table()
migrator._cursor.execute("SELECT * FROM sqlite_master WHERE type='table' AND name='version';")
assert migrator._cursor.fetchone() is not None assert migrator._cursor.fetchone() is not None
def test_get_current_version(migrator: SQLiteMigrator): def test_migrator_migration_sets_version(migrator: SQLiteMigrator, migration_no_op: Migration):
migrator._create_migrations_table() migrator._create_migrations_table(migrator._cursor)
migrator._conn.commit() migrator.register_migration(migration_no_op)
assert migrator._get_current_version() == 0 # initial version migrator.run_migrations()
migrator._cursor.execute("SELECT MAX(version) FROM migrations;")
def test_set_version(migrator: SQLiteMigrator):
migrator._create_migrations_table()
migrator._set_version(db_version=1, app_version="1.0.0")
migrator._cursor.execute("SELECT MAX(db_version) FROM version;")
assert migrator._cursor.fetchone()[0] == 1 assert migrator._cursor.fetchone()[0] == 1
migrator._cursor.execute("SELECT app_version from version WHERE db_version = 1;")
assert migrator._cursor.fetchone()[0] == "1.0.0"
def test_run_migration(migrator: SQLiteMigrator): def test_migrator_gets_current_version(migrator: SQLiteMigrator, migration_no_op: Migration):
migrator._create_migrations_table() assert migrator._get_current_version(migrator._cursor) == 0
migrator._create_migrations_table(migrator._cursor)
def migration_callback(cursor: sqlite3.Cursor) -> None: assert migrator._get_current_version(migrator._cursor) == 0
cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);") migrator.register_migration(migration_no_op)
migrator.run_migrations()
migration = Migration(db_version=1, app_version="1.0.0", migrate=migration_callback) assert migrator._get_current_version(migrator._cursor) == 1
migrator._run_migration(migration)
assert migrator._get_current_version() == 1
migrator._cursor.execute("SELECT app_version from version WHERE db_version = 1;")
assert migrator._cursor.fetchone()[0] == "1.0.0"
def test_run_migrations(migrator: SQLiteMigrator): def test_migrator_runs_single_migration(migrator: SQLiteMigrator, migration_create_test_table: Migration):
migrator._create_migrations_table() migrator._create_migrations_table(migrator._cursor)
migrator._run_migration(migration_create_test_table, migrator._cursor)
assert migrator._get_current_version(migrator._cursor) == 1
migrator._cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test';")
assert migrator._cursor.fetchone() is not None
def create_migrate(i: int) -> Callable[[sqlite3.Cursor], None]:
def migrate(cursor: sqlite3.Cursor) -> None:
cursor.execute(f"CREATE TABLE test{i} (id INTEGER PRIMARY KEY);")
return migrate def test_migrator_runs_all_migrations_in_memory(
migrator: SQLiteMigrator,
migrations = [Migration(db_version=i, app_version=f"{i}.0.0", migrate=create_migrate(i)) for i in range(1, 4)] ):
migrations = [Migration(from_version=i, to_version=i + 1, migrate=create_migrate(i)) for i in range(0, 3)]
for migration in migrations: for migration in migrations:
migrator.register_migration(migration) migrator.register_migration(migration)
migrator.run_migrations() migrator.run_migrations()
assert migrator._get_current_version() == 3 assert migrator._get_current_version(migrator._cursor) == 3
def test_backup_and_restore_db(): def test_migrator_runs_all_migrations_file(logger: Logger, lock: threading.RLock):
# must do this with a file database - we don't backup/restore for memory
with TemporaryDirectory() as tempdir: with TemporaryDirectory() as tempdir:
# create test DB w/ some data original_db_path = Path(tempdir) / "invokeai.db"
database = Path(tempdir) / "test.db" # The Migrator closes the database when it finishes; we cannot use a context manager.
conn = sqlite3.connect(database, check_same_thread=False) original_db_conn = sqlite3.connect(original_db_path)
cursor = conn.cursor() migrator = SQLiteMigrator(conn=original_db_conn, db_path=original_db_path, lock=lock, logger=logger)
cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);") migrations = [Migration(from_version=i, to_version=i + 1, migrate=create_migrate(i)) for i in range(0, 3)]
conn.commit() for migration in migrations:
migrator.register_migration(migration)
migrator = SQLiteMigrator(conn=conn, db_path=database, lock=threading.RLock(), logger=Logger("test"))
backup_path = migrator._backup_db(migrator._db_path)
# mangle the db
migrator._cursor.execute("DROP TABLE test;")
migrator._conn.commit()
migrator._cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test';")
assert migrator._cursor.fetchone() is None
# restore (closes the connection - must create a new one)
migrator._restore_db(backup_path)
restored_conn = sqlite3.connect(database)
restored_cursor = restored_conn.cursor()
restored_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test';")
assert restored_cursor.fetchone() is not None
# must manually close else tempfile throws on cleanup on windows
restored_conn.close()
def test_no_backup_and_restore_for_memory_db(migrator: SQLiteMigrator):
with pytest.raises(MigrationError, match="Cannot back up memory database"):
migrator._backup_db(sqlite_memory)
def test_failed_migration(migrator: SQLiteMigrator, failing_migration: Migration):
migrator._create_migrations_table()
with pytest.raises(MigrationError, match="Error migrating database from 0 to 1"):
migrator._run_migration(failing_migration)
assert migrator._get_current_version() == 0
def test_duplicate_migration_versions(migrator: SQLiteMigrator, good_migration: Migration):
migrator._create_migrations_table()
migrator.register_migration(good_migration)
with pytest.raises(MigrationVersionError, match="already registered"):
migrator.register_migration(deepcopy(good_migration))
def test_non_sequential_migration_registration(migrator: SQLiteMigrator):
migrator._create_migrations_table()
def create_migrate(i: int) -> Callable[[sqlite3.Cursor], None]:
def migrate(cursor: sqlite3.Cursor) -> None:
cursor.execute(f"CREATE TABLE test{i} (id INTEGER PRIMARY KEY);")
return migrate
migrations = [
Migration(db_version=i, app_version=f"{i}.0.0", migrate=create_migrate(i)) for i in reversed(range(1, 4))
]
for migration in migrations:
migrator.register_migration(migration)
migrator.run_migrations()
assert migrator._get_current_version() == 3
def test_db_version_gt_last_migration(migrator: SQLiteMigrator, good_migration: Migration):
migrator._create_migrations_table()
migrator.register_migration(good_migration)
migrator._set_version(db_version=2, app_version="2.0.0")
with pytest.raises(MigrationError, match="greater than the latest migration version"):
migrator.run_migrations() migrator.run_migrations()
assert migrator._get_current_version() == 2 with sqlite3.connect(original_db_path) as original_db_conn:
original_db_cursor = original_db_conn.cursor()
assert SQLiteMigrator._get_current_version(original_db_cursor) == 3
def test_idempotent_migrations(migrator: SQLiteMigrator): def test_migrator_creates_temp_db():
migrator._create_migrations_table() with TemporaryDirectory() as tempdir:
original_db_path = Path(tempdir) / "invokeai.db"
with sqlite3.connect(original_db_path):
# create the db file so _create_temp_db has something to copy
pass
temp_db_path = SQLiteMigrator._create_temp_db(original_db_path)
assert temp_db_path.is_file()
assert temp_db_path == SQLiteMigrator._get_temp_db_path(original_db_path)
def create_test_table(cursor: sqlite3.Cursor) -> None:
# This SQL throws if run twice
cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
migration = Migration(db_version=1, app_version="1.0.0", migrate=create_test_table) def test_migrator_finalizes():
with TemporaryDirectory() as tempdir:
original_db_path = Path(tempdir) / "invokeai.db"
temp_db_path = SQLiteMigrator._get_temp_db_path(original_db_path)
backup_db_path = SQLiteMigrator._get_backup_db_path(original_db_path)
with sqlite3.connect(original_db_path) as original_db_conn, sqlite3.connect(temp_db_path) as temp_db_conn:
original_db_cursor = original_db_conn.cursor()
original_db_cursor.execute("CREATE TABLE original_db_test (id INTEGER PRIMARY KEY);")
original_db_conn.commit()
temp_db_cursor = temp_db_conn.cursor()
temp_db_cursor.execute("CREATE TABLE temp_db_test (id INTEGER PRIMARY KEY);")
temp_db_conn.commit()
SQLiteMigrator._finalize_migration(
original_db_path=original_db_path,
temp_db_path=temp_db_path,
)
with sqlite3.connect(backup_db_path) as backup_db_conn, sqlite3.connect(temp_db_path) as temp_db_conn:
backup_db_cursor = backup_db_conn.cursor()
backup_db_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='original_db_test';")
assert backup_db_cursor.fetchone() is not None
temp_db_cursor = temp_db_conn.cursor()
temp_db_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='temp_db_test';")
assert temp_db_cursor.fetchone() is None
migrator.register_migration(migration)
def test_migrator_makes_no_changes_on_failed_migration(
migrator: SQLiteMigrator, migration_no_op: Migration, failing_migrate_callback: MigrateCallback
):
migrator.register_migration(migration_no_op)
migrator.run_migrations()
assert migrator._get_current_version(migrator._cursor) == 1
migrator.register_migration(Migration(from_version=1, to_version=2, migrate=failing_migrate_callback))
with pytest.raises(MigrationError, match="Bad migration"):
migrator.run_migrations()
assert migrator._get_current_version(migrator._cursor) == 1
def test_idempotent_migrations(migrator: SQLiteMigrator, migration_create_test_table: Migration):
migrator.register_migration(migration_create_test_table)
migrator.run_migrations() migrator.run_migrations()
# not throwing is sufficient # not throwing is sufficient
migrator.run_migrations() migrator.run_migrations()
assert migrator._get_current_version(migrator._cursor) == 1