reinstated failing deny_nodes validation test for Graph

This commit is contained in:
Lincoln Stein 2024-04-25 00:22:09 -04:00
parent 8144a263de
commit d24877561d

View File

@ -1,9 +1,11 @@
from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Any from typing import Any
import pytest import pytest
from packaging.version import Version from packaging.version import Version
from pydantic import ValidationError
from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.config.config_default import ( from invokeai.app.services.config.config_default import (
@ -12,6 +14,7 @@ from invokeai.app.services.config.config_default import (
get_config, get_config,
load_and_migrate_config, load_and_migrate_config,
) )
from invokeai.app.services.shared.graph import Graph
from invokeai.frontend.cli.arg_parser import InvokeAIArgs from invokeai.frontend.cli.arg_parser import InvokeAIArgs
invalid_v4_0_1_config = """ invalid_v4_0_1_config = """
@ -285,21 +288,34 @@ def test_get_config_writing(patch_rootdir: None, monkeypatch: pytest.MonkeyPatch
InvokeAIArgs.did_parse = False InvokeAIArgs.did_parse = False
@contextmanager
def clear_config():
try:
yield None
finally:
get_config.cache_clear()
def test_deny_nodes(): def test_deny_nodes():
config = get_config() with clear_config():
config.allow_nodes = ["integer", "string", "float"] config = get_config()
config.deny_nodes = ["float"] config.allow_nodes = ["integer", "string", "float"]
config.deny_nodes = ["float"]
# confirm invocations union will not have denied nodes # confirm graph validation fails when using denied node
all_invocations = BaseInvocation.get_invocations() Graph(nodes={"1": {"id": "1", "type": "integer"}})
Graph(nodes={"1": {"id": "1", "type": "string"}})
has_integer = len([i for i in all_invocations if i.model_fields.get("type").default == "integer"]) == 1 with pytest.raises(ValidationError):
has_string = len([i for i in all_invocations if i.model_fields.get("type").default == "string"]) == 1 Graph(nodes={"1": {"id": "1", "type": "float"}})
has_float = len([i for i in all_invocations if i.model_fields.get("type").default == "float"]) == 1
assert has_integer # confirm invocations union will not have denied nodes
assert has_string all_invocations = BaseInvocation.get_invocations()
assert not has_float
# may not be necessary has_integer = len([i for i in all_invocations if i.model_fields.get("type").default == "integer"]) == 1
get_config.cache_clear() has_string = len([i for i in all_invocations if i.model_fields.get("type").default == "string"]) == 1
has_float = len([i for i in all_invocations if i.model_fields.get("type").default == "float"]) == 1
assert has_integer
assert has_string
assert not has_float