mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(db): incorporate feedback
This commit is contained in:
parent
c382329e8c
commit
0710ec30cf
@ -70,9 +70,10 @@ class ApiDependencies:
|
|||||||
logger.debug(f"Internet connectivity is {config.internet_available}")
|
logger.debug(f"Internet connectivity is {config.internet_available}")
|
||||||
|
|
||||||
output_folder = config.output_path
|
output_folder = config.output_path
|
||||||
|
image_files = DiskImageFileStorage(f"{output_folder}/images")
|
||||||
|
|
||||||
db = SqliteDatabase(config, logger)
|
db = SqliteDatabase(config, logger)
|
||||||
migrator = SQLiteMigrator(conn=db.conn, database=db.database, lock=db.lock, logger=logger)
|
migrator = SQLiteMigrator(db=db, image_files=image_files)
|
||||||
migrator.register_migration(migration_1)
|
migrator.register_migration(migration_1)
|
||||||
migrator.register_migration(migration_2)
|
migrator.register_migration(migration_2)
|
||||||
migrator.run_migrations()
|
migrator.run_migrations()
|
||||||
@ -87,7 +88,6 @@ class ApiDependencies:
|
|||||||
events = FastAPIEventService(event_handler_id)
|
events = FastAPIEventService(event_handler_id)
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
|
||||||
graph_library = SqliteItemStorage[LibraryGraph](db=db, table_name="graphs")
|
graph_library = SqliteItemStorage[LibraryGraph](db=db, table_name="graphs")
|
||||||
image_files = DiskImageFileStorage(f"{output_folder}/images")
|
|
||||||
image_records = SqliteImageRecordStorage(db=db)
|
image_records = SqliteImageRecordStorage(db=db)
|
||||||
images = ImageService()
|
images = ImageService()
|
||||||
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
|
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
|
||||||
|
@ -1,11 +1,14 @@
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
|
|
||||||
|
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
|
||||||
|
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||||
from invokeai.app.services.shared.sqlite.sqlite_migrator import Migration
|
from invokeai.app.services.shared.sqlite.sqlite_migrator import Migration
|
||||||
|
|
||||||
|
|
||||||
def _migrate(cursor: sqlite3.Cursor) -> None:
|
def _migrate(db: SqliteDatabase, image_files: ImageFileStorageBase) -> None:
|
||||||
"""Migration callback for database version 1."""
|
"""Migration callback for database version 1."""
|
||||||
|
|
||||||
|
cursor = db.conn.cursor()
|
||||||
_create_board_images(cursor)
|
_create_board_images(cursor)
|
||||||
_create_boards(cursor)
|
_create_boards(cursor)
|
||||||
_create_images(cursor)
|
_create_images(cursor)
|
||||||
@ -350,7 +353,11 @@ def _create_workflows(cursor: sqlite3.Cursor) -> None:
|
|||||||
cursor.execute(stmt)
|
cursor.execute(stmt)
|
||||||
|
|
||||||
|
|
||||||
migration_1 = Migration(db_version=1, app_version="3.4.0", migrate=_migrate)
|
migration_1 = Migration(
|
||||||
|
from_version=0,
|
||||||
|
to_version=1,
|
||||||
|
migrate=_migrate,
|
||||||
|
)
|
||||||
"""
|
"""
|
||||||
Database version 1 (initial state).
|
Database version 1 (initial state).
|
||||||
|
|
||||||
|
@ -1,39 +1,55 @@
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
|
from logging import Logger
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
|
||||||
|
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||||
from invokeai.app.services.shared.sqlite.sqlite_migrator import Migration
|
from invokeai.app.services.shared.sqlite.sqlite_migrator import Migration
|
||||||
|
|
||||||
|
|
||||||
def _migrate(cursor: sqlite3.Cursor) -> None:
|
def _migrate(db: SqliteDatabase, image_files: ImageFileStorageBase) -> None:
|
||||||
"""Migration callback for database version 2."""
|
"""Migration callback for database version 2."""
|
||||||
|
|
||||||
|
cursor = db.conn.cursor()
|
||||||
_add_images_has_workflow(cursor)
|
_add_images_has_workflow(cursor)
|
||||||
_add_session_queue_workflow(cursor)
|
_add_session_queue_workflow(cursor)
|
||||||
_drop_old_workflow_tables(cursor)
|
_drop_old_workflow_tables(cursor)
|
||||||
_add_workflow_library(cursor)
|
_add_workflow_library(cursor)
|
||||||
_drop_model_manager_metadata(cursor)
|
_drop_model_manager_metadata(cursor)
|
||||||
|
_migrate_embedded_workflows(cursor=cursor, image_files=image_files, logger=db._logger)
|
||||||
|
|
||||||
|
|
||||||
def _add_images_has_workflow(cursor: sqlite3.Cursor) -> None:
|
def _add_images_has_workflow(cursor: sqlite3.Cursor) -> None:
|
||||||
"""Add the `has_workflow` column to `images` table."""
|
"""Add the `has_workflow` column to `images` table."""
|
||||||
cursor.execute("ALTER TABLE images ADD COLUMN has_workflow BOOLEAN DEFAULT FALSE;")
|
cursor.execute("PRAGMA table_info(images)")
|
||||||
|
columns = [column[1] for column in cursor.fetchall()]
|
||||||
|
|
||||||
|
if "has_workflow" not in columns:
|
||||||
|
cursor.execute("ALTER TABLE images ADD COLUMN has_workflow BOOLEAN DEFAULT FALSE;")
|
||||||
|
|
||||||
|
|
||||||
def _add_session_queue_workflow(cursor: sqlite3.Cursor) -> None:
|
def _add_session_queue_workflow(cursor: sqlite3.Cursor) -> None:
|
||||||
"""Add the `workflow` column to `session_queue` table."""
|
"""Add the `workflow` column to `session_queue` table."""
|
||||||
cursor.execute("ALTER TABLE session_queue ADD COLUMN workflow TEXT;")
|
|
||||||
|
cursor.execute("PRAGMA table_info(session_queue)")
|
||||||
|
columns = [column[1] for column in cursor.fetchall()]
|
||||||
|
|
||||||
|
if "workflow" not in columns:
|
||||||
|
cursor.execute("ALTER TABLE session_queue ADD COLUMN workflow TEXT;")
|
||||||
|
|
||||||
|
|
||||||
def _drop_old_workflow_tables(cursor: sqlite3.Cursor) -> None:
|
def _drop_old_workflow_tables(cursor: sqlite3.Cursor) -> None:
|
||||||
"""Drops the `workflows` and `workflow_images` tables."""
|
"""Drops the `workflows` and `workflow_images` tables."""
|
||||||
cursor.execute("DROP TABLE workflow_images;")
|
cursor.execute("DROP TABLE IF EXISTS workflow_images;")
|
||||||
cursor.execute("DROP TABLE workflows;")
|
cursor.execute("DROP TABLE IF EXISTS workflows;")
|
||||||
|
|
||||||
|
|
||||||
def _add_workflow_library(cursor: sqlite3.Cursor) -> None:
|
def _add_workflow_library(cursor: sqlite3.Cursor) -> None:
|
||||||
"""Adds the `workflow_library` table and drops the `workflows` and `workflow_images` tables."""
|
"""Adds the `workflow_library` table and drops the `workflows` and `workflow_images` tables."""
|
||||||
tables = [
|
tables = [
|
||||||
"""--sql
|
"""--sql
|
||||||
CREATE TABLE workflow_library (
|
CREATE TABLE IF NOT EXISTS workflow_library (
|
||||||
workflow_id TEXT NOT NULL PRIMARY KEY,
|
workflow_id TEXT NOT NULL PRIMARY KEY,
|
||||||
workflow TEXT NOT NULL,
|
workflow TEXT NOT NULL,
|
||||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
@ -50,17 +66,17 @@ def _add_workflow_library(cursor: sqlite3.Cursor) -> None:
|
|||||||
]
|
]
|
||||||
|
|
||||||
indices = [
|
indices = [
|
||||||
"CREATE INDEX idx_workflow_library_created_at ON workflow_library(created_at);",
|
"CREATE INDEX IF NOT EXISTS idx_workflow_library_created_at ON workflow_library(created_at);",
|
||||||
"CREATE INDEX idx_workflow_library_updated_at ON workflow_library(updated_at);",
|
"CREATE INDEX IF NOT EXISTS idx_workflow_library_updated_at ON workflow_library(updated_at);",
|
||||||
"CREATE INDEX idx_workflow_library_opened_at ON workflow_library(opened_at);",
|
"CREATE INDEX IF NOT EXISTS idx_workflow_library_opened_at ON workflow_library(opened_at);",
|
||||||
"CREATE INDEX idx_workflow_library_category ON workflow_library(category);",
|
"CREATE INDEX IF NOT EXISTS idx_workflow_library_category ON workflow_library(category);",
|
||||||
"CREATE INDEX idx_workflow_library_name ON workflow_library(name);",
|
"CREATE INDEX IF NOT EXISTS idx_workflow_library_name ON workflow_library(name);",
|
||||||
"CREATE INDEX idx_workflow_library_description ON workflow_library(description);",
|
"CREATE INDEX IF NOT EXISTS idx_workflow_library_description ON workflow_library(description);",
|
||||||
]
|
]
|
||||||
|
|
||||||
triggers = [
|
triggers = [
|
||||||
"""--sql
|
"""--sql
|
||||||
CREATE TRIGGER tg_workflow_library_updated_at
|
CREATE TRIGGER IF NOT EXISTS tg_workflow_library_updated_at
|
||||||
AFTER UPDATE
|
AFTER UPDATE
|
||||||
ON workflow_library FOR EACH ROW
|
ON workflow_library FOR EACH ROW
|
||||||
BEGIN
|
BEGIN
|
||||||
@ -77,12 +93,43 @@ def _add_workflow_library(cursor: sqlite3.Cursor) -> None:
|
|||||||
|
|
||||||
def _drop_model_manager_metadata(cursor: sqlite3.Cursor) -> None:
|
def _drop_model_manager_metadata(cursor: sqlite3.Cursor) -> None:
|
||||||
"""Drops the `model_manager_metadata` table."""
|
"""Drops the `model_manager_metadata` table."""
|
||||||
cursor.execute("DROP TABLE model_manager_metadata;")
|
cursor.execute("DROP TABLE IF EXISTS model_manager_metadata;")
|
||||||
|
|
||||||
|
|
||||||
|
def _migrate_embedded_workflows(cursor: sqlite3.Cursor, image_files: ImageFileStorageBase, logger: Logger) -> None:
|
||||||
|
"""
|
||||||
|
In the v3.5.0 release, InvokeAI changed how it handles embedded workflows. The `images` table in
|
||||||
|
the database now has a `has_workflow` column, indicating if an image has a workflow embedded.
|
||||||
|
|
||||||
|
This migrate callbakc checks each image for the presence of an embedded workflow, then updates its entry
|
||||||
|
in the database accordingly.
|
||||||
|
"""
|
||||||
|
# Get the total number of images and chunk it into pages
|
||||||
|
cursor.execute("SELECT image_name FROM images")
|
||||||
|
image_names: list[str] = [image[0] for image in cursor.fetchall()]
|
||||||
|
total_image_names = len(image_names)
|
||||||
|
|
||||||
|
if not total_image_names:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Migrating workflows for {total_image_names} images")
|
||||||
|
|
||||||
|
# Migrate the images
|
||||||
|
to_migrate: list[tuple[bool, str]] = []
|
||||||
|
pbar = tqdm(image_names)
|
||||||
|
for idx, image_name in enumerate(pbar):
|
||||||
|
pbar.set_description(f"Checking image {idx + 1}/{total_image_names} for workflow")
|
||||||
|
pil_image = image_files.get(image_name)
|
||||||
|
if "invokeai_workflow" in pil_image.info:
|
||||||
|
to_migrate.append((True, image_name))
|
||||||
|
|
||||||
|
logger.info(f"Adding {len(to_migrate)} embedded workflows to database")
|
||||||
|
cursor.executemany("UPDATE images SET has_workflow = ? WHERE image_name = ?", to_migrate)
|
||||||
|
|
||||||
|
|
||||||
migration_2 = Migration(
|
migration_2 = Migration(
|
||||||
db_version=2,
|
from_version=1,
|
||||||
app_version="3.5.0",
|
to_version=2,
|
||||||
migrate=_migrate,
|
migrate=_migrate,
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
@ -90,8 +137,10 @@ Database version 2.
|
|||||||
|
|
||||||
Introduced in v3.5.0 for the new workflow library.
|
Introduced in v3.5.0 for the new workflow library.
|
||||||
|
|
||||||
|
Migration:
|
||||||
- Add `has_workflow` column to `images` table
|
- Add `has_workflow` column to `images` table
|
||||||
- Add `workflow` column to `session_queue` table
|
- Add `workflow` column to `session_queue` table
|
||||||
- Drop `workflows` and `workflow_images` tables
|
- Drop `workflows` and `workflow_images` tables
|
||||||
- Add `workflow_library` table
|
- Add `workflow_library` table
|
||||||
|
- Updates `has_workflow` for all images
|
||||||
"""
|
"""
|
||||||
|
@ -8,13 +8,15 @@ from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
|
|||||||
|
|
||||||
|
|
||||||
class SqliteDatabase:
|
class SqliteDatabase:
|
||||||
database: Path | str
|
database: Path | str # Must declare this here to satisfy type checker
|
||||||
|
|
||||||
def __init__(self, config: InvokeAIAppConfig, logger: Logger):
|
def __init__(self, config: InvokeAIAppConfig, logger: Logger):
|
||||||
self._logger = logger
|
self._logger = logger
|
||||||
self._config = config
|
self._config = config
|
||||||
|
self.is_memory = False
|
||||||
if self._config.use_memory_db:
|
if self._config.use_memory_db:
|
||||||
self.database = sqlite_memory
|
self.database = sqlite_memory
|
||||||
|
self.is_memory = True
|
||||||
logger.info("Using in-memory database")
|
logger.info("Using in-memory database")
|
||||||
else:
|
else:
|
||||||
self.database = self._config.db_path
|
self.database = self._config.db_path
|
||||||
|
@ -1,14 +1,15 @@
|
|||||||
import shutil
|
import shutil
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import threading
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from logging import Logger
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Optional, TypeAlias
|
from typing import Callable, Optional, TypeAlias
|
||||||
|
|
||||||
from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
MigrateCallback: TypeAlias = Callable[[sqlite3.Cursor], None]
|
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
|
||||||
|
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||||
|
|
||||||
|
MigrateCallback: TypeAlias = Callable[[SqliteDatabase, ImageFileStorageBase], None]
|
||||||
|
|
||||||
|
|
||||||
class MigrationError(Exception):
|
class MigrationError(Exception):
|
||||||
@ -19,146 +20,186 @@ class MigrationVersionError(ValueError, MigrationError):
|
|||||||
"""Raised when a migration version is invalid."""
|
"""Raised when a migration version is invalid."""
|
||||||
|
|
||||||
|
|
||||||
class Migration:
|
class Migration(BaseModel):
|
||||||
"""Represents a migration for a SQLite database.
|
"""
|
||||||
|
Represents a migration for a SQLite database.
|
||||||
|
|
||||||
:param db_version: The database schema version this migration results in.
|
Migration callbacks will be provided an instance of SqliteDatabase.
|
||||||
:param app_version: The app version this migration is introduced in.
|
Migration callbacks should not commit; the migrator will commit the transaction.
|
||||||
:param migrate: The callback to run to perform the migration. The callback will be passed a
|
|
||||||
cursor to the database. The migrator will manage locking database access and committing the
|
|
||||||
transaction; the callback should not do either of these things.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
from_version: int = Field(ge=0, strict=True, description="The database version on which this migration may be run")
|
||||||
self,
|
to_version: int = Field(ge=1, strict=True, description="The database version that results from this migration")
|
||||||
db_version: int,
|
migrate: MigrateCallback = Field(description="The callback to run to perform the migration")
|
||||||
app_version: str,
|
pre_migrate: list[MigrateCallback] = Field(
|
||||||
migrate: MigrateCallback,
|
default=[], description="A list of callbacks to run before the migration"
|
||||||
) -> None:
|
)
|
||||||
self.db_version = db_version
|
post_migrate: list[MigrateCallback] = Field(
|
||||||
self.app_version = app_version
|
default=[], description="A list of callbacks to run after the migration"
|
||||||
self.migrate = migrate
|
)
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_to_version(self) -> "Migration":
|
||||||
|
if self.to_version <= self.from_version:
|
||||||
|
raise ValueError("to_version must be greater than from_version")
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
# Callables are not hashable, so we need to implement our own __hash__ function to use this class in a set.
|
||||||
|
return hash((self.from_version, self.to_version))
|
||||||
|
|
||||||
|
|
||||||
|
class MigrationSet:
|
||||||
|
"""A set of Migrations. Performs validation during migration registration and provides utility methods."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._migrations: set[Migration] = set()
|
||||||
|
|
||||||
|
def register(self, migration: Migration) -> None:
|
||||||
|
"""Registers a migration."""
|
||||||
|
if any(m.from_version == migration.from_version for m in self._migrations):
|
||||||
|
raise MigrationVersionError(f"Migration from {migration.from_version} already registered")
|
||||||
|
if any(m.to_version == migration.to_version for m in self._migrations):
|
||||||
|
raise MigrationVersionError(f"Migration to {migration.to_version} already registered")
|
||||||
|
self._migrations.add(migration)
|
||||||
|
|
||||||
|
def get(self, from_version: int) -> Optional[Migration]:
|
||||||
|
"""Gets the migration that may be run on the given database version."""
|
||||||
|
# register() ensures that there is only one migration with a given from_version, so this is safe.
|
||||||
|
return next((m for m in self._migrations if m.from_version == from_version), None)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def count(self) -> int:
|
||||||
|
"""The count of registered migrations."""
|
||||||
|
return len(self._migrations)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def latest_version(self) -> int:
|
||||||
|
"""Gets latest to_version among registered migrations. Returns 0 if there are no migrations registered."""
|
||||||
|
if self.count == 0:
|
||||||
|
return 0
|
||||||
|
return sorted(self._migrations, key=lambda m: m.to_version)[-1].to_version
|
||||||
|
|
||||||
|
|
||||||
class SQLiteMigrator:
|
class SQLiteMigrator:
|
||||||
"""
|
"""
|
||||||
Manages migrations for a SQLite database.
|
Manages migrations for a SQLite database.
|
||||||
|
|
||||||
:param conn: The database connection.
|
:param db: The SqliteDatabase, representing the database on which to run migrations.
|
||||||
:param database: The path to the database file, or ":memory:" for an in-memory database.
|
:param image_files: An instance of ImageFileStorageBase. Migrations may need to access image files.
|
||||||
:param lock: A lock to use when accessing the database.
|
|
||||||
:param logger: The logger to use.
|
|
||||||
|
|
||||||
Migrations should be registered with :meth:`register_migration`. Migrations will be run in
|
Migrations should be registered with :meth:`register_migration`. Migrations will be run in
|
||||||
order of their version number. If the database is already at the latest version, no migrations
|
order of their version number. If the database is already at the latest version, no migrations
|
||||||
will be run.
|
will be run.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, conn: sqlite3.Connection, database: Path | str, lock: threading.RLock, logger: Logger) -> None:
|
backup_path: Optional[Path] = None
|
||||||
self._logger = logger
|
|
||||||
self._conn = conn
|
def __init__(self, db: SqliteDatabase, image_files: ImageFileStorageBase) -> None:
|
||||||
self._cursor = self._conn.cursor()
|
self._image_files = image_files
|
||||||
self._lock = lock
|
self._db = db
|
||||||
self._database = database
|
self._logger = self._db._logger
|
||||||
self._migrations: set[Migration] = set()
|
self._config = self._db._config
|
||||||
|
self._cursor = self._db.conn.cursor()
|
||||||
|
self._migrations = MigrationSet()
|
||||||
|
|
||||||
def register_migration(self, migration: Migration) -> None:
|
def register_migration(self, migration: Migration) -> None:
|
||||||
"""Registers a migration."""
|
"""
|
||||||
if not isinstance(migration.db_version, int) or migration.db_version < 1:
|
Registers a migration.
|
||||||
raise MigrationVersionError(f"Invalid migration version {migration.db_version}")
|
Migration callbacks should not commit any changes to the database; the migrator will commit the transaction.
|
||||||
if any(m.db_version == migration.db_version for m in self._migrations):
|
"""
|
||||||
raise MigrationVersionError(f"Migration version {migration.db_version} already registered")
|
self._migrations.register(migration)
|
||||||
self._migrations.add(migration)
|
self._logger.debug(f"Registered migration {migration.from_version} -> {migration.to_version}")
|
||||||
self._logger.debug(f"Registered migration {migration.db_version}")
|
|
||||||
|
|
||||||
def run_migrations(self) -> None:
|
def run_migrations(self) -> None:
|
||||||
"""Migrates the database to the latest version."""
|
"""Migrates the database to the latest version."""
|
||||||
with self._lock:
|
with self._db.lock:
|
||||||
self._create_version_table()
|
self._create_version_table()
|
||||||
sorted_migrations = sorted(self._migrations, key=lambda m: m.db_version)
|
|
||||||
current_version = self._get_current_version()
|
current_version = self._get_current_version()
|
||||||
|
|
||||||
if len(sorted_migrations) == 0:
|
if self._migrations.count == 0:
|
||||||
self._logger.debug("No migrations registered")
|
self._logger.debug("No migrations registered")
|
||||||
return
|
return
|
||||||
|
|
||||||
latest_version = sorted_migrations[-1].db_version
|
latest_version = self._migrations.latest_version
|
||||||
if current_version == latest_version:
|
if current_version == latest_version:
|
||||||
self._logger.debug("Database is up to date, no migrations to run")
|
self._logger.debug("Database is up to date, no migrations to run")
|
||||||
return
|
return
|
||||||
|
|
||||||
if current_version > latest_version:
|
|
||||||
raise MigrationError(
|
|
||||||
f"Database version {current_version} is greater than the latest migration version {latest_version}"
|
|
||||||
)
|
|
||||||
|
|
||||||
self._logger.info("Database update needed")
|
self._logger.info("Database update needed")
|
||||||
|
|
||||||
# Only make a backup if using a file database (not memory)
|
# Only make a backup if using a file database (not memory)
|
||||||
backup_path: Optional[Path] = None
|
self._backup_db()
|
||||||
if isinstance(self._database, Path):
|
|
||||||
backup_path = self._backup_db(self._database)
|
|
||||||
else:
|
|
||||||
self._logger.info("Using in-memory database, skipping backup")
|
|
||||||
|
|
||||||
for migration in sorted_migrations:
|
next_migration = self._migrations.get(from_version=current_version)
|
||||||
|
while next_migration is not None:
|
||||||
try:
|
try:
|
||||||
self._run_migration(migration)
|
self._run_migration(next_migration)
|
||||||
|
next_migration = self._migrations.get(self._get_current_version())
|
||||||
except MigrationError:
|
except MigrationError:
|
||||||
if backup_path is not None:
|
self._restore_db()
|
||||||
self._logger.error(f" Restoring from {backup_path}")
|
|
||||||
self._restore_db(backup_path)
|
|
||||||
raise
|
raise
|
||||||
self._logger.info("Database updated successfully")
|
self._logger.info("Database updated successfully")
|
||||||
|
|
||||||
def _run_migration(self, migration: Migration) -> None:
|
def _run_migration(self, migration: Migration) -> None:
|
||||||
"""Runs a single migration."""
|
"""Runs a single migration."""
|
||||||
with self._lock:
|
with self._db.lock:
|
||||||
current_version = self._get_current_version()
|
|
||||||
try:
|
try:
|
||||||
if current_version >= migration.db_version:
|
if self._get_current_version() != migration.from_version:
|
||||||
return
|
raise MigrationError(
|
||||||
migration.migrate(self._cursor)
|
f"Database is at version {self._get_current_version()}, expected {migration.from_version}"
|
||||||
|
)
|
||||||
|
self._logger.debug(f"Running migration from {migration.from_version} to {migration.to_version}")
|
||||||
|
if migration.pre_migrate:
|
||||||
|
self._logger.debug(f"Running {len(migration.pre_migrate)} pre-migration callbacks")
|
||||||
|
for callback in migration.pre_migrate:
|
||||||
|
callback(self._db, self._image_files)
|
||||||
|
migration.migrate(self._db, self._image_files)
|
||||||
|
self._cursor.execute("INSERT INTO migrations (version) VALUES (?);", (migration.to_version,))
|
||||||
|
if migration.post_migrate:
|
||||||
|
self._logger.debug(f"Running {len(migration.post_migrate)} post-migration callbacks")
|
||||||
|
for callback in migration.post_migrate:
|
||||||
|
callback(self._db, self._image_files)
|
||||||
# Migration callbacks only get a cursor; they cannot commit the transaction.
|
# Migration callbacks only get a cursor; they cannot commit the transaction.
|
||||||
self._conn.commit()
|
self._db.conn.commit()
|
||||||
self._set_version(db_version=migration.db_version, app_version=migration.app_version)
|
self._logger.debug(
|
||||||
self._logger.debug(f"Successfully migrated database from {current_version} to {migration.db_version}")
|
f"Successfully migrated database from {migration.from_version} to {migration.to_version}"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
msg = f"Error migrating database from {current_version} to {migration.db_version}: {e}"
|
msg = f"Error migrating database from {migration.from_version} to {migration.to_version}: {e}"
|
||||||
self._conn.rollback()
|
self._db.conn.rollback()
|
||||||
self._logger.error(msg)
|
self._logger.error(msg)
|
||||||
raise MigrationError(msg) from e
|
raise MigrationError(msg) from e
|
||||||
|
|
||||||
def _create_version_table(self) -> None:
|
def _create_version_table(self) -> None:
|
||||||
"""Creates a version table for the database, if one does not already exist."""
|
"""Creates a version table for the database, if one does not already exist."""
|
||||||
with self._lock:
|
with self._db.lock:
|
||||||
try:
|
try:
|
||||||
self._cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='version';")
|
self._cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='migrations';")
|
||||||
if self._cursor.fetchone() is not None:
|
if self._cursor.fetchone() is not None:
|
||||||
return
|
return
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
CREATE TABLE IF NOT EXISTS version (
|
CREATE TABLE migrations (
|
||||||
db_version INTEGER PRIMARY KEY,
|
version INTEGER PRIMARY KEY,
|
||||||
app_version TEXT NOT NULL,
|
|
||||||
migrated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
|
migrated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
|
||||||
);
|
);
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
self._cursor.execute("INSERT INTO version (db_version, app_version) VALUES (?,?);", (0, "0.0.0"))
|
self._cursor.execute("INSERT INTO migrations (version) VALUES (0);")
|
||||||
self._conn.commit()
|
self._db.conn.commit()
|
||||||
self._logger.debug("Created version table")
|
self._logger.debug("Created migrations table")
|
||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
msg = f"Problem creation version table: {e}"
|
msg = f"Problem creating migrations table: {e}"
|
||||||
self._logger.error(msg)
|
self._logger.error(msg)
|
||||||
self._conn.rollback()
|
self._db.conn.rollback()
|
||||||
raise MigrationError(msg) from e
|
raise MigrationError(msg) from e
|
||||||
|
|
||||||
def _get_current_version(self) -> int:
|
def _get_current_version(self) -> int:
|
||||||
"""Gets the current version of the database, or 0 if the version table does not exist."""
|
"""Gets the current version of the database, or 0 if the version table does not exist."""
|
||||||
with self._lock:
|
with self._db.lock:
|
||||||
try:
|
try:
|
||||||
self._cursor.execute("SELECT MAX(db_version) FROM version;")
|
self._cursor.execute("SELECT MAX(version) FROM migrations;")
|
||||||
version = self._cursor.fetchone()[0]
|
version = self._cursor.fetchone()[0]
|
||||||
if version is None:
|
if version is None:
|
||||||
return 0
|
return 0
|
||||||
@ -168,43 +209,42 @@ class SQLiteMigrator:
|
|||||||
return 0
|
return 0
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _set_version(self, db_version: int, app_version: str) -> None:
|
def _backup_db(self) -> None:
|
||||||
"""Adds a version entry to the table's version table."""
|
|
||||||
with self._lock:
|
|
||||||
try:
|
|
||||||
self._cursor.execute(
|
|
||||||
"INSERT INTO version (db_version, app_version) VALUES (?,?);", (db_version, app_version)
|
|
||||||
)
|
|
||||||
self._conn.commit()
|
|
||||||
except sqlite3.Error as e:
|
|
||||||
msg = f"Problem setting database version: {e}"
|
|
||||||
self._logger.error(msg)
|
|
||||||
self._conn.rollback()
|
|
||||||
raise MigrationError(msg) from e
|
|
||||||
|
|
||||||
def _backup_db(self, db_path: Path | str) -> Path:
|
|
||||||
"""Backs up the databse, returning the path to the backup file."""
|
"""Backs up the databse, returning the path to the backup file."""
|
||||||
if db_path == sqlite_memory:
|
if self._db.is_memory:
|
||||||
raise MigrationError("Cannot back up memory database")
|
self._logger.debug("Using memory database, skipping backup")
|
||||||
if not isinstance(db_path, Path):
|
# Sanity check!
|
||||||
raise MigrationError(f'Database path must be "{sqlite_memory}" or a Path')
|
if not isinstance(self._db.database, Path):
|
||||||
with self._lock:
|
raise MigrationError(f"Database path must be a Path, got {self._db.database} ({type(self._db.database)})")
|
||||||
|
with self._db.lock:
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
backup_path = db_path.parent / f"{db_path.stem}_{timestamp}.db"
|
backup_path = self._db.database.parent / f"{self._db.database.stem}_{timestamp}.db"
|
||||||
self._logger.info(f"Backing up database to {backup_path}")
|
self._logger.info(f"Backing up database to {backup_path}")
|
||||||
|
|
||||||
|
# Use SQLite's built in backup capabilities so we don't need to worry about locking and such.
|
||||||
backup_conn = sqlite3.connect(backup_path)
|
backup_conn = sqlite3.connect(backup_path)
|
||||||
with backup_conn:
|
with backup_conn:
|
||||||
self._conn.backup(backup_conn)
|
self._db.conn.backup(backup_conn)
|
||||||
backup_conn.close()
|
backup_conn.close()
|
||||||
return backup_path
|
|
||||||
|
|
||||||
def _restore_db(self, backup_path: Path) -> None:
|
# Sanity check!
|
||||||
|
if not backup_path.is_file():
|
||||||
|
raise MigrationError("Unable to back up database")
|
||||||
|
self.backup_path = backup_path
|
||||||
|
|
||||||
|
def _restore_db(
|
||||||
|
self,
|
||||||
|
) -> None:
|
||||||
"""Restores the database from a backup file, unless the database is a memory database."""
|
"""Restores the database from a backup file, unless the database is a memory database."""
|
||||||
if self._database == sqlite_memory:
|
# We don't need to restore a memory database.
|
||||||
|
if self._db.is_memory:
|
||||||
return
|
return
|
||||||
with self._lock:
|
|
||||||
self._logger.info(f"Restoring database from {backup_path}")
|
with self._db.lock:
|
||||||
self._conn.close()
|
self._logger.info(f"Restoring database from {self.backup_path}")
|
||||||
if not Path(backup_path).is_file():
|
self._db.conn.close()
|
||||||
raise FileNotFoundError(f"Backup file {backup_path} does not exist")
|
if self.backup_path is None:
|
||||||
shutil.copy2(backup_path, self._database)
|
raise FileNotFoundError("No backup path set")
|
||||||
|
if not Path(self.backup_path).is_file():
|
||||||
|
raise FileNotFoundError(f"Backup file {self.backup_path} does not exist")
|
||||||
|
shutil.copy2(self.backup_path, self._db.database)
|
||||||
|
@ -1,84 +0,0 @@
|
|||||||
import sqlite3
|
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
|
||||||
|
|
||||||
|
|
||||||
def migrate_image_workflows(output_path: Path, database: Path, page_size=100):
|
|
||||||
"""
|
|
||||||
In the v3.5.0 release, InvokeAI changed how it handles image workflows. The `images` table in
|
|
||||||
the database now has a `has_workflow` column, indicating if an image has a workflow embedded.
|
|
||||||
|
|
||||||
This script checks each image for the presence of an embedded workflow, then updates its entry
|
|
||||||
in the database accordingly.
|
|
||||||
|
|
||||||
1) Check if the database is updated to support image workflows. Aborts if it doesn't have the
|
|
||||||
`has_workflow` column yet.
|
|
||||||
2) Backs up the database.
|
|
||||||
3) Opens each image in the `images` table via PIL
|
|
||||||
4) Checks if the `"invokeai_workflow"` attribute its in the image's embedded metadata, indicating
|
|
||||||
that it has a workflow.
|
|
||||||
5) If it does, updates the `has_workflow` column for that image to `TRUE`.
|
|
||||||
|
|
||||||
If there are any problems, the script immediately aborts. Because the processing happens in chunks,
|
|
||||||
if there is a problem, it is suggested that you restore the database from the backup and try again.
|
|
||||||
"""
|
|
||||||
output_path = output_path
|
|
||||||
database = database
|
|
||||||
conn = sqlite3.connect(database)
|
|
||||||
cursor = conn.cursor()
|
|
||||||
|
|
||||||
# We can only migrate if the `images` table has the `has_workflow` column
|
|
||||||
cursor.execute("PRAGMA table_info(images)")
|
|
||||||
columns = [column[1] for column in cursor.fetchall()]
|
|
||||||
if "has_workflow" not in columns:
|
|
||||||
raise Exception("Database needs to be updated to support image workflows")
|
|
||||||
|
|
||||||
# Back up the database before we start
|
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
||||||
backup_path = database.parent / f"{database.stem}_migrate-image-workflows_{timestamp}.db"
|
|
||||||
print(f"Backing up database to {backup_path}")
|
|
||||||
backup_conn = sqlite3.connect(backup_path)
|
|
||||||
with backup_conn:
|
|
||||||
conn.backup(backup_conn)
|
|
||||||
backup_conn.close()
|
|
||||||
|
|
||||||
# Get the total number of images and chunk it into pages
|
|
||||||
cursor.execute("SELECT COUNT(*) FROM images")
|
|
||||||
total_images = cursor.fetchone()[0]
|
|
||||||
total_pages = (total_images + page_size - 1) // page_size
|
|
||||||
print(f"Processing {total_images} images in chunks of {page_size} images...")
|
|
||||||
|
|
||||||
# Migrate the images
|
|
||||||
migrated_count = 0
|
|
||||||
pbar = tqdm(range(total_pages))
|
|
||||||
for page in pbar:
|
|
||||||
pbar.set_description(f"Migrating page {page + 1}/{total_pages}")
|
|
||||||
offset = page * page_size
|
|
||||||
cursor.execute("SELECT image_name FROM images LIMIT ? OFFSET ?", (page_size, offset))
|
|
||||||
images = cursor.fetchall()
|
|
||||||
for image_name in images:
|
|
||||||
image_path = output_path / "images" / image_name[0]
|
|
||||||
with Image.open(image_path) as img:
|
|
||||||
if "invokeai_workflow" in img.info:
|
|
||||||
cursor.execute("UPDATE images SET has_workflow = TRUE WHERE image_name = ?", (image_name[0],))
|
|
||||||
migrated_count += 1
|
|
||||||
conn.commit()
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
print(f"Migrated workflows for {migrated_count} images.")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
config = InvokeAIAppConfig.get_config()
|
|
||||||
output_path = config.output_path
|
|
||||||
database = config.db_path
|
|
||||||
|
|
||||||
assert output_path is not None
|
|
||||||
assert output_path.exists()
|
|
||||||
assert database.exists()
|
|
||||||
migrate_image_workflows(output_path=output_path, database=database, page_size=100)
|
|
Loading…
x
Reference in New Issue
Block a user