From 0cf7fe43afb44345b55962884b241dbf95f932f6 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 12 Dec 2023 12:35:42 +1100 Subject: [PATCH] feat(db): refactor migrate callbacks to use dependencies, remote pre/post callbacks --- invokeai/app/api/dependencies.py | 7 +- .../sqlite_migrator/migrations/migration_1.py | 4 +- .../sqlite_migrator/migrations/migration_2.py | 59 ++++++++++- .../migrations/migration_2_post.py | 41 ------- .../sqlite_migrator/sqlite_migrator_common.py | 92 ++++++++++++---- .../sqlite_migrator/sqlite_migrator_impl.py | 15 +-- tests/aa_nodes/test_graph_execution_state.py | 14 +-- tests/aa_nodes/test_invoker.py | 14 +-- .../model_install/test_model_install.py | 16 +-- .../model_records/test_model_records_sql.py | 14 +-- tests/conftest.py | 2 + tests/fixtures/__init__.py | 0 tests/fixtures/sqlite_database.py | 33 ++++++ tests/test_sqlite_migrator.py | 100 ++++++++---------- 14 files changed, 230 insertions(+), 181 deletions(-) delete mode 100644 invokeai/app/services/shared/sqlite_migrator/migrations/migration_2_post.py create mode 100644 tests/fixtures/__init__.py create mode 100644 tests/fixtures/sqlite_database.py diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index fe42872bcd..212d470d2c 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -1,11 +1,9 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) -from functools import partial from logging import Logger from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import migration_1 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import migration_2 -from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2_post import migrate_embedded_workflows from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SQLiteMigrator from invokeai.backend.util.logging import InvokeAILogger from invokeai.version.invokeai_version import __version__ @@ -77,8 +75,11 @@ class ApiDependencies: db_path = None if config.use_memory_db else config.db_path db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql) + # This migration requires an ImageFileStorageBase service and logger + migration_2.provide_dependency("image_files", image_files) + migration_2.provide_dependency("logger", logger) + migrator = SQLiteMigrator(db=db) - migration_2.register_post_callback(partial(migrate_embedded_workflows, logger=logger, image_files=image_files)) migrator.register_migration(migration_1) migrator.register_migration(migration_2) migrator.run_migrations() diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_1.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_1.py index fce9acf54c..e456cd94fc 100644 --- a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_1.py +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_1.py @@ -3,7 +3,7 @@ import sqlite3 from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration -def _migrate(cursor: sqlite3.Cursor) -> None: +def migrate_callback(cursor: sqlite3.Cursor, **kwargs) -> None: """Migration callback for database version 1.""" _create_board_images(cursor) @@ -353,7 +353,7 @@ def _create_workflows(cursor: sqlite3.Cursor) -> None: migration_1 = Migration( from_version=0, to_version=1, - migrate=_migrate, + migrate_callback=migrate_callback, ) """ Database version 1 (initial state). diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_2.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_2.py index a4c950b85c..e99d14b779 100644 --- a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_2.py +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_2.py @@ -1,16 +1,24 @@ import sqlite3 +from logging import Logger -from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration +from tqdm import tqdm + +from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase +from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration, MigrationDependency -def _migrate(cursor: sqlite3.Cursor) -> None: +def migrate_callback(cursor: sqlite3.Cursor, **kwargs) -> None: """Migration callback for database version 2.""" + logger = kwargs["logger"] + image_files = kwargs["image_files"] + _add_images_has_workflow(cursor) _add_session_queue_workflow(cursor) _drop_old_workflow_tables(cursor) _add_workflow_library(cursor) _drop_model_manager_metadata(cursor) + _migrate_embedded_workflows(cursor, logger, image_files) def _add_images_has_workflow(cursor: sqlite3.Cursor) -> None: @@ -89,19 +97,64 @@ def _drop_model_manager_metadata(cursor: sqlite3.Cursor) -> None: cursor.execute("DROP TABLE IF EXISTS model_manager_metadata;") +def _migrate_embedded_workflows( + cursor: sqlite3.Cursor, + logger: Logger, + image_files: ImageFileStorageBase, +) -> None: + """ + In the v3.5.0 release, InvokeAI changed how it handles embedded workflows. The `images` table in + the database now has a `has_workflow` column, indicating if an image has a workflow embedded. + + This migrate callback checks each image for the presence of an embedded workflow, then updates its entry + in the database accordingly. + """ + # Get the total number of images and chunk it into pages + cursor.execute("SELECT image_name FROM images") + image_names: list[str] = [image[0] for image in cursor.fetchall()] + total_image_names = len(image_names) + + if not total_image_names: + return + + logger.info(f"Migrating workflows for {total_image_names} images") + + # Migrate the images + to_migrate: list[tuple[bool, str]] = [] + pbar = tqdm(image_names) + for idx, image_name in enumerate(pbar): + pbar.set_description(f"Checking image {idx + 1}/{total_image_names} for workflow") + pil_image = image_files.get(image_name) + if "invokeai_workflow" in pil_image.info: + to_migrate.append((True, image_name)) + + logger.info(f"Adding {len(to_migrate)} embedded workflows to database") + cursor.executemany("UPDATE images SET has_workflow = ? WHERE image_name = ?", to_migrate) + + +image_files_dependency = MigrationDependency(name="image_files", dependency_type=ImageFileStorageBase) +logger_dependency = MigrationDependency(name="logger", dependency_type=Logger) + + migration_2 = Migration( from_version=1, to_version=2, - migrate=_migrate, + migrate_callback=migrate_callback, + dependencies={"image_files": image_files_dependency, "logger": logger_dependency}, ) """ Database version 2. Introduced in v3.5.0 for the new workflow library. +Dependencies: +- image_files: ImageFileStorageBase +- logger: Logger + Migration: - Add `has_workflow` column to `images` table - Add `workflow` column to `session_queue` table - Drop `workflows` and `workflow_images` tables - Add `workflow_library` table +- Populates the `has_workflow` column in the `images` table (requires `image_files` & `logger` dependencies) """ diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_2_post.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_2_post.py deleted file mode 100644 index fa0d874a5c..0000000000 --- a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_2_post.py +++ /dev/null @@ -1,41 +0,0 @@ -import sqlite3 -from logging import Logger - -from tqdm import tqdm - -from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase - - -def migrate_embedded_workflows( - cursor: sqlite3.Cursor, - logger: Logger, - image_files: ImageFileStorageBase, -) -> None: - """ - In the v3.5.0 release, InvokeAI changed how it handles embedded workflows. The `images` table in - the database now has a `has_workflow` column, indicating if an image has a workflow embedded. - - This migrate callbakc checks each image for the presence of an embedded workflow, then updates its entry - in the database accordingly. - """ - # Get the total number of images and chunk it into pages - cursor.execute("SELECT image_name FROM images") - image_names: list[str] = [image[0] for image in cursor.fetchall()] - total_image_names = len(image_names) - - if not total_image_names: - return - - logger.info(f"Migrating workflows for {total_image_names} images") - - # Migrate the images - to_migrate: list[tuple[bool, str]] = [] - pbar = tqdm(image_names) - for idx, image_name in enumerate(pbar): - pbar.set_description(f"Checking image {idx + 1}/{total_image_names} for workflow") - pil_image = image_files.get(image_name) - if "invokeai_workflow" in pil_image.info: - to_migrate.append((True, image_name)) - - logger.info(f"Adding {len(to_migrate)} embedded workflows to database") - cursor.executemany("UPDATE images SET has_workflow = ? WHERE image_name = ?", to_migrate) diff --git a/invokeai/app/services/shared/sqlite_migrator/sqlite_migrator_common.py b/invokeai/app/services/shared/sqlite_migrator/sqlite_migrator_common.py index 0c395f54d6..638fab6eb7 100644 --- a/invokeai/app/services/shared/sqlite_migrator/sqlite_migrator_common.py +++ b/invokeai/app/services/shared/sqlite_migrator/sqlite_migrator_common.py @@ -1,9 +1,16 @@ import sqlite3 -from typing import Callable, Optional, TypeAlias +from functools import partial +from typing import Any, Optional, Protocol, runtime_checkable -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, ConfigDict, Field, model_validator -MigrateCallback: TypeAlias = Callable[[sqlite3.Cursor], None] + +@runtime_checkable +class MigrateCallback(Protocol): + """A callback that performs a migration.""" + + def __call__(self, cursor: sqlite3.Cursor, **kwargs: Any) -> None: + ... class MigrationError(RuntimeError): @@ -14,28 +21,65 @@ class MigrationVersionError(ValueError): """Raised when a migration version is invalid.""" +class MigrationDependency: + """Represents a dependency for a migration.""" + + def __init__( + self, + name: str, + dependency_type: Any, + ) -> None: + self.name = name + self.dependency_type = dependency_type + self.value = None + + def set(self, value: Any) -> None: + """Sets the value of the dependency.""" + if not isinstance(value, self.dependency_type): + raise ValueError(f"Dependency {self.name} must be of type {self.dependency_type}") + self.value = value + + class Migration(BaseModel): """ Represents a migration for a SQLite database. + :param from_version: The database version on which this migration may be run + :param to_version: The database version that results from this migration + :param migrate: The callback to run to perform the migration + :param dependencies: A dict of dependencies that must be provided to the migration + Migration callbacks will be provided an open cursor to the database. They should not commit their transaction; this is handled by the migrator. - Pre- and post-migration callback may be registered with :meth:`register_pre_callback` or - :meth:`register_post_callback`. + If a migration needs an additional dependency, it must be provided with :meth:`provide_dependency` + before the migration is run. - If a migration has additional dependencies, it is recommended to use functools.partial to provide - the dependencies and register the partial as the migration callback. + Example Usage: + ```py + # Define the migrate callback + def migrate_callback(cursor: sqlite3.Cursor, **kwargs) -> None: + some_dependency = kwargs["some_dependency"] + ... + + # Instantiate the migration, declaring dependencies + migration = Migration( + from_version=0, + to_version=1, + migrate_callback=migrate_callback, + dependencies={"some_dependency": MigrationDependency(name="some_dependency", dependency_type=SomeType)}, + ) + + # Register the dependency before running the migration + migration.provide_dependency(name="some_dependency", value=some_value) + ``` """ from_version: int = Field(ge=0, strict=True, description="The database version on which this migration may be run") to_version: int = Field(ge=1, strict=True, description="The database version that results from this migration") - migrate: MigrateCallback = Field(description="The callback to run to perform the migration") - pre_migrate: list[MigrateCallback] = Field( - default=[], description="A list of callbacks to run before the migration" - ) - post_migrate: list[MigrateCallback] = Field( - default=[], description="A list of callbacks to run after the migration" + migrate_callback: MigrateCallback = Field(description="The callback to run to perform the migration") + dependencies: dict[str, MigrationDependency] = Field( + default={}, description="A dict of dependencies that must be provided to the migration" ) @model_validator(mode="after") @@ -48,13 +92,21 @@ class Migration(BaseModel): # Callables are not hashable, so we need to implement our own __hash__ function to use this class in a set. return hash((self.from_version, self.to_version)) - def register_pre_callback(self, callback: MigrateCallback) -> None: - """Registers a pre-migration callback.""" - self.pre_migrate.append(callback) + def provide_dependency(self, name: str, value: Any) -> None: + """Provides a dependency for this migration.""" + if name not in self.dependencies: + raise ValueError(f"{name} of type {type(value)} is not a dependency of this migration") + self.dependencies[name].set(value) - def register_post_callback(self, callback: MigrateCallback) -> None: - """Registers a post-migration callback.""" - self.post_migrate.append(callback) + def run(self, cursor: sqlite3.Cursor) -> None: + """Runs the migration.""" + missing_dependencies = [d.name for d in self.dependencies.values() if d.value is None] + if missing_dependencies: + raise ValueError(f"Missing migration dependencies: {', '.join(missing_dependencies)}") + self.migrate_callback = partial(self.migrate_callback, **{d.name: d.value for d in self.dependencies.values()}) + self.migrate_callback(cursor=cursor) + + model_config = ConfigDict(arbitrary_types_allowed=True) class MigrationSet: @@ -78,7 +130,7 @@ class MigrationSet: def validate_migration_chain(self) -> None: """ - Validates that the migrations form a single chain of migrations from version 0 to the latest version. + Validates that the migrations form a single chain of migrations from version 0 to the latest version, Raises a MigrationError if there is a problem. """ if self.count == 0: diff --git a/invokeai/app/services/shared/sqlite_migrator/sqlite_migrator_impl.py b/invokeai/app/services/shared/sqlite_migrator/sqlite_migrator_impl.py index 07cae1bb3a..ec32818e31 100644 --- a/invokeai/app/services/shared/sqlite_migrator/sqlite_migrator_impl.py +++ b/invokeai/app/services/shared/sqlite_migrator/sqlite_migrator_impl.py @@ -66,21 +66,12 @@ class SQLiteMigrator: ) self._logger.debug(f"Running migration from {migration.from_version} to {migration.to_version}") - # Run pre-migration callbacks - if migration.pre_migrate: - self._logger.debug(f"Running {len(migration.pre_migrate)} pre-migration callbacks") - for callback in migration.pre_migrate: - callback(cursor) - # Run the actual migration - migration.migrate(cursor) + migration.run(cursor) + + # Update the version cursor.execute("INSERT INTO migrations (version) VALUES (?);", (migration.to_version,)) - # Run post-migration callbacks - if migration.post_migrate: - self._logger.debug(f"Running {len(migration.post_migrate)} post-migration callbacks") - for callback in migration.post_migrate: - callback(cursor) self._logger.debug( f"Successfully migrated database from {migration.from_version} to {migration.to_version}" ) diff --git a/tests/aa_nodes/test_graph_execution_state.py b/tests/aa_nodes/test_graph_execution_state.py index 609b0c3736..99a8382001 100644 --- a/tests/aa_nodes/test_graph_execution_state.py +++ b/tests/aa_nodes/test_graph_execution_state.py @@ -28,11 +28,8 @@ from invokeai.app.services.shared.graph import ( IterateInvocation, LibraryGraph, ) -from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase -from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import migration_1 -from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import migration_2 -from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SQLiteMigrator from invokeai.backend.util.logging import InvokeAILogger +from tests.fixtures.sqlite_database import CreateSqliteDatabaseFunction from .test_invoker import create_edge @@ -50,15 +47,10 @@ def simple_graph(): # Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate # the test invocations. @pytest.fixture -def mock_services() -> InvocationServices: +def mock_services(create_sqlite_database: CreateSqliteDatabaseFunction) -> InvocationServices: configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0) logger = InvokeAILogger.get_logger() - db_path = None if configuration.use_memory_db else configuration.db_path - db = SqliteDatabase(db_path=db_path, logger=logger, verbose=configuration.log_sql) - migrator = SQLiteMigrator(db=db) - migrator.register_migration(migration_1) - migrator.register_migration(migration_2) - migrator.run_migrations() + db = create_sqlite_database(configuration, logger) # NOTE: none of these are actually called by the test invocations graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions") return InvocationServices( diff --git a/tests/aa_nodes/test_invoker.py b/tests/aa_nodes/test_invoker.py index 866287c461..6703c2768f 100644 --- a/tests/aa_nodes/test_invoker.py +++ b/tests/aa_nodes/test_invoker.py @@ -3,8 +3,8 @@ import logging import pytest from invokeai.app.services.config.config_default import InvokeAIAppConfig -from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SQLiteMigrator from invokeai.backend.util.logging import InvokeAILogger +from tests.fixtures.sqlite_database import CreateSqliteDatabaseFunction # This import must happen before other invoke imports or test in other files(!!) break from .test_nodes import ( # isort: split @@ -25,9 +25,6 @@ from invokeai.app.services.invoker import Invoker from invokeai.app.services.item_storage.item_storage_sqlite import SqliteItemStorage from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation, LibraryGraph -from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase -from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import migration_1 -from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import migration_2 @pytest.fixture @@ -54,15 +51,10 @@ def graph_with_subgraph(): # Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate # the test invocations. @pytest.fixture -def mock_services() -> InvocationServices: +def mock_services(create_sqlite_database: CreateSqliteDatabaseFunction) -> InvocationServices: configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0) logger = InvokeAILogger.get_logger() - db_path = None if configuration.use_memory_db else configuration.db_path - db = SqliteDatabase(db_path=db_path, logger=logger, verbose=configuration.log_sql) - migrator = SQLiteMigrator(db=db) - migrator.register_migration(migration_1) - migrator.register_migration(migration_2) - migrator.run_migrations() + db = create_sqlite_database(configuration, logger) # NOTE: none of these are actually called by the test invocations graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions") diff --git a/tests/app/services/model_install/test_model_install.py b/tests/app/services/model_install/test_model_install.py index 2b245cce6d..8983b914b0 100644 --- a/tests/app/services/model_install/test_model_install.py +++ b/tests/app/services/model_install/test_model_install.py @@ -18,12 +18,9 @@ from invokeai.app.services.model_install import ( ModelInstallServiceBase, ) from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL, UnknownModelException -from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase -from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import migration_1 -from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import migration_2 -from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SQLiteMigrator from invokeai.backend.model_manager.config import BaseModelType, ModelType from invokeai.backend.util.logging import InvokeAILogger +from tests.fixtures.sqlite_database import CreateSqliteDatabaseFunction @pytest.fixture @@ -40,14 +37,11 @@ def app_config(datadir: Path) -> InvokeAIAppConfig: @pytest.fixture -def store(app_config: InvokeAIAppConfig) -> ModelRecordServiceBase: +def store( + app_config: InvokeAIAppConfig, create_sqlite_database: CreateSqliteDatabaseFunction +) -> ModelRecordServiceBase: logger = InvokeAILogger.get_logger(config=app_config) - db_path = None if app_config.use_memory_db else app_config.db_path - db = SqliteDatabase(db_path=db_path, logger=logger, verbose=app_config.log_sql) - migrator = SQLiteMigrator(db=db) - migrator.register_migration(migration_1) - migrator.register_migration(migration_2) - migrator.run_migrations() + db = create_sqlite_database(app_config, logger) store: ModelRecordServiceBase = ModelRecordServiceSQL(db) return store diff --git a/tests/app/services/model_records/test_model_records_sql.py b/tests/app/services/model_records/test_model_records_sql.py index e3589d6ec0..f1cd5d4d88 100644 --- a/tests/app/services/model_records/test_model_records_sql.py +++ b/tests/app/services/model_records/test_model_records_sql.py @@ -14,10 +14,6 @@ from invokeai.app.services.model_records import ( ModelRecordServiceSQL, UnknownModelException, ) -from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase -from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import migration_1 -from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import migration_2 -from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SQLiteMigrator from invokeai.backend.model_manager.config import ( BaseModelType, MainCheckpointConfig, @@ -27,18 +23,14 @@ from invokeai.backend.model_manager.config import ( VaeDiffusersConfig, ) from invokeai.backend.util.logging import InvokeAILogger +from tests.fixtures.sqlite_database import CreateSqliteDatabaseFunction @pytest.fixture -def store(datadir: Any) -> ModelRecordServiceBase: +def store(datadir: Any, create_sqlite_database: CreateSqliteDatabaseFunction) -> ModelRecordServiceBase: config = InvokeAIAppConfig(root=datadir) logger = InvokeAILogger.get_logger(config=config) - db_path = None if config.use_memory_db else config.db_path - db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql) - migrator = SQLiteMigrator(db=db) - migrator.register_migration(migration_1) - migrator.register_migration(migration_2) - migrator.run_migrations() + db = create_sqlite_database(config, logger) return ModelRecordServiceSQL(db) diff --git a/tests/conftest.py b/tests/conftest.py index 8618f5e102..d37904d80e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,3 +4,5 @@ # We import the model_installer and torch_device fixtures here so that they can be used by all tests. Flake8 does not # play well with fixtures (F401 and F811), so this is cleaner than importing in all files that use these fixtures. from invokeai.backend.util.test_utils import model_installer, torch_device # noqa: F401 + +pytest_plugins = ["tests.fixtures.sqlite_database"] diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/fixtures/sqlite_database.py b/tests/fixtures/sqlite_database.py new file mode 100644 index 0000000000..8609f81768 --- /dev/null +++ b/tests/fixtures/sqlite_database.py @@ -0,0 +1,33 @@ +from logging import Logger +from typing import Callable +from unittest import mock + +import pytest + +from invokeai.app.services.config.config_default import InvokeAIAppConfig +from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase +from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase +from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import migration_1 +from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import migration_2 +from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SQLiteMigrator + +CreateSqliteDatabaseFunction = Callable[[InvokeAIAppConfig, Logger], SqliteDatabase] + + +@pytest.fixture +def create_sqlite_database() -> CreateSqliteDatabaseFunction: + def _create_sqlite_database(config: InvokeAIAppConfig, logger: Logger) -> SqliteDatabase: + db_path = None if config.use_memory_db else config.db_path + db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql) + + image_files = mock.Mock(spec=ImageFileStorageBase) + + migrator = SQLiteMigrator(db=db) + migration_2.provide_dependency("logger", logger) + migration_2.provide_dependency("image_files", image_files) + migrator.register_migration(migration_1) + migrator.register_migration(migration_2) + migrator.run_migrations() + return db + + return _create_sqlite_database diff --git a/tests/test_sqlite_migrator.py b/tests/test_sqlite_migrator.py index 3759859b4b..1e6d0548b6 100644 --- a/tests/test_sqlite_migrator.py +++ b/tests/test_sqlite_migrator.py @@ -31,45 +31,45 @@ def migrator(logger: Logger) -> SQLiteMigrator: return SQLiteMigrator(db=db) -@pytest.fixture -def migration_no_op() -> Migration: - return Migration(from_version=0, to_version=1, migrate=lambda cursor: None) - - -@pytest.fixture -def migration_create_test_table() -> Migration: - def migrate(cursor: sqlite3.Cursor) -> None: - cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);") - - return Migration(from_version=0, to_version=1, migrate=migrate) - - -@pytest.fixture -def failing_migration() -> Migration: - def failing_migration(cursor: sqlite3.Cursor) -> None: - raise Exception("Bad migration") - - return Migration(from_version=0, to_version=1, migrate=failing_migration) - - @pytest.fixture def no_op_migrate_callback() -> MigrateCallback: - def no_op_migrate(cursor: sqlite3.Cursor) -> None: + def no_op_migrate(cursor: sqlite3.Cursor, **kwargs) -> None: pass return no_op_migrate +@pytest.fixture +def migration_no_op(no_op_migrate_callback: MigrateCallback) -> Migration: + return Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback) + + +@pytest.fixture +def migration_create_test_table() -> Migration: + def migrate(cursor: sqlite3.Cursor, **kwargs) -> None: + cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);") + + return Migration(from_version=0, to_version=1, migrate_callback=migrate) + + +@pytest.fixture +def failing_migration() -> Migration: + def failing_migration(cursor: sqlite3.Cursor, **kwargs) -> None: + raise Exception("Bad migration") + + return Migration(from_version=0, to_version=1, migrate_callback=failing_migration) + + @pytest.fixture def failing_migrate_callback() -> MigrateCallback: - def failing_migrate(cursor: sqlite3.Cursor) -> None: + def failing_migrate(cursor: sqlite3.Cursor, **kwargs) -> None: raise Exception("Bad migration") return failing_migrate def create_migrate(i: int) -> MigrateCallback: - def migrate(cursor: sqlite3.Cursor) -> None: + def migrate(cursor: sqlite3.Cursor, **kwargs) -> None: cursor.execute(f"CREATE TABLE test{i} (id INTEGER PRIMARY KEY);") return migrate @@ -77,30 +77,16 @@ def create_migrate(i: int) -> MigrateCallback: def test_migration_to_version_is_one_gt_from_version(no_op_migrate_callback: MigrateCallback) -> None: with pytest.raises(ValidationError, match="to_version must be one greater than from_version"): - Migration(from_version=0, to_version=2, migrate=no_op_migrate_callback) + Migration(from_version=0, to_version=2, migrate_callback=no_op_migrate_callback) # not raising is sufficient - Migration(from_version=1, to_version=2, migrate=no_op_migrate_callback) + Migration(from_version=1, to_version=2, migrate_callback=no_op_migrate_callback) def test_migration_hash(no_op_migrate_callback: MigrateCallback) -> None: - migration = Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback) + migration = Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback) assert hash(migration) == hash((0, 1)) -def test_migration_registers_pre_and_post_callbacks(no_op_migrate_callback: MigrateCallback) -> None: - def pre_callback(cursor: sqlite3.Cursor) -> None: - pass - - def post_callback(cursor: sqlite3.Cursor) -> None: - pass - - migration = Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback) - migration.register_pre_callback(pre_callback) - migration.register_post_callback(post_callback) - assert pre_callback in migration.pre_migrate - assert post_callback in migration.post_migrate - - def test_migration_set_add_migration(migrator: SQLiteMigrator, migration_no_op: Migration) -> None: migration = migration_no_op migrator._migration_set.register(migration) @@ -110,13 +96,13 @@ def test_migration_set_add_migration(migrator: SQLiteMigrator, migration_no_op: def test_migration_set_may_not_register_dupes( migrator: SQLiteMigrator, no_op_migrate_callback: MigrateCallback ) -> None: - migrate_0_to_1_a = Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback) - migrate_0_to_1_b = Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback) + migrate_0_to_1_a = Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback) + migrate_0_to_1_b = Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback) migrator._migration_set.register(migrate_0_to_1_a) with pytest.raises(MigrationVersionError, match=r"Migration with from_version or to_version already registered"): migrator._migration_set.register(migrate_0_to_1_b) - migrate_1_to_2_a = Migration(from_version=1, to_version=2, migrate=no_op_migrate_callback) - migrate_1_to_2_b = Migration(from_version=1, to_version=2, migrate=no_op_migrate_callback) + migrate_1_to_2_a = Migration(from_version=1, to_version=2, migrate_callback=no_op_migrate_callback) + migrate_1_to_2_b = Migration(from_version=1, to_version=2, migrate_callback=no_op_migrate_callback) migrator._migration_set.register(migrate_1_to_2_a) with pytest.raises(MigrationVersionError, match=r"Migration with from_version or to_version already registered"): migrator._migration_set.register(migrate_1_to_2_b) @@ -131,15 +117,15 @@ def test_migration_set_gets_migration(migration_no_op: Migration) -> None: def test_migration_set_validates_migration_chain(no_op_migrate_callback: MigrateCallback) -> None: migration_set = MigrationSet() - migration_set.register(Migration(from_version=1, to_version=2, migrate=no_op_migrate_callback)) + migration_set.register(Migration(from_version=1, to_version=2, migrate_callback=no_op_migrate_callback)) with pytest.raises(MigrationError, match="Migration chain is fragmented"): # no migration from 0 to 1 migration_set.validate_migration_chain() - migration_set.register(Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback)) + migration_set.register(Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback)) migration_set.validate_migration_chain() - migration_set.register(Migration(from_version=2, to_version=3, migrate=no_op_migrate_callback)) + migration_set.register(Migration(from_version=2, to_version=3, migrate_callback=no_op_migrate_callback)) migration_set.validate_migration_chain() - migration_set.register(Migration(from_version=4, to_version=5, migrate=no_op_migrate_callback)) + migration_set.register(Migration(from_version=4, to_version=5, migrate_callback=no_op_migrate_callback)) with pytest.raises(MigrationError, match="Migration chain is fragmented"): # no migration from 3 to 4 migration_set.validate_migration_chain() @@ -148,18 +134,18 @@ def test_migration_set_validates_migration_chain(no_op_migrate_callback: Migrate def test_migration_set_counts_migrations(no_op_migrate_callback: MigrateCallback) -> None: migration_set = MigrationSet() assert migration_set.count == 0 - migration_set.register(Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback)) + migration_set.register(Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback)) assert migration_set.count == 1 - migration_set.register(Migration(from_version=1, to_version=2, migrate=no_op_migrate_callback)) + migration_set.register(Migration(from_version=1, to_version=2, migrate_callback=no_op_migrate_callback)) assert migration_set.count == 2 def test_migration_set_gets_latest_version(no_op_migrate_callback: MigrateCallback) -> None: migration_set = MigrationSet() assert migration_set.latest_version == 0 - migration_set.register(Migration(from_version=1, to_version=2, migrate=no_op_migrate_callback)) + migration_set.register(Migration(from_version=1, to_version=2, migrate_callback=no_op_migrate_callback)) assert migration_set.latest_version == 2 - migration_set.register(Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback)) + migration_set.register(Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback)) assert migration_set.latest_version == 2 @@ -206,7 +192,7 @@ def test_migrator_runs_single_migration(migrator: SQLiteMigrator, migration_crea def test_migrator_runs_all_migrations_in_memory(migrator: SQLiteMigrator) -> None: cursor = migrator._db.conn.cursor() - migrations = [Migration(from_version=i, to_version=i + 1, migrate=create_migrate(i)) for i in range(0, 3)] + migrations = [Migration(from_version=i, to_version=i + 1, migrate_callback=create_migrate(i)) for i in range(0, 3)] for migration in migrations: migrator.register_migration(migration) migrator.run_migrations() @@ -219,7 +205,9 @@ def test_migrator_runs_all_migrations_file(logger: Logger) -> None: # The Migrator closes the database when it finishes; we cannot use a context manager. db = SqliteDatabase(db_path=original_db_path, logger=logger, verbose=False) migrator = SQLiteMigrator(db=db) - migrations = [Migration(from_version=i, to_version=i + 1, migrate=create_migrate(i)) for i in range(0, 3)] + migrations = [ + Migration(from_version=i, to_version=i + 1, migrate_callback=create_migrate(i)) for i in range(0, 3) + ] for migration in migrations: migrator.register_migration(migration) migrator.run_migrations() @@ -235,7 +223,7 @@ def test_migrator_makes_no_changes_on_failed_migration( migrator.register_migration(migration_no_op) migrator.run_migrations() assert migrator._get_current_version(cursor) == 1 - migrator.register_migration(Migration(from_version=1, to_version=2, migrate=failing_migrate_callback)) + migrator.register_migration(Migration(from_version=1, to_version=2, migrate_callback=failing_migrate_callback)) with pytest.raises(MigrationError, match="Bad migration"): migrator.run_migrations() assert migrator._get_current_version(cursor) == 1