feat(db): add tests for migration dependencies

This commit is contained in:
psychedelicious 2023-12-12 13:09:24 +11:00
parent a69f518c76
commit 50815d36c6

View File

@ -1,4 +1,5 @@
import sqlite3 import sqlite3
from abc import ABC, abstractmethod
from contextlib import closing from contextlib import closing
from logging import Logger from logging import Logger
from pathlib import Path from pathlib import Path
@ -11,6 +12,7 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import ( from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import (
MigrateCallback, MigrateCallback,
Migration, Migration,
MigrationDependency,
MigrationError, MigrationError,
MigrationSet, MigrationSet,
MigrationVersionError, MigrationVersionError,
@ -25,6 +27,16 @@ def logger() -> Logger:
return Logger("test_sqlite_migrator") return Logger("test_sqlite_migrator")
@pytest.fixture
def memory_db_conn() -> sqlite3.Connection:
return sqlite3.connect(":memory:")
@pytest.fixture
def memory_db_cursor(memory_db_conn: sqlite3.Connection) -> sqlite3.Cursor:
return memory_db_conn.cursor()
@pytest.fixture @pytest.fixture
def migrator(logger: Logger) -> SQLiteMigrator: def migrator(logger: Logger) -> SQLiteMigrator:
db = SqliteDatabase(db_path=None, logger=logger, verbose=False) db = SqliteDatabase(db_path=None, logger=logger, verbose=False)
@ -45,11 +57,25 @@ def migration_no_op(no_op_migrate_callback: MigrateCallback) -> Migration:
@pytest.fixture @pytest.fixture
def migration_create_test_table() -> Migration: def migrate_callback_create_table_of_name() -> MigrateCallback:
def migrate(cursor: sqlite3.Cursor, **kwargs) -> None:
table_name = kwargs["table_name"]
cursor.execute(f"CREATE TABLE {table_name} (id INTEGER PRIMARY KEY);")
return migrate
@pytest.fixture
def migrate_callback_create_test_table() -> MigrateCallback:
def migrate(cursor: sqlite3.Cursor, **kwargs) -> None: def migrate(cursor: sqlite3.Cursor, **kwargs) -> None:
cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);") cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
return Migration(from_version=0, to_version=1, migrate_callback=migrate) return migrate
@pytest.fixture
def migration_create_test_table(migrate_callback_create_test_table: MigrateCallback) -> Migration:
return Migration(from_version=0, to_version=1, migrate_callback=migrate_callback_create_test_table)
@pytest.fixture @pytest.fixture
@ -75,6 +101,31 @@ def create_migrate(i: int) -> MigrateCallback:
return migrate return migrate
def test_migration_dependency_sets_value_primitive() -> None:
dependency = MigrationDependency(name="test_dependency", dependency_type=str)
dependency.set_value("test")
assert dependency.value == "test"
with pytest.raises(TypeError, match=r"Dependency test_dependency must be of type.*str"):
dependency.set_value(1)
def test_migration_dependency_sets_value_complex() -> None:
class SomeBase(ABC):
@abstractmethod
def some_method(self) -> None:
pass
class SomeImpl(SomeBase):
def some_method(self) -> None:
return
dependency = MigrationDependency(name="test_dependency", dependency_type=SomeBase)
with pytest.raises(TypeError, match=r"Dependency test_dependency must be of type.*SomeBase"):
dependency.set_value(1)
# not throwing is sufficient
dependency.set_value(SomeImpl())
def test_migration_to_version_is_one_gt_from_version(no_op_migrate_callback: MigrateCallback) -> None: def test_migration_to_version_is_one_gt_from_version(no_op_migrate_callback: MigrateCallback) -> None:
with pytest.raises(ValidationError, match="to_version must be one greater than from_version"): with pytest.raises(ValidationError, match="to_version must be one greater than from_version"):
Migration(from_version=0, to_version=2, migrate_callback=no_op_migrate_callback) Migration(from_version=0, to_version=2, migrate_callback=no_op_migrate_callback)
@ -149,6 +200,49 @@ def test_migration_set_gets_latest_version(no_op_migrate_callback: MigrateCallba
assert migration_set.latest_version == 2 assert migration_set.latest_version == 2
def test_migration_provide_dependency_validates_name(no_op_migrate_callback: MigrateCallback) -> None:
dependency = MigrationDependency(name="my_dependency", dependency_type=str)
migration = Migration(
from_version=0,
to_version=1,
migrate_callback=no_op_migrate_callback,
dependencies={dependency.name: dependency},
)
with pytest.raises(ValueError, match="is not a dependency of this migration"):
migration.provide_dependency("unknown_dependency_name", "banana_sushi")
def test_migration_runs_without_dependencies(
memory_db_cursor: sqlite3.Cursor, migrate_callback_create_test_table: MigrateCallback
) -> None:
migration = Migration(
from_version=0,
to_version=1,
migrate_callback=migrate_callback_create_test_table,
)
migration.run(memory_db_cursor)
memory_db_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test';")
assert memory_db_cursor.fetchone() is not None
def test_migration_runs_with_dependencies(
memory_db_cursor: sqlite3.Cursor, migrate_callback_create_table_of_name: MigrateCallback
) -> None:
dependency = MigrationDependency(name="table_name", dependency_type=str)
migration = Migration(
from_version=0,
to_version=1,
migrate_callback=migrate_callback_create_table_of_name,
dependencies={dependency.name: dependency},
)
with pytest.raises(MigrationError, match="Missing migration dependencies"):
migration.run(memory_db_cursor)
migration.provide_dependency(dependency.name, "banana_sushi")
migration.run(memory_db_cursor)
memory_db_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='banana_sushi';")
assert memory_db_cursor.fetchone() is not None
def test_migrator_registers_migration(migrator: SQLiteMigrator, migration_no_op: Migration) -> None: def test_migrator_registers_migration(migrator: SQLiteMigrator, migration_no_op: Migration) -> None:
migration = migration_no_op migration = migration_no_op
migrator.register_migration(migration) migrator.register_migration(migration)