feat(db): remove file backups

Instead of mucking about with the filesystem, we rely on SQLite transactions to handle failed migrations.
This commit is contained in:
psychedelicious 2023-12-12 11:12:46 +11:00
parent 3414437eea
commit c5ba4f2ea5
6 changed files with 35 additions and 170 deletions

View File

@ -81,13 +81,7 @@ class ApiDependencies:
migration_2.register_post_callback(partial(migrate_embedded_workflows, logger=logger, image_files=image_files))
migrator.register_migration(migration_1)
migrator.register_migration(migration_2)
did_migrate = migrator.run_migrations()
# We need to reinitialize the database if we migrated, but only if we are using a file database.
# This closes the SqliteDatabase's connection and re-runs its `__init__` logic.
# If we do this with a memory database, we wipe the db.
if not db.db_path and did_migrate:
db.reinitialize()
migrator.run_migrations()
configuration = config
logger = logger

View File

@ -10,25 +10,22 @@ class SqliteDatabase:
"""
Manages a connection to an SQLite database.
:param db_path: Path to the database file. If None, an in-memory database is used.
:param logger: Logger to use for logging.
:param verbose: Whether to log SQL statements. Provides `logger.debug` as the SQLite trace callback.
This is a light wrapper around the `sqlite3` module, providing a few conveniences:
- The database file is written to disk if it does not exist.
- Foreign key constraints are enabled by default.
- The connection is configured to use the `sqlite3.Row` row factory.
- A `conn` attribute is provided to access the connection.
- A `lock` attribute is provided to lock the database connection.
- A `clean` method to run the VACUUM command and report on the freed space.
- A `reinitialize` method to close the connection and re-run the init.
- A `close` method to close the connection.
:param db_path: Path to the database file. If None, an in-memory database is used.
:param logger: Logger to use for logging.
:param verbose: Whether to log SQL statements. Provides `logger.debug` as the SQLite trace callback.
In addition to the constructor args, the instance provides the following attributes and methods:
- `conn`: A `sqlite3.Connection` object. Note that the connection must never be closed if the database is in-memory.
- `lock`: A shared re-entrant lock, used to approximate thread safety.
- `clean()`: Runs the SQL `VACUUM;` command and reports on the freed space.
"""
def __init__(self, db_path: Path | None, logger: Logger, verbose: bool = False) -> None:
self.initialize(db_path=db_path, logger=logger, verbose=verbose)
def initialize(self, db_path: Path | None, logger: Logger, verbose: bool = False) -> None:
"""Initializes the database. This is used internally by the class constructor."""
self.logger = logger
self.db_path = db_path
@ -49,21 +46,6 @@ class SqliteDatabase:
self.conn.execute("PRAGMA foreign_keys = ON;")
def reinitialize(self) -> None:
"""
Re-initializes the database by closing the connection and re-running the init.
Warning: This will wipe the database if it is an in-memory database.
"""
self.close()
self.initialize(db_path=self.db_path, logger=self.logger, verbose=self.verbose)
def close(self) -> None:
"""
Closes the connection to the database.
Warning: This will wipe the database if it is an in-memory database.
"""
self.conn.close()
def clean(self) -> None:
"""
Cleans the database by running the VACUUM command, reporting on the freed space.

View File

@ -1,6 +1,4 @@
import shutil
import sqlite3
from datetime import datetime
from pathlib import Path
from typing import Optional
@ -12,19 +10,11 @@ class SQLiteMigrator:
"""
Manages migrations for a SQLite database.
:param db_path: The path to the database to migrate, or None if using an in-memory database.
:param conn: The connection to the database.
:param lock: A lock to use when running migrations.
:param logger: A logger to use for logging.
:param log_sql: Whether to log SQL statements. Only used when the log level is set to debug.
:param db: The instanceof :class:`SqliteDatabase` to migrate.
Migrations should be registered with :meth:`register_migration`.
During migration, a copy of the current database is made and the migrations are run on the copy. If the migration
is successful, the original database is backed up and the migrated database is moved to the original database's
path. If the migration fails, the original database is left untouched and the migrated database is deleted.
If the database is in-memory, no backup is made; the migration is run in-place.
Each migration is run in a transaction. If a migration fails, the transaction is rolled back.
"""
backup_path: Optional[Path] = None
@ -34,11 +24,6 @@ class SQLiteMigrator:
self._logger = db.logger
self._migration_set = MigrationSet()
# The presence of an temp database file indicates a catastrophic failure of a previous migration.
if self._db.db_path and self._get_temp_db_path(self._db.db_path).is_file():
self._logger.warning("Previous migration failed! Trying again...")
self._get_temp_db_path(self._db.db_path).unlink()
def register_migration(self, migration: Migration) -> None:
"""Registers a migration."""
self._migration_set.register(migration)
@ -61,44 +46,24 @@ class SQLiteMigrator:
return False
self._logger.info("Database update needed")
if self._db.db_path:
# We are using a file database. Create a copy of the database to run the migrations on.
temp_db_path = self._create_temp_db(self._db.db_path)
self._logger.info(f"Copied database to {temp_db_path} for migration")
temp_db = SqliteDatabase(db_path=temp_db_path, logger=self._logger, verbose=self._db.verbose)
temp_db_cursor = temp_db.conn.cursor()
self._run_migrations(temp_db_cursor)
# Close the connections, copy the original database as a backup, and move the temp database to the
# original database's path.
temp_db.close()
self._db.close()
backup_db_path = self._finalize_migration(
temp_db_path=temp_db_path,
original_db_path=self._db.db_path,
)
self._logger.info(f"Migration successful. Original DB backed up to {backup_db_path}")
else:
# We are using a memory database. No special backup or special handling needed.
self._run_migrations(cursor)
next_migration = self._migration_set.get(from_version=self._get_current_version(cursor))
while next_migration is not None:
self._run_migration(next_migration)
next_migration = self._migration_set.get(self._get_current_version(cursor))
self._logger.info("Database updated successfully")
return True
def _run_migrations(self, temp_db_cursor: sqlite3.Cursor) -> None:
"""Runs all migrations in a loop."""
next_migration = self._migration_set.get(from_version=self._get_current_version(temp_db_cursor))
while next_migration is not None:
self._run_migration(next_migration, temp_db_cursor)
next_migration = self._migration_set.get(self._get_current_version(temp_db_cursor))
def _run_migration(self, migration: Migration, temp_db_cursor: sqlite3.Cursor) -> None:
def _run_migration(self, migration: Migration) -> None:
"""Runs a single migration."""
with self._db.lock:
try:
if self._get_current_version(temp_db_cursor) != migration.from_version:
# Using sqlite3.Connection as a context manager commits a the transaction on exit, or rolls it back if an
# exception is raised. We want to commit the transaction if the migration is successful, or roll it back if
# there is an error.
try:
with self._db.lock, self._db.conn as conn:
cursor = conn.cursor()
if self._get_current_version(cursor) != migration.from_version:
raise MigrationError(
f"Database is at version {self._get_current_version(temp_db_cursor)}, expected {migration.from_version}"
f"Database is at version {self._get_current_version(cursor)}, expected {migration.from_version}"
)
self._logger.debug(f"Running migration from {migration.from_version} to {migration.to_version}")
@ -106,28 +71,26 @@ class SQLiteMigrator:
if migration.pre_migrate:
self._logger.debug(f"Running {len(migration.pre_migrate)} pre-migration callbacks")
for callback in migration.pre_migrate:
callback(temp_db_cursor)
callback(cursor)
# Run the actual migration
migration.migrate(temp_db_cursor)
temp_db_cursor.execute("INSERT INTO migrations (version) VALUES (?);", (migration.to_version,))
migration.migrate(cursor)
cursor.execute("INSERT INTO migrations (version) VALUES (?);", (migration.to_version,))
# Run post-migration callbacks
if migration.post_migrate:
self._logger.debug(f"Running {len(migration.post_migrate)} post-migration callbacks")
for callback in migration.post_migrate:
callback(temp_db_cursor)
# Migration callbacks only get a cursor. Commit this migration.
temp_db_cursor.connection.commit()
callback(cursor)
self._logger.debug(
f"Successfully migrated database from {migration.from_version} to {migration.to_version}"
)
except Exception as e:
msg = f"Error migrating database from {migration.from_version} to {migration.to_version}: {e}"
temp_db_cursor.connection.rollback()
self._logger.error(msg)
raise MigrationError(msg) from e
# We want to catch *any* error, mirroring the behaviour of the sqlite3 module.
except Exception as e:
# The connection context manager has already rolled back the migration, so we don't need to do anything.
msg = f"Error migrating database from {migration.from_version} to {migration.to_version}: {e}"
self._logger.error(msg)
raise MigrationError(msg) from e
def _create_migrations_table(self, cursor: sqlite3.Cursor) -> None:
"""Creates the migrations table for the database, if one does not already exist."""
@ -166,33 +129,3 @@ class SQLiteMigrator:
if "no such table" in str(e):
return 0
raise
@classmethod
def _create_temp_db(cls, original_db_path: Path) -> Path:
"""Copies the current database to a new file for migration."""
temp_db_path = cls._get_temp_db_path(original_db_path)
shutil.copy2(original_db_path, temp_db_path)
return temp_db_path
@classmethod
def _finalize_migration(
cls,
temp_db_path: Path,
original_db_path: Path,
) -> Path:
"""Renames the original database as a backup and renames the migrated database to the original name."""
backup_db_path = cls._get_backup_db_path(original_db_path)
original_db_path.rename(backup_db_path)
temp_db_path.rename(original_db_path)
return backup_db_path
@classmethod
def _get_temp_db_path(cls, original_db_path: Path) -> Path:
"""Gets the path to the temp database."""
return original_db_path.parent / original_db_path.name.replace(".db", ".db.temp")
@classmethod
def _get_backup_db_path(cls, original_db_path: Path) -> Path:
"""Gets the path to the final backup database."""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
return original_db_path.parent / f"{original_db_path.stem}_backup_{timestamp}.db"

View File

@ -48,8 +48,6 @@ def store(app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
migrator.register_migration(migration_1)
migrator.register_migration(migration_2)
migrator.run_migrations()
# this test uses a file database, so we need to reinitialize it after migrations
db.reinitialize()
store: ModelRecordServiceBase = ModelRecordServiceSQL(db)
return store

View File

@ -39,8 +39,6 @@ def store(datadir: Any) -> ModelRecordServiceBase:
migrator.register_migration(migration_1)
migrator.register_migration(migration_2)
migrator.run_migrations()
# this test uses a file database, so we need to reinitialize it after migrations
db.reinitialize()
return ModelRecordServiceSQL(db)

View File

@ -198,7 +198,7 @@ def test_migrator_gets_current_version(migrator: SQLiteMigrator, migration_no_op
def test_migrator_runs_single_migration(migrator: SQLiteMigrator, migration_create_test_table: Migration) -> None:
cursor = migrator._db.conn.cursor()
migrator._create_migrations_table(cursor)
migrator._run_migration(migration_create_test_table, cursor)
migrator._run_migration(migration_create_test_table)
assert migrator._get_current_version(cursor) == 1
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test';")
assert cursor.fetchone() is not None
@ -228,46 +228,6 @@ def test_migrator_runs_all_migrations_file(logger: Logger) -> None:
assert SQLiteMigrator._get_current_version(original_db_cursor) == 3
def test_migrator_creates_temp_db() -> None:
with TemporaryDirectory() as tempdir:
original_db_path = Path(tempdir) / "invokeai.db"
with closing(sqlite3.connect(original_db_path)):
# create the db file so _create_temp_db has something to copy
pass
temp_db_path = SQLiteMigrator._create_temp_db(original_db_path)
assert temp_db_path.is_file()
assert temp_db_path == SQLiteMigrator._get_temp_db_path(original_db_path)
def test_migrator_finalizes() -> None:
with TemporaryDirectory() as tempdir:
original_db_path = Path(tempdir) / "invokeai.db"
temp_db_path = SQLiteMigrator._get_temp_db_path(original_db_path)
backup_db_path = SQLiteMigrator._get_backup_db_path(original_db_path)
with closing(sqlite3.connect(original_db_path)) as original_db_conn, closing(
sqlite3.connect(temp_db_path)
) as temp_db_conn:
original_db_cursor = original_db_conn.cursor()
original_db_cursor.execute("CREATE TABLE original_db_test (id INTEGER PRIMARY KEY);")
original_db_conn.commit()
temp_db_cursor = temp_db_conn.cursor()
temp_db_cursor.execute("CREATE TABLE temp_db_test (id INTEGER PRIMARY KEY);")
temp_db_conn.commit()
SQLiteMigrator._finalize_migration(
original_db_path=original_db_path,
temp_db_path=temp_db_path,
)
with closing(sqlite3.connect(backup_db_path)) as backup_db_conn, closing(
sqlite3.connect(temp_db_path)
) as temp_db_conn:
backup_db_cursor = backup_db_conn.cursor()
backup_db_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='original_db_test';")
assert backup_db_cursor.fetchone() is not None
temp_db_cursor = temp_db_conn.cursor()
temp_db_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='temp_db_test';")
assert temp_db_cursor.fetchone() is None
def test_migrator_makes_no_changes_on_failed_migration(
migrator: SQLiteMigrator, migration_no_op: Migration, failing_migrate_callback: MigrateCallback
) -> None: