feat(db): refactor migrate callbacks to use dependencies, remote pre/post callbacks

This commit is contained in:
psychedelicious
2023-12-12 12:35:42 +11:00
parent 6063760ce2
commit 0cf7fe43af
14 changed files with 230 additions and 181 deletions

View File

@ -28,11 +28,8 @@ from invokeai.app.services.shared.graph import (
IterateInvocation,
LibraryGraph,
)
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import migration_1
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import migration_2
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SQLiteMigrator
from invokeai.backend.util.logging import InvokeAILogger
from tests.fixtures.sqlite_database import CreateSqliteDatabaseFunction
from .test_invoker import create_edge
@ -50,15 +47,10 @@ def simple_graph():
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
# the test invocations.
@pytest.fixture
def mock_services() -> InvocationServices:
def mock_services(create_sqlite_database: CreateSqliteDatabaseFunction) -> InvocationServices:
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
logger = InvokeAILogger.get_logger()
db_path = None if configuration.use_memory_db else configuration.db_path
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=configuration.log_sql)
migrator = SQLiteMigrator(db=db)
migrator.register_migration(migration_1)
migrator.register_migration(migration_2)
migrator.run_migrations()
db = create_sqlite_database(configuration, logger)
# NOTE: none of these are actually called by the test invocations
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
return InvocationServices(

View File

@ -3,8 +3,8 @@ import logging
import pytest
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SQLiteMigrator
from invokeai.backend.util.logging import InvokeAILogger
from tests.fixtures.sqlite_database import CreateSqliteDatabaseFunction
# This import must happen before other invoke imports or test in other files(!!) break
from .test_nodes import ( # isort: split
@ -25,9 +25,6 @@ from invokeai.app.services.invoker import Invoker
from invokeai.app.services.item_storage.item_storage_sqlite import SqliteItemStorage
from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation, LibraryGraph
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import migration_1
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import migration_2
@pytest.fixture
@ -54,15 +51,10 @@ def graph_with_subgraph():
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
# the test invocations.
@pytest.fixture
def mock_services() -> InvocationServices:
def mock_services(create_sqlite_database: CreateSqliteDatabaseFunction) -> InvocationServices:
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
logger = InvokeAILogger.get_logger()
db_path = None if configuration.use_memory_db else configuration.db_path
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=configuration.log_sql)
migrator = SQLiteMigrator(db=db)
migrator.register_migration(migration_1)
migrator.register_migration(migration_2)
migrator.run_migrations()
db = create_sqlite_database(configuration, logger)
# NOTE: none of these are actually called by the test invocations
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")

View File

@ -18,12 +18,9 @@ from invokeai.app.services.model_install import (
ModelInstallServiceBase,
)
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL, UnknownModelException
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import migration_1
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import migration_2
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SQLiteMigrator
from invokeai.backend.model_manager.config import BaseModelType, ModelType
from invokeai.backend.util.logging import InvokeAILogger
from tests.fixtures.sqlite_database import CreateSqliteDatabaseFunction
@pytest.fixture
@ -40,14 +37,11 @@ def app_config(datadir: Path) -> InvokeAIAppConfig:
@pytest.fixture
def store(app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
def store(
app_config: InvokeAIAppConfig, create_sqlite_database: CreateSqliteDatabaseFunction
) -> ModelRecordServiceBase:
logger = InvokeAILogger.get_logger(config=app_config)
db_path = None if app_config.use_memory_db else app_config.db_path
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=app_config.log_sql)
migrator = SQLiteMigrator(db=db)
migrator.register_migration(migration_1)
migrator.register_migration(migration_2)
migrator.run_migrations()
db = create_sqlite_database(app_config, logger)
store: ModelRecordServiceBase = ModelRecordServiceSQL(db)
return store

View File

@ -14,10 +14,6 @@ from invokeai.app.services.model_records import (
ModelRecordServiceSQL,
UnknownModelException,
)
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import migration_1
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import migration_2
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SQLiteMigrator
from invokeai.backend.model_manager.config import (
BaseModelType,
MainCheckpointConfig,
@ -27,18 +23,14 @@ from invokeai.backend.model_manager.config import (
VaeDiffusersConfig,
)
from invokeai.backend.util.logging import InvokeAILogger
from tests.fixtures.sqlite_database import CreateSqliteDatabaseFunction
@pytest.fixture
def store(datadir: Any) -> ModelRecordServiceBase:
def store(datadir: Any, create_sqlite_database: CreateSqliteDatabaseFunction) -> ModelRecordServiceBase:
config = InvokeAIAppConfig(root=datadir)
logger = InvokeAILogger.get_logger(config=config)
db_path = None if config.use_memory_db else config.db_path
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql)
migrator = SQLiteMigrator(db=db)
migrator.register_migration(migration_1)
migrator.register_migration(migration_2)
migrator.run_migrations()
db = create_sqlite_database(config, logger)
return ModelRecordServiceSQL(db)

View File

@ -4,3 +4,5 @@
# We import the model_installer and torch_device fixtures here so that they can be used by all tests. Flake8 does not
# play well with fixtures (F401 and F811), so this is cleaner than importing in all files that use these fixtures.
from invokeai.backend.util.test_utils import model_installer, torch_device # noqa: F401
pytest_plugins = ["tests.fixtures.sqlite_database"]

0
tests/fixtures/__init__.py vendored Normal file
View File

33
tests/fixtures/sqlite_database.py vendored Normal file
View File

@ -0,0 +1,33 @@
from logging import Logger
from typing import Callable
from unittest import mock
import pytest
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import migration_1
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import migration_2
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SQLiteMigrator
CreateSqliteDatabaseFunction = Callable[[InvokeAIAppConfig, Logger], SqliteDatabase]
@pytest.fixture
def create_sqlite_database() -> CreateSqliteDatabaseFunction:
def _create_sqlite_database(config: InvokeAIAppConfig, logger: Logger) -> SqliteDatabase:
db_path = None if config.use_memory_db else config.db_path
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql)
image_files = mock.Mock(spec=ImageFileStorageBase)
migrator = SQLiteMigrator(db=db)
migration_2.provide_dependency("logger", logger)
migration_2.provide_dependency("image_files", image_files)
migrator.register_migration(migration_1)
migrator.register_migration(migration_2)
migrator.run_migrations()
return db
return _create_sqlite_database

