feat(db): address feedback, cleanup

- use simpler pattern for migration dependencies
- move SqliteDatabase & migration to utility method `init_db`, use this in both the app and tests, ensuring the same db schema is used in both
This commit is contained in:
psychedelicious 2023-12-13 11:19:59 +11:00
parent 386b656530
commit ebf5f5d418
12 changed files with 615 additions and 722 deletions

View File

@ -2,9 +2,7 @@
from logging import Logger
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import migration_1
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import migration_2
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SQLiteMigrator
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.version.invokeai_version import __version__
@ -33,7 +31,6 @@ from ..services.session_processor.session_processor_default import DefaultSessio
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
from ..services.shared.default_graphs import create_system_graphs
from ..services.shared.graph import GraphExecutionState, LibraryGraph
from ..services.shared.sqlite.sqlite_database import SqliteDatabase
from ..services.urls.urls_default import LocalUrlService
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
from .events import FastAPIEventService
@ -72,17 +69,7 @@ class ApiDependencies:
output_folder = config.output_path
image_files = DiskImageFileStorage(f"{output_folder}/images")
db_path = None if config.use_memory_db else config.db_path
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql)
# This migration requires an ImageFileStorageBase service and logger
migration_2.provide_dependency("image_files", image_files)
migration_2.provide_dependency("logger", logger)
migrator = SQLiteMigrator(db=db)
migrator.register_migration(migration_1)
migrator.register_migration(migration_2)
migrator.run_migrations()
db = init_db(config=config, logger=logger, image_files=image_files)
configuration = config
logger = logger

View File

@ -0,0 +1,32 @@
from logging import Logger
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import build_migration_1
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import build_migration_2
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SQLiteMigrator
def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileStorageBase) -> SqliteDatabase:
"""
Initializes the SQLite database.
:param config: The app config
:param logger: The logger
:param image_files: The image files service (used by migration 2)
This function:
- Instantiates a :class:`SqliteDatabase`
- Instantiates a :class:`SQLiteMigrator` and registers all migrations
- Runs all migrations
"""
db_path = None if config.use_memory_db else config.db_path
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql)
migrator = SQLiteMigrator(db=db)
migrator.register_migration(build_migration_1())
migrator.register_migration(build_migration_2(image_files=image_files, logger=logger))
migrator.run_migrations()
return db

View File

