From 3227b30430aca0ad01e7d909b4f3286120279ee0 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 11 Dec 2023 12:43:22 +1100 Subject: [PATCH] feat(db): extract non-stateful logic to class methods --- .../services/shared/sqlite/sqlite_migrator.py | 73 ++++++++++--------- 1 file changed, 39 insertions(+), 34 deletions(-) diff --git a/invokeai/app/services/shared/sqlite/sqlite_migrator.py b/invokeai/app/services/shared/sqlite/sqlite_migrator.py index 5672cdbceb..9a6984dbd3 100644 --- a/invokeai/app/services/shared/sqlite/sqlite_migrator.py +++ b/invokeai/app/services/shared/sqlite/sqlite_migrator.py @@ -112,17 +112,6 @@ class MigrationSet: return sorted(self._migrations, key=lambda m: m.to_version)[-1].to_version -def get_temp_db_path(original_db_path: Path) -> Path: - """Gets the path to the temp database.""" - return original_db_path.parent / original_db_path.name.replace(".db", ".db.temp") - - -def get_backup_db_path(original_db_path: Path) -> Path: - """Gets the path to the final backup database.""" - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - return original_db_path.parent / f"{original_db_path.stem}_backup_{timestamp}.db" - - class SQLiteMigrator: """ Manages migrations for a SQLite database. @@ -161,9 +150,9 @@ class SQLiteMigrator: self._migration_set = MigrationSet() # The presence of an temp database file indicates a catastrophic failure of a previous migration. - if self._db_path and get_temp_db_path(self._db_path).is_file(): + if self._db_path and self._get_temp_db_path(self._db_path).is_file(): self._logger.warning("Previous migration failed! Trying again...") - get_temp_db_path(self._db_path).unlink() + self._get_temp_db_path(self._db_path).unlink() def register_migration(self, migration: Migration) -> None: """Registers a migration.""" @@ -200,10 +189,11 @@ class SQLiteMigrator: # Close the connections, copy the original database as a backup, and move the temp database to the # original database's path. backup_db_path = self._finalize_migration( - temp_db_conn=temp_db_conn, temp_db_path=temp_db_path, original_db_path=self._db_path, ) + temp_db_conn.close() + self._conn.close() self._logger.info(f"Migration successful. Original DB backed up to {backup_db_path}") else: # We are using a memory database. No special backup or special handling needed. @@ -280,31 +270,46 @@ class SQLiteMigrator: cursor.connection.rollback() raise MigrationError(msg) from e - def _get_current_version(self, cursor: sqlite3.Cursor) -> int: + @classmethod + def _get_current_version(cls, cursor: sqlite3.Cursor) -> int: """Gets the current version of the database, or 0 if the migrations table does not exist.""" - with self._lock: - try: - cursor.execute("SELECT MAX(version) FROM migrations;") - version = cursor.fetchone()[0] - if version is None: - return 0 - return version - except sqlite3.OperationalError as e: - if "no such table" in str(e): - return 0 - raise + try: + cursor.execute("SELECT MAX(version) FROM migrations;") + version = cursor.fetchone()[0] + if version is None: + return 0 + return version + except sqlite3.OperationalError as e: + if "no such table" in str(e): + return 0 + raise - def _create_temp_db(self, current_db_path: Path) -> Path: + @classmethod + def _create_temp_db(cls, original_db_path: Path) -> Path: """Copies the current database to a new file for migration.""" - temp_db_path = get_temp_db_path(current_db_path) - shutil.copy2(current_db_path, temp_db_path) + temp_db_path = cls._get_temp_db_path(original_db_path) + shutil.copy2(original_db_path, temp_db_path) return temp_db_path - def _finalize_migration(self, temp_db_conn: sqlite3.Connection, temp_db_path: Path, original_db_path: Path) -> Path: - """Closes connections, renames the original database as a backup and renames the migrated database to the original name.""" - self._conn.close() - temp_db_conn.close() - backup_db_path = get_backup_db_path(original_db_path) + @classmethod + def _finalize_migration( + cls, + temp_db_path: Path, + original_db_path: Path, + ) -> Path: + """Renames the original database as a backup and renames the migrated database to the original name.""" + backup_db_path = cls._get_backup_db_path(original_db_path) original_db_path.rename(backup_db_path) temp_db_path.rename(original_db_path) return backup_db_path + + @classmethod + def _get_temp_db_path(cls, original_db_path: Path) -> Path: + """Gets the path to the temp database.""" + return original_db_path.parent / original_db_path.name.replace(".db", ".db.temp") + + @classmethod + def _get_backup_db_path(cls, original_db_path: Path) -> Path: + """Gets the path to the final backup database.""" + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + return original_db_path.parent / f"{original_db_path.stem}_backup_{timestamp}.db"