feat(db): incorporate feedback

This commit is contained in:
psychedelicious 2023-12-10 19:51:45 +11:00
parent c382329e8c
commit 0710ec30cf
6 changed files with 229 additions and 215 deletions

@ -70,9 +70,10 @@ class ApiDependencies:
logger.debug(f"Internet connectivity is {config.internet_available}")
output_folder = config.output_path
image_files = DiskImageFileStorage(f"{output_folder}/images")
db = SqliteDatabase(config, logger)
migrator = SQLiteMigrator(conn=db.conn, database=db.database, lock=db.lock, logger=logger)
migrator = SQLiteMigrator(db=db, image_files=image_files)
migrator.register_migration(migration_1)
migrator.register_migration(migration_2)
migrator.run_migrations()
@ -87,7 +88,6 @@ class ApiDependencies:
events = FastAPIEventService(event_handler_id)
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
graph_library = SqliteItemStorage[LibraryGraph](db=db, table_name="graphs")
image_files = DiskImageFileStorage(f"{output_folder}/images")
image_records = SqliteImageRecordStorage(db=db)
images = ImageService()
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)

@ -1,11 +1,14 @@
import sqlite3
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite.sqlite_migrator import Migration
def _migrate(cursor: sqlite3.Cursor) -> None:
def _migrate(db: SqliteDatabase, image_files: ImageFileStorageBase) -> None:
"""Migration callback for database version 1."""
cursor = db.conn.cursor()
_create_board_images(cursor)
_create_boards(cursor)
_create_images(cursor)
@ -350,7 +353,11 @@ def _create_workflows(cursor: sqlite3.Cursor) -> None:
cursor.execute(stmt)
migration_1 = Migration(db_version=1, app_version="3.4.0", migrate=_migrate)
migration_1 = Migration(
from_version=0,
to_version=1,
migrate=_migrate,
)
"""
Database version 1 (initial state).

@ -1,39 +1,55 @@
import sqlite3
from logging import Logger
from tqdm import tqdm
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite.sqlite_migrator import Migration
def _migrate(cursor: sqlite3.Cursor) -> None:
def _migrate(db: SqliteDatabase, image_files: ImageFileStorageBase) -> None:
"""Migration callback for database version 2."""
cursor = db.conn.cursor()
_add_images_has_workflow(cursor)
_add_session_queue_workflow(cursor)
_drop_old_workflow_tables(cursor)
_add_workflow_library(cursor)
_drop_model_manager_metadata(cursor)
_migrate_embedded_workflows(cursor=cursor, image_files=image_files, logger=db._logger)
def _add_images_has_workflow(cursor: sqlite3.Cursor) -> None:
"""Add the `has_workflow` column to `images` table."""
cursor.execute("ALTER TABLE images ADD COLUMN has_workflow BOOLEAN DEFAULT FALSE;")
cursor.execute("PRAGMA table_info(images)")
columns = [column[1] for column in cursor.fetchall()]
if "has_workflow" not in columns:
cursor.execute("ALTER TABLE images ADD COLUMN has_workflow BOOLEAN DEFAULT FALSE;")
def _add_session_queue_workflow(cursor: sqlite3.Cursor) -> None:
"""Add the `workflow` column to `session_queue` table."""
cursor.execute("ALTER TABLE session_queue ADD COLUMN workflow TEXT;")
cursor.execute("PRAGMA table_info(session_queue)")
columns = [column[1] for column in cursor.fetchall()]
if "workflow" not in columns:
cursor.execute("ALTER TABLE session_queue ADD COLUMN workflow TEXT;")
def _drop_old_workflow_tables(cursor: sqlite3.Cursor) -> None:
"""Drops the `workflows` and `workflow_images` tables."""
cursor.execute("DROP TABLE workflow_images;")
cursor.execute("DROP TABLE workflows;")
cursor.execute("DROP TABLE IF EXISTS workflow_images;")
cursor.execute("DROP TABLE IF EXISTS workflows;")
def _add_workflow_library(cursor: sqlite3.Cursor) -> None:
"""Adds the `workflow_library` table and drops the `workflows` and `workflow_images` tables."""
tables = [
"""--sql
CREATE TABLE workflow_library (
CREATE TABLE IF NOT EXISTS workflow_library (
workflow_id TEXT NOT NULL PRIMARY KEY,
workflow TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
@ -50,17 +66,17 @@ def _add_workflow_library(cursor: sqlite3.Cursor) -> None:
]
indices = [
"CREATE INDEX idx_workflow_library_created_at ON workflow_library(created_at);",
"CREATE INDEX idx_workflow_library_updated_at ON workflow_library(updated_at);",
"CREATE INDEX idx_workflow_library_opened_at ON workflow_library(opened_at);",
"CREATE INDEX idx_workflow_library_category ON workflow_library(category);",
"CREATE INDEX idx_workflow_library_name ON workflow_library(name);",
"CREATE INDEX idx_workflow_library_description ON workflow_library(description);",
"CREATE INDEX IF NOT EXISTS idx_workflow_library_created_at ON workflow_library(created_at);",
"CREATE INDEX IF NOT EXISTS idx_workflow_library_updated_at ON workflow_library(updated_at);",
"CREATE INDEX IF NOT EXISTS idx_workflow_library_opened_at ON workflow_library(opened_at);",
"CREATE INDEX IF NOT EXISTS idx_workflow_library_category ON workflow_library(category);",
"CREATE INDEX IF NOT EXISTS idx_workflow_library_name ON workflow_library(name);",
"CREATE INDEX IF NOT EXISTS idx_workflow_library_description ON workflow_library(description);",
]
triggers = [
"""--sql
CREATE TRIGGER tg_workflow_library_updated_at
CREATE TRIGGER IF NOT EXISTS tg_workflow_library_updated_at
AFTER UPDATE
ON workflow_library FOR EACH ROW
BEGIN
@ -77,12 +93,43 @@ def _add_workflow_library(cursor: sqlite3.Cursor) -> None:
def _drop_model_manager_metadata(cursor: sqlite3.Cursor) -> None:
"""Drops the `model_manager_metadata` table."""
cursor.execute("DROP TABLE model_manager_metadata;")
cursor.execute("DROP TABLE IF EXISTS model_manager_metadata;")
def _migrate_embedded_workflows(cursor: sqlite3.Cursor, image_files: ImageFileStorageBase, logger: Logger) -> None:
"""
In the v3.5.0 release, InvokeAI changed how it handles embedded workflows. The `images` table in
the database now has a `has_workflow` column, indicating if an image has a workflow embedded.
This migrate callbakc checks each image for the presence of an embedded workflow, then updates its entry
in the database accordingly.
"""
# Get the total number of images and chunk it into pages
cursor.execute("SELECT image_name FROM images")
image_names: list[str] = [image[0] for image in cursor.fetchall()]
total_image_names = len(image_names)
if not total_image_names:
return
logger.info(f"Migrating workflows for {total_image_names} images")
# Migrate the images
to_migrate: list[tuple[bool, str]] = []
pbar = tqdm(image_names)
for idx, image_name in enumerate(pbar):
pbar.set_description(f"Checking image {idx + 1}/{total_image_names} for workflow")
pil_image = image_files.get(image_name)
if "invokeai_workflow" in pil_image.info:
to_migrate.append((True, image_name))
logger.info(f"Adding {len(to_migrate)} embedded workflows to database")
cursor.executemany("UPDATE images SET has_workflow = ? WHERE image_name = ?", to_migrate)
migration_2 = Migration(
db_version=2,
app_version="3.5.0",
from_version=1,
to_version=2,
migrate=_migrate,
)
"""
@ -90,8 +137,10 @@ Database version 2.
Introduced in v3.5.0 for the new workflow library.
Migration:
- Add `has_workflow` column to `images` table
- Add `workflow` column to `session_queue` table
- Drop `workflows` and `workflow_images` tables
- Add `workflow_library` table
- Updates `has_workflow` for all images
"""

@ -8,13 +8,15 @@ from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
class SqliteDatabase:
database: Path | str
database: Path | str # Must declare this here to satisfy type checker
def __init__(self, config: InvokeAIAppConfig, logger: Logger):
self._logger = logger
self._config = config
self.is_memory = False
if self._config.use_memory_db:
self.database = sqlite_memory
self.is_memory = True
logger.info("Using in-memory database")
else:
self.database = self._config.db_path

@ -1,14 +1,15 @@
import shutil
import sqlite3
import threading
from datetime import datetime
from logging import Logger
from pathlib import Path
from typing import Callable, Optional, TypeAlias
from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
from pydantic import BaseModel, Field, model_validator
MigrateCallback: TypeAlias = Callable[[sqlite3.Cursor], None]
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
MigrateCallback: TypeAlias = Callable[[SqliteDatabase, ImageFileStorageBase], None]
class MigrationError(Exception):
@ -19,146 +20,186 @@ class MigrationVersionError(ValueError, MigrationError):
"""Raised when a migration version is invalid."""
class Migration:
"""Represents a migration for a SQLite database.
class Migration(BaseModel):
"""
Represents a migration for a SQLite database.
:param db_version: The database schema version this migration results in.
:param app_version: The app version this migration is introduced in.
:param migrate: The callback to run to perform the migration. The callback will be passed a
cursor to the database. The migrator will manage locking database access and committing the
transaction; the callback should not do either of these things.
Migration callbacks will be provided an instance of SqliteDatabase.
Migration callbacks should not commit; the migrator will commit the transaction.
"""
def __init__(
self,
db_version: int,
app_version: str,
migrate: MigrateCallback,
) -> None:
self.db_version = db_version
self.app_version = app_version
self.migrate = migrate
from_version: int = Field(ge=0, strict=True, description="The database version on which this migration may be run")
to_version: int = Field(ge=1, strict=True, description="The database version that results from this migration")
migrate: MigrateCallback = Field(description="The callback to run to perform the migration")
pre_migrate: list[MigrateCallback] = Field(
default=[], description="A list of callbacks to run before the migration"
)
post_migrate: list[MigrateCallback] = Field(
default=[], description="A list of callbacks to run after the migration"
)
@model_validator(mode="after")
def validate_to_version(self) -> "Migration":
if self.to_version <= self.from_version:
raise ValueError("to_version must be greater than from_version")
return self
def __hash__(self) -> int:
# Callables are not hashable, so we need to implement our own __hash__ function to use this class in a set.
return hash((self.from_version, self.to_version))
class MigrationSet:
"""A set of Migrations. Performs validation during migration registration and provides utility methods."""
def __init__(self) -> None:
self._migrations: set[Migration] = set()
def register(self, migration: Migration) -> None:
"""Registers a migration."""
if any(m.from_version == migration.from_version for m in self._migrations):
raise MigrationVersionError(f"Migration from {migration.from_version} already registered")
if any(m.to_version == migration.to_version for m in self._migrations):
raise MigrationVersionError(f"Migration to {migration.to_version} already registered")
self._migrations.add(migration)
def get(self, from_version: int) -> Optional[Migration]:
"""Gets the migration that may be run on the given database version."""
# register() ensures that there is only one migration with a given from_version, so this is safe.
return next((m for m in self._migrations if m.from_version == from_version), None)
@property
def count(self) -> int:
"""The count of registered migrations."""
return len(self._migrations)
@property
def latest_version(self) -> int:
"""Gets latest to_version among registered migrations. Returns 0 if there are no migrations registered."""
if self.count == 0:
return 0
return sorted(self._migrations, key=lambda m: m.to_version)[-1].to_version
class SQLiteMigrator:
"""
Manages migrations for a SQLite database.
:param conn: The database connection.
:param database: The path to the database file, or ":memory:" for an in-memory database.
:param lock: A lock to use when accessing the database.
:param logger: The logger to use.
:param db: The SqliteDatabase, representing the database on which to run migrations.
:param image_files: An instance of ImageFileStorageBase. Migrations may need to access image files.
Migrations should be registered with :meth:`register_migration`. Migrations will be run in
order of their version number. If the database is already at the latest version, no migrations
will be run.
"""
def __init__(self, conn: sqlite3.Connection, database: Path | str, lock: threading.RLock, logger: Logger) -> None:
self._logger = logger
self._conn = conn
self._cursor = self._conn.cursor()
self._lock = lock
self._database = database
self._migrations: set[Migration] = set()
backup_path: Optional[Path] = None
def __init__(self, db: SqliteDatabase, image_files: ImageFileStorageBase) -> None:
self._image_files = image_files
self._db = db
self._logger = self._db._logger
self._config = self._db._config
self._cursor = self._db.conn.cursor()
self._migrations = MigrationSet()
def register_migration(self, migration: Migration) -> None:
"""Registers a migration."""
if not isinstance(migration.db_version, int) or migration.db_version < 1:
raise MigrationVersionError(f"Invalid migration version {migration.db_version}")
if any(m.db_version == migration.db_version for m in self._migrations):
raise MigrationVersionError(f"Migration version {migration.db_version} already registered")
self._migrations.add(migration)
self._logger.debug(f"Registered migration {migration.db_version}")
"""
Registers a migration.
Migration callbacks should not commit any changes to the database; the migrator will commit the transaction.
"""
self._migrations.register(migration)
self._logger.debug(f"Registered migration {migration.from_version} -> {migration.to_version}")
def run_migrations(self) -> None:
"""Migrates the database to the latest version."""
with self._lock:
with self._db.lock:
self._create_version_table()
sorted_migrations = sorted(self._migrations, key=lambda m: m.db_version)
current_version = self._get_current_version()
if len(sorted_migrations) == 0:
if self._migrations.count == 0:
self._logger.debug("No migrations registered")
return
latest_version = sorted_migrations[-1].db_version
latest_version = self._migrations.latest_version
if current_version == latest_version:
self._logger.debug("Database is up to date, no migrations to run")
return
if current_version > latest_version:
raise MigrationError(
f"Database version {current_version} is greater than the latest migration version {latest_version}"
)
self._logger.info("Database update needed")
# Only make a backup if using a file database (not memory)
backup_path: Optional[Path] = None
if isinstance(self._database, Path):
backup_path = self._backup_db(self._database)
else:
self._logger.info("Using in-memory database, skipping backup")
self._backup_db()
for migration in sorted_migrations:
next_migration = self._migrations.get(from_version=current_version)
while next_migration is not None:
try:
self._run_migration(migration)
self._run_migration(next_migration)
next_migration = self._migrations.get(self._get_current_version())
except MigrationError:
if backup_path is not None:
self._logger.error(f" Restoring from {backup_path}")
self._restore_db(backup_path)
self._restore_db()
raise
self._logger.info("Database updated successfully")
def _run_migration(self, migration: Migration) -> None:
"""Runs a single migration."""
with self._lock:
current_version = self._get_current_version()
with self._db.lock:
try:
if current_version >= migration.db_version:
return
migration.migrate(self._cursor)
if self._get_current_version() != migration.from_version:
raise MigrationError(
f"Database is at version {self._get_current_version()}, expected {migration.from_version}"
)
self._logger.debug(f"Running migration from {migration.from_version} to {migration.to_version}")
if migration.pre_migrate:
self._logger.debug(f"Running {len(migration.pre_migrate)} pre-migration callbacks")
for callback in migration.pre_migrate:
callback(self._db, self._image_files)
migration.migrate(self._db, self._image_files)
self._cursor.execute("INSERT INTO migrations (version) VALUES (?);", (migration.to_version,))
if migration.post_migrate:
self._logger.debug(f"Running {len(migration.post_migrate)} post-migration callbacks")
for callback in migration.post_migrate:
callback(self._db, self._image_files)
# Migration callbacks only get a cursor; they cannot commit the transaction.
self._conn.commit()
self._set_version(db_version=migration.db_version, app_version=migration.app_version)
self._logger.debug(f"Successfully migrated database from {current_version} to {migration.db_version}")
self._db.conn.commit()
self._logger.debug(
f"Successfully migrated database from {migration.from_version} to {migration.to_version}"
)
except Exception as e:
msg = f"Error migrating database from {current_version} to {migration.db_version}: {e}"
self._conn.rollback()
msg = f"Error migrating database from {migration.from_version} to {migration.to_version}: {e}"
self._db.conn.rollback()
self._logger.error(msg)
raise MigrationError(msg) from e
def _create_version_table(self) -> None:
"""Creates a version table for the database, if one does not already exist."""
with self._lock:
with self._db.lock:
try:
self._cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='version';")
self._cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='migrations';")
if self._cursor.fetchone() is not None:
return
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS version (
db_version INTEGER PRIMARY KEY,
app_version TEXT NOT NULL,
CREATE TABLE migrations (
version INTEGER PRIMARY KEY,
migrated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
);
"""
)
self._cursor.execute("INSERT INTO version (db_version, app_version) VALUES (?,?);", (0, "0.0.0"))
self._conn.commit()
self._logger.debug("Created version table")
self._cursor.execute("INSERT INTO migrations (version) VALUES (0);")
self._db.conn.commit()
self._logger.debug("Created migrations table")
except sqlite3.Error as e:
msg = f"Problem creation version table: {e}"
msg = f"Problem creating migrations table: {e}"
self._logger.error(msg)
self._conn.rollback()
self._db.conn.rollback()
raise MigrationError(msg) from e
def _get_current_version(self) -> int:
"""Gets the current version of the database, or 0 if the version table does not exist."""
with self._lock:
with self._db.lock:
try:
self._cursor.execute("SELECT MAX(db_version) FROM version;")
self._cursor.execute("SELECT MAX(version) FROM migrations;")
version = self._cursor.fetchone()[0]
if version is None:
return 0
@ -168,43 +209,42 @@ class SQLiteMigrator:
return 0
raise
def _set_version(self, db_version: int, app_version: str) -> None:
"""Adds a version entry to the table's version table."""
with self._lock:
try:
self._cursor.execute(
"INSERT INTO version (db_version, app_version) VALUES (?,?);", (db_version, app_version)
)
self._conn.commit()
except sqlite3.Error as e:
msg = f"Problem setting database version: {e}"
self._logger.error(msg)
self._conn.rollback()
raise MigrationError(msg) from e
def _backup_db(self, db_path: Path | str) -> Path:
def _backup_db(self) -> None:
"""Backs up the databse, returning the path to the backup file."""
if db_path == sqlite_memory:
raise MigrationError("Cannot back up memory database")
if not isinstance(db_path, Path):
raise MigrationError(f'Database path must be "{sqlite_memory}" or a Path')
with self._lock:
if self._db.is_memory:
self._logger.debug("Using memory database, skipping backup")
# Sanity check!
if not isinstance(self._db.database, Path):
raise MigrationError(f"Database path must be a Path, got {self._db.database} ({type(self._db.database)})")
with self._db.lock:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_path = db_path.parent / f"{db_path.stem}_{timestamp}.db"
backup_path = self._db.database.parent / f"{self._db.database.stem}_{timestamp}.db"
self._logger.info(f"Backing up database to {backup_path}")
# Use SQLite's built in backup capabilities so we don't need to worry about locking and such.
backup_conn = sqlite3.connect(backup_path)
with backup_conn:
self._conn.backup(backup_conn)
self._db.conn.backup(backup_conn)
backup_conn.close()
return backup_path
def _restore_db(self, backup_path: Path) -> None:
# Sanity check!
if not backup_path.is_file():
raise MigrationError("Unable to back up database")
self.backup_path = backup_path
def _restore_db(
self,
) -> None:
"""Restores the database from a backup file, unless the database is a memory database."""
if self._database == sqlite_memory:
# We don't need to restore a memory database.
if self._db.is_memory:
return
with self._lock:
self._logger.info(f"Restoring database from {backup_path}")
self._conn.close()
if not Path(backup_path).is_file():
raise FileNotFoundError(f"Backup file {backup_path} does not exist")
shutil.copy2(backup_path, self._database)
with self._db.lock:
self._logger.info(f"Restoring database from {self.backup_path}")
self._db.conn.close()
if self.backup_path is None:
raise FileNotFoundError("No backup path set")
if not Path(self.backup_path).is_file():
raise FileNotFoundError(f"Backup file {self.backup_path} does not exist")
shutil.copy2(self.backup_path, self._db.database)

@ -1,84 +0,0 @@
import sqlite3
from datetime import datetime
from pathlib import Path
from PIL import Image
from tqdm import tqdm
from invokeai.app.services.config.config_default import InvokeAIAppConfig
def migrate_image_workflows(output_path: Path, database: Path, page_size=100):
"""
In the v3.5.0 release, InvokeAI changed how it handles image workflows. The `images` table in
the database now has a `has_workflow` column, indicating if an image has a workflow embedded.
This script checks each image for the presence of an embedded workflow, then updates its entry
in the database accordingly.
1) Check if the database is updated to support image workflows. Aborts if it doesn't have the
`has_workflow` column yet.
2) Backs up the database.
3) Opens each image in the `images` table via PIL
4) Checks if the `"invokeai_workflow"` attribute its in the image's embedded metadata, indicating
that it has a workflow.
5) If it does, updates the `has_workflow` column for that image to `TRUE`.
If there are any problems, the script immediately aborts. Because the processing happens in chunks,
if there is a problem, it is suggested that you restore the database from the backup and try again.
"""
output_path = output_path
database = database
conn = sqlite3.connect(database)
cursor = conn.cursor()
# We can only migrate if the `images` table has the `has_workflow` column
cursor.execute("PRAGMA table_info(images)")
columns = [column[1] for column in cursor.fetchall()]
if "has_workflow" not in columns:
raise Exception("Database needs to be updated to support image workflows")
# Back up the database before we start
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_path = database.parent / f"{database.stem}_migrate-image-workflows_{timestamp}.db"
print(f"Backing up database to {backup_path}")
backup_conn = sqlite3.connect(backup_path)
with backup_conn:
conn.backup(backup_conn)
backup_conn.close()
# Get the total number of images and chunk it into pages
cursor.execute("SELECT COUNT(*) FROM images")
total_images = cursor.fetchone()[0]
total_pages = (total_images + page_size - 1) // page_size
print(f"Processing {total_images} images in chunks of {page_size} images...")
# Migrate the images
migrated_count = 0
pbar = tqdm(range(total_pages))
for page in pbar:
pbar.set_description(f"Migrating page {page + 1}/{total_pages}")
offset = page * page_size
cursor.execute("SELECT image_name FROM images LIMIT ? OFFSET ?", (page_size, offset))
images = cursor.fetchall()
for image_name in images:
image_path = output_path / "images" / image_name[0]
with Image.open(image_path) as img:
if "invokeai_workflow" in img.info:
cursor.execute("UPDATE images SET has_workflow = TRUE WHERE image_name = ?", (image_name[0],))
migrated_count += 1
conn.commit()
conn.close()
print(f"Migrated workflows for {migrated_count} images.")
if __name__ == "__main__":
config = InvokeAIAppConfig.get_config()
output_path = config.output_path
database = config.db_path
assert output_path is not None
assert output_path.exists()
assert database.exists()
migrate_image_workflows(output_path=output_path, database=database, page_size=100)