feat(db): add tests for migration dependencies

This commit is contained in:
psychedelicious 2023-12-12 13:09:24 +11:00
parent a69f518c76
commit 50815d36c6

View File

@ -1,4 +1,5 @@
import sqlite3
from abc import ABC, abstractmethod
from contextlib import closing
from logging import Logger
from pathlib import Path
@ -11,6 +12,7 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import (
MigrateCallback,
Migration,
MigrationDependency,
MigrationError,
MigrationSet,
MigrationVersionError,
@ -25,6 +27,16 @@ def logger() -> Logger:
return Logger("test_sqlite_migrator")
@pytest.fixture
def memory_db_conn() -> sqlite3.Connection:
return sqlite3.connect(":memory:")
@pytest.fixture
def memory_db_cursor(memory_db_conn: sqlite3.Connection) -> sqlite3.Cursor:
return memory_db_conn.cursor()
@pytest.fixture
def migrator(logger: Logger) -> SQLiteMigrator:
db = SqliteDatabase(db_path=None, logger=logger, verbose=False)
@ -45,11 +57,25 @@ def migration_no_op(no_op_migrate_callback: MigrateCallback) -> Migration:
@pytest.fixture
def migration_create_test_table() -> Migration:
def migrate_callback_create_table_of_name() -> MigrateCallback:
def migrate(cursor: sqlite3.Cursor, **kwargs) -> None:
table_name = kwargs["table_name"]
cursor.execute(f"CREATE TABLE {table_name} (id INTEGER PRIMARY KEY);")
return migrate
@pytest.fixture
def migrate_callback_create_test_table() -> MigrateCallback:
def migrate(cursor: sqlite3.Cursor, **kwargs) -> None:
cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
return Migration(from_version=0, to_version=1, migrate_callback=migrate)
return migrate
@pytest.fixture
def migration_create_test_table(migrate_callback_create_test_table: MigrateCallback) -> Migration:
return Migration(from_version=0, to_version=1, migrate_callback=migrate_callback_create_test_table)
@pytest.fixture
@ -75,6 +101,31 @@ def create_migrate(i: int) -> MigrateCallback:
return migrate
def test_migration_dependency_sets_value_primitive() -> None:
dependency = MigrationDependency(name="test_dependency", dependency_type=str)
dependency.set_value("test")
assert dependency.value == "test"
with pytest.raises(TypeError, match=r"Dependency test_dependency must be of type.*str"):
dependency.set_value(1)
def test_migration_dependency_sets_value_complex() -> None:
class SomeBase(ABC):
@abstractmethod
def some_method(self) -> None:
pass
class SomeImpl(SomeBase):
def some_method(self) -> None:
return
dependency = MigrationDependency(name="test_dependency", dependency_type=SomeBase)
with pytest.raises(TypeError, match=r"Dependency test_dependency must be of type.*SomeBase"):
dependency.set_value(1)
# not throwing is sufficient
dependency.set_value(SomeImpl())
def test_migration_to_version_is_one_gt_from_version(no_op_migrate_callback: MigrateCallback) -> None:
with pytest.raises(ValidationError, match="to_version must be one greater than from_version"):
Migration(from_version=0, to_version=2, migrate_callback=no_op_migrate_callback)
@ -149,6 +200,49 @@ def test_migration_set_gets_latest_version(no_op_migrate_callback: MigrateCallba
assert migration_set.latest_version == 2
def test_migration_provide_dependency_validates_name(no_op_migrate_callback: MigrateCallback) -> None:
dependency = MigrationDependency(name="my_dependency", dependency_type=str)
migration = Migration(
from_version=0,
to_version=1,
migrate_callback=no_op_migrate_callback,
dependencies={dependency.name: dependency},
)
with pytest.raises(ValueError, match="is not a dependency of this migration"):
migration.provide_dependency("unknown_dependency_name", "banana_sushi")
def test_migration_runs_without_dependencies(
memory_db_cursor: sqlite3.Cursor, migrate_callback_create_test_table: MigrateCallback
) -> None:
migration = Migration(
from_version=0,
to_version=1,
migrate_callback=migrate_callback_create_test_table,
)
migration.run(memory_db_cursor)
memory_db_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test';")
assert memory_db_cursor.fetchone() is not None
def test_migration_runs_with_dependencies(
memory_db_cursor: sqlite3.Cursor, migrate_callback_create_table_of_name: MigrateCallback
) -> None:
dependency = MigrationDependency(name="table_name", dependency_type=str)
migration = Migration(
from_version=0,
to_version=1,
migrate_callback=migrate_callback_create_table_of_name,
dependencies={dependency.name: dependency},
)
with pytest.raises(MigrationError, match="Missing migration dependencies"):
migration.run(memory_db_cursor)
migration.provide_dependency(dependency.name, "banana_sushi")
migration.run(memory_db_cursor)
memory_db_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='banana_sushi';")
assert memory_db_cursor.fetchone() is not None
def test_migrator_registers_migration(migrator: SQLiteMigrator, migration_no_op: Migration) -> None:
migration = migration_no_op
migrator.register_migration(migration)