updated and reinstated the test_deny_nodes() unit test

This commit is contained in:
Lincoln Stein 2024-04-24 22:06:16 -04:00
parent ab086a7069
commit 8144a263de

View File

@ -3,10 +3,9 @@ from tempfile import TemporaryDirectory
from typing import Any
import pytest
from omegaconf import OmegaConf
from packaging.version import Version
from pydantic import ValidationError
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.config.config_default import (
DefaultInvokeAIAppConfig,
InvokeAIAppConfig,
@ -286,50 +285,10 @@ def test_get_config_writing(patch_rootdir: None, monkeypatch: pytest.MonkeyPatch
InvokeAIArgs.did_parse = False
@pytest.mark.xfail(
reason="""
This test fails when run as part of the full test suite.
This test needs to deny nodes from being included in the InvocationsUnion by providing
an app configuration as a test fixture. Pytest executes all test files before running
tests, so the app configuration is already initialized by the time this test runs, and
the InvocationUnion is already created and the denied nodes are not omitted from it.
This test passes when `test_config.py` is tested in isolation.
Perhaps a solution would be to call `get_app_config().parse_args()` in
other test files?
"""
)
def test_deny_nodes(patch_rootdir):
# Allow integer, string and float, but explicitly deny float
allow_deny_nodes_conf = OmegaConf.create(
"""
InvokeAI:
Nodes:
allow_nodes:
- integer
- string
- float
deny_nodes:
- float
"""
)
# must parse config before importing Graph, so its nodes union uses the config
get_config.cache_clear()
conf = get_config()
get_config.cache_clear()
conf.merge_from_file(conf=allow_deny_nodes_conf, argv=[])
from invokeai.app.services.shared.graph import Graph
# confirm graph validation fails when using denied node
Graph(nodes={"1": {"id": "1", "type": "integer"}})
Graph(nodes={"1": {"id": "1", "type": "string"}})
with pytest.raises(ValidationError):
Graph(nodes={"1": {"id": "1", "type": "float"}})
from invokeai.app.invocations.baseinvocation import BaseInvocation
def test_deny_nodes():
config = get_config()
config.allow_nodes = ["integer", "string", "float"]
config.deny_nodes = ["float"]
# confirm invocations union will not have denied nodes
all_invocations = BaseInvocation.get_invocations()
@ -341,3 +300,6 @@ def test_deny_nodes(patch_rootdir):
assert has_integer
assert has_string
assert not has_float
# may not be necessary
get_config.cache_clear()