feat(db): decouple from SqliteDatabase

This commit is contained in:
psychedelicious 2023-12-10 22:17:03 +11:00
parent f8e4b93a74
commit 83e820d721
4 changed files with 46 additions and 39 deletions

View File

@ -73,7 +73,7 @@ class ApiDependencies:
image_files = DiskImageFileStorage(f"{output_folder}/images") image_files = DiskImageFileStorage(f"{output_folder}/images")
db = SqliteDatabase(config, logger) db = SqliteDatabase(config, logger)
migrator = SQLiteMigrator(db=db, image_files=image_files) migrator = SQLiteMigrator(database=db.database, lock=db.lock, image_files=image_files, logger=logger)
migrator.register_migration(migration_1) migrator.register_migration(migration_1)
migrator.register_migration(migration_2) migrator.register_migration(migration_2)
migrator.run_migrations() migrator.run_migrations()

View File

@ -1,14 +1,13 @@
import sqlite3 import sqlite3
from logging import Logger
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite.sqlite_migrator import Migration from invokeai.app.services.shared.sqlite.sqlite_migrator import Migration
def _migrate(db: SqliteDatabase, image_files: ImageFileStorageBase) -> None: def _migrate(cursor: sqlite3.Cursor, image_files: ImageFileStorageBase, logger: Logger) -> None:
"""Migration callback for database version 1.""" """Migration callback for database version 1."""
cursor = db.conn.cursor()
_create_board_images(cursor) _create_board_images(cursor)
_create_boards(cursor) _create_boards(cursor)
_create_images(cursor) _create_images(cursor)

View File

@ -4,20 +4,18 @@ from logging import Logger
from tqdm import tqdm from tqdm import tqdm
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite.sqlite_migrator import Migration from invokeai.app.services.shared.sqlite.sqlite_migrator import Migration
def _migrate(db: SqliteDatabase, image_files: ImageFileStorageBase) -> None: def _migrate(cursor: sqlite3.Cursor, image_files: ImageFileStorageBase, logger: Logger) -> None:
"""Migration callback for database version 2.""" """Migration callback for database version 2."""
cursor = db.conn.cursor()
_add_images_has_workflow(cursor) _add_images_has_workflow(cursor)
_add_session_queue_workflow(cursor) _add_session_queue_workflow(cursor)
_drop_old_workflow_tables(cursor) _drop_old_workflow_tables(cursor)
_add_workflow_library(cursor) _add_workflow_library(cursor)
_drop_model_manager_metadata(cursor) _drop_model_manager_metadata(cursor)
_migrate_embedded_workflows(cursor=cursor, image_files=image_files, logger=db._logger) _migrate_embedded_workflows(cursor=cursor, image_files=image_files, logger=logger)
def _add_images_has_workflow(cursor: sqlite3.Cursor) -> None: def _add_images_has_workflow(cursor: sqlite3.Cursor) -> None:

View File

@ -1,15 +1,17 @@
import shutil import shutil
import sqlite3 import sqlite3
import threading
from datetime import datetime from datetime import datetime
from logging import Logger
from pathlib import Path from pathlib import Path
from typing import Callable, Optional, TypeAlias from typing import Callable, Optional, TypeAlias
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
MigrateCallback: TypeAlias = Callable[[SqliteDatabase, ImageFileStorageBase], None] MigrateCallback: TypeAlias = Callable[[sqlite3.Cursor, ImageFileStorageBase, Logger], None]
class MigrationError(Exception): class MigrationError(Exception):
@ -95,17 +97,25 @@ class SQLiteMigrator:
backup_path: Optional[Path] = None backup_path: Optional[Path] = None
def __init__(self, db: SqliteDatabase, image_files: ImageFileStorageBase) -> None: def __init__(
self,
database: Path | str,
lock: threading.RLock,
logger: Logger,
image_files: ImageFileStorageBase,
) -> None:
self._lock = lock
self._database = database
self._is_memory = database == sqlite_memory
self._image_files = image_files self._image_files = image_files
self._db = db self._logger = logger
self._logger = self._db._logger self._conn = sqlite3.connect(database)
self._config = self._db._config self._cursor = self._conn.cursor()
self._cursor = self._db.conn.cursor()
self._migrations = MigrationSet() self._migrations = MigrationSet()
# Use a lock file to indicate that a migration is in progress. Should only exist in the event of catastrophic failure. # Use a lock file to indicate that a migration is in progress. Should only exist in the event of catastrophic failure.
self._migration_lock_file_path = ( self._migration_lock_file_path = (
self._db.database.parent / ".migration_in_progress" if isinstance(self._db.database, Path) else None self._database.parent / ".migration_in_progress" if isinstance(self._database, Path) else None
) )
if self._unlink_lock_file(): if self._unlink_lock_file():
@ -121,7 +131,7 @@ class SQLiteMigrator:
def run_migrations(self) -> None: def run_migrations(self) -> None:
"""Migrates the database to the latest version.""" """Migrates the database to the latest version."""
with self._db.lock: with self._lock:
self._create_version_table() self._create_version_table()
current_version = self._get_current_version() current_version = self._get_current_version()
@ -151,7 +161,7 @@ class SQLiteMigrator:
def _run_migration(self, migration: Migration) -> None: def _run_migration(self, migration: Migration) -> None:
"""Runs a single migration.""" """Runs a single migration."""
with self._db.lock: with self._lock:
try: try:
if self._get_current_version() != migration.from_version: if self._get_current_version() != migration.from_version:
raise MigrationError( raise MigrationError(
@ -161,27 +171,27 @@ class SQLiteMigrator:
if migration.pre_migrate: if migration.pre_migrate:
self._logger.debug(f"Running {len(migration.pre_migrate)} pre-migration callbacks") self._logger.debug(f"Running {len(migration.pre_migrate)} pre-migration callbacks")
for callback in migration.pre_migrate: for callback in migration.pre_migrate:
callback(self._db, self._image_files) callback(self._cursor, self._image_files, self._logger)
migration.migrate(self._db, self._image_files) migration.migrate(self._cursor, self._image_files, self._logger)
self._cursor.execute("INSERT INTO migrations (version) VALUES (?);", (migration.to_version,)) self._cursor.execute("INSERT INTO migrations (version) VALUES (?);", (migration.to_version,))
if migration.post_migrate: if migration.post_migrate:
self._logger.debug(f"Running {len(migration.post_migrate)} post-migration callbacks") self._logger.debug(f"Running {len(migration.post_migrate)} post-migration callbacks")
for callback in migration.post_migrate: for callback in migration.post_migrate:
callback(self._db, self._image_files) callback(self._cursor, self._image_files, self._logger)
# Migration callbacks only get a cursor; they cannot commit the transaction. # Migration callbacks only get a cursor; they cannot commit the transaction.
self._db.conn.commit() self._conn.commit()
self._logger.debug( self._logger.debug(
f"Successfully migrated database from {migration.from_version} to {migration.to_version}" f"Successfully migrated database from {migration.from_version} to {migration.to_version}"
) )
except Exception as e: except Exception as e:
msg = f"Error migrating database from {migration.from_version} to {migration.to_version}: {e}" msg = f"Error migrating database from {migration.from_version} to {migration.to_version}: {e}"
self._db.conn.rollback() self._conn.rollback()
self._logger.error(msg) self._logger.error(msg)
raise MigrationError(msg) from e raise MigrationError(msg) from e
def _create_version_table(self) -> None: def _create_version_table(self) -> None:
"""Creates a version table for the database, if one does not already exist.""" """Creates a version table for the database, if one does not already exist."""
with self._db.lock: with self._lock:
try: try:
self._cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='migrations';") self._cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='migrations';")
if self._cursor.fetchone() is not None: if self._cursor.fetchone() is not None:
@ -195,17 +205,17 @@ class SQLiteMigrator:
""" """
) )
self._cursor.execute("INSERT INTO migrations (version) VALUES (0);") self._cursor.execute("INSERT INTO migrations (version) VALUES (0);")
self._db.conn.commit() self._conn.commit()
self._logger.debug("Created migrations table") self._logger.debug("Created migrations table")
except sqlite3.Error as e: except sqlite3.Error as e:
msg = f"Problem creating migrations table: {e}" msg = f"Problem creating migrations table: {e}"
self._logger.error(msg) self._logger.error(msg)
self._db.conn.rollback() self._conn.rollback()
raise MigrationError(msg) from e raise MigrationError(msg) from e
def _get_current_version(self) -> int: def _get_current_version(self) -> int:
"""Gets the current version of the database, or 0 if the version table does not exist.""" """Gets the current version of the database, or 0 if the version table does not exist."""
with self._db.lock: with self._lock:
try: try:
self._cursor.execute("SELECT MAX(version) FROM migrations;") self._cursor.execute("SELECT MAX(version) FROM migrations;")
version = self._cursor.fetchone()[0] version = self._cursor.fetchone()[0]
@ -219,19 +229,19 @@ class SQLiteMigrator:
def _backup_db(self) -> None: def _backup_db(self) -> None:
"""Backs up the databse, returning the path to the backup file.""" """Backs up the databse, returning the path to the backup file."""
if self._db.is_memory: if self._is_memory:
self._logger.debug("Using memory database, skipping backup") self._logger.debug("Using memory database, skipping backup")
# Sanity check! # Sanity check!
assert isinstance(self._db.database, Path) assert isinstance(self._database, Path)
with self._db.lock: with self._lock:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_path = self._db.database.parent / f"{self._db.database.stem}_{timestamp}.db" backup_path = self._database.parent / f"{self._database.stem}_{timestamp}.db"
self._logger.info(f"Backing up database to {backup_path}") self._logger.info(f"Backing up database to {backup_path}")
# Use SQLite's built in backup capabilities so we don't need to worry about locking and such. # Use SQLite's built in backup capabilities so we don't need to worry about locking and such.
backup_conn = sqlite3.connect(backup_path) backup_conn = sqlite3.connect(backup_path)
with backup_conn: with backup_conn:
self._db.conn.backup(backup_conn) self._conn.backup(backup_conn)
backup_conn.close() backup_conn.close()
# Sanity check! # Sanity check!
@ -243,18 +253,18 @@ class SQLiteMigrator:
self, self,
) -> None: ) -> None:
"""Restores the database from a backup file, unless the database is a memory database.""" """Restores the database from a backup file, unless the database is a memory database."""
if self._db.is_memory: if self._is_memory:
return return
with self._db.lock: with self._lock:
self._logger.info(f"Restoring database from {self.backup_path}") self._logger.info(f"Restoring database from {self.backup_path}")
self._db.conn.close() self._conn.close()
assert isinstance(self.backup_path, Path) assert isinstance(self.backup_path, Path)
shutil.copy2(self.backup_path, self._db.database) shutil.copy2(self.backup_path, self._database)
def _unlink_lock_file(self) -> bool: def _unlink_lock_file(self) -> bool:
"""Unlinks the migration lock file, returning True if it existed.""" """Unlinks the migration lock file, returning True if it existed."""
if self._db.is_memory or self._migration_lock_file_path is None: if self._is_memory or self._migration_lock_file_path is None:
return False return False
if self._migration_lock_file_path.is_file(): if self._migration_lock_file_path.is_file():
self._migration_lock_file_path.unlink() self._migration_lock_file_path.unlink()
@ -263,8 +273,8 @@ class SQLiteMigrator:
def _write_migration_lock_file(self) -> None: def _write_migration_lock_file(self) -> None:
"""Writes a file to indicate that a migration is in progress.""" """Writes a file to indicate that a migration is in progress."""
if self._db.is_memory or self._migration_lock_file_path is None: if self._is_memory or self._migration_lock_file_path is None:
return return
assert isinstance(self._db.database, Path) assert isinstance(self._database, Path)
with open(self._migration_lock_file_path, "w") as f: with open(self._migration_lock_file_path, "w") as f:
f.write("1") f.write("1")