mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(db): extract non-stateful logic to class methods
This commit is contained in:
parent
567f107a81
commit
3227b30430
@ -112,17 +112,6 @@ class MigrationSet:
|
|||||||
return sorted(self._migrations, key=lambda m: m.to_version)[-1].to_version
|
return sorted(self._migrations, key=lambda m: m.to_version)[-1].to_version
|
||||||
|
|
||||||
|
|
||||||
def get_temp_db_path(original_db_path: Path) -> Path:
|
|
||||||
"""Gets the path to the temp database."""
|
|
||||||
return original_db_path.parent / original_db_path.name.replace(".db", ".db.temp")
|
|
||||||
|
|
||||||
|
|
||||||
def get_backup_db_path(original_db_path: Path) -> Path:
|
|
||||||
"""Gets the path to the final backup database."""
|
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
||||||
return original_db_path.parent / f"{original_db_path.stem}_backup_{timestamp}.db"
|
|
||||||
|
|
||||||
|
|
||||||
class SQLiteMigrator:
|
class SQLiteMigrator:
|
||||||
"""
|
"""
|
||||||
Manages migrations for a SQLite database.
|
Manages migrations for a SQLite database.
|
||||||
@ -161,9 +150,9 @@ class SQLiteMigrator:
|
|||||||
self._migration_set = MigrationSet()
|
self._migration_set = MigrationSet()
|
||||||
|
|
||||||
# The presence of an temp database file indicates a catastrophic failure of a previous migration.
|
# The presence of an temp database file indicates a catastrophic failure of a previous migration.
|
||||||
if self._db_path and get_temp_db_path(self._db_path).is_file():
|
if self._db_path and self._get_temp_db_path(self._db_path).is_file():
|
||||||
self._logger.warning("Previous migration failed! Trying again...")
|
self._logger.warning("Previous migration failed! Trying again...")
|
||||||
get_temp_db_path(self._db_path).unlink()
|
self._get_temp_db_path(self._db_path).unlink()
|
||||||
|
|
||||||
def register_migration(self, migration: Migration) -> None:
|
def register_migration(self, migration: Migration) -> None:
|
||||||
"""Registers a migration."""
|
"""Registers a migration."""
|
||||||
@ -200,10 +189,11 @@ class SQLiteMigrator:
|
|||||||
# Close the connections, copy the original database as a backup, and move the temp database to the
|
# Close the connections, copy the original database as a backup, and move the temp database to the
|
||||||
# original database's path.
|
# original database's path.
|
||||||
backup_db_path = self._finalize_migration(
|
backup_db_path = self._finalize_migration(
|
||||||
temp_db_conn=temp_db_conn,
|
|
||||||
temp_db_path=temp_db_path,
|
temp_db_path=temp_db_path,
|
||||||
original_db_path=self._db_path,
|
original_db_path=self._db_path,
|
||||||
)
|
)
|
||||||
|
temp_db_conn.close()
|
||||||
|
self._conn.close()
|
||||||
self._logger.info(f"Migration successful. Original DB backed up to {backup_db_path}")
|
self._logger.info(f"Migration successful. Original DB backed up to {backup_db_path}")
|
||||||
else:
|
else:
|
||||||
# We are using a memory database. No special backup or special handling needed.
|
# We are using a memory database. No special backup or special handling needed.
|
||||||
@ -280,31 +270,46 @@ class SQLiteMigrator:
|
|||||||
cursor.connection.rollback()
|
cursor.connection.rollback()
|
||||||
raise MigrationError(msg) from e
|
raise MigrationError(msg) from e
|
||||||
|
|
||||||
def _get_current_version(self, cursor: sqlite3.Cursor) -> int:
|
@classmethod
|
||||||
|
def _get_current_version(cls, cursor: sqlite3.Cursor) -> int:
|
||||||
"""Gets the current version of the database, or 0 if the migrations table does not exist."""
|
"""Gets the current version of the database, or 0 if the migrations table does not exist."""
|
||||||
with self._lock:
|
try:
|
||||||
try:
|
cursor.execute("SELECT MAX(version) FROM migrations;")
|
||||||
cursor.execute("SELECT MAX(version) FROM migrations;")
|
version = cursor.fetchone()[0]
|
||||||
version = cursor.fetchone()[0]
|
if version is None:
|
||||||
if version is None:
|
return 0
|
||||||
return 0
|
return version
|
||||||
return version
|
except sqlite3.OperationalError as e:
|
||||||
except sqlite3.OperationalError as e:
|
if "no such table" in str(e):
|
||||||
if "no such table" in str(e):
|
return 0
|
||||||
return 0
|
raise
|
||||||
raise
|
|
||||||
|
|
||||||
def _create_temp_db(self, current_db_path: Path) -> Path:
|
@classmethod
|
||||||
|
def _create_temp_db(cls, original_db_path: Path) -> Path:
|
||||||
"""Copies the current database to a new file for migration."""
|
"""Copies the current database to a new file for migration."""
|
||||||
temp_db_path = get_temp_db_path(current_db_path)
|
temp_db_path = cls._get_temp_db_path(original_db_path)
|
||||||
shutil.copy2(current_db_path, temp_db_path)
|
shutil.copy2(original_db_path, temp_db_path)
|
||||||
return temp_db_path
|
return temp_db_path
|
||||||
|
|
||||||
def _finalize_migration(self, temp_db_conn: sqlite3.Connection, temp_db_path: Path, original_db_path: Path) -> Path:
|
@classmethod
|
||||||
"""Closes connections, renames the original database as a backup and renames the migrated database to the original name."""
|
def _finalize_migration(
|
||||||
self._conn.close()
|
cls,
|
||||||
temp_db_conn.close()
|
temp_db_path: Path,
|
||||||
backup_db_path = get_backup_db_path(original_db_path)
|
original_db_path: Path,
|
||||||
|
) -> Path:
|
||||||
|
"""Renames the original database as a backup and renames the migrated database to the original name."""
|
||||||
|
backup_db_path = cls._get_backup_db_path(original_db_path)
|
||||||
original_db_path.rename(backup_db_path)
|
original_db_path.rename(backup_db_path)
|
||||||
temp_db_path.rename(original_db_path)
|
temp_db_path.rename(original_db_path)
|
||||||
return backup_db_path
|
return backup_db_path
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_temp_db_path(cls, original_db_path: Path) -> Path:
|
||||||
|
"""Gets the path to the temp database."""
|
||||||
|
return original_db_path.parent / original_db_path.name.replace(".db", ".db.temp")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_backup_db_path(cls, original_db_path: Path) -> Path:
|
||||||
|
"""Gets the path to the final backup database."""
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
return original_db_path.parent / f"{original_db_path.stem}_backup_{timestamp}.db"
|
||||||
|
Loading…
Reference in New Issue
Block a user