feat(db): incorporate feedback

This commit is contained in:
psychedelicious 2023-12-10 19:51:45 +11:00
parent c382329e8c
commit 0710ec30cf
6 changed files with 229 additions and 215 deletions

View File

@ -70,9 +70,10 @@ class ApiDependencies:
logger.debug(f"Internet connectivity is {config.internet_available}") logger.debug(f"Internet connectivity is {config.internet_available}")
output_folder = config.output_path output_folder = config.output_path
image_files = DiskImageFileStorage(f"{output_folder}/images")
db = SqliteDatabase(config, logger) db = SqliteDatabase(config, logger)
migrator = SQLiteMigrator(conn=db.conn, database=db.database, lock=db.lock, logger=logger) migrator = SQLiteMigrator(db=db, image_files=image_files)
migrator.register_migration(migration_1) migrator.register_migration(migration_1)
migrator.register_migration(migration_2) migrator.register_migration(migration_2)
migrator.run_migrations() migrator.run_migrations()
@ -87,7 +88,6 @@ class ApiDependencies:
events = FastAPIEventService(event_handler_id) events = FastAPIEventService(event_handler_id)
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions") graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
graph_library = SqliteItemStorage[LibraryGraph](db=db, table_name="graphs") graph_library = SqliteItemStorage[LibraryGraph](db=db, table_name="graphs")
image_files = DiskImageFileStorage(f"{output_folder}/images")
image_records = SqliteImageRecordStorage(db=db) image_records = SqliteImageRecordStorage(db=db)
images = ImageService() images = ImageService()
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size) invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)

View File

@ -1,11 +1,14 @@
import sqlite3 import sqlite3
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite.sqlite_migrator import Migration from invokeai.app.services.shared.sqlite.sqlite_migrator import Migration
def _migrate(cursor: sqlite3.Cursor) -> None: def _migrate(db: SqliteDatabase, image_files: ImageFileStorageBase) -> None:
"""Migration callback for database version 1.""" """Migration callback for database version 1."""
cursor = db.conn.cursor()
_create_board_images(cursor) _create_board_images(cursor)
_create_boards(cursor) _create_boards(cursor)
_create_images(cursor) _create_images(cursor)
@ -350,7 +353,11 @@ def _create_workflows(cursor: sqlite3.Cursor) -> None:
cursor.execute(stmt) cursor.execute(stmt)
migration_1 = Migration(db_version=1, app_version="3.4.0", migrate=_migrate) migration_1 = Migration(
from_version=0,
to_version=1,
migrate=_migrate,
)
""" """
Database version 1 (initial state). Database version 1 (initial state).

View File

