mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(nodes): respect nodes denylist
In #5838 graph validation was updated to resolve the issue with the nodes union and import order. That broke the nodes denylist functionality. However, because the corresponding test was marked as `xfail`, we didn't catch the issue. - Fix the nodes denylist handling - Update the tests
This commit is contained in:
parent
5118160282
commit
f07a46f195
@ -183,7 +183,7 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
"""Gets a pydantc TypeAdapter for the union of all invocation types."""
|
"""Gets a pydantc TypeAdapter for the union of all invocation types."""
|
||||||
if not cls._typeadapter:
|
if not cls._typeadapter:
|
||||||
InvocationsUnion = TypeAliasType(
|
InvocationsUnion = TypeAliasType(
|
||||||
"InvocationsUnion", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
|
"InvocationsUnion", Annotated[Union[tuple(cls.get_invocations())], Field(discriminator="type")]
|
||||||
)
|
)
|
||||||
cls._typeadapter = TypeAdapter(InvocationsUnion)
|
cls._typeadapter = TypeAdapter(InvocationsUnion)
|
||||||
return cls._typeadapter
|
return cls._typeadapter
|
||||||
|
@ -6,6 +6,10 @@ import pytest
|
|||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||||
|
from invokeai.app.invocations.primitives import FloatInvocation, IntegerInvocation, StringInvocation
|
||||||
|
from invokeai.app.services.shared.graph import Graph
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def patch_rootdir(tmp_path: Path, monkeypatch: Any) -> None:
|
def patch_rootdir(tmp_path: Path, monkeypatch: Any) -> None:
|
||||||
@ -158,58 +162,44 @@ def test_type_coercion(patch_rootdir):
|
|||||||
assert isinstance(conf.root, Path)
|
assert isinstance(conf.root, Path)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xfail(
|
|
||||||
reason="""
|
|
||||||
This test fails when run as part of the full test suite.
|
|
||||||
|
|
||||||
This test needs to deny nodes from being included in the InvocationsUnion by providing
|
|
||||||
an app configuration as a test fixture. Pytest executes all test files before running
|
|
||||||
tests, so the app configuration is already initialized by the time this test runs, and
|
|
||||||
the InvocationUnion is already created and the denied nodes are not omitted from it.
|
|
||||||
|
|
||||||
This test passes when `test_config.py` is tested in isolation.
|
|
||||||
|
|
||||||
Perhaps a solution would be to call `InvokeAIAppConfig.get_config().parse_args()` in
|
|
||||||
other test files?
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
def test_deny_nodes(patch_rootdir):
|
def test_deny_nodes(patch_rootdir):
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
# Allow integer, string and float, but explicitly deny float
|
# Allow integer, string and float, but explicitly deny float
|
||||||
allow_deny_nodes_conf = OmegaConf.create(
|
allow_deny_nodes_conf = OmegaConf.create(
|
||||||
"""
|
f"""
|
||||||
InvokeAI:
|
InvokeAI:
|
||||||
Nodes:
|
Nodes:
|
||||||
allow_nodes:
|
allow_nodes:
|
||||||
- integer
|
- {IntegerInvocation.get_type()}
|
||||||
- string
|
- {StringInvocation.get_type()}
|
||||||
- float
|
- {FloatInvocation.get_type()}
|
||||||
deny_nodes:
|
deny_nodes:
|
||||||
- float
|
- {FloatInvocation.get_type()}
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
# must parse config before importing Graph, so its nodes union uses the config
|
# must parse config before importing Graph, so its nodes union uses the config
|
||||||
conf = InvokeAIAppConfig().get_config()
|
conf = InvokeAIAppConfig().get_config()
|
||||||
conf.parse_args(conf=allow_deny_nodes_conf, argv=[])
|
conf.parse_args(conf=allow_deny_nodes_conf, argv=[])
|
||||||
from invokeai.app.services.shared.graph import Graph
|
|
||||||
|
|
||||||
# confirm graph validation fails when using denied node
|
# confirm graph validation fails when using denied node
|
||||||
Graph(nodes={"1": {"id": "1", "type": "integer"}})
|
Graph(nodes={"1": IntegerInvocation(value=1)})
|
||||||
Graph(nodes={"1": {"id": "1", "type": "string"}})
|
Graph(nodes={"1": StringInvocation(value="asdf")})
|
||||||
|
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
Graph(nodes={"1": {"id": "1", "type": "float"}})
|
Graph(nodes={"1": FloatInvocation(value=1.0)})
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
# Also test with a dict input
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Graph(nodes={"1": {"id": "1", "type": "float"}})
|
||||||
|
|
||||||
# confirm invocations union will not have denied nodes
|
# confirm invocations union will not have denied nodes
|
||||||
all_invocations = BaseInvocation.get_invocations()
|
all_invocations = BaseInvocation.get_invocations()
|
||||||
|
|
||||||
has_integer = len([i for i in all_invocations if i.model_fields.get("type").default == "integer"]) == 1
|
has_integer = len([i for i in all_invocations if i.get_type() == "integer"]) == 1
|
||||||
has_string = len([i for i in all_invocations if i.model_fields.get("type").default == "string"]) == 1
|
has_string = len([i for i in all_invocations if i.get_type() == "string"]) == 1
|
||||||
has_float = len([i for i in all_invocations if i.model_fields.get("type").default == "float"]) == 1
|
does_not_have_float = len([i for i in all_invocations if i.get_type() == "float"]) == 0
|
||||||
|
|
||||||
assert has_integer
|
assert has_integer
|
||||||
assert has_string
|
assert has_string
|
||||||
assert not has_float
|
assert does_not_have_float
|
||||||
|
Loading…
Reference in New Issue
Block a user