mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
reinstated failing deny_nodes validation test for Graph
This commit is contained in:
parent
8144a263de
commit
d24877561d
@ -1,9 +1,11 @@
|
|||||||
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||||
from invokeai.app.services.config.config_default import (
|
from invokeai.app.services.config.config_default import (
|
||||||
@ -12,6 +14,7 @@ from invokeai.app.services.config.config_default import (
|
|||||||
get_config,
|
get_config,
|
||||||
load_and_migrate_config,
|
load_and_migrate_config,
|
||||||
)
|
)
|
||||||
|
from invokeai.app.services.shared.graph import Graph
|
||||||
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
||||||
|
|
||||||
invalid_v4_0_1_config = """
|
invalid_v4_0_1_config = """
|
||||||
@ -285,21 +288,34 @@ def test_get_config_writing(patch_rootdir: None, monkeypatch: pytest.MonkeyPatch
|
|||||||
InvokeAIArgs.did_parse = False
|
InvokeAIArgs.did_parse = False
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def clear_config():
|
||||||
|
try:
|
||||||
|
yield None
|
||||||
|
finally:
|
||||||
|
get_config.cache_clear()
|
||||||
|
|
||||||
|
|
||||||
def test_deny_nodes():
|
def test_deny_nodes():
|
||||||
config = get_config()
|
with clear_config():
|
||||||
config.allow_nodes = ["integer", "string", "float"]
|
config = get_config()
|
||||||
config.deny_nodes = ["float"]
|
config.allow_nodes = ["integer", "string", "float"]
|
||||||
|
config.deny_nodes = ["float"]
|
||||||
|
|
||||||
# confirm invocations union will not have denied nodes
|
# confirm graph validation fails when using denied node
|
||||||
all_invocations = BaseInvocation.get_invocations()
|
Graph(nodes={"1": {"id": "1", "type": "integer"}})
|
||||||
|
Graph(nodes={"1": {"id": "1", "type": "string"}})
|
||||||
|
|
||||||
has_integer = len([i for i in all_invocations if i.model_fields.get("type").default == "integer"]) == 1
|
with pytest.raises(ValidationError):
|
||||||
has_string = len([i for i in all_invocations if i.model_fields.get("type").default == "string"]) == 1
|
Graph(nodes={"1": {"id": "1", "type": "float"}})
|
||||||
has_float = len([i for i in all_invocations if i.model_fields.get("type").default == "float"]) == 1
|
|
||||||
|
|
||||||
assert has_integer
|
# confirm invocations union will not have denied nodes
|
||||||
assert has_string
|
all_invocations = BaseInvocation.get_invocations()
|
||||||
assert not has_float
|
|
||||||
|
|
||||||
# may not be necessary
|
has_integer = len([i for i in all_invocations if i.model_fields.get("type").default == "integer"]) == 1
|
||||||
get_config.cache_clear()
|
has_string = len([i for i in all_invocations if i.model_fields.get("type").default == "string"]) == 1
|
||||||
|
has_float = len([i for i in all_invocations if i.model_fields.get("type").default == "float"]) == 1
|
||||||
|
|
||||||
|
assert has_integer
|
||||||
|
assert has_string
|
||||||
|
assert not has_float
|
||||||
|
Loading…
Reference in New Issue
Block a user