mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(db): refactor migrate callbacks to use dependencies, remote pre/post callbacks
This commit is contained in:
parent
6063760ce2
commit
0cf7fe43af
@ -1,11 +1,9 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from functools import partial
|
||||
from logging import Logger
|
||||
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import migration_1
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import migration_2
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2_post import migrate_embedded_workflows
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SQLiteMigrator
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
@ -77,8 +75,11 @@ class ApiDependencies:
|
||||
db_path = None if config.use_memory_db else config.db_path
|
||||
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql)
|
||||
|
||||
# This migration requires an ImageFileStorageBase service and logger
|
||||
migration_2.provide_dependency("image_files", image_files)
|
||||
migration_2.provide_dependency("logger", logger)
|
||||
|
||||
migrator = SQLiteMigrator(db=db)
|
||||
migration_2.register_post_callback(partial(migrate_embedded_workflows, logger=logger, image_files=image_files))
|
||||
migrator.register_migration(migration_1)
|
||||
migrator.register_migration(migration_2)
|
||||
migrator.run_migrations()
|
||||
|
@ -3,7 +3,7 @@ import sqlite3
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
|
||||
def _migrate(cursor: sqlite3.Cursor) -> None:
|
||||
def migrate_callback(cursor: sqlite3.Cursor, **kwargs) -> None:
|
||||
"""Migration callback for database version 1."""
|
||||
|
||||
_create_board_images(cursor)
|
||||
@ -353,7 +353,7 @@ def _create_workflows(cursor: sqlite3.Cursor) -> None:
|
||||
migration_1 = Migration(
|
||||
from_version=0,
|
||||
to_version=1,
|
||||
migrate=_migrate,
|
||||
migrate_callback=migrate_callback,
|
||||
)
|
||||
"""
|
||||
Database version 1 (initial state).
|
||||
|
@ -1,16 +1,24 @@
|
||||
import sqlite3
|
||||
from logging import Logger
|
||||
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration, MigrationDependency
|
||||
|
||||
|
||||
def _migrate(cursor: sqlite3.Cursor) -> None:
|
||||
def migrate_callback(cursor: sqlite3.Cursor, **kwargs) -> None:
|
||||
"""Migration callback for database version 2."""
|
||||
|
||||
logger = kwargs["logger"]
|
||||
image_files = kwargs["image_files"]
|
||||
|
||||
_add_images_has_workflow(cursor)
|
||||
_add_session_queue_workflow(cursor)
|
||||
_drop_old_workflow_tables(cursor)
|
||||
_add_workflow_library(cursor)
|
||||
_drop_model_manager_metadata(cursor)
|
||||
_migrate_embedded_workflows(cursor, logger, image_files)
|
||||
|
||||
|
||||
def _add_images_has_workflow(cursor: sqlite3.Cursor) -> None:
|
||||
@ -89,19 +97,64 @@ def _drop_model_manager_metadata(cursor: sqlite3.Cursor) -> None:
|
||||
cursor.execute("DROP TABLE IF EXISTS model_manager_metadata;")
|
||||
|
||||
|
||||
def _migrate_embedded_workflows(
|
||||
cursor: sqlite3.Cursor,
|
||||
logger: Logger,
|
||||
image_files: ImageFileStorageBase,
|
||||
) -> None:
|
||||
"""
|
||||
In the v3.5.0 release, InvokeAI changed how it handles embedded workflows. The `images` table in
|
||||
the database now has a `has_workflow` column, indicating if an image has a workflow embedded.
|
||||
|
||||
This migrate callback checks each image for the presence of an embedded workflow, then updates its entry
|
||||
in the database accordingly.
|
||||
"""
|
||||
# Get the total number of images and chunk it into pages
|
||||
cursor.execute("SELECT image_name FROM images")
|
||||
image_names: list[str] = [image[0] for image in cursor.fetchall()]
|
||||
total_image_names = len(image_names)
|
||||
|
||||
if not total_image_names:
|
||||
return
|
||||
|
||||
logger.info(f"Migrating workflows for {total_image_names} images")
|
||||
|
||||
# Migrate the images
|
||||
to_migrate: list[tuple[bool, str]] = []
|
||||
pbar = tqdm(image_names)
|
||||
for idx, image_name in enumerate(pbar):
|
||||
pbar.set_description(f"Checking image {idx + 1}/{total_image_names} for workflow")
|
||||
pil_image = image_files.get(image_name)
|
||||
if "invokeai_workflow" in pil_image.info:
|
||||
to_migrate.append((True, image_name))
|
||||
|
||||
logger.info(f"Adding {len(to_migrate)} embedded workflows to database")
|
||||
cursor.executemany("UPDATE images SET has_workflow = ? WHERE image_name = ?", to_migrate)
|
||||
|
||||
|
||||
image_files_dependency = MigrationDependency(name="image_files", dependency_type=ImageFileStorageBase)
|
||||
logger_dependency = MigrationDependency(name="logger", dependency_type=Logger)
|
||||
|
||||
|
||||
migration_2 = Migration(
|
||||
from_version=1,
|
||||
to_version=2,
|
||||
migrate=_migrate,
|
||||
migrate_callback=migrate_callback,
|
||||
dependencies={"image_files": image_files_dependency, "logger": logger_dependency},
|
||||
)
|
||||
"""
|
||||
Database version 2.
|
||||
|
||||
Introduced in v3.5.0 for the new workflow library.
|
||||
|
||||
Dependencies:
|
||||
- image_files: ImageFileStorageBase
|
||||
- logger: Logger
|
||||
|
||||
Migration:
|
||||
- Add `has_workflow` column to `images` table
|
||||
- Add `workflow` column to `session_queue` table
|
||||
- Drop `workflows` and `workflow_images` tables
|
||||
- Add `workflow_library` table
|
||||
- Populates the `has_workflow` column in the `images` table (requires `image_files` & `logger` dependencies)
|
||||
"""
|
||||
|
@ -1,41 +0,0 @@
|
||||
import sqlite3
|
||||
from logging import Logger
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
|
||||
|
||||
|
||||
def migrate_embedded_workflows(
|
||||
cursor: sqlite3.Cursor,
|
||||
logger: Logger,
|
||||
image_files: ImageFileStorageBase,
|
||||
) -> None:
|
||||
"""
|
||||
In the v3.5.0 release, InvokeAI changed how it handles embedded workflows. The `images` table in
|
||||
the database now has a `has_workflow` column, indicating if an image has a workflow embedded.
|
||||
|
||||
This migrate callbakc checks each image for the presence of an embedded workflow, then updates its entry
|
||||
in the database accordingly.
|
||||
"""
|
||||
# Get the total number of images and chunk it into pages
|
||||
cursor.execute("SELECT image_name FROM images")
|
||||
image_names: list[str] = [image[0] for image in cursor.fetchall()]
|
||||
total_image_names = len(image_names)
|
||||
|
||||
if not total_image_names:
|
||||
return
|
||||
|
||||
logger.info(f"Migrating workflows for {total_image_names} images")
|
||||
|
||||
# Migrate the images
|
||||
to_migrate: list[tuple[bool, str]] = []
|
||||
pbar = tqdm(image_names)
|
||||
for idx, image_name in enumerate(pbar):
|
||||
pbar.set_description(f"Checking image {idx + 1}/{total_image_names} for workflow")
|
||||
pil_image = image_files.get(image_name)
|
||||
if "invokeai_workflow" in pil_image.info:
|
||||
to_migrate.append((True, image_name))
|
||||
|
||||
logger.info(f"Adding {len(to_migrate)} embedded workflows to database")
|
||||
cursor.executemany("UPDATE images SET has_workflow = ? WHERE image_name = ?", to_migrate)
|
@ -1,9 +1,16 @@
|
||||
import sqlite3
|
||||
from typing import Callable, Optional, TypeAlias
|
||||
from functools import partial
|
||||
from typing import Any, Optional, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
MigrateCallback: TypeAlias = Callable[[sqlite3.Cursor], None]
|
||||
|
||||
@runtime_checkable
|
||||
class MigrateCallback(Protocol):
|
||||
"""A callback that performs a migration."""
|
||||
|
||||
def __call__(self, cursor: sqlite3.Cursor, **kwargs: Any) -> None:
|
||||
...
|
||||
|
||||
|
||||
class MigrationError(RuntimeError):
|
||||
@ -14,28 +21,65 @@ class MigrationVersionError(ValueError):
|
||||
"""Raised when a migration version is invalid."""
|
||||
|
||||
|
||||
class MigrationDependency:
|
||||
"""Represents a dependency for a migration."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
dependency_type: Any,
|
||||
) -> None:
|
||||
self.name = name
|
||||
self.dependency_type = dependency_type
|
||||
self.value = None
|
||||
|
||||
def set(self, value: Any) -> None:
|
||||
"""Sets the value of the dependency."""
|
||||
if not isinstance(value, self.dependency_type):
|
||||
raise ValueError(f"Dependency {self.name} must be of type {self.dependency_type}")
|
||||
self.value = value
|
||||
|
||||
|
||||
class Migration(BaseModel):
|
||||
"""
|
||||
Represents a migration for a SQLite database.
|
||||
|
||||
:param from_version: The database version on which this migration may be run
|
||||
:param to_version: The database version that results from this migration
|
||||
:param migrate: The callback to run to perform the migration
|
||||
:param dependencies: A dict of dependencies that must be provided to the migration
|
||||
|
||||
Migration callbacks will be provided an open cursor to the database. They should not commit their
|
||||
transaction; this is handled by the migrator.
|
||||
|
||||
Pre- and post-migration callback may be registered with :meth:`register_pre_callback` or
|
||||
:meth:`register_post_callback`.
|
||||
If a migration needs an additional dependency, it must be provided with :meth:`provide_dependency`
|
||||
before the migration is run.
|
||||
|
||||
If a migration has additional dependencies, it is recommended to use functools.partial to provide
|
||||
the dependencies and register the partial as the migration callback.
|
||||
Example Usage:
|
||||
```py
|
||||
# Define the migrate callback
|
||||
def migrate_callback(cursor: sqlite3.Cursor, **kwargs) -> None:
|
||||
some_dependency = kwargs["some_dependency"]
|
||||
...
|
||||
|
||||
# Instantiate the migration, declaring dependencies
|
||||
migration = Migration(
|
||||
from_version=0,
|
||||
to_version=1,
|
||||
migrate_callback=migrate_callback,
|
||||
dependencies={"some_dependency": MigrationDependency(name="some_dependency", dependency_type=SomeType)},
|
||||
)
|
||||
|
||||
# Register the dependency before running the migration
|
||||
migration.provide_dependency(name="some_dependency", value=some_value)
|
||||
```
|
||||
"""
|
||||
|
||||
from_version: int = Field(ge=0, strict=True, description="The database version on which this migration may be run")
|
||||
to_version: int = Field(ge=1, strict=True, description="The database version that results from this migration")
|
||||
migrate: MigrateCallback = Field(description="The callback to run to perform the migration")
|
||||
pre_migrate: list[MigrateCallback] = Field(
|
||||
default=[], description="A list of callbacks to run before the migration"
|
||||
)
|
||||
post_migrate: list[MigrateCallback] = Field(
|
||||
default=[], description="A list of callbacks to run after the migration"
|
||||
migrate_callback: MigrateCallback = Field(description="The callback to run to perform the migration")
|
||||
dependencies: dict[str, MigrationDependency] = Field(
|
||||
default={}, description="A dict of dependencies that must be provided to the migration"
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
@ -48,13 +92,21 @@ class Migration(BaseModel):
|
||||
# Callables are not hashable, so we need to implement our own __hash__ function to use this class in a set.
|
||||
return hash((self.from_version, self.to_version))
|
||||
|
||||
def register_pre_callback(self, callback: MigrateCallback) -> None:
|
||||
"""Registers a pre-migration callback."""
|
||||
self.pre_migrate.append(callback)
|
||||
def provide_dependency(self, name: str, value: Any) -> None:
|
||||
"""Provides a dependency for this migration."""
|
||||
if name not in self.dependencies:
|
||||
raise ValueError(f"{name} of type {type(value)} is not a dependency of this migration")
|
||||
self.dependencies[name].set(value)
|
||||
|
||||
def register_post_callback(self, callback: MigrateCallback) -> None:
|
||||
"""Registers a post-migration callback."""
|
||||
self.post_migrate.append(callback)
|
||||
def run(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Runs the migration."""
|
||||
missing_dependencies = [d.name for d in self.dependencies.values() if d.value is None]
|
||||
if missing_dependencies:
|
||||
raise ValueError(f"Missing migration dependencies: {', '.join(missing_dependencies)}")
|
||||
self.migrate_callback = partial(self.migrate_callback, **{d.name: d.value for d in self.dependencies.values()})
|
||||
self.migrate_callback(cursor=cursor)
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class MigrationSet:
|
||||
@ -78,7 +130,7 @@ class MigrationSet:
|
||||
|
||||
def validate_migration_chain(self) -> None:
|
||||
"""
|
||||
Validates that the migrations form a single chain of migrations from version 0 to the latest version.
|
||||
Validates that the migrations form a single chain of migrations from version 0 to the latest version,
|
||||
Raises a MigrationError if there is a problem.
|
||||
"""
|
||||
if self.count == 0:
|
||||
|
@ -66,21 +66,12 @@ class SQLiteMigrator:
|
||||
)
|
||||
self._logger.debug(f"Running migration from {migration.from_version} to {migration.to_version}")
|
||||
|
||||
# Run pre-migration callbacks
|
||||
if migration.pre_migrate:
|
||||
self._logger.debug(f"Running {len(migration.pre_migrate)} pre-migration callbacks")
|
||||
for callback in migration.pre_migrate:
|
||||
callback(cursor)
|
||||
|
||||
# Run the actual migration
|
||||
migration.migrate(cursor)
|
||||
migration.run(cursor)
|
||||
|
||||
# Update the version
|
||||
cursor.execute("INSERT INTO migrations (version) VALUES (?);", (migration.to_version,))
|
||||
|
||||
# Run post-migration callbacks
|
||||
if migration.post_migrate:
|
||||
self._logger.debug(f"Running {len(migration.post_migrate)} post-migration callbacks")
|
||||
for callback in migration.post_migrate:
|
||||
callback(cursor)
|
||||
self._logger.debug(
|
||||
f"Successfully migrated database from {migration.from_version} to {migration.to_version}"
|
||||
)
|
||||
|
@ -28,11 +28,8 @@ from invokeai.app.services.shared.graph import (
|
||||
IterateInvocation,
|
||||
LibraryGraph,
|
||||
)
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import migration_1
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import migration_2
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SQLiteMigrator
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from tests.fixtures.sqlite_database import CreateSqliteDatabaseFunction
|
||||
|
||||
from .test_invoker import create_edge
|
||||
|
||||
@ -50,15 +47,10 @@ def simple_graph():
|
||||
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
|
||||
# the test invocations.
|
||||
@pytest.fixture
|
||||
def mock_services() -> InvocationServices:
|
||||
def mock_services(create_sqlite_database: CreateSqliteDatabaseFunction) -> InvocationServices:
|
||||
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
|
||||
logger = InvokeAILogger.get_logger()
|
||||
db_path = None if configuration.use_memory_db else configuration.db_path
|
||||
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=configuration.log_sql)
|
||||
migrator = SQLiteMigrator(db=db)
|
||||
migrator.register_migration(migration_1)
|
||||
migrator.register_migration(migration_2)
|
||||
migrator.run_migrations()
|
||||
db = create_sqlite_database(configuration, logger)
|
||||
# NOTE: none of these are actually called by the test invocations
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
|
||||
return InvocationServices(
|
||||
|
@ -3,8 +3,8 @@ import logging
|
||||
import pytest
|
||||
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SQLiteMigrator
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from tests.fixtures.sqlite_database import CreateSqliteDatabaseFunction
|
||||
|
||||
# This import must happen before other invoke imports or test in other files(!!) break
|
||||
from .test_nodes import ( # isort: split
|
||||
@ -25,9 +25,6 @@ from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.item_storage.item_storage_sqlite import SqliteItemStorage
|
||||
from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID
|
||||
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation, LibraryGraph
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import migration_1
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import migration_2
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -54,15 +51,10 @@ def graph_with_subgraph():
|
||||
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
|
||||
# the test invocations.
|
||||
@pytest.fixture
|
||||
def mock_services() -> InvocationServices:
|
||||
def mock_services(create_sqlite_database: CreateSqliteDatabaseFunction) -> InvocationServices:
|
||||
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
|
||||
logger = InvokeAILogger.get_logger()
|
||||
db_path = None if configuration.use_memory_db else configuration.db_path
|
||||
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=configuration.log_sql)
|
||||
migrator = SQLiteMigrator(db=db)
|
||||
migrator.register_migration(migration_1)
|
||||
migrator.register_migration(migration_2)
|
||||
migrator.run_migrations()
|
||||
db = create_sqlite_database(configuration, logger)
|
||||
|
||||
# NOTE: none of these are actually called by the test invocations
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
|
||||
|
@ -18,12 +18,9 @@ from invokeai.app.services.model_install import (
|
||||
ModelInstallServiceBase,
|
||||
)
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL, UnknownModelException
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import migration_1
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import migration_2
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SQLiteMigrator
|
||||
from invokeai.backend.model_manager.config import BaseModelType, ModelType
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from tests.fixtures.sqlite_database import CreateSqliteDatabaseFunction
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -40,14 +37,11 @@ def app_config(datadir: Path) -> InvokeAIAppConfig:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store(app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
|
||||
def store(
|
||||
app_config: InvokeAIAppConfig, create_sqlite_database: CreateSqliteDatabaseFunction
|
||||
) -> ModelRecordServiceBase:
|
||||
logger = InvokeAILogger.get_logger(config=app_config)
|
||||
db_path = None if app_config.use_memory_db else app_config.db_path
|
||||
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=app_config.log_sql)
|
||||
migrator = SQLiteMigrator(db=db)
|
||||
migrator.register_migration(migration_1)
|
||||
migrator.register_migration(migration_2)
|
||||
migrator.run_migrations()
|
||||
db = create_sqlite_database(app_config, logger)
|
||||
store: ModelRecordServiceBase = ModelRecordServiceSQL(db)
|
||||
return store
|
||||
|
||||
|
@ -14,10 +14,6 @@ from invokeai.app.services.model_records import (
|
||||
ModelRecordServiceSQL,
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import migration_1
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import migration_2
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SQLiteMigrator
|
||||
from invokeai.backend.model_manager.config import (
|
||||
BaseModelType,
|
||||
MainCheckpointConfig,
|
||||
@ -27,18 +23,14 @@ from invokeai.backend.model_manager.config import (
|
||||
VaeDiffusersConfig,
|
||||
)
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from tests.fixtures.sqlite_database import CreateSqliteDatabaseFunction
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store(datadir: Any) -> ModelRecordServiceBase:
|
||||
def store(datadir: Any, create_sqlite_database: CreateSqliteDatabaseFunction) -> ModelRecordServiceBase:
|
||||
config = InvokeAIAppConfig(root=datadir)
|
||||
logger = InvokeAILogger.get_logger(config=config)
|
||||
db_path = None if config.use_memory_db else config.db_path
|
||||
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql)
|
||||
migrator = SQLiteMigrator(db=db)
|
||||
migrator.register_migration(migration_1)
|
||||
migrator.register_migration(migration_2)
|
||||
migrator.run_migrations()
|
||||
db = create_sqlite_database(config, logger)
|
||||
return ModelRecordServiceSQL(db)
|
||||
|
||||
|
||||
|
@ -4,3 +4,5 @@
|
||||
# We import the model_installer and torch_device fixtures here so that they can be used by all tests. Flake8 does not
|
||||
# play well with fixtures (F401 and F811), so this is cleaner than importing in all files that use these fixtures.
|
||||
from invokeai.backend.util.test_utils import model_installer, torch_device # noqa: F401
|
||||
|
||||
pytest_plugins = ["tests.fixtures.sqlite_database"]
|
||||
|
0
tests/fixtures/__init__.py
vendored
Normal file
0
tests/fixtures/__init__.py
vendored
Normal file
33
tests/fixtures/sqlite_database.py
vendored
Normal file
33
tests/fixtures/sqlite_database.py
vendored
Normal file
@ -0,0 +1,33 @@
|
||||
from logging import Logger
|
||||
from typing import Callable
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import migration_1
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import migration_2
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SQLiteMigrator
|
||||
|
||||
CreateSqliteDatabaseFunction = Callable[[InvokeAIAppConfig, Logger], SqliteDatabase]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_sqlite_database() -> CreateSqliteDatabaseFunction:
|
||||
def _create_sqlite_database(config: InvokeAIAppConfig, logger: Logger) -> SqliteDatabase:
|
||||
db_path = None if config.use_memory_db else config.db_path
|
||||
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql)
|
||||
|
||||
image_files = mock.Mock(spec=ImageFileStorageBase)
|
||||
|
||||
migrator = SQLiteMigrator(db=db)
|
||||
migration_2.provide_dependency("logger", logger)
|
||||
migration_2.provide_dependency("image_files", image_files)
|
||||
migrator.register_migration(migration_1)
|
||||
migrator.register_migration(migration_2)
|
||||
migrator.run_migrations()
|
||||
return db
|
||||
|
||||
return _create_sqlite_database
|
@ -31,45 +31,45 @@ def migrator(logger: Logger) -> SQLiteMigrator:
|
||||
return SQLiteMigrator(db=db)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def migration_no_op() -> Migration:
|
||||
return Migration(from_version=0, to_version=1, migrate=lambda cursor: None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def migration_create_test_table() -> Migration:
|
||||
def migrate(cursor: sqlite3.Cursor) -> None:
|
||||
cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
|
||||
return Migration(from_version=0, to_version=1, migrate=migrate)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def failing_migration() -> Migration:
|
||||
def failing_migration(cursor: sqlite3.Cursor) -> None:
|
||||
raise Exception("Bad migration")
|
||||
|
||||
return Migration(from_version=0, to_version=1, migrate=failing_migration)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def no_op_migrate_callback() -> MigrateCallback:
|
||||
def no_op_migrate(cursor: sqlite3.Cursor) -> None:
|
||||
def no_op_migrate(cursor: sqlite3.Cursor, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
return no_op_migrate
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def migration_no_op(no_op_migrate_callback: MigrateCallback) -> Migration:
|
||||
return Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def migration_create_test_table() -> Migration:
|
||||
def migrate(cursor: sqlite3.Cursor, **kwargs) -> None:
|
||||
cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
|
||||
return Migration(from_version=0, to_version=1, migrate_callback=migrate)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def failing_migration() -> Migration:
|
||||
def failing_migration(cursor: sqlite3.Cursor, **kwargs) -> None:
|
||||
raise Exception("Bad migration")
|
||||
|
||||
return Migration(from_version=0, to_version=1, migrate_callback=failing_migration)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def failing_migrate_callback() -> MigrateCallback:
|
||||
def failing_migrate(cursor: sqlite3.Cursor) -> None:
|
||||
def failing_migrate(cursor: sqlite3.Cursor, **kwargs) -> None:
|
||||
raise Exception("Bad migration")
|
||||
|
||||
return failing_migrate
|
||||
|
||||
|
||||
def create_migrate(i: int) -> MigrateCallback:
|
||||
def migrate(cursor: sqlite3.Cursor) -> None:
|
||||
def migrate(cursor: sqlite3.Cursor, **kwargs) -> None:
|
||||
cursor.execute(f"CREATE TABLE test{i} (id INTEGER PRIMARY KEY);")
|
||||
|
||||
return migrate
|
||||
@ -77,30 +77,16 @@ def create_migrate(i: int) -> MigrateCallback:
|
||||
|
||||
def test_migration_to_version_is_one_gt_from_version(no_op_migrate_callback: MigrateCallback) -> None:
|
||||
with pytest.raises(ValidationError, match="to_version must be one greater than from_version"):
|
||||
Migration(from_version=0, to_version=2, migrate=no_op_migrate_callback)
|
||||
Migration(from_version=0, to_version=2, migrate_callback=no_op_migrate_callback)
|
||||
# not raising is sufficient
|
||||
Migration(from_version=1, to_version=2, migrate=no_op_migrate_callback)
|
||||
Migration(from_version=1, to_version=2, migrate_callback=no_op_migrate_callback)
|
||||
|
||||
|
||||
def test_migration_hash(no_op_migrate_callback: MigrateCallback) -> None:
|
||||
migration = Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback)
|
||||
migration = Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback)
|
||||
assert hash(migration) == hash((0, 1))
|
||||
|
||||
|
||||
def test_migration_registers_pre_and_post_callbacks(no_op_migrate_callback: MigrateCallback) -> None:
|
||||
def pre_callback(cursor: sqlite3.Cursor) -> None:
|
||||
pass
|
||||
|
||||
def post_callback(cursor: sqlite3.Cursor) -> None:
|
||||
pass
|
||||
|
||||
migration = Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback)
|
||||
migration.register_pre_callback(pre_callback)
|
||||
migration.register_post_callback(post_callback)
|
||||
assert pre_callback in migration.pre_migrate
|
||||
assert post_callback in migration.post_migrate
|
||||
|
||||
|
||||
def test_migration_set_add_migration(migrator: SQLiteMigrator, migration_no_op: Migration) -> None:
|
||||
migration = migration_no_op
|
||||
migrator._migration_set.register(migration)
|
||||
@ -110,13 +96,13 @@ def test_migration_set_add_migration(migrator: SQLiteMigrator, migration_no_op:
|
||||
def test_migration_set_may_not_register_dupes(
|
||||
migrator: SQLiteMigrator, no_op_migrate_callback: MigrateCallback
|
||||
) -> None:
|
||||
migrate_0_to_1_a = Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback)
|
||||
migrate_0_to_1_b = Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback)
|
||||
migrate_0_to_1_a = Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback)
|
||||
migrate_0_to_1_b = Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback)
|
||||
migrator._migration_set.register(migrate_0_to_1_a)
|
||||
with pytest.raises(MigrationVersionError, match=r"Migration with from_version or to_version already registered"):
|
||||
migrator._migration_set.register(migrate_0_to_1_b)
|
||||
migrate_1_to_2_a = Migration(from_version=1, to_version=2, migrate=no_op_migrate_callback)
|
||||
migrate_1_to_2_b = Migration(from_version=1, to_version=2, migrate=no_op_migrate_callback)
|
||||
migrate_1_to_2_a = Migration(from_version=1, to_version=2, migrate_callback=no_op_migrate_callback)
|
||||
migrate_1_to_2_b = Migration(from_version=1, to_version=2, migrate_callback=no_op_migrate_callback)
|
||||
migrator._migration_set.register(migrate_1_to_2_a)
|
||||
with pytest.raises(MigrationVersionError, match=r"Migration with from_version or to_version already registered"):
|
||||
migrator._migration_set.register(migrate_1_to_2_b)
|
||||
@ -131,15 +117,15 @@ def test_migration_set_gets_migration(migration_no_op: Migration) -> None:
|
||||
|
||||
def test_migration_set_validates_migration_chain(no_op_migrate_callback: MigrateCallback) -> None:
|
||||
migration_set = MigrationSet()
|
||||
migration_set.register(Migration(from_version=1, to_version=2, migrate=no_op_migrate_callback))
|
||||
migration_set.register(Migration(from_version=1, to_version=2, migrate_callback=no_op_migrate_callback))
|
||||
with pytest.raises(MigrationError, match="Migration chain is fragmented"):
|
||||
# no migration from 0 to 1
|
||||
migration_set.validate_migration_chain()
|
||||
migration_set.register(Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback))
|
||||
migration_set.register(Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback))
|
||||
migration_set.validate_migration_chain()
|
||||
migration_set.register(Migration(from_version=2, to_version=3, migrate=no_op_migrate_callback))
|
||||
migration_set.register(Migration(from_version=2, to_version=3, migrate_callback=no_op_migrate_callback))
|
||||
migration_set.validate_migration_chain()
|
||||
migration_set.register(Migration(from_version=4, to_version=5, migrate=no_op_migrate_callback))
|
||||
migration_set.register(Migration(from_version=4, to_version=5, migrate_callback=no_op_migrate_callback))
|
||||
with pytest.raises(MigrationError, match="Migration chain is fragmented"):
|
||||
# no migration from 3 to 4
|
||||
migration_set.validate_migration_chain()
|
||||
@ -148,18 +134,18 @@ def test_migration_set_validates_migration_chain(no_op_migrate_callback: Migrate
|
||||
def test_migration_set_counts_migrations(no_op_migrate_callback: MigrateCallback) -> None:
|
||||
migration_set = MigrationSet()
|
||||
assert migration_set.count == 0
|
||||
migration_set.register(Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback))
|
||||
migration_set.register(Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback))
|
||||
assert migration_set.count == 1
|
||||
migration_set.register(Migration(from_version=1, to_version=2, migrate=no_op_migrate_callback))
|
||||
migration_set.register(Migration(from_version=1, to_version=2, migrate_callback=no_op_migrate_callback))
|
||||
assert migration_set.count == 2
|
||||
|
||||
|
||||
def test_migration_set_gets_latest_version(no_op_migrate_callback: MigrateCallback) -> None:
|
||||
migration_set = MigrationSet()
|
||||
assert migration_set.latest_version == 0
|
||||
migration_set.register(Migration(from_version=1, to_version=2, migrate=no_op_migrate_callback))
|
||||
migration_set.register(Migration(from_version=1, to_version=2, migrate_callback=no_op_migrate_callback))
|
||||
assert migration_set.latest_version == 2
|
||||
migration_set.register(Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback))
|
||||
migration_set.register(Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback))
|
||||
assert migration_set.latest_version == 2
|
||||
|
||||
|
||||
@ -206,7 +192,7 @@ def test_migrator_runs_single_migration(migrator: SQLiteMigrator, migration_crea
|
||||
|
||||
def test_migrator_runs_all_migrations_in_memory(migrator: SQLiteMigrator) -> None:
|
||||
cursor = migrator._db.conn.cursor()
|
||||
migrations = [Migration(from_version=i, to_version=i + 1, migrate=create_migrate(i)) for i in range(0, 3)]
|
||||
migrations = [Migration(from_version=i, to_version=i + 1, migrate_callback=create_migrate(i)) for i in range(0, 3)]
|
||||
for migration in migrations:
|
||||
migrator.register_migration(migration)
|
||||
migrator.run_migrations()
|
||||
@ -219,7 +205,9 @@ def test_migrator_runs_all_migrations_file(logger: Logger) -> None:
|
||||
# The Migrator closes the database when it finishes; we cannot use a context manager.
|
||||
db = SqliteDatabase(db_path=original_db_path, logger=logger, verbose=False)
|
||||
migrator = SQLiteMigrator(db=db)
|
||||
migrations = [Migration(from_version=i, to_version=i + 1, migrate=create_migrate(i)) for i in range(0, 3)]
|
||||
migrations = [
|
||||
Migration(from_version=i, to_version=i + 1, migrate_callback=create_migrate(i)) for i in range(0, 3)
|
||||
]
|
||||
for migration in migrations:
|
||||
migrator.register_migration(migration)
|
||||
migrator.run_migrations()
|
||||
@ -235,7 +223,7 @@ def test_migrator_makes_no_changes_on_failed_migration(
|
||||
migrator.register_migration(migration_no_op)
|
||||
migrator.run_migrations()
|
||||
assert migrator._get_current_version(cursor) == 1
|
||||
migrator.register_migration(Migration(from_version=1, to_version=2, migrate=failing_migrate_callback))
|
||||
migrator.register_migration(Migration(from_version=1, to_version=2, migrate_callback=failing_migrate_callback))
|
||||
with pytest.raises(MigrationError, match="Bad migration"):
|
||||
migrator.run_migrations()
|
||||
assert migrator._get_current_version(cursor) == 1
|
||||
|
Loading…
Reference in New Issue
Block a user