feat(db): address feedback, cleanup

- use simpler pattern for migration dependencies
- move SqliteDatabase & migration to utility method `init_db`, use this in both the app and tests, ensuring the same db schema is used in both
This commit is contained in:
psychedelicious 2023-12-13 11:19:59 +11:00
parent 386b656530
commit ebf5f5d418
12 changed files with 615 additions and 722 deletions

View File

@ -2,9 +2,7 @@
from logging import Logger
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import migration_1
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import migration_2
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SQLiteMigrator
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.version.invokeai_version import __version__
@ -33,7 +31,6 @@ from ..services.session_processor.session_processor_default import DefaultSessio
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
from ..services.shared.default_graphs import create_system_graphs
from ..services.shared.graph import GraphExecutionState, LibraryGraph
from ..services.shared.sqlite.sqlite_database import SqliteDatabase
from ..services.urls.urls_default import LocalUrlService
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
from .events import FastAPIEventService
@ -72,17 +69,7 @@ class ApiDependencies:
output_folder = config.output_path
image_files = DiskImageFileStorage(f"{output_folder}/images")
db_path = None if config.use_memory_db else config.db_path
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql)
# This migration requires an ImageFileStorageBase service and logger
migration_2.provide_dependency("image_files", image_files)
migration_2.provide_dependency("logger", logger)
migrator = SQLiteMigrator(db=db)
migrator.register_migration(migration_1)
migrator.register_migration(migration_2)
migrator.run_migrations()
db = init_db(config=config, logger=logger, image_files=image_files)
configuration = config
logger = logger

View File

@ -0,0 +1,32 @@
from logging import Logger
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import build_migration_1
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import build_migration_2
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SQLiteMigrator
def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileStorageBase) -> SqliteDatabase:
"""
Initializes the SQLite database.
:param config: The app config
:param logger: The logger
:param image_files: The image files service (used by migration 2)
This function:
- Instantiates a :class:`SqliteDatabase`
- Instantiates a :class:`SQLiteMigrator` and registers all migrations
- Runs all migrations
"""
db_path = None if config.use_memory_db else config.db_path
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql)
migrator = SQLiteMigrator(db=db)
migrator.register_migration(build_migration_1())
migrator.register_migration(build_migration_2(image_files=image_files, logger=logger))
migrator.run_migrations()
return db

View File

