feat(db): refactor migrate callbacks to use dependencies, remote pre/post callbacks

This commit is contained in:
psychedelicious
2023-12-12 12:35:42 +11:00
parent 6063760ce2
commit 0cf7fe43af
14 changed files with 230 additions and 181 deletions

View File

@ -1,11 +1,9 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from functools import partial
from logging import Logger
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import migration_1
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import migration_2
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2_post import migrate_embedded_workflows
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SQLiteMigrator
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.version.invokeai_version import __version__
@ -77,8 +75,11 @@ class ApiDependencies:
db_path = None if config.use_memory_db else config.db_path
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql)
# This migration requires an ImageFileStorageBase service and logger
migration_2.provide_dependency("image_files", image_files)
migration_2.provide_dependency("logger", logger)
migrator = SQLiteMigrator(db=db)
migration_2.register_post_callback(partial(migrate_embedded_workflows, logger=logger, image_files=image_files))
migrator.register_migration(migration_1)
migrator.register_migration(migration_2)
migrator.run_migrations()

View File

@ -3,7 +3,7 @@ import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
def _migrate(cursor: sqlite3.Cursor) -> None:
def migrate_callback(cursor: sqlite3.Cursor, **kwargs) -> None:
"""Migration callback for database version 1."""
_create_board_images(cursor)
@ -353,7 +353,7 @@ def _create_workflows(cursor: sqlite3.Cursor) -> None:
migration_1 = Migration(
from_version=0,
to_version=1,
migrate=_migrate,
migrate_callback=migrate_callback,
)
"""
Database version 1 (initial state).

View File

@ -1,16 +1,24 @@
import sqlite3
from logging import Logger
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
from tqdm import tqdm
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration, MigrationDependency
def _migrate(cursor: sqlite3.Cursor) -> None:
def migrate_callback(cursor: sqlite3.Cursor, **kwargs) -> None:
"""Migration callback for database version 2."""
logger = kwargs["logger"]
image_files = kwargs["image_files"]
_add_images_has_workflow(cursor)
_add_session_queue_workflow(cursor)
_drop_old_workflow_tables(cursor)
_add_workflow_library(cursor)
_drop_model_manager_metadata(cursor)
_migrate_embedded_workflows(cursor, logger, image_files)
def _add_images_has_workflow(cursor: sqlite3.Cursor) -> None:
@ -89,19 +97,64 @@ def _drop_model_manager_metadata(cursor: sqlite3.Cursor) -> None:
cursor.execute("DROP TABLE IF EXISTS model_manager_metadata;")
def _migrate_embedded_workflows(
cursor: sqlite3.Cursor,
logger: Logger,
image_files: ImageFileStorageBase,
) -> None:
"""
In the v3.5.0 release, InvokeAI changed how it handles embedded workflows. The `images` table in
the database now has a `has_workflow` column, indicating if an image has a workflow embedded.
This migrate callback checks each image for the presence of an embedded workflow, then updates its entry
in the database accordingly.
"""
# Get the total number of images and chunk it into pages
cursor.execute("SELECT image_name FROM images")
image_names: list[str] = [image[0] for image in cursor.fetchall()]
total_image_names = len(image_names)
if not total_image_names:
return
logger.info(f"Migrating workflows for {total_image_names} images")
# Migrate the images
to_migrate: list[tuple[bool, str]] = []
pbar = tqdm(image_names)
for idx, image_name in enumerate(pbar):
pbar.set_description(f"Checking image {idx + 1}/{total_image_names} for workflow")
pil_image = image_files.get(image_name)
if "invokeai_workflow" in pil_image.info:
to_migrate.append((True, image_name))
logger.info(f"Adding {len(to_migrate)} embedded workflows to database")
cursor.executemany("UPDATE images SET has_workflow = ? WHERE image_name = ?", to_migrate)
image_files_dependency = MigrationDependency(name="image_files", dependency_type=ImageFileStorageBase)
logger_dependency = MigrationDependency(name="logger", dependency_type=Logger)
migration_2 = Migration(
from_version=1,
to_version=2,
migrate=_migrate,
migrate_callback=migrate_callback,
dependencies={"image_files": image_files_dependency, "logger": logger_dependency},
)
"""
Database version 2.
Introduced in v3.5.0 for the new workflow library.
Dependencies:
- image_files: ImageFileStorageBase
- logger: Logger
Migration:
- Add `has_workflow` column to `images` table
- Add `workflow` column to `session_queue` table
- Drop `workflows` and `workflow_images` tables
- Add `workflow_library` table
- Populates the `has_workflow` column in the `images` table (requires `image_files` & `logger` dependencies)
"""

View File

@ -1,41 +0,0 @@
import sqlite3
from logging import Logger
from tqdm import tqdm
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
def migrate_embedded_workflows(
cursor: sqlite3.Cursor,
logger: Logger,
image_files: ImageFileStorageBase,
) -> None:
"""
In the v3.5.0 release, InvokeAI changed how it handles embedded workflows. The `images` table in
the database now has a `has_workflow` column, indicating if an image has a workflow embedded.
This migrate callbakc checks each image for the presence of an embedded workflow, then updates its entry
in the database accordingly.
"""
# Get the total number of images and chunk it into pages
cursor.execute("SELECT image_name FROM images")
image_names: list[str] = [image[0] for image in cursor.fetchall()]
total_image_names = len(image_names)
if not total_image_names:
return
logger.info(f"Migrating workflows for {total_image_names} images")
# Migrate the images
to_migrate: list[tuple[bool, str]] = []
pbar = tqdm(image_names)
for idx, image_name in enumerate(pbar):
pbar.set_description(f"Checking image {idx + 1}/{total_image_names} for workflow")
pil_image = image_files.get(image_name)
if "invokeai_workflow" in pil_image.info:
to_migrate.append((True, image_name))
logger.info(f"Adding {len(to_migrate)} embedded workflows to database")
cursor.executemany("UPDATE images SET has_workflow = ? WHERE image_name = ?", to_migrate)

View File

@ -1,9 +1,16 @@
import sqlite3
from typing import Callable, Optional, TypeAlias
from functools import partial
from typing import Any, Optional, Protocol, runtime_checkable
from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, ConfigDict, Field, model_validator
MigrateCallback: TypeAlias = Callable[[sqlite3.Cursor], None]
@runtime_checkable
class MigrateCallback(Protocol):
"""A callback that performs a migration."""
def __call__(self, cursor: sqlite3.Cursor, **kwargs: Any) -> None:
...
class MigrationError(RuntimeError):
@ -14,28 +21,65 @@ class MigrationVersionError(ValueError):
"""Raised when a migration version is invalid."""
class MigrationDependency:
"""Represents a dependency for a migration."""
def __init__(
self,
name: str,
dependency_type: Any,
) -> None:
self.name = name
self.dependency_type = dependency_type
self.value = None
def set(self, value: Any) -> None:
"""Sets the value of the dependency."""
if not isinstance(value, self.dependency_type):
raise ValueError(f"Dependency {self.name} must be of type {self.dependency_type}")
self.value = value
class Migration(BaseModel):
"""
Represents a migration for a SQLite database.
:param from_version: The database version on which this migration may be run
:param to_version: The database version that results from this migration
:param migrate: The callback to run to perform the migration
:param dependencies: A dict of dependencies that must be provided to the migration
Migration callbacks will be provided an open cursor to the database. They should not commit their
transaction; this is handled by the migrator.
Pre- and post-migration callback may be registered with :meth:`register_pre_callback` or
:meth:`register_post_callback`.
If a migration needs an additional dependency, it must be provided with :meth:`provide_dependency`
before the migration is run.
If a migration has additional dependencies, it is recommended to use functools.partial to provide
the dependencies and register the partial as the migration callback.
Example Usage:
```py
# Define the migrate callback
def migrate_callback(cursor: sqlite3.Cursor, **kwargs) -> None:
some_dependency = kwargs["some_dependency"]
...
# Instantiate the migration, declaring dependencies
migration = Migration(
from_version=0,
to_version=1,
migrate_callback=migrate_callback,
dependencies={"some_dependency": MigrationDependency(name="some_dependency", dependency_type=SomeType)},
)
# Register the dependency before running the migration
migration.provide_dependency(name="some_dependency", value=some_value)
```
"""
from_version: int = Field(ge=0, strict=True, description="The database version on which this migration may be run")
to_version: int = Field(ge=1, strict=True, description="The database version that results from this migration")
migrate: MigrateCallback = Field(description="The callback to run to perform the migration")
pre_migrate: list[MigrateCallback] = Field(
default=[], description="A list of callbacks to run before the migration"
)
post_migrate: list[MigrateCallback] = Field(
default=[], description="A list of callbacks to run after the migration"
migrate_callback: MigrateCallback = Field(description="The callback to run to perform the migration")
dependencies: dict[str, MigrationDependency] = Field(
default={}, description="A dict of dependencies that must be provided to the migration"
)
@model_validator(mode="after")
@ -48,13 +92,21 @@ class Migration(BaseModel):
# Callables are not hashable, so we need to implement our own __hash__ function to use this class in a set.
return hash((self.from_version, self.to_version))
def register_pre_callback(self, callback: MigrateCallback) -> None:
"""Registers a pre-migration callback."""
self.pre_migrate.append(callback)
def provide_dependency(self, name: str, value: Any) -> None:
"""Provides a dependency for this migration."""
if name not in self.dependencies:
raise ValueError(f"{name} of type {type(value)} is not a dependency of this migration")
self.dependencies[name].set(value)
def register_post_callback(self, callback: MigrateCallback) -> None:
"""Registers a post-migration callback."""
self.post_migrate.append(callback)
def run(self, cursor: sqlite3.Cursor) -> None:
"""Runs the migration."""
missing_dependencies = [d.name for d in self.dependencies.values() if d.value is None]
if missing_dependencies:
raise ValueError(f"Missing migration dependencies: {', '.join(missing_dependencies)}")
self.migrate_callback = partial(self.migrate_callback, **{d.name: d.value for d in self.dependencies.values()})
self.migrate_callback(cursor=cursor)
model_config = ConfigDict(arbitrary_types_allowed=True)
class MigrationSet:
@ -78,7 +130,7 @@ class MigrationSet:
def validate_migration_chain(self) -> None:
"""
Validates that the migrations form a single chain of migrations from version 0 to the latest version.
Validates that the migrations form a single chain of migrations from version 0 to the latest version,
Raises a MigrationError if there is a problem.
"""
if self.count == 0:

View File

@ -66,21 +66,12 @@ class SQLiteMigrator:
)
self._logger.debug(f"Running migration from {migration.from_version} to {migration.to_version}")
# Run pre-migration callbacks
if migration.pre_migrate:
self._logger.debug(f"Running {len(migration.pre_migrate)} pre-migration callbacks")
for callback in migration.pre_migrate:
callback(cursor)
# Run the actual migration
migration.migrate(cursor)
migration.run(cursor)
# Update the version
cursor.execute("INSERT INTO migrations (version) VALUES (?);", (migration.to_version,))
# Run post-migration callbacks
if migration.post_migrate:
self._logger.debug(f"Running {len(migration.post_migrate)} post-migration callbacks")
for callback in migration.post_migrate:
callback(cursor)
self._logger.debug(
f"Successfully migrated database from {migration.from_version} to {migration.to_version}"
)