@ -1,39 +1,55 @@
import sqlite3 import sqlite3
from logging import Logger
from tqdm import tqdm
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite.sqlite_migrator import Migration from invokeai.app.services.shared.sqlite.sqlite_migrator import Migration
def _migrate(cursor: sqlite3.Cursor) -> None: def _migrate(db: SqliteDatabase, image_files: ImageFileStorageBase) -> None:
"""Migration callback for database version 2.""" """Migration callback for database version 2."""
cursor = db.conn.cursor()
_add_images_has_workflow(cursor) _add_images_has_workflow(cursor)
_add_session_queue_workflow(cursor) _add_session_queue_workflow(cursor)
_drop_old_workflow_tables(cursor) _drop_old_workflow_tables(cursor)
_add_workflow_library(cursor) _add_workflow_library(cursor)
_drop_model_manager_metadata(cursor) _drop_model_manager_metadata(cursor)
_migrate_embedded_workflows(cursor=cursor, image_files=image_files, logger=db._logger)
def _add_images_has_workflow(cursor: sqlite3.Cursor) -> None: def _add_images_has_workflow(cursor: sqlite3.Cursor) -> None:
"""Add the `has_workflow` column to `images` table.""" """Add the `has_workflow` column to `images` table."""
cursor.execute("ALTER TABLE images ADD COLUMN has_workflow BOOLEAN DEFAULT FALSE;") cursor.execute("PRAGMA table_info(images)")
columns = [column[1] for column in cursor.fetchall()]
if "has_workflow" not in columns:
cursor.execute("ALTER TABLE images ADD COLUMN has_workflow BOOLEAN DEFAULT FALSE;")
def _add_session_queue_workflow(cursor: sqlite3.Cursor) -> None: def _add_session_queue_workflow(cursor: sqlite3.Cursor) -> None:
"""Add the `workflow` column to `session_queue` table.""" """Add the `workflow` column to `session_queue` table."""
cursor.execute("ALTER TABLE session_queue ADD COLUMN workflow TEXT;")
cursor.execute("PRAGMA table_info(session_queue)")
columns = [column[1] for column in cursor.fetchall()]
if "workflow" not in columns:
cursor.execute("ALTER TABLE session_queue ADD COLUMN workflow TEXT;")
def _drop_old_workflow_tables(cursor: sqlite3.Cursor) -> None: def _drop_old_workflow_tables(cursor: sqlite3.Cursor) -> None:
"""Drops the `workflows` and `workflow_images` tables.""" """Drops the `workflows` and `workflow_images` tables."""
cursor.execute("DROP TABLE workflow_images;") cursor.execute("DROP TABLE IF EXISTS workflow_images;")
cursor.execute("DROP TABLE workflows;") cursor.execute("DROP TABLE IF EXISTS workflows;")
def _add_workflow_library(cursor: sqlite3.Cursor) -> None: def _add_workflow_library(cursor: sqlite3.Cursor) -> None:
"""Adds the `workflow_library` table and drops the `workflows` and `workflow_images` tables.""" """Adds the `workflow_library` table and drops the `workflows` and `workflow_images` tables."""
tables = [ tables = [
"""--sql """--sql
CREATE TABLE workflow_library ( CREATE TABLE IF NOT EXISTS workflow_library (
workflow_id TEXT NOT NULL PRIMARY KEY, workflow_id TEXT NOT NULL PRIMARY KEY,
workflow TEXT NOT NULL, workflow TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
@ -50,17 +66,17 @@ def _add_workflow_library(cursor: sqlite3.Cursor) -> None:
] ]
indices = [ indices = [
"CREATE INDEX idx_workflow_library_created_at ON workflow_library(created_at);", "CREATE INDEX IF NOT EXISTS idx_workflow_library_created_at ON workflow_library(created_at);",
"CREATE INDEX idx_workflow_library_updated_at ON workflow_library(updated_at);", "CREATE INDEX IF NOT EXISTS idx_workflow_library_updated_at ON workflow_library(updated_at);",
"CREATE INDEX idx_workflow_library_opened_at ON workflow_library(opened_at);", "CREATE INDEX IF NOT EXISTS idx_workflow_library_opened_at ON workflow_library(opened_at);",
"CREATE INDEX idx_workflow_library_category ON workflow_library(category);", "CREATE INDEX IF NOT EXISTS idx_workflow_library_category ON workflow_library(category);",
"CREATE INDEX idx_workflow_library_name ON workflow_library(name);", "CREATE INDEX IF NOT EXISTS idx_workflow_library_name ON workflow_library(name);",
"CREATE INDEX idx_workflow_library_description ON workflow_library(description);", "CREATE INDEX IF NOT EXISTS idx_workflow_library_description ON workflow_library(description);",
] ]
triggers = [ triggers = [
"""--sql """--sql
CREATE TRIGGER tg_workflow_library_updated_at CREATE TRIGGER IF NOT EXISTS tg_workflow_library_updated_at
AFTER UPDATE AFTER UPDATE
ON workflow_library FOR EACH ROW ON workflow_library FOR EACH ROW
BEGIN BEGIN
@ -77,12 +93,43 @@ def _add_workflow_library(cursor: sqlite3.Cursor) -> None:
def _drop_model_manager_metadata(cursor: sqlite3.Cursor) -> None: def _drop_model_manager_metadata(cursor: sqlite3.Cursor) -> None:
"""Drops the `model_manager_metadata` table.""" """Drops the `model_manager_metadata` table."""
cursor.execute("DROP TABLE model_manager_metadata;") cursor.execute("DROP TABLE IF EXISTS model_manager_metadata;")
def _migrate_embedded_workflows(cursor: sqlite3.Cursor, image_files: ImageFileStorageBase, logger: Logger) -> None:
"""
In the v3.5.0 release, InvokeAI changed how it handles embedded workflows. The `images` table in
the database now has a `has_workflow` column, indicating if an image has a workflow embedded.
This migrate callbakc checks each image for the presence of an embedded workflow, then updates its entry
in the database accordingly.
"""
# Get the total number of images and chunk it into pages
cursor.execute("SELECT image_name FROM images")
image_names: list[str] = [image[0] for image in cursor.fetchall()]
total_image_names = len(image_names)
if not total_image_names:
return
logger.info(f"Migrating workflows for {total_image_names} images")
# Migrate the images
to_migrate: list[tuple[bool, str]] = []
pbar = tqdm(image_names)
for idx, image_name in enumerate(pbar):
pbar.set_description(f"Checking image {idx + 1}/{total_image_names} for workflow")
pil_image = image_files.get(image_name)
if "invokeai_workflow" in pil_image.info:
to_migrate.append((True, image_name))
logger.info(f"Adding {len(to_migrate)} embedded workflows to database")
cursor.executemany("UPDATE images SET has_workflow = ? WHERE image_name = ?", to_migrate)
migration_2 = Migration( migration_2 = Migration(
db_version=2, from_version=1,
app_version="3.5.0", to_version=2,
migrate=_migrate, migrate=_migrate,
) )
""" """
@ -90,8 +137,10 @@ Database version 2.
Introduced in v3.5.0 for the new workflow library. Introduced in v3.5.0 for the new workflow library.
Migration:
- Add `has_workflow` column to `images` table - Add `has_workflow` column to `images` table
- Add `workflow` column to `session_queue` table - Add `workflow` column to `session_queue` table
- Drop `workflows` and `workflow_images` tables - Drop `workflows` and `workflow_images` tables
- Add `workflow_library` table - Add `workflow_library` table
- Updates `has_workflow` for all images
""" """

View File

@ -8,13 +8,15 @@ from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
class SqliteDatabase: class SqliteDatabase:
database: Path | str database: Path | str # Must declare this here to satisfy type checker
def __init__(self, config: InvokeAIAppConfig, logger: Logger): def __init__(self, config: InvokeAIAppConfig, logger: Logger):
self._logger = logger self._logger = logger
self._config = config self._config = config
self.is_memory = False
if self._config.use_memory_db: if self._config.use_memory_db:
self.database = sqlite_memory self.database = sqlite_memory
self.is_memory = True
logger.info("Using in-memory database") logger.info("Using in-memory database")
else: else:
self.database = self._config.db_path self.database = self._config.db_path

View File

@ -1,14 +1,15 @@
import shutil import shutil
import sqlite3 import sqlite3
import threading
from datetime import datetime from datetime import datetime
from logging import Logger
from pathlib import Path from pathlib import Path
from typing import Callable, Optional, TypeAlias from typing import Callable, Optional, TypeAlias
from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory from pydantic import BaseModel, Field, model_validator
MigrateCallback: TypeAlias = Callable[[sqlite3.Cursor], None] from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
MigrateCallback: TypeAlias = Callable[[SqliteDatabase, ImageFileStorageBase], None]
class MigrationError(Exception): class MigrationError(Exception):
@ -19,146 +20,186 @@ class MigrationVersionError(ValueError, MigrationError):
"""Raised when a migration version is invalid.""" """Raised when a migration version is invalid."""
class Migration: class Migration(BaseModel):
"""Represents a migration for a SQLite database. """
Represents a migration for a SQLite database.
:param db_version: The database schema version this migration results in. Migration callbacks will be provided an instance of SqliteDatabase.
:param app_version: The app version this migration is introduced in. Migration callbacks should not commit; the migrator will commit the transaction.
:param migrate: The callback to run to perform the migration. The callback will be passed a
cursor to the database. The migrator will manage locking database access and committing the
transaction; the callback should not do either of these things.
""" """
def __init__( from_version: int = Field(ge=0, strict=True, description="The database version on which this migration may be run")
self, to_version: int = Field(ge=1, strict=True, description="The database version that results from this migration")
db_version: int, migrate: MigrateCallback = Field(description="The callback to run to perform the migration")
app_version: str, pre_migrate: list[MigrateCallback] = Field(
migrate: MigrateCallback, default=[], description="A list of callbacks to run before the migration"
) -> None: )
self.db_version = db_version post_migrate: list[MigrateCallback] = Field(
self.app_version = app_version default=[], description="A list of callbacks to run after the migration"
self.migrate = migrate )
@model_validator(mode="after")
def validate_to_version(self) -> "Migration":
if self.to_version <= self.from_version:
raise ValueError("to_version must be greater than from_version")
return self
def __hash__(self) -> int:
# Callables are not hashable, so we need to implement our own __hash__ function to use this class in a set.
return hash((self.from_version, self.to_version))
class MigrationSet:
"""A set of Migrations. Performs validation during migration registration and provides utility methods."""
def __init__(self) -> None:
self._migrations: set[Migration] = set()
def register(self, migration: Migration) -> None:
"""Registers a migration."""
if any(m.from_version == migration.from_version for m in self._migrations):
raise MigrationVersionError(f"Migration from {migration.from_version} already registered")
if any(m.to_version == migration.to_version for m in self._migrations):
raise MigrationVersionError(f"Migration to {migration.to_version} already registered")
self._migrations.add(migration)
def get(self, from_version: int) -> Optional[Migration]:
"""Gets the migration that may be run on the given database version."""
# register() ensures that there is only one migration with a given from_version, so this is safe.
return next((m for m in self._migrations if m.from_version == from_version), None)
@property
def count(self) -> int:
"""The count of registered migrations."""
return len(self._migrations)
@property
def latest_version(self) -> int:
"""Gets latest to_version among registered migrations. Returns 0 if there are no migrations registered."""
if self.count == 0:
return 0
return sorted(self._migrations, key=lambda m: m.to_version)[-1].to_version
class SQLiteMigrator: class SQLiteMigrator:
""" """
Manages migrations for a SQLite database. Manages migrations for a SQLite database.
:param conn: The database connection. :param db: The SqliteDatabase, representing the database on which to run migrations.
:param database: The path to the database file, or ":memory:" for an in-memory database. :param image_files: An instance of ImageFileStorageBase. Migrations may need to access image files.
:param lock: A lock to use when accessing the database.
:param logger: The logger to use.
Migrations should be registered with :meth:`register_migration`. Migrations will be run in Migrations should be registered with :meth:`register_migration`. Migrations will be run in
order of their version number. If the database is already at the latest version, no migrations order of their version number. If the database is already at the latest version, no migrations
will be run. will be run.
""" """
def __init__(self, conn: sqlite3.Connection, database: Path | str, lock: threading.RLock, logger: Logger) -> None: backup_path: Optional[Path] = None
self._logger = logger
self._conn = conn def __init__(self, db: SqliteDatabase, image_files: ImageFileStorageBase) -> None:
self._cursor = self._conn.cursor() self._image_files = image_files
self._lock = lock self._db = db
self._database = database self._logger = self._db._logger
self._migrations: set[Migration] = set() self._config = self._db._config
self._cursor = self._db.conn.cursor()
self._migrations = MigrationSet()
def register_migration(self, migration: Migration) -> None: def register_migration(self, migration: Migration) -> None:
"""Registers a migration.""" """
if not isinstance(migration.db_version, int) or migration.db_version < 1: Registers a migration.
raise MigrationVersionError(f"Invalid migration version {migration.db_version}") Migration callbacks should not commit any changes to the database; the migrator will commit the transaction.
if any(m.db_version == migration.db_version for m in self._migrations): """
raise MigrationVersionError(f"Migration version {migration.db_version} already registered") self._migrations.register(migration)
self._migrations.add(migration) self._logger.debug(f"Registered migration {migration.from_version} -> {migration.to_version}")
self._logger.debug(f"Registered migration {migration.db_version}")
def run_migrations(self) -> None: def run_migrations(self) -> None:
"""Migrates the database to the latest version.""" """Migrates the database to the latest version."""
with self._lock: with self._db.lock:
self._create_version_table() self._create_version_table()
sorted_migrations = sorted(self._migrations, key=lambda m: m.db_version)
current_version = self._get_current_version() current_version = self._get_current_version()
if len(sorted_migrations) == 0: if self._migrations.count == 0:
self._logger.debug("No migrations registered") self._logger.debug("No migrations registered")
return return
latest_version = sorted_migrations[-1].db_version latest_version = self._migrations.latest_version
if current_version == latest_version: if current_version == latest_version:
self._logger.debug("Database is up to date, no migrations to run") self._logger.debug("Database is up to date, no migrations to run")
return return
if current_version > latest_version:
raise MigrationError(
f"Database version {current_version} is greater than the latest migration version {latest_version}"
)
self._logger.info("Database update needed") self._logger.info("Database update needed")
# Only make a backup if using a file database (not memory) # Only make a backup if using a file database (not memory)
backup_path: Optional[Path] = None self._backup_db()
if isinstance(self._database, Path):
backup_path = self._backup_db(self._database)
else:
self._logger.info("Using in-memory database, skipping backup")
for migration in sorted_migrations: next_migration = self._migrations.get(from_version=current_version)
while next_migration is not None:
try: try:
self._run_migration(migration) self._run_migration(next_migration)
next_migration = self._migrations.get(self._get_current_version())
except MigrationError: except MigrationError:
if backup_path is not None: self._restore_db()
self._logger.error(f" Restoring from {backup_path}")
self._restore_db(backup_path)
raise raise
self._logger.info("Database updated successfully") self._logger.info("Database updated successfully")
def _run_migration(self, migration: Migration) -> None: def _run_migration(self, migration: Migration) -> None:
"""Runs a single migration.""" """Runs a single migration."""
with self._lock: with self._db.lock:
current_version = self._get_current_version()
try: try:
if current_version >= migration.db_version: if self._get_current_version() != migration.from_version:
return raise MigrationError(
migration.migrate(self._cursor) f"Database is at version {self._get_current_version()}, expected {migration.from_version}"
)
self._logger.debug(f"Running migration from {migration.from_version} to {migration.to_version}")
if migration.pre_migrate:
self._logger.debug(f"Running {len(migration.pre_migrate)} pre-migration callbacks")
for callback in migration.pre_migrate:
callback(self._db, self._image_files)
migration.migrate(self._db, self._image_files)
self._cursor.execute("INSERT INTO migrations (version) VALUES (?);", (migration.to_version,))
if migration.post_migrate:
self._logger.debug(f"Running {len(migration.post_migrate)} post-migration callbacks")
for callback in migration.post_migrate:
callback(self._db, self._image_files)
# Migration callbacks only get a cursor; they cannot commit the transaction. # Migration callbacks only get a cursor; they cannot commit the transaction.
self._conn.commit() self._db.conn.commit()
self._set_version(db_version=migration.db_version, app_version=migration.app_version) self._logger.debug(
self._logger.debug(f"Successfully migrated database from {current_version} to {migration.db_version}") f"Successfully migrated database from {migration.from_version} to {migration.to_version}"
)
except Exception as e: except Exception as e:
msg = f"Error migrating database from {current_version} to {migration.db_version}: {e}" msg = f"Error migrating database from {migration.from_version} to {migration.to_version}: {e}"
self._conn.rollback() self._db.conn.rollback()
self._logger.error(msg) self._logger.error(msg)
raise MigrationError(msg) from e raise MigrationError(msg) from e
def _create_version_table(self) -> None: def _create_version_table(self) -> None:
"""Creates a version table for the database, if one does not already exist.""" """Creates a version table for the database, if one does not already exist."""
with self._lock: with self._db.lock:
try: try:
self._cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='version';") self._cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='migrations';")
if self._cursor.fetchone() is not None: if self._cursor.fetchone() is not None:
return return
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
CREATE TABLE IF NOT EXISTS version ( CREATE TABLE migrations (
db_version INTEGER PRIMARY KEY, version INTEGER PRIMARY KEY,
app_version TEXT NOT NULL,
migrated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) migrated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
); );
""" """
) )
self._cursor.execute("INSERT INTO version (db_version, app_version) VALUES (?,?);", (0, "0.0.0")) self._cursor.execute("INSERT INTO migrations (version) VALUES (0);")
self._conn.commit() self._db.conn.commit()
self._logger.debug("Created version table") self._logger.debug("Created migrations table")
except sqlite3.Error as e: except sqlite3.Error as e:
msg = f"Problem creation version table: {e}" msg = f"Problem creating migrations table: {e}"
self._logger.error(msg) self._logger.error(msg)
self._conn.rollback() self._db.conn.rollback()
raise MigrationError(msg) from e raise MigrationError(msg) from e
def _get_current_version(self) -> int: def _get_current_version(self) -> int:
"""Gets the current version of the database, or 0 if the version table does not exist.""" """Gets the current version of the database, or 0 if the version table does not exist."""
with self._lock: with self._db.lock:
try: try:
self._cursor.execute("SELECT MAX(db_version) FROM version;") self._cursor.execute("SELECT MAX(version) FROM migrations;")
version = self._cursor.fetchone()[0] version = self._cursor.fetchone()[0]
if version is None: if version is None:
return 0 return 0
@ -168,43 +209,42 @@ class SQLiteMigrator:
return 0 return 0
raise raise
def _set_version(self, db_version: int, app_version: str) -> None: def _backup_db(self) -> None:
"""Adds a version entry to the table's version table."""
with self._lock:
try:
self._cursor.execute(
"INSERT INTO version (db_version, app_version) VALUES (?,?);", (db_version, app_version)
)
self._conn.commit()
except sqlite3.Error as e:
msg = f"Problem setting database version: {e}"
self._logger.error(msg)
self._conn.rollback()
raise MigrationError(msg) from e
def _backup_db(self, db_path: Path | str) -> Path:
"""Backs up the databse, returning the path to the backup file.""" """Backs up the databse, returning the path to the backup file."""
if db_path == sqlite_memory: if self._db.is_memory:
raise MigrationError("Cannot back up memory database") self._logger.debug("Using memory database, skipping backup")
if not isinstance(db_path, Path): # Sanity check!
raise MigrationError(f'Database path must be "{sqlite_memory}" or a Path') if not isinstance(self._db.database, Path):
with self._lock: raise MigrationError(f"Database path must be a Path, got {self._db.database} ({type(self._db.database)})")
with self._db.lock:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_path = db_path.parent / f"{db_path.stem}_{timestamp}.db" backup_path = self._db.database.parent / f"{self._db.database.stem}_{timestamp}.db"
self._logger.info(f"Backing up database to {backup_path}") self._logger.info(f"Backing up database to {backup_path}")
# Use SQLite's built in backup capabilities so we don't need to worry about locking and such.
backup_conn = sqlite3.connect(backup_path) backup_conn = sqlite3.connect(backup_path)
with backup_conn: with backup_conn:
self._conn.backup(backup_conn) self._db.conn.backup(backup_conn)
backup_conn.close() backup_conn.close()
return backup_path
def _restore_db(self, backup_path: Path) -> None: # Sanity check!
if not backup_path.is_file():
raise MigrationError("Unable to back up database")
self.backup_path = backup_path
def _restore_db(
self,
) -> None:
"""Restores the database from a backup file, unless the database is a memory database.""" """Restores the database from a backup file, unless the database is a memory database."""
if self._database == sqlite_memory: # We don't need to restore a memory database.
if self._db.is_memory:
return return
with self._lock:
self._logger.info(f"Restoring database from {backup_path}") with self._db.lock:
self._conn.close() self._logger.info(f"Restoring database from {self.backup_path}")
if not Path(backup_path).is_file(): self._db.conn.close()
raise FileNotFoundError(f"Backup file {backup_path} does not exist") if self.backup_path is None:
shutil.copy2(backup_path, self._database) raise FileNotFoundError("No backup path set")
if not Path(self.backup_path).is_file():
raise FileNotFoundError(f"Backup file {self.backup_path} does not exist")
shutil.copy2(self.backup_path, self._db.database)

