From 83e820d721f88f8909ae82bb4270b0296fe3e351 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 10 Dec 2023 22:17:03 +1100 Subject: [PATCH] feat(db): decouple from SqliteDatabase --- invokeai/app/api/dependencies.py | 2 +- .../shared/sqlite/migrations/migration_1.py | 5 +- .../shared/sqlite/migrations/migration_2.py | 6 +- .../services/shared/sqlite/sqlite_migrator.py | 72 +++++++++++-------- 4 files changed, 46 insertions(+), 39 deletions(-) diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 6b3e2d6226..6a6b37378f 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -73,7 +73,7 @@ class ApiDependencies: image_files = DiskImageFileStorage(f"{output_folder}/images") db = SqliteDatabase(config, logger) - migrator = SQLiteMigrator(db=db, image_files=image_files) + migrator = SQLiteMigrator(database=db.database, lock=db.lock, image_files=image_files, logger=logger) migrator.register_migration(migration_1) migrator.register_migration(migration_2) migrator.run_migrations() diff --git a/invokeai/app/services/shared/sqlite/migrations/migration_1.py b/invokeai/app/services/shared/sqlite/migrations/migration_1.py index f9be92badc..0cfe53d651 100644 --- a/invokeai/app/services/shared/sqlite/migrations/migration_1.py +++ b/invokeai/app/services/shared/sqlite/migrations/migration_1.py @@ -1,14 +1,13 @@ import sqlite3 +from logging import Logger from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase -from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase from invokeai.app.services.shared.sqlite.sqlite_migrator import Migration -def _migrate(db: SqliteDatabase, image_files: ImageFileStorageBase) -> None: +def _migrate(cursor: sqlite3.Cursor, image_files: ImageFileStorageBase, logger: Logger) -> None: """Migration callback for database version 1.""" - cursor = db.conn.cursor() _create_board_images(cursor) _create_boards(cursor) _create_images(cursor) diff --git a/invokeai/app/services/shared/sqlite/migrations/migration_2.py b/invokeai/app/services/shared/sqlite/migrations/migration_2.py index b6a697844b..0d3c10a629 100644 --- a/invokeai/app/services/shared/sqlite/migrations/migration_2.py +++ b/invokeai/app/services/shared/sqlite/migrations/migration_2.py @@ -4,20 +4,18 @@ from logging import Logger from tqdm import tqdm from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase -from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase from invokeai.app.services.shared.sqlite.sqlite_migrator import Migration -def _migrate(db: SqliteDatabase, image_files: ImageFileStorageBase) -> None: +def _migrate(cursor: sqlite3.Cursor, image_files: ImageFileStorageBase, logger: Logger) -> None: """Migration callback for database version 2.""" - cursor = db.conn.cursor() _add_images_has_workflow(cursor) _add_session_queue_workflow(cursor) _drop_old_workflow_tables(cursor) _add_workflow_library(cursor) _drop_model_manager_metadata(cursor) - _migrate_embedded_workflows(cursor=cursor, image_files=image_files, logger=db._logger) + _migrate_embedded_workflows(cursor=cursor, image_files=image_files, logger=logger) def _add_images_has_workflow(cursor: sqlite3.Cursor) -> None: diff --git a/invokeai/app/services/shared/sqlite/sqlite_migrator.py b/invokeai/app/services/shared/sqlite/sqlite_migrator.py index 9cbdf028d5..e9895c7de6 100644 --- a/invokeai/app/services/shared/sqlite/sqlite_migrator.py +++ b/invokeai/app/services/shared/sqlite/sqlite_migrator.py @@ -1,15 +1,17 @@ import shutil import sqlite3 +import threading from datetime import datetime +from logging import Logger from pathlib import Path from typing import Callable, Optional, TypeAlias from pydantic import BaseModel, Field, model_validator from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase -from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase +from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory -MigrateCallback: TypeAlias = Callable[[SqliteDatabase, ImageFileStorageBase], None] +MigrateCallback: TypeAlias = Callable[[sqlite3.Cursor, ImageFileStorageBase, Logger], None] class MigrationError(Exception): @@ -95,17 +97,25 @@ class SQLiteMigrator: backup_path: Optional[Path] = None - def __init__(self, db: SqliteDatabase, image_files: ImageFileStorageBase) -> None: + def __init__( + self, + database: Path | str, + lock: threading.RLock, + logger: Logger, + image_files: ImageFileStorageBase, + ) -> None: + self._lock = lock + self._database = database + self._is_memory = database == sqlite_memory self._image_files = image_files - self._db = db - self._logger = self._db._logger - self._config = self._db._config - self._cursor = self._db.conn.cursor() + self._logger = logger + self._conn = sqlite3.connect(database) + self._cursor = self._conn.cursor() self._migrations = MigrationSet() # Use a lock file to indicate that a migration is in progress. Should only exist in the event of catastrophic failure. self._migration_lock_file_path = ( - self._db.database.parent / ".migration_in_progress" if isinstance(self._db.database, Path) else None + self._database.parent / ".migration_in_progress" if isinstance(self._database, Path) else None ) if self._unlink_lock_file(): @@ -121,7 +131,7 @@ class SQLiteMigrator: def run_migrations(self) -> None: """Migrates the database to the latest version.""" - with self._db.lock: + with self._lock: self._create_version_table() current_version = self._get_current_version() @@ -151,7 +161,7 @@ class SQLiteMigrator: def _run_migration(self, migration: Migration) -> None: """Runs a single migration.""" - with self._db.lock: + with self._lock: try: if self._get_current_version() != migration.from_version: raise MigrationError( @@ -161,27 +171,27 @@ class SQLiteMigrator: if migration.pre_migrate: self._logger.debug(f"Running {len(migration.pre_migrate)} pre-migration callbacks") for callback in migration.pre_migrate: - callback(self._db, self._image_files) - migration.migrate(self._db, self._image_files) + callback(self._cursor, self._image_files, self._logger) + migration.migrate(self._cursor, self._image_files, self._logger) self._cursor.execute("INSERT INTO migrations (version) VALUES (?);", (migration.to_version,)) if migration.post_migrate: self._logger.debug(f"Running {len(migration.post_migrate)} post-migration callbacks") for callback in migration.post_migrate: - callback(self._db, self._image_files) + callback(self._cursor, self._image_files, self._logger) # Migration callbacks only get a cursor; they cannot commit the transaction. - self._db.conn.commit() + self._conn.commit() self._logger.debug( f"Successfully migrated database from {migration.from_version} to {migration.to_version}" ) except Exception as e: msg = f"Error migrating database from {migration.from_version} to {migration.to_version}: {e}" - self._db.conn.rollback() + self._conn.rollback() self._logger.error(msg) raise MigrationError(msg) from e def _create_version_table(self) -> None: """Creates a version table for the database, if one does not already exist.""" - with self._db.lock: + with self._lock: try: self._cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='migrations';") if self._cursor.fetchone() is not None: @@ -195,17 +205,17 @@ class SQLiteMigrator: """ ) self._cursor.execute("INSERT INTO migrations (version) VALUES (0);") - self._db.conn.commit() + self._conn.commit() self._logger.debug("Created migrations table") except sqlite3.Error as e: msg = f"Problem creating migrations table: {e}" self._logger.error(msg) - self._db.conn.rollback() + self._conn.rollback() raise MigrationError(msg) from e def _get_current_version(self) -> int: """Gets the current version of the database, or 0 if the version table does not exist.""" - with self._db.lock: + with self._lock: try: self._cursor.execute("SELECT MAX(version) FROM migrations;") version = self._cursor.fetchone()[0] @@ -219,19 +229,19 @@ class SQLiteMigrator: def _backup_db(self) -> None: """Backs up the databse, returning the path to the backup file.""" - if self._db.is_memory: + if self._is_memory: self._logger.debug("Using memory database, skipping backup") # Sanity check! - assert isinstance(self._db.database, Path) - with self._db.lock: + assert isinstance(self._database, Path) + with self._lock: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - backup_path = self._db.database.parent / f"{self._db.database.stem}_{timestamp}.db" + backup_path = self._database.parent / f"{self._database.stem}_{timestamp}.db" self._logger.info(f"Backing up database to {backup_path}") # Use SQLite's built in backup capabilities so we don't need to worry about locking and such. backup_conn = sqlite3.connect(backup_path) with backup_conn: - self._db.conn.backup(backup_conn) + self._conn.backup(backup_conn) backup_conn.close() # Sanity check! @@ -243,18 +253,18 @@ class SQLiteMigrator: self, ) -> None: """Restores the database from a backup file, unless the database is a memory database.""" - if self._db.is_memory: + if self._is_memory: return - with self._db.lock: + with self._lock: self._logger.info(f"Restoring database from {self.backup_path}") - self._db.conn.close() + self._conn.close() assert isinstance(self.backup_path, Path) - shutil.copy2(self.backup_path, self._db.database) + shutil.copy2(self.backup_path, self._database) def _unlink_lock_file(self) -> bool: """Unlinks the migration lock file, returning True if it existed.""" - if self._db.is_memory or self._migration_lock_file_path is None: + if self._is_memory or self._migration_lock_file_path is None: return False if self._migration_lock_file_path.is_file(): self._migration_lock_file_path.unlink() @@ -263,8 +273,8 @@ class SQLiteMigrator: def _write_migration_lock_file(self) -> None: """Writes a file to indicate that a migration is in progress.""" - if self._db.is_memory or self._migration_lock_file_path is None: + if self._is_memory or self._migration_lock_file_path is None: return - assert isinstance(self._db.database, Path) + assert isinstance(self._database, Path) with open(self._migration_lock_file_path, "w") as f: f.write("1")