diff --git a/tests/test_config.py b/tests/test_config.py index 5858bfa47a..80c7ccc950 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,7 +1,7 @@ from contextlib import contextmanager from pathlib import Path from tempfile import TemporaryDirectory -from typing import Any +from typing import Any, Generator import pytest from packaging.version import Version @@ -9,11 +9,13 @@ from pydantic import ValidationError from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.services.config.config_default import ( + CONFIG_SCHEMA_VERSION, DefaultInvokeAIAppConfig, InvokeAIAppConfig, get_config, load_and_migrate_config, ) +from invokeai.app.services.config.config_migrate import ConfigMigrator from invokeai.app.services.shared.graph import Graph from invokeai.frontend.cli.arg_parser import InvokeAIArgs @@ -288,14 +290,49 @@ def test_get_config_writing(patch_rootdir: None, monkeypatch: pytest.MonkeyPatch InvokeAIArgs.did_parse = False +def test_migration_check() -> None: + new_config = ConfigMigrator.migrate({"schema_version": "4.0.0"}) + assert new_config is not None + assert new_config["schema_version"] == CONFIG_SCHEMA_VERSION + + # Does this execute at compile time or run time? + @ConfigMigrator.register(from_version=CONFIG_SCHEMA_VERSION, to_version=CONFIG_SCHEMA_VERSION + ".1") + def ok_migration(config_dict: dict[str, Any]) -> dict[str, Any]: + return config_dict + + new_config = ConfigMigrator.migrate({"schema_version": "4.0.0"}) + assert new_config["schema_version"] == CONFIG_SCHEMA_VERSION + ".1" + + @ConfigMigrator.register(from_version=CONFIG_SCHEMA_VERSION + ".2", to_version=CONFIG_SCHEMA_VERSION + ".3") + def bad_migration(config_dict: dict[str, Any]) -> dict[str, Any]: + return config_dict + + # Because there is no version for "*.1" => "*.2", this should fail. + with pytest.raises(ValueError): + ConfigMigrator.migrate({"schema_version": "4.0.0"}) + + @ConfigMigrator.register(from_version=CONFIG_SCHEMA_VERSION + ".1", to_version=CONFIG_SCHEMA_VERSION + ".2") + def good_migration(config_dict: dict[str, Any]) -> dict[str, Any]: + return config_dict + + # should work now, because there is a continuous path to *.3 + new_config = ConfigMigrator.migrate(new_config) + assert new_config["schema_version"] == CONFIG_SCHEMA_VERSION + ".3" + + @contextmanager -def clear_config(): +def clear_config() -> Generator[None, None, None]: try: yield None finally: get_config.cache_clear() +@pytest.mark.xfail( + reason=""" + Currently this test is failing due to an issue described in issue #5983. +""" +) def test_deny_nodes(): with clear_config(): config = get_config()