mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(nodes): respect nodes denylist
In #5838 graph validation was updated to resolve the issue with the nodes union and import order. That broke the nodes denylist functionality. However, because the corresponding test was marked as `xfail`, we didn't catch the issue. - Fix the nodes denylist handling - Update the tests
This commit is contained in:
parent
5118160282
commit
f07a46f195
@ -183,7 +183,7 @@ class BaseInvocation(ABC, BaseModel):
|
||||
"""Gets a pydantc TypeAdapter for the union of all invocation types."""
|
||||
if not cls._typeadapter:
|
||||
InvocationsUnion = TypeAliasType(
|
||||
"InvocationsUnion", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
|
||||
"InvocationsUnion", Annotated[Union[tuple(cls.get_invocations())], Field(discriminator="type")]
|
||||
)
|
||||
cls._typeadapter = TypeAdapter(InvocationsUnion)
|
||||
return cls._typeadapter
|
||||
|
@ -6,6 +6,10 @@ import pytest
|
||||
from omegaconf import OmegaConf
|
||||
from pydantic import ValidationError
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.invocations.primitives import FloatInvocation, IntegerInvocation, StringInvocation
|
||||
from invokeai.app.services.shared.graph import Graph
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_rootdir(tmp_path: Path, monkeypatch: Any) -> None:
|
||||
@ -158,58 +162,44 @@ def test_type_coercion(patch_rootdir):
|
||||
assert isinstance(conf.root, Path)
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="""
|
||||
This test fails when run as part of the full test suite.
|
||||
|
||||
This test needs to deny nodes from being included in the InvocationsUnion by providing
|
||||
an app configuration as a test fixture. Pytest executes all test files before running
|
||||
tests, so the app configuration is already initialized by the time this test runs, and
|
||||
the InvocationUnion is already created and the denied nodes are not omitted from it.
|
||||
|
||||
This test passes when `test_config.py` is tested in isolation.
|
||||
|
||||
Perhaps a solution would be to call `InvokeAIAppConfig.get_config().parse_args()` in
|
||||
other test files?
|
||||
"""
|
||||
)
|
||||
def test_deny_nodes(patch_rootdir):
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
# Allow integer, string and float, but explicitly deny float
|
||||
allow_deny_nodes_conf = OmegaConf.create(
|
||||
"""
|
||||
f"""
|
||||
InvokeAI:
|
||||
Nodes:
|
||||
allow_nodes:
|
||||
- integer
|
||||
- string
|
||||
- float
|
||||
- {IntegerInvocation.get_type()}
|
||||
- {StringInvocation.get_type()}
|
||||
- {FloatInvocation.get_type()}
|
||||
deny_nodes:
|
||||
- float
|
||||
- {FloatInvocation.get_type()}
|
||||
"""
|
||||
)
|
||||
# must parse config before importing Graph, so its nodes union uses the config
|
||||
conf = InvokeAIAppConfig().get_config()
|
||||
conf.parse_args(conf=allow_deny_nodes_conf, argv=[])
|
||||
from invokeai.app.services.shared.graph import Graph
|
||||
|
||||
# confirm graph validation fails when using denied node
|
||||
Graph(nodes={"1": {"id": "1", "type": "integer"}})
|
||||
Graph(nodes={"1": {"id": "1", "type": "string"}})
|
||||
Graph(nodes={"1": IntegerInvocation(value=1)})
|
||||
Graph(nodes={"1": StringInvocation(value="asdf")})
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
Graph(nodes={"1": {"id": "1", "type": "float"}})
|
||||
Graph(nodes={"1": FloatInvocation(value=1.0)})
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
# Also test with a dict input
|
||||
with pytest.raises(ValidationError):
|
||||
Graph(nodes={"1": {"id": "1", "type": "float"}})
|
||||
|
||||
# confirm invocations union will not have denied nodes
|
||||
all_invocations = BaseInvocation.get_invocations()
|
||||
|
||||
has_integer = len([i for i in all_invocations if i.model_fields.get("type").default == "integer"]) == 1
|
||||
has_string = len([i for i in all_invocations if i.model_fields.get("type").default == "string"]) == 1
|
||||
has_float = len([i for i in all_invocations if i.model_fields.get("type").default == "float"]) == 1
|
||||
has_integer = len([i for i in all_invocations if i.get_type() == "integer"]) == 1
|
||||
has_string = len([i for i in all_invocations if i.get_type() == "string"]) == 1
|
||||
does_not_have_float = len([i for i in all_invocations if i.get_type() == "float"]) == 0
|
||||
|
||||
assert has_integer
|
||||
assert has_string
|
||||
assert not has_float
|
||||
assert does_not_have_float
|
||||
|
Loading…
Reference in New Issue
Block a user