feat(db): refactor migrate callbacks to use dependencies, remote pre/post callbacks

This commit is contained in:
psychedelicious 2023-12-12 12:35:42 +11:00
parent 6063760ce2
commit 0cf7fe43af
14 changed files with 230 additions and 181 deletions

View File

@ -1,11 +1,9 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from functools import partial
from logging import Logger from logging import Logger
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import migration_1 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import migration_1
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import migration_2 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import migration_2
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2_post import migrate_embedded_workflows
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SQLiteMigrator from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SQLiteMigrator
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from invokeai.version.invokeai_version import __version__ from invokeai.version.invokeai_version import __version__
@ -77,8 +75,11 @@ class ApiDependencies:
db_path = None if config.use_memory_db else config.db_path db_path = None if config.use_memory_db else config.db_path
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql) db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql)
# This migration requires an ImageFileStorageBase service and logger
migration_2.provide_dependency("image_files", image_files)
migration_2.provide_dependency("logger", logger)
migrator = SQLiteMigrator(db=db) migrator = SQLiteMigrator(db=db)
migration_2.register_post_callback(partial(migrate_embedded_workflows, logger=logger, image_files=image_files))
migrator.register_migration(migration_1) migrator.register_migration(migration_1)
migrator.register_migration(migration_2) migrator.register_migration(migration_2)
migrator.run_migrations() migrator.run_migrations()

View File

@ -3,7 +3,7 @@ import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
def _migrate(cursor: sqlite3.Cursor) -> None: def migrate_callback(cursor: sqlite3.Cursor, **kwargs) -> None:
"""Migration callback for database version 1.""" """Migration callback for database version 1."""
_create_board_images(cursor) _create_board_images(cursor)
@ -353,7 +353,7 @@ def _create_workflows(cursor: sqlite3.Cursor) -> None:
migration_1 = Migration( migration_1 = Migration(
from_version=0, from_version=0,
to_version=1, to_version=1,
migrate=_migrate, migrate_callback=migrate_callback,
) )
""" """
Database version 1 (initial state). Database version 1 (initial state).

View File

@ -1,16 +1,24 @@
import sqlite3 import sqlite3
from logging import Logger
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration from tqdm import tqdm
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration, MigrationDependency
def _migrate(cursor: sqlite3.Cursor) -> None: def migrate_callback(cursor: sqlite3.Cursor, **kwargs) -> None:
"""Migration callback for database version 2.""" """Migration callback for database version 2."""
logger = kwargs["logger"]
image_files = kwargs["image_files"]
_add_images_has_workflow(cursor) _add_images_has_workflow(cursor)
_add_session_queue_workflow(cursor) _add_session_queue_workflow(cursor)
_drop_old_workflow_tables(cursor) _drop_old_workflow_tables(cursor)
_add_workflow_library(cursor) _add_workflow_library(cursor)
_drop_model_manager_metadata(cursor) _drop_model_manager_metadata(cursor)
_migrate_embedded_workflows(cursor, logger, image_files)
def _add_images_has_workflow(cursor: sqlite3.Cursor) -> None: def _add_images_has_workflow(cursor: sqlite3.Cursor) -> None:
@ -89,19 +97,64 @@ def _drop_model_manager_metadata(cursor: sqlite3.Cursor) -> None:
cursor.execute("DROP TABLE IF EXISTS model_manager_metadata;") cursor.execute("DROP TABLE IF EXISTS model_manager_metadata;")
def _migrate_embedded_workflows(
cursor: sqlite3.Cursor,
logger: Logger,
image_files: ImageFileStorageBase,
) -> None:
"""
In the v3.5.0 release, InvokeAI changed how it handles embedded workflows. The `images` table in
the database now has a `has_workflow` column, indicating if an image has a workflow embedded.
This migrate callback checks each image for the presence of an embedded workflow, then updates its entry
in the database accordingly.
"""
# Get the total number of images and chunk it into pages
cursor.execute("SELECT image_name FROM images")
image_names: list[str] = [image[0] for image in cursor.fetchall()]
total_image_names = len(image_names)
if not total_image_names:
return
logger.info(f"Migrating workflows for {total_image_names} images")
# Migrate the images
to_migrate: list[tuple[bool, str]] = []
pbar = tqdm(image_names)
for idx, image_name in enumerate(pbar):
pbar.set_description(f"Checking image {idx + 1}/{total_image_names} for workflow")
pil_image = image_files.get(image_name)
if "invokeai_workflow" in pil_image.info:
to_migrate.append((True, image_name))
logger.info(f"Adding {len(to_migrate)} embedded workflows to database")
cursor.executemany("UPDATE images SET has_workflow = ? WHERE image_name = ?", to_migrate)
image_files_dependency = MigrationDependency(name="image_files", dependency_type=ImageFileStorageBase)
logger_dependency = MigrationDependency(name="logger", dependency_type=Logger)
migration_2 = Migration( migration_2 = Migration(
from_version=1, from_version=1,
to_version=2, to_version=2,
migrate=_migrate, migrate_callback=migrate_callback,
dependencies={"image_files": image_files_dependency, "logger": logger_dependency},
) )
""" """
Database version 2. Database version 2.
Introduced in v3.5.0 for the new workflow library. Introduced in v3.5.0 for the new workflow library.
Dependencies:
- image_files: ImageFileStorageBase
- logger: Logger
Migration: Migration:
- Add `has_workflow` column to `images` table - Add `has_workflow` column to `images` table
- Add `workflow` column to `session_queue` table - Add `workflow` column to `session_queue` table
- Drop `workflows` and `workflow_images` tables - Drop `workflows` and `workflow_images` tables
- Add `workflow_library` table - Add `workflow_library` table
- Populates the `has_workflow` column in the `images` table (requires `image_files` & `logger` dependencies)
""" """

