feat(db): incorporate feedback

This commit is contained in:
psychedelicious
2023-12-10 19:51:45 +11:00
parent c382329e8c
commit 0710ec30cf
6 changed files with 229 additions and 215 deletions

View File

@ -70,9 +70,10 @@ class ApiDependencies:
logger.debug(f"Internet connectivity is {config.internet_available}")
output_folder = config.output_path
image_files = DiskImageFileStorage(f"{output_folder}/images")
db = SqliteDatabase(config, logger)
migrator = SQLiteMigrator(conn=db.conn, database=db.database, lock=db.lock, logger=logger)
migrator = SQLiteMigrator(db=db, image_files=image_files)
migrator.register_migration(migration_1)
migrator.register_migration(migration_2)
migrator.run_migrations()
@ -87,7 +88,6 @@ class ApiDependencies:
events = FastAPIEventService(event_handler_id)
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
graph_library = SqliteItemStorage[LibraryGraph](db=db, table_name="graphs")
image_files = DiskImageFileStorage(f"{output_folder}/images")
image_records = SqliteImageRecordStorage(db=db)
images = ImageService()
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)

View File

@ -1,11 +1,14 @@
import sqlite3
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite.sqlite_migrator import Migration
def _migrate(cursor: sqlite3.Cursor) -> None:
def _migrate(db: SqliteDatabase, image_files: ImageFileStorageBase) -> None:
"""Migration callback for database version 1."""
cursor = db.conn.cursor()
_create_board_images(cursor)
_create_boards(cursor)
_create_images(cursor)
@ -350,7 +353,11 @@ def _create_workflows(cursor: sqlite3.Cursor) -> None:
cursor.execute(stmt)
migration_1 = Migration(db_version=1, app_version="3.4.0", migrate=_migrate)
migration_1 = Migration(
from_version=0,
to_version=1,
migrate=_migrate,
)
"""
Database version 1 (initial state).

View File

@ -1,39 +1,55 @@
import sqlite3
from logging import Logger
from tqdm import tqdm
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite.sqlite_migrator import Migration
def _migrate(cursor: sqlite3.Cursor) -> None:
def _migrate(db: SqliteDatabase, image_files: ImageFileStorageBase) -> None:
"""Migration callback for database version 2."""
cursor = db.conn.cursor()
_add_images_has_workflow(cursor)
_add_session_queue_workflow(cursor)
_drop_old_workflow_tables(cursor)
_add_workflow_library(cursor)
_drop_model_manager_metadata(cursor)
_migrate_embedded_workflows(cursor=cursor, image_files=image_files, logger=db._logger)
def _add_images_has_workflow(cursor: sqlite3.Cursor) -> None:
"""Add the `has_workflow` column to `images` table."""
cursor.execute("ALTER TABLE images ADD COLUMN has_workflow BOOLEAN DEFAULT FALSE;")
cursor.execute("PRAGMA table_info(images)")
columns = [column[1] for column in cursor.fetchall()]
if "has_workflow" not in columns:
cursor.execute("ALTER TABLE images ADD COLUMN has_workflow BOOLEAN DEFAULT FALSE;")
def _add_session_queue_workflow(cursor: sqlite3.Cursor) -> None:
"""Add the `workflow` column to `session_queue` table."""
cursor.execute("ALTER TABLE session_queue ADD COLUMN workflow TEXT;")
cursor.execute("PRAGMA table_info(session_queue)")
columns = [column[1] for column in cursor.fetchall()]
if "workflow" not in columns:
cursor.execute("ALTER TABLE session_queue ADD COLUMN workflow TEXT;")
def _drop_old_workflow_tables(cursor: sqlite3.Cursor) -> None:
"""Drops the `workflows` and `workflow_images` tables."""
cursor.execute("DROP TABLE workflow_images;")
cursor.execute("DROP TABLE workflows;")
cursor.execute("DROP TABLE IF EXISTS workflow_images;")
cursor.execute("DROP TABLE IF EXISTS workflows;")
def _add_workflow_library(cursor: sqlite3.Cursor) -> None:
"""Adds the `workflow_library` table and drops the `workflows` and `workflow_images` tables."""
tables = [
"""--sql
CREATE TABLE workflow_library (
CREATE TABLE IF NOT EXISTS workflow_library (
workflow_id TEXT NOT NULL PRIMARY KEY,
workflow TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
@ -50,17 +66,17 @@ def _add_workflow_library(cursor: sqlite3.Cursor) -> None:
]
indices = [
"CREATE INDEX idx_workflow_library_created_at ON workflow_library(created_at);",
"CREATE INDEX idx_workflow_library_updated_at ON workflow_library(updated_at);",
"CREATE INDEX idx_workflow_library_opened_at ON workflow_library(opened_at);",
"CREATE INDEX idx_workflow_library_category ON workflow_library(category);",
"CREATE INDEX idx_workflow_library_name ON workflow_library(name);",
"CREATE INDEX idx_workflow_library_description ON workflow_library(description);",
"CREATE INDEX IF NOT EXISTS idx_workflow_library_created_at ON workflow_library(created_at);",
"CREATE INDEX IF NOT EXISTS idx_workflow_library_updated_at ON workflow_library(updated_at);",
"CREATE INDEX IF NOT EXISTS idx_workflow_library_opened_at ON workflow_library(opened_at);",
"CREATE INDEX IF NOT EXISTS idx_workflow_library_category ON workflow_library(category);",
"CREATE INDEX IF NOT EXISTS idx_workflow_library_name ON workflow_library(name);",
"CREATE INDEX IF NOT EXISTS idx_workflow_library_description ON workflow_library(description);",
]
triggers = [
"""--sql
CREATE TRIGGER tg_workflow_library_updated_at
CREATE TRIGGER IF NOT EXISTS tg_workflow_library_updated_at
AFTER UPDATE
ON workflow_library FOR EACH ROW
BEGIN
@ -77,12 +93,43 @@ def _add_workflow_library(cursor: sqlite3.Cursor) -> None:
def _drop_model_manager_metadata(cursor: sqlite3.Cursor) -> None:
"""Drops the `model_manager_metadata` table."""
cursor.execute("DROP TABLE model_manager_metadata;")
cursor.execute("DROP TABLE IF EXISTS model_manager_metadata;")
def _migrate_embedded_workflows(cursor: sqlite3.Cursor, image_files: ImageFileStorageBase, logger: Logger) -> None:
"""
In the v3.5.0 release, InvokeAI changed how it handles embedded workflows. The `images` table in
the database now has a `has_workflow` column, indicating if an image has a workflow embedded.
This migrate callbakc checks each image for the presence of an embedded workflow, then updates its entry
in the database accordingly.
"""
# Get the total number of images and chunk it into pages
cursor.execute("SELECT image_name FROM images")
image_names: list[str] = [image[0] for image in cursor.fetchall()]
total_image_names = len(image_names)
if not total_image_names:
return
logger.info(f"Migrating workflows for {total_image_names} images")
# Migrate the images
to_migrate: list[tuple[bool, str]] = []
pbar = tqdm(image_names)
for idx, image_name in enumerate(pbar):
pbar.set_description(f"Checking image {idx + 1}/{total_image_names} for workflow")
pil_image = image_files.get(image_name)
if "invokeai_workflow" in pil_image.info:
to_migrate.append((True, image_name))
logger.info(f"Adding {len(to_migrate)} embedded workflows to database")
cursor.executemany("UPDATE images SET has_workflow = ? WHERE image_name = ?", to_migrate)
migration_2 = Migration(
db_version=2,
app_version="3.5.0",
from_version=1,
to_version=2,
migrate=_migrate,
)
"""
@ -90,8 +137,10 @@ Database version 2.
Introduced in v3.5.0 for the new workflow library.
Migration:
- Add `has_workflow` column to `images` table
- Add `workflow` column to `session_queue` table
- Drop `workflows` and `workflow_images` tables
- Add `workflow_library` table
- Updates `has_workflow` for all images
"""

View File

@ -8,13 +8,15 @@ from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
class SqliteDatabase:
database: Path | str
database: Path | str # Must declare this here to satisfy type checker
def __init__(self, config: InvokeAIAppConfig, logger: Logger):
self._logger = logger
self._config = config
self.is_memory = False
if self._config.use_memory_db:
self.database = sqlite_memory
self.is_memory = True
logger.info("Using in-memory database")
else:
self.database = self._config.db_path

View File

@ -1,14 +1,15 @@
import shutil
import sqlite3
import threading
from datetime import datetime
from logging import Logger
from pathlib import Path
from typing import Callable, Optional, TypeAlias
from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
from pydantic import BaseModel, Field, model_validator
MigrateCallback: TypeAlias = Callable[[sqlite3.Cursor], None]
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
MigrateCallback: TypeAlias = Callable[[SqliteDatabase, ImageFileStorageBase], None]
class MigrationError(Exception):
@ -19,146 +20,186 @@ class MigrationVersionError(ValueError, MigrationError):
"""Raised when a migration version is invalid."""
class Migration:
"""Represents a migration for a SQLite database.
class Migration(BaseModel):
"""
Represents a migration for a SQLite database.
:param db_version: The database schema version this migration results in.
:param app_version: The app version this migration is introduced in.
:param migrate: The callback to run to perform the migration. The callback will be passed a
cursor to the database. The migrator will manage locking database access and committing the
transaction; the callback should not do either of these things.
Migration callbacks will be provided an instance of SqliteDatabase.
Migration callbacks should not commit; the migrator will commit the transaction.
"""
def __init__(
self,
db_version: int,
app_version: str,
migrate: MigrateCallback,
) -> None:
self.db_version = db_version
self.app_version = app_version
self.migrate = migrate
from_version: int = Field(ge=0, strict=True, description="The database version on which this migration may be run")
to_version: int = Field(ge=1, strict=True, description="The database version that results from this migration")
migrate: MigrateCallback = Field(description="The callback to run to perform the migration")
pre_migrate: list[MigrateCallback] = Field(
default=[], description="A list of callbacks to run before the migration"
)
post_migrate: list[MigrateCallback] = Field(
default=[], description="A list of callbacks to run after the migration"
)
@model_validator(mode="after")
def validate_to_version(self) -> "Migration":
if self.to_version <= self.from_version:
raise ValueError("to_version must be greater than from_version")
return self
def __hash__(self) -> int:
# Callables are not hashable, so we need to implement our own __hash__ function to use this class in a set.
return hash((self.from_version, self.to_version))
class MigrationSet:
"""A set of Migrations. Performs validation during migration registration and provides utility methods."""
def __init__(self) -> None:
self._migrations: set[Migration] = set()
def register(self, migration: Migration) -> None:
"""Registers a migration."""
if any(m.from_version == migration.from_version for m in self._migrations):
raise MigrationVersionError(f"Migration from {migration.from_version} already registered")
if any(m.to_version == migration.to_version for m in self._migrations):
raise MigrationVersionError(f"Migration to {migration.to_version} already registered")
self._migrations.add(migration)
def get(self, from_version: int) -> Optional[Migration]:
"""Gets the migration that may be run on the given database version."""
# register() ensures that there is only one migration with a given from_version, so this is safe.
return next((m for m in self._migrations if m.from_version == from_version), None)
@property
def count(self) -> int:
"""The count of registered migrations."""
return len(self._migrations)
@property
def latest_version(self) -> int:
"""Gets latest to_version among registered migrations. Returns 0 if there are no migrations registered."""
if self.count == 0:
return 0
return sorted(self._migrations, key=lambda m: m.to_version)[-1].to_version
class SQLiteMigrator:
"""
Manages migrations for a SQLite database.
:param conn: The database connection.
:param database: The path to the database file, or ":memory:" for an in-memory database.
:param lock: A lock to use when accessing the database.
:param logger: The logger to use.
:param db: The SqliteDatabase, representing the database on which to run migrations.
:param image_files: An instance of ImageFileStorageBase. Migrations may need to access image files.
Migrations should be registered with :meth:`register_migration`. Migrations will be run in
order of their version number. If the database is already at the latest version, no migrations
will be run.
"""
def __init__(self, conn: sqlite3.Connection, database: Path | str, lock: threading.RLock, logger: Logger) -> None:
self._logger = logger
self._conn = conn
self._cursor = self._conn.cursor()
self._lock = lock
self._database = database
self._migrations: set[Migration] = set()
backup_path: Optional[Path] = None
def __init__(self, db: SqliteDatabase, image_files: ImageFileStorageBase) -> None:
self._image_files = image_files
self._db = db
self._logger = self._db._logger
self._config = self._db._config
self._cursor = self._db.conn.cursor()
self._migrations = MigrationSet()
def register_migration(self, migration: Migration) -> None:
"""Registers a migration."""
if not isinstance(migration.db_version, int) or migration.db_version < 1:
raise MigrationVersionError(f"Invalid migration version {migration.db_version}")
if any(m.db_version == migration.db_version for m in self._migrations):
raise MigrationVersionError(f"Migration version {migration.db_version} already registered")
self._migrations.add(migration)
self._logger.debug(f"Registered migration {migration.db_version}")
"""
Registers a migration.
Migration callbacks should not commit any changes to the database; the migrator will commit the transaction.
"""
self._migrations.register(migration)
self._logger.debug(f"Registered migration {migration.from_version} -> {migration.to_version}")
def run_migrations(self) -> None:
"""Migrates the database to the latest version."""
with self._lock:
with self._db.lock:
self._create_version_table()
sorted_migrations = sorted(self._migrations, key=lambda m: m.db_version)
current_version = self._get_current_version()
if len(sorted_migrations) == 0:
if self._migrations.count == 0:
self._logger.debug("No migrations registered")
return
latest_version = sorted_migrations[-1].db_version
latest_version = self._migrations.latest_version
if current_version == latest_version:
self._logger.debug("Database is up to date, no migrations to run")
return
if current_version > latest_version:
raise MigrationError(
f"Database version {current_version} is greater than the latest migration version {latest_version}"
)
self._logger.info("Database update needed")
# Only make a backup if using a file database (not memory)
backup_path: Optional[Path] = None
if isinstance(self._database, Path):
backup_path = self._backup_db(self._database)
else:
self._logger.info("Using in-memory database, skipping backup")
self._backup_db()
for migration in sorted_migrations:
next_migration = self._migrations.get(from_version=current_version)
while next_migration is not None:
try:
self._run_migration(migration)
self._run_migration(next_migration)
next_migration = self._migrations.get(self._get_current_version())
except MigrationError:
if backup_path is not None:
self._logger.error(f" Restoring from {backup_path}")
self._restore_db(backup_path)
self._restore_db()
raise
self._logger.info("Database updated successfully")
def _run_migration(self, migration: Migration) -> None:
"""Runs a single migration."""
with self._lock:
current_version = self._get_current_version()
with self._db.lock:
try:
if current_version >= migration.db_version:
return
migration.migrate(self._cursor)
if self._get_current_version() != migration.from_version:
raise MigrationError(
f"Database is at version {self._get_current_version()}, expected {migration.from_version}"
)
self._logger.debug(f"Running migration from {migration.from_version} to {migration.to_version}")
if migration.pre_migrate:
self._logger.debug(f"Running {len(migration.pre_migrate)} pre-migration callbacks")
for callback in migration.pre_migrate:
callback(self._db, self._image_files)
migration.migrate(self._db, self._image_files)
self._cursor.execute("INSERT INTO migrations (version) VALUES (?);", (migration.to_version,))
if migration.post_migrate:
self._logger.debug(f"Running {len(migration.post_migrate)} post-migration callbacks")
for callback in migration.post_migrate:
callback(self._db, self._image_files)
# Migration callbacks only get a cursor; they cannot commit the transaction.
self._conn.commit()
self._set_version(db_version=migration.db_version, app_version=migration.app_version)
self._logger.debug(f"Successfully migrated database from {current_version} to {migration.db_version}")
self._db.conn.commit()
self._logger.debug(
f"Successfully migrated database from {migration.from_version} to {migration.to_version}"
)
except Exception as e:
msg = f"Error migrating database from {current_version} to {migration.db_version}: {e}"
self._conn.rollback()
msg = f"Error migrating database from {migration.from_version} to {migration.to_version}: {e}"
self._db.conn.rollback()
self._logger.error(msg)
raise MigrationError(msg) from e
def _create_version_table(self) -> None:
"""Creates a version table for the database, if one does not already exist."""
with self._lock:
with self._db.lock:
try:
self._cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='version';")
self._cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='migrations';")
if self._cursor.fetchone() is not None:
return
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS version (
db_version INTEGER PRIMARY KEY,
app_version TEXT NOT NULL,
CREATE TABLE migrations (
version INTEGER PRIMARY KEY,
migrated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
);
"""
)
self._cursor.execute("INSERT INTO version (db_version, app_version) VALUES (?,?);", (0, "0.0.0"))
self._conn.commit()
self._logger.debug("Created version table")
self._cursor.execute("INSERT INTO migrations (version) VALUES (0);")
self._db.conn.commit()
self._logger.debug("Created migrations table")
except sqlite3.Error as e:
msg = f"Problem creation version table: {e}"
msg = f"Problem creating migrations table: {e}"
self._logger.error(msg)
self._conn.rollback()
self._db.conn.rollback()
raise MigrationError(msg) from e
def _get_current_version(self) -> int:
"""Gets the current version of the database, or 0 if the version table does not exist."""
with self._lock:
with self._db.lock:
try:
self._cursor.execute("SELECT MAX(db_version) FROM version;")
self._cursor.execute("SELECT MAX(version) FROM migrations;")
version = self._cursor.fetchone()[0]
if version is None:
return 0
@ -168,43 +209,42 @@ class SQLiteMigrator:
return 0
raise
def _set_version(self, db_version: int, app_version: str) -> None:
"""Adds a version entry to the table's version table."""
with self._lock:
try:
self._cursor.execute(
"INSERT INTO version (db_version, app_version) VALUES (?,?);", (db_version, app_version)
)
self._conn.commit()
except sqlite3.Error as e:
msg = f"Problem setting database version: {e}"
self._logger.error(msg)
self._conn.rollback()
raise MigrationError(msg) from e
def _backup_db(self, db_path: Path | str) -> Path:
def _backup_db(self) -> None:
"""Backs up the databse, returning the path to the backup file."""
if db_path == sqlite_memory:
raise MigrationError("Cannot back up memory database")
if not isinstance(db_path, Path):
raise MigrationError(f'Database path must be "{sqlite_memory}" or a Path')
with self._lock:
if self._db.is_memory:
self._logger.debug("Using memory database, skipping backup")
# Sanity check!
if not isinstance(self._db.database, Path):
raise MigrationError(f"Database path must be a Path, got {self._db.database} ({type(self._db.database)})")
with self._db.lock:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_path = db_path.parent / f"{db_path.stem}_{timestamp}.db"
backup_path = self._db.database.parent / f"{self._db.database.stem}_{timestamp}.db"
self._logger.info(f"Backing up database to {backup_path}")
# Use SQLite's built in backup capabilities so we don't need to worry about locking and such.
backup_conn = sqlite3.connect(backup_path)
with backup_conn:
self._conn.backup(backup_conn)
self._db.conn.backup(backup_conn)
backup_conn.close()
return backup_path
def _restore_db(self, backup_path: Path) -> None:
# Sanity check!
if not backup_path.is_file():
raise MigrationError("Unable to back up database")
self.backup_path = backup_path
def _restore_db(
self,
) -> None:
"""Restores the database from a backup file, unless the database is a memory database."""
if self._database == sqlite_memory:
# We don't need to restore a memory database.
if self._db.is_memory:
return
with self._lock:
self._logger.info(f"Restoring database from {backup_path}")
self._conn.close()
if not Path(backup_path).is_file():
raise FileNotFoundError(f"Backup file {backup_path} does not exist")
shutil.copy2(backup_path, self._database)
with self._db.lock:
self._logger.info(f"Restoring database from {self.backup_path}")
self._db.conn.close()
if self.backup_path is None:
raise FileNotFoundError("No backup path set")
if not Path(self.backup_path).is_file():
raise FileNotFoundError(f"Backup file {self.backup_path} does not exist")
shutil.copy2(self.backup_path, self._db.database)