InvokeAI/tests/test_sqlite_migrator.py

290 lines
12 KiB
Python
Raw Normal View History

import sqlite3
import threading
from contextlib import closing
from logging import Logger
from pathlib import Path
from tempfile import TemporaryDirectory
import pytest
2023-12-11 01:43:40 +00:00
from pydantic import ValidationError
from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
2023-12-11 05:41:47 +00:00
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import (
2023-12-11 01:43:40 +00:00
MigrateCallback,
Migration,
MigrationError,
2023-12-11 01:43:40 +00:00
MigrationSet,
MigrationVersionError,
2023-12-11 05:41:47 +00:00
)
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import (
SQLiteMigrator,
)
@pytest.fixture
2023-12-11 01:43:40 +00:00
def logger() -> Logger:
return Logger("test_sqlite_migrator")
@pytest.fixture
def lock() -> threading.RLock:
return threading.RLock()
@pytest.fixture
def migrator(logger: Logger, lock: threading.RLock) -> SQLiteMigrator:
conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
2023-12-11 01:43:40 +00:00
return SQLiteMigrator(conn=conn, db_path=None, lock=lock, logger=logger)
@pytest.fixture
2023-12-11 01:43:40 +00:00
def migration_no_op() -> Migration:
return Migration(from_version=0, to_version=1, migrate=lambda cursor: None)
@pytest.fixture
def migration_create_test_table() -> Migration:
def migrate(cursor: sqlite3.Cursor) -> None:
cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
return Migration(from_version=0, to_version=1, migrate=migrate)
@pytest.fixture
def failing_migration() -> Migration:
def failing_migration(cursor: sqlite3.Cursor) -> None:
raise Exception("Bad migration")
2023-12-11 01:43:40 +00:00
return Migration(from_version=0, to_version=1, migrate=failing_migration)
2023-12-11 01:43:40 +00:00
@pytest.fixture
def no_op_migrate_callback() -> MigrateCallback:
def no_op_migrate(cursor: sqlite3.Cursor) -> None:
pass
2023-12-11 01:43:40 +00:00
return no_op_migrate
2023-12-11 01:43:40 +00:00
@pytest.fixture
def failing_migrate_callback() -> MigrateCallback:
def failing_migrate(cursor: sqlite3.Cursor) -> None:
raise Exception("Bad migration")
2023-12-11 01:43:40 +00:00
return failing_migrate
2023-12-11 01:43:40 +00:00
def create_migrate(i: int) -> MigrateCallback:
def migrate(cursor: sqlite3.Cursor) -> None:
cursor.execute(f"CREATE TABLE test{i} (id INTEGER PRIMARY KEY);")
2023-12-11 01:43:40 +00:00
return migrate
def test_migration_to_version_is_one_gt_from_version(no_op_migrate_callback: MigrateCallback) -> None:
with pytest.raises(ValidationError, match="to_version must be one greater than from_version"):
Migration(from_version=0, to_version=2, migrate=no_op_migrate_callback)
# not raising is sufficient
Migration(from_version=1, to_version=2, migrate=no_op_migrate_callback)
2023-12-11 05:22:29 +00:00
def test_migration_hash(no_op_migrate_callback: MigrateCallback) -> None:
2023-12-11 01:43:40 +00:00
migration = Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback)
assert hash(migration) == hash((0, 1))
2023-12-11 05:22:29 +00:00
def test_migration_registers_pre_and_post_callbacks(no_op_migrate_callback: MigrateCallback) -> None:
2023-12-11 01:43:40 +00:00
def pre_callback(cursor: sqlite3.Cursor) -> None:
pass
2023-12-11 01:43:40 +00:00
def post_callback(cursor: sqlite3.Cursor) -> None:
pass
2023-12-11 01:43:40 +00:00
migration = Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback)
migration.register_pre_callback(pre_callback)
migration.register_post_callback(post_callback)
assert pre_callback in migration.pre_migrate
assert post_callback in migration.post_migrate
2023-12-11 05:22:29 +00:00
def test_migration_set_add_migration(migrator: SQLiteMigrator, migration_no_op: Migration) -> None:
2023-12-11 01:43:40 +00:00
migration = migration_no_op
migrator._migration_set.register(migration)
assert migration in migrator._migration_set._migrations
2023-12-11 05:22:29 +00:00
def test_migration_set_may_not_register_dupes(
migrator: SQLiteMigrator, no_op_migrate_callback: MigrateCallback
) -> None:
migrate_0_to_1_a = Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback)
migrate_0_to_1_b = Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback)
migrator._migration_set.register(migrate_0_to_1_a)
with pytest.raises(MigrationVersionError, match=r"Migration with from_version or to_version already registered"):
migrator._migration_set.register(migrate_0_to_1_b)
migrate_1_to_2_a = Migration(from_version=1, to_version=2, migrate=no_op_migrate_callback)
migrate_1_to_2_b = Migration(from_version=1, to_version=2, migrate=no_op_migrate_callback)
migrator._migration_set.register(migrate_1_to_2_a)
with pytest.raises(MigrationVersionError, match=r"Migration with from_version or to_version already registered"):
migrator._migration_set.register(migrate_1_to_2_b)
2023-12-11 05:22:29 +00:00
def test_migration_set_gets_migration(migration_no_op: Migration) -> None:
2023-12-11 01:43:40 +00:00
migration_set = MigrationSet()
migration_set.register(migration_no_op)
assert migration_set.get(0) == migration_no_op
assert migration_set.get(1) is None
2023-12-11 05:22:29 +00:00
def test_migration_set_validates_migration_chain(no_op_migrate_callback: MigrateCallback) -> None:
2023-12-11 01:43:40 +00:00
migration_set = MigrationSet()
migration_set.register(Migration(from_version=1, to_version=2, migrate=no_op_migrate_callback))
with pytest.raises(MigrationError, match="Migration chain is fragmented"):
# no migration from 0 to 1
migration_set.validate_migration_chain()
2023-12-11 01:43:40 +00:00
migration_set.register(Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback))
migration_set.validate_migration_chain()
migration_set.register(Migration(from_version=2, to_version=3, migrate=no_op_migrate_callback))
migration_set.validate_migration_chain()
migration_set.register(Migration(from_version=4, to_version=5, migrate=no_op_migrate_callback))
with pytest.raises(MigrationError, match="Migration chain is fragmented"):
# no migration from 3 to 4
2023-12-11 01:43:40 +00:00
migration_set.validate_migration_chain()
2023-12-11 05:22:29 +00:00
def test_migration_set_counts_migrations(no_op_migrate_callback: MigrateCallback) -> None:
2023-12-11 01:43:40 +00:00
migration_set = MigrationSet()
assert migration_set.count == 0
migration_set.register(Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback))
assert migration_set.count == 1
migration_set.register(Migration(from_version=1, to_version=2, migrate=no_op_migrate_callback))
assert migration_set.count == 2
2023-12-11 05:22:29 +00:00
def test_migration_set_gets_latest_version(no_op_migrate_callback: MigrateCallback) -> None:
2023-12-11 01:43:40 +00:00
migration_set = MigrationSet()
assert migration_set.latest_version == 0
migration_set.register(Migration(from_version=1, to_version=2, migrate=no_op_migrate_callback))
assert migration_set.latest_version == 2
migration_set.register(Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback))
assert migration_set.latest_version == 2
2023-12-11 05:22:29 +00:00
def test_migrator_registers_migration(migrator: SQLiteMigrator, migration_no_op: Migration) -> None:
2023-12-11 01:43:40 +00:00
migration = migration_no_op
migrator.register_migration(migration)
assert migration in migrator._migration_set._migrations
2023-12-11 05:22:29 +00:00
def test_migrator_creates_migrations_table(migrator: SQLiteMigrator) -> None:
2023-12-11 01:43:40 +00:00
migrator._create_migrations_table(migrator._cursor)
migrator._cursor.execute("SELECT * FROM sqlite_master WHERE type='table' AND name='migrations';")
assert migrator._cursor.fetchone() is not None
2023-12-11 05:22:29 +00:00
def test_migrator_migration_sets_version(migrator: SQLiteMigrator, migration_no_op: Migration) -> None:
2023-12-11 01:43:40 +00:00
migrator._create_migrations_table(migrator._cursor)
migrator.register_migration(migration_no_op)
migrator.run_migrations()
migrator._cursor.execute("SELECT MAX(version) FROM migrations;")
assert migrator._cursor.fetchone()[0] == 1
2023-12-11 01:43:40 +00:00
2023-12-11 05:22:29 +00:00
def test_migrator_gets_current_version(migrator: SQLiteMigrator, migration_no_op: Migration) -> None:
2023-12-11 01:43:40 +00:00
assert migrator._get_current_version(migrator._cursor) == 0
migrator._create_migrations_table(migrator._cursor)
assert migrator._get_current_version(migrator._cursor) == 0
migrator.register_migration(migration_no_op)
migrator.run_migrations()
assert migrator._get_current_version(migrator._cursor) == 1
2023-12-11 05:22:29 +00:00
def test_migrator_runs_single_migration(migrator: SQLiteMigrator, migration_create_test_table: Migration) -> None:
2023-12-11 01:43:40 +00:00
migrator._create_migrations_table(migrator._cursor)
migrator._run_migration(migration_create_test_table, migrator._cursor)
assert migrator._get_current_version(migrator._cursor) == 1
migrator._cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test';")
assert migrator._cursor.fetchone() is not None
2023-12-11 05:22:29 +00:00
def test_migrator_runs_all_migrations_in_memory(migrator: SQLiteMigrator) -> None:
2023-12-11 01:43:40 +00:00
migrations = [Migration(from_version=i, to_version=i + 1, migrate=create_migrate(i)) for i in range(0, 3)]
for migration in migrations:
migrator.register_migration(migration)
migrator.run_migrations()
2023-12-11 01:43:40 +00:00
assert migrator._get_current_version(migrator._cursor) == 3
2023-12-11 05:22:29 +00:00
def test_migrator_runs_all_migrations_file(logger: Logger, lock: threading.RLock) -> None:
2023-12-11 01:43:40 +00:00
with TemporaryDirectory() as tempdir:
original_db_path = Path(tempdir) / "invokeai.db"
# The Migrator closes the database when it finishes; we cannot use a context manager.
original_db_conn = sqlite3.connect(original_db_path)
migrator = SQLiteMigrator(conn=original_db_conn, db_path=original_db_path, lock=lock, logger=logger)
migrations = [Migration(from_version=i, to_version=i + 1, migrate=create_migrate(i)) for i in range(0, 3)]
for migration in migrations:
migrator.register_migration(migration)
migrator.run_migrations()
with closing(sqlite3.connect(original_db_path)) as original_db_conn:
2023-12-11 01:43:40 +00:00
original_db_cursor = original_db_conn.cursor()
assert SQLiteMigrator._get_current_version(original_db_cursor) == 3
2023-12-11 05:22:29 +00:00
def test_migrator_creates_temp_db() -> None:
2023-12-11 01:43:40 +00:00
with TemporaryDirectory() as tempdir:
original_db_path = Path(tempdir) / "invokeai.db"
with closing(sqlite3.connect(original_db_path)):
2023-12-11 01:43:40 +00:00
# create the db file so _create_temp_db has something to copy
pass
temp_db_path = SQLiteMigrator._create_temp_db(original_db_path)
assert temp_db_path.is_file()
assert temp_db_path == SQLiteMigrator._get_temp_db_path(original_db_path)
2023-12-11 05:22:29 +00:00
def test_migrator_finalizes() -> None:
2023-12-11 01:43:40 +00:00
with TemporaryDirectory() as tempdir:
original_db_path = Path(tempdir) / "invokeai.db"
temp_db_path = SQLiteMigrator._get_temp_db_path(original_db_path)
backup_db_path = SQLiteMigrator._get_backup_db_path(original_db_path)
with closing(sqlite3.connect(original_db_path)) as original_db_conn, closing(
sqlite3.connect(temp_db_path)
) as temp_db_conn:
2023-12-11 01:43:40 +00:00
original_db_cursor = original_db_conn.cursor()
original_db_cursor.execute("CREATE TABLE original_db_test (id INTEGER PRIMARY KEY);")
original_db_conn.commit()
temp_db_cursor = temp_db_conn.cursor()
temp_db_cursor.execute("CREATE TABLE temp_db_test (id INTEGER PRIMARY KEY);")
temp_db_conn.commit()
SQLiteMigrator._finalize_migration(
original_db_path=original_db_path,
temp_db_path=temp_db_path,
)
with closing(sqlite3.connect(backup_db_path)) as backup_db_conn, closing(
sqlite3.connect(temp_db_path)
) as temp_db_conn:
2023-12-11 01:43:40 +00:00
backup_db_cursor = backup_db_conn.cursor()
backup_db_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='original_db_test';")
assert backup_db_cursor.fetchone() is not None
temp_db_cursor = temp_db_conn.cursor()
temp_db_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='temp_db_test';")
assert temp_db_cursor.fetchone() is None
def test_migrator_makes_no_changes_on_failed_migration(
migrator: SQLiteMigrator, migration_no_op: Migration, failing_migrate_callback: MigrateCallback
2023-12-11 05:22:29 +00:00
) -> None:
2023-12-11 01:43:40 +00:00
migrator.register_migration(migration_no_op)
migrator.run_migrations()
assert migrator._get_current_version(migrator._cursor) == 1
migrator.register_migration(Migration(from_version=1, to_version=2, migrate=failing_migrate_callback))
with pytest.raises(MigrationError, match="Bad migration"):
migrator.run_migrations()
assert migrator._get_current_version(migrator._cursor) == 1
2023-12-11 01:43:40 +00:00
2023-12-11 05:22:29 +00:00
def test_idempotent_migrations(migrator: SQLiteMigrator, migration_create_test_table: Migration) -> None:
2023-12-11 01:43:40 +00:00
migrator.register_migration(migration_create_test_table)
migrator.run_migrations()
# not throwing is sufficient
migrator.run_migrations()
2023-12-11 01:43:40 +00:00
assert migrator._get_current_version(migrator._cursor) == 1