View File

@ -1,41 +0,0 @@
import sqlite3
from logging import Logger
from tqdm import tqdm
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
def migrate_embedded_workflows(
cursor: sqlite3.Cursor,
logger: Logger,
image_files: ImageFileStorageBase,
) -> None:
"""
In the v3.5.0 release, InvokeAI changed how it handles embedded workflows. The `images` table in
the database now has a `has_workflow` column, indicating if an image has a workflow embedded.
This migrate callbakc checks each image for the presence of an embedded workflow, then updates its entry
in the database accordingly.
"""
# Get the total number of images and chunk it into pages
cursor.execute("SELECT image_name FROM images")
image_names: list[str] = [image[0] for image in cursor.fetchall()]
total_image_names = len(image_names)
if not total_image_names:
return
logger.info(f"Migrating workflows for {total_image_names} images")
# Migrate the images
to_migrate: list[tuple[bool, str]] = []
pbar = tqdm(image_names)
for idx, image_name in enumerate(pbar):
pbar.set_description(f"Checking image {idx + 1}/{total_image_names} for workflow")
pil_image = image_files.get(image_name)
if "invokeai_workflow" in pil_image.info:
to_migrate.append((True, image_name))
logger.info(f"Adding {len(to_migrate)} embedded workflows to database")
cursor.executemany("UPDATE images SET has_workflow = ? WHERE image_name = ?", to_migrate)

View File

