feat: parse config before importing anything else

We need to parse the config before doing anything related to invocations to ensure that the invocations union picks up on denied nodes.

- Move that to the top of api_app and cli_app
- Wrap subsequent imports in `if True:`, as a hack to satisfy flake8 and not have to noqa every line or the whole file
- Add tests to ensure graph validation fails when using a denied node, and that the invocations union does not have denied nodes (this indirectly provides confidence that the generated OpenAPI schema will not include denied nodes)
This commit is contained in:
psychedelicious
2023-09-08 10:41:00 +10:00
committed by Kent Keirsey
parent 1d2636aa90
commit 4395ee3c03
3 changed files with 135 additions and 95 deletions

View File

@ -1,5 +1,6 @@
import os
from typing import Any
from pydantic import ValidationError
import pytest
from omegaconf import OmegaConf
@ -147,3 +148,43 @@ def test_type_coercion(patch_rootdir):
conf.parse_args(argv=["--root=/tmp/foobar"])
assert conf.root == Path("/tmp/different")
assert isinstance(conf.root, Path)
def test_deny_nodes(patch_rootdir):
# Allow integer, string and float, but explicitly deny float
allow_deny_nodes_conf = OmegaConf.create(
"""
InvokeAI:
Nodes:
allow_nodes:
- integer
- string
- float
deny_nodes:
- float
"""
)
# must parse config before importing Graph, so its nodes union uses the config
conf = InvokeAIAppConfig().get_config()
conf.parse_args(conf=allow_deny_nodes_conf, argv=[])
from invokeai.app.services.graph import Graph
# confirm graph validation fails when using denied node
Graph(nodes={"1": {"id": "1", "type": "integer"}})
Graph(nodes={"1": {"id": "1", "type": "string"}})
with pytest.raises(ValidationError):
Graph(nodes={"1": {"id": "1", "type": "float"}})
from invokeai.app.invocations.baseinvocation import BaseInvocation
# confirm invocations union will not have denied nodes
all_invocations = BaseInvocation.get_invocations()
has_integer = len([i for i in all_invocations if i.__fields__.get("type").default == "integer"]) == 1
has_string = len([i for i in all_invocations if i.__fields__.get("type").default == "string"]) == 1
has_float = len([i for i in all_invocations if i.__fields__.get("type").default == "float"]) == 1
assert has_integer
assert has_string
assert not has_float