feat(db): invert backup/restore logic

Do the migration on a temp copy of the db, then back up the original and move the temp into its file.
This commit is contained in:
psychedelicious 2023-12-11 00:47:53 +11:00
parent abeb1bd3b3
commit e461f9925e
6 changed files with 167 additions and 126 deletions

View File

@ -2,6 +2,7 @@
from functools import partial
from logging import Logger
from pathlib import Path
from invokeai.app.services.shared.sqlite.migrations.migration_1 import migration_1
from invokeai.app.services.shared.sqlite.migrations.migration_2 import migration_2
@ -75,12 +76,22 @@ class ApiDependencies:
image_files = DiskImageFileStorage(f"{output_folder}/images")
db = SqliteDatabase(config, logger)
migrator = SQLiteMigrator(database=db.database, lock=db.lock, logger=logger)
migrator = SQLiteMigrator(
db_path=db.database if isinstance(db.database, Path) else None,
conn=db.conn,
lock=db.lock,
logger=logger,
log_sql=config.log_sql,
)
migration_2.register_post_callback(partial(migrate_embedded_workflows, logger=logger, image_files=image_files))
migrator.register_migration(migration_1)
migrator.register_migration(migration_2)
migrator.run_migrations()
if not db.is_memory:
db.reinitialize()
configuration = config
logger = logger

View File

@ -6,6 +6,8 @@ from invokeai.app.services.shared.sqlite.sqlite_migrator import Migration
def _migrate(cursor: sqlite3.Cursor) -> None:
"""Migration callback for database version 1."""
print("migration 1!!!")
_create_board_images(cursor)
_create_boards(cursor)
_create_images(cursor)

View File

@ -6,6 +6,8 @@ from invokeai.app.services.shared.sqlite.sqlite_migrator import Migration
def _migrate(cursor: sqlite3.Cursor) -> None:
"""Migration callback for database version 2."""
print("migration 2!!!")
_add_images_has_workflow(cursor)
_add_session_queue_workflow(cursor)
_drop_old_workflow_tables(cursor)

View File

@ -10,18 +10,21 @@ from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
class SqliteDatabase:
database: Path | str # Must declare this here to satisfy type checker
def __init__(self, config: InvokeAIAppConfig, logger: Logger):
def __init__(self, config: InvokeAIAppConfig, logger: Logger) -> None:
self.initialize(config, logger)
def initialize(self, config: InvokeAIAppConfig, logger: Logger) -> None:
self._logger = logger
self._config = config
self.is_memory = False
if self._config.use_memory_db:
self.database = sqlite_memory
self.is_memory = True
logger.info("Using in-memory database")
logger.info("Initializing in-memory database")
else:
self.database = self._config.db_path
self.database.parent.mkdir(parents=True, exist_ok=True)
self._logger.info(f"Using database at {self.database}")
self._logger.info(f"Initializing database at {self.database}")
self.conn = sqlite3.connect(database=self.database, check_same_thread=False)
self.lock = threading.RLock()
@ -32,6 +35,13 @@ class SqliteDatabase:
self.conn.execute("PRAGMA foreign_keys = ON;")
def reinitialize(self) -> None:
"""Reinitializes the database. Needed after migration."""
self.initialize(self._config, self._logger)
def close(self) -> None:
self.conn.close()
def clean(self) -> None:
with self.lock:
try:

View File

@ -8,16 +8,14 @@ from typing import Callable, Optional, TypeAlias
from pydantic import BaseModel, Field, model_validator
from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
MigrateCallback: TypeAlias = Callable[[sqlite3.Cursor], None]
class MigrationError(Exception):
class MigrationError(RuntimeError):
"""Raised when a migration fails."""
class MigrationVersionError(ValueError, MigrationError):
class MigrationVersionError(ValueError):
"""Raised when a migration version is invalid."""
@ -25,8 +23,14 @@ class Migration(BaseModel):
"""
Represents a migration for a SQLite database.
Migration callbacks will be provided an instance of SqliteDatabase.
Migration callbacks should not commit; the migrator will commit the transaction.
Migration callbacks will be provided an open cursor to the database. They should not commit their
transaction; this is handled by the migrator.
Pre- and post-migration callback may be registered with :meth:`register_pre_callback` or
:meth:`register_post_callback`.
If a migration has additional dependencies, it is recommended to use functools.partial to provide
the dependencies and register the partial as the migration callback.
"""
from_version: int = Field(ge=0, strict=True, description="The database version on which this migration may be run")
@ -77,6 +81,28 @@ class MigrationSet:
# register() ensures that there is only one migration with a given from_version, so this is safe.
return next((m for m in self._migrations if m.from_version == from_version), None)
def validate_migration_path(self) -> None:
"""
Validates that the migrations form a single path of migrations from version 0 to the latest version.
Raises a MigrationError if there is a problem.
"""
if self.count == 0:
return
if self.latest_version == 0:
return
current_version = 0
touched_count = 0
while current_version < self.latest_version:
migration = self.get(current_version)
if migration is None:
raise MigrationError(f"Missing migration from {current_version}")
current_version = migration.to_version
touched_count += 1
if current_version != self.latest_version:
raise MigrationError(f"Missing migration to {self.latest_version}")
if touched_count != self.count:
raise MigrationError("Migration path is not contiguous")
@property
def count(self) -> int:
"""The count of registered migrations."""
@ -90,87 +116,112 @@ class MigrationSet:
return sorted(self._migrations, key=lambda m: m.to_version)[-1].to_version
def get_temp_db_path(original_db_path: Path) -> Path:
"""Gets the path to the migrated database."""
return original_db_path.parent / original_db_path.name.replace(".db", ".db.temp")
class SQLiteMigrator:
"""
Manages migrations for a SQLite database.
:param db: The SqliteDatabase, representing the database on which to run migrations.
:param image_files: An instance of ImageFileStorageBase. Migrations may need to access image files.
:param db_path: The path to the database to migrate, or None if using an in-memory database.
:param conn: The connection to the database.
:param lock: A lock to use when running migrations.
:param logger: A logger to use for logging.
:param log_sql: Whether to log SQL statements. Only used when the log level is set to debug.
Migrations should be registered with :meth:`register_migration`. Migrations will be run in
order of their version number. If the database is already at the latest version, no migrations
will be run.
Migrations should be registered with :meth:`register_migration`.
During migration, a copy of the current database is made and the migrations are run on the copy. If the migration
is successful, the original database is backed up and the migrated database is moved to the original database's
path. If the migration fails, the original database is left untouched and the migrated database is deleted.
If the database is in-memory, no backup is made; the migration is run in-place.
"""
backup_path: Optional[Path] = None
def __init__(
self,
database: Path | str,
db_path: Path | None,
conn: sqlite3.Connection,
lock: threading.RLock,
logger: Logger,
log_sql: bool = False,
) -> None:
self._lock = lock
self._database = database
self._is_memory = database == sqlite_memory
self._db_path = db_path
self._logger = logger
self._conn = sqlite3.connect(database)
self._conn = conn
self._log_sql = log_sql
self._cursor = self._conn.cursor()
self._migrations = MigrationSet()
# Use a lock file to indicate that a migration is in progress. Should only exist in the event of catastrophic failure.
self._migration_lock_file_path = (
self._database.parent / ".migration_in_progress" if isinstance(self._database, Path) else None
)
if self._unlink_lock_file():
# The presence of an temp database file indicates a catastrophic failure of a previous migration.
if self._db_path and get_temp_db_path(self._db_path).is_file():
self._logger.warning("Previous migration failed! Trying again...")
get_temp_db_path(self._db_path).unlink()
def register_migration(self, migration: Migration) -> None:
"""
Registers a migration.
Migration callbacks should not commit any changes to the database; the migrator will commit the transaction.
"""
"""Registers a migration."""
self._migrations.register(migration)
self._logger.debug(f"Registered migration {migration.from_version} -> {migration.to_version}")
def run_migrations(self) -> None:
"""Migrates the database to the latest version."""
with self._lock:
self._create_version_table()
current_version = self._get_current_version()
# This throws if there is a problem.
self._migrations.validate_migration_path()
self._create_migrations_table(cursor=self._cursor)
if self._migrations.count == 0:
self._logger.debug("No migrations registered")
return
latest_version = self._migrations.latest_version
if current_version == latest_version:
if self._get_current_version(self._cursor) == self._migrations.latest_version:
self._logger.debug("Database is up to date, no migrations to run")
return
self._logger.info("Database update needed")
# Only make a backup if using a file database (not memory)
self._backup_db()
if self._db_path:
# We are using a file database. Create a copy of the database to run the migrations on.
temp_db_path = self._create_temp_db(self._db_path)
temp_db_conn = sqlite3.connect(temp_db_path)
# We have to re-set this because we just created a new connection.
if self._log_sql:
temp_db_conn.set_trace_callback(self._logger.debug)
temp_db_cursor = temp_db_conn.cursor()
self._run_migrations(temp_db_cursor)
# Close the connections, copy the original database as a backup, and move the temp database to the
# original database's path.
self._finalize_migration(
temp_db_conn=temp_db_conn,
temp_db_path=temp_db_path,
original_db_path=self._db_path,
)
else:
# We are using a memory database. No special backup or special handling needed.
self._run_migrations(self._cursor)
return
next_migration = self._migrations.get(from_version=current_version)
while next_migration is not None:
try:
self._run_migration(next_migration)
next_migration = self._migrations.get(self._get_current_version())
except MigrationError:
self._restore_db()
raise
self._logger.info("Database updated successfully")
return
def _run_migration(self, migration: Migration) -> None:
def _run_migrations(self, temp_db_cursor: sqlite3.Cursor) -> None:
next_migration = self._migrations.get(from_version=self._get_current_version(temp_db_cursor))
while next_migration is not None:
self._run_migration(next_migration, temp_db_cursor)
next_migration = self._migrations.get(self._get_current_version(temp_db_cursor))
def _run_migration(self, migration: Migration, temp_db_cursor: sqlite3.Cursor) -> None:
"""Runs a single migration."""
with self._lock:
try:
if self._get_current_version() != migration.from_version:
if self._get_current_version(temp_db_cursor) != migration.from_version:
raise MigrationError(
f"Database is at version {self._get_current_version()}, expected {migration.from_version}"
f"Database is at version {self._get_current_version(temp_db_cursor)}, expected {migration.from_version}"
)
self._logger.debug(f"Running migration from {migration.from_version} to {migration.to_version}")
@ -178,37 +229,37 @@ class SQLiteMigrator:
if migration.pre_migrate:
self._logger.debug(f"Running {len(migration.pre_migrate)} pre-migration callbacks")
for callback in migration.pre_migrate:
callback(self._cursor)
callback(temp_db_cursor)
# Run the actual migration
migration.migrate(self._cursor)
self._cursor.execute("INSERT INTO migrations (version) VALUES (?);", (migration.to_version,))
migration.migrate(temp_db_cursor)
temp_db_cursor.execute("INSERT INTO migrations (version) VALUES (?);", (migration.to_version,))
# Run post-migration callbacks
if migration.post_migrate:
self._logger.debug(f"Running {len(migration.post_migrate)} post-migration callbacks")
for callback in migration.post_migrate:
callback(self._cursor)
callback(temp_db_cursor)
# Migration callbacks only get a cursor. Commit this migration.
self._conn.commit()
temp_db_cursor.connection.commit()
self._logger.debug(
f"Successfully migrated database from {migration.from_version} to {migration.to_version}"
)
except Exception as e:
msg = f"Error migrating database from {migration.from_version} to {migration.to_version}: {e}"
self._conn.rollback()
temp_db_cursor.connection.rollback()
self._logger.error(msg)
raise MigrationError(msg) from e
def _create_version_table(self) -> None:
"""Creates a version table for the database, if one does not already exist."""
def _create_migrations_table(self, cursor: sqlite3.Cursor) -> None:
"""Creates the migrations table for the database, if one does not already exist."""
with self._lock:
try:
self._cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='migrations';")
if self._cursor.fetchone() is not None:
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='migrations';")
if cursor.fetchone() is not None:
return
self._cursor.execute(
cursor.execute(
"""--sql
CREATE TABLE migrations (
version INTEGER PRIMARY KEY,
@ -216,21 +267,21 @@ class SQLiteMigrator:
);
"""
)
self._cursor.execute("INSERT INTO migrations (version) VALUES (0);")
self._conn.commit()
cursor.execute("INSERT INTO migrations (version) VALUES (0);")
cursor.connection.commit()
self._logger.debug("Created migrations table")
except sqlite3.Error as e:
msg = f"Problem creating migrations table: {e}"
self._logger.error(msg)
self._conn.rollback()
cursor.connection.rollback()
raise MigrationError(msg) from e
def _get_current_version(self) -> int:
def _get_current_version(self, cursor: sqlite3.Cursor) -> int:
"""Gets the current version of the database, or 0 if the version table does not exist."""
with self._lock:
try:
self._cursor.execute("SELECT MAX(version) FROM migrations;")
version = self._cursor.fetchone()[0]
cursor.execute("SELECT MAX(version) FROM migrations;")
version = cursor.fetchone()[0]
if version is None:
return 0
return version
@ -239,54 +290,19 @@ class SQLiteMigrator:
return 0
raise
def _backup_db(self) -> None:
"""Backs up the databse, returning the path to the backup file."""
if self._is_memory:
self._logger.debug("Using memory database, skipping backup")
# Sanity check!
assert isinstance(self._database, Path)
with self._lock:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_path = self._database.parent / f"{self._database.stem}_{timestamp}.db"
self._logger.info(f"Backing up database to {backup_path}")
def _create_temp_db(self, current_db_path: Path) -> Path:
"""Copies the current database to a new file for migration."""
temp_db_path = get_temp_db_path(current_db_path)
shutil.copy2(current_db_path, temp_db_path)
self._logger.info(f"Copied database to {temp_db_path} for migration")
return temp_db_path
# Use SQLite's built in backup capabilities so we don't need to worry about locking and such.
backup_conn = sqlite3.connect(backup_path)
with backup_conn:
self._conn.backup(backup_conn)
backup_conn.close()
# Sanity check!
if not backup_path.is_file():
raise MigrationError("Unable to back up database")
self.backup_path = backup_path
def _restore_db(
self,
) -> None:
"""Restores the database from a backup file, unless the database is a memory database."""
if self._is_memory:
return
with self._lock:
self._logger.info(f"Restoring database from {self.backup_path}")
self._conn.close()
assert isinstance(self.backup_path, Path)
shutil.copy2(self.backup_path, self._database)
def _unlink_lock_file(self) -> bool:
"""Unlinks the migration lock file, returning True if it existed."""
if self._is_memory or self._migration_lock_file_path is None:
return False
if self._migration_lock_file_path.is_file():
self._migration_lock_file_path.unlink()
return True
return False
def _write_migration_lock_file(self) -> None:
"""Writes a file to indicate that a migration is in progress."""
if self._is_memory or self._migration_lock_file_path is None:
return
assert isinstance(self._database, Path)
with open(self._migration_lock_file_path, "w") as f:
f.write("1")
def _finalize_migration(self, temp_db_conn: sqlite3.Connection, temp_db_path: Path, original_db_path: Path) -> None:
"""Closes connections, renames the original database as a backup and renames the migrated database to the original db path."""
self._conn.close()
temp_db_conn.close()
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_db_path = original_db_path.parent / f"{original_db_path.stem}_backup_{timestamp}.db"
original_db_path.rename(backup_db_path)
temp_db_path.rename(original_db_path)
self._logger.info(f"Migration successful. Original DB backed up to {backup_db_path}")

View File

@ -21,7 +21,7 @@ from invokeai.app.services.shared.sqlite.sqlite_migrator import (
def migrator() -> SQLiteMigrator:
conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
return SQLiteMigrator(
conn=conn, database=sqlite_memory, lock=threading.RLock(), logger=Logger("test_sqlite_migrator")
conn=conn, db_path=sqlite_memory, lock=threading.RLock(), logger=Logger("test_sqlite_migrator")
)
@ -50,19 +50,19 @@ def test_register_invalid_migration_version(migrator: SQLiteMigrator):
def test_create_version_table(migrator: SQLiteMigrator):
migrator._create_version_table()
migrator._create_migrations_table()
migrator._cursor.execute("SELECT * FROM sqlite_master WHERE type='table' AND name='version';")
assert migrator._cursor.fetchone() is not None
def test_get_current_version(migrator: SQLiteMigrator):
migrator._create_version_table()
migrator._create_migrations_table()
migrator._conn.commit()
assert migrator._get_current_version() == 0 # initial version
def test_set_version(migrator: SQLiteMigrator):
migrator._create_version_table()
migrator._create_migrations_table()
migrator._set_version(db_version=1, app_version="1.0.0")
migrator._cursor.execute("SELECT MAX(db_version) FROM version;")
assert migrator._cursor.fetchone()[0] == 1
@ -71,7 +71,7 @@ def test_set_version(migrator: SQLiteMigrator):
def test_run_migration(migrator: SQLiteMigrator):
migrator._create_version_table()
migrator._create_migrations_table()
def migration_callback(cursor: sqlite3.Cursor) -> None:
cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
@ -84,7 +84,7 @@ def test_run_migration(migrator: SQLiteMigrator):
def test_run_migrations(migrator: SQLiteMigrator):
migrator._create_version_table()
migrator._create_migrations_table()
def create_migrate(i: int) -> Callable[[sqlite3.Cursor], None]:
def migrate(cursor: sqlite3.Cursor) -> None:
@ -109,8 +109,8 @@ def test_backup_and_restore_db():
cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
conn.commit()
migrator = SQLiteMigrator(conn=conn, database=database, lock=threading.RLock(), logger=Logger("test"))
backup_path = migrator._backup_db(migrator._database)
migrator = SQLiteMigrator(conn=conn, db_path=database, lock=threading.RLock(), logger=Logger("test"))
backup_path = migrator._backup_db(migrator._db_path)
# mangle the db
migrator._cursor.execute("DROP TABLE test;")
@ -135,21 +135,21 @@ def test_no_backup_and_restore_for_memory_db(migrator: SQLiteMigrator):
def test_failed_migration(migrator: SQLiteMigrator, failing_migration: Migration):
migrator._create_version_table()
migrator._create_migrations_table()
with pytest.raises(MigrationError, match="Error migrating database from 0 to 1"):
migrator._run_migration(failing_migration)
assert migrator._get_current_version() == 0
def test_duplicate_migration_versions(migrator: SQLiteMigrator, good_migration: Migration):
migrator._create_version_table()
migrator._create_migrations_table()
migrator.register_migration(good_migration)
with pytest.raises(MigrationVersionError, match="already registered"):
migrator.register_migration(deepcopy(good_migration))
def test_non_sequential_migration_registration(migrator: SQLiteMigrator):
migrator._create_version_table()
migrator._create_migrations_table()
def create_migrate(i: int) -> Callable[[sqlite3.Cursor], None]:
def migrate(cursor: sqlite3.Cursor) -> None:
@ -167,7 +167,7 @@ def test_non_sequential_migration_registration(migrator: SQLiteMigrator):
def test_db_version_gt_last_migration(migrator: SQLiteMigrator, good_migration: Migration):
migrator._create_version_table()
migrator._create_migrations_table()
migrator.register_migration(good_migration)
migrator._set_version(db_version=2, app_version="2.0.0")
with pytest.raises(MigrationError, match="greater than the latest migration version"):
@ -176,7 +176,7 @@ def test_db_version_gt_last_migration(migrator: SQLiteMigrator, good_migration:
def test_idempotent_migrations(migrator: SQLiteMigrator):
migrator._create_version_table()
migrator._create_migrations_table()
def create_test_table(cursor: sqlite3.Cursor) -> None:
# This SQL throws if run twice