@ -3,19 +3,19 @@ import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
def migrate_callback(cursor: sqlite3.Cursor, **kwargs) -> None:
class Migration1Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
"""Migration callback for database version 1."""
_create_board_images(cursor)
_create_boards(cursor)
_create_images(cursor)
_create_model_config(cursor)
_create_session_queue(cursor)
_create_workflow_images(cursor)
_create_workflows(cursor)
self._create_board_images(cursor)
self._create_boards(cursor)
self._create_images(cursor)
self._create_model_config(cursor)
self._create_session_queue(cursor)
self._create_workflow_images(cursor)
self._create_workflows(cursor)
def _create_board_images(cursor: sqlite3.Cursor) -> None:
def _create_board_images(self, cursor: sqlite3.Cursor) -> None:
"""Creates the `board_images` table, indices and triggers."""
tables = [
"""--sql
@ -56,8 +56,7 @@ def _create_board_images(cursor: sqlite3.Cursor) -> None:
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def _create_boards(cursor: sqlite3.Cursor) -> None:
def _create_boards(self, cursor: sqlite3.Cursor) -> None:
"""Creates the `boards` table, indices and triggers."""
tables = [
"""--sql
@ -92,8 +91,7 @@ def _create_boards(cursor: sqlite3.Cursor) -> None:
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def _create_images(cursor: sqlite3.Cursor) -> None:
def _create_images(self, cursor: sqlite3.Cursor) -> None:
"""Creates the `images` table, indices and triggers. Adds the `starred` column."""
tables = [
@ -149,8 +147,7 @@ def _create_images(cursor: sqlite3.Cursor) -> None:
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def _create_model_config(cursor: sqlite3.Cursor) -> None:
def _create_model_config(self, cursor: sqlite3.Cursor) -> None:
"""Creates the `model_config` table, `model_manager_metadata` table, indices and triggers."""
tables = [
@ -205,8 +202,7 @@ def _create_model_config(cursor: sqlite3.Cursor) -> None:
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def _create_session_queue(cursor: sqlite3.Cursor) -> None:
def _create_session_queue(self, cursor: sqlite3.Cursor) -> None:
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS session_queue (
@ -279,8 +275,7 @@ def _create_session_queue(cursor: sqlite3.Cursor) -> None:
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def _create_workflow_images(cursor: sqlite3.Cursor) -> None:
def _create_workflow_images(self, cursor: sqlite3.Cursor) -> None:
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS workflow_images (
@ -320,8 +315,7 @@ def _create_workflow_images(cursor: sqlite3.Cursor) -> None:
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def _create_workflows(cursor: sqlite3.Cursor) -> None:
def _create_workflows(self, cursor: sqlite3.Cursor) -> None:
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS workflows (
@ -350,13 +344,9 @@ def _create_workflows(cursor: sqlite3.Cursor) -> None:
cursor.execute(stmt)
migration_1 = Migration(
from_version=0,
to_version=1,
migrate_callback=migrate_callback,
)
def build_migration_1() -> Migration:
"""
Database version 1 (initial state).
Builds the migration from database version 0 (init) to 1.
This migration represents the state of the database circa InvokeAI v3.4.0, which was the last
version to not use migrations to manage the database.
@ -372,3 +362,11 @@ to be idempotent.
- Create `workflow_images` junction table
- Create `workflows` table
"""
migration_1 = Migration(
from_version=0,
to_version=1,
callback=Migration1Callback(),
)
return migration_1

View File

@ -4,29 +4,24 @@ from logging import Logger
from tqdm import tqdm
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration, MigrationDependency
# This migration requires an ImageFileStorageBase service and logger
image_files_dependency = MigrationDependency(name="image_files", dependency_type=ImageFileStorageBase)
logger_dependency = MigrationDependency(name="logger", dependency_type=Logger)
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
def migrate_callback(cursor: sqlite3.Cursor, **kwargs) -> None:
"""Migration callback for database version 2."""
class Migration2Callback:
def __init__(self, image_files: ImageFileStorageBase, logger: Logger):
self._image_files = image_files
self._logger = logger
logger = kwargs[logger_dependency.name]
image_files = kwargs[image_files_dependency.name]
def __call__(self, cursor: sqlite3.Cursor):
self._add_images_has_workflow(cursor)
self._add_session_queue_workflow(cursor)
self._drop_old_workflow_tables(cursor)
self._add_workflow_library(cursor)
self._drop_model_manager_metadata(cursor)
self._recreate_model_config(cursor)
self._migrate_embedded_workflows(cursor)
_add_images_has_workflow(cursor)
_add_session_queue_workflow(cursor)
_drop_old_workflow_tables(cursor)
_add_workflow_library(cursor)
_drop_model_manager_metadata(cursor)
_recreate_model_config(cursor)
_migrate_embedded_workflows(cursor, logger, image_files)
def _add_images_has_workflow(cursor: sqlite3.Cursor) -> None:
def _add_images_has_workflow(self, cursor: sqlite3.Cursor) -> None:
"""Add the `has_workflow` column to `images` table."""
cursor.execute("PRAGMA table_info(images)")
columns = [column[1] for column in cursor.fetchall()]
@ -34,8 +29,7 @@ def _add_images_has_workflow(cursor: sqlite3.Cursor) -> None:
if "has_workflow" not in columns:
cursor.execute("ALTER TABLE images ADD COLUMN has_workflow BOOLEAN DEFAULT FALSE;")
def _add_session_queue_workflow(cursor: sqlite3.Cursor) -> None:
def _add_session_queue_workflow(self, cursor: sqlite3.Cursor) -> None:
"""Add the `workflow` column to `session_queue` table."""
cursor.execute("PRAGMA table_info(session_queue)")
@ -44,14 +38,12 @@ def _add_session_queue_workflow(cursor: sqlite3.Cursor) -> None:
if "workflow" not in columns:
cursor.execute("ALTER TABLE session_queue ADD COLUMN workflow TEXT;")
def _drop_old_workflow_tables(cursor: sqlite3.Cursor) -> None:
def _drop_old_workflow_tables(self, cursor: sqlite3.Cursor) -> None:
"""Drops the `workflows` and `workflow_images` tables."""
cursor.execute("DROP TABLE IF EXISTS workflow_images;")
cursor.execute("DROP TABLE IF EXISTS workflows;")
def _add_workflow_library(cursor: sqlite3.Cursor) -> None:
def _add_workflow_library(self, cursor: sqlite3.Cursor) -> None:
"""Adds the `workflow_library` table and drops the `workflows` and `workflow_images` tables."""
tables = [
"""--sql
@ -96,13 +88,11 @@ def _add_workflow_library(cursor: sqlite3.Cursor) -> None:
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def _drop_model_manager_metadata(cursor: sqlite3.Cursor) -> None:
def _drop_model_manager_metadata(self, cursor: sqlite3.Cursor) -> None:
"""Drops the `model_manager_metadata` table."""
cursor.execute("DROP TABLE IF EXISTS model_manager_metadata;")
def _recreate_model_config(cursor: sqlite3.Cursor) -> None:
def _recreate_model_config(self, cursor: sqlite3.Cursor) -> None:
"""
Drops the `model_config` table, recreating it.
@ -136,12 +126,7 @@ def _recreate_model_config(cursor: sqlite3.Cursor) -> None:
"""
)
def _migrate_embedded_workflows(
cursor: sqlite3.Cursor,
logger: Logger,
image_files: ImageFileStorageBase,
) -> None:
def _migrate_embedded_workflows(self, cursor: sqlite3.Cursor) -> None:
"""
In the v3.5.0 release, InvokeAI changed how it handles embedded workflows. The `images` table in
the database now has a `has_workflow` column, indicating if an image has a workflow embedded.
@ -157,40 +142,43 @@ def _migrate_embedded_workflows(
if not total_image_names:
return
logger.info(f"Migrating workflows for {total_image_names} images")
self._logger.info(f"Migrating workflows for {total_image_names} images")
# Migrate the images
to_migrate: list[tuple[bool, str]] = []
pbar = tqdm(image_names)
for idx, image_name in enumerate(pbar):
pbar.set_description(f"Checking image {idx + 1}/{total_image_names} for workflow")
pil_image = image_files.get(image_name)
pil_image = self._image_files.get(image_name)
if "invokeai_workflow" in pil_image.info:
to_migrate.append((True, image_name))
logger.info(f"Adding {len(to_migrate)} embedded workflows to database")
self._logger.info(f"Adding {len(to_migrate)} embedded workflows to database")
cursor.executemany("UPDATE images SET has_workflow = ? WHERE image_name = ?", to_migrate)
migration_2 = Migration(
from_version=1,
to_version=2,
migrate_callback=migrate_callback,
dependencies={image_files_dependency.name: image_files_dependency, logger_dependency.name: logger_dependency},
)
def build_migration_2(image_files: ImageFileStorageBase, logger: Logger) -> Migration:
"""
Database version 2.
Builds the migration from database version 1 to 2.
Introduced in v3.5.0 for the new workflow library.
Dependencies:
- image_files: ImageFileStorageBase
- logger: Logger
:param image_files: The image files service, used to check for embedded workflows
:param logger: The logger, used to log progress during embedded workflows handling
Migration:
This migration does the following:
- Add `has_workflow` column to `images` table
- Add `workflow` column to `session_queue` table
- Drop `workflows` and `workflow_images` tables
- Add `workflow_library` table
- Drops the `model_manager_metadata` table
- Drops the `model_config` table, recreating it (at this point, there is no user data in this table)
- Populates the `has_workflow` column in the `images` table (requires `image_files` & `logger` dependencies)
"""
migration_2 = Migration(
from_version=1,
to_version=2,
callback=Migration2Callback(image_files=image_files, logger=logger),
)
return migration_2

View File

@ -1,6 +1,5 @@
import sqlite3
from functools import partial
from typing import Any, Optional, Protocol, runtime_checkable
from typing import Optional, Protocol, runtime_checkable
from pydantic import BaseModel, ConfigDict, Field, model_validator
@ -18,7 +17,7 @@ class MigrateCallback(Protocol):
See :class:`Migration` for an example.
"""
def __call__(self, cursor: sqlite3.Cursor, **kwargs: Any) -> None:
def __call__(self, cursor: sqlite3.Cursor) -> None:
...
@ -30,96 +29,69 @@ class MigrationVersionError(ValueError):
"""Raised when a migration version is invalid."""
class MigrationDependency:
"""
Represents a dependency for a migration.
:param name: The name of the dependency
:param dependency_type: The type of the dependency (e.g. `str`, `int`, `SomeClass`, etc.)
"""
def __init__(
self,
name: str,
dependency_type: Any,
) -> None:
self.name = name
self.dependency_type = dependency_type
self.value = None
def set_value(self, value: Any) -> None:
"""
Sets the value of the dependency.
If the value is not of the correct type, a TypeError is raised.
"""
if not isinstance(value, self.dependency_type):
raise TypeError(f"Dependency {self.name} must be of type {self.dependency_type}")
self.value = value
class Migration(BaseModel):
"""
Represents a migration for a SQLite database.
:param from_version: The database version on which this migration may be run
:param to_version: The database version that results from this migration
:param migrate: The callback to run to perform the migration
:param dependencies: A dict of dependencies that must be provided to the migration
:param migrate_callback: The callback to run to perform the migration
Migration callbacks will be provided an open cursor to the database. They should not commit their
transaction; this is handled by the migrator.
Example Usage:
It is suggested to use a class to define the migration callback and a builder function to create
the :class:`Migration`. This allows the callback to be provided with additional dependencies and
keeps things tidy, as all migration logic is self-contained.
Example:
```py
# Define the migrate callback. This migration adds a column to the sushi table.
def migrate_callback(cursor: sqlite3.Cursor, **kwargs) -> None:
# Define the migration callback class
class Migration1Callback:
# This migration needs a logger, so we define a class that accepts a logger in its constructor.
def __init__(self, image_files: ImageFileStorageBase) -> None:
self._image_files = ImageFileStorageBase
# This dunder method allows the instance of the class to be called like a function.
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._add_with_banana_column(cursor)
self._do_something_with_images(cursor)
def _add_with_banana_column(self, cursor: sqlite3.Cursor) -> None:
\"""Adds the with_banana column to the sushi table.\"""
# Execute SQL using the cursor, taking care to *not commit* a transaction
cursor.execute('ALTER TABLE sushi ADD COLUMN with_banana BOOLEAN DEFAULT TRUE;')
...
# Instantiate the migration
migration = Migration(
def _do_something_with_images(self, cursor: sqlite3.Cursor) -> None:
\"""Does something with the image files service.\"""
self._image_files.get(...)
# Define the migration builder function. This function creates an instance of the migration callback
# class and returns a Migration.
def build_migration_1(image_files: ImageFileStorageBase) -> Migration:
\"""Builds the migration from database version 0 to 1.
Requires the image files service to...
\"""
migration_1 = Migration(
from_version=0,
to_version=1,
migrate_callback=migrate_callback,
)
```
If a migration needs an additional dependency, it must be provided with :meth:`provide_dependency`
before the migration is run. The migrator provides dependencies to the migrate callback,
raising an error if a dependency is missing or was provided the wrong type.
Example Usage:
```py
# Create a migration dependency. This migration needs access the image files service, so we set the type to the ABC of that service.
image_files_dependency = MigrationDependency(name="image_files", dependency_type=ImageFileStorageBase)
# Define the migrate callback. The dependency may be accessed by name in the kwargs. The migrator will ensure that the dependency is of the required type.
def migrate_callback(cursor: sqlite3.Cursor, **kwargs) -> None:
image_files = kwargs[image_files_dependency.name]
# Do something with image_files
...
# Instantiate the migration, including the dependency.
migration = Migration(
from_version=0,
to_version=1,
migrate_callback=migrate_callback,
dependencies={image_files_dependency.name: image_files_dependency},
migrate_callback=Migration1Callback(image_files=image_files),
)
# Provide the dependency before registering the migration.
# (DiskImageFileStorage is an implementation of ImageFileStorageBase)
migration.provide_dependency(name="image_files", value=DiskImageFileStorage())
return migration_1
# Register the migration after all dependencies have been initialized
db = SqliteDatabase(db_path, logger)
migrator = SqliteMigrator(db)
migrator.register_migration(build_migration_1(image_files))
migrator.run_migrations()
```
"""
from_version: int = Field(ge=0, strict=True, description="The database version on which this migration may be run")
to_version: int = Field(ge=1, strict=True, description="The database version that results from this migration")
migrate_callback: MigrateCallback = Field(description="The callback to run to perform the migration")
dependencies: dict[str, MigrationDependency] = Field(
default={}, description="A dict of dependencies that must be provided to the migration"
)
callback: MigrateCallback = Field(description="The callback to run to perform the migration")
@model_validator(mode="after")
def validate_to_version(self) -> "Migration":
@ -132,23 +104,6 @@ class Migration(BaseModel):
# Callables are not hashable, so we need to implement our own __hash__ function to use this class in a set.
return hash((self.from_version, self.to_version))
def provide_dependency(self, name: str, value: Any) -> None:
"""Provides a dependency for this migration."""
if name not in self.dependencies:
raise ValueError(f"{name} of type {type(value)} is not a dependency of this migration")
self.dependencies[name].set_value(value)
def run(self, cursor: sqlite3.Cursor) -> None:
"""
Runs the migration.
If any dependencies are missing, a MigrationError is raised.
"""
missing_dependencies = [d.name for d in self.dependencies.values() if d.value is None]
if missing_dependencies:
raise MigrationError(f"Missing migration dependencies: {', '.join(missing_dependencies)}")
self.migrate_callback = partial(self.migrate_callback, **{d.name: d.value for d in self.dependencies.values()})
self.migrate_callback(cursor=cursor)
model_config = ConfigDict(arbitrary_types_allowed=True)

View File

@ -20,8 +20,8 @@ class SQLiteMigrator:
```py
db = SqliteDatabase(db_path="my_db.db", logger=logger)
migrator = SQLiteMigrator(db=db)
migrator.register_migration(migration_1)
migrator.register_migration(migration_2)
migrator.register_migration(build_migration_1())
migrator.register_migration(build_migration_2())
migrator.run_migrations()
```
"""
@ -76,7 +76,7 @@ class SQLiteMigrator:
self._logger.debug(f"Running migration from {migration.from_version} to {migration.to_version}")
# Run the actual migration
migration.run(cursor)
migration.callback(cursor)
# Update the version
cursor.execute("INSERT INTO migrations (version) VALUES (?);", (migration.to_version,))

View File

@ -29,7 +29,7 @@ from invokeai.app.services.shared.graph import (
LibraryGraph,
)
from invokeai.backend.util.logging import InvokeAILogger
from tests.fixtures.sqlite_database import CreateSqliteDatabaseFunction
from tests.fixtures.sqlite_database import create_mock_sqlite_database
from .test_invoker import create_edge
@ -47,10 +47,10 @@ def simple_graph():
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
# the test invocations.
@pytest.fixture
def mock_services(create_sqlite_database: CreateSqliteDatabaseFunction) -> InvocationServices:
def mock_services() -> InvocationServices:
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
logger = InvokeAILogger.get_logger()
db = create_sqlite_database(configuration, logger)
db = create_mock_sqlite_database(configuration, logger)
# NOTE: none of these are actually called by the test invocations
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
return InvocationServices(

View File

@ -4,7 +4,7 @@ import pytest
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.backend.util.logging import InvokeAILogger
from tests.fixtures.sqlite_database import CreateSqliteDatabaseFunction
from tests.fixtures.sqlite_database import create_mock_sqlite_database
# This import must happen before other invoke imports or test in other files(!!) break
from .test_nodes import ( # isort: split
@ -51,10 +51,10 @@ def graph_with_subgraph():
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
# the test invocations.
@pytest.fixture
def mock_services(create_sqlite_database: CreateSqliteDatabaseFunction) -> InvocationServices:
def mock_services() -> InvocationServices:
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
logger = InvokeAILogger.get_logger()
db = create_sqlite_database(configuration, logger)
db = create_mock_sqlite_database(configuration, logger)
# NOTE: none of these are actually called by the test invocations
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")

View File

@ -20,7 +20,7 @@ from invokeai.app.services.model_install import (
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL, UnknownModelException
from invokeai.backend.model_manager.config import BaseModelType, ModelType
from invokeai.backend.util.logging import InvokeAILogger
from tests.fixtures.sqlite_database import CreateSqliteDatabaseFunction
from tests.fixtures.sqlite_database import create_mock_sqlite_database
@pytest.fixture
@ -38,10 +38,10 @@ def app_config(datadir: Path) -> InvokeAIAppConfig:
@pytest.fixture
def store(
app_config: InvokeAIAppConfig, create_sqlite_database: CreateSqliteDatabaseFunction
app_config: InvokeAIAppConfig,
) -> ModelRecordServiceBase:
logger = InvokeAILogger.get_logger(config=app_config)
db = create_sqlite_database(app_config, logger)
db = create_mock_sqlite_database(app_config, logger)
store: ModelRecordServiceBase = ModelRecordServiceSQL(db)
return store

View File

@ -23,14 +23,16 @@ from invokeai.backend.model_manager.config import (
VaeDiffusersConfig,
)
from invokeai.backend.util.logging import InvokeAILogger
from tests.fixtures.sqlite_database import CreateSqliteDatabaseFunction
from tests.fixtures.sqlite_database import create_mock_sqlite_database
@pytest.fixture
def store(datadir: Any, create_sqlite_database: CreateSqliteDatabaseFunction) -> ModelRecordServiceBase:
def store(
datadir: Any,
) -> ModelRecordServiceBase:
config = InvokeAIAppConfig(root=datadir)
logger = InvokeAILogger.get_logger(config=config)
db = create_sqlite_database(config, logger)
db = create_mock_sqlite_database(config, logger)
return ModelRecordServiceSQL(db)

View File

@ -4,21 +4,13 @@ from unittest import mock
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import migration_1
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import migration_2
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SQLiteMigrator
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
def create_sqlite_database(config: InvokeAIAppConfig, logger: Logger) -> SqliteDatabase:
db_path = None if config.use_memory_db else config.db_path
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql)
def create_mock_sqlite_database(
config: InvokeAIAppConfig,
logger: Logger,
) -> SqliteDatabase:
image_files = mock.Mock(spec=ImageFileStorageBase)
migrator = SQLiteMigrator(db=db)
migration_2.provide_dependency("logger", logger)
migration_2.provide_dependency("image_files", image_files)
migrator.register_migration(migration_1)
migrator.register_migration(migration_2)
migrator.run_migrations()
db = init_db(config=config, logger=logger, image_files=image_files)
return db

View File

@ -1,5 +1,4 @@
import sqlite3
from abc import ABC, abstractmethod
from contextlib import closing
from logging import Logger
from pathlib import Path
@ -12,7 +11,6 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import (
MigrateCallback,
Migration,
MigrationDependency,
MigrationError,
MigrationSet,
MigrationVersionError,
@ -53,7 +51,7 @@ def no_op_migrate_callback() -> MigrateCallback:
@pytest.fixture
def migration_no_op(no_op_migrate_callback: MigrateCallback) -> Migration:
return Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback)
return Migration(from_version=0, to_version=1, callback=no_op_migrate_callback)
@pytest.fixture
@ -75,7 +73,7 @@ def migrate_callback_create_test_table() -> MigrateCallback:
@pytest.fixture
def migration_create_test_table(migrate_callback_create_test_table: MigrateCallback) -> Migration:
return Migration(from_version=0, to_version=1, migrate_callback=migrate_callback_create_test_table)
return Migration(from_version=0, to_version=1, callback=migrate_callback_create_test_table)
@pytest.fixture
@ -83,7 +81,7 @@ def failing_migration() -> Migration:
def failing_migration(cursor: sqlite3.Cursor, **kwargs) -> None:
raise Exception("Bad migration")
return Migration(from_version=0, to_version=1, migrate_callback=failing_migration)
return Migration(from_version=0, to_version=1, callback=failing_migration)
@pytest.fixture
@ -101,40 +99,15 @@ def create_migrate(i: int) -> MigrateCallback:
return migrate
def test_migration_dependency_sets_value_primitive() -> None:
dependency = MigrationDependency(name="test_dependency", dependency_type=str)
dependency.set_value("test")
assert dependency.value == "test"
with pytest.raises(TypeError, match=r"Dependency test_dependency must be of type.*str"):
dependency.set_value(1)
def test_migration_dependency_sets_value_complex() -> None:
class SomeBase(ABC):
@abstractmethod
def some_method(self) -> None:
pass
class SomeImpl(SomeBase):
def some_method(self) -> None:
return
dependency = MigrationDependency(name="test_dependency", dependency_type=SomeBase)
with pytest.raises(TypeError, match=r"Dependency test_dependency must be of type.*SomeBase"):
dependency.set_value(1)
# not throwing is sufficient
dependency.set_value(SomeImpl())
def test_migration_to_version_is_one_gt_from_version(no_op_migrate_callback: MigrateCallback) -> None:
with pytest.raises(ValidationError, match="to_version must be one greater than from_version"):
Migration(from_version=0, to_version=2, migrate_callback=no_op_migrate_callback)
Migration(from_version=0, to_version=2, callback=no_op_migrate_callback)
# not raising is sufficient
Migration(from_version=1, to_version=2, migrate_callback=no_op_migrate_callback)
Migration(from_version=1, to_version=2, callback=no_op_migrate_callback)
def test_migration_hash(no_op_migrate_callback: MigrateCallback) -> None:
migration = Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback)
migration = Migration(from_version=0, to_version=1, callback=no_op_migrate_callback)
assert hash(migration) == hash((0, 1))
@ -147,13 +120,13 @@ def test_migration_set_add_migration(migrator: SQLiteMigrator, migration_no_op:
def test_migration_set_may_not_register_dupes(
migrator: SQLiteMigrator, no_op_migrate_callback: MigrateCallback
) -> None:
migrate_0_to_1_a = Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback)
migrate_0_to_1_b = Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback)
migrate_0_to_1_a = Migration(from_version=0, to_version=1, callback=no_op_migrate_callback)
migrate_0_to_1_b = Migration(from_version=0, to_version=1, callback=no_op_migrate_callback)
migrator._migration_set.register(migrate_0_to_1_a)
with pytest.raises(MigrationVersionError, match=r"Migration with from_version or to_version already registered"):
migrator._migration_set.register(migrate_0_to_1_b)
migrate_1_to_2_a = Migration(from_version=1, to_version=2, migrate_callback=no_op_migrate_callback)
migrate_1_to_2_b = Migration(from_version=1, to_version=2, migrate_callback=no_op_migrate_callback)
migrate_1_to_2_a = Migration(from_version=1, to_version=2, callback=no_op_migrate_callback)
migrate_1_to_2_b = Migration(from_version=1, to_version=2, callback=no_op_migrate_callback)
migrator._migration_set.register(migrate_1_to_2_a)
with pytest.raises(MigrationVersionError, match=r"Migration with from_version or to_version already registered"):
migrator._migration_set.register(migrate_1_to_2_b)
@ -168,15 +141,15 @@ def test_migration_set_gets_migration(migration_no_op: Migration) -> None:
def test_migration_set_validates_migration_chain(no_op_migrate_callback: MigrateCallback) -> None:
migration_set = MigrationSet()
migration_set.register(Migration(from_version=1, to_version=2, migrate_callback=no_op_migrate_callback))
migration_set.register(Migration(from_version=1, to_version=2, callback=no_op_migrate_callback))
with pytest.raises(MigrationError, match="Migration chain is fragmented"):
# no migration from 0 to 1
migration_set.validate_migration_chain()
migration_set.register(Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback))
migration_set.register(Migration(from_version=0, to_version=1, callback=no_op_migrate_callback))
migration_set.validate_migration_chain()
migration_set.register(Migration(from_version=2, to_version=3, migrate_callback=no_op_migrate_callback))
migration_set.register(Migration(from_version=2, to_version=3, callback=no_op_migrate_callback))
migration_set.validate_migration_chain()
migration_set.register(Migration(from_version=4, to_version=5, migrate_callback=no_op_migrate_callback))
migration_set.register(Migration(from_version=4, to_version=5, callback=no_op_migrate_callback))
with pytest.raises(MigrationError, match="Migration chain is fragmented"):
# no migration from 3 to 4
migration_set.validate_migration_chain()
@ -185,64 +158,32 @@ def test_migration_set_validates_migration_chain(no_op_migrate_callback: Migrate
def test_migration_set_counts_migrations(no_op_migrate_callback: MigrateCallback) -> None:
migration_set = MigrationSet()
assert migration_set.count == 0
migration_set.register(Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback))
migration_set.register(Migration(from_version=0, to_version=1, callback=no_op_migrate_callback))
assert migration_set.count == 1
migration_set.register(Migration(from_version=1, to_version=2, migrate_callback=no_op_migrate_callback))
migration_set.register(Migration(from_version=1, to_version=2, callback=no_op_migrate_callback))
assert migration_set.count == 2
def test_migration_set_gets_latest_version(no_op_migrate_callback: MigrateCallback) -> None:
migration_set = MigrationSet()
assert migration_set.latest_version == 0
migration_set.register(Migration(from_version=1, to_version=2, migrate_callback=no_op_migrate_callback))
migration_set.register(Migration(from_version=1, to_version=2, callback=no_op_migrate_callback))
assert migration_set.latest_version == 2
migration_set.register(Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback))
migration_set.register(Migration(from_version=0, to_version=1, callback=no_op_migrate_callback))
assert migration_set.latest_version == 2
def test_migration_provide_dependency_validates_name(no_op_migrate_callback: MigrateCallback) -> None:
dependency = MigrationDependency(name="my_dependency", dependency_type=str)
def test_migration_runs(memory_db_cursor: sqlite3.Cursor, migrate_callback_create_test_table: MigrateCallback) -> None:
migration = Migration(
from_version=0,
to_version=1,
migrate_callback=no_op_migrate_callback,
dependencies={dependency.name: dependency},
callback=migrate_callback_create_test_table,
)
with pytest.raises(ValueError, match="is not a dependency of this migration"):
migration.provide_dependency("unknown_dependency_name", "banana_sushi")
def test_migration_runs_without_dependencies(
memory_db_cursor: sqlite3.Cursor, migrate_callback_create_test_table: MigrateCallback
) -> None:
migration = Migration(
from_version=0,
to_version=1,
migrate_callback=migrate_callback_create_test_table,
)
migration.run(memory_db_cursor)
migration.callback(memory_db_cursor)
memory_db_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test';")
assert memory_db_cursor.fetchone() is not None
def test_migration_runs_with_dependencies(
memory_db_cursor: sqlite3.Cursor, migrate_callback_create_table_of_name: MigrateCallback
) -> None:
dependency = MigrationDependency(name="table_name", dependency_type=str)
migration = Migration(
from_version=0,
to_version=1,
migrate_callback=migrate_callback_create_table_of_name,
dependencies={dependency.name: dependency},
)
with pytest.raises(MigrationError, match="Missing migration dependencies"):
migration.run(memory_db_cursor)
migration.provide_dependency(dependency.name, "banana_sushi")
migration.run(memory_db_cursor)
memory_db_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='banana_sushi';")
assert memory_db_cursor.fetchone() is not None
def test_migrator_registers_migration(migrator: SQLiteMigrator, migration_no_op: Migration) -> None:
migration = migration_no_op
migrator.register_migration(migration)
@ -286,7 +227,7 @@ def test_migrator_runs_single_migration(migrator: SQLiteMigrator, migration_crea
def test_migrator_runs_all_migrations_in_memory(migrator: SQLiteMigrator) -> None:
cursor = migrator._db.conn.cursor()
migrations = [Migration(from_version=i, to_version=i + 1, migrate_callback=create_migrate(i)) for i in range(0, 3)]
migrations = [Migration(from_version=i, to_version=i + 1, callback=create_migrate(i)) for i in range(0, 3)]
for migration in migrations:
migrator.register_migration(migration)
migrator.run_migrations()
@ -299,9 +240,7 @@ def test_migrator_runs_all_migrations_file(logger: Logger) -> None:
# The Migrator closes the database when it finishes; we cannot use a context manager.
db = SqliteDatabase(db_path=original_db_path, logger=logger, verbose=False)
migrator = SQLiteMigrator(db=db)
migrations = [
Migration(from_version=i, to_version=i + 1, migrate_callback=create_migrate(i)) for i in range(0, 3)
]
migrations = [Migration(from_version=i, to_version=i + 1, callback=create_migrate(i)) for i in range(0, 3)]
for migration in migrations:
migrator.register_migration(migration)
migrator.run_migrations()
@ -319,7 +258,7 @@ def test_migrator_makes_no_changes_on_failed_migration(
migrator.register_migration(migration_no_op)
migrator.run_migrations()
assert migrator._get_current_version(cursor) == 1
migrator.register_migration(Migration(from_version=1, to_version=2, migrate_callback=failing_migrate_callback))
migrator.register_migration(Migration(from_version=1, to_version=2, callback=failing_migrate_callback))
with pytest.raises(MigrationError, match="Bad migration"):
migrator.run_migrations()
assert migrator._get_current_version(cursor) == 1