feat(db): address feedback, cleanup

- use simpler pattern for migration dependencies
- move SqliteDatabase & migration to utility method `init_db`, use this in both the app and tests, ensuring the same db schema is used in both
This commit is contained in:
psychedelicious 2023-12-13 11:19:59 +11:00
parent 386b656530
commit ebf5f5d418
12 changed files with 615 additions and 722 deletions

View File

@ -2,9 +2,7 @@
from logging import Logger from logging import Logger
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import migration_1 from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import migration_2
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SQLiteMigrator
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from invokeai.version.invokeai_version import __version__ from invokeai.version.invokeai_version import __version__
@ -33,7 +31,6 @@ from ..services.session_processor.session_processor_default import DefaultSessio
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
from ..services.shared.default_graphs import create_system_graphs from ..services.shared.default_graphs import create_system_graphs
from ..services.shared.graph import GraphExecutionState, LibraryGraph from ..services.shared.graph import GraphExecutionState, LibraryGraph
from ..services.shared.sqlite.sqlite_database import SqliteDatabase
from ..services.urls.urls_default import LocalUrlService from ..services.urls.urls_default import LocalUrlService
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
from .events import FastAPIEventService from .events import FastAPIEventService
@ -72,17 +69,7 @@ class ApiDependencies:
output_folder = config.output_path output_folder = config.output_path
image_files = DiskImageFileStorage(f"{output_folder}/images") image_files = DiskImageFileStorage(f"{output_folder}/images")
db_path = None if config.use_memory_db else config.db_path db = init_db(config=config, logger=logger, image_files=image_files)
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql)
# This migration requires an ImageFileStorageBase service and logger
migration_2.provide_dependency("image_files", image_files)
migration_2.provide_dependency("logger", logger)
migrator = SQLiteMigrator(db=db)
migrator.register_migration(migration_1)
migrator.register_migration(migration_2)
migrator.run_migrations()
configuration = config configuration = config
logger = logger logger = logger

View File

@ -0,0 +1,32 @@
from logging import Logger
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import build_migration_1
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import build_migration_2
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SQLiteMigrator
def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileStorageBase) -> SqliteDatabase:
"""
Initializes the SQLite database.
:param config: The app config
:param logger: The logger
:param image_files: The image files service (used by migration 2)
This function:
- Instantiates a :class:`SqliteDatabase`
- Instantiates a :class:`SQLiteMigrator` and registers all migrations
- Runs all migrations
"""
db_path = None if config.use_memory_db else config.db_path
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql)
migrator = SQLiteMigrator(db=db)
migrator.register_migration(build_migration_1())
migrator.register_migration(build_migration_2(image_files=image_files, logger=logger))
migrator.run_migrations()
return db

View File

