mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(db): decouple from SqliteDatabase
This commit is contained in:
parent
f8e4b93a74
commit
83e820d721
@ -73,7 +73,7 @@ class ApiDependencies:
|
||||
image_files = DiskImageFileStorage(f"{output_folder}/images")
|
||||
|
||||
db = SqliteDatabase(config, logger)
|
||||
migrator = SQLiteMigrator(db=db, image_files=image_files)
|
||||
migrator = SQLiteMigrator(database=db.database, lock=db.lock, image_files=image_files, logger=logger)
|
||||
migrator.register_migration(migration_1)
|
||||
migrator.register_migration(migration_2)
|
||||
migrator.run_migrations()
|
||||
|
@ -1,14 +1,13 @@
|
||||
import sqlite3
|
||||
from logging import Logger
|
||||
|
||||
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.services.shared.sqlite.sqlite_migrator import Migration
|
||||
|
||||
|
||||
def _migrate(db: SqliteDatabase, image_files: ImageFileStorageBase) -> None:
|
||||
def _migrate(cursor: sqlite3.Cursor, image_files: ImageFileStorageBase, logger: Logger) -> None:
|
||||
"""Migration callback for database version 1."""
|
||||
|
||||
cursor = db.conn.cursor()
|
||||
_create_board_images(cursor)
|
||||
_create_boards(cursor)
|
||||
_create_images(cursor)
|
||||
|
@ -4,20 +4,18 @@ from logging import Logger
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.services.shared.sqlite.sqlite_migrator import Migration
|
||||
|
||||
|
||||
def _migrate(db: SqliteDatabase, image_files: ImageFileStorageBase) -> None:
|
||||
def _migrate(cursor: sqlite3.Cursor, image_files: ImageFileStorageBase, logger: Logger) -> None:
|
||||
"""Migration callback for database version 2."""
|
||||
|
||||
cursor = db.conn.cursor()
|
||||
_add_images_has_workflow(cursor)
|
||||
_add_session_queue_workflow(cursor)
|
||||
_drop_old_workflow_tables(cursor)
|
||||
_add_workflow_library(cursor)
|
||||
_drop_model_manager_metadata(cursor)
|
||||
_migrate_embedded_workflows(cursor=cursor, image_files=image_files, logger=db._logger)
|
||||
_migrate_embedded_workflows(cursor=cursor, image_files=image_files, logger=logger)
|
||||
|
||||
|
||||
def _add_images_has_workflow(cursor: sqlite3.Cursor) -> None:
|
||||
|
@ -1,15 +1,17 @@
|
||||
import shutil
|
||||
import sqlite3
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
|
||||
|
||||
MigrateCallback: TypeAlias = Callable[[SqliteDatabase, ImageFileStorageBase], None]
|
||||
MigrateCallback: TypeAlias = Callable[[sqlite3.Cursor, ImageFileStorageBase, Logger], None]
|
||||
|
||||
|
||||
class MigrationError(Exception):
|
||||
@ -95,17 +97,25 @@ class SQLiteMigrator:
|
||||
|
||||
backup_path: Optional[Path] = None
|
||||
|
||||
def __init__(self, db: SqliteDatabase, image_files: ImageFileStorageBase) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
database: Path | str,
|
||||
lock: threading.RLock,
|
||||
logger: Logger,
|
||||
image_files: ImageFileStorageBase,
|
||||
) -> None:
|
||||
self._lock = lock
|
||||
self._database = database
|
||||
self._is_memory = database == sqlite_memory
|
||||
self._image_files = image_files
|
||||
self._db = db
|
||||
self._logger = self._db._logger
|
||||
self._config = self._db._config
|
||||
self._cursor = self._db.conn.cursor()
|
||||
self._logger = logger
|
||||
self._conn = sqlite3.connect(database)
|
||||
self._cursor = self._conn.cursor()
|
||||
self._migrations = MigrationSet()
|
||||
|
||||
# Use a lock file to indicate that a migration is in progress. Should only exist in the event of catastrophic failure.
|
||||
self._migration_lock_file_path = (
|
||||
self._db.database.parent / ".migration_in_progress" if isinstance(self._db.database, Path) else None
|
||||
self._database.parent / ".migration_in_progress" if isinstance(self._database, Path) else None
|
||||
)
|
||||
|
||||
if self._unlink_lock_file():
|
||||
@ -121,7 +131,7 @@ class SQLiteMigrator:
|
||||
|
||||
def run_migrations(self) -> None:
|
||||
"""Migrates the database to the latest version."""
|
||||
with self._db.lock:
|
||||
with self._lock:
|
||||
self._create_version_table()
|
||||
current_version = self._get_current_version()
|
||||
|
||||
@ -151,7 +161,7 @@ class SQLiteMigrator:
|
||||
|
||||
def _run_migration(self, migration: Migration) -> None:
|
||||
"""Runs a single migration."""
|
||||
with self._db.lock:
|
||||
with self._lock:
|
||||
try:
|
||||
if self._get_current_version() != migration.from_version:
|
||||
raise MigrationError(
|
||||
@ -161,27 +171,27 @@ class SQLiteMigrator:
|
||||
if migration.pre_migrate:
|
||||
self._logger.debug(f"Running {len(migration.pre_migrate)} pre-migration callbacks")
|
||||
for callback in migration.pre_migrate:
|
||||
callback(self._db, self._image_files)
|
||||
migration.migrate(self._db, self._image_files)
|
||||
callback(self._cursor, self._image_files, self._logger)
|
||||
migration.migrate(self._cursor, self._image_files, self._logger)
|
||||
self._cursor.execute("INSERT INTO migrations (version) VALUES (?);", (migration.to_version,))
|
||||
if migration.post_migrate:
|
||||
self._logger.debug(f"Running {len(migration.post_migrate)} post-migration callbacks")
|
||||
for callback in migration.post_migrate:
|
||||
callback(self._db, self._image_files)
|
||||
callback(self._cursor, self._image_files, self._logger)
|
||||
# Migration callbacks only get a cursor; they cannot commit the transaction.
|
||||
self._db.conn.commit()
|
||||
self._conn.commit()
|
||||
self._logger.debug(
|
||||
f"Successfully migrated database from {migration.from_version} to {migration.to_version}"
|
||||
)
|
||||
except Exception as e:
|
||||
msg = f"Error migrating database from {migration.from_version} to {migration.to_version}: {e}"
|
||||
self._db.conn.rollback()
|
||||
self._conn.rollback()
|
||||
self._logger.error(msg)
|
||||
raise MigrationError(msg) from e
|
||||
|
||||
def _create_version_table(self) -> None:
|
||||
"""Creates a version table for the database, if one does not already exist."""
|
||||
with self._db.lock:
|
||||
with self._lock:
|
||||
try:
|
||||
self._cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='migrations';")
|
||||
if self._cursor.fetchone() is not None:
|
||||
@ -195,17 +205,17 @@ class SQLiteMigrator:
|
||||
"""
|
||||
)
|
||||
self._cursor.execute("INSERT INTO migrations (version) VALUES (0);")
|
||||
self._db.conn.commit()
|
||||
self._conn.commit()
|
||||
self._logger.debug("Created migrations table")
|
||||
except sqlite3.Error as e:
|
||||
msg = f"Problem creating migrations table: {e}"
|
||||
self._logger.error(msg)
|
||||
self._db.conn.rollback()
|
||||
self._conn.rollback()
|
||||
raise MigrationError(msg) from e
|
||||
|
||||
def _get_current_version(self) -> int:
|
||||
"""Gets the current version of the database, or 0 if the version table does not exist."""
|
||||
with self._db.lock:
|
||||
with self._lock:
|
||||
try:
|
||||
self._cursor.execute("SELECT MAX(version) FROM migrations;")
|
||||
version = self._cursor.fetchone()[0]
|
||||
@ -219,19 +229,19 @@ class SQLiteMigrator:
|
||||
|
||||
def _backup_db(self) -> None:
|
||||
"""Backs up the databse, returning the path to the backup file."""
|
||||
if self._db.is_memory:
|
||||
if self._is_memory:
|
||||
self._logger.debug("Using memory database, skipping backup")
|
||||
# Sanity check!
|
||||
assert isinstance(self._db.database, Path)
|
||||
with self._db.lock:
|
||||
assert isinstance(self._database, Path)
|
||||
with self._lock:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
backup_path = self._db.database.parent / f"{self._db.database.stem}_{timestamp}.db"
|
||||
backup_path = self._database.parent / f"{self._database.stem}_{timestamp}.db"
|
||||
self._logger.info(f"Backing up database to {backup_path}")
|
||||
|
||||
# Use SQLite's built in backup capabilities so we don't need to worry about locking and such.
|
||||
backup_conn = sqlite3.connect(backup_path)
|
||||
with backup_conn:
|
||||
self._db.conn.backup(backup_conn)
|
||||
self._conn.backup(backup_conn)
|
||||
backup_conn.close()
|
||||
|
||||
# Sanity check!
|
||||
@ -243,18 +253,18 @@ class SQLiteMigrator:
|
||||
self,
|
||||
) -> None:
|
||||
"""Restores the database from a backup file, unless the database is a memory database."""
|
||||
if self._db.is_memory:
|
||||
if self._is_memory:
|
||||
return
|
||||
|
||||
with self._db.lock:
|
||||
with self._lock:
|
||||
self._logger.info(f"Restoring database from {self.backup_path}")
|
||||
self._db.conn.close()
|
||||
self._conn.close()
|
||||
assert isinstance(self.backup_path, Path)
|
||||
shutil.copy2(self.backup_path, self._db.database)
|
||||
shutil.copy2(self.backup_path, self._database)
|
||||
|
||||
def _unlink_lock_file(self) -> bool:
|
||||
"""Unlinks the migration lock file, returning True if it existed."""
|
||||
if self._db.is_memory or self._migration_lock_file_path is None:
|
||||
if self._is_memory or self._migration_lock_file_path is None:
|
||||
return False
|
||||
if self._migration_lock_file_path.is_file():
|
||||
self._migration_lock_file_path.unlink()
|
||||
@ -263,8 +273,8 @@ class SQLiteMigrator:
|
||||
|
||||
def _write_migration_lock_file(self) -> None:
|
||||
"""Writes a file to indicate that a migration is in progress."""
|
||||
if self._db.is_memory or self._migration_lock_file_path is None:
|
||||
if self._is_memory or self._migration_lock_file_path is None:
|
||||
return
|
||||
assert isinstance(self._db.database, Path)
|
||||
assert isinstance(self._database, Path)
|
||||
with open(self._migration_lock_file_path, "w") as f:
|
||||
f.write("1")
|
||||
|
Loading…
Reference in New Issue
Block a user