reinstated failing deny_nodes validation test for Graph

This commit is contained in:
Lincoln Stein 2024-04-25 00:22:09 -04:00
parent 8144a263de
commit d24877561d

View File

@ -1,9 +1,11 @@
from contextlib import contextmanager
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any
import pytest
from packaging.version import Version
from pydantic import ValidationError
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.config.config_default import (
@ -12,6 +14,7 @@ from invokeai.app.services.config.config_default import (
get_config,
load_and_migrate_config,
)
from invokeai.app.services.shared.graph import Graph
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
invalid_v4_0_1_config = """
@ -285,21 +288,34 @@ def test_get_config_writing(patch_rootdir: None, monkeypatch: pytest.MonkeyPatch
InvokeAIArgs.did_parse = False
@contextmanager
def clear_config():
try:
yield None
finally:
get_config.cache_clear()
def test_deny_nodes():
config = get_config()
config.allow_nodes = ["integer", "string", "float"]
config.deny_nodes = ["float"]
with clear_config():
config = get_config()
config.allow_nodes = ["integer", "string", "float"]
config.deny_nodes = ["float"]
# confirm invocations union will not have denied nodes
all_invocations = BaseInvocation.get_invocations()
# confirm graph validation fails when using denied node
Graph(nodes={"1": {"id": "1", "type": "integer"}})
Graph(nodes={"1": {"id": "1", "type": "string"}})
has_integer = len([i for i in all_invocations if i.model_fields.get("type").default == "integer"]) == 1
has_string = len([i for i in all_invocations if i.model_fields.get("type").default == "string"]) == 1
has_float = len([i for i in all_invocations if i.model_fields.get("type").default == "float"]) == 1
with pytest.raises(ValidationError):
Graph(nodes={"1": {"id": "1", "type": "float"}})
assert has_integer
assert has_string
assert not has_float
# confirm invocations union will not have denied nodes
all_invocations = BaseInvocation.get_invocations()
# may not be necessary
get_config.cache_clear()
has_integer = len([i for i in all_invocations if i.model_fields.get("type").default == "integer"]) == 1
has_string = len([i for i in all_invocations if i.model_fields.get("type").default == "string"]) == 1
has_float = len([i for i in all_invocations if i.model_fields.get("type").default == "float"]) == 1
assert has_integer
assert has_string
assert not has_float