fix(nodes): respect nodes denylist

In #5838 graph validation was updated to resolve the issue with the nodes union and import order. That broke the nodes denylist functionality.

However, because the corresponding test was marked as `xfail`, we didn't catch the issue.

- Fix the nodes denylist handling
- Update the tests
This commit is contained in:
psychedelicious 2024-03-08 12:52:41 +11:00
parent 5118160282
commit f07a46f195
2 changed files with 20 additions and 30 deletions

View File

@ -183,7 +183,7 @@ class BaseInvocation(ABC, BaseModel):
"""Gets a pydantc TypeAdapter for the union of all invocation types.""" """Gets a pydantc TypeAdapter for the union of all invocation types."""
if not cls._typeadapter: if not cls._typeadapter:
InvocationsUnion = TypeAliasType( InvocationsUnion = TypeAliasType(
"InvocationsUnion", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")] "InvocationsUnion", Annotated[Union[tuple(cls.get_invocations())], Field(discriminator="type")]
) )
cls._typeadapter = TypeAdapter(InvocationsUnion) cls._typeadapter = TypeAdapter(InvocationsUnion)
return cls._typeadapter return cls._typeadapter

View File

@ -6,6 +6,10 @@ import pytest
from omegaconf import OmegaConf from omegaconf import OmegaConf
from pydantic import ValidationError from pydantic import ValidationError
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.invocations.primitives import FloatInvocation, IntegerInvocation, StringInvocation
from invokeai.app.services.shared.graph import Graph
@pytest.fixture @pytest.fixture
def patch_rootdir(tmp_path: Path, monkeypatch: Any) -> None: def patch_rootdir(tmp_path: Path, monkeypatch: Any) -> None:
@ -158,58 +162,44 @@ def test_type_coercion(patch_rootdir):
assert isinstance(conf.root, Path) assert isinstance(conf.root, Path)
@pytest.mark.xfail(
reason="""
This test fails when run as part of the full test suite.
This test needs to deny nodes from being included in the InvocationsUnion by providing
an app configuration as a test fixture. Pytest executes all test files before running
tests, so the app configuration is already initialized by the time this test runs, and
the InvocationUnion is already created and the denied nodes are not omitted from it.
This test passes when `test_config.py` is tested in isolation.
Perhaps a solution would be to call `InvokeAIAppConfig.get_config().parse_args()` in
other test files?
"""
)
def test_deny_nodes(patch_rootdir): def test_deny_nodes(patch_rootdir):
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
# Allow integer, string and float, but explicitly deny float # Allow integer, string and float, but explicitly deny float
allow_deny_nodes_conf = OmegaConf.create( allow_deny_nodes_conf = OmegaConf.create(
""" f"""
InvokeAI: InvokeAI:
Nodes: Nodes:
allow_nodes: allow_nodes:
- integer - {IntegerInvocation.get_type()}
- string - {StringInvocation.get_type()}
- float - {FloatInvocation.get_type()}
deny_nodes: deny_nodes:
- float - {FloatInvocation.get_type()}
""" """
) )
# must parse config before importing Graph, so its nodes union uses the config # must parse config before importing Graph, so its nodes union uses the config
conf = InvokeAIAppConfig().get_config() conf = InvokeAIAppConfig().get_config()
conf.parse_args(conf=allow_deny_nodes_conf, argv=[]) conf.parse_args(conf=allow_deny_nodes_conf, argv=[])
from invokeai.app.services.shared.graph import Graph
# confirm graph validation fails when using denied node # confirm graph validation fails when using denied node
Graph(nodes={"1": {"id": "1", "type": "integer"}}) Graph(nodes={"1": IntegerInvocation(value=1)})
Graph(nodes={"1": {"id": "1", "type": "string"}}) Graph(nodes={"1": StringInvocation(value="asdf")})
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
Graph(nodes={"1": {"id": "1", "type": "float"}}) Graph(nodes={"1": FloatInvocation(value=1.0)})
from invokeai.app.invocations.baseinvocation import BaseInvocation # Also test with a dict input
with pytest.raises(ValidationError):
Graph(nodes={"1": {"id": "1", "type": "float"}})
# confirm invocations union will not have denied nodes # confirm invocations union will not have denied nodes
all_invocations = BaseInvocation.get_invocations() all_invocations = BaseInvocation.get_invocations()
has_integer = len([i for i in all_invocations if i.model_fields.get("type").default == "integer"]) == 1 has_integer = len([i for i in all_invocations if i.get_type() == "integer"]) == 1
has_string = len([i for i in all_invocations if i.model_fields.get("type").default == "string"]) == 1 has_string = len([i for i in all_invocations if i.get_type() == "string"]) == 1
has_float = len([i for i in all_invocations if i.model_fields.get("type").default == "float"]) == 1 does_not_have_float = len([i for i in all_invocations if i.get_type() == "float"]) == 0
assert has_integer assert has_integer
assert has_string assert has_string
assert not has_float assert does_not_have_float