diff --git a/tests/test_sqlite_migrator.py b/tests/test_sqlite_migrator.py index 1e6d0548b6..97419bcefa 100644 --- a/tests/test_sqlite_migrator.py +++ b/tests/test_sqlite_migrator.py @@ -1,4 +1,5 @@ import sqlite3 +from abc import ABC, abstractmethod from contextlib import closing from logging import Logger from pathlib import Path @@ -11,6 +12,7 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import ( MigrateCallback, Migration, + MigrationDependency, MigrationError, MigrationSet, MigrationVersionError, @@ -25,6 +27,16 @@ def logger() -> Logger: return Logger("test_sqlite_migrator") +@pytest.fixture +def memory_db_conn() -> sqlite3.Connection: + return sqlite3.connect(":memory:") + + +@pytest.fixture +def memory_db_cursor(memory_db_conn: sqlite3.Connection) -> sqlite3.Cursor: + return memory_db_conn.cursor() + + @pytest.fixture def migrator(logger: Logger) -> SQLiteMigrator: db = SqliteDatabase(db_path=None, logger=logger, verbose=False) @@ -45,11 +57,25 @@ def migration_no_op(no_op_migrate_callback: MigrateCallback) -> Migration: @pytest.fixture -def migration_create_test_table() -> Migration: +def migrate_callback_create_table_of_name() -> MigrateCallback: + def migrate(cursor: sqlite3.Cursor, **kwargs) -> None: + table_name = kwargs["table_name"] + cursor.execute(f"CREATE TABLE {table_name} (id INTEGER PRIMARY KEY);") + + return migrate + + +@pytest.fixture +def migrate_callback_create_test_table() -> MigrateCallback: def migrate(cursor: sqlite3.Cursor, **kwargs) -> None: cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);") - return Migration(from_version=0, to_version=1, migrate_callback=migrate) + return migrate + + +@pytest.fixture +def migration_create_test_table(migrate_callback_create_test_table: MigrateCallback) -> Migration: + return Migration(from_version=0, to_version=1, migrate_callback=migrate_callback_create_test_table) @pytest.fixture @@ -75,6 +101,31 @@ def create_migrate(i: int) -> MigrateCallback: return migrate +def test_migration_dependency_sets_value_primitive() -> None: + dependency = MigrationDependency(name="test_dependency", dependency_type=str) + dependency.set_value("test") + assert dependency.value == "test" + with pytest.raises(TypeError, match=r"Dependency test_dependency must be of type.*str"): + dependency.set_value(1) + + +def test_migration_dependency_sets_value_complex() -> None: + class SomeBase(ABC): + @abstractmethod + def some_method(self) -> None: + pass + + class SomeImpl(SomeBase): + def some_method(self) -> None: + return + + dependency = MigrationDependency(name="test_dependency", dependency_type=SomeBase) + with pytest.raises(TypeError, match=r"Dependency test_dependency must be of type.*SomeBase"): + dependency.set_value(1) + # not throwing is sufficient + dependency.set_value(SomeImpl()) + + def test_migration_to_version_is_one_gt_from_version(no_op_migrate_callback: MigrateCallback) -> None: with pytest.raises(ValidationError, match="to_version must be one greater than from_version"): Migration(from_version=0, to_version=2, migrate_callback=no_op_migrate_callback) @@ -149,6 +200,49 @@ def test_migration_set_gets_latest_version(no_op_migrate_callback: MigrateCallba assert migration_set.latest_version == 2 +def test_migration_provide_dependency_validates_name(no_op_migrate_callback: MigrateCallback) -> None: + dependency = MigrationDependency(name="my_dependency", dependency_type=str) + migration = Migration( + from_version=0, + to_version=1, + migrate_callback=no_op_migrate_callback, + dependencies={dependency.name: dependency}, + ) + with pytest.raises(ValueError, match="is not a dependency of this migration"): + migration.provide_dependency("unknown_dependency_name", "banana_sushi") + + +def test_migration_runs_without_dependencies( + memory_db_cursor: sqlite3.Cursor, migrate_callback_create_test_table: MigrateCallback +) -> None: + migration = Migration( + from_version=0, + to_version=1, + migrate_callback=migrate_callback_create_test_table, + ) + migration.run(memory_db_cursor) + memory_db_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test';") + assert memory_db_cursor.fetchone() is not None + + +def test_migration_runs_with_dependencies( + memory_db_cursor: sqlite3.Cursor, migrate_callback_create_table_of_name: MigrateCallback +) -> None: + dependency = MigrationDependency(name="table_name", dependency_type=str) + migration = Migration( + from_version=0, + to_version=1, + migrate_callback=migrate_callback_create_table_of_name, + dependencies={dependency.name: dependency}, + ) + with pytest.raises(MigrationError, match="Missing migration dependencies"): + migration.run(memory_db_cursor) + migration.provide_dependency(dependency.name, "banana_sushi") + migration.run(memory_db_cursor) + memory_db_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='banana_sushi';") + assert memory_db_cursor.fetchone() is not None + + def test_migrator_registers_migration(migrator: SQLiteMigrator, migration_no_op: Migration) -> None: migration = migration_no_op migrator.register_migration(migration)