From f07a46f195f89a9c1d80cedf8437bb5b26cfca8c Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 8 Mar 2024 12:52:41 +1100 Subject: [PATCH] fix(nodes): respect nodes denylist In #5838 graph validation was updated to resolve the issue with the nodes union and import order. That broke the nodes denylist functionality. However, because the corresponding test was marked as `xfail`, we didn't catch the issue. - Fix the nodes denylist handling - Update the tests --- invokeai/app/invocations/baseinvocation.py | 2 +- tests/test_config.py | 48 +++++++++------------- 2 files changed, 20 insertions(+), 30 deletions(-) diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 5edae5342d..a1303d38ef 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -183,7 +183,7 @@ class BaseInvocation(ABC, BaseModel): """Gets a pydantc TypeAdapter for the union of all invocation types.""" if not cls._typeadapter: InvocationsUnion = TypeAliasType( - "InvocationsUnion", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")] + "InvocationsUnion", Annotated[Union[tuple(cls.get_invocations())], Field(discriminator="type")] ) cls._typeadapter = TypeAdapter(InvocationsUnion) return cls._typeadapter diff --git a/tests/test_config.py b/tests/test_config.py index 740a4866dd..c33a344eee 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -6,6 +6,10 @@ import pytest from omegaconf import OmegaConf from pydantic import ValidationError +from invokeai.app.invocations.baseinvocation import BaseInvocation +from invokeai.app.invocations.primitives import FloatInvocation, IntegerInvocation, StringInvocation +from invokeai.app.services.shared.graph import Graph + @pytest.fixture def patch_rootdir(tmp_path: Path, monkeypatch: Any) -> None: @@ -158,58 +162,44 @@ def test_type_coercion(patch_rootdir): assert isinstance(conf.root, Path) -@pytest.mark.xfail( - reason=""" - This test fails when run as part of the full test suite. - - This test needs to deny nodes from being included in the InvocationsUnion by providing - an app configuration as a test fixture. Pytest executes all test files before running - tests, so the app configuration is already initialized by the time this test runs, and - the InvocationUnion is already created and the denied nodes are not omitted from it. - - This test passes when `test_config.py` is tested in isolation. - - Perhaps a solution would be to call `InvokeAIAppConfig.get_config().parse_args()` in - other test files? - """ -) def test_deny_nodes(patch_rootdir): from invokeai.app.services.config import InvokeAIAppConfig # Allow integer, string and float, but explicitly deny float allow_deny_nodes_conf = OmegaConf.create( - """ + f""" InvokeAI: Nodes: allow_nodes: - - integer - - string - - float + - {IntegerInvocation.get_type()} + - {StringInvocation.get_type()} + - {FloatInvocation.get_type()} deny_nodes: - - float + - {FloatInvocation.get_type()} """ ) # must parse config before importing Graph, so its nodes union uses the config conf = InvokeAIAppConfig().get_config() conf.parse_args(conf=allow_deny_nodes_conf, argv=[]) - from invokeai.app.services.shared.graph import Graph # confirm graph validation fails when using denied node - Graph(nodes={"1": {"id": "1", "type": "integer"}}) - Graph(nodes={"1": {"id": "1", "type": "string"}}) + Graph(nodes={"1": IntegerInvocation(value=1)}) + Graph(nodes={"1": StringInvocation(value="asdf")}) with pytest.raises(ValidationError): - Graph(nodes={"1": {"id": "1", "type": "float"}}) + Graph(nodes={"1": FloatInvocation(value=1.0)}) - from invokeai.app.invocations.baseinvocation import BaseInvocation + # Also test with a dict input + with pytest.raises(ValidationError): + Graph(nodes={"1": {"id": "1", "type": "float"}}) # confirm invocations union will not have denied nodes all_invocations = BaseInvocation.get_invocations() - has_integer = len([i for i in all_invocations if i.model_fields.get("type").default == "integer"]) == 1 - has_string = len([i for i in all_invocations if i.model_fields.get("type").default == "string"]) == 1 - has_float = len([i for i in all_invocations if i.model_fields.get("type").default == "float"]) == 1 + has_integer = len([i for i in all_invocations if i.get_type() == "integer"]) == 1 + has_string = len([i for i in all_invocations if i.get_type() == "string"]) == 1 + does_not_have_float = len([i for i in all_invocations if i.get_type() == "float"]) == 0 assert has_integer assert has_string - assert not has_float + assert does_not_have_float