@ -1,9 +1,16 @@
import sqlite3 import sqlite3
from typing import Callable, Optional, TypeAlias from functools import partial
from typing import Any, Optional, Protocol, runtime_checkable
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, ConfigDict, Field, model_validator
MigrateCallback: TypeAlias = Callable[[sqlite3.Cursor], None]
@runtime_checkable
class MigrateCallback(Protocol):
"""A callback that performs a migration."""
def __call__(self, cursor: sqlite3.Cursor, **kwargs: Any) -> None:
...
class MigrationError(RuntimeError): class MigrationError(RuntimeError):
@ -14,28 +21,65 @@ class MigrationVersionError(ValueError):
"""Raised when a migration version is invalid.""" """Raised when a migration version is invalid."""
class MigrationDependency:
"""Represents a dependency for a migration."""
def __init__(
self,
name: str,
dependency_type: Any,
) -> None:
self.name = name
self.dependency_type = dependency_type
self.value = None
def set(self, value: Any) -> None:
"""Sets the value of the dependency."""
if not isinstance(value, self.dependency_type):
raise ValueError(f"Dependency {self.name} must be of type {self.dependency_type}")
self.value = value
class Migration(BaseModel): class Migration(BaseModel):
""" """
Represents a migration for a SQLite database. Represents a migration for a SQLite database.
:param from_version: The database version on which this migration may be run
:param to_version: The database version that results from this migration
:param migrate: The callback to run to perform the migration
:param dependencies: A dict of dependencies that must be provided to the migration
Migration callbacks will be provided an open cursor to the database. They should not commit their Migration callbacks will be provided an open cursor to the database. They should not commit their
transaction; this is handled by the migrator. transaction; this is handled by the migrator.
Pre- and post-migration callback may be registered with :meth:`register_pre_callback` or If a migration needs an additional dependency, it must be provided with :meth:`provide_dependency`
:meth:`register_post_callback`. before the migration is run.
If a migration has additional dependencies, it is recommended to use functools.partial to provide Example Usage:
the dependencies and register the partial as the migration callback. ```py
# Define the migrate callback
def migrate_callback(cursor: sqlite3.Cursor, **kwargs) -> None:
some_dependency = kwargs["some_dependency"]
...
# Instantiate the migration, declaring dependencies
migration = Migration(
from_version=0,
to_version=1,
migrate_callback=migrate_callback,
dependencies={"some_dependency": MigrationDependency(name="some_dependency", dependency_type=SomeType)},
)
# Register the dependency before running the migration
migration.provide_dependency(name="some_dependency", value=some_value)
```
""" """
from_version: int = Field(ge=0, strict=True, description="The database version on which this migration may be run") from_version: int = Field(ge=0, strict=True, description="The database version on which this migration may be run")
to_version: int = Field(ge=1, strict=True, description="The database version that results from this migration") to_version: int = Field(ge=1, strict=True, description="The database version that results from this migration")
migrate: MigrateCallback = Field(description="The callback to run to perform the migration") migrate_callback: MigrateCallback = Field(description="The callback to run to perform the migration")
pre_migrate: list[MigrateCallback] = Field( dependencies: dict[str, MigrationDependency] = Field(
default=[], description="A list of callbacks to run before the migration" default={}, description="A dict of dependencies that must be provided to the migration"
)
post_migrate: list[MigrateCallback] = Field(
default=[], description="A list of callbacks to run after the migration"
) )
@model_validator(mode="after") @model_validator(mode="after")
@ -48,13 +92,21 @@ class Migration(BaseModel):
# Callables are not hashable, so we need to implement our own __hash__ function to use this class in a set. # Callables are not hashable, so we need to implement our own __hash__ function to use this class in a set.
return hash((self.from_version, self.to_version)) return hash((self.from_version, self.to_version))
def register_pre_callback(self, callback: MigrateCallback) -> None: def provide_dependency(self, name: str, value: Any) -> None:
"""Registers a pre-migration callback.""" """Provides a dependency for this migration."""
self.pre_migrate.append(callback) if name not in self.dependencies:
raise ValueError(f"{name} of type {type(value)} is not a dependency of this migration")
self.dependencies[name].set(value)
def register_post_callback(self, callback: MigrateCallback) -> None: def run(self, cursor: sqlite3.Cursor) -> None:
"""Registers a post-migration callback.""" """Runs the migration."""
self.post_migrate.append(callback) missing_dependencies = [d.name for d in self.dependencies.values() if d.value is None]
if missing_dependencies:
raise ValueError(f"Missing migration dependencies: {', '.join(missing_dependencies)}")
self.migrate_callback = partial(self.migrate_callback, **{d.name: d.value for d in self.dependencies.values()})
self.migrate_callback(cursor=cursor)
model_config = ConfigDict(arbitrary_types_allowed=True)
class MigrationSet: class MigrationSet:
@ -78,7 +130,7 @@ class MigrationSet:
def validate_migration_chain(self) -> None: def validate_migration_chain(self) -> None:
""" """
Validates that the migrations form a single chain of migrations from version 0 to the latest version. Validates that the migrations form a single chain of migrations from version 0 to the latest version,
Raises a MigrationError if there is a problem. Raises a MigrationError if there is a problem.
""" """
if self.count == 0: if self.count == 0:

View File

@ -66,21 +66,12 @@ class SQLiteMigrator:
) )
self._logger.debug(f"Running migration from {migration.from_version} to {migration.to_version}") self._logger.debug(f"Running migration from {migration.from_version} to {migration.to_version}")
# Run pre-migration callbacks
if migration.pre_migrate:
self._logger.debug(f"Running {len(migration.pre_migrate)} pre-migration callbacks")
for callback in migration.pre_migrate:
callback(cursor)
# Run the actual migration # Run the actual migration
migration.migrate(cursor) migration.run(cursor)
# Update the version
cursor.execute("INSERT INTO migrations (version) VALUES (?);", (migration.to_version,)) cursor.execute("INSERT INTO migrations (version) VALUES (?);", (migration.to_version,))
# Run post-migration callbacks
if migration.post_migrate:
self._logger.debug(f"Running {len(migration.post_migrate)} post-migration callbacks")
for callback in migration.post_migrate:
callback(cursor)
self._logger.debug( self._logger.debug(
f"Successfully migrated database from {migration.from_version} to {migration.to_version}" f"Successfully migrated database from {migration.from_version} to {migration.to_version}"
) )

View File

@ -28,11 +28,8 @@ from invokeai.app.services.shared.graph import (
IterateInvocation, IterateInvocation,
LibraryGraph, LibraryGraph,
) )
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import migration_1
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import migration_2
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SQLiteMigrator
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from tests.fixtures.sqlite_database import CreateSqliteDatabaseFunction
from .test_invoker import create_edge from .test_invoker import create_edge
@ -50,15 +47,10 @@ def simple_graph():
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate # Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
# the test invocations. # the test invocations.
@pytest.fixture @pytest.fixture
def mock_services() -> InvocationServices: def mock_services(create_sqlite_database: CreateSqliteDatabaseFunction) -> InvocationServices:
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0) configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
logger = InvokeAILogger.get_logger() logger = InvokeAILogger.get_logger()
db_path = None if configuration.use_memory_db else configuration.db_path db = create_sqlite_database(configuration, logger)
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=configuration.log_sql)
migrator = SQLiteMigrator(db=db)
migrator.register_migration(migration_1)
migrator.register_migration(migration_2)
migrator.run_migrations()
# NOTE: none of these are actually called by the test invocations # NOTE: none of these are actually called by the test invocations
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions") graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
return InvocationServices( return InvocationServices(

View File

@ -3,8 +3,8 @@ import logging
import pytest import pytest
from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SQLiteMigrator
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from tests.fixtures.sqlite_database import CreateSqliteDatabaseFunction
# This import must happen before other invoke imports or test in other files(!!) break # This import must happen before other invoke imports or test in other files(!!) break
from .test_nodes import ( # isort: split from .test_nodes import ( # isort: split
@ -25,9 +25,6 @@ from invokeai.app.services.invoker import Invoker
from invokeai.app.services.item_storage.item_storage_sqlite import SqliteItemStorage from invokeai.app.services.item_storage.item_storage_sqlite import SqliteItemStorage
from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation, LibraryGraph from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation, LibraryGraph
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import migration_1
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import migration_2
@pytest.fixture @pytest.fixture
@ -54,15 +51,10 @@ def graph_with_subgraph():
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate # Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
# the test invocations. # the test invocations.
@pytest.fixture @pytest.fixture
def mock_services() -> InvocationServices: def mock_services(create_sqlite_database: CreateSqliteDatabaseFunction) -> InvocationServices:
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0) configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
logger = InvokeAILogger.get_logger() logger = InvokeAILogger.get_logger()
db_path = None if configuration.use_memory_db else configuration.db_path db = create_sqlite_database(configuration, logger)
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=configuration.log_sql)
migrator = SQLiteMigrator(db=db)
migrator.register_migration(migration_1)
migrator.register_migration(migration_2)
migrator.run_migrations()
# NOTE: none of these are actually called by the test invocations # NOTE: none of these are actually called by the test invocations
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions") graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")

View File

