mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(db): add tests for migration dependencies
This commit is contained in:
parent
a69f518c76
commit
50815d36c6
@ -1,4 +1,5 @@
|
||||
import sqlite3
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import closing
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
@ -11,6 +12,7 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import (
|
||||
MigrateCallback,
|
||||
Migration,
|
||||
MigrationDependency,
|
||||
MigrationError,
|
||||
MigrationSet,
|
||||
MigrationVersionError,
|
||||
@ -25,6 +27,16 @@ def logger() -> Logger:
|
||||
return Logger("test_sqlite_migrator")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def memory_db_conn() -> sqlite3.Connection:
|
||||
return sqlite3.connect(":memory:")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def memory_db_cursor(memory_db_conn: sqlite3.Connection) -> sqlite3.Cursor:
|
||||
return memory_db_conn.cursor()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def migrator(logger: Logger) -> SQLiteMigrator:
|
||||
db = SqliteDatabase(db_path=None, logger=logger, verbose=False)
|
||||
@ -45,11 +57,25 @@ def migration_no_op(no_op_migrate_callback: MigrateCallback) -> Migration:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def migration_create_test_table() -> Migration:
|
||||
def migrate_callback_create_table_of_name() -> MigrateCallback:
|
||||
def migrate(cursor: sqlite3.Cursor, **kwargs) -> None:
|
||||
table_name = kwargs["table_name"]
|
||||
cursor.execute(f"CREATE TABLE {table_name} (id INTEGER PRIMARY KEY);")
|
||||
|
||||
return migrate
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def migrate_callback_create_test_table() -> MigrateCallback:
|
||||
def migrate(cursor: sqlite3.Cursor, **kwargs) -> None:
|
||||
cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
|
||||
return Migration(from_version=0, to_version=1, migrate_callback=migrate)
|
||||
return migrate
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def migration_create_test_table(migrate_callback_create_test_table: MigrateCallback) -> Migration:
|
||||
return Migration(from_version=0, to_version=1, migrate_callback=migrate_callback_create_test_table)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -75,6 +101,31 @@ def create_migrate(i: int) -> MigrateCallback:
|
||||
return migrate
|
||||
|
||||
|
||||
def test_migration_dependency_sets_value_primitive() -> None:
|
||||
dependency = MigrationDependency(name="test_dependency", dependency_type=str)
|
||||
dependency.set_value("test")
|
||||
assert dependency.value == "test"
|
||||
with pytest.raises(TypeError, match=r"Dependency test_dependency must be of type.*str"):
|
||||
dependency.set_value(1)
|
||||
|
||||
|
||||
def test_migration_dependency_sets_value_complex() -> None:
|
||||
class SomeBase(ABC):
|
||||
@abstractmethod
|
||||
def some_method(self) -> None:
|
||||
pass
|
||||
|
||||
class SomeImpl(SomeBase):
|
||||
def some_method(self) -> None:
|
||||
return
|
||||
|
||||
dependency = MigrationDependency(name="test_dependency", dependency_type=SomeBase)
|
||||
with pytest.raises(TypeError, match=r"Dependency test_dependency must be of type.*SomeBase"):
|
||||
dependency.set_value(1)
|
||||
# not throwing is sufficient
|
||||
dependency.set_value(SomeImpl())
|
||||
|
||||
|
||||
def test_migration_to_version_is_one_gt_from_version(no_op_migrate_callback: MigrateCallback) -> None:
|
||||
with pytest.raises(ValidationError, match="to_version must be one greater than from_version"):
|
||||
Migration(from_version=0, to_version=2, migrate_callback=no_op_migrate_callback)
|
||||
@ -149,6 +200,49 @@ def test_migration_set_gets_latest_version(no_op_migrate_callback: MigrateCallba
|
||||
assert migration_set.latest_version == 2
|
||||
|
||||
|
||||
def test_migration_provide_dependency_validates_name(no_op_migrate_callback: MigrateCallback) -> None:
|
||||
dependency = MigrationDependency(name="my_dependency", dependency_type=str)
|
||||
migration = Migration(
|
||||
from_version=0,
|
||||
to_version=1,
|
||||
migrate_callback=no_op_migrate_callback,
|
||||
dependencies={dependency.name: dependency},
|
||||
)
|
||||
with pytest.raises(ValueError, match="is not a dependency of this migration"):
|
||||
migration.provide_dependency("unknown_dependency_name", "banana_sushi")
|
||||
|
||||
|
||||
def test_migration_runs_without_dependencies(
|
||||
memory_db_cursor: sqlite3.Cursor, migrate_callback_create_test_table: MigrateCallback
|
||||
) -> None:
|
||||
migration = Migration(
|
||||
from_version=0,
|
||||
to_version=1,
|
||||
migrate_callback=migrate_callback_create_test_table,
|
||||
)
|
||||
migration.run(memory_db_cursor)
|
||||
memory_db_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test';")
|
||||
assert memory_db_cursor.fetchone() is not None
|
||||
|
||||
|
||||
def test_migration_runs_with_dependencies(
|
||||
memory_db_cursor: sqlite3.Cursor, migrate_callback_create_table_of_name: MigrateCallback
|
||||
) -> None:
|
||||
dependency = MigrationDependency(name="table_name", dependency_type=str)
|
||||
migration = Migration(
|
||||
from_version=0,
|
||||
to_version=1,
|
||||
migrate_callback=migrate_callback_create_table_of_name,
|
||||
dependencies={dependency.name: dependency},
|
||||
)
|
||||
with pytest.raises(MigrationError, match="Missing migration dependencies"):
|
||||
migration.run(memory_db_cursor)
|
||||
migration.provide_dependency(dependency.name, "banana_sushi")
|
||||
migration.run(memory_db_cursor)
|
||||
memory_db_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='banana_sushi';")
|
||||
assert memory_db_cursor.fetchone() is not None
|
||||
|
||||
|
||||
def test_migrator_registers_migration(migrator: SQLiteMigrator, migration_no_op: Migration) -> None:
|
||||
migration = migration_no_op
|
||||
migrator.register_migration(migration)
|
||||
|
Loading…
Reference in New Issue
Block a user