View File

@ -31,45 +31,45 @@ def migrator(logger: Logger) -> SQLiteMigrator:
return SQLiteMigrator(db=db)
@pytest.fixture
def migration_no_op() -> Migration:
return Migration(from_version=0, to_version=1, migrate=lambda cursor: None)
@pytest.fixture
def migration_create_test_table() -> Migration:
def migrate(cursor: sqlite3.Cursor) -> None:
cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
return Migration(from_version=0, to_version=1, migrate=migrate)
@pytest.fixture
def failing_migration() -> Migration:
def failing_migration(cursor: sqlite3.Cursor) -> None:
raise Exception("Bad migration")
return Migration(from_version=0, to_version=1, migrate=failing_migration)
@pytest.fixture
def no_op_migrate_callback() -> MigrateCallback:
def no_op_migrate(cursor: sqlite3.Cursor) -> None:
def no_op_migrate(cursor: sqlite3.Cursor, **kwargs) -> None:
pass
return no_op_migrate
@pytest.fixture
def migration_no_op(no_op_migrate_callback: MigrateCallback) -> Migration:
return Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback)
@pytest.fixture
def migration_create_test_table() -> Migration:
def migrate(cursor: sqlite3.Cursor, **kwargs) -> None:
cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
return Migration(from_version=0, to_version=1, migrate_callback=migrate)
@pytest.fixture
def failing_migration() -> Migration:
def failing_migration(cursor: sqlite3.Cursor, **kwargs) -> None:
raise Exception("Bad migration")
return Migration(from_version=0, to_version=1, migrate_callback=failing_migration)
@pytest.fixture
def failing_migrate_callback() -> MigrateCallback:
def failing_migrate(cursor: sqlite3.Cursor) -> None:
def failing_migrate(cursor: sqlite3.Cursor, **kwargs) -> None:
raise Exception("Bad migration")
return failing_migrate
def create_migrate(i: int) -> MigrateCallback:
def migrate(cursor: sqlite3.Cursor) -> None:
def migrate(cursor: sqlite3.Cursor, **kwargs) -> None:
cursor.execute(f"CREATE TABLE test{i} (id INTEGER PRIMARY KEY);")
return migrate
@ -77,30 +77,16 @@ def create_migrate(i: int) -> MigrateCallback:
def test_migration_to_version_is_one_gt_from_version(no_op_migrate_callback: MigrateCallback) -> None:
with pytest.raises(ValidationError, match="to_version must be one greater than from_version"):
Migration(from_version=0, to_version=2, migrate=no_op_migrate_callback)
Migration(from_version=0, to_version=2, migrate_callback=no_op_migrate_callback)
# not raising is sufficient
Migration(from_version=1, to_version=2, migrate=no_op_migrate_callback)
Migration(from_version=1, to_version=2, migrate_callback=no_op_migrate_callback)
def test_migration_hash(no_op_migrate_callback: MigrateCallback) -> None:
migration = Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback)
migration = Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback)
assert hash(migration) == hash((0, 1))
def test_migration_registers_pre_and_post_callbacks(no_op_migrate_callback: MigrateCallback) -> None:
def pre_callback(cursor: sqlite3.Cursor) -> None:
pass
def post_callback(cursor: sqlite3.Cursor) -> None:
pass
migration = Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback)
migration.register_pre_callback(pre_callback)
migration.register_post_callback(post_callback)
assert pre_callback in migration.pre_migrate
assert post_callback in migration.post_migrate
def test_migration_set_add_migration(migrator: SQLiteMigrator, migration_no_op: Migration) -> None:
migration = migration_no_op
migrator._migration_set.register(migration)
@ -110,13 +96,13 @@ def test_migration_set_add_migration(migrator: SQLiteMigrator, migration_no_op:
def test_migration_set_may_not_register_dupes(
migrator: SQLiteMigrator, no_op_migrate_callback: MigrateCallback
) -> None:
migrate_0_to_1_a = Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback)
migrate_0_to_1_b = Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback)
migrate_0_to_1_a = Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback)
migrate_0_to_1_b = Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback)
migrator._migration_set.register(migrate_0_to_1_a)
with pytest.raises(MigrationVersionError, match=r"Migration with from_version or to_version already registered"):
migrator._migration_set.register(migrate_0_to_1_b)
migrate_1_to_2_a = Migration(from_version=1, to_version=2, migrate=no_op_migrate_callback)
migrate_1_to_2_b = Migration(from_version=1, to_version=2, migrate=no_op_migrate_callback)
migrate_1_to_2_a = Migration(from_version=1, to_version=2, migrate_callback=no_op_migrate_callback)
migrate_1_to_2_b = Migration(from_version=1, to_version=2, migrate_callback=no_op_migrate_callback)
migrator._migration_set.register(migrate_1_to_2_a)
with pytest.raises(MigrationVersionError, match=r"Migration with from_version or to_version already registered"):
migrator._migration_set.register(migrate_1_to_2_b)
@ -131,15 +117,15 @@ def test_migration_set_gets_migration(migration_no_op: Migration) -> None:
def test_migration_set_validates_migration_chain(no_op_migrate_callback: MigrateCallback) -> None:
migration_set = MigrationSet()
migration_set.register(Migration(from_version=1, to_version=2, migrate=no_op_migrate_callback))
migration_set.register(Migration(from_version=1, to_version=2, migrate_callback=no_op_migrate_callback))
with pytest.raises(MigrationError, match="Migration chain is fragmented"):
# no migration from 0 to 1
migration_set.validate_migration_chain()
migration_set.register(Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback))
migration_set.register(Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback))
migration_set.validate_migration_chain()
migration_set.register(Migration(from_version=2, to_version=3, migrate=no_op_migrate_callback))
migration_set.register(Migration(from_version=2, to_version=3, migrate_callback=no_op_migrate_callback))
migration_set.validate_migration_chain()
migration_set.register(Migration(from_version=4, to_version=5, migrate=no_op_migrate_callback))
migration_set.register(Migration(from_version=4, to_version=5, migrate_callback=no_op_migrate_callback))
with pytest.raises(MigrationError, match="Migration chain is fragmented"):
# no migration from 3 to 4
migration_set.validate_migration_chain()
@ -148,18 +134,18 @@ def test_migration_set_validates_migration_chain(no_op_migrate_callback: Migrate
def test_migration_set_counts_migrations(no_op_migrate_callback: MigrateCallback) -> None:
migration_set = MigrationSet()
assert migration_set.count == 0
migration_set.register(Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback))
migration_set.register(Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback))
assert migration_set.count == 1
migration_set.register(Migration(from_version=1, to_version=2, migrate=no_op_migrate_callback))
migration_set.register(Migration(from_version=1, to_version=2, migrate_callback=no_op_migrate_callback))
assert migration_set.count == 2
def test_migration_set_gets_latest_version(no_op_migrate_callback: MigrateCallback) -> None:
migration_set = MigrationSet()
assert migration_set.latest_version == 0
migration_set.register(Migration(from_version=1, to_version=2, migrate=no_op_migrate_callback))
migration_set.register(Migration(from_version=1, to_version=2, migrate_callback=no_op_migrate_callback))
assert migration_set.latest_version == 2
migration_set.register(Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback))
migration_set.register(Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback))
assert migration_set.latest_version == 2
@ -206,7 +192,7 @@ def test_migrator_runs_single_migration(migrator: SQLiteMigrator, migration_crea
def test_migrator_runs_all_migrations_in_memory(migrator: SQLiteMigrator) -> None:
cursor = migrator._db.conn.cursor()
migrations = [Migration(from_version=i, to_version=i + 1, migrate=create_migrate(i)) for i in range(0, 3)]
migrations = [Migration(from_version=i, to_version=i + 1, migrate_callback=create_migrate(i)) for i in range(0, 3)]
for migration in migrations:
migrator.register_migration(migration)
migrator.run_migrations()
@ -219,7 +205,9 @@ def test_migrator_runs_all_migrations_file(logger: Logger) -> None:
# The Migrator closes the database when it finishes; we cannot use a context manager.
db = SqliteDatabase(db_path=original_db_path, logger=logger, verbose=False)
migrator = SQLiteMigrator(db=db)
migrations = [Migration(from_version=i, to_version=i + 1, migrate=create_migrate(i)) for i in range(0, 3)]
migrations = [
Migration(from_version=i, to_version=i + 1, migrate_callback=create_migrate(i)) for i in range(0, 3)
]
for migration in migrations:
migrator.register_migration(migration)
migrator.run_migrations()
@ -235,7 +223,7 @@ def test_migrator_makes_no_changes_on_failed_migration(
migrator.register_migration(migration_no_op)
migrator.run_migrations()
assert migrator._get_current_version(cursor) == 1
migrator.register_migration(Migration(from_version=1, to_version=2, migrate=failing_migrate_callback))
migrator.register_migration(Migration(from_version=1, to_version=2, migrate_callback=failing_migrate_callback))
with pytest.raises(MigrationError, match="Bad migration"):
migrator.run_migrations()
assert migrator._get_current_version(cursor) == 1