View File

@ -1,84 +0,0 @@
import sqlite3
from datetime import datetime
from pathlib import Path
from PIL import Image
from tqdm import tqdm
from invokeai.app.services.config.config_default import InvokeAIAppConfig
def migrate_image_workflows(output_path: Path, database: Path, page_size=100):
"""
In the v3.5.0 release, InvokeAI changed how it handles image workflows. The `images` table in
the database now has a `has_workflow` column, indicating if an image has a workflow embedded.
This script checks each image for the presence of an embedded workflow, then updates its entry
in the database accordingly.
1) Check if the database is updated to support image workflows. Aborts if it doesn't have the
`has_workflow` column yet.
2) Backs up the database.
3) Opens each image in the `images` table via PIL
4) Checks if the `"invokeai_workflow"` attribute its in the image's embedded metadata, indicating
that it has a workflow.
5) If it does, updates the `has_workflow` column for that image to `TRUE`.
If there are any problems, the script immediately aborts. Because the processing happens in chunks,
if there is a problem, it is suggested that you restore the database from the backup and try again.
"""
output_path = output_path
database = database
conn = sqlite3.connect(database)
cursor = conn.cursor()
# We can only migrate if the `images` table has the `has_workflow` column
cursor.execute("PRAGMA table_info(images)")
columns = [column[1] for column in cursor.fetchall()]
if "has_workflow" not in columns:
raise Exception("Database needs to be updated to support image workflows")
# Back up the database before we start
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_path = database.parent / f"{database.stem}_migrate-image-workflows_{timestamp}.db"
print(f"Backing up database to {backup_path}")
backup_conn = sqlite3.connect(backup_path)
with backup_conn:
conn.backup(backup_conn)
backup_conn.close()
# Get the total number of images and chunk it into pages
cursor.execute("SELECT COUNT(*) FROM images")
total_images = cursor.fetchone()[0]
total_pages = (total_images + page_size - 1) // page_size
print(f"Processing {total_images} images in chunks of {page_size} images...")
# Migrate the images
migrated_count = 0
pbar = tqdm(range(total_pages))
for page in pbar:
pbar.set_description(f"Migrating page {page + 1}/{total_pages}")
offset = page * page_size
cursor.execute("SELECT image_name FROM images LIMIT ? OFFSET ?", (page_size, offset))
images = cursor.fetchall()
for image_name in images:
image_path = output_path / "images" / image_name[0]
with Image.open(image_path) as img:
if "invokeai_workflow" in img.info:
cursor.execute("UPDATE images SET has_workflow = TRUE WHERE image_name = ?", (image_name[0],))
migrated_count += 1
conn.commit()
conn.close()
print(f"Migrated workflows for {migrated_count} images.")
if __name__ == "__main__":
config = InvokeAIAppConfig.get_config()
output_path = config.output_path
database = config.db_path
assert output_path is not None
assert output_path.exists()
assert database.exists()
migrate_image_workflows(output_path=output_path, database=database, page_size=100)