mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(db): add tests for migration dependencies
This commit is contained in:
parent
a69f518c76
commit
50815d36c6
@ -1,4 +1,5 @@
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from contextlib import closing
|
from contextlib import closing
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -11,6 +12,7 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
|||||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import (
|
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import (
|
||||||
MigrateCallback,
|
MigrateCallback,
|
||||||
Migration,
|
Migration,
|
||||||
|
MigrationDependency,
|
||||||
MigrationError,
|
MigrationError,
|
||||||
MigrationSet,
|
MigrationSet,
|
||||||
MigrationVersionError,
|
MigrationVersionError,
|
||||||
@ -25,6 +27,16 @@ def logger() -> Logger:
|
|||||||
return Logger("test_sqlite_migrator")
|
return Logger("test_sqlite_migrator")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def memory_db_conn() -> sqlite3.Connection:
|
||||||
|
return sqlite3.connect(":memory:")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def memory_db_cursor(memory_db_conn: sqlite3.Connection) -> sqlite3.Cursor:
|
||||||
|
return memory_db_conn.cursor()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def migrator(logger: Logger) -> SQLiteMigrator:
|
def migrator(logger: Logger) -> SQLiteMigrator:
|
||||||
db = SqliteDatabase(db_path=None, logger=logger, verbose=False)
|
db = SqliteDatabase(db_path=None, logger=logger, verbose=False)
|
||||||
@ -45,11 +57,25 @@ def migration_no_op(no_op_migrate_callback: MigrateCallback) -> Migration:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def migration_create_test_table() -> Migration:
|
def migrate_callback_create_table_of_name() -> MigrateCallback:
|
||||||
|
def migrate(cursor: sqlite3.Cursor, **kwargs) -> None:
|
||||||
|
table_name = kwargs["table_name"]
|
||||||
|
cursor.execute(f"CREATE TABLE {table_name} (id INTEGER PRIMARY KEY);")
|
||||||
|
|
||||||
|
return migrate
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def migrate_callback_create_test_table() -> MigrateCallback:
|
||||||
def migrate(cursor: sqlite3.Cursor, **kwargs) -> None:
|
def migrate(cursor: sqlite3.Cursor, **kwargs) -> None:
|
||||||
cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||||
|
|
||||||
return Migration(from_version=0, to_version=1, migrate_callback=migrate)
|
return migrate
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def migration_create_test_table(migrate_callback_create_test_table: MigrateCallback) -> Migration:
|
||||||
|
return Migration(from_version=0, to_version=1, migrate_callback=migrate_callback_create_test_table)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -75,6 +101,31 @@ def create_migrate(i: int) -> MigrateCallback:
|
|||||||
return migrate
|
return migrate
|
||||||
|
|
||||||
|
|
||||||
|
def test_migration_dependency_sets_value_primitive() -> None:
|
||||||
|
dependency = MigrationDependency(name="test_dependency", dependency_type=str)
|
||||||
|
dependency.set_value("test")
|
||||||
|
assert dependency.value == "test"
|
||||||
|
with pytest.raises(TypeError, match=r"Dependency test_dependency must be of type.*str"):
|
||||||
|
dependency.set_value(1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_migration_dependency_sets_value_complex() -> None:
|
||||||
|
class SomeBase(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def some_method(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class SomeImpl(SomeBase):
|
||||||
|
def some_method(self) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
dependency = MigrationDependency(name="test_dependency", dependency_type=SomeBase)
|
||||||
|
with pytest.raises(TypeError, match=r"Dependency test_dependency must be of type.*SomeBase"):
|
||||||
|
dependency.set_value(1)
|
||||||
|
# not throwing is sufficient
|
||||||
|
dependency.set_value(SomeImpl())
|
||||||
|
|
||||||
|
|
||||||
def test_migration_to_version_is_one_gt_from_version(no_op_migrate_callback: MigrateCallback) -> None:
|
def test_migration_to_version_is_one_gt_from_version(no_op_migrate_callback: MigrateCallback) -> None:
|
||||||
with pytest.raises(ValidationError, match="to_version must be one greater than from_version"):
|
with pytest.raises(ValidationError, match="to_version must be one greater than from_version"):
|
||||||
Migration(from_version=0, to_version=2, migrate_callback=no_op_migrate_callback)
|
Migration(from_version=0, to_version=2, migrate_callback=no_op_migrate_callback)
|
||||||
@ -149,6 +200,49 @@ def test_migration_set_gets_latest_version(no_op_migrate_callback: MigrateCallba
|
|||||||
assert migration_set.latest_version == 2
|
assert migration_set.latest_version == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_migration_provide_dependency_validates_name(no_op_migrate_callback: MigrateCallback) -> None:
|
||||||
|
dependency = MigrationDependency(name="my_dependency", dependency_type=str)
|
||||||
|
migration = Migration(
|
||||||
|
from_version=0,
|
||||||
|
to_version=1,
|
||||||
|
migrate_callback=no_op_migrate_callback,
|
||||||
|
dependencies={dependency.name: dependency},
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError, match="is not a dependency of this migration"):
|
||||||
|
migration.provide_dependency("unknown_dependency_name", "banana_sushi")
|
||||||
|
|
||||||
|
|
||||||
|
def test_migration_runs_without_dependencies(
|
||||||
|
memory_db_cursor: sqlite3.Cursor, migrate_callback_create_test_table: MigrateCallback
|
||||||
|
) -> None:
|
||||||
|
migration = Migration(
|
||||||
|
from_version=0,
|
||||||
|
to_version=1,
|
||||||
|
migrate_callback=migrate_callback_create_test_table,
|
||||||
|
)
|
||||||
|
migration.run(memory_db_cursor)
|
||||||
|
memory_db_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test';")
|
||||||
|
assert memory_db_cursor.fetchone() is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_migration_runs_with_dependencies(
|
||||||
|
memory_db_cursor: sqlite3.Cursor, migrate_callback_create_table_of_name: MigrateCallback
|
||||||
|
) -> None:
|
||||||
|
dependency = MigrationDependency(name="table_name", dependency_type=str)
|
||||||
|
migration = Migration(
|
||||||
|
from_version=0,
|
||||||
|
to_version=1,
|
||||||
|
migrate_callback=migrate_callback_create_table_of_name,
|
||||||
|
dependencies={dependency.name: dependency},
|
||||||
|
)
|
||||||
|
with pytest.raises(MigrationError, match="Missing migration dependencies"):
|
||||||
|
migration.run(memory_db_cursor)
|
||||||
|
migration.provide_dependency(dependency.name, "banana_sushi")
|
||||||
|
migration.run(memory_db_cursor)
|
||||||
|
memory_db_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='banana_sushi';")
|
||||||
|
assert memory_db_cursor.fetchone() is not None
|
||||||
|
|
||||||
|
|
||||||
def test_migrator_registers_migration(migrator: SQLiteMigrator, migration_no_op: Migration) -> None:
|
def test_migrator_registers_migration(migrator: SQLiteMigrator, migration_no_op: Migration) -> None:
|
||||||
migration = migration_no_op
|
migration = migration_no_op
|
||||||
migrator.register_migration(migration)
|
migrator.register_migration(migration)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user