mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(db): incorporate feedback
This commit is contained in:
parent
c382329e8c
commit
0710ec30cf
invokeai
app
api
services/shared/sqlite
backend/util
@ -70,9 +70,10 @@ class ApiDependencies:
|
||||
logger.debug(f"Internet connectivity is {config.internet_available}")
|
||||
|
||||
output_folder = config.output_path
|
||||
image_files = DiskImageFileStorage(f"{output_folder}/images")
|
||||
|
||||
db = SqliteDatabase(config, logger)
|
||||
migrator = SQLiteMigrator(conn=db.conn, database=db.database, lock=db.lock, logger=logger)
|
||||
migrator = SQLiteMigrator(db=db, image_files=image_files)
|
||||
migrator.register_migration(migration_1)
|
||||
migrator.register_migration(migration_2)
|
||||
migrator.run_migrations()
|
||||
@ -87,7 +88,6 @@ class ApiDependencies:
|
||||
events = FastAPIEventService(event_handler_id)
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
|
||||
graph_library = SqliteItemStorage[LibraryGraph](db=db, table_name="graphs")
|
||||
image_files = DiskImageFileStorage(f"{output_folder}/images")
|
||||
image_records = SqliteImageRecordStorage(db=db)
|
||||
images = ImageService()
|
||||
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
|
||||
|
@ -1,11 +1,14 @@
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.services.shared.sqlite.sqlite_migrator import Migration
|
||||
|
||||
|
||||
def _migrate(cursor: sqlite3.Cursor) -> None:
|
||||
def _migrate(db: SqliteDatabase, image_files: ImageFileStorageBase) -> None:
|
||||
"""Migration callback for database version 1."""
|
||||
|
||||
cursor = db.conn.cursor()
|
||||
_create_board_images(cursor)
|
||||
_create_boards(cursor)
|
||||
_create_images(cursor)
|
||||
@ -350,7 +353,11 @@ def _create_workflows(cursor: sqlite3.Cursor) -> None:
|
||||
cursor.execute(stmt)
|
||||
|
||||
|
||||
migration_1 = Migration(db_version=1, app_version="3.4.0", migrate=_migrate)
|
||||
migration_1 = Migration(
|
||||
from_version=0,
|
||||
to_version=1,
|
||||
migrate=_migrate,
|
||||
)
|
||||
"""
|
||||
Database version 1 (initial state).
|
||||
|
||||
|
@ -1,39 +1,55 @@
|
||||
import sqlite3
|
||||
from logging import Logger
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.services.shared.sqlite.sqlite_migrator import Migration
|
||||
|
||||
|
||||
def _migrate(cursor: sqlite3.Cursor) -> None:
|
||||
def _migrate(db: SqliteDatabase, image_files: ImageFileStorageBase) -> None:
|
||||
"""Migration callback for database version 2."""
|
||||
|
||||
cursor = db.conn.cursor()
|
||||
_add_images_has_workflow(cursor)
|
||||
_add_session_queue_workflow(cursor)
|
||||
_drop_old_workflow_tables(cursor)
|
||||
_add_workflow_library(cursor)
|
||||
_drop_model_manager_metadata(cursor)
|
||||
_migrate_embedded_workflows(cursor=cursor, image_files=image_files, logger=db._logger)
|
||||
|
||||
|
||||
def _add_images_has_workflow(cursor: sqlite3.Cursor) -> None:
|
||||
"""Add the `has_workflow` column to `images` table."""
|
||||
cursor.execute("ALTER TABLE images ADD COLUMN has_workflow BOOLEAN DEFAULT FALSE;")
|
||||
cursor.execute("PRAGMA table_info(images)")
|
||||
columns = [column[1] for column in cursor.fetchall()]
|
||||
|
||||
if "has_workflow" not in columns:
|
||||
cursor.execute("ALTER TABLE images ADD COLUMN has_workflow BOOLEAN DEFAULT FALSE;")
|
||||
|
||||
|
||||
def _add_session_queue_workflow(cursor: sqlite3.Cursor) -> None:
|
||||
"""Add the `workflow` column to `session_queue` table."""
|
||||
cursor.execute("ALTER TABLE session_queue ADD COLUMN workflow TEXT;")
|
||||
|
||||
cursor.execute("PRAGMA table_info(session_queue)")
|
||||
columns = [column[1] for column in cursor.fetchall()]
|
||||
|
||||
if "workflow" not in columns:
|
||||
cursor.execute("ALTER TABLE session_queue ADD COLUMN workflow TEXT;")
|
||||
|
||||
|
||||
def _drop_old_workflow_tables(cursor: sqlite3.Cursor) -> None:
|
||||
"""Drops the `workflows` and `workflow_images` tables."""
|
||||
cursor.execute("DROP TABLE workflow_images;")
|
||||
cursor.execute("DROP TABLE workflows;")
|
||||
cursor.execute("DROP TABLE IF EXISTS workflow_images;")
|
||||
cursor.execute("DROP TABLE IF EXISTS workflows;")
|
||||
|
||||
|
||||
def _add_workflow_library(cursor: sqlite3.Cursor) -> None:
|
||||
"""Adds the `workflow_library` table and drops the `workflows` and `workflow_images` tables."""
|
||||
tables = [
|
||||
"""--sql
|
||||
CREATE TABLE workflow_library (
|
||||
CREATE TABLE IF NOT EXISTS workflow_library (
|
||||
workflow_id TEXT NOT NULL PRIMARY KEY,
|
||||
workflow TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
@ -50,17 +66,17 @@ def _add_workflow_library(cursor: sqlite3.Cursor) -> None:
|
||||
]
|
||||
|
||||
indices = [
|
||||
"CREATE INDEX idx_workflow_library_created_at ON workflow_library(created_at);",
|
||||
"CREATE INDEX idx_workflow_library_updated_at ON workflow_library(updated_at);",
|
||||
"CREATE INDEX idx_workflow_library_opened_at ON workflow_library(opened_at);",
|
||||
"CREATE INDEX idx_workflow_library_category ON workflow_library(category);",
|
||||
"CREATE INDEX idx_workflow_library_name ON workflow_library(name);",
|
||||
"CREATE INDEX idx_workflow_library_description ON workflow_library(description);",
|
||||
"CREATE INDEX IF NOT EXISTS idx_workflow_library_created_at ON workflow_library(created_at);",
|
||||
"CREATE INDEX IF NOT EXISTS idx_workflow_library_updated_at ON workflow_library(updated_at);",
|
||||
"CREATE INDEX IF NOT EXISTS idx_workflow_library_opened_at ON workflow_library(opened_at);",
|
||||
"CREATE INDEX IF NOT EXISTS idx_workflow_library_category ON workflow_library(category);",
|
||||
"CREATE INDEX IF NOT EXISTS idx_workflow_library_name ON workflow_library(name);",
|
||||
"CREATE INDEX IF NOT EXISTS idx_workflow_library_description ON workflow_library(description);",
|
||||
]
|
||||
|
||||
triggers = [
|
||||
"""--sql
|
||||
CREATE TRIGGER tg_workflow_library_updated_at
|
||||
CREATE TRIGGER IF NOT EXISTS tg_workflow_library_updated_at
|
||||
AFTER UPDATE
|
||||
ON workflow_library FOR EACH ROW
|
||||
BEGIN
|
||||
@ -77,12 +93,43 @@ def _add_workflow_library(cursor: sqlite3.Cursor) -> None:
|
||||
|
||||
def _drop_model_manager_metadata(cursor: sqlite3.Cursor) -> None:
|
||||
"""Drops the `model_manager_metadata` table."""
|
||||
cursor.execute("DROP TABLE model_manager_metadata;")
|
||||
cursor.execute("DROP TABLE IF EXISTS model_manager_metadata;")
|
||||
|
||||
|
||||
def _migrate_embedded_workflows(cursor: sqlite3.Cursor, image_files: ImageFileStorageBase, logger: Logger) -> None:
|
||||
"""
|
||||
In the v3.5.0 release, InvokeAI changed how it handles embedded workflows. The `images` table in
|
||||
the database now has a `has_workflow` column, indicating if an image has a workflow embedded.
|
||||
|
||||
This migrate callbakc checks each image for the presence of an embedded workflow, then updates its entry
|
||||
in the database accordingly.
|
||||
"""
|
||||
# Get the total number of images and chunk it into pages
|
||||
cursor.execute("SELECT image_name FROM images")
|
||||
image_names: list[str] = [image[0] for image in cursor.fetchall()]
|
||||
total_image_names = len(image_names)
|
||||
|
||||
if not total_image_names:
|
||||
return
|
||||
|
||||
logger.info(f"Migrating workflows for {total_image_names} images")
|
||||
|
||||
# Migrate the images
|
||||
to_migrate: list[tuple[bool, str]] = []
|
||||
pbar = tqdm(image_names)
|
||||
for idx, image_name in enumerate(pbar):
|
||||
pbar.set_description(f"Checking image {idx + 1}/{total_image_names} for workflow")
|
||||
pil_image = image_files.get(image_name)
|
||||
if "invokeai_workflow" in pil_image.info:
|
||||
to_migrate.append((True, image_name))
|
||||
|
||||
logger.info(f"Adding {len(to_migrate)} embedded workflows to database")
|
||||
cursor.executemany("UPDATE images SET has_workflow = ? WHERE image_name = ?", to_migrate)
|
||||
|
||||
|
||||
migration_2 = Migration(
|
||||
db_version=2,
|
||||
app_version="3.5.0",
|
||||
from_version=1,
|
||||
to_version=2,
|
||||
migrate=_migrate,
|
||||
)
|
||||
"""
|
||||
@ -90,8 +137,10 @@ Database version 2.
|
||||
|
||||
Introduced in v3.5.0 for the new workflow library.
|
||||
|
||||
Migration:
|
||||
- Add `has_workflow` column to `images` table
|
||||
- Add `workflow` column to `session_queue` table
|
||||
- Drop `workflows` and `workflow_images` tables
|
||||
- Add `workflow_library` table
|
||||
- Updates `has_workflow` for all images
|
||||
"""
|
||||
|
@ -8,13 +8,15 @@ from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
|
||||
|
||||
|
||||
class SqliteDatabase:
|
||||
database: Path | str
|
||||
database: Path | str # Must declare this here to satisfy type checker
|
||||
|
||||
def __init__(self, config: InvokeAIAppConfig, logger: Logger):
|
||||
self._logger = logger
|
||||
self._config = config
|
||||
self.is_memory = False
|
||||
if self._config.use_memory_db:
|
||||
self.database = sqlite_memory
|
||||
self.is_memory = True
|
||||
logger.info("Using in-memory database")
|
||||
else:
|
||||
self.database = self._config.db_path
|
||||
|
@ -1,14 +1,15 @@
|
||||
import shutil
|
||||
import sqlite3
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional, TypeAlias
|
||||
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
MigrateCallback: TypeAlias = Callable[[sqlite3.Cursor], None]
|
||||
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
MigrateCallback: TypeAlias = Callable[[SqliteDatabase, ImageFileStorageBase], None]
|
||||
|
||||
|
||||
class MigrationError(Exception):
|
||||
@ -19,146 +20,186 @@ class MigrationVersionError(ValueError, MigrationError):
|
||||
"""Raised when a migration version is invalid."""
|
||||
|
||||
|
||||
class Migration:
|
||||
"""Represents a migration for a SQLite database.
|
||||
class Migration(BaseModel):
|
||||
"""
|
||||
Represents a migration for a SQLite database.
|
||||
|
||||
:param db_version: The database schema version this migration results in.
|
||||
:param app_version: The app version this migration is introduced in.
|
||||
:param migrate: The callback to run to perform the migration. The callback will be passed a
|
||||
cursor to the database. The migrator will manage locking database access and committing the
|
||||
transaction; the callback should not do either of these things.
|
||||
Migration callbacks will be provided an instance of SqliteDatabase.
|
||||
Migration callbacks should not commit; the migrator will commit the transaction.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_version: int,
|
||||
app_version: str,
|
||||
migrate: MigrateCallback,
|
||||
) -> None:
|
||||
self.db_version = db_version
|
||||
self.app_version = app_version
|
||||
self.migrate = migrate
|
||||
from_version: int = Field(ge=0, strict=True, description="The database version on which this migration may be run")
|
||||
to_version: int = Field(ge=1, strict=True, description="The database version that results from this migration")
|
||||
migrate: MigrateCallback = Field(description="The callback to run to perform the migration")
|
||||
pre_migrate: list[MigrateCallback] = Field(
|
||||
default=[], description="A list of callbacks to run before the migration"
|
||||
)
|
||||
post_migrate: list[MigrateCallback] = Field(
|
||||
default=[], description="A list of callbacks to run after the migration"
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_to_version(self) -> "Migration":
|
||||
if self.to_version <= self.from_version:
|
||||
raise ValueError("to_version must be greater than from_version")
|
||||
return self
|
||||
|
||||
def __hash__(self) -> int:
|
||||
# Callables are not hashable, so we need to implement our own __hash__ function to use this class in a set.
|
||||
return hash((self.from_version, self.to_version))
|
||||
|
||||
|
||||
class MigrationSet:
|
||||
"""A set of Migrations. Performs validation during migration registration and provides utility methods."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._migrations: set[Migration] = set()
|
||||
|
||||
def register(self, migration: Migration) -> None:
|
||||
"""Registers a migration."""
|
||||
if any(m.from_version == migration.from_version for m in self._migrations):
|
||||
raise MigrationVersionError(f"Migration from {migration.from_version} already registered")
|
||||
if any(m.to_version == migration.to_version for m in self._migrations):
|
||||
raise MigrationVersionError(f"Migration to {migration.to_version} already registered")
|
||||
self._migrations.add(migration)
|
||||
|
||||
def get(self, from_version: int) -> Optional[Migration]:
|
||||
"""Gets the migration that may be run on the given database version."""
|
||||
# register() ensures that there is only one migration with a given from_version, so this is safe.
|
||||
return next((m for m in self._migrations if m.from_version == from_version), None)
|
||||
|
||||
@property
|
||||
def count(self) -> int:
|
||||
"""The count of registered migrations."""
|
||||
return len(self._migrations)
|
||||
|
||||
@property
|
||||
def latest_version(self) -> int:
|
||||
"""Gets latest to_version among registered migrations. Returns 0 if there are no migrations registered."""
|
||||
if self.count == 0:
|
||||
return 0
|
||||
return sorted(self._migrations, key=lambda m: m.to_version)[-1].to_version
|
||||
|
||||
|
||||
class SQLiteMigrator:
|
||||
"""
|
||||
Manages migrations for a SQLite database.
|
||||
|
||||
:param conn: The database connection.
|
||||
:param database: The path to the database file, or ":memory:" for an in-memory database.
|
||||
:param lock: A lock to use when accessing the database.
|
||||
:param logger: The logger to use.
|
||||
:param db: The SqliteDatabase, representing the database on which to run migrations.
|
||||
:param image_files: An instance of ImageFileStorageBase. Migrations may need to access image files.
|
||||
|
||||
Migrations should be registered with :meth:`register_migration`. Migrations will be run in
|
||||
order of their version number. If the database is already at the latest version, no migrations
|
||||
will be run.
|
||||
"""
|
||||
|
||||
def __init__(self, conn: sqlite3.Connection, database: Path | str, lock: threading.RLock, logger: Logger) -> None:
|
||||
self._logger = logger
|
||||
self._conn = conn
|
||||
self._cursor = self._conn.cursor()
|
||||
self._lock = lock
|
||||
self._database = database
|
||||
self._migrations: set[Migration] = set()
|
||||
backup_path: Optional[Path] = None
|
||||
|
||||
def __init__(self, db: SqliteDatabase, image_files: ImageFileStorageBase) -> None:
|
||||
self._image_files = image_files
|
||||
self._db = db
|
||||
self._logger = self._db._logger
|
||||
self._config = self._db._config
|
||||
self._cursor = self._db.conn.cursor()
|
||||
self._migrations = MigrationSet()
|
||||
|
||||
def register_migration(self, migration: Migration) -> None:
|
||||
"""Registers a migration."""
|
||||
if not isinstance(migration.db_version, int) or migration.db_version < 1:
|
||||
raise MigrationVersionError(f"Invalid migration version {migration.db_version}")
|
||||
if any(m.db_version == migration.db_version for m in self._migrations):
|
||||
raise MigrationVersionError(f"Migration version {migration.db_version} already registered")
|
||||
self._migrations.add(migration)
|
||||
self._logger.debug(f"Registered migration {migration.db_version}")
|
||||
"""
|
||||
Registers a migration.
|
||||
Migration callbacks should not commit any changes to the database; the migrator will commit the transaction.
|
||||
"""
|
||||
self._migrations.register(migration)
|
||||
self._logger.debug(f"Registered migration {migration.from_version} -> {migration.to_version}")
|
||||
|
||||
def run_migrations(self) -> None:
|
||||
"""Migrates the database to the latest version."""
|
||||
with self._lock:
|
||||
with self._db.lock:
|
||||
self._create_version_table()
|
||||
sorted_migrations = sorted(self._migrations, key=lambda m: m.db_version)
|
||||
current_version = self._get_current_version()
|
||||
|
||||
if len(sorted_migrations) == 0:
|
||||
if self._migrations.count == 0:
|
||||
self._logger.debug("No migrations registered")
|
||||
return
|
||||
|
||||
latest_version = sorted_migrations[-1].db_version
|
||||
latest_version = self._migrations.latest_version
|
||||
if current_version == latest_version:
|
||||
self._logger.debug("Database is up to date, no migrations to run")
|
||||
return
|
||||
|
||||
if current_version > latest_version:
|
||||
raise MigrationError(
|
||||
f"Database version {current_version} is greater than the latest migration version {latest_version}"
|
||||
)
|
||||
|
||||
self._logger.info("Database update needed")
|
||||
|
||||
# Only make a backup if using a file database (not memory)
|
||||
backup_path: Optional[Path] = None
|
||||
if isinstance(self._database, Path):
|
||||
backup_path = self._backup_db(self._database)
|
||||
else:
|
||||
self._logger.info("Using in-memory database, skipping backup")
|
||||
self._backup_db()
|
||||
|
||||
for migration in sorted_migrations:
|
||||
next_migration = self._migrations.get(from_version=current_version)
|
||||
while next_migration is not None:
|
||||
try:
|
||||
self._run_migration(migration)
|
||||
self._run_migration(next_migration)
|
||||
next_migration = self._migrations.get(self._get_current_version())
|
||||
except MigrationError:
|
||||
if backup_path is not None:
|
||||
self._logger.error(f" Restoring from {backup_path}")
|
||||
self._restore_db(backup_path)
|
||||
self._restore_db()
|
||||
raise
|
||||
self._logger.info("Database updated successfully")
|
||||
|
||||
def _run_migration(self, migration: Migration) -> None:
|
||||
"""Runs a single migration."""
|
||||
with self._lock:
|
||||
current_version = self._get_current_version()
|
||||
with self._db.lock:
|
||||
try:
|
||||
if current_version >= migration.db_version:
|
||||
return
|
||||
migration.migrate(self._cursor)
|
||||
if self._get_current_version() != migration.from_version:
|
||||
raise MigrationError(
|
||||
f"Database is at version {self._get_current_version()}, expected {migration.from_version}"
|
||||
)
|
||||
self._logger.debug(f"Running migration from {migration.from_version} to {migration.to_version}")
|
||||
if migration.pre_migrate:
|
||||
self._logger.debug(f"Running {len(migration.pre_migrate)} pre-migration callbacks")
|
||||
for callback in migration.pre_migrate:
|
||||
callback(self._db, self._image_files)
|
||||
migration.migrate(self._db, self._image_files)
|
||||
self._cursor.execute("INSERT INTO migrations (version) VALUES (?);", (migration.to_version,))
|
||||
if migration.post_migrate:
|
||||
self._logger.debug(f"Running {len(migration.post_migrate)} post-migration callbacks")
|
||||
for callback in migration.post_migrate:
|
||||
callback(self._db, self._image_files)
|
||||
# Migration callbacks only get a cursor; they cannot commit the transaction.
|
||||
self._conn.commit()
|
||||
self._set_version(db_version=migration.db_version, app_version=migration.app_version)
|
||||
self._logger.debug(f"Successfully migrated database from {current_version} to {migration.db_version}")
|
||||
self._db.conn.commit()
|
||||
self._logger.debug(
|
||||
f"Successfully migrated database from {migration.from_version} to {migration.to_version}"
|
||||
)
|
||||
except Exception as e:
|
||||
msg = f"Error migrating database from {current_version} to {migration.db_version}: {e}"
|
||||
self._conn.rollback()
|
||||
msg = f"Error migrating database from {migration.from_version} to {migration.to_version}: {e}"
|
||||
self._db.conn.rollback()
|
||||
self._logger.error(msg)
|
||||
raise MigrationError(msg) from e
|
||||
|
||||
def _create_version_table(self) -> None:
|
||||
"""Creates a version table for the database, if one does not already exist."""
|
||||
with self._lock:
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='version';")
|
||||
self._cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='migrations';")
|
||||
if self._cursor.fetchone() is not None:
|
||||
return
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS version (
|
||||
db_version INTEGER PRIMARY KEY,
|
||||
app_version TEXT NOT NULL,
|
||||
CREATE TABLE migrations (
|
||||
version INTEGER PRIMARY KEY,
|
||||
migrated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
|
||||
);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute("INSERT INTO version (db_version, app_version) VALUES (?,?);", (0, "0.0.0"))
|
||||
self._conn.commit()
|
||||
self._logger.debug("Created version table")
|
||||
self._cursor.execute("INSERT INTO migrations (version) VALUES (0);")
|
||||
self._db.conn.commit()
|
||||
self._logger.debug("Created migrations table")
|
||||
except sqlite3.Error as e:
|
||||
msg = f"Problem creation version table: {e}"
|
||||
msg = f"Problem creating migrations table: {e}"
|
||||
self._logger.error(msg)
|
||||
self._conn.rollback()
|
||||
self._db.conn.rollback()
|
||||
raise MigrationError(msg) from e
|
||||
|
||||
def _get_current_version(self) -> int:
|
||||
"""Gets the current version of the database, or 0 if the version table does not exist."""
|
||||
with self._lock:
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute("SELECT MAX(db_version) FROM version;")
|
||||
self._cursor.execute("SELECT MAX(version) FROM migrations;")
|
||||
version = self._cursor.fetchone()[0]
|
||||
if version is None:
|
||||
return 0
|
||||
@ -168,43 +209,42 @@ class SQLiteMigrator:
|
||||
return 0
|
||||
raise
|
||||
|
||||
def _set_version(self, db_version: int, app_version: str) -> None:
|
||||
"""Adds a version entry to the table's version table."""
|
||||
with self._lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"INSERT INTO version (db_version, app_version) VALUES (?,?);", (db_version, app_version)
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
msg = f"Problem setting database version: {e}"
|
||||
self._logger.error(msg)
|
||||
self._conn.rollback()
|
||||
raise MigrationError(msg) from e
|
||||
|
||||
def _backup_db(self, db_path: Path | str) -> Path:
|
||||
def _backup_db(self) -> None:
|
||||
"""Backs up the databse, returning the path to the backup file."""
|
||||
if db_path == sqlite_memory:
|
||||
raise MigrationError("Cannot back up memory database")
|
||||
if not isinstance(db_path, Path):
|
||||
raise MigrationError(f'Database path must be "{sqlite_memory}" or a Path')
|
||||
with self._lock:
|
||||
if self._db.is_memory:
|
||||
self._logger.debug("Using memory database, skipping backup")
|
||||
# Sanity check!
|
||||
if not isinstance(self._db.database, Path):
|
||||
raise MigrationError(f"Database path must be a Path, got {self._db.database} ({type(self._db.database)})")
|
||||
with self._db.lock:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
backup_path = db_path.parent / f"{db_path.stem}_{timestamp}.db"
|
||||
backup_path = self._db.database.parent / f"{self._db.database.stem}_{timestamp}.db"
|
||||
self._logger.info(f"Backing up database to {backup_path}")
|
||||
|
||||
# Use SQLite's built in backup capabilities so we don't need to worry about locking and such.
|
||||
backup_conn = sqlite3.connect(backup_path)
|
||||
with backup_conn:
|
||||
self._conn.backup(backup_conn)
|
||||
self._db.conn.backup(backup_conn)
|
||||
backup_conn.close()
|
||||
return backup_path
|
||||
|
||||
def _restore_db(self, backup_path: Path) -> None:
|
||||
# Sanity check!
|
||||
if not backup_path.is_file():
|
||||
raise MigrationError("Unable to back up database")
|
||||
self.backup_path = backup_path
|
||||
|
||||
def _restore_db(
|
||||
self,
|
||||
) -> None:
|
||||
"""Restores the database from a backup file, unless the database is a memory database."""
|
||||
if self._database == sqlite_memory:
|
||||
# We don't need to restore a memory database.
|
||||
if self._db.is_memory:
|
||||
return
|
||||
with self._lock:
|
||||
self._logger.info(f"Restoring database from {backup_path}")
|
||||
self._conn.close()
|
||||
if not Path(backup_path).is_file():
|
||||
raise FileNotFoundError(f"Backup file {backup_path} does not exist")
|
||||
shutil.copy2(backup_path, self._database)
|
||||
|
||||
with self._db.lock:
|
||||
self._logger.info(f"Restoring database from {self.backup_path}")
|
||||
self._db.conn.close()
|
||||
if self.backup_path is None:
|
||||
raise FileNotFoundError("No backup path set")
|
||||
if not Path(self.backup_path).is_file():
|
||||
raise FileNotFoundError(f"Backup file {self.backup_path} does not exist")
|
||||
shutil.copy2(self.backup_path, self._db.database)
|
||||
|
@ -1,84 +0,0 @@
|
||||
import sqlite3
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
|
||||
|
||||
def migrate_image_workflows(output_path: Path, database: Path, page_size=100):
|
||||
"""
|
||||
In the v3.5.0 release, InvokeAI changed how it handles image workflows. The `images` table in
|
||||
the database now has a `has_workflow` column, indicating if an image has a workflow embedded.
|
||||
|
||||
This script checks each image for the presence of an embedded workflow, then updates its entry
|
||||
in the database accordingly.
|
||||
|
||||
1) Check if the database is updated to support image workflows. Aborts if it doesn't have the
|
||||
`has_workflow` column yet.
|
||||
2) Backs up the database.
|
||||
3) Opens each image in the `images` table via PIL
|
||||
4) Checks if the `"invokeai_workflow"` attribute its in the image's embedded metadata, indicating
|
||||
that it has a workflow.
|
||||
5) If it does, updates the `has_workflow` column for that image to `TRUE`.
|
||||
|
||||
If there are any problems, the script immediately aborts. Because the processing happens in chunks,
|
||||
if there is a problem, it is suggested that you restore the database from the backup and try again.
|
||||
"""
|
||||
output_path = output_path
|
||||
database = database
|
||||
conn = sqlite3.connect(database)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# We can only migrate if the `images` table has the `has_workflow` column
|
||||
cursor.execute("PRAGMA table_info(images)")
|
||||
columns = [column[1] for column in cursor.fetchall()]
|
||||
if "has_workflow" not in columns:
|
||||
raise Exception("Database needs to be updated to support image workflows")
|
||||
|
||||
# Back up the database before we start
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
backup_path = database.parent / f"{database.stem}_migrate-image-workflows_{timestamp}.db"
|
||||
print(f"Backing up database to {backup_path}")
|
||||
backup_conn = sqlite3.connect(backup_path)
|
||||
with backup_conn:
|
||||
conn.backup(backup_conn)
|
||||
backup_conn.close()
|
||||
|
||||
# Get the total number of images and chunk it into pages
|
||||
cursor.execute("SELECT COUNT(*) FROM images")
|
||||
total_images = cursor.fetchone()[0]
|
||||
total_pages = (total_images + page_size - 1) // page_size
|
||||
print(f"Processing {total_images} images in chunks of {page_size} images...")
|
||||
|
||||
# Migrate the images
|
||||
migrated_count = 0
|
||||
pbar = tqdm(range(total_pages))
|
||||
for page in pbar:
|
||||
pbar.set_description(f"Migrating page {page + 1}/{total_pages}")
|
||||
offset = page * page_size
|
||||
cursor.execute("SELECT image_name FROM images LIMIT ? OFFSET ?", (page_size, offset))
|
||||
images = cursor.fetchall()
|
||||
for image_name in images:
|
||||
image_path = output_path / "images" / image_name[0]
|
||||
with Image.open(image_path) as img:
|
||||
if "invokeai_workflow" in img.info:
|
||||
cursor.execute("UPDATE images SET has_workflow = TRUE WHERE image_name = ?", (image_name[0],))
|
||||
migrated_count += 1
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
print(f"Migrated workflows for {migrated_count} images.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
output_path = config.output_path
|
||||
database = config.db_path
|
||||
|
||||
assert output_path is not None
|
||||
assert output_path.exists()
|
||||
assert database.exists()
|
||||
migrate_image_workflows(output_path=output_path, database=database, page_size=100)
|
Loading…
x
Reference in New Issue
Block a user