mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(db): reduce power MigrateCallback, only gets cursor
use partial to provide extra dependencies for the image workflow migration function
This commit is contained in:
parent
83e820d721
commit
abeb1bd3b3
@ -1,9 +1,11 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from functools import partial
|
||||
from logging import Logger
|
||||
|
||||
from invokeai.app.services.shared.sqlite.migrations.migration_1 import migration_1
|
||||
from invokeai.app.services.shared.sqlite.migrations.migration_2 import migration_2
|
||||
from invokeai.app.services.shared.sqlite.migrations.migration_2_post import migrate_embedded_workflows
|
||||
from invokeai.app.services.shared.sqlite.sqlite_migrator import SQLiteMigrator
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
@ -73,7 +75,8 @@ class ApiDependencies:
|
||||
image_files = DiskImageFileStorage(f"{output_folder}/images")
|
||||
|
||||
db = SqliteDatabase(config, logger)
|
||||
migrator = SQLiteMigrator(database=db.database, lock=db.lock, image_files=image_files, logger=logger)
|
||||
migrator = SQLiteMigrator(database=db.database, lock=db.lock, logger=logger)
|
||||
migration_2.register_post_callback(partial(migrate_embedded_workflows, logger=logger, image_files=image_files))
|
||||
migrator.register_migration(migration_1)
|
||||
migrator.register_migration(migration_2)
|
||||
migrator.run_migrations()
|
||||
|
@ -1,11 +1,9 @@
|
||||
import sqlite3
|
||||
from logging import Logger
|
||||
|
||||
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
|
||||
from invokeai.app.services.shared.sqlite.sqlite_migrator import Migration
|
||||
|
||||
|
||||
def _migrate(cursor: sqlite3.Cursor, image_files: ImageFileStorageBase, logger: Logger) -> None:
|
||||
def _migrate(cursor: sqlite3.Cursor) -> None:
|
||||
"""Migration callback for database version 1."""
|
||||
|
||||
_create_board_images(cursor)
|
||||
|
@ -1,13 +1,9 @@
|
||||
import sqlite3
|
||||
from logging import Logger
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
|
||||
from invokeai.app.services.shared.sqlite.sqlite_migrator import Migration
|
||||
|
||||
|
||||
def _migrate(cursor: sqlite3.Cursor, image_files: ImageFileStorageBase, logger: Logger) -> None:
|
||||
def _migrate(cursor: sqlite3.Cursor) -> None:
|
||||
"""Migration callback for database version 2."""
|
||||
|
||||
_add_images_has_workflow(cursor)
|
||||
@ -15,7 +11,6 @@ def _migrate(cursor: sqlite3.Cursor, image_files: ImageFileStorageBase, logger:
|
||||
_drop_old_workflow_tables(cursor)
|
||||
_add_workflow_library(cursor)
|
||||
_drop_model_manager_metadata(cursor)
|
||||
_migrate_embedded_workflows(cursor=cursor, image_files=image_files, logger=logger)
|
||||
|
||||
|
||||
def _add_images_has_workflow(cursor: sqlite3.Cursor) -> None:
|
||||
@ -94,37 +89,6 @@ def _drop_model_manager_metadata(cursor: sqlite3.Cursor) -> None:
|
||||
cursor.execute("DROP TABLE IF EXISTS model_manager_metadata;")
|
||||
|
||||
|
||||
def _migrate_embedded_workflows(cursor: sqlite3.Cursor, image_files: ImageFileStorageBase, logger: Logger) -> None:
|
||||
"""
|
||||
In the v3.5.0 release, InvokeAI changed how it handles embedded workflows. The `images` table in
|
||||
the database now has a `has_workflow` column, indicating if an image has a workflow embedded.
|
||||
|
||||
This migrate callbakc checks each image for the presence of an embedded workflow, then updates its entry
|
||||
in the database accordingly.
|
||||
"""
|
||||
# Get the total number of images and chunk it into pages
|
||||
cursor.execute("SELECT image_name FROM images")
|
||||
image_names: list[str] = [image[0] for image in cursor.fetchall()]
|
||||
total_image_names = len(image_names)
|
||||
|
||||
if not total_image_names:
|
||||
return
|
||||
|
||||
logger.info(f"Migrating workflows for {total_image_names} images")
|
||||
|
||||
# Migrate the images
|
||||
to_migrate: list[tuple[bool, str]] = []
|
||||
pbar = tqdm(image_names)
|
||||
for idx, image_name in enumerate(pbar):
|
||||
pbar.set_description(f"Checking image {idx + 1}/{total_image_names} for workflow")
|
||||
pil_image = image_files.get(image_name)
|
||||
if "invokeai_workflow" in pil_image.info:
|
||||
to_migrate.append((True, image_name))
|
||||
|
||||
logger.info(f"Adding {len(to_migrate)} embedded workflows to database")
|
||||
cursor.executemany("UPDATE images SET has_workflow = ? WHERE image_name = ?", to_migrate)
|
||||
|
||||
|
||||
migration_2 = Migration(
|
||||
from_version=1,
|
||||
to_version=2,
|
||||
@ -140,5 +104,4 @@ Migration:
|
||||
- Add `workflow` column to `session_queue` table
|
||||
- Drop `workflows` and `workflow_images` tables
|
||||
- Add `workflow_library` table
|
||||
- Updates `has_workflow` for all images
|
||||
"""
|
||||
|
@ -0,0 +1,41 @@
|
||||
import sqlite3
|
||||
from logging import Logger
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
|
||||
|
||||
|
||||
def migrate_embedded_workflows(
|
||||
cursor: sqlite3.Cursor,
|
||||
logger: Logger,
|
||||
image_files: ImageFileStorageBase,
|
||||
) -> None:
|
||||
"""
|
||||
In the v3.5.0 release, InvokeAI changed how it handles embedded workflows. The `images` table in
|
||||
the database now has a `has_workflow` column, indicating if an image has a workflow embedded.
|
||||
|
||||
This migrate callbakc checks each image for the presence of an embedded workflow, then updates its entry
|
||||
in the database accordingly.
|
||||
"""
|
||||
# Get the total number of images and chunk it into pages
|
||||
cursor.execute("SELECT image_name FROM images")
|
||||
image_names: list[str] = [image[0] for image in cursor.fetchall()]
|
||||
total_image_names = len(image_names)
|
||||
|
||||
if not total_image_names:
|
||||
return
|
||||
|
||||
logger.info(f"Migrating workflows for {total_image_names} images")
|
||||
|
||||
# Migrate the images
|
||||
to_migrate: list[tuple[bool, str]] = []
|
||||
pbar = tqdm(image_names)
|
||||
for idx, image_name in enumerate(pbar):
|
||||
pbar.set_description(f"Checking image {idx + 1}/{total_image_names} for workflow")
|
||||
pil_image = image_files.get(image_name)
|
||||
if "invokeai_workflow" in pil_image.info:
|
||||
to_migrate.append((True, image_name))
|
||||
|
||||
logger.info(f"Adding {len(to_migrate)} embedded workflows to database")
|
||||
cursor.executemany("UPDATE images SET has_workflow = ? WHERE image_name = ?", to_migrate)
|
@ -8,10 +8,9 @@ from typing import Callable, Optional, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
|
||||
|
||||
MigrateCallback: TypeAlias = Callable[[sqlite3.Cursor, ImageFileStorageBase, Logger], None]
|
||||
MigrateCallback: TypeAlias = Callable[[sqlite3.Cursor], None]
|
||||
|
||||
|
||||
class MigrationError(Exception):
|
||||
@ -50,6 +49,14 @@ class Migration(BaseModel):
|
||||
# Callables are not hashable, so we need to implement our own __hash__ function to use this class in a set.
|
||||
return hash((self.from_version, self.to_version))
|
||||
|
||||
def register_pre_callback(self, callback: MigrateCallback) -> None:
|
||||
"""Registers a pre-migration callback."""
|
||||
self.pre_migrate.append(callback)
|
||||
|
||||
def register_post_callback(self, callback: MigrateCallback) -> None:
|
||||
"""Registers a post-migration callback."""
|
||||
self.post_migrate.append(callback)
|
||||
|
||||
|
||||
class MigrationSet:
|
||||
"""A set of Migrations. Performs validation during migration registration and provides utility methods."""
|
||||
@ -102,12 +109,10 @@ class SQLiteMigrator:
|
||||
database: Path | str,
|
||||
lock: threading.RLock,
|
||||
logger: Logger,
|
||||
image_files: ImageFileStorageBase,
|
||||
) -> None:
|
||||
self._lock = lock
|
||||
self._database = database
|
||||
self._is_memory = database == sqlite_memory
|
||||
self._image_files = image_files
|
||||
self._logger = logger
|
||||
self._conn = sqlite3.connect(database)
|
||||
self._cursor = self._conn.cursor()
|
||||
@ -168,17 +173,24 @@ class SQLiteMigrator:
|
||||
f"Database is at version {self._get_current_version()}, expected {migration.from_version}"
|
||||
)
|
||||
self._logger.debug(f"Running migration from {migration.from_version} to {migration.to_version}")
|
||||
|
||||
# Run pre-migration callbacks
|
||||
if migration.pre_migrate:
|
||||
self._logger.debug(f"Running {len(migration.pre_migrate)} pre-migration callbacks")
|
||||
for callback in migration.pre_migrate:
|
||||
callback(self._cursor, self._image_files, self._logger)
|
||||
migration.migrate(self._cursor, self._image_files, self._logger)
|
||||
callback(self._cursor)
|
||||
|
||||
# Run the actual migration
|
||||
migration.migrate(self._cursor)
|
||||
self._cursor.execute("INSERT INTO migrations (version) VALUES (?);", (migration.to_version,))
|
||||
|
||||
# Run post-migration callbacks
|
||||
if migration.post_migrate:
|
||||
self._logger.debug(f"Running {len(migration.post_migrate)} post-migration callbacks")
|
||||
for callback in migration.post_migrate:
|
||||
callback(self._cursor, self._image_files, self._logger)
|
||||
# Migration callbacks only get a cursor; they cannot commit the transaction.
|
||||
callback(self._cursor)
|
||||
|
||||
# Migration callbacks only get a cursor. Commit this migration.
|
||||
self._conn.commit()
|
||||
self._logger.debug(
|
||||
f"Successfully migrated database from {migration.from_version} to {migration.to_version}"
|
||||
|
Loading…
Reference in New Issue
Block a user