added test for non-contiguous migration routines

This commit is contained in:
Lincoln Stein 2024-04-28 14:31:38 -04:00
parent d24877561d
commit d852ca7a8d

View File

@ -1,7 +1,7 @@
from contextlib import contextmanager
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any
from typing import Any, Generator
import pytest
from packaging.version import Version
@ -9,11 +9,13 @@ from pydantic import ValidationError
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.config.config_default import (
CONFIG_SCHEMA_VERSION,
DefaultInvokeAIAppConfig,
InvokeAIAppConfig,
get_config,
load_and_migrate_config,
)
from invokeai.app.services.config.config_migrate import ConfigMigrator
from invokeai.app.services.shared.graph import Graph
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
@ -288,14 +290,49 @@ def test_get_config_writing(patch_rootdir: None, monkeypatch: pytest.MonkeyPatch
InvokeAIArgs.did_parse = False
def test_migration_check() -> None:
new_config = ConfigMigrator.migrate({"schema_version": "4.0.0"})
assert new_config is not None
assert new_config["schema_version"] == CONFIG_SCHEMA_VERSION
# Does this execute at compile time or run time?
@ConfigMigrator.register(from_version=CONFIG_SCHEMA_VERSION, to_version=CONFIG_SCHEMA_VERSION + ".1")
def ok_migration(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
new_config = ConfigMigrator.migrate({"schema_version": "4.0.0"})
assert new_config["schema_version"] == CONFIG_SCHEMA_VERSION + ".1"
@ConfigMigrator.register(from_version=CONFIG_SCHEMA_VERSION + ".2", to_version=CONFIG_SCHEMA_VERSION + ".3")
def bad_migration(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
# Because there is no version for "*.1" => "*.2", this should fail.
with pytest.raises(ValueError):
ConfigMigrator.migrate({"schema_version": "4.0.0"})
@ConfigMigrator.register(from_version=CONFIG_SCHEMA_VERSION + ".1", to_version=CONFIG_SCHEMA_VERSION + ".2")
def good_migration(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
# should work now, because there is a continuous path to *.3
new_config = ConfigMigrator.migrate(new_config)
assert new_config["schema_version"] == CONFIG_SCHEMA_VERSION + ".3"
@contextmanager
def clear_config():
def clear_config() -> Generator[None, None, None]:
try:
yield None
finally:
get_config.cache_clear()
@pytest.mark.xfail(
reason="""
Currently this test is failing due to an issue described in issue #5983.
"""
)
def test_deny_nodes():
with clear_config():
config = get_config()