@ -18,12 +18,9 @@ from invokeai.app.services.model_install import (
ModelInstallServiceBase, ModelInstallServiceBase,
) )
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL, UnknownModelException from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL, UnknownModelException
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import migration_1
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import migration_2
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SQLiteMigrator
from invokeai.backend.model_manager.config import BaseModelType, ModelType from invokeai.backend.model_manager.config import BaseModelType, ModelType
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from tests.fixtures.sqlite_database import CreateSqliteDatabaseFunction
@pytest.fixture @pytest.fixture
@ -40,14 +37,11 @@ def app_config(datadir: Path) -> InvokeAIAppConfig:
@pytest.fixture @pytest.fixture
def store(app_config: InvokeAIAppConfig) -> ModelRecordServiceBase: def store(
app_config: InvokeAIAppConfig, create_sqlite_database: CreateSqliteDatabaseFunction
) -> ModelRecordServiceBase:
logger = InvokeAILogger.get_logger(config=app_config) logger = InvokeAILogger.get_logger(config=app_config)
db_path = None if app_config.use_memory_db else app_config.db_path db = create_sqlite_database(app_config, logger)
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=app_config.log_sql)
migrator = SQLiteMigrator(db=db)
migrator.register_migration(migration_1)
migrator.register_migration(migration_2)
migrator.run_migrations()
store: ModelRecordServiceBase = ModelRecordServiceSQL(db) store: ModelRecordServiceBase = ModelRecordServiceSQL(db)
return store return store

View File

@ -14,10 +14,6 @@ from invokeai.app.services.model_records import (
ModelRecordServiceSQL, ModelRecordServiceSQL,
UnknownModelException, UnknownModelException,
) )
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import migration_1
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import migration_2
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SQLiteMigrator
from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.config import (
BaseModelType, BaseModelType,
MainCheckpointConfig, MainCheckpointConfig,
@ -27,18 +23,14 @@ from invokeai.backend.model_manager.config import (
VaeDiffusersConfig, VaeDiffusersConfig,
) )
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from tests.fixtures.sqlite_database import CreateSqliteDatabaseFunction
@pytest.fixture @pytest.fixture
def store(datadir: Any) -> ModelRecordServiceBase: def store(datadir: Any, create_sqlite_database: CreateSqliteDatabaseFunction) -> ModelRecordServiceBase:
config = InvokeAIAppConfig(root=datadir) config = InvokeAIAppConfig(root=datadir)
logger = InvokeAILogger.get_logger(config=config) logger = InvokeAILogger.get_logger(config=config)
db_path = None if config.use_memory_db else config.db_path db = create_sqlite_database(config, logger)
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql)
migrator = SQLiteMigrator(db=db)
migrator.register_migration(migration_1)
migrator.register_migration(migration_2)
migrator.run_migrations()
return ModelRecordServiceSQL(db) return ModelRecordServiceSQL(db)

View File

@ -4,3 +4,5 @@
# We import the model_installer and torch_device fixtures here so that they can be used by all tests. Flake8 does not # We import the model_installer and torch_device fixtures here so that they can be used by all tests. Flake8 does not
# play well with fixtures (F401 and F811), so this is cleaner than importing in all files that use these fixtures. # play well with fixtures (F401 and F811), so this is cleaner than importing in all files that use these fixtures.
from invokeai.backend.util.test_utils import model_installer, torch_device # noqa: F401 from invokeai.backend.util.test_utils import model_installer, torch_device # noqa: F401
pytest_plugins = ["tests.fixtures.sqlite_database"]

0
tests/fixtures/__init__.py vendored Normal file
View File

33
tests/fixtures/sqlite_database.py vendored Normal file
View File

@ -0,0 +1,33 @@
from logging import Logger
from typing import Callable
from unittest import mock
import pytest
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import migration_1
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import migration_2
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SQLiteMigrator
CreateSqliteDatabaseFunction = Callable[[InvokeAIAppConfig, Logger], SqliteDatabase]
@pytest.fixture
def create_sqlite_database() -> CreateSqliteDatabaseFunction:
def _create_sqlite_database(config: InvokeAIAppConfig, logger: Logger) -> SqliteDatabase:
db_path = None if config.use_memory_db else config.db_path
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql)
image_files = mock.Mock(spec=ImageFileStorageBase)
migrator = SQLiteMigrator(db=db)
migration_2.provide_dependency("logger", logger)
migration_2.provide_dependency("image_files", image_files)
migrator.register_migration(migration_1)
migrator.register_migration(migration_2)
migrator.run_migrations()
return db
return _create_sqlite_database

View File

@ -31,45 +31,45 @@ def migrator(logger: Logger) -> SQLiteMigrator:
return SQLiteMigrator(db=db) return SQLiteMigrator(db=db)
@pytest.fixture
def migration_no_op() -> Migration:
return Migration(from_version=0, to_version=1, migrate=lambda cursor: None)
@pytest.fixture
def migration_create_test_table() -> Migration:
def migrate(cursor: sqlite3.Cursor) -> None:
cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
return Migration(from_version=0, to_version=1, migrate=migrate)
@pytest.fixture
def failing_migration() -> Migration:
def failing_migration(cursor: sqlite3.Cursor) -> None:
raise Exception("Bad migration")
return Migration(from_version=0, to_version=1, migrate=failing_migration)
@pytest.fixture @pytest.fixture
def no_op_migrate_callback() -> MigrateCallback: def no_op_migrate_callback() -> MigrateCallback:
def no_op_migrate(cursor: sqlite3.Cursor) -> None: def no_op_migrate(cursor: sqlite3.Cursor, **kwargs) -> None:
pass pass
return no_op_migrate return no_op_migrate
@pytest.fixture
def migration_no_op(no_op_migrate_callback: MigrateCallback) -> Migration:
return Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback)
@pytest.fixture
def migration_create_test_table() -> Migration:
def migrate(cursor: sqlite3.Cursor, **kwargs) -> None:
cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
return Migration(from_version=0, to_version=1, migrate_callback=migrate)
@pytest.fixture
def failing_migration() -> Migration:
def failing_migration(cursor: sqlite3.Cursor, **kwargs) -> None:
raise Exception("Bad migration")
return Migration(from_version=0, to_version=1, migrate_callback=failing_migration)
@pytest.fixture @pytest.fixture
def failing_migrate_callback() -> MigrateCallback: def failing_migrate_callback() -> MigrateCallback:
def failing_migrate(cursor: sqlite3.Cursor) -> None: def failing_migrate(cursor: sqlite3.Cursor, **kwargs) -> None:
raise Exception("Bad migration") raise Exception("Bad migration")
return failing_migrate return failing_migrate
def create_migrate(i: int) -> MigrateCallback: def create_migrate(i: int) -> MigrateCallback:
def migrate(cursor: sqlite3.Cursor) -> None: def migrate(cursor: sqlite3.Cursor, **kwargs) -> None:
cursor.execute(f"CREATE TABLE test{i} (id INTEGER PRIMARY KEY);") cursor.execute(f"CREATE TABLE test{i} (id INTEGER PRIMARY KEY);")
return migrate return migrate
@ -77,30 +77,16 @@ def create_migrate(i: int) -> MigrateCallback:
def test_migration_to_version_is_one_gt_from_version(no_op_migrate_callback: MigrateCallback) -> None: def test_migration_to_version_is_one_gt_from_version(no_op_migrate_callback: MigrateCallback) -> None:
with pytest.raises(ValidationError, match="to_version must be one greater than from_version"): with pytest.raises(ValidationError, match="to_version must be one greater than from_version"):
Migration(from_version=0, to_version=2, migrate=no_op_migrate_callback) Migration(from_version=0, to_version=2, migrate_callback=no_op_migrate_callback)
# not raising is sufficient # not raising is sufficient
Migration(from_version=1, to_version=2, migrate=no_op_migrate_callback) Migration(from_version=1, to_version=2, migrate_callback=no_op_migrate_callback)
def test_migration_hash(no_op_migrate_callback: MigrateCallback) -> None: def test_migration_hash(no_op_migrate_callback: MigrateCallback) -> None:
migration = Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback) migration = Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback)
assert hash(migration) == hash((0, 1)) assert hash(migration) == hash((0, 1))
def test_migration_registers_pre_and_post_callbacks(no_op_migrate_callback: MigrateCallback) -> None:
def pre_callback(cursor: sqlite3.Cursor) -> None:
pass
def post_callback(cursor: sqlite3.Cursor) -> None:
pass
migration = Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback)
migration.register_pre_callback(pre_callback)
migration.register_post_callback(post_callback)
assert pre_callback in migration.pre_migrate
assert post_callback in migration.post_migrate
def test_migration_set_add_migration(migrator: SQLiteMigrator, migration_no_op: Migration) -> None: def test_migration_set_add_migration(migrator: SQLiteMigrator, migration_no_op: Migration) -> None:
migration = migration_no_op migration = migration_no_op
migrator._migration_set.register(migration) migrator._migration_set.register(migration)
@ -110,13 +96,13 @@ def test_migration_set_add_migration(migrator: SQLiteMigrator, migration_no_op:
def test_migration_set_may_not_register_dupes( def test_migration_set_may_not_register_dupes(
migrator: SQLiteMigrator, no_op_migrate_callback: MigrateCallback migrator: SQLiteMigrator, no_op_migrate_callback: MigrateCallback
) -> None: ) -> None:
migrate_0_to_1_a = Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback) migrate_0_to_1_a = Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback)
migrate_0_to_1_b = Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback) migrate_0_to_1_b = Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback)
migrator._migration_set.register(migrate_0_to_1_a) migrator._migration_set.register(migrate_0_to_1_a)
with pytest.raises(MigrationVersionError, match=r"Migration with from_version or to_version already registered"): with pytest.raises(MigrationVersionError, match=r"Migration with from_version or to_version already registered"):
migrator._migration_set.register(migrate_0_to_1_b) migrator._migration_set.register(migrate_0_to_1_b)
migrate_1_to_2_a = Migration(from_version=1, to_version=2, migrate=no_op_migrate_callback) migrate_1_to_2_a = Migration(from_version=1, to_version=2, migrate_callback=no_op_migrate_callback)
migrate_1_to_2_b = Migration(from_version=1, to_version=2, migrate=no_op_migrate_callback) migrate_1_to_2_b = Migration(from_version=1, to_version=2, migrate_callback=no_op_migrate_callback)
migrator._migration_set.register(migrate_1_to_2_a) migrator._migration_set.register(migrate_1_to_2_a)
with pytest.raises(MigrationVersionError, match=r"Migration with from_version or to_version already registered"): with pytest.raises(MigrationVersionError, match=r"Migration with from_version or to_version already registered"):
migrator._migration_set.register(migrate_1_to_2_b) migrator._migration_set.register(migrate_1_to_2_b)
@ -131,15 +117,15 @@ def test_migration_set_gets_migration(migration_no_op: Migration) -> None:
def test_migration_set_validates_migration_chain(no_op_migrate_callback: MigrateCallback) -> None: def test_migration_set_validates_migration_chain(no_op_migrate_callback: MigrateCallback) -> None:
migration_set = MigrationSet() migration_set = MigrationSet()
migration_set.register(Migration(from_version=1, to_version=2, migrate=no_op_migrate_callback)) migration_set.register(Migration(from_version=1, to_version=2, migrate_callback=no_op_migrate_callback))
with pytest.raises(MigrationError, match="Migration chain is fragmented"): with pytest.raises(MigrationError, match="Migration chain is fragmented"):
# no migration from 0 to 1 # no migration from 0 to 1
migration_set.validate_migration_chain() migration_set.validate_migration_chain()
migration_set.register(Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback)) migration_set.register(Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback))
migration_set.validate_migration_chain() migration_set.validate_migration_chain()
migration_set.register(Migration(from_version=2, to_version=3, migrate=no_op_migrate_callback)) migration_set.register(Migration(from_version=2, to_version=3, migrate_callback=no_op_migrate_callback))
migration_set.validate_migration_chain() migration_set.validate_migration_chain()
migration_set.register(Migration(from_version=4, to_version=5, migrate=no_op_migrate_callback)) migration_set.register(Migration(from_version=4, to_version=5, migrate_callback=no_op_migrate_callback))
with pytest.raises(MigrationError, match="Migration chain is fragmented"): with pytest.raises(MigrationError, match="Migration chain is fragmented"):
# no migration from 3 to 4 # no migration from 3 to 4
migration_set.validate_migration_chain() migration_set.validate_migration_chain()
@ -148,18 +134,18 @@ def test_migration_set_validates_migration_chain(no_op_migrate_callback: Migrate
def test_migration_set_counts_migrations(no_op_migrate_callback: MigrateCallback) -> None: def test_migration_set_counts_migrations(no_op_migrate_callback: MigrateCallback) -> None:
migration_set = MigrationSet() migration_set = MigrationSet()
assert migration_set.count == 0 assert migration_set.count == 0
migration_set.register(Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback)) migration_set.register(Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback))
assert migration_set.count == 1 assert migration_set.count == 1
migration_set.register(Migration(from_version=1, to_version=2, migrate=no_op_migrate_callback)) migration_set.register(Migration(from_version=1, to_version=2, migrate_callback=no_op_migrate_callback))
assert migration_set.count == 2 assert migration_set.count == 2
def test_migration_set_gets_latest_version(no_op_migrate_callback: MigrateCallback) -> None: def test_migration_set_gets_latest_version(no_op_migrate_callback: MigrateCallback) -> None:
migration_set = MigrationSet() migration_set = MigrationSet()
assert migration_set.latest_version == 0 assert migration_set.latest_version == 0
migration_set.register(Migration(from_version=1, to_version=2, migrate=no_op_migrate_callback)) migration_set.register(Migration(from_version=1, to_version=2, migrate_callback=no_op_migrate_callback))
assert migration_set.latest_version == 2 assert migration_set.latest_version == 2
migration_set.register(Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback)) migration_set.register(Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback))
assert migration_set.latest_version == 2 assert migration_set.latest_version == 2
@ -206,7 +192,7 @@ def test_migrator_runs_single_migration(migrator: SQLiteMigrator, migration_crea
def test_migrator_runs_all_migrations_in_memory(migrator: SQLiteMigrator) -> None: def test_migrator_runs_all_migrations_in_memory(migrator: SQLiteMigrator) -> None:
cursor = migrator._db.conn.cursor() cursor = migrator._db.conn.cursor()
migrations = [Migration(from_version=i, to_version=i + 1, migrate=create_migrate(i)) for i in range(0, 3)] migrations = [Migration(from_version=i, to_version=i + 1, migrate_callback=create_migrate(i)) for i in range(0, 3)]
for migration in migrations: for migration in migrations:
migrator.register_migration(migration) migrator.register_migration(migration)
migrator.run_migrations() migrator.run_migrations()
@ -219,7 +205,9 @@ def test_migrator_runs_all_migrations_file(logger: Logger) -> None:
# The Migrator closes the database when it finishes; we cannot use a context manager. # The Migrator closes the database when it finishes; we cannot use a context manager.
db = SqliteDatabase(db_path=original_db_path, logger=logger, verbose=False) db = SqliteDatabase(db_path=original_db_path, logger=logger, verbose=False)
migrator = SQLiteMigrator(db=db) migrator = SQLiteMigrator(db=db)
migrations = [Migration(from_version=i, to_version=i + 1, migrate=create_migrate(i)) for i in range(0, 3)] migrations = [
Migration(from_version=i, to_version=i + 1, migrate_callback=create_migrate(i)) for i in range(0, 3)
]
for migration in migrations: for migration in migrations:
migrator.register_migration(migration) migrator.register_migration(migration)
migrator.run_migrations() migrator.run_migrations()
@ -235,7 +223,7 @@ def test_migrator_makes_no_changes_on_failed_migration(
migrator.register_migration(migration_no_op) migrator.register_migration(migration_no_op)
migrator.run_migrations() migrator.run_migrations()
assert migrator._get_current_version(cursor) == 1 assert migrator._get_current_version(cursor) == 1
migrator.register_migration(Migration(from_version=1, to_version=2, migrate=failing_migrate_callback)) migrator.register_migration(Migration(from_version=1, to_version=2, migrate_callback=failing_migrate_callback))
with pytest.raises(MigrationError, match="Bad migration"): with pytest.raises(MigrationError, match="Bad migration"):
migrator.run_migrations() migrator.run_migrations()
assert migrator._get_current_version(cursor) == 1 assert migrator._get_current_version(cursor) == 1