mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: parse config before importing anything else
We need to parse the config before doing anything related to invocations to ensure that the invocations union picks up on denied nodes. - Move that to the top of api_app and cli_app - Wrap subsequent imports in `if True:`, as a hack to satisfy flake8 and not have to noqa every line or the whole file - Add tests to ensure graph validation fails when using a denied node, and that the invocations union does not have denied nodes (this indirectly provides confidence that the generated OpenAPI schema will not include denied nodes)
This commit is contained in:
committed by
Kent Keirsey
parent
1d2636aa90
commit
4395ee3c03
@ -1,5 +1,6 @@
|
||||
import os
|
||||
from typing import Any
|
||||
from pydantic import ValidationError
|
||||
|
||||
import pytest
|
||||
from omegaconf import OmegaConf
|
||||
@ -147,3 +148,43 @@ def test_type_coercion(patch_rootdir):
|
||||
conf.parse_args(argv=["--root=/tmp/foobar"])
|
||||
assert conf.root == Path("/tmp/different")
|
||||
assert isinstance(conf.root, Path)
|
||||
|
||||
|
||||
def test_deny_nodes(patch_rootdir):
|
||||
# Allow integer, string and float, but explicitly deny float
|
||||
allow_deny_nodes_conf = OmegaConf.create(
|
||||
"""
|
||||
InvokeAI:
|
||||
Nodes:
|
||||
allow_nodes:
|
||||
- integer
|
||||
- string
|
||||
- float
|
||||
deny_nodes:
|
||||
- float
|
||||
"""
|
||||
)
|
||||
# must parse config before importing Graph, so its nodes union uses the config
|
||||
conf = InvokeAIAppConfig().get_config()
|
||||
conf.parse_args(conf=allow_deny_nodes_conf, argv=[])
|
||||
from invokeai.app.services.graph import Graph
|
||||
|
||||
# confirm graph validation fails when using denied node
|
||||
Graph(nodes={"1": {"id": "1", "type": "integer"}})
|
||||
Graph(nodes={"1": {"id": "1", "type": "string"}})
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
Graph(nodes={"1": {"id": "1", "type": "float"}})
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
|
||||
# confirm invocations union will not have denied nodes
|
||||
all_invocations = BaseInvocation.get_invocations()
|
||||
|
||||
has_integer = len([i for i in all_invocations if i.__fields__.get("type").default == "integer"]) == 1
|
||||
has_string = len([i for i in all_invocations if i.__fields__.get("type").default == "string"]) == 1
|
||||
has_float = len([i for i in all_invocations if i.__fields__.get("type").default == "float"]) == 1
|
||||
|
||||
assert has_integer
|
||||
assert has_string
|
||||
assert not has_float
|
||||
|
Reference in New Issue
Block a user