@ -3,19 +3,19 @@ import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
def migrate_callback(cursor: sqlite3.Cursor, **kwargs) -> None:
class Migration1Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
"""Migration callback for database version 1."""
_create_board_images(cursor)
_create_boards(cursor)
_create_images(cursor)
_create_model_config(cursor)
_create_session_queue(cursor)
_create_workflow_images(cursor)
_create_workflows(cursor)
self._create_board_images(cursor)
self._create_boards(cursor)
self._create_images(cursor)
self._create_model_config(cursor)
self._create_session_queue(cursor)
self._create_workflow_images(cursor)
self._create_workflows(cursor)
def _create_board_images(cursor: sqlite3.Cursor) -> None:
def _create_board_images(self, cursor: sqlite3.Cursor) -> None:
"""Creates the `board_images` table, indices and triggers."""
tables = [
"""--sql
@ -56,8 +56,7 @@ def _create_board_images(cursor: sqlite3.Cursor) -> None:
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def _create_boards(cursor: sqlite3.Cursor) -> None:
def _create_boards(self, cursor: sqlite3.Cursor) -> None:
"""Creates the `boards` table, indices and triggers."""
tables = [
"""--sql
@ -92,8 +91,7 @@ def _create_boards(cursor: sqlite3.Cursor) -> None:
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def _create_images(cursor: sqlite3.Cursor) -> None:
def _create_images(self, cursor: sqlite3.Cursor) -> None:
"""Creates the `images` table, indices and triggers. Adds the `starred` column."""
tables = [
@ -149,8 +147,7 @@ def _create_images(cursor: sqlite3.Cursor) -> None:
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def _create_model_config(cursor: sqlite3.Cursor) -> None:
def _create_model_config(self, cursor: sqlite3.Cursor) -> None:
"""Creates the `model_config` table, `model_manager_metadata` table, indices and triggers."""
tables = [
@ -205,8 +202,7 @@ def _create_model_config(cursor: sqlite3.Cursor) -> None:
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def _create_session_queue(cursor: sqlite3.Cursor) -> None:
def _create_session_queue(self, cursor: sqlite3.Cursor) -> None:
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS session_queue (
@ -279,8 +275,7 @@ def _create_session_queue(cursor: sqlite3.Cursor) -> None:
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def _create_workflow_images(cursor: sqlite3.Cursor) -> None:
def _create_workflow_images(self, cursor: sqlite3.Cursor) -> None:
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS workflow_images (
@ -320,8 +315,7 @@ def _create_workflow_images(cursor: sqlite3.Cursor) -> None:
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def _create_workflows(cursor: sqlite3.Cursor) -> None:
def _create_workflows(self, cursor: sqlite3.Cursor) -> None:
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS workflows (
@ -350,25 +344,29 @@ def _create_workflows(cursor: sqlite3.Cursor) -> None:
cursor.execute(stmt)
migration_1 = Migration(
def build_migration_1() -> Migration:
"""
Builds the migration from database version 0 (init) to 1.
This migration represents the state of the database circa InvokeAI v3.4.0, which was the last
version to not use migrations to manage the database.
As such, this migration does include some ALTER statements, and the SQL statements are written
to be idempotent.
- Create `board_images` junction table
- Create `boards` table
- Create `images` table, add `starred` column
- Create `model_config` table
- Create `session_queue` table
- Create `workflow_images` junction table
- Create `workflows` table
"""
migration_1 = Migration(
from_version=0,
to_version=1,
migrate_callback=migrate_callback,
)
"""
Database version 1 (initial state).
callback=Migration1Callback(),
)
This migration represents the state of the database circa InvokeAI v3.4.0, which was the last
version to not use migrations to manage the database.
As such, this migration does include some ALTER statements, and the SQL statements are written
to be idempotent.
- Create `board_images` junction table
- Create `boards` table
- Create `images` table, add `starred` column
- Create `model_config` table
- Create `session_queue` table
- Create `workflow_images` junction table
- Create `workflows` table
"""
return migration_1

View File

@ -4,29 +4,24 @@ from logging import Logger
from tqdm import tqdm
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration, MigrationDependency
# This migration requires an ImageFileStorageBase service and logger
image_files_dependency = MigrationDependency(name="image_files", dependency_type=ImageFileStorageBase)
logger_dependency = MigrationDependency(name="logger", dependency_type=Logger)
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
def migrate_callback(cursor: sqlite3.Cursor, **kwargs) -> None:
"""Migration callback for database version 2."""
class Migration2Callback:
def __init__(self, image_files: ImageFileStorageBase, logger: Logger):
self._image_files = image_files
self._logger = logger
logger = kwargs[logger_dependency.name]
image_files = kwargs[image_files_dependency.name]
def __call__(self, cursor: sqlite3.Cursor):
self._add_images_has_workflow(cursor)
self._add_session_queue_workflow(cursor)
self._drop_old_workflow_tables(cursor)
self._add_workflow_library(cursor)
self._drop_model_manager_metadata(cursor)
self._recreate_model_config(cursor)
self._migrate_embedded_workflows(cursor)
_add_images_has_workflow(cursor)
_add_session_queue_workflow(cursor)
_drop_old_workflow_tables(cursor)
_add_workflow_library(cursor)
_drop_model_manager_metadata(cursor)
_recreate_model_config(cursor)
_migrate_embedded_workflows(cursor, logger, image_files)
def _add_images_has_workflow(cursor: sqlite3.Cursor) -> None:
def _add_images_has_workflow(self, cursor: sqlite3.Cursor) -> None:
"""Add the `has_workflow` column to `images` table."""
cursor.execute("PRAGMA table_info(images)")
columns = [column[1] for column in cursor.fetchall()]
@ -34,8 +29,7 @@ def _add_images_has_workflow(cursor: sqlite3.Cursor) -> None:
if "has_workflow" not in columns:
cursor.execute("ALTER TABLE images ADD COLUMN has_workflow BOOLEAN DEFAULT FALSE;")
def _add_session_queue_workflow(cursor: sqlite3.Cursor) -> None:
def _add_session_queue_workflow(self, cursor: sqlite3.Cursor) -> None:
"""Add the `workflow` column to `session_queue` table."""
cursor.execute("PRAGMA table_info(session_queue)")
@ -44,14 +38,12 @@ def _add_session_queue_workflow(cursor: sqlite3.Cursor) -> None:
if "workflow" not in columns:
cursor.execute("ALTER TABLE session_queue ADD COLUMN workflow TEXT;")
def _drop_old_workflow_tables(cursor: sqlite3.Cursor) -> None:
def _drop_old_workflow_tables(self, cursor: sqlite3.Cursor) -> None:
"""Drops the `workflows` and `workflow_images` tables."""
cursor.execute("DROP TABLE IF EXISTS workflow_images;")
cursor.execute("DROP TABLE IF EXISTS workflows;")
def _add_workflow_library(cursor: sqlite3.Cursor) -> None:
def _add_workflow_library(self, cursor: sqlite3.Cursor) -> None:
"""Adds the `workflow_library` table and drops the `workflows` and `workflow_images` tables."""
tables = [
"""--sql
@ -96,13 +88,11 @@ def _add_workflow_library(cursor: sqlite3.Cursor) -> None:
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def _drop_model_manager_metadata(cursor: sqlite3.Cursor) -> None:
def _drop_model_manager_metadata(self, cursor: sqlite3.Cursor) -> None:
"""Drops the `model_manager_metadata` table."""
cursor.execute("DROP TABLE IF EXISTS model_manager_metadata;")
def _recreate_model_config(cursor: sqlite3.Cursor) -> None:
def _recreate_model_config(self, cursor: sqlite3.Cursor) -> None:
"""
Drops the `model_config` table, recreating it.
@ -136,12 +126,7 @@ def _recreate_model_config(cursor: sqlite3.Cursor) -> None:
"""
)
def _migrate_embedded_workflows(
cursor: sqlite3.Cursor,
logger: Logger,
image_files: ImageFileStorageBase,
) -> None:
def _migrate_embedded_workflows(self, cursor: sqlite3.Cursor) -> None:
"""
In the v3.5.0 release, InvokeAI changed how it handles embedded workflows. The `images` table in
the database now has a `has_workflow` column, indicating if an image has a workflow embedded.
@ -157,40 +142,43 @@ def _migrate_embedded_workflows(
if not total_image_names:
return
logger.info(f"Migrating workflows for {total_image_names} images")
self._logger.info(f"Migrating workflows for {total_image_names} images")
# Migrate the images
to_migrate: list[tuple[bool, str]] = []
pbar = tqdm(image_names)
for idx, image_name in enumerate(pbar):
pbar.set_description(f"Checking image {idx + 1}/{total_image_names} for workflow")
pil_image = image_files.get(image_name)
pil_image = self._image_files.get(image_name)
if "invokeai_workflow" in pil_image.info:
to_migrate.append((True, image_name))
logger.info(f"Adding {len(to_migrate)} embedded workflows to database")
self._logger.info(f"Adding {len(to_migrate)} embedded workflows to database")
cursor.executemany("UPDATE images SET has_workflow = ? WHERE image_name = ?", to_migrate)
migration_2 = Migration(
def build_migration_2(image_files: ImageFileStorageBase, logger: Logger) -> Migration:
"""
Builds the migration from database version 1 to 2.
Introduced in v3.5.0 for the new workflow library.
:param image_files: The image files service, used to check for embedded workflows
:param logger: The logger, used to log progress during embedded workflows handling
This migration does the following:
- Add `has_workflow` column to `images` table
- Add `workflow` column to `session_queue` table
- Drop `workflows` and `workflow_images` tables
- Add `workflow_library` table
- Drops the `model_manager_metadata` table
- Drops the `model_config` table, recreating it (at this point, there is no user data in this table)
- Populates the `has_workflow` column in the `images` table (requires `image_files` & `logger` dependencies)
"""
migration_2 = Migration(
from_version=1,
to_version=2,
migrate_callback=migrate_callback,
dependencies={image_files_dependency.name: image_files_dependency, logger_dependency.name: logger_dependency},
)
"""
Database version 2.
callback=Migration2Callback(image_files=image_files, logger=logger),
)
Introduced in v3.5.0 for the new workflow library.
Dependencies:
- image_files: ImageFileStorageBase
- logger: Logger
Migration:
- Add `has_workflow` column to `images` table
- Add `workflow` column to `session_queue` table
- Drop `workflows` and `workflow_images` tables
- Add `workflow_library` table
- Populates the `has_workflow` column in the `images` table (requires `image_files` & `logger` dependencies)
"""
return migration_2

View File

@ -1,6 +1,5 @@
import sqlite3
from functools import partial
from typing import Any, Optional, Protocol, runtime_checkable
from typing import Optional, Protocol, runtime_checkable
from pydantic import BaseModel, ConfigDict, Field, model_validator
@ -18,7 +17,7 @@ class MigrateCallback(Protocol):
See :class:`Migration` for an example.
"""
def __call__(self, cursor: sqlite3.Cursor, **kwargs: Any) -> None:
def __call__(self, cursor: sqlite3.Cursor) -> None:
...
@ -30,96 +29,69 @@ class MigrationVersionError(ValueError):
"""Raised when a migration version is invalid."""
class MigrationDependency:
"""
Represents a dependency for a migration.
:param name: The name of the dependency
:param dependency_type: The type of the dependency (e.g. `str`, `int`, `SomeClass`, etc.)
"""
def __init__(
self,
name: str,
dependency_type: Any,
) -> None:
self.name = name
self.dependency_type = dependency_type
self.value = None
def set_value(self, value: Any) -> None:
"""
Sets the value of the dependency.
If the value is not of the correct type, a TypeError is raised.
"""
if not isinstance(value, self.dependency_type):
raise TypeError(f"Dependency {self.name} must be of type {self.dependency_type}")
self.value = value
class Migration(BaseModel):
"""
Represents a migration for a SQLite database.
:param from_version: The database version on which this migration may be run
:param to_version: The database version that results from this migration
:param migrate: The callback to run to perform the migration
:param dependencies: A dict of dependencies that must be provided to the migration
:param migrate_callback: The callback to run to perform the migration
Migration callbacks will be provided an open cursor to the database. They should not commit their
transaction; this is handled by the migrator.
Example Usage:
It is suggested to use a class to define the migration callback and a builder function to create
the :class:`Migration`. This allows the callback to be provided with additional dependencies and
keeps things tidy, as all migration logic is self-contained.
Example:
```py
# Define the migrate callback. This migration adds a column to the sushi table.
def migrate_callback(cursor: sqlite3.Cursor, **kwargs) -> None:
# Define the migration callback class
class Migration1Callback:
# This migration needs a logger, so we define a class that accepts a logger in its constructor.
def __init__(self, image_files: ImageFileStorageBase) -> None:
self._image_files = ImageFileStorageBase
# This dunder method allows the instance of the class to be called like a function.
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._add_with_banana_column(cursor)
self._do_something_with_images(cursor)
def _add_with_banana_column(self, cursor: sqlite3.Cursor) -> None:
\"""Adds the with_banana column to the sushi table.\"""
# Execute SQL using the cursor, taking care to *not commit* a transaction
cursor.execute('ALTER TABLE sushi ADD COLUMN with_banana BOOLEAN DEFAULT TRUE;')
...
# Instantiate the migration
migration = Migration(
def _do_something_with_images(self, cursor: sqlite3.Cursor) -> None:
\"""Does something with the image files service.\"""
self._image_files.get(...)
# Define the migration builder function. This function creates an instance of the migration callback
# class and returns a Migration.
def build_migration_1(image_files: ImageFileStorageBase) -> Migration:
\"""Builds the migration from database version 0 to 1.
Requires the image files service to...
\"""
migration_1 = Migration(
from_version=0,
to_version=1,
migrate_callback=migrate_callback,
)
```
If a migration needs an additional dependency, it must be provided with :meth:`provide_dependency`
before the migration is run. The migrator provides dependencies to the migrate callback,
raising an error if a dependency is missing or was provided the wrong type.
Example Usage:
```py
# Create a migration dependency. This migration needs access the image files service, so we set the type to the ABC of that service.
image_files_dependency = MigrationDependency(name="image_files", dependency_type=ImageFileStorageBase)
# Define the migrate callback. The dependency may be accessed by name in the kwargs. The migrator will ensure that the dependency is of the required type.
def migrate_callback(cursor: sqlite3.Cursor, **kwargs) -> None:
image_files = kwargs[image_files_dependency.name]
# Do something with image_files
...
# Instantiate the migration, including the dependency.
migration = Migration(
from_version=0,
to_version=1,
migrate_callback=migrate_callback,
dependencies={image_files_dependency.name: image_files_dependency},
migrate_callback=Migration1Callback(image_files=image_files),
)
# Provide the dependency before registering the migration.
# (DiskImageFileStorage is an implementation of ImageFileStorageBase)
migration.provide_dependency(name="image_files", value=DiskImageFileStorage())
return migration_1
# Register the migration after all dependencies have been initialized
db = SqliteDatabase(db_path, logger)
migrator = SqliteMigrator(db)
migrator.register_migration(build_migration_1(image_files))
migrator.run_migrations()
```
"""
from_version: int = Field(ge=0, strict=True, description="The database version on which this migration may be run")
to_version: int = Field(ge=1, strict=True, description="The database version that results from this migration")
migrate_callback: MigrateCallback = Field(description="The callback to run to perform the migration")
dependencies: dict[str, MigrationDependency] = Field(
default={}, description="A dict of dependencies that must be provided to the migration"
)
callback: MigrateCallback = Field(description="The callback to run to perform the migration")
@model_validator(mode="after")
def validate_to_version(self) -> "Migration":
@ -132,23 +104,6 @@ class Migration(BaseModel):
# Callables are not hashable, so we need to implement our own __hash__ function to use this class in a set.
return hash((self.from_version, self.to_version))
def provide_dependency(self, name: str, value: Any) -> None:
"""Provides a dependency for this migration."""
if name not in self.dependencies:
raise ValueError(f"{name} of type {type(value)} is not a dependency of this migration")
self.dependencies[name].set_value(value)
def run(self, cursor: sqlite3.Cursor) -> None:
"""
Runs the migration.
If any dependencies are missing, a MigrationError is raised.
"""
missing_dependencies = [d.name for d in self.dependencies.values() if d.value is None]
if missing_dependencies:
raise MigrationError(f"Missing migration dependencies: {', '.join(missing_dependencies)}")
self.migrate_callback = partial(self.migrate_callback, **{d.name: d.value for d in self.dependencies.values()})
self.migrate_callback(cursor=cursor)
model_config = ConfigDict(arbitrary_types_allowed=True)

View File

@ -20,8 +20,8 @@ class SQLiteMigrator:
```py
db = SqliteDatabase(db_path="my_db.db", logger=logger)
migrator = SQLiteMigrator(db=db)
migrator.register_migration(migration_1)
migrator.register_migration(migration_2)
migrator.register_migration(build_migration_1())
migrator.register_migration(build_migration_2())
migrator.run_migrations()
```
"""
@ -76,7 +76,7 @@ class SQLiteMigrator:
self._logger.debug(f"Running migration from {migration.from_version} to {migration.to_version}")
# Run the actual migration
migration.run(cursor)
migration.callback(cursor)
# Update the version
cursor.execute("INSERT INTO migrations (version) VALUES (?);", (migration.to_version,))

View File

@ -29,7 +29,7 @@ from invokeai.app.services.shared.graph import (
LibraryGraph,
)
from invokeai.backend.util.logging import InvokeAILogger
from tests.fixtures.sqlite_database import CreateSqliteDatabaseFunction
from tests.fixtures.sqlite_database import create_mock_sqlite_database
from .test_invoker import create_edge
@ -47,10 +47,10 @@ def simple_graph():
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
# the test invocations.
@pytest.fixture
def mock_services(create_sqlite_database: CreateSqliteDatabaseFunction) -> InvocationServices:
def mock_services() -> InvocationServices:
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
logger = InvokeAILogger.get_logger()
db = create_sqlite_database(configuration, logger)
db = create_mock_sqlite_database(configuration, logger)
# NOTE: none of these are actually called by the test invocations
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
return InvocationServices(

View File

@ -4,7 +4,7 @@ import pytest
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.backend.util.logging import InvokeAILogger
from tests.fixtures.sqlite_database import CreateSqliteDatabaseFunction
from tests.fixtures.sqlite_database import create_mock_sqlite_database
# This import must happen before other invoke imports or test in other files(!!) break
from .test_nodes import ( # isort: split
@ -51,10 +51,10 @@ def graph_with_subgraph():
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
# the test invocations.
@pytest.fixture
def mock_services(create_sqlite_database: CreateSqliteDatabaseFunction) -> InvocationServices:
def mock_services() -> InvocationServices:
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
logger = InvokeAILogger.get_logger()
db = create_sqlite_database(configuration, logger)
db = create_mock_sqlite_database(configuration, logger)
# NOTE: none of these are actually called by the test invocations
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")

View File

@ -20,7 +20,7 @@ from invokeai.app.services.model_install import (
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL, UnknownModelException
from invokeai.backend.model_manager.config import BaseModelType, ModelType
from invokeai.backend.util.logging import InvokeAILogger
from tests.fixtures.sqlite_database import CreateSqliteDatabaseFunction
from tests.fixtures.sqlite_database import create_mock_sqlite_database
@pytest.fixture
@ -38,10 +38,10 @@ def app_config(datadir: Path) -> InvokeAIAppConfig:
@pytest.fixture
def store(
app_config: InvokeAIAppConfig, create_sqlite_database: CreateSqliteDatabaseFunction
app_config: InvokeAIAppConfig,
) -> ModelRecordServiceBase:
logger = InvokeAILogger.get_logger(config=app_config)
db = create_sqlite_database(app_config, logger)
db = create_mock_sqlite_database(app_config, logger)
store: ModelRecordServiceBase = ModelRecordServiceSQL(db)
return store

View File

@ -23,14 +23,16 @@ from invokeai.backend.model_manager.config import (
VaeDiffusersConfig,
)
from invokeai.backend.util.logging import InvokeAILogger
from tests.fixtures.sqlite_database import CreateSqliteDatabaseFunction
from tests.fixtures.sqlite_database import create_mock_sqlite_database
@pytest.fixture
def store(datadir: Any, create_sqlite_database: CreateSqliteDatabaseFunction) -> ModelRecordServiceBase:
def store(
datadir: Any,
) -> ModelRecordServiceBase:
config = InvokeAIAppConfig(root=datadir)
logger = InvokeAILogger.get_logger(config=config)
db = create_sqlite_database(config, logger)
db = create_mock_sqlite_database(config, logger)
return ModelRecordServiceSQL(db)

View File

@ -4,21 +4,13 @@ from unittest import mock
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import migration_1
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import migration_2
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SQLiteMigrator
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
def create_sqlite_database(config: InvokeAIAppConfig, logger: Logger) -> SqliteDatabase:
db_path = None if config.use_memory_db else config.db_path
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql)
def create_mock_sqlite_database(
config: InvokeAIAppConfig,
logger: Logger,
) -> SqliteDatabase:
image_files = mock.Mock(spec=ImageFileStorageBase)
migrator = SQLiteMigrator(db=db)
migration_2.provide_dependency("logger", logger)
migration_2.provide_dependency("image_files", image_files)
migrator.register_migration(migration_1)
migrator.register_migration(migration_2)
migrator.run_migrations()
db = init_db(config=config, logger=logger, image_files=image_files)
return db

View File

@ -1,5 +1,4 @@
import sqlite3
from abc import ABC, abstractmethod
from contextlib import closing
from logging import Logger
from pathlib import Path
@ -12,7 +11,6 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import (
MigrateCallback,
Migration,
MigrationDependency,
MigrationError,
MigrationSet,
MigrationVersionError,
@ -53,7 +51,7 @@ def no_op_migrate_callback() -> MigrateCallback:
@pytest.fixture
def migration_no_op(no_op_migrate_callback: MigrateCallback) -> Migration:
return Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback)
return Migration(from_version=0, to_version=1, callback=no_op_migrate_callback)
@pytest.fixture
@ -75,7 +73,7 @@ def migrate_callback_create_test_table() -> MigrateCallback:
@pytest.fixture
def migration_create_test_table(migrate_callback_create_test_table: MigrateCallback) -> Migration:
return Migration(from_version=0, to_version=1, migrate_callback=migrate_callback_create_test_table)
return Migration(from_version=0, to_version=1, callback=migrate_callback_create_test_table)
@pytest.fixture
@ -83,7 +81,7 @@ def failing_migration() -> Migration:
def failing_migration(cursor: sqlite3.Cursor, **kwargs) -> None:
raise Exception("Bad migration")
return Migration(from_version=0, to_version=1, migrate_callback=failing_migration)
return Migration(from_version=0, to_version=1, callback=failing_migration)
@pytest.fixture
@ -101,40 +99,15 @@ def create_migrate(i: int) -> MigrateCallback:
return migrate
def test_migration_dependency_sets_value_primitive() -> None:
dependency = MigrationDependency(name="test_dependency", dependency_type=str)
dependency.set_value("test")
assert dependency.value == "test"
with pytest.raises(TypeError, match=r"Dependency test_dependency must be of type.*str"):
dependency.set_value(1)
def test_migration_dependency_sets_value_complex() -> None:
class SomeBase(ABC):
@abstractmethod
def some_method(self) -> None:
pass
class SomeImpl(SomeBase):
def some_method(self) -> None:
return
dependency = MigrationDependency(name="test_dependency", dependency_type=SomeBase)
with pytest.raises(TypeError, match=r"Dependency test_dependency must be of type.*SomeBase"):
dependency.set_value(1)
# not throwing is sufficient
dependency.set_value(SomeImpl())
def test_migration_to_version_is_one_gt_from_version(no_op_migrate_callback: MigrateCallback) -> None:
with pytest.raises(ValidationError, match="to_version must be one greater than from_version"):
Migration(from_version=0, to_version=2, migrate_callback=no_op_migrate_callback)
Migration(from_version=0, to_version=2, callback=no_op_migrate_callback)
# not raising is sufficient
Migration(from_version=1, to_version=2, migrate_callback=no_op_migrate_callback)
Migration(from_version=1, to_version=2, callback=no_op_migrate_callback)
def test_migration_hash(no_op_migrate_callback: MigrateCallback) -> None:
migration = Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback)
migration = Migration(from_version=0, to_version=1, callback=no_op_migrate_callback)
assert hash(migration) == hash((0, 1))
@ -147,13 +120,13 @@ def test_migration_set_add_migration(migrator: SQLiteMigrator, migration_no_op:
def test_migration_set_may_not_register_dupes(
migrator: SQLiteMigrator, no_op_migrate_callback: MigrateCallback
) -> None:
migrate_0_to_1_a = Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback)
migrate_0_to_1_b = Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback)
migrate_0_to_1_a = Migration(from_version=0, to_version=1, callback=no_op_migrate_callback)
migrate_0_to_1_b = Migration(from_version=0, to_version=1, callback=no_op_migrate_callback)
migrator._migration_set.register(migrate_0_to_1_a)
with pytest.raises(MigrationVersionError, match=r"Migration with from_version or to_version already registered"):
migrator._migration_set.register(migrate_0_to_1_b)
migrate_1_to_2_a = Migration(from_version=1, to_version=2, migrate_callback=no_op_migrate_callback)
migrate_1_to_2_b = Migration(from_version=1, to_version=2, migrate_callback=no_op_migrate_callback)
migrate_1_to_2_a = Migration(from_version=1, to_version=2, callback=no_op_migrate_callback)
migrate_1_to_2_b = Migration(from_version=1, to_version=2, callback=no_op_migrate_callback)
migrator._migration_set.register(migrate_1_to_2_a)
with pytest.raises(MigrationVersionError, match=r"Migration with from_version or to_version already registered"):
migrator._migration_set.register(migrate_1_to_2_b)
@ -168,15 +141,15 @@ def test_migration_set_gets_migration(migration_no_op: Migration) -> None:
def test_migration_set_validates_migration_chain(no_op_migrate_callback: MigrateCallback) -> None:
migration_set = MigrationSet()
migration_set.register(Migration(from_version=1, to_version=2, migrate_callback=no_op_migrate_callback))
migration_set.register(Migration(from_version=1, to_version=2, callback=no_op_migrate_callback))
with pytest.raises(MigrationError, match="Migration chain is fragmented"):
# no migration from 0 to 1
migration_set.validate_migration_chain()
migration_set.register(Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback))
migration_set.register(Migration(from_version=0, to_version=1, callback=no_op_migrate_callback))
migration_set.validate_migration_chain()
migration_set.register(Migration(from_version=2, to_version=3, migrate_callback=no_op_migrate_callback))
migration_set.register(Migration(from_version=2, to_version=3, callback=no_op_migrate_callback))
migration_set.validate_migration_chain()
migration_set.register(Migration(from_version=4, to_version=5, migrate_callback=no_op_migrate_callback))
migration_set.register(Migration(from_version=4, to_version=5, callback=no_op_migrate_callback))
with pytest.raises(MigrationError, match="Migration chain is fragmented"):
# no migration from 3 to 4
migration_set.validate_migration_chain()
@ -185,64 +158,32 @@ def test_migration_set_validates_migration_chain(no_op_migrate_callback: Migrate
def test_migration_set_counts_migrations(no_op_migrate_callback: MigrateCallback) -> None:
migration_set = MigrationSet()
assert migration_set.count == 0
migration_set.register(Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback))
migration_set.register(Migration(from_version=0, to_version=1, callback=no_op_migrate_callback))
assert migration_set.count == 1
migration_set.register(Migration(from_version=1, to_version=2, migrate_callback=no_op_migrate_callback))
migration_set.register(Migration(from_version=1, to_version=2, callback=no_op_migrate_callback))
assert migration_set.count == 2
def test_migration_set_gets_latest_version(no_op_migrate_callback: MigrateCallback) -> None:
migration_set = MigrationSet()
assert migration_set.latest_version == 0
migration_set.register(Migration(from_version=1, to_version=2, migrate_callback=no_op_migrate_callback))
migration_set.register(Migration(from_version=1, to_version=2, callback=no_op_migrate_callback))
assert migration_set.latest_version == 2
migration_set.register(Migration(from_version=0, to_version=1, migrate_callback=no_op_migrate_callback))
migration_set.register(Migration(from_version=0, to_version=1, callback=no_op_migrate_callback))
assert migration_set.latest_version == 2
def test_migration_provide_dependency_validates_name(no_op_migrate_callback: MigrateCallback) -> None:
dependency = MigrationDependency(name="my_dependency", dependency_type=str)
def test_migration_runs(memory_db_cursor: sqlite3.Cursor, migrate_callback_create_test_table: MigrateCallback) -> None:
migration = Migration(
from_version=0,
to_version=1,
migrate_callback=no_op_migrate_callback,
dependencies={dependency.name: dependency},
callback=migrate_callback_create_test_table,
)
with pytest.raises(ValueError, match="is not a dependency of this migration"):
migration.provide_dependency("unknown_dependency_name", "banana_sushi")
def test_migration_runs_without_dependencies(
memory_db_cursor: sqlite3.Cursor, migrate_callback_create_test_table: MigrateCallback
) -> None:
migration = Migration(
from_version=0,
to_version=1,
migrate_callback=migrate_callback_create_test_table,
)
migration.run(memory_db_cursor)
migration.callback(memory_db_cursor)
memory_db_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test';")
assert memory_db_cursor.fetchone() is not None
def test_migration_runs_with_dependencies(
memory_db_cursor: sqlite3.Cursor, migrate_callback_create_table_of_name: MigrateCallback
) -> None:
dependency = MigrationDependency(name="table_name", dependency_type=str)
migration = Migration(
from_version=0,
to_version=1,
migrate_callback=migrate_callback_create_table_of_name,
dependencies={dependency.name: dependency},
)
with pytest.raises(MigrationError, match="Missing migration dependencies"):
migration.run(memory_db_cursor)
migration.provide_dependency(dependency.name, "banana_sushi")
migration.run(memory_db_cursor)
memory_db_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='banana_sushi';")
assert memory_db_cursor.fetchone() is not None
def test_migrator_registers_migration(migrator: SQLiteMigrator, migration_no_op: Migration) -> None:
migration = migration_no_op
migrator.register_migration(migration)
@ -286,7 +227,7 @@ def test_migrator_runs_single_migration(migrator: SQLiteMigrator, migration_crea
def test_migrator_runs_all_migrations_in_memory(migrator: SQLiteMigrator) -> None:
cursor = migrator._db.conn.cursor()
migrations = [Migration(from_version=i, to_version=i + 1, migrate_callback=create_migrate(i)) for i in range(0, 3)]
migrations = [Migration(from_version=i, to_version=i + 1, callback=create_migrate(i)) for i in range(0, 3)]
for migration in migrations:
migrator.register_migration(migration)
migrator.run_migrations()
@ -299,9 +240,7 @@ def test_migrator_runs_all_migrations_file(logger: Logger) -> None:
# The Migrator closes the database when it finishes; we cannot use a context manager.
db = SqliteDatabase(db_path=original_db_path, logger=logger, verbose=False)
migrator = SQLiteMigrator(db=db)
migrations = [
Migration(from_version=i, to_version=i + 1, migrate_callback=create_migrate(i)) for i in range(0, 3)
]
migrations = [Migration(from_version=i, to_version=i + 1, callback=create_migrate(i)) for i in range(0, 3)]
for migration in migrations:
migrator.register_migration(migration)
migrator.run_migrations()
@ -319,7 +258,7 @@ def test_migrator_makes_no_changes_on_failed_migration(
migrator.register_migration(migration_no_op)
migrator.run_migrations()
assert migrator._get_current_version(cursor) == 1
migrator.register_migration(Migration(from_version=1, to_version=2, migrate_callback=failing_migrate_callback))
migrator.register_migration(Migration(from_version=1, to_version=2, callback=failing_migrate_callback))
with pytest.raises(MigrationError, match="Bad migration"):
migrator.run_migrations()
assert migrator._get_current_version(cursor) == 1