tests(config): update tests for config migration

This commit is contained in:
psychedelicious 2024-05-14 16:55:53 +10:00
parent 7d8b011f89
commit 964adb817c

View File

@ -8,11 +8,12 @@ from packaging.version import Version
from pydantic import ValidationError from pydantic import ValidationError
from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.config.config_common import AppConfigDict, ConfigMigration
from invokeai.app.services.config.config_default import ( from invokeai.app.services.config.config_default import (
DefaultInvokeAIAppConfig, DefaultInvokeAIAppConfig,
InvokeAIAppConfig, InvokeAIAppConfig,
) )
from invokeai.app.services.config.config_migrate import get_config, load_and_migrate_config from invokeai.app.services.config.config_migrate import ConfigMigrator, get_config, load_and_migrate_config
from invokeai.app.services.shared.graph import Graph from invokeai.app.services.shared.graph import Graph
from invokeai.frontend.cli.arg_parser import InvokeAIArgs from invokeai.frontend.cli.arg_parser import InvokeAIArgs
@ -75,6 +76,96 @@ def patch_rootdir(tmp_path: Path, monkeypatch: Any) -> None:
monkeypatch.setenv("INVOKEAI_ROOT", str(tmp_path)) monkeypatch.setenv("INVOKEAI_ROOT", str(tmp_path))
def test_config_migrator_registers_migrations() -> None:
"""Test that the config migrator registers migrations."""
migrator = ConfigMigrator()
def migration_func(config: AppConfigDict) -> AppConfigDict:
return config
migration_1 = ConfigMigration(from_version=Version("3.0.0"), to_version=Version("4.0.0"), function=migration_func)
migration_2 = ConfigMigration(from_version=Version("4.0.0"), to_version=Version("5.0.0"), function=migration_func)
migrator.register(migration_1)
assert migrator._migrations == {migration_1}
migrator.register(migration_2)
assert migrator._migrations == {migration_1, migration_2}
def test_config_migrator_rejects_duplicate_migrations() -> None:
"""Test that the config migrator rejects duplicate migrations."""
migrator = ConfigMigrator()
def migration_func(config: AppConfigDict) -> AppConfigDict:
return config
migration_1 = ConfigMigration(from_version=Version("3.0.0"), to_version=Version("4.0.0"), function=migration_func)
migrator.register(migration_1)
# Re-register the same migration
with pytest.raises(
ValueError,
match=f"A migration from {migration_1.from_version} or to {migration_1.to_version} has already been registered.",
):
migrator.register(migration_1)
# Register a migration with the same from_version
migration_2 = ConfigMigration(from_version=Version("3.0.0"), to_version=Version("5.0.0"), function=migration_func)
with pytest.raises(
ValueError,
match=f"A migration from {migration_2.from_version} or to {migration_2.to_version} has already been registered.",
):
migrator.register(migration_2)
# Register a migration with the same to_version
migration_3 = ConfigMigration(from_version=Version("3.0.1"), to_version=Version("4.0.0"), function=migration_func)
with pytest.raises(
ValueError,
match=f"A migration from {migration_3.from_version} or to {migration_3.to_version} has already been registered.",
):
migrator.register(migration_3)
def test_config_migrator_contiguous_migrations() -> None:
"""Test that the config migrator requires contiguous migrations."""
migrator = ConfigMigrator()
def migration_1_func(config: AppConfigDict) -> AppConfigDict:
return {"schema_version": "4.0.0"}
def migration_3_func(config: AppConfigDict) -> AppConfigDict:
return {"schema_version": "6.0.0"}
migration_1 = ConfigMigration(from_version=Version("3.0.0"), to_version=Version("4.0.0"), function=migration_1_func)
migration_3 = ConfigMigration(from_version=Version("5.0.0"), to_version=Version("6.0.0"), function=migration_3_func)
migrator.register(migration_1)
migrator.register(migration_3)
with pytest.raises(ValueError, match="Migration functions are not continuous"):
migrator._check_for_discontinuities(migrator._migrations)
def test_config_migrator_runs_migrations() -> None:
"""Test that the config migrator runs migrations."""
migrator = ConfigMigrator()
def migration_1_func(config: AppConfigDict) -> AppConfigDict:
return {"schema_version": "4.0.0"}
def migration_2_func(config: AppConfigDict) -> AppConfigDict:
return {"schema_version": "5.0.0"}
migration_1 = ConfigMigration(from_version=Version("3.0.0"), to_version=Version("4.0.0"), function=migration_1_func)
migration_2 = ConfigMigration(from_version=Version("4.0.0"), to_version=Version("5.0.0"), function=migration_2_func)
migrator.register(migration_1)
migrator.register(migration_2)
original_config = {"schema_version": "3.0.0"}
migrated_config = migrator.run_migrations(original_config)
assert migrated_config == {"schema_version": "5.0.0"}
def test_path_resolution_root_not_set(patch_rootdir: None): def test_path_resolution_root_not_set(patch_rootdir: None):
"""Test path resolutions when the root is not explicitly set.""" """Test path resolutions when the root is not explicitly set."""
config = InvokeAIAppConfig() config = InvokeAIAppConfig()