feat(db): remove file backups

Instead of mucking about with the filesystem, we rely on SQLite transactions to handle failed migrations.
This commit is contained in:
psychedelicious 2023-12-12 11:12:46 +11:00
parent 3414437eea
commit c5ba4f2ea5
6 changed files with 35 additions and 170 deletions

View File

@ -81,13 +81,7 @@ class ApiDependencies:
migration_2.register_post_callback(partial(migrate_embedded_workflows, logger=logger, image_files=image_files)) migration_2.register_post_callback(partial(migrate_embedded_workflows, logger=logger, image_files=image_files))
migrator.register_migration(migration_1) migrator.register_migration(migration_1)
migrator.register_migration(migration_2) migrator.register_migration(migration_2)
did_migrate = migrator.run_migrations() migrator.run_migrations()
# We need to reinitialize the database if we migrated, but only if we are using a file database.
# This closes the SqliteDatabase's connection and re-runs its `__init__` logic.
# If we do this with a memory database, we wipe the db.
if not db.db_path and did_migrate:
db.reinitialize()
configuration = config configuration = config
logger = logger logger = logger

View File

@ -10,25 +10,22 @@ class SqliteDatabase:
""" """
Manages a connection to an SQLite database. Manages a connection to an SQLite database.
:param db_path: Path to the database file. If None, an in-memory database is used.
:param logger: Logger to use for logging.
:param verbose: Whether to log SQL statements. Provides `logger.debug` as the SQLite trace callback.
This is a light wrapper around the `sqlite3` module, providing a few conveniences: This is a light wrapper around the `sqlite3` module, providing a few conveniences:
- The database file is written to disk if it does not exist. - The database file is written to disk if it does not exist.
- Foreign key constraints are enabled by default. - Foreign key constraints are enabled by default.
- The connection is configured to use the `sqlite3.Row` row factory. - The connection is configured to use the `sqlite3.Row` row factory.
- A `conn` attribute is provided to access the connection.
- A `lock` attribute is provided to lock the database connection.
- A `clean` method to run the VACUUM command and report on the freed space.
- A `reinitialize` method to close the connection and re-run the init.
- A `close` method to close the connection.
:param db_path: Path to the database file. If None, an in-memory database is used. In addition to the constructor args, the instance provides the following attributes and methods:
:param logger: Logger to use for logging. - `conn`: A `sqlite3.Connection` object. Note that the connection must never be closed if the database is in-memory.
:param verbose: Whether to log SQL statements. Provides `logger.debug` as the SQLite trace callback. - `lock`: A shared re-entrant lock, used to approximate thread safety.
- `clean()`: Runs the SQL `VACUUM;` command and reports on the freed space.
""" """
def __init__(self, db_path: Path | None, logger: Logger, verbose: bool = False) -> None: def __init__(self, db_path: Path | None, logger: Logger, verbose: bool = False) -> None:
self.initialize(db_path=db_path, logger=logger, verbose=verbose)
def initialize(self, db_path: Path | None, logger: Logger, verbose: bool = False) -> None:
"""Initializes the database. This is used internally by the class constructor.""" """Initializes the database. This is used internally by the class constructor."""
self.logger = logger self.logger = logger
self.db_path = db_path self.db_path = db_path
@ -49,21 +46,6 @@ class SqliteDatabase:
self.conn.execute("PRAGMA foreign_keys = ON;") self.conn.execute("PRAGMA foreign_keys = ON;")
def reinitialize(self) -> None:
"""
Re-initializes the database by closing the connection and re-running the init.
Warning: This will wipe the database if it is an in-memory database.
"""
self.close()
self.initialize(db_path=self.db_path, logger=self.logger, verbose=self.verbose)
def close(self) -> None:
"""
Closes the connection to the database.
Warning: This will wipe the database if it is an in-memory database.
"""
self.conn.close()
def clean(self) -> None: def clean(self) -> None:
""" """
Cleans the database by running the VACUUM command, reporting on the freed space. Cleans the database by running the VACUUM command, reporting on the freed space.

View File

@ -1,6 +1,4 @@
import shutil
import sqlite3 import sqlite3
from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
@ -12,19 +10,11 @@ class SQLiteMigrator:
""" """
Manages migrations for a SQLite database. Manages migrations for a SQLite database.
:param db_path: The path to the database to migrate, or None if using an in-memory database. :param db: The instanceof :class:`SqliteDatabase` to migrate.
:param conn: The connection to the database.
:param lock: A lock to use when running migrations.
:param logger: A logger to use for logging.
:param log_sql: Whether to log SQL statements. Only used when the log level is set to debug.
Migrations should be registered with :meth:`register_migration`. Migrations should be registered with :meth:`register_migration`.
During migration, a copy of the current database is made and the migrations are run on the copy. If the migration Each migration is run in a transaction. If a migration fails, the transaction is rolled back.
is successful, the original database is backed up and the migrated database is moved to the original database's
path. If the migration fails, the original database is left untouched and the migrated database is deleted.
If the database is in-memory, no backup is made; the migration is run in-place.
""" """
backup_path: Optional[Path] = None backup_path: Optional[Path] = None
@ -34,11 +24,6 @@ class SQLiteMigrator:
self._logger = db.logger self._logger = db.logger
self._migration_set = MigrationSet() self._migration_set = MigrationSet()
# The presence of an temp database file indicates a catastrophic failure of a previous migration.
if self._db.db_path and self._get_temp_db_path(self._db.db_path).is_file():
self._logger.warning("Previous migration failed! Trying again...")
self._get_temp_db_path(self._db.db_path).unlink()
def register_migration(self, migration: Migration) -> None: def register_migration(self, migration: Migration) -> None:
"""Registers a migration.""" """Registers a migration."""
self._migration_set.register(migration) self._migration_set.register(migration)
@ -61,44 +46,24 @@ class SQLiteMigrator:
return False return False
self._logger.info("Database update needed") self._logger.info("Database update needed")
next_migration = self._migration_set.get(from_version=self._get_current_version(cursor))
if self._db.db_path: while next_migration is not None:
# We are using a file database. Create a copy of the database to run the migrations on. self._run_migration(next_migration)
temp_db_path = self._create_temp_db(self._db.db_path) next_migration = self._migration_set.get(self._get_current_version(cursor))
self._logger.info(f"Copied database to {temp_db_path} for migration")
temp_db = SqliteDatabase(db_path=temp_db_path, logger=self._logger, verbose=self._db.verbose)
temp_db_cursor = temp_db.conn.cursor()
self._run_migrations(temp_db_cursor)
# Close the connections, copy the original database as a backup, and move the temp database to the
# original database's path.
temp_db.close()
self._db.close()
backup_db_path = self._finalize_migration(
temp_db_path=temp_db_path,
original_db_path=self._db.db_path,
)
self._logger.info(f"Migration successful. Original DB backed up to {backup_db_path}")
else:
# We are using a memory database. No special backup or special handling needed.
self._run_migrations(cursor)
self._logger.info("Database updated successfully") self._logger.info("Database updated successfully")
return True return True
def _run_migrations(self, temp_db_cursor: sqlite3.Cursor) -> None: def _run_migration(self, migration: Migration) -> None:
"""Runs all migrations in a loop."""
next_migration = self._migration_set.get(from_version=self._get_current_version(temp_db_cursor))
while next_migration is not None:
self._run_migration(next_migration, temp_db_cursor)
next_migration = self._migration_set.get(self._get_current_version(temp_db_cursor))
def _run_migration(self, migration: Migration, temp_db_cursor: sqlite3.Cursor) -> None:
"""Runs a single migration.""" """Runs a single migration."""
with self._db.lock: # Using sqlite3.Connection as a context manager commits a the transaction on exit, or rolls it back if an
try: # exception is raised. We want to commit the transaction if the migration is successful, or roll it back if
if self._get_current_version(temp_db_cursor) != migration.from_version: # there is an error.
try:
with self._db.lock, self._db.conn as conn:
cursor = conn.cursor()
if self._get_current_version(cursor) != migration.from_version:
raise MigrationError( raise MigrationError(
f"Database is at version {self._get_current_version(temp_db_cursor)}, expected {migration.from_version}" f"Database is at version {self._get_current_version(cursor)}, expected {migration.from_version}"
) )
self._logger.debug(f"Running migration from {migration.from_version} to {migration.to_version}") self._logger.debug(f"Running migration from {migration.from_version} to {migration.to_version}")
@ -106,28 +71,26 @@ class SQLiteMigrator:
if migration.pre_migrate: if migration.pre_migrate:
self._logger.debug(f"Running {len(migration.pre_migrate)} pre-migration callbacks") self._logger.debug(f"Running {len(migration.pre_migrate)} pre-migration callbacks")
for callback in migration.pre_migrate: for callback in migration.pre_migrate:
callback(temp_db_cursor) callback(cursor)
# Run the actual migration # Run the actual migration
migration.migrate(temp_db_cursor) migration.migrate(cursor)
temp_db_cursor.execute("INSERT INTO migrations (version) VALUES (?);", (migration.to_version,)) cursor.execute("INSERT INTO migrations (version) VALUES (?);", (migration.to_version,))
# Run post-migration callbacks # Run post-migration callbacks
if migration.post_migrate: if migration.post_migrate:
self._logger.debug(f"Running {len(migration.post_migrate)} post-migration callbacks") self._logger.debug(f"Running {len(migration.post_migrate)} post-migration callbacks")
for callback in migration.post_migrate: for callback in migration.post_migrate:
callback(temp_db_cursor) callback(cursor)
# Migration callbacks only get a cursor. Commit this migration.
temp_db_cursor.connection.commit()
self._logger.debug( self._logger.debug(
f"Successfully migrated database from {migration.from_version} to {migration.to_version}" f"Successfully migrated database from {migration.from_version} to {migration.to_version}"
) )
except Exception as e: # We want to catch *any* error, mirroring the behaviour of the sqlite3 module.
msg = f"Error migrating database from {migration.from_version} to {migration.to_version}: {e}" except Exception as e:
temp_db_cursor.connection.rollback() # The connection context manager has already rolled back the migration, so we don't need to do anything.
self._logger.error(msg) msg = f"Error migrating database from {migration.from_version} to {migration.to_version}: {e}"
raise MigrationError(msg) from e self._logger.error(msg)
raise MigrationError(msg) from e
def _create_migrations_table(self, cursor: sqlite3.Cursor) -> None: def _create_migrations_table(self, cursor: sqlite3.Cursor) -> None:
"""Creates the migrations table for the database, if one does not already exist.""" """Creates the migrations table for the database, if one does not already exist."""
@ -166,33 +129,3 @@ class SQLiteMigrator:
if "no such table" in str(e): if "no such table" in str(e):
return 0 return 0
raise raise
@classmethod
def _create_temp_db(cls, original_db_path: Path) -> Path:
"""Copies the current database to a new file for migration."""
temp_db_path = cls._get_temp_db_path(original_db_path)
shutil.copy2(original_db_path, temp_db_path)
return temp_db_path
@classmethod
def _finalize_migration(
cls,
temp_db_path: Path,
original_db_path: Path,
) -> Path:
"""Renames the original database as a backup and renames the migrated database to the original name."""
backup_db_path = cls._get_backup_db_path(original_db_path)
original_db_path.rename(backup_db_path)
temp_db_path.rename(original_db_path)
return backup_db_path
@classmethod
def _get_temp_db_path(cls, original_db_path: Path) -> Path:
"""Gets the path to the temp database."""
return original_db_path.parent / original_db_path.name.replace(".db", ".db.temp")
@classmethod
def _get_backup_db_path(cls, original_db_path: Path) -> Path:
"""Gets the path to the final backup database."""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
return original_db_path.parent / f"{original_db_path.stem}_backup_{timestamp}.db"

View File

@ -48,8 +48,6 @@ def store(app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
migrator.register_migration(migration_1) migrator.register_migration(migration_1)
migrator.register_migration(migration_2) migrator.register_migration(migration_2)
migrator.run_migrations() migrator.run_migrations()
# this test uses a file database, so we need to reinitialize it after migrations
db.reinitialize()
store: ModelRecordServiceBase = ModelRecordServiceSQL(db) store: ModelRecordServiceBase = ModelRecordServiceSQL(db)
return store return store

View File

@ -39,8 +39,6 @@ def store(datadir: Any) -> ModelRecordServiceBase:
migrator.register_migration(migration_1) migrator.register_migration(migration_1)
migrator.register_migration(migration_2) migrator.register_migration(migration_2)
migrator.run_migrations() migrator.run_migrations()
# this test uses a file database, so we need to reinitialize it after migrations
db.reinitialize()
return ModelRecordServiceSQL(db) return ModelRecordServiceSQL(db)

View File

@ -198,7 +198,7 @@ def test_migrator_gets_current_version(migrator: SQLiteMigrator, migration_no_op
def test_migrator_runs_single_migration(migrator: SQLiteMigrator, migration_create_test_table: Migration) -> None: def test_migrator_runs_single_migration(migrator: SQLiteMigrator, migration_create_test_table: Migration) -> None:
cursor = migrator._db.conn.cursor() cursor = migrator._db.conn.cursor()
migrator._create_migrations_table(cursor) migrator._create_migrations_table(cursor)
migrator._run_migration(migration_create_test_table, cursor) migrator._run_migration(migration_create_test_table)
assert migrator._get_current_version(cursor) == 1 assert migrator._get_current_version(cursor) == 1
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test';") cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test';")
assert cursor.fetchone() is not None assert cursor.fetchone() is not None
@ -228,46 +228,6 @@ def test_migrator_runs_all_migrations_file(logger: Logger) -> None:
assert SQLiteMigrator._get_current_version(original_db_cursor) == 3 assert SQLiteMigrator._get_current_version(original_db_cursor) == 3
def test_migrator_creates_temp_db() -> None:
with TemporaryDirectory() as tempdir:
original_db_path = Path(tempdir) / "invokeai.db"
with closing(sqlite3.connect(original_db_path)):
# create the db file so _create_temp_db has something to copy
pass
temp_db_path = SQLiteMigrator._create_temp_db(original_db_path)
assert temp_db_path.is_file()
assert temp_db_path == SQLiteMigrator._get_temp_db_path(original_db_path)
def test_migrator_finalizes() -> None:
with TemporaryDirectory() as tempdir:
original_db_path = Path(tempdir) / "invokeai.db"
temp_db_path = SQLiteMigrator._get_temp_db_path(original_db_path)
backup_db_path = SQLiteMigrator._get_backup_db_path(original_db_path)
with closing(sqlite3.connect(original_db_path)) as original_db_conn, closing(
sqlite3.connect(temp_db_path)
) as temp_db_conn:
original_db_cursor = original_db_conn.cursor()
original_db_cursor.execute("CREATE TABLE original_db_test (id INTEGER PRIMARY KEY);")
original_db_conn.commit()
temp_db_cursor = temp_db_conn.cursor()
temp_db_cursor.execute("CREATE TABLE temp_db_test (id INTEGER PRIMARY KEY);")
temp_db_conn.commit()
SQLiteMigrator._finalize_migration(
original_db_path=original_db_path,
temp_db_path=temp_db_path,
)
with closing(sqlite3.connect(backup_db_path)) as backup_db_conn, closing(
sqlite3.connect(temp_db_path)
) as temp_db_conn:
backup_db_cursor = backup_db_conn.cursor()
backup_db_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='original_db_test';")
assert backup_db_cursor.fetchone() is not None
temp_db_cursor = temp_db_conn.cursor()
temp_db_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='temp_db_test';")
assert temp_db_cursor.fetchone() is None
def test_migrator_makes_no_changes_on_failed_migration( def test_migrator_makes_no_changes_on_failed_migration(
migrator: SQLiteMigrator, migration_no_op: Migration, failing_migrate_callback: MigrateCallback migrator: SQLiteMigrator, migration_no_op: Migration, failing_migrate_callback: MigrateCallback
) -> None: ) -> None: