mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(db): remove file backups
Instead of mucking about with the filesystem, we rely on SQLite transactions to handle failed migrations.
This commit is contained in:
parent
3414437eea
commit
c5ba4f2ea5
@ -81,13 +81,7 @@ class ApiDependencies:
|
|||||||
migration_2.register_post_callback(partial(migrate_embedded_workflows, logger=logger, image_files=image_files))
|
migration_2.register_post_callback(partial(migrate_embedded_workflows, logger=logger, image_files=image_files))
|
||||||
migrator.register_migration(migration_1)
|
migrator.register_migration(migration_1)
|
||||||
migrator.register_migration(migration_2)
|
migrator.register_migration(migration_2)
|
||||||
did_migrate = migrator.run_migrations()
|
migrator.run_migrations()
|
||||||
|
|
||||||
# We need to reinitialize the database if we migrated, but only if we are using a file database.
|
|
||||||
# This closes the SqliteDatabase's connection and re-runs its `__init__` logic.
|
|
||||||
# If we do this with a memory database, we wipe the db.
|
|
||||||
if not db.db_path and did_migrate:
|
|
||||||
db.reinitialize()
|
|
||||||
|
|
||||||
configuration = config
|
configuration = config
|
||||||
logger = logger
|
logger = logger
|
||||||
|
@ -10,25 +10,22 @@ class SqliteDatabase:
|
|||||||
"""
|
"""
|
||||||
Manages a connection to an SQLite database.
|
Manages a connection to an SQLite database.
|
||||||
|
|
||||||
|
:param db_path: Path to the database file. If None, an in-memory database is used.
|
||||||
|
:param logger: Logger to use for logging.
|
||||||
|
:param verbose: Whether to log SQL statements. Provides `logger.debug` as the SQLite trace callback.
|
||||||
|
|
||||||
This is a light wrapper around the `sqlite3` module, providing a few conveniences:
|
This is a light wrapper around the `sqlite3` module, providing a few conveniences:
|
||||||
- The database file is written to disk if it does not exist.
|
- The database file is written to disk if it does not exist.
|
||||||
- Foreign key constraints are enabled by default.
|
- Foreign key constraints are enabled by default.
|
||||||
- The connection is configured to use the `sqlite3.Row` row factory.
|
- The connection is configured to use the `sqlite3.Row` row factory.
|
||||||
- A `conn` attribute is provided to access the connection.
|
|
||||||
- A `lock` attribute is provided to lock the database connection.
|
|
||||||
- A `clean` method to run the VACUUM command and report on the freed space.
|
|
||||||
- A `reinitialize` method to close the connection and re-run the init.
|
|
||||||
- A `close` method to close the connection.
|
|
||||||
|
|
||||||
:param db_path: Path to the database file. If None, an in-memory database is used.
|
In addition to the constructor args, the instance provides the following attributes and methods:
|
||||||
:param logger: Logger to use for logging.
|
- `conn`: A `sqlite3.Connection` object. Note that the connection must never be closed if the database is in-memory.
|
||||||
:param verbose: Whether to log SQL statements. Provides `logger.debug` as the SQLite trace callback.
|
- `lock`: A shared re-entrant lock, used to approximate thread safety.
|
||||||
|
- `clean()`: Runs the SQL `VACUUM;` command and reports on the freed space.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, db_path: Path | None, logger: Logger, verbose: bool = False) -> None:
|
def __init__(self, db_path: Path | None, logger: Logger, verbose: bool = False) -> None:
|
||||||
self.initialize(db_path=db_path, logger=logger, verbose=verbose)
|
|
||||||
|
|
||||||
def initialize(self, db_path: Path | None, logger: Logger, verbose: bool = False) -> None:
|
|
||||||
"""Initializes the database. This is used internally by the class constructor."""
|
"""Initializes the database. This is used internally by the class constructor."""
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.db_path = db_path
|
self.db_path = db_path
|
||||||
@ -49,21 +46,6 @@ class SqliteDatabase:
|
|||||||
|
|
||||||
self.conn.execute("PRAGMA foreign_keys = ON;")
|
self.conn.execute("PRAGMA foreign_keys = ON;")
|
||||||
|
|
||||||
def reinitialize(self) -> None:
|
|
||||||
"""
|
|
||||||
Re-initializes the database by closing the connection and re-running the init.
|
|
||||||
Warning: This will wipe the database if it is an in-memory database.
|
|
||||||
"""
|
|
||||||
self.close()
|
|
||||||
self.initialize(db_path=self.db_path, logger=self.logger, verbose=self.verbose)
|
|
||||||
|
|
||||||
def close(self) -> None:
|
|
||||||
"""
|
|
||||||
Closes the connection to the database.
|
|
||||||
Warning: This will wipe the database if it is an in-memory database.
|
|
||||||
"""
|
|
||||||
self.conn.close()
|
|
||||||
|
|
||||||
def clean(self) -> None:
|
def clean(self) -> None:
|
||||||
"""
|
"""
|
||||||
Cleans the database by running the VACUUM command, reporting on the freed space.
|
Cleans the database by running the VACUUM command, reporting on the freed space.
|
||||||
|
@ -1,6 +1,4 @@
|
|||||||
import shutil
|
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@ -12,19 +10,11 @@ class SQLiteMigrator:
|
|||||||
"""
|
"""
|
||||||
Manages migrations for a SQLite database.
|
Manages migrations for a SQLite database.
|
||||||
|
|
||||||
:param db_path: The path to the database to migrate, or None if using an in-memory database.
|
:param db: The instanceof :class:`SqliteDatabase` to migrate.
|
||||||
:param conn: The connection to the database.
|
|
||||||
:param lock: A lock to use when running migrations.
|
|
||||||
:param logger: A logger to use for logging.
|
|
||||||
:param log_sql: Whether to log SQL statements. Only used when the log level is set to debug.
|
|
||||||
|
|
||||||
Migrations should be registered with :meth:`register_migration`.
|
Migrations should be registered with :meth:`register_migration`.
|
||||||
|
|
||||||
During migration, a copy of the current database is made and the migrations are run on the copy. If the migration
|
Each migration is run in a transaction. If a migration fails, the transaction is rolled back.
|
||||||
is successful, the original database is backed up and the migrated database is moved to the original database's
|
|
||||||
path. If the migration fails, the original database is left untouched and the migrated database is deleted.
|
|
||||||
|
|
||||||
If the database is in-memory, no backup is made; the migration is run in-place.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
backup_path: Optional[Path] = None
|
backup_path: Optional[Path] = None
|
||||||
@ -34,11 +24,6 @@ class SQLiteMigrator:
|
|||||||
self._logger = db.logger
|
self._logger = db.logger
|
||||||
self._migration_set = MigrationSet()
|
self._migration_set = MigrationSet()
|
||||||
|
|
||||||
# The presence of an temp database file indicates a catastrophic failure of a previous migration.
|
|
||||||
if self._db.db_path and self._get_temp_db_path(self._db.db_path).is_file():
|
|
||||||
self._logger.warning("Previous migration failed! Trying again...")
|
|
||||||
self._get_temp_db_path(self._db.db_path).unlink()
|
|
||||||
|
|
||||||
def register_migration(self, migration: Migration) -> None:
|
def register_migration(self, migration: Migration) -> None:
|
||||||
"""Registers a migration."""
|
"""Registers a migration."""
|
||||||
self._migration_set.register(migration)
|
self._migration_set.register(migration)
|
||||||
@ -61,44 +46,24 @@ class SQLiteMigrator:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
self._logger.info("Database update needed")
|
self._logger.info("Database update needed")
|
||||||
|
next_migration = self._migration_set.get(from_version=self._get_current_version(cursor))
|
||||||
if self._db.db_path:
|
while next_migration is not None:
|
||||||
# We are using a file database. Create a copy of the database to run the migrations on.
|
self._run_migration(next_migration)
|
||||||
temp_db_path = self._create_temp_db(self._db.db_path)
|
next_migration = self._migration_set.get(self._get_current_version(cursor))
|
||||||
self._logger.info(f"Copied database to {temp_db_path} for migration")
|
|
||||||
temp_db = SqliteDatabase(db_path=temp_db_path, logger=self._logger, verbose=self._db.verbose)
|
|
||||||
temp_db_cursor = temp_db.conn.cursor()
|
|
||||||
self._run_migrations(temp_db_cursor)
|
|
||||||
# Close the connections, copy the original database as a backup, and move the temp database to the
|
|
||||||
# original database's path.
|
|
||||||
temp_db.close()
|
|
||||||
self._db.close()
|
|
||||||
backup_db_path = self._finalize_migration(
|
|
||||||
temp_db_path=temp_db_path,
|
|
||||||
original_db_path=self._db.db_path,
|
|
||||||
)
|
|
||||||
self._logger.info(f"Migration successful. Original DB backed up to {backup_db_path}")
|
|
||||||
else:
|
|
||||||
# We are using a memory database. No special backup or special handling needed.
|
|
||||||
self._run_migrations(cursor)
|
|
||||||
|
|
||||||
self._logger.info("Database updated successfully")
|
self._logger.info("Database updated successfully")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _run_migrations(self, temp_db_cursor: sqlite3.Cursor) -> None:
|
def _run_migration(self, migration: Migration) -> None:
|
||||||
"""Runs all migrations in a loop."""
|
|
||||||
next_migration = self._migration_set.get(from_version=self._get_current_version(temp_db_cursor))
|
|
||||||
while next_migration is not None:
|
|
||||||
self._run_migration(next_migration, temp_db_cursor)
|
|
||||||
next_migration = self._migration_set.get(self._get_current_version(temp_db_cursor))
|
|
||||||
|
|
||||||
def _run_migration(self, migration: Migration, temp_db_cursor: sqlite3.Cursor) -> None:
|
|
||||||
"""Runs a single migration."""
|
"""Runs a single migration."""
|
||||||
with self._db.lock:
|
# Using sqlite3.Connection as a context manager commits a the transaction on exit, or rolls it back if an
|
||||||
try:
|
# exception is raised. We want to commit the transaction if the migration is successful, or roll it back if
|
||||||
if self._get_current_version(temp_db_cursor) != migration.from_version:
|
# there is an error.
|
||||||
|
try:
|
||||||
|
with self._db.lock, self._db.conn as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
if self._get_current_version(cursor) != migration.from_version:
|
||||||
raise MigrationError(
|
raise MigrationError(
|
||||||
f"Database is at version {self._get_current_version(temp_db_cursor)}, expected {migration.from_version}"
|
f"Database is at version {self._get_current_version(cursor)}, expected {migration.from_version}"
|
||||||
)
|
)
|
||||||
self._logger.debug(f"Running migration from {migration.from_version} to {migration.to_version}")
|
self._logger.debug(f"Running migration from {migration.from_version} to {migration.to_version}")
|
||||||
|
|
||||||
@ -106,28 +71,26 @@ class SQLiteMigrator:
|
|||||||
if migration.pre_migrate:
|
if migration.pre_migrate:
|
||||||
self._logger.debug(f"Running {len(migration.pre_migrate)} pre-migration callbacks")
|
self._logger.debug(f"Running {len(migration.pre_migrate)} pre-migration callbacks")
|
||||||
for callback in migration.pre_migrate:
|
for callback in migration.pre_migrate:
|
||||||
callback(temp_db_cursor)
|
callback(cursor)
|
||||||
|
|
||||||
# Run the actual migration
|
# Run the actual migration
|
||||||
migration.migrate(temp_db_cursor)
|
migration.migrate(cursor)
|
||||||
temp_db_cursor.execute("INSERT INTO migrations (version) VALUES (?);", (migration.to_version,))
|
cursor.execute("INSERT INTO migrations (version) VALUES (?);", (migration.to_version,))
|
||||||
|
|
||||||
# Run post-migration callbacks
|
# Run post-migration callbacks
|
||||||
if migration.post_migrate:
|
if migration.post_migrate:
|
||||||
self._logger.debug(f"Running {len(migration.post_migrate)} post-migration callbacks")
|
self._logger.debug(f"Running {len(migration.post_migrate)} post-migration callbacks")
|
||||||
for callback in migration.post_migrate:
|
for callback in migration.post_migrate:
|
||||||
callback(temp_db_cursor)
|
callback(cursor)
|
||||||
|
|
||||||
# Migration callbacks only get a cursor. Commit this migration.
|
|
||||||
temp_db_cursor.connection.commit()
|
|
||||||
self._logger.debug(
|
self._logger.debug(
|
||||||
f"Successfully migrated database from {migration.from_version} to {migration.to_version}"
|
f"Successfully migrated database from {migration.from_version} to {migration.to_version}"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
# We want to catch *any* error, mirroring the behaviour of the sqlite3 module.
|
||||||
msg = f"Error migrating database from {migration.from_version} to {migration.to_version}: {e}"
|
except Exception as e:
|
||||||
temp_db_cursor.connection.rollback()
|
# The connection context manager has already rolled back the migration, so we don't need to do anything.
|
||||||
self._logger.error(msg)
|
msg = f"Error migrating database from {migration.from_version} to {migration.to_version}: {e}"
|
||||||
raise MigrationError(msg) from e
|
self._logger.error(msg)
|
||||||
|
raise MigrationError(msg) from e
|
||||||
|
|
||||||
def _create_migrations_table(self, cursor: sqlite3.Cursor) -> None:
|
def _create_migrations_table(self, cursor: sqlite3.Cursor) -> None:
|
||||||
"""Creates the migrations table for the database, if one does not already exist."""
|
"""Creates the migrations table for the database, if one does not already exist."""
|
||||||
@ -166,33 +129,3 @@ class SQLiteMigrator:
|
|||||||
if "no such table" in str(e):
|
if "no such table" in str(e):
|
||||||
return 0
|
return 0
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _create_temp_db(cls, original_db_path: Path) -> Path:
|
|
||||||
"""Copies the current database to a new file for migration."""
|
|
||||||
temp_db_path = cls._get_temp_db_path(original_db_path)
|
|
||||||
shutil.copy2(original_db_path, temp_db_path)
|
|
||||||
return temp_db_path
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _finalize_migration(
|
|
||||||
cls,
|
|
||||||
temp_db_path: Path,
|
|
||||||
original_db_path: Path,
|
|
||||||
) -> Path:
|
|
||||||
"""Renames the original database as a backup and renames the migrated database to the original name."""
|
|
||||||
backup_db_path = cls._get_backup_db_path(original_db_path)
|
|
||||||
original_db_path.rename(backup_db_path)
|
|
||||||
temp_db_path.rename(original_db_path)
|
|
||||||
return backup_db_path
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _get_temp_db_path(cls, original_db_path: Path) -> Path:
|
|
||||||
"""Gets the path to the temp database."""
|
|
||||||
return original_db_path.parent / original_db_path.name.replace(".db", ".db.temp")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _get_backup_db_path(cls, original_db_path: Path) -> Path:
|
|
||||||
"""Gets the path to the final backup database."""
|
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
||||||
return original_db_path.parent / f"{original_db_path.stem}_backup_{timestamp}.db"
|
|
||||||
|
@ -48,8 +48,6 @@ def store(app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
|
|||||||
migrator.register_migration(migration_1)
|
migrator.register_migration(migration_1)
|
||||||
migrator.register_migration(migration_2)
|
migrator.register_migration(migration_2)
|
||||||
migrator.run_migrations()
|
migrator.run_migrations()
|
||||||
# this test uses a file database, so we need to reinitialize it after migrations
|
|
||||||
db.reinitialize()
|
|
||||||
store: ModelRecordServiceBase = ModelRecordServiceSQL(db)
|
store: ModelRecordServiceBase = ModelRecordServiceSQL(db)
|
||||||
return store
|
return store
|
||||||
|
|
||||||
|
@ -39,8 +39,6 @@ def store(datadir: Any) -> ModelRecordServiceBase:
|
|||||||
migrator.register_migration(migration_1)
|
migrator.register_migration(migration_1)
|
||||||
migrator.register_migration(migration_2)
|
migrator.register_migration(migration_2)
|
||||||
migrator.run_migrations()
|
migrator.run_migrations()
|
||||||
# this test uses a file database, so we need to reinitialize it after migrations
|
|
||||||
db.reinitialize()
|
|
||||||
return ModelRecordServiceSQL(db)
|
return ModelRecordServiceSQL(db)
|
||||||
|
|
||||||
|
|
||||||
|
@ -198,7 +198,7 @@ def test_migrator_gets_current_version(migrator: SQLiteMigrator, migration_no_op
|
|||||||
def test_migrator_runs_single_migration(migrator: SQLiteMigrator, migration_create_test_table: Migration) -> None:
|
def test_migrator_runs_single_migration(migrator: SQLiteMigrator, migration_create_test_table: Migration) -> None:
|
||||||
cursor = migrator._db.conn.cursor()
|
cursor = migrator._db.conn.cursor()
|
||||||
migrator._create_migrations_table(cursor)
|
migrator._create_migrations_table(cursor)
|
||||||
migrator._run_migration(migration_create_test_table, cursor)
|
migrator._run_migration(migration_create_test_table)
|
||||||
assert migrator._get_current_version(cursor) == 1
|
assert migrator._get_current_version(cursor) == 1
|
||||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test';")
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test';")
|
||||||
assert cursor.fetchone() is not None
|
assert cursor.fetchone() is not None
|
||||||
@ -228,46 +228,6 @@ def test_migrator_runs_all_migrations_file(logger: Logger) -> None:
|
|||||||
assert SQLiteMigrator._get_current_version(original_db_cursor) == 3
|
assert SQLiteMigrator._get_current_version(original_db_cursor) == 3
|
||||||
|
|
||||||
|
|
||||||
def test_migrator_creates_temp_db() -> None:
|
|
||||||
with TemporaryDirectory() as tempdir:
|
|
||||||
original_db_path = Path(tempdir) / "invokeai.db"
|
|
||||||
with closing(sqlite3.connect(original_db_path)):
|
|
||||||
# create the db file so _create_temp_db has something to copy
|
|
||||||
pass
|
|
||||||
temp_db_path = SQLiteMigrator._create_temp_db(original_db_path)
|
|
||||||
assert temp_db_path.is_file()
|
|
||||||
assert temp_db_path == SQLiteMigrator._get_temp_db_path(original_db_path)
|
|
||||||
|
|
||||||
|
|
||||||
def test_migrator_finalizes() -> None:
|
|
||||||
with TemporaryDirectory() as tempdir:
|
|
||||||
original_db_path = Path(tempdir) / "invokeai.db"
|
|
||||||
temp_db_path = SQLiteMigrator._get_temp_db_path(original_db_path)
|
|
||||||
backup_db_path = SQLiteMigrator._get_backup_db_path(original_db_path)
|
|
||||||
with closing(sqlite3.connect(original_db_path)) as original_db_conn, closing(
|
|
||||||
sqlite3.connect(temp_db_path)
|
|
||||||
) as temp_db_conn:
|
|
||||||
original_db_cursor = original_db_conn.cursor()
|
|
||||||
original_db_cursor.execute("CREATE TABLE original_db_test (id INTEGER PRIMARY KEY);")
|
|
||||||
original_db_conn.commit()
|
|
||||||
temp_db_cursor = temp_db_conn.cursor()
|
|
||||||
temp_db_cursor.execute("CREATE TABLE temp_db_test (id INTEGER PRIMARY KEY);")
|
|
||||||
temp_db_conn.commit()
|
|
||||||
SQLiteMigrator._finalize_migration(
|
|
||||||
original_db_path=original_db_path,
|
|
||||||
temp_db_path=temp_db_path,
|
|
||||||
)
|
|
||||||
with closing(sqlite3.connect(backup_db_path)) as backup_db_conn, closing(
|
|
||||||
sqlite3.connect(temp_db_path)
|
|
||||||
) as temp_db_conn:
|
|
||||||
backup_db_cursor = backup_db_conn.cursor()
|
|
||||||
backup_db_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='original_db_test';")
|
|
||||||
assert backup_db_cursor.fetchone() is not None
|
|
||||||
temp_db_cursor = temp_db_conn.cursor()
|
|
||||||
temp_db_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='temp_db_test';")
|
|
||||||
assert temp_db_cursor.fetchone() is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_migrator_makes_no_changes_on_failed_migration(
|
def test_migrator_makes_no_changes_on_failed_migration(
|
||||||
migrator: SQLiteMigrator, migration_no_op: Migration, failing_migrate_callback: MigrateCallback
|
migrator: SQLiteMigrator, migration_no_op: Migration, failing_migrate_callback: MigrateCallback
|
||||||
) -> None:
|
) -> None:
|
||||||
|
Loading…
Reference in New Issue
Block a user