|
|
|
@ -8,16 +8,14 @@ from typing import Callable, Optional, TypeAlias
|
|
|
|
|
|
|
|
|
|
from pydantic import BaseModel, Field, model_validator
|
|
|
|
|
|
|
|
|
|
from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
|
|
|
|
|
|
|
|
|
|
MigrateCallback: TypeAlias = Callable[[sqlite3.Cursor], None]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MigrationError(Exception):
|
|
|
|
|
class MigrationError(RuntimeError):
|
|
|
|
|
"""Raised when a migration fails."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MigrationVersionError(ValueError, MigrationError):
|
|
|
|
|
class MigrationVersionError(ValueError):
|
|
|
|
|
"""Raised when a migration version is invalid."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -25,8 +23,14 @@ class Migration(BaseModel):
|
|
|
|
|
"""
|
|
|
|
|
Represents a migration for a SQLite database.
|
|
|
|
|
|
|
|
|
|
Migration callbacks will be provided an instance of SqliteDatabase.
|
|
|
|
|
Migration callbacks should not commit; the migrator will commit the transaction.
|
|
|
|
|
Migration callbacks will be provided an open cursor to the database. They should not commit their
|
|
|
|
|
transaction; this is handled by the migrator.
|
|
|
|
|
|
|
|
|
|
Pre- and post-migration callback may be registered with :meth:`register_pre_callback` or
|
|
|
|
|
:meth:`register_post_callback`.
|
|
|
|
|
|
|
|
|
|
If a migration has additional dependencies, it is recommended to use functools.partial to provide
|
|
|
|
|
the dependencies and register the partial as the migration callback.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
from_version: int = Field(ge=0, strict=True, description="The database version on which this migration may be run")
|
|
|
|
@ -77,6 +81,28 @@ class MigrationSet:
|
|
|
|
|
# register() ensures that there is only one migration with a given from_version, so this is safe.
|
|
|
|
|
return next((m for m in self._migrations if m.from_version == from_version), None)
|
|
|
|
|
|
|
|
|
|
def validate_migration_path(self) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Validates that the migrations form a single path of migrations from version 0 to the latest version.
|
|
|
|
|
Raises a MigrationError if there is a problem.
|
|
|
|
|
"""
|
|
|
|
|
if self.count == 0:
|
|
|
|
|
return
|
|
|
|
|
if self.latest_version == 0:
|
|
|
|
|
return
|
|
|
|
|
current_version = 0
|
|
|
|
|
touched_count = 0
|
|
|
|
|
while current_version < self.latest_version:
|
|
|
|
|
migration = self.get(current_version)
|
|
|
|
|
if migration is None:
|
|
|
|
|
raise MigrationError(f"Missing migration from {current_version}")
|
|
|
|
|
current_version = migration.to_version
|
|
|
|
|
touched_count += 1
|
|
|
|
|
if current_version != self.latest_version:
|
|
|
|
|
raise MigrationError(f"Missing migration to {self.latest_version}")
|
|
|
|
|
if touched_count != self.count:
|
|
|
|
|
raise MigrationError("Migration path is not contiguous")
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def count(self) -> int:
|
|
|
|
|
"""The count of registered migrations."""
|
|
|
|
@ -90,87 +116,112 @@ class MigrationSet:
|
|
|
|
|
return sorted(self._migrations, key=lambda m: m.to_version)[-1].to_version
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_temp_db_path(original_db_path: Path) -> Path:
|
|
|
|
|
"""Gets the path to the migrated database."""
|
|
|
|
|
return original_db_path.parent / original_db_path.name.replace(".db", ".db.temp")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SQLiteMigrator:
|
|
|
|
|
"""
|
|
|
|
|
Manages migrations for a SQLite database.
|
|
|
|
|
|
|
|
|
|
:param db: The SqliteDatabase, representing the database on which to run migrations.
|
|
|
|
|
:param image_files: An instance of ImageFileStorageBase. Migrations may need to access image files.
|
|
|
|
|
:param db_path: The path to the database to migrate, or None if using an in-memory database.
|
|
|
|
|
:param conn: The connection to the database.
|
|
|
|
|
:param lock: A lock to use when running migrations.
|
|
|
|
|
:param logger: A logger to use for logging.
|
|
|
|
|
:param log_sql: Whether to log SQL statements. Only used when the log level is set to debug.
|
|
|
|
|
|
|
|
|
|
Migrations should be registered with :meth:`register_migration`. Migrations will be run in
|
|
|
|
|
order of their version number. If the database is already at the latest version, no migrations
|
|
|
|
|
will be run.
|
|
|
|
|
Migrations should be registered with :meth:`register_migration`.
|
|
|
|
|
|
|
|
|
|
During migration, a copy of the current database is made and the migrations are run on the copy. If the migration
|
|
|
|
|
is successful, the original database is backed up and the migrated database is moved to the original database's
|
|
|
|
|
path. If the migration fails, the original database is left untouched and the migrated database is deleted.
|
|
|
|
|
|
|
|
|
|
If the database is in-memory, no backup is made; the migration is run in-place.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
backup_path: Optional[Path] = None
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
database: Path | str,
|
|
|
|
|
db_path: Path | None,
|
|
|
|
|
conn: sqlite3.Connection,
|
|
|
|
|
lock: threading.RLock,
|
|
|
|
|
logger: Logger,
|
|
|
|
|
log_sql: bool = False,
|
|
|
|
|
) -> None:
|
|
|
|
|
self._lock = lock
|
|
|
|
|
self._database = database
|
|
|
|
|
self._is_memory = database == sqlite_memory
|
|
|
|
|
self._db_path = db_path
|
|
|
|
|
self._logger = logger
|
|
|
|
|
self._conn = sqlite3.connect(database)
|
|
|
|
|
self._conn = conn
|
|
|
|
|
self._log_sql = log_sql
|
|
|
|
|
self._cursor = self._conn.cursor()
|
|
|
|
|
self._migrations = MigrationSet()
|
|
|
|
|
|
|
|
|
|
# Use a lock file to indicate that a migration is in progress. Should only exist in the event of catastrophic failure.
|
|
|
|
|
self._migration_lock_file_path = (
|
|
|
|
|
self._database.parent / ".migration_in_progress" if isinstance(self._database, Path) else None
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if self._unlink_lock_file():
|
|
|
|
|
# The presence of an temp database file indicates a catastrophic failure of a previous migration.
|
|
|
|
|
if self._db_path and get_temp_db_path(self._db_path).is_file():
|
|
|
|
|
self._logger.warning("Previous migration failed! Trying again...")
|
|
|
|
|
get_temp_db_path(self._db_path).unlink()
|
|
|
|
|
|
|
|
|
|
def register_migration(self, migration: Migration) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Registers a migration.
|
|
|
|
|
Migration callbacks should not commit any changes to the database; the migrator will commit the transaction.
|
|
|
|
|
"""
|
|
|
|
|
"""Registers a migration."""
|
|
|
|
|
self._migrations.register(migration)
|
|
|
|
|
self._logger.debug(f"Registered migration {migration.from_version} -> {migration.to_version}")
|
|
|
|
|
|
|
|
|
|
def run_migrations(self) -> None:
|
|
|
|
|
"""Migrates the database to the latest version."""
|
|
|
|
|
with self._lock:
|
|
|
|
|
self._create_version_table()
|
|
|
|
|
current_version = self._get_current_version()
|
|
|
|
|
# This throws if there is a problem.
|
|
|
|
|
self._migrations.validate_migration_path()
|
|
|
|
|
self._create_migrations_table(cursor=self._cursor)
|
|
|
|
|
|
|
|
|
|
if self._migrations.count == 0:
|
|
|
|
|
self._logger.debug("No migrations registered")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
latest_version = self._migrations.latest_version
|
|
|
|
|
if current_version == latest_version:
|
|
|
|
|
if self._get_current_version(self._cursor) == self._migrations.latest_version:
|
|
|
|
|
self._logger.debug("Database is up to date, no migrations to run")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
self._logger.info("Database update needed")
|
|
|
|
|
|
|
|
|
|
# Only make a backup if using a file database (not memory)
|
|
|
|
|
self._backup_db()
|
|
|
|
|
if self._db_path:
|
|
|
|
|
# We are using a file database. Create a copy of the database to run the migrations on.
|
|
|
|
|
temp_db_path = self._create_temp_db(self._db_path)
|
|
|
|
|
temp_db_conn = sqlite3.connect(temp_db_path)
|
|
|
|
|
# We have to re-set this because we just created a new connection.
|
|
|
|
|
if self._log_sql:
|
|
|
|
|
temp_db_conn.set_trace_callback(self._logger.debug)
|
|
|
|
|
temp_db_cursor = temp_db_conn.cursor()
|
|
|
|
|
self._run_migrations(temp_db_cursor)
|
|
|
|
|
# Close the connections, copy the original database as a backup, and move the temp database to the
|
|
|
|
|
# original database's path.
|
|
|
|
|
self._finalize_migration(
|
|
|
|
|
temp_db_conn=temp_db_conn,
|
|
|
|
|
temp_db_path=temp_db_path,
|
|
|
|
|
original_db_path=self._db_path,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
# We are using a memory database. No special backup or special handling needed.
|
|
|
|
|
self._run_migrations(self._cursor)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
next_migration = self._migrations.get(from_version=current_version)
|
|
|
|
|
while next_migration is not None:
|
|
|
|
|
try:
|
|
|
|
|
self._run_migration(next_migration)
|
|
|
|
|
next_migration = self._migrations.get(self._get_current_version())
|
|
|
|
|
except MigrationError:
|
|
|
|
|
self._restore_db()
|
|
|
|
|
raise
|
|
|
|
|
self._logger.info("Database updated successfully")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
def _run_migration(self, migration: Migration) -> None:
|
|
|
|
|
def _run_migrations(self, temp_db_cursor: sqlite3.Cursor) -> None:
|
|
|
|
|
next_migration = self._migrations.get(from_version=self._get_current_version(temp_db_cursor))
|
|
|
|
|
while next_migration is not None:
|
|
|
|
|
self._run_migration(next_migration, temp_db_cursor)
|
|
|
|
|
next_migration = self._migrations.get(self._get_current_version(temp_db_cursor))
|
|
|
|
|
|
|
|
|
|
def _run_migration(self, migration: Migration, temp_db_cursor: sqlite3.Cursor) -> None:
|
|
|
|
|
"""Runs a single migration."""
|
|
|
|
|
with self._lock:
|
|
|
|
|
try:
|
|
|
|
|
if self._get_current_version() != migration.from_version:
|
|
|
|
|
if self._get_current_version(temp_db_cursor) != migration.from_version:
|
|
|
|
|
raise MigrationError(
|
|
|
|
|
f"Database is at version {self._get_current_version()}, expected {migration.from_version}"
|
|
|
|
|
f"Database is at version {self._get_current_version(temp_db_cursor)}, expected {migration.from_version}"
|
|
|
|
|
)
|
|
|
|
|
self._logger.debug(f"Running migration from {migration.from_version} to {migration.to_version}")
|
|
|
|
|
|
|
|
|
@ -178,37 +229,37 @@ class SQLiteMigrator:
|
|
|
|
|
if migration.pre_migrate:
|
|
|
|
|
self._logger.debug(f"Running {len(migration.pre_migrate)} pre-migration callbacks")
|
|
|
|
|
for callback in migration.pre_migrate:
|
|
|
|
|
callback(self._cursor)
|
|
|
|
|
callback(temp_db_cursor)
|
|
|
|
|
|
|
|
|
|
# Run the actual migration
|
|
|
|
|
migration.migrate(self._cursor)
|
|
|
|
|
self._cursor.execute("INSERT INTO migrations (version) VALUES (?);", (migration.to_version,))
|
|
|
|
|
migration.migrate(temp_db_cursor)
|
|
|
|
|
temp_db_cursor.execute("INSERT INTO migrations (version) VALUES (?);", (migration.to_version,))
|
|
|
|
|
|
|
|
|
|
# Run post-migration callbacks
|
|
|
|
|
if migration.post_migrate:
|
|
|
|
|
self._logger.debug(f"Running {len(migration.post_migrate)} post-migration callbacks")
|
|
|
|
|
for callback in migration.post_migrate:
|
|
|
|
|
callback(self._cursor)
|
|
|
|
|
callback(temp_db_cursor)
|
|
|
|
|
|
|
|
|
|
# Migration callbacks only get a cursor. Commit this migration.
|
|
|
|
|
self._conn.commit()
|
|
|
|
|
temp_db_cursor.connection.commit()
|
|
|
|
|
self._logger.debug(
|
|
|
|
|
f"Successfully migrated database from {migration.from_version} to {migration.to_version}"
|
|
|
|
|
)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
msg = f"Error migrating database from {migration.from_version} to {migration.to_version}: {e}"
|
|
|
|
|
self._conn.rollback()
|
|
|
|
|
temp_db_cursor.connection.rollback()
|
|
|
|
|
self._logger.error(msg)
|
|
|
|
|
raise MigrationError(msg) from e
|
|
|
|
|
|
|
|
|
|
def _create_version_table(self) -> None:
|
|
|
|
|
"""Creates a version table for the database, if one does not already exist."""
|
|
|
|
|
def _create_migrations_table(self, cursor: sqlite3.Cursor) -> None:
|
|
|
|
|
"""Creates the migrations table for the database, if one does not already exist."""
|
|
|
|
|
with self._lock:
|
|
|
|
|
try:
|
|
|
|
|
self._cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='migrations';")
|
|
|
|
|
if self._cursor.fetchone() is not None:
|
|
|
|
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='migrations';")
|
|
|
|
|
if cursor.fetchone() is not None:
|
|
|
|
|
return
|
|
|
|
|
self._cursor.execute(
|
|
|
|
|
cursor.execute(
|
|
|
|
|
"""--sql
|
|
|
|
|
CREATE TABLE migrations (
|
|
|
|
|
version INTEGER PRIMARY KEY,
|
|
|
|
@ -216,21 +267,21 @@ class SQLiteMigrator:
|
|
|
|
|
);
|
|
|
|
|
"""
|
|
|
|
|
)
|
|
|
|
|
self._cursor.execute("INSERT INTO migrations (version) VALUES (0);")
|
|
|
|
|
self._conn.commit()
|
|
|
|
|
cursor.execute("INSERT INTO migrations (version) VALUES (0);")
|
|
|
|
|
cursor.connection.commit()
|
|
|
|
|
self._logger.debug("Created migrations table")
|
|
|
|
|
except sqlite3.Error as e:
|
|
|
|
|
msg = f"Problem creating migrations table: {e}"
|
|
|
|
|
self._logger.error(msg)
|
|
|
|
|
self._conn.rollback()
|
|
|
|
|
cursor.connection.rollback()
|
|
|
|
|
raise MigrationError(msg) from e
|
|
|
|
|
|
|
|
|
|
def _get_current_version(self) -> int:
|
|
|
|
|
def _get_current_version(self, cursor: sqlite3.Cursor) -> int:
|
|
|
|
|
"""Gets the current version of the database, or 0 if the version table does not exist."""
|
|
|
|
|
with self._lock:
|
|
|
|
|
try:
|
|
|
|
|
self._cursor.execute("SELECT MAX(version) FROM migrations;")
|
|
|
|
|
version = self._cursor.fetchone()[0]
|
|
|
|
|
cursor.execute("SELECT MAX(version) FROM migrations;")
|
|
|
|
|
version = cursor.fetchone()[0]
|
|
|
|
|
if version is None:
|
|
|
|
|
return 0
|
|
|
|
|
return version
|
|
|
|
@ -239,54 +290,19 @@ class SQLiteMigrator:
|
|
|
|
|
return 0
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
def _backup_db(self) -> None:
|
|
|
|
|
"""Backs up the databse, returning the path to the backup file."""
|
|
|
|
|
if self._is_memory:
|
|
|
|
|
self._logger.debug("Using memory database, skipping backup")
|
|
|
|
|
# Sanity check!
|
|
|
|
|
assert isinstance(self._database, Path)
|
|
|
|
|
with self._lock:
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
|
|
|
backup_path = self._database.parent / f"{self._database.stem}_{timestamp}.db"
|
|
|
|
|
self._logger.info(f"Backing up database to {backup_path}")
|
|
|
|
|
def _create_temp_db(self, current_db_path: Path) -> Path:
|
|
|
|
|
"""Copies the current database to a new file for migration."""
|
|
|
|
|
temp_db_path = get_temp_db_path(current_db_path)
|
|
|
|
|
shutil.copy2(current_db_path, temp_db_path)
|
|
|
|
|
self._logger.info(f"Copied database to {temp_db_path} for migration")
|
|
|
|
|
return temp_db_path
|
|
|
|
|
|
|
|
|
|
# Use SQLite's built in backup capabilities so we don't need to worry about locking and such.
|
|
|
|
|
backup_conn = sqlite3.connect(backup_path)
|
|
|
|
|
with backup_conn:
|
|
|
|
|
self._conn.backup(backup_conn)
|
|
|
|
|
backup_conn.close()
|
|
|
|
|
|
|
|
|
|
# Sanity check!
|
|
|
|
|
if not backup_path.is_file():
|
|
|
|
|
raise MigrationError("Unable to back up database")
|
|
|
|
|
self.backup_path = backup_path
|
|
|
|
|
|
|
|
|
|
def _restore_db(
|
|
|
|
|
self,
|
|
|
|
|
) -> None:
|
|
|
|
|
"""Restores the database from a backup file, unless the database is a memory database."""
|
|
|
|
|
if self._is_memory:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
with self._lock:
|
|
|
|
|
self._logger.info(f"Restoring database from {self.backup_path}")
|
|
|
|
|
self._conn.close()
|
|
|
|
|
assert isinstance(self.backup_path, Path)
|
|
|
|
|
shutil.copy2(self.backup_path, self._database)
|
|
|
|
|
|
|
|
|
|
def _unlink_lock_file(self) -> bool:
|
|
|
|
|
"""Unlinks the migration lock file, returning True if it existed."""
|
|
|
|
|
if self._is_memory or self._migration_lock_file_path is None:
|
|
|
|
|
return False
|
|
|
|
|
if self._migration_lock_file_path.is_file():
|
|
|
|
|
self._migration_lock_file_path.unlink()
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
def _write_migration_lock_file(self) -> None:
|
|
|
|
|
"""Writes a file to indicate that a migration is in progress."""
|
|
|
|
|
if self._is_memory or self._migration_lock_file_path is None:
|
|
|
|
|
return
|
|
|
|
|
assert isinstance(self._database, Path)
|
|
|
|
|
with open(self._migration_lock_file_path, "w") as f:
|
|
|
|
|
f.write("1")
|
|
|
|
|
def _finalize_migration(self, temp_db_conn: sqlite3.Connection, temp_db_path: Path, original_db_path: Path) -> None:
|
|
|
|
|
"""Closes connections, renames the original database as a backup and renames the migrated database to the original db path."""
|
|
|
|
|
self._conn.close()
|
|
|
|
|
temp_db_conn.close()
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
|
|
|
backup_db_path = original_db_path.parent / f"{original_db_path.stem}_backup_{timestamp}.db"
|
|
|
|
|
original_db_path.rename(backup_db_path)
|
|
|
|
|
temp_db_path.rename(original_db_path)
|
|
|
|
|
self._logger.info(f"Migration successful. Original DB backed up to {backup_db_path}")
|
|
|
|
|