mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(db): update sqlite migrator tests
This commit is contained in:
parent
3227b30430
commit
c823f5667b
@ -1,33 +1,50 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
from copy import deepcopy
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
|
||||
from invokeai.app.services.shared.sqlite.sqlite_migrator import (
|
||||
MigrateCallback,
|
||||
Migration,
|
||||
MigrationError,
|
||||
MigrationSet,
|
||||
MigrationVersionError,
|
||||
SQLiteMigrator,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def migrator() -> SQLiteMigrator:
|
||||
conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
|
||||
return SQLiteMigrator(
|
||||
conn=conn, db_path=sqlite_memory, lock=threading.RLock(), logger=Logger("test_sqlite_migrator")
|
||||
)
|
||||
def logger() -> Logger:
|
||||
return Logger("test_sqlite_migrator")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def good_migration() -> Migration:
|
||||
return Migration(db_version=1, app_version="1.0.0", migrate=lambda cursor: None)
|
||||
def lock() -> threading.RLock:
|
||||
return threading.RLock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def migrator(logger: Logger, lock: threading.RLock) -> SQLiteMigrator:
|
||||
conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
|
||||
return SQLiteMigrator(conn=conn, db_path=None, lock=lock, logger=logger)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def migration_no_op() -> Migration:
|
||||
return Migration(from_version=0, to_version=1, migrate=lambda cursor: None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def migration_create_test_table() -> Migration:
|
||||
def migrate(cursor: sqlite3.Cursor) -> None:
|
||||
cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
|
||||
return Migration(from_version=0, to_version=1, migrate=migrate)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -35,156 +52,224 @@ def failing_migration() -> Migration:
|
||||
def failing_migration(cursor: sqlite3.Cursor) -> None:
|
||||
raise Exception("Bad migration")
|
||||
|
||||
return Migration(db_version=1, app_version="1.0.0", migrate=failing_migration)
|
||||
return Migration(from_version=0, to_version=1, migrate=failing_migration)
|
||||
|
||||
|
||||
def test_register_migration(migrator: SQLiteMigrator, good_migration: Migration):
|
||||
migration = good_migration
|
||||
@pytest.fixture
|
||||
def no_op_migrate_callback() -> MigrateCallback:
|
||||
def no_op_migrate(cursor: sqlite3.Cursor) -> None:
|
||||
pass
|
||||
|
||||
return no_op_migrate
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def failing_migrate_callback() -> MigrateCallback:
|
||||
def failing_migrate(cursor: sqlite3.Cursor) -> None:
|
||||
raise Exception("Bad migration")
|
||||
|
||||
return failing_migrate
|
||||
|
||||
|
||||
def create_migrate(i: int) -> MigrateCallback:
|
||||
def migrate(cursor: sqlite3.Cursor) -> None:
|
||||
cursor.execute(f"CREATE TABLE test{i} (id INTEGER PRIMARY KEY);")
|
||||
|
||||
return migrate
|
||||
|
||||
|
||||
def test_migration_to_version_gt_from_version(no_op_migrate_callback: MigrateCallback):
|
||||
with pytest.raises(ValidationError, match="greater_than_equal"):
|
||||
Migration(from_version=1, to_version=0, migrate=no_op_migrate_callback)
|
||||
|
||||
|
||||
def test_migration_hash(no_op_migrate_callback: MigrateCallback):
|
||||
migration = Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback)
|
||||
assert hash(migration) == hash((0, 1))
|
||||
|
||||
|
||||
def test_migration_registers_pre_and_post_callbacks(no_op_migrate_callback: MigrateCallback):
|
||||
def pre_callback(cursor: sqlite3.Cursor) -> None:
|
||||
pass
|
||||
|
||||
def post_callback(cursor: sqlite3.Cursor) -> None:
|
||||
pass
|
||||
|
||||
migration = Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback)
|
||||
migration.register_pre_callback(pre_callback)
|
||||
migration.register_post_callback(post_callback)
|
||||
assert pre_callback in migration.pre_migrate
|
||||
assert post_callback in migration.post_migrate
|
||||
|
||||
|
||||
def test_migration_set_add_migration(migrator: SQLiteMigrator, migration_no_op: Migration):
|
||||
migration = migration_no_op
|
||||
migrator._migration_set.register(migration)
|
||||
assert migration in migrator._migration_set._migrations
|
||||
|
||||
|
||||
def test_migration_set_may_not_register_dupes(migrator: SQLiteMigrator, no_op_migrate_callback: MigrateCallback):
|
||||
migrate_1_to_2 = Migration(from_version=1, to_version=2, migrate=no_op_migrate_callback)
|
||||
migrate_0_to_2 = Migration(from_version=0, to_version=2, migrate=no_op_migrate_callback)
|
||||
migrate_1_to_3 = Migration(from_version=1, to_version=3, migrate=no_op_migrate_callback)
|
||||
migrator._migration_set.register(migrate_1_to_2)
|
||||
with pytest.raises(MigrationVersionError, match=r"Migration to 2 already registered"):
|
||||
migrator._migration_set.register(migrate_0_to_2)
|
||||
with pytest.raises(MigrationVersionError, match=r"Migration from 1 already registered"):
|
||||
migrator._migration_set.register(migrate_1_to_3)
|
||||
|
||||
|
||||
def test_migration_set_gets_migration(migration_no_op: Migration):
|
||||
migration_set = MigrationSet()
|
||||
migration_set.register(migration_no_op)
|
||||
assert migration_set.get(0) == migration_no_op
|
||||
assert migration_set.get(1) is None
|
||||
|
||||
|
||||
def test_migration_set_validates_migration_path(no_op_migrate_callback: MigrateCallback):
|
||||
migration_set = MigrationSet()
|
||||
migration_set.validate_migration_chain()
|
||||
migration_set.register(Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback))
|
||||
migration_set.validate_migration_chain()
|
||||
migration_set.register(Migration(from_version=1, to_version=2, migrate=no_op_migrate_callback))
|
||||
migration_set.register(Migration(from_version=2, to_version=3, migrate=no_op_migrate_callback))
|
||||
migration_set.validate_migration_chain()
|
||||
migration_set.register(Migration(from_version=4, to_version=5, migrate=no_op_migrate_callback))
|
||||
with pytest.raises(MigrationError, match="Migration chain is fragmented"):
|
||||
migration_set.validate_migration_chain()
|
||||
|
||||
|
||||
def test_migration_set_counts_migrations(no_op_migrate_callback: MigrateCallback):
|
||||
migration_set = MigrationSet()
|
||||
assert migration_set.count == 0
|
||||
migration_set.register(Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback))
|
||||
assert migration_set.count == 1
|
||||
migration_set.register(Migration(from_version=1, to_version=2, migrate=no_op_migrate_callback))
|
||||
assert migration_set.count == 2
|
||||
|
||||
|
||||
def test_migration_set_gets_latest_version(no_op_migrate_callback: MigrateCallback):
|
||||
migration_set = MigrationSet()
|
||||
assert migration_set.latest_version == 0
|
||||
migration_set.register(Migration(from_version=1, to_version=2, migrate=no_op_migrate_callback))
|
||||
assert migration_set.latest_version == 2
|
||||
migration_set.register(Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback))
|
||||
assert migration_set.latest_version == 2
|
||||
|
||||
|
||||
def test_migrator_registers_migration(migrator: SQLiteMigrator, migration_no_op: Migration):
|
||||
migration = migration_no_op
|
||||
migrator.register_migration(migration)
|
||||
assert migration in migrator._migrations
|
||||
assert migration in migrator._migration_set._migrations
|
||||
|
||||
|
||||
def test_register_invalid_migration_version(migrator: SQLiteMigrator):
|
||||
with pytest.raises(MigrationError, match="Invalid migration version"):
|
||||
migrator.register_migration(Migration(db_version=0, app_version="0.0.0", migrate=lambda cursor: None))
|
||||
|
||||
|
||||
def test_create_version_table(migrator: SQLiteMigrator):
|
||||
migrator._create_migrations_table()
|
||||
migrator._cursor.execute("SELECT * FROM sqlite_master WHERE type='table' AND name='version';")
|
||||
def test_migrator_creates_migrations_table(migrator: SQLiteMigrator):
|
||||
migrator._create_migrations_table(migrator._cursor)
|
||||
migrator._cursor.execute("SELECT * FROM sqlite_master WHERE type='table' AND name='migrations';")
|
||||
assert migrator._cursor.fetchone() is not None
|
||||
|
||||
|
||||
def test_get_current_version(migrator: SQLiteMigrator):
|
||||
migrator._create_migrations_table()
|
||||
migrator._conn.commit()
|
||||
assert migrator._get_current_version() == 0 # initial version
|
||||
|
||||
|
||||
def test_set_version(migrator: SQLiteMigrator):
|
||||
migrator._create_migrations_table()
|
||||
migrator._set_version(db_version=1, app_version="1.0.0")
|
||||
migrator._cursor.execute("SELECT MAX(db_version) FROM version;")
|
||||
def test_migrator_migration_sets_version(migrator: SQLiteMigrator, migration_no_op: Migration):
|
||||
migrator._create_migrations_table(migrator._cursor)
|
||||
migrator.register_migration(migration_no_op)
|
||||
migrator.run_migrations()
|
||||
migrator._cursor.execute("SELECT MAX(version) FROM migrations;")
|
||||
assert migrator._cursor.fetchone()[0] == 1
|
||||
migrator._cursor.execute("SELECT app_version from version WHERE db_version = 1;")
|
||||
assert migrator._cursor.fetchone()[0] == "1.0.0"
|
||||
|
||||
|
||||
def test_run_migration(migrator: SQLiteMigrator):
|
||||
migrator._create_migrations_table()
|
||||
|
||||
def migration_callback(cursor: sqlite3.Cursor) -> None:
|
||||
cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
|
||||
migration = Migration(db_version=1, app_version="1.0.0", migrate=migration_callback)
|
||||
migrator._run_migration(migration)
|
||||
assert migrator._get_current_version() == 1
|
||||
migrator._cursor.execute("SELECT app_version from version WHERE db_version = 1;")
|
||||
assert migrator._cursor.fetchone()[0] == "1.0.0"
|
||||
|
||||
|
||||
def test_run_migrations(migrator: SQLiteMigrator):
|
||||
migrator._create_migrations_table()
|
||||
|
||||
def create_migrate(i: int) -> Callable[[sqlite3.Cursor], None]:
|
||||
def migrate(cursor: sqlite3.Cursor) -> None:
|
||||
cursor.execute(f"CREATE TABLE test{i} (id INTEGER PRIMARY KEY);")
|
||||
|
||||
return migrate
|
||||
|
||||
migrations = [Migration(db_version=i, app_version=f"{i}.0.0", migrate=create_migrate(i)) for i in range(1, 4)]
|
||||
for migration in migrations:
|
||||
migrator.register_migration(migration)
|
||||
def test_migrator_gets_current_version(migrator: SQLiteMigrator, migration_no_op: Migration):
|
||||
assert migrator._get_current_version(migrator._cursor) == 0
|
||||
migrator._create_migrations_table(migrator._cursor)
|
||||
assert migrator._get_current_version(migrator._cursor) == 0
|
||||
migrator.register_migration(migration_no_op)
|
||||
migrator.run_migrations()
|
||||
assert migrator._get_current_version() == 3
|
||||
assert migrator._get_current_version(migrator._cursor) == 1
|
||||
|
||||
|
||||
def test_backup_and_restore_db():
|
||||
# must do this with a file database - we don't backup/restore for memory
|
||||
with TemporaryDirectory() as tempdir:
|
||||
# create test DB w/ some data
|
||||
database = Path(tempdir) / "test.db"
|
||||
conn = sqlite3.connect(database, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
conn.commit()
|
||||
|
||||
migrator = SQLiteMigrator(conn=conn, db_path=database, lock=threading.RLock(), logger=Logger("test"))
|
||||
backup_path = migrator._backup_db(migrator._db_path)
|
||||
|
||||
# mangle the db
|
||||
migrator._cursor.execute("DROP TABLE test;")
|
||||
migrator._conn.commit()
|
||||
def test_migrator_runs_single_migration(migrator: SQLiteMigrator, migration_create_test_table: Migration):
|
||||
migrator._create_migrations_table(migrator._cursor)
|
||||
migrator._run_migration(migration_create_test_table, migrator._cursor)
|
||||
assert migrator._get_current_version(migrator._cursor) == 1
|
||||
migrator._cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test';")
|
||||
assert migrator._cursor.fetchone() is None
|
||||
|
||||
# restore (closes the connection - must create a new one)
|
||||
migrator._restore_db(backup_path)
|
||||
restored_conn = sqlite3.connect(database)
|
||||
restored_cursor = restored_conn.cursor()
|
||||
restored_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test';")
|
||||
assert restored_cursor.fetchone() is not None
|
||||
|
||||
# must manually close else tempfile throws on cleanup on windows
|
||||
restored_conn.close()
|
||||
assert migrator._cursor.fetchone() is not None
|
||||
|
||||
|
||||
def test_no_backup_and_restore_for_memory_db(migrator: SQLiteMigrator):
|
||||
with pytest.raises(MigrationError, match="Cannot back up memory database"):
|
||||
migrator._backup_db(sqlite_memory)
|
||||
|
||||
|
||||
def test_failed_migration(migrator: SQLiteMigrator, failing_migration: Migration):
|
||||
migrator._create_migrations_table()
|
||||
with pytest.raises(MigrationError, match="Error migrating database from 0 to 1"):
|
||||
migrator._run_migration(failing_migration)
|
||||
assert migrator._get_current_version() == 0
|
||||
|
||||
|
||||
def test_duplicate_migration_versions(migrator: SQLiteMigrator, good_migration: Migration):
|
||||
migrator._create_migrations_table()
|
||||
migrator.register_migration(good_migration)
|
||||
with pytest.raises(MigrationVersionError, match="already registered"):
|
||||
migrator.register_migration(deepcopy(good_migration))
|
||||
|
||||
|
||||
def test_non_sequential_migration_registration(migrator: SQLiteMigrator):
|
||||
migrator._create_migrations_table()
|
||||
|
||||
def create_migrate(i: int) -> Callable[[sqlite3.Cursor], None]:
|
||||
def migrate(cursor: sqlite3.Cursor) -> None:
|
||||
cursor.execute(f"CREATE TABLE test{i} (id INTEGER PRIMARY KEY);")
|
||||
|
||||
return migrate
|
||||
|
||||
migrations = [
|
||||
Migration(db_version=i, app_version=f"{i}.0.0", migrate=create_migrate(i)) for i in reversed(range(1, 4))
|
||||
]
|
||||
def test_migrator_runs_all_migrations_in_memory(
|
||||
migrator: SQLiteMigrator,
|
||||
):
|
||||
migrations = [Migration(from_version=i, to_version=i + 1, migrate=create_migrate(i)) for i in range(0, 3)]
|
||||
for migration in migrations:
|
||||
migrator.register_migration(migration)
|
||||
migrator.run_migrations()
|
||||
assert migrator._get_current_version() == 3
|
||||
assert migrator._get_current_version(migrator._cursor) == 3
|
||||
|
||||
|
||||
def test_db_version_gt_last_migration(migrator: SQLiteMigrator, good_migration: Migration):
|
||||
migrator._create_migrations_table()
|
||||
migrator.register_migration(good_migration)
|
||||
migrator._set_version(db_version=2, app_version="2.0.0")
|
||||
with pytest.raises(MigrationError, match="greater than the latest migration version"):
|
||||
migrator.run_migrations()
|
||||
assert migrator._get_current_version() == 2
|
||||
|
||||
|
||||
def test_idempotent_migrations(migrator: SQLiteMigrator):
|
||||
migrator._create_migrations_table()
|
||||
|
||||
def create_test_table(cursor: sqlite3.Cursor) -> None:
|
||||
# This SQL throws if run twice
|
||||
cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
|
||||
migration = Migration(db_version=1, app_version="1.0.0", migrate=create_test_table)
|
||||
|
||||
def test_migrator_runs_all_migrations_file(logger: Logger, lock: threading.RLock):
|
||||
with TemporaryDirectory() as tempdir:
|
||||
original_db_path = Path(tempdir) / "invokeai.db"
|
||||
# The Migrator closes the database when it finishes; we cannot use a context manager.
|
||||
original_db_conn = sqlite3.connect(original_db_path)
|
||||
migrator = SQLiteMigrator(conn=original_db_conn, db_path=original_db_path, lock=lock, logger=logger)
|
||||
migrations = [Migration(from_version=i, to_version=i + 1, migrate=create_migrate(i)) for i in range(0, 3)]
|
||||
for migration in migrations:
|
||||
migrator.register_migration(migration)
|
||||
migrator.run_migrations()
|
||||
with sqlite3.connect(original_db_path) as original_db_conn:
|
||||
original_db_cursor = original_db_conn.cursor()
|
||||
assert SQLiteMigrator._get_current_version(original_db_cursor) == 3
|
||||
|
||||
|
||||
def test_migrator_creates_temp_db():
|
||||
with TemporaryDirectory() as tempdir:
|
||||
original_db_path = Path(tempdir) / "invokeai.db"
|
||||
with sqlite3.connect(original_db_path):
|
||||
# create the db file so _create_temp_db has something to copy
|
||||
pass
|
||||
temp_db_path = SQLiteMigrator._create_temp_db(original_db_path)
|
||||
assert temp_db_path.is_file()
|
||||
assert temp_db_path == SQLiteMigrator._get_temp_db_path(original_db_path)
|
||||
|
||||
|
||||
def test_migrator_finalizes():
|
||||
with TemporaryDirectory() as tempdir:
|
||||
original_db_path = Path(tempdir) / "invokeai.db"
|
||||
temp_db_path = SQLiteMigrator._get_temp_db_path(original_db_path)
|
||||
backup_db_path = SQLiteMigrator._get_backup_db_path(original_db_path)
|
||||
with sqlite3.connect(original_db_path) as original_db_conn, sqlite3.connect(temp_db_path) as temp_db_conn:
|
||||
original_db_cursor = original_db_conn.cursor()
|
||||
original_db_cursor.execute("CREATE TABLE original_db_test (id INTEGER PRIMARY KEY);")
|
||||
original_db_conn.commit()
|
||||
temp_db_cursor = temp_db_conn.cursor()
|
||||
temp_db_cursor.execute("CREATE TABLE temp_db_test (id INTEGER PRIMARY KEY);")
|
||||
temp_db_conn.commit()
|
||||
SQLiteMigrator._finalize_migration(
|
||||
original_db_path=original_db_path,
|
||||
temp_db_path=temp_db_path,
|
||||
)
|
||||
with sqlite3.connect(backup_db_path) as backup_db_conn, sqlite3.connect(temp_db_path) as temp_db_conn:
|
||||
backup_db_cursor = backup_db_conn.cursor()
|
||||
backup_db_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='original_db_test';")
|
||||
assert backup_db_cursor.fetchone() is not None
|
||||
temp_db_cursor = temp_db_conn.cursor()
|
||||
temp_db_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='temp_db_test';")
|
||||
assert temp_db_cursor.fetchone() is None
|
||||
|
||||
|
||||
def test_migrator_makes_no_changes_on_failed_migration(
|
||||
migrator: SQLiteMigrator, migration_no_op: Migration, failing_migrate_callback: MigrateCallback
|
||||
):
|
||||
migrator.register_migration(migration_no_op)
|
||||
migrator.run_migrations()
|
||||
assert migrator._get_current_version(migrator._cursor) == 1
|
||||
migrator.register_migration(Migration(from_version=1, to_version=2, migrate=failing_migrate_callback))
|
||||
with pytest.raises(MigrationError, match="Bad migration"):
|
||||
migrator.run_migrations()
|
||||
assert migrator._get_current_version(migrator._cursor) == 1
|
||||
|
||||
|
||||
def test_idempotent_migrations(migrator: SQLiteMigrator, migration_create_test_table: Migration):
|
||||
migrator.register_migration(migration_create_test_table)
|
||||
migrator.run_migrations()
|
||||
# not throwing is sufficient
|
||||
migrator.run_migrations()
|
||||
assert migrator._get_current_version(migrator._cursor) == 1
|
||||
|
Loading…
Reference in New Issue
Block a user