From afe4e55bf9a1e7b7cf0163a672c4ea5e9378e48b Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 12 Dec 2023 09:52:03 +1100 Subject: [PATCH] feat(db): simplify migration registration validation With the previous change to assert that the to_version == from_version + 1, this validation can be simpler. --- .../sqlite_migrator/sqlite_migrator_common.py | 8 ++++---- tests/test_sqlite_migrator.py | 18 ++++++++++-------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/invokeai/app/services/shared/sqlite_migrator/sqlite_migrator_common.py b/invokeai/app/services/shared/sqlite_migrator/sqlite_migrator_common.py index 8c71c1d969..0c395f54d6 100644 --- a/invokeai/app/services/shared/sqlite_migrator/sqlite_migrator_common.py +++ b/invokeai/app/services/shared/sqlite_migrator/sqlite_migrator_common.py @@ -65,10 +65,10 @@ class MigrationSet: def register(self, migration: Migration) -> None: """Registers a migration.""" - if any(m.from_version == migration.from_version for m in self._migrations): - raise MigrationVersionError(f"Migration from {migration.from_version} already registered") - if any(m.to_version == migration.to_version for m in self._migrations): - raise MigrationVersionError(f"Migration to {migration.to_version} already registered") + migration_from_already_registered = any(m.from_version == migration.from_version for m in self._migrations) + migration_to_already_registered = any(m.to_version == migration.to_version for m in self._migrations) + if migration_from_already_registered or migration_to_already_registered: + raise MigrationVersionError("Migration with from_version or to_version already registered") self._migrations.add(migration) def get(self, from_version: int) -> Optional[Migration]: diff --git a/tests/test_sqlite_migrator.py b/tests/test_sqlite_migrator.py index 11a9b0ef78..630fb5dd3b 100644 --- a/tests/test_sqlite_migrator.py +++ b/tests/test_sqlite_migrator.py @@ -116,14 +116,16 @@ def test_migration_set_add_migration(migrator: SQLiteMigrator, migration_no_op: def test_migration_set_may_not_register_dupes( migrator: SQLiteMigrator, no_op_migrate_callback: MigrateCallback ) -> None: - migrate_1_to_2 = Migration(from_version=1, to_version=2, migrate=no_op_migrate_callback) - migrate_0_to_2 = Migration(from_version=0, to_version=2, migrate=no_op_migrate_callback) - migrate_1_to_3 = Migration(from_version=1, to_version=3, migrate=no_op_migrate_callback) - migrator._migration_set.register(migrate_1_to_2) - with pytest.raises(MigrationVersionError, match=r"Migration to 2 already registered"): - migrator._migration_set.register(migrate_0_to_2) - with pytest.raises(MigrationVersionError, match=r"Migration from 1 already registered"): - migrator._migration_set.register(migrate_1_to_3) + migrate_0_to_1_a = Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback) + migrate_0_to_1_b = Migration(from_version=0, to_version=1, migrate=no_op_migrate_callback) + migrator._migration_set.register(migrate_0_to_1_a) + with pytest.raises(MigrationVersionError, match=r"Migration with from_version or to_version already registered"): + migrator._migration_set.register(migrate_0_to_1_b) + migrate_1_to_2_a = Migration(from_version=1, to_version=2, migrate=no_op_migrate_callback) + migrate_1_to_2_b = Migration(from_version=1, to_version=2, migrate=no_op_migrate_callback) + migrator._migration_set.register(migrate_1_to_2_a) + with pytest.raises(MigrationVersionError, match=r"Migration with from_version or to_version already registered"): + migrator._migration_set.register(migrate_1_to_2_b) def test_migration_set_gets_migration(migration_no_op: Migration) -> None: