diff --git a/invokeai/app/services/shared/sqlite/sqlite_migrator.py b/invokeai/app/services/shared/sqlite/sqlite_migrator.py index b8a7941a95..9cbdf028d5 100644 --- a/invokeai/app/services/shared/sqlite/sqlite_migrator.py +++ b/invokeai/app/services/shared/sqlite/sqlite_migrator.py @@ -103,6 +103,14 @@ class SQLiteMigrator: self._cursor = self._db.conn.cursor() self._migrations = MigrationSet() + # Use a lock file to indicate that a migration is in progress. Should only exist in the event of catastrophic failure. + self._migration_lock_file_path = ( + self._db.database.parent / ".migration_in_progress" if isinstance(self._db.database, Path) else None + ) + + if self._unlink_lock_file(): + self._logger.warning("Previous migration failed! Trying again...") + def register_migration(self, migration: Migration) -> None: """ Registers a migration. @@ -214,8 +222,7 @@ class SQLiteMigrator: if self._db.is_memory: self._logger.debug("Using memory database, skipping backup") # Sanity check! - if not isinstance(self._db.database, Path): - raise MigrationError(f"Database path must be a Path, got {self._db.database} ({type(self._db.database)})") + assert isinstance(self._db.database, Path) with self._db.lock: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") backup_path = self._db.database.parent / f"{self._db.database.stem}_{timestamp}.db" @@ -236,15 +243,28 @@ class SQLiteMigrator: self, ) -> None: """Restores the database from a backup file, unless the database is a memory database.""" - # We don't need to restore a memory database. if self._db.is_memory: return with self._db.lock: self._logger.info(f"Restoring database from {self.backup_path}") self._db.conn.close() - if self.backup_path is None: - raise FileNotFoundError("No backup path set") - if not Path(self.backup_path).is_file(): - raise FileNotFoundError(f"Backup file {self.backup_path} does not exist") + assert isinstance(self.backup_path, Path) shutil.copy2(self.backup_path, self._db.database) + + def _unlink_lock_file(self) -> bool: + """Unlinks the migration lock file, returning True if it existed.""" + if self._db.is_memory or self._migration_lock_file_path is None: + return False + if self._migration_lock_file_path.is_file(): + self._migration_lock_file_path.unlink() + return True + return False + + def _write_migration_lock_file(self) -> None: + """Writes a file to indicate that a migration is in progress.""" + if self._db.is_memory or self._migration_lock_file_path is None: + return + assert isinstance(self._db.database, Path) + with open(self._migration_lock_file_path, "w") as f: + f.write("1")