mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
reinstated failing deny_nodes validation test for Graph
This commit is contained in:
parent
8144a263de
commit
d24877561d
@ -1,9 +1,11 @@
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from packaging.version import Version
|
||||
from pydantic import ValidationError
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.services.config.config_default import (
|
||||
@ -12,6 +14,7 @@ from invokeai.app.services.config.config_default import (
|
||||
get_config,
|
||||
load_and_migrate_config,
|
||||
)
|
||||
from invokeai.app.services.shared.graph import Graph
|
||||
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
||||
|
||||
invalid_v4_0_1_config = """
|
||||
@ -285,21 +288,34 @@ def test_get_config_writing(patch_rootdir: None, monkeypatch: pytest.MonkeyPatch
|
||||
InvokeAIArgs.did_parse = False
|
||||
|
||||
|
||||
@contextmanager
|
||||
def clear_config():
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
get_config.cache_clear()
|
||||
|
||||
|
||||
def test_deny_nodes():
|
||||
config = get_config()
|
||||
config.allow_nodes = ["integer", "string", "float"]
|
||||
config.deny_nodes = ["float"]
|
||||
with clear_config():
|
||||
config = get_config()
|
||||
config.allow_nodes = ["integer", "string", "float"]
|
||||
config.deny_nodes = ["float"]
|
||||
|
||||
# confirm invocations union will not have denied nodes
|
||||
all_invocations = BaseInvocation.get_invocations()
|
||||
# confirm graph validation fails when using denied node
|
||||
Graph(nodes={"1": {"id": "1", "type": "integer"}})
|
||||
Graph(nodes={"1": {"id": "1", "type": "string"}})
|
||||
|
||||
has_integer = len([i for i in all_invocations if i.model_fields.get("type").default == "integer"]) == 1
|
||||
has_string = len([i for i in all_invocations if i.model_fields.get("type").default == "string"]) == 1
|
||||
has_float = len([i for i in all_invocations if i.model_fields.get("type").default == "float"]) == 1
|
||||
with pytest.raises(ValidationError):
|
||||
Graph(nodes={"1": {"id": "1", "type": "float"}})
|
||||
|
||||
assert has_integer
|
||||
assert has_string
|
||||
assert not has_float
|
||||
# confirm invocations union will not have denied nodes
|
||||
all_invocations = BaseInvocation.get_invocations()
|
||||
|
||||
# may not be necessary
|
||||
get_config.cache_clear()
|
||||
has_integer = len([i for i in all_invocations if i.model_fields.get("type").default == "integer"]) == 1
|
||||
has_string = len([i for i in all_invocations if i.model_fields.get("type").default == "string"]) == 1
|
||||
has_float = len([i for i in all_invocations if i.model_fields.get("type").default == "float"]) == 1
|
||||
|
||||
assert has_integer
|
||||
assert has_string
|
||||
assert not has_float
|
||||
|
Loading…
Reference in New Issue
Block a user