@ -3,19 +3,19 @@ import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
def migrate_callback(cursor: sqlite3.Cursor, **kwargs) -> None: class Migration1Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
"""Migration callback for database version 1.""" """Migration callback for database version 1."""
_create_board_images(cursor) self._create_board_images(cursor)
_create_boards(cursor) self._create_boards(cursor)
_create_images(cursor) self._create_images(cursor)
_create_model_config(cursor) self._create_model_config(cursor)
_create_session_queue(cursor) self._create_session_queue(cursor)
_create_workflow_images(cursor) self._create_workflow_images(cursor)
_create_workflows(cursor) self._create_workflows(cursor)
def _create_board_images(self, cursor: sqlite3.Cursor) -> None:
def _create_board_images(cursor: sqlite3.Cursor) -> None:
"""Creates the `board_images` table, indices and triggers.""" """Creates the `board_images` table, indices and triggers."""
tables = [ tables = [
"""--sql """--sql
@ -56,8 +56,7 @@ def _create_board_images(cursor: sqlite3.Cursor) -> None:
for stmt in tables + indices + triggers: for stmt in tables + indices + triggers:
cursor.execute(stmt) cursor.execute(stmt)
def _create_boards(self, cursor: sqlite3.Cursor) -> None:
def _create_boards(cursor: sqlite3.Cursor) -> None:
"""Creates the `boards` table, indices and triggers.""" """Creates the `boards` table, indices and triggers."""
tables = [ tables = [
"""--sql """--sql
@ -92,8 +91,7 @@ def _create_boards(cursor: sqlite3.Cursor) -> None:
for stmt in tables + indices + triggers: for stmt in tables + indices + triggers:
cursor.execute(stmt) cursor.execute(stmt)
def _create_images(self, cursor: sqlite3.Cursor) -> None:
def _create_images(cursor: sqlite3.Cursor) -> None:
"""Creates the `images` table, indices and triggers. Adds the `starred` column.""" """Creates the `images` table, indices and triggers. Adds the `starred` column."""
tables = [ tables = [
@ -149,8 +147,7 @@ def _create_images(cursor: sqlite3.Cursor) -> None:
for stmt in tables + indices + triggers: for stmt in tables + indices + triggers:
cursor.execute(stmt) cursor.execute(stmt)
def _create_model_config(self, cursor: sqlite3.Cursor) -> None:
def _create_model_config(cursor: sqlite3.Cursor) -> None:
"""Creates the `model_config` table, `model_manager_metadata` table, indices and triggers.""" """Creates the `model_config` table, `model_manager_metadata` table, indices and triggers."""
tables = [ tables = [
@ -205,8 +202,7 @@ def _create_model_config(cursor: sqlite3.Cursor) -> None:
for stmt in tables + indices + triggers: for stmt in tables + indices + triggers:
cursor.execute(stmt) cursor.execute(stmt)
def _create_session_queue(self, cursor: sqlite3.Cursor) -> None:
def _create_session_queue(cursor: sqlite3.Cursor) -> None:
tables = [ tables = [
"""--sql """--sql
CREATE TABLE IF NOT EXISTS session_queue ( CREATE TABLE IF NOT EXISTS session_queue (
@ -279,8 +275,7 @@ def _create_session_queue(cursor: sqlite3.Cursor) -> None:
for stmt in tables + indices + triggers: for stmt in tables + indices + triggers:
cursor.execute(stmt) cursor.execute(stmt)
def _create_workflow_images(self, cursor: sqlite3.Cursor) -> None:
def _create_workflow_images(cursor: sqlite3.Cursor) -> None:
tables = [ tables = [
"""--sql """--sql
CREATE TABLE IF NOT EXISTS workflow_images ( CREATE TABLE IF NOT EXISTS workflow_images (
@ -320,8 +315,7 @@ def _create_workflow_images(cursor: sqlite3.Cursor) -> None:
for stmt in tables + indices + triggers: for stmt in tables + indices + triggers:
cursor.execute(stmt) cursor.execute(stmt)
def _create_workflows(self, cursor: sqlite3.Cursor) -> None:
def _create_workflows(cursor: sqlite3.Cursor) -> None:
tables = [ tables = [
"""--sql """--sql
CREATE TABLE IF NOT EXISTS workflows ( CREATE TABLE IF NOT EXISTS workflows (
@ -350,13 +344,9 @@ def _create_workflows(cursor: sqlite3.Cursor) -> None:
cursor.execute(stmt) cursor.execute(stmt)
migration_1 = Migration( def build_migration_1() -> Migration:
from_version=0,
to_version=1,
migrate_callback=migrate_callback,
)
""" """
Database version 1 (initial state). Builds the migration from database version 0 (init) to 1.
This migration represents the state of the database circa InvokeAI v3.4.0, which was the last This migration represents the state of the database circa InvokeAI v3.4.0, which was the last
version to not use migrations to manage the database. version to not use migrations to manage the database.
@ -372,3 +362,11 @@ to be idempotent.
- Create `workflow_images` junction table - Create `workflow_images` junction table
- Create `workflows` table - Create `workflows` table
""" """
migration_1 = Migration(
from_version=0,
to_version=1,
callback=Migration1Callback(),
)
return migration_1

View File

@ -4,29 +4,24 @@ from logging import Logger
from tqdm import tqdm from tqdm import tqdm
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration, MigrationDependency from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
# This migration requires an ImageFileStorageBase service and logger
image_files_dependency = MigrationDependency(name="image_files", dependency_type=ImageFileStorageBase)
logger_dependency = MigrationDependency(name="logger", dependency_type=Logger)
def migrate_callback(cursor: sqlite3.Cursor, **kwargs) -> None: class Migration2Callback:
"""Migration callback for database version 2.""" def __init__(self, image_files: ImageFileStorageBase, logger: Logger):
self._image_files = image_files
self._logger = logger
logger = kwargs[logger_dependency.name] def __call__(self, cursor: sqlite3.Cursor):
image_files = kwargs[image_files_dependency.name] self._add_images_has_workflow(cursor)
self._add_session_queue_workflow(cursor)
self._drop_old_workflow_tables(cursor)
self._add_workflow_library(cursor)
self._drop_model_manager_metadata(cursor)
self._recreate_model_config(cursor)
self._migrate_embedded_workflows(cursor)
_add_images_has_workflow(cursor) def _add_images_has_workflow(self, cursor: sqlite3.Cursor) -> None:
_add_session_queue_workflow(cursor)
_drop_old_workflow_tables(cursor)
_add_workflow_library(cursor)
_drop_model_manager_metadata(cursor)
_recreate_model_config(cursor)
_migrate_embedded_workflows(cursor, logger, image_files)
def _add_images_has_workflow(cursor: sqlite3.Cursor) -> None:
"""Add the `has_workflow` column to `images` table.""" """Add the `has_workflow` column to `images` table."""
cursor.execute("PRAGMA table_info(images)") cursor.execute("PRAGMA table_info(images)")
columns = [column[1] for column in cursor.fetchall()] columns = [column[1] for column in cursor.fetchall()]
@ -34,8 +29,7 @@ def _add_images_has_workflow(cursor: sqlite3.Cursor) -> None:
if "has_workflow" not in columns: if "has_workflow" not in columns:
cursor.execute("ALTER TABLE images ADD COLUMN has_workflow BOOLEAN DEFAULT FALSE;") cursor.execute("ALTER TABLE images ADD COLUMN has_workflow BOOLEAN DEFAULT FALSE;")
def _add_session_queue_workflow(self, cursor: sqlite3.Cursor) -> None:
def _add_session_queue_workflow(cursor: sqlite3.Cursor) -> None:
"""Add the `workflow` column to `session_queue` table.""" """Add the `workflow` column to `session_queue` table."""
cursor.execute("PRAGMA table_info(session_queue)") cursor.execute("PRAGMA table_info(session_queue)")
@ -44,14 +38,12 @@ def _add_session_queue_workflow(cursor: sqlite3.Cursor) -> None:
if "workflow" not in columns: if "workflow" not in columns:
cursor.execute("ALTER TABLE session_queue ADD COLUMN workflow TEXT;") cursor.execute("ALTER TABLE session_queue ADD COLUMN workflow TEXT;")
def _drop_old_workflow_tables(self, cursor: sqlite3.Cursor) -> None:
def _drop_old_workflow_tables(cursor: sqlite3.Cursor) -> None:
"""Drops the `workflows` and `workflow_images` tables.""" """Drops the `workflows` and `workflow_images` tables."""
cursor.execute("DROP TABLE IF EXISTS workflow_images;") cursor.execute("DROP TABLE IF EXISTS workflow_images;")
cursor.execute("DROP TABLE IF EXISTS workflows;") cursor.execute("DROP TABLE IF EXISTS workflows;")
def _add_workflow_library(self, cursor: sqlite3.Cursor) -> None:
def _add_workflow_library(cursor: sqlite3.Cursor) -> None:
"""Adds the `workflow_library` table and drops the `workflows` and `workflow_images` tables.""" """Adds the `workflow_library` table and drops the `workflows` and `workflow_images` tables."""
tables = [ tables = [
"""--sql """--sql
@ -96,13 +88,11 @@ def _add_workflow_library(cursor: sqlite3.Cursor) -> None:
for stmt in tables + indices + triggers: for stmt in tables + indices + triggers:
cursor.execute(stmt) cursor.execute(stmt)
def _drop_model_manager_metadata(self, cursor: sqlite3.Cursor) -> None:
def _drop_model_manager_metadata(cursor: sqlite3.Cursor) -> None:
"""Drops the `model_manager_metadata` table.""" """Drops the `model_manager_metadata` table."""
cursor.execute("DROP TABLE IF EXISTS model_manager_metadata;") cursor.execute("DROP TABLE IF EXISTS model_manager_metadata;")
def _recreate_model_config(self, cursor: sqlite3.Cursor) -> None:
def _recreate_model_config(cursor: sqlite3.Cursor) -> None:
""" """
Drops the `model_config` table, recreating it. Drops the `model_config` table, recreating it.
@ -136,12 +126,7 @@ def _recreate_model_config(cursor: sqlite3.Cursor) -> None:
""" """
) )
def _migrate_embedded_workflows(self, cursor: sqlite3.Cursor) -> None:
def _migrate_embedded_workflows(
cursor: sqlite3.Cursor,
logger: Logger,
image_files: ImageFileStorageBase,
) -> None:
""" """
In the v3.5.0 release, InvokeAI changed how it handles embedded workflows. The `images` table in In the v3.5.0 release, InvokeAI changed how it handles embedded workflows. The `images` table in
the database now has a `has_workflow` column, indicating if an image has a workflow embedded. the database now has a `has_workflow` column, indicating if an image has a workflow embedded.
@ -157,40 +142,43 @@ def _migrate_embedded_workflows(
if not total_image_names: if not total_image_names:
return return
logger.info(f"Migrating workflows for {total_image_names} images") self._logger.info(f"Migrating workflows for {total_image_names} images")
# Migrate the images # Migrate the images
to_migrate: list[tuple[bool, str]] = [] to_migrate: list[tuple[bool, str]] = []
pbar = tqdm(image_names) pbar = tqdm(image_names)
for idx, image_name in enumerate(pbar): for idx, image_name in enumerate(pbar):
pbar.set_description(f"Checking image {idx + 1}/{total_image_names} for workflow") pbar.set_description(f"Checking image {idx + 1}/{total_image_names} for workflow")
pil_image = image_files.get(image_name) pil_image = self._image_files.get(image_name)
if "invokeai_workflow" in pil_image.info: if "invokeai_workflow" in pil_image.info:
to_migrate.append((True, image_name)) to_migrate.append((True, image_name))
logger.info(f"Adding {len(to_migrate)} embedded workflows to database") self._logger.info(f"Adding {len(to_migrate)} embedded workflows to database")
cursor.executemany("UPDATE images SET has_workflow = ? WHERE image_name = ?", to_migrate) cursor.executemany("UPDATE images SET has_workflow = ? WHERE image_name = ?", to_migrate)
migration_2 = Migration( def build_migration_2(image_files: ImageFileStorageBase, logger: Logger) -> Migration:
from_version=1,
to_version=2,
migrate_callback=migrate_callback,
dependencies={image_files_dependency.name: image_files_dependency, logger_dependency.name: logger_dependency},
)
""" """
Database version 2. Builds the migration from database version 1 to 2.
Introduced in v3.5.0 for the new workflow library. Introduced in v3.5.0 for the new workflow library.
Dependencies: :param image_files: The image files service, used to check for embedded workflows
- image_files: ImageFileStorageBase :param logger: The logger, used to log progress during embedded workflows handling
- logger: Logger
Migration: This migration does the following:
- Add `has_workflow` column to `images` table - Add `has_workflow` column to `images` table
- Add `workflow` column to `session_queue` table - Add `workflow` column to `session_queue` table
- Drop `workflows` and `workflow_images` tables - Drop `workflows` and `workflow_images` tables
- Add `workflow_library` table - Add `workflow_library` table
- Drops the `model_manager_metadata` table
- Drops the `model_config` table, recreating it (at this point, there is no user data in this table)
- Populates the `has_workflow` column in the `images` table (requires `image_files` & `logger` dependencies) - Populates the `has_workflow` column in the `images` table (requires `image_files` & `logger` dependencies)
""" """
migration_2 = Migration(
from_version=1,
to_version=2,
callback=Migration2Callback(image_files=image_files, logger=logger),
)
return migration_2

View File

@ -1,6 +1,5 @@
import sqlite3 import sqlite3
from functools import partial from typing import Optional, Protocol, runtime_checkable
from typing import Any, Optional, Protocol, runtime_checkable
from pydantic import BaseModel, ConfigDict, Field, model_validator from pydantic import BaseModel, ConfigDict, Field, model_validator
@ -18,7 +17,7 @@ class MigrateCallback(Protocol):
See :class:`Migration` for an example. See :class:`Migration` for an example.
""" """
def __call__(self, cursor: sqlite3.Cursor, **kwargs: Any) -> None: def __call__(self, cursor: sqlite3.Cursor) -> None:
... ...
@ -30,96 +29,69 @@ class MigrationVersionError(ValueError):
"""Raised when a migration version is invalid.""" """Raised when a migration version is invalid."""
class MigrationDependency:
"""
Represents a dependency for a migration.
:param name: The name of the dependency
:param dependency_type: The type of the dependency (e.g. `str`, `int`, `SomeClass`, etc.)
"""
def __init__(
self,
name: str,
dependency_type: Any,
) -> None:
self.name = name
self.dependency_type = dependency_type
self.value = None
def set_value(self, value: Any) -> None:
"""
Sets the value of the dependency.
If the value is not of the correct type, a TypeError is raised.
"""
if not isinstance(value, self.dependency_type):
raise TypeError(f"Dependency {self.name} must be of type {self.dependency_type}")
self.value = value
class Migration(BaseModel): class Migration(BaseModel):
""" """
Represents a migration for a SQLite database. Represents a migration for a SQLite database.
:param from_version: The database version on which this migration may be run :param from_version: The database version on which this migration may be run
:param to_version: The database version that results from this migration :param to_version: The database version that results from this migration
:param migrate: The callback to run to perform the migration :param migrate_callback: The callback to run to perform the migration
:param dependencies: A dict of dependencies that must be provided to the migration
Migration callbacks will be provided an open cursor to the database. They should not commit their Migration callbacks will be provided an open cursor to the database. They should not commit their
transaction; this is handled by the migrator. transaction; this is handled by the migrator.
Example Usage: It is suggested to use a class to define the migration callback and a builder function to create
the :class:`Migration`. This allows the callback to be provided with additional dependencies and
keeps things tidy, as all migration logic is self-contained.
Example:
```py ```py
# Define the migrate callback. This migration adds a column to the sushi table. # Define the migration callback class
def migrate_callback(cursor: sqlite3.Cursor, **kwargs) -> None: class Migration1Callback:
# This migration needs a logger, so we define a class that accepts a logger in its constructor.
def __init__(self, image_files: ImageFileStorageBase) -> None:
self._image_files = ImageFileStorageBase
# This dunder method allows the instance of the class to be called like a function.
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._add_with_banana_column(cursor)
self._do_something_with_images(cursor)
def _add_with_banana_column(self, cursor: sqlite3.Cursor) -> None:
\"""Adds the with_banana column to the sushi table.\"""
# Execute SQL using the cursor, taking care to *not commit* a transaction # Execute SQL using the cursor, taking care to *not commit* a transaction
cursor.execute('ALTER TABLE sushi ADD COLUMN with_banana BOOLEAN DEFAULT TRUE;') cursor.execute('ALTER TABLE sushi ADD COLUMN with_banana BOOLEAN DEFAULT TRUE;')
...
# Instantiate the migration def _do_something_with_images(self, cursor: sqlite3.Cursor) -> None:
migration = Migration( \"""Does something with the image files service.\"""
self._image_files.get(...)
# Define the migration builder function. This function creates an instance of the migration callback
# class and returns a Migration.
def build_migration_1(image_files: ImageFileStorageBase) -> Migration:
\"""Builds the migration from database version 0 to 1.
Requires the image files service to...
\"""
migration_1 = Migration(
from_version=0, from_version=0,
to_version=1, to_version=1,
migrate_callback=migrate_callback, migrate_callback=Migration1Callback(image_files=image_files),
)
```
If a migration needs an additional dependency, it must be provided with :meth:`provide_dependency`
before the migration is run. The migrator provides dependencies to the migrate callback,
raising an error if a dependency is missing or was provided the wrong type.
Example Usage:
```py
# Create a migration dependency. This migration needs access the image files service, so we set the type to the ABC of that service.
image_files_dependency = MigrationDependency(name="image_files", dependency_type=ImageFileStorageBase)
# Define the migrate callback. The dependency may be accessed by name in the kwargs. The migrator will ensure that the dependency is of the required type.
def migrate_callback(cursor: sqlite3.Cursor, **kwargs) -> None:
image_files = kwargs[image_files_dependency.name]
# Do something with image_files
...
# Instantiate the migration, including the dependency.
migration = Migration(
from_version=0,
to_version=1,
migrate_callback=migrate_callback,
dependencies={image_files_dependency.name: image_files_dependency},
) )
# Provide the dependency before registering the migration. return migration_1
# (DiskImageFileStorage is an implementation of ImageFileStorageBase)
migration.provide_dependency(name="image_files", value=DiskImageFileStorage()) # Register the migration after all dependencies have been initialized
db = SqliteDatabase(db_path, logger)
migrator = SqliteMigrator(db)
migrator.register_migration(build_migration_1(image_files))
migrator.run_migrations()
``` ```
""" """
from_version: int = Field(ge=0, strict=True, description="The database version on which this migration may be run") from_version: int = Field(ge=0, strict=True, description="The database version on which this migration may be run")
to_version: int = Field(ge=1, strict=True, description="The database version that results from this migration") to_version: int = Field(ge=1, strict=True, description="The database version that results from this migration")
migrate_callback: MigrateCallback = Field(description="The callback to run to perform the migration") callback: MigrateCallback = Field(description="The callback to run to perform the migration")
dependencies: dict[str, MigrationDependency] = Field(
default={}, description="A dict of dependencies that must be provided to the migration"
)
@model_validator(mode="after") @model_validator(mode="after")
def validate_to_version(self) -> "Migration": def validate_to_version(self) -> "Migration":
@ -132,23 +104,6 @@ class Migration(BaseModel):
# Callables are not hashable, so we need to implement our own __hash__ function to use this class in a set. # Callables are not hashable, so we need to implement our own __hash__ function to use this class in a set.
return hash((self.from_version, self.to_version)) return hash((self.from_version, self.to_version))
def provide_dependency(self, name: str, value: Any) -> None:
"""Provides a dependency for this migration."""
if name not in self.dependencies:
raise ValueError(f"{name} of type {type(value)} is not a dependency of this migration")
self.dependencies[name].set_value(value)
def run(self, cursor: sqlite3.Cursor) -> None:
"""
Runs the migration.
If any dependencies are missing, a MigrationError is raised.
"""
missing_dependencies = [d.name for d in self.dependencies.values() if d.value is None]
if missing_dependencies:
raise MigrationError(f"Missing migration dependencies: {', '.join(missing_dependencies)}")
self.migrate_callback = partial(self.migrate_callback, **{d.name: d.value for d in self.dependencies.values()})
self.migrate_callback(cursor=cursor)
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)

View File

@ -20,8 +20,8 @@ class SQLiteMigrator:
```py ```py
db = SqliteDatabase(db_path="my_db.db", logger=logger) db = SqliteDatabase(db_path="my_db.db", logger=logger)
migrator = SQLiteMigrator(db=db) migrator = SQLiteMigrator(db=db)
migrator.register_migration(migration_1) migrator.register_migration(build_migration_1())
migrator.register_migration(migration_2) migrator.register_migration(build_migration_2())
migrator.run_migrations() migrator.run_migrations()
``` ```
""" """
@ -76,7 +76,7 @@ class SQLiteMigrator:
self._logger.debug(f"Running migration from {migration.from_version} to {migration.to_version}") self._logger.debug(f"Running migration from {migration.from_version} to {migration.to_version}")
# Run the actual migration # Run the actual migration
migration.run(cursor) migration.callback(cursor)
# Update the version # Update the version
cursor.execute("INSERT INTO migrations (version) VALUES (?);", (migration.to_version,)) cursor.execute("INSERT INTO migrations (version) VALUES (?);", (migration.to_version,))

View File

@ -29,7 +29,7 @@ from invokeai.app.services.shared.graph import (
LibraryGraph, LibraryGraph,
) )
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from tests.fixtures.sqlite_database import CreateSqliteDatabaseFunction from tests.fixtures.sqlite_database import create_mock_sqlite_database
from .test_invoker import create_edge from .test_invoker import create_edge
@ -47,10 +47,10 @@ def simple_graph():
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate # Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
# the test invocations. # the test invocations.
@pytest.fixture @pytest.fixture
def mock_services(create_sqlite_database: CreateSqliteDatabaseFunction) -> InvocationServices: def mock_services() -> InvocationServices:
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0) configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
logger = InvokeAILogger.get_logger() logger = InvokeAILogger.get_logger()
db = create_sqlite_database(configuration, logger) db = create_mock_sqlite_database(configuration, logger)
# NOTE: none of these are actually called by the test invocations # NOTE: none of these are actually called by the test invocations
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions") graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
return InvocationServices( return InvocationServices(

View File

@ -4,7 +4,7 @@ import pytest
from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from tests.fixtures.sqlite_database import CreateSqliteDatabaseFunction from tests.fixtures.sqlite_database import create_mock_sqlite_database
# This import must happen before other invoke imports or test in other files(!!) break # This import must happen before other invoke imports or test in other files(!!) break
from .test_nodes import ( # isort: split from .test_nodes import ( # isort: split
@ -51,10 +51,10 @@ def graph_with_subgraph():
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate # Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
# the test invocations. # the test invocations.
@pytest.fixture @pytest.fixture
def mock_services(create_sqlite_database: CreateSqliteDatabaseFunction) -> InvocationServices: def mock_services() -> InvocationServices:
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0) configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
logger = InvokeAILogger.get_logger() logger = InvokeAILogger.get_logger()
db = create_sqlite_database(configuration, logger) db = create_mock_sqlite_database(configuration, logger)
# NOTE: none of these are actually called by the test invocations # NOTE: none of these are actually called by the test invocations
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions") graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")

View File

@ -20,7 +20,7 @@ from invokeai.app.services.model_install import (
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL, UnknownModelException from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL, UnknownModelException
from invokeai.backend.model_manager.config import BaseModelType, ModelType from invokeai.backend.model_manager.config import BaseModelType, ModelType
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from tests.fixtures.sqlite_database import CreateSqliteDatabaseFunction from tests.fixtures.sqlite_database import create_mock_sqlite_database
@pytest.fixture @pytest.fixture
@ -38,10 +38,10 @@ def app_config(datadir: Path) -> InvokeAIAppConfig:
@pytest.fixture @pytest.fixture
def store( def store(
app_config: InvokeAIAppConfig, create_sqlite_database: CreateSqliteDatabaseFunction app_config: InvokeAIAppConfig,
) -> ModelRecordServiceBase: ) -> ModelRecordServiceBase:
logger = InvokeAILogger.get_logger(config=app_config) logger = InvokeAILogger.get_logger(config=app_config)
db = create_sqlite_database(app_config, logger) db = create_mock_sqlite_database(app_config, logger)
store: ModelRecordServiceBase = ModelRecordServiceSQL(db) store: ModelRecordServiceBase = ModelRecordServiceSQL(db)
return store return store

View File

@ -23,14 +23,16 @@ from invokeai.backend.model_manager.config import (
VaeDiffusersConfig, VaeDiffusersConfig,
) )
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from tests.fixtures.sqlite_database import CreateSqliteDatabaseFunction from tests.fixtures.sqlite_database import create_mock_sqlite_database
@pytest.fixture @pytest.fixture
def store(datadir: Any, create_sqlite_database: CreateSqliteDatabaseFunction) -> ModelRecordServiceBase: def store(
datadir: Any,
) -> ModelRecordServiceBase:
config = InvokeAIAppConfig(root=datadir) config = InvokeAIAppConfig(root=datadir)
logger = InvokeAILogger.get_logger(config=config) logger = InvokeAILogger.get_logger(config=config)
db = create_sqlite_database(config, logger) db = create_mock_sqlite_database(config, logger)
return ModelRecordServiceSQL(db) return ModelRecordServiceSQL(db)

View File

@ -4,21 +4,13 @@ from unittest import mock
from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import migration_1 from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import migration_2
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SQLiteMigrator
def create_sqlite_database(config: InvokeAIAppConfig, logger: Logger) -> SqliteDatabase: def create_mock_sqlite_database(
db_path = None if config.use_memory_db else config.db_path config: InvokeAIAppConfig,
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql) logger: Logger,
) -> SqliteDatabase:
image_files = mock.Mock(spec=ImageFileStorageBase) image_files = mock.Mock(spec=ImageFileStorageBase)
db = init_db(config=config, logger=logger, image_files=image_files)
migrator = SQLiteMigrator(db=db)
migration_2.provide_dependency("logger", logger)
migration_2.provide_dependency("image_files", image_files)
migrator.register_migration(migration_1)
migrator.register_migration(migration_2)
migrator.run_migrations()
return db return db

View File

@ -1,5 +1,4 @@
import sqlite3 import sqlite3
from abc import ABC, abstractmethod
from contextlib import closing from contextlib import closing
from logging import Logger from logging import Logger
from pathlib import Path from pathlib import Path
@ -12,7 +11,6 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import ( from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import (
MigrateCallback, MigrateCallback,
Migration, Migration,
MigrationDependency,
MigrationError, MigrationError,
MigrationSet, MigrationSet,
MigrationVersionError, MigrationVersionError,
@ -53,7 +51,7 @@ def no_op_migrate_callback() -> MigrateCallback:
@pytest.fixture @pytest.fixture
def migration_no_op(no_op_migrate_callback: MigrateCallback) -> Migration: def migration_no_op(no_op_migrate_callback: MigrateCallback) -> Migration:
return Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback) return Migration(from_version=0, to_version=1, callback=no_op_migrate_callback)
@pytest.fixture @pytest.fixture
@ -75,7 +73,7 @@ def migrate_callback_create_test_table() -> MigrateCallback:
@pytest.fixture @pytest.fixture
def migration_create_test_table(migrate_callback_create_test_table: MigrateCallback) -> Migration: def migration_create_test_table(migrate_callback_create_test_table: MigrateCallback) -> Migration:
return Migration(from_version=0, to_version=1, migrate_callback=migrate_callback_create_test_table) return Migration(from_version=0, to_version=1, callback=migrate_callback_create_test_table)
@pytest.fixture @pytest.fixture
@ -83,7 +81,7 @@ def failing_migration() -> Migration:
def failing_migration(cursor: sqlite3.Cursor, **kwargs) -> None: def failing_migration(cursor: sqlite3.Cursor, **kwargs) -> None:
raise Exception("Bad migration") raise Exception("Bad migration")
return Migration(from_version=0, to_version=1, migrate_callback=failing_migration) return Migration(from_version=0, to_version=1, callback=failing_migration)
@pytest.fixture @pytest.fixture
@ -101,40 +99,15 @@ def create_migrate(i: int) -> MigrateCallback:
return migrate return migrate
def test_migration_dependency_sets_value_primitive() -> None:
dependency = MigrationDependency(name="test_dependency", dependency_type=str)
dependency.set_value("test")
assert dependency.value == "test"
with pytest.raises(TypeError, match=r"Dependency test_dependency must be of type.*str"):
dependency.set_value(1)
def test_migration_dependency_sets_value_complex() -> None:
class SomeBase(ABC):
@abstractmethod
def some_method(self) -> None:
pass
class SomeImpl(SomeBase):
def some_method(self) -> None:
return
dependency = MigrationDependency(name="test_dependency", dependency_type=SomeBase)
with pytest.raises(TypeError, match=r"Dependency test_dependency must be of type.*SomeBase"):
dependency.set_value(1)
# not throwing is sufficient
dependency.set_value(SomeImpl())
def test_migration_to_version_is_one_gt_from_version(no_op_migrate_callback: MigrateCallback) -> None: def test_migration_to_version_is_one_gt_from_version(no_op_migrate_callback: MigrateCallback) -> None:
with pytest.raises(ValidationError, match="to_version must be one greater than from_version"): with pytest.raises(ValidationError, match="to_version must be one greater than from_version"):
Migration(from_version=0, to_version=2, migrate_callback=no_op_migrate_callback) Migration(from_version=0, to_version=2, callback=no_op_migrate_callback)
# not raising is sufficient # not raising is sufficient
Migration(from_version=1, to_version=2, migrate_callback=no_op_migrate_callback) Migration(from_version=1, to_version=2, callback=no_op_migrate_callback)
def test_migration_hash(no_op_migrate_callback: MigrateCallback) -> None: def test_migration_hash(no_op_migrate_callback: MigrateCallback) -> None:
migration = Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback) migration = Migration(from_version=0, to_version=1, callback=no_op_migrate_callback)
assert hash(migration) == hash((0, 1)) assert hash(migration) == hash((0, 1))
@ -147,13 +120,13 @@ def test_migration_set_add_migration(migrator: SQLiteMigrator, migration_no_op:
def test_migration_set_may_not_register_dupes( def test_migration_set_may_not_register_dupes(
migrator: SQLiteMigrator, no_op_migrate_callback: MigrateCallback migrator: SQLiteMigrator, no_op_migrate_callback: MigrateCallback
) -> None: ) -> None:
migrate_0_to_1_a = Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback) migrate_0_to_1_a = Migration(from_version=0, to_version=1, callback=no_op_migrate_callback)
migrate_0_to_1_b = Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback) migrate_0_to_1_b = Migration(from_version=0, to_version=1, callback=no_op_migrate_callback)
migrator._migration_set.register(migrate_0_to_1_a) migrator._migration_set.register(migrate_0_to_1_a)
with pytest.raises(MigrationVersionError, match=r"Migration with from_version or to_version already registered"): with pytest.raises(MigrationVersionError, match=r"Migration with from_version or to_version already registered"):
migrator._migration_set.register(migrate_0_to_1_b) migrator._migration_set.register(migrate_0_to_1_b)
migrate_1_to_2_a = Migration(from_version=1, to_version=2, migrate_callback=no_op_migrate_callback) migrate_1_to_2_a = Migration(from_version=1, to_version=2, callback=no_op_migrate_callback)
migrate_1_to_2_b = Migration(from_version=1, to_version=2, migrate_callback=no_op_migrate_callback) migrate_1_to_2_b = Migration(from_version=1, to_version=2, callback=no_op_migrate_callback)
migrator._migration_set.register(migrate_1_to_2_a) migrator._migration_set.register(migrate_1_to_2_a)
with pytest.raises(MigrationVersionError, match=r"Migration with from_version or to_version already registered"): with pytest.raises(MigrationVersionError, match=r"Migration with from_version or to_version already registered"):
migrator._migration_set.register(migrate_1_to_2_b) migrator._migration_set.register(migrate_1_to_2_b)
@ -168,15 +141,15 @@ def test_migration_set_gets_migration(migration_no_op: Migration) -> None:
def test_migration_set_validates_migration_chain(no_op_migrate_callback: MigrateCallback) -> None: def test_migration_set_validates_migration_chain(no_op_migrate_callback: MigrateCallback) -> None:
migration_set = MigrationSet() migration_set = MigrationSet()
migration_set.register(Migration(from_version=1, to_version=2, migrate_callback=no_op_migrate_callback)) migration_set.register(Migration(from_version=1, to_version=2, callback=no_op_migrate_callback))
with pytest.raises(MigrationError, match="Migration chain is fragmented"): with pytest.raises(MigrationError, match="Migration chain is fragmented"):
# no migration from 0 to 1 # no migration from 0 to 1
migration_set.validate_migration_chain() migration_set.validate_migration_chain()
migration_set.register(Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback)) migration_set.register(Migration(from_version=0, to_version=1, callback=no_op_migrate_callback))
migration_set.validate_migration_chain() migration_set.validate_migration_chain()
migration_set.register(Migration(from_version=2, to_version=3, migrate_callback=no_op_migrate_callback)) migration_set.register(Migration(from_version=2, to_version=3, callback=no_op_migrate_callback))
migration_set.validate_migration_chain() migration_set.validate_migration_chain()
migration_set.register(Migration(from_version=4, to_version=5, migrate_callback=no_op_migrate_callback)) migration_set.register(Migration(from_version=4, to_version=5, callback=no_op_migrate_callback))
with pytest.raises(MigrationError, match="Migration chain is fragmented"): with pytest.raises(MigrationError, match="Migration chain is fragmented"):
# no migration from 3 to 4 # no migration from 3 to 4
migration_set.validate_migration_chain() migration_set.validate_migration_chain()
@ -185,64 +158,32 @@ def test_migration_set_validates_migration_chain(no_op_migrate_callback: Migrate
def test_migration_set_counts_migrations(no_op_migrate_callback: MigrateCallback) -> None: def test_migration_set_counts_migrations(no_op_migrate_callback: MigrateCallback) -> None:
migration_set = MigrationSet() migration_set = MigrationSet()
assert migration_set.count == 0 assert migration_set.count == 0
migration_set.register(Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback)) migration_set.register(Migration(from_version=0, to_version=1, callback=no_op_migrate_callback))
assert migration_set.count == 1 assert migration_set.count == 1
migration_set.register(Migration(from_version=1, to_version=2, migrate_callback=no_op_migrate_callback)) migration_set.register(Migration(from_version=1, to_version=2, callback=no_op_migrate_callback))
assert migration_set.count == 2 assert migration_set.count == 2
def test_migration_set_gets_latest_version(no_op_migrate_callback: MigrateCallback) -> None: def test_migration_set_gets_latest_version(no_op_migrate_callback: MigrateCallback) -> None:
migration_set = MigrationSet() migration_set = MigrationSet()
assert migration_set.latest_version == 0 assert migration_set.latest_version == 0
migration_set.register(Migration(from_version=1, to_version=2, migrate_callback=no_op_migrate_callback)) migration_set.register(Migration(from_version=1, to_version=2, callback=no_op_migrate_callback))
assert migration_set.latest_version == 2 assert migration_set.latest_version == 2
migration_set.register(Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback)) migration_set.register(Migration(from_version=0, to_version=1, callback=no_op_migrate_callback))
assert migration_set.latest_version == 2 assert migration_set.latest_version == 2
def test_migration_provide_dependency_validates_name(no_op_migrate_callback: MigrateCallback) -> None: def test_migration_runs(memory_db_cursor: sqlite3.Cursor, migrate_callback_create_test_table: MigrateCallback) -> None:
dependency = MigrationDependency(name="my_dependency", dependency_type=str)
migration = Migration( migration = Migration(
from_version=0, from_version=0,
to_version=1, to_version=1,
migrate_callback=no_op_migrate_callback, callback=migrate_callback_create_test_table,
dependencies={dependency.name: dependency},
) )
with pytest.raises(ValueError, match="is not a dependency of this migration"): migration.callback(memory_db_cursor)
migration.provide_dependency("unknown_dependency_name", "banana_sushi")
def test_migration_runs_without_dependencies(
memory_db_cursor: sqlite3.Cursor, migrate_callback_create_test_table: MigrateCallback
) -> None:
migration = Migration(
from_version=0,
to_version=1,
migrate_callback=migrate_callback_create_test_table,
)
migration.run(memory_db_cursor)
memory_db_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test';") memory_db_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test';")
assert memory_db_cursor.fetchone() is not None assert memory_db_cursor.fetchone() is not None
def test_migration_runs_with_dependencies(
memory_db_cursor: sqlite3.Cursor, migrate_callback_create_table_of_name: MigrateCallback
) -> None:
dependency = MigrationDependency(name="table_name", dependency_type=str)
migration = Migration(
from_version=0,
to_version=1,
migrate_callback=migrate_callback_create_table_of_name,
dependencies={dependency.name: dependency},
)
with pytest.raises(MigrationError, match="Missing migration dependencies"):
migration.run(memory_db_cursor)
migration.provide_dependency(dependency.name, "banana_sushi")
migration.run(memory_db_cursor)
memory_db_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='banana_sushi';")
assert memory_db_cursor.fetchone() is not None
def test_migrator_registers_migration(migrator: SQLiteMigrator, migration_no_op: Migration) -> None: def test_migrator_registers_migration(migrator: SQLiteMigrator, migration_no_op: Migration) -> None:
migration = migration_no_op migration = migration_no_op
migrator.register_migration(migration) migrator.register_migration(migration)
@ -286,7 +227,7 @@ def test_migrator_runs_single_migration(migrator: SQLiteMigrator, migration_crea
def test_migrator_runs_all_migrations_in_memory(migrator: SQLiteMigrator) -> None: def test_migrator_runs_all_migrations_in_memory(migrator: SQLiteMigrator) -> None:
cursor = migrator._db.conn.cursor() cursor = migrator._db.conn.cursor()
migrations = [Migration(from_version=i, to_version=i + 1, migrate_callback=create_migrate(i)) for i in range(0, 3)] migrations = [Migration(from_version=i, to_version=i + 1, callback=create_migrate(i)) for i in range(0, 3)]
for migration in migrations: for migration in migrations:
migrator.register_migration(migration) migrator.register_migration(migration)
migrator.run_migrations() migrator.run_migrations()
@ -299,9 +240,7 @@ def test_migrator_runs_all_migrations_file(logger: Logger) -> None:
# The Migrator closes the database when it finishes; we cannot use a context manager. # The Migrator closes the database when it finishes; we cannot use a context manager.
db = SqliteDatabase(db_path=original_db_path, logger=logger, verbose=False) db = SqliteDatabase(db_path=original_db_path, logger=logger, verbose=False)
migrator = SQLiteMigrator(db=db) migrator = SQLiteMigrator(db=db)
migrations = [ migrations = [Migration(from_version=i, to_version=i + 1, callback=create_migrate(i)) for i in range(0, 3)]
Migration(from_version=i, to_version=i + 1, migrate_callback=create_migrate(i)) for i in range(0, 3)
]
for migration in migrations: for migration in migrations:
migrator.register_migration(migration) migrator.register_migration(migration)
migrator.run_migrations() migrator.run_migrations()
@ -319,7 +258,7 @@ def test_migrator_makes_no_changes_on_failed_migration(
migrator.register_migration(migration_no_op) migrator.register_migration(migration_no_op)
migrator.run_migrations() migrator.run_migrations()
assert migrator._get_current_version(cursor) == 1 assert migrator._get_current_version(cursor) == 1
migrator.register_migration(Migration(from_version=1, to_version=2, migrate_callback=failing_migrate_callback)) migrator.register_migration(Migration(from_version=1, to_version=2, callback=failing_migrate_callback))
with pytest.raises(MigrationError, match="Bad migration"): with pytest.raises(MigrationError, match="Bad migration"):
migrator.run_migrations() migrator.run_migrations()
assert migrator._get_current_version(cursor) == 1 assert migrator._get_current_version(cursor) == 1