diff --git a/tests/test_config.py b/tests/test_config.py index f1b4b4a6ce..5858bfa47a 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,9 +1,11 @@ +from contextlib import contextmanager from pathlib import Path from tempfile import TemporaryDirectory from typing import Any import pytest from packaging.version import Version +from pydantic import ValidationError from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.services.config.config_default import ( @@ -12,6 +14,7 @@ from invokeai.app.services.config.config_default import ( get_config, load_and_migrate_config, ) +from invokeai.app.services.shared.graph import Graph from invokeai.frontend.cli.arg_parser import InvokeAIArgs invalid_v4_0_1_config = """ @@ -285,21 +288,34 @@ def test_get_config_writing(patch_rootdir: None, monkeypatch: pytest.MonkeyPatch InvokeAIArgs.did_parse = False +@contextmanager +def clear_config(): + try: + yield None + finally: + get_config.cache_clear() + + def test_deny_nodes(): - config = get_config() - config.allow_nodes = ["integer", "string", "float"] - config.deny_nodes = ["float"] + with clear_config(): + config = get_config() + config.allow_nodes = ["integer", "string", "float"] + config.deny_nodes = ["float"] - # confirm invocations union will not have denied nodes - all_invocations = BaseInvocation.get_invocations() + # confirm graph validation fails when using denied node + Graph(nodes={"1": {"id": "1", "type": "integer"}}) + Graph(nodes={"1": {"id": "1", "type": "string"}}) - has_integer = len([i for i in all_invocations if i.model_fields.get("type").default == "integer"]) == 1 - has_string = len([i for i in all_invocations if i.model_fields.get("type").default == "string"]) == 1 - has_float = len([i for i in all_invocations if i.model_fields.get("type").default == "float"]) == 1 + with pytest.raises(ValidationError): + Graph(nodes={"1": {"id": "1", "type": "float"}}) - assert has_integer - assert has_string - assert not has_float + # confirm invocations union will not have denied nodes + all_invocations = BaseInvocation.get_invocations() - # may not be necessary - get_config.cache_clear() + has_integer = len([i for i in all_invocations if i.model_fields.get("type").default == "integer"]) == 1 + has_string = len([i for i in all_invocations if i.model_fields.get("type").default == "string"]) == 1 + has_float = len([i for i in all_invocations if i.model_fields.get("type").default == "float"]) == 1 + + assert has_integer + assert has_string + assert not has_float