fix(db): fix migration chain validation

This commit is contained in:
psychedelicious
2023-12-11 11:15:09 +11:00
parent b3f92e0547
commit 8726b203d4

View File

@ -81,27 +81,23 @@ class MigrationSet:
# register() ensures that there is only one migration with a given from_version, so this is safe. # register() ensures that there is only one migration with a given from_version, so this is safe.
return next((m for m in self._migrations if m.from_version == from_version), None) return next((m for m in self._migrations if m.from_version == from_version), None)
def validate_migration_path(self) -> None: def validate_migration_chain(self) -> None:
""" """
Validates that the migrations form a single path of migrations from version 0 to the latest version. Validates that the migrations form a single chain of migrations from version 0 to the latest version.
Raises a MigrationError if there is a problem. Raises a MigrationError if there is a problem.
""" """
if self.count == 0: if self.count == 0:
return return
if self.latest_version == 0: if self.latest_version == 0:
return return
current_version = 0 next_migration = self.get(from_version=0)
touched_count = 0 touched_count = 1
while current_version < self.latest_version: while next_migration is not None:
migration = self.get(current_version) next_migration = self.get(next_migration.to_version)
if migration is None: if next_migration is not None:
raise MigrationError(f"Missing migration from {current_version}") touched_count += 1
current_version = migration.to_version
touched_count += 1
if current_version != self.latest_version:
raise MigrationError(f"Missing migration to {self.latest_version}")
if touched_count != self.count: if touched_count != self.count:
raise MigrationError("Migration path is not contiguous") raise MigrationError("Migration chain is fragmented")
@property @property
def count(self) -> int: def count(self) -> int:
@ -178,7 +174,7 @@ class SQLiteMigrator:
"""Migrates the database to the latest version.""" """Migrates the database to the latest version."""
with self._lock: with self._lock:
# This throws if there is a problem. # This throws if there is a problem.
self._migrations.validate_migration_path() self._migration_set.validate_migration_chain()
self._create_migrations_table(cursor=self._cursor) self._create_migrations_table(cursor=self._cursor)
if self._migrations.count == 0: if self._migrations.count == 0: