diff --git a/tests/nodes/test_node_graph.py b/tests/nodes/test_node_graph.py index fe6709827f..56bf823d14 100644 --- a/tests/nodes/test_node_graph.py +++ b/tests/nodes/test_node_graph.py @@ -1,3 +1,4 @@ +from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output from .test_nodes import ( ImageToImageTestInvocation, TextToImageTestInvocation, @@ -20,7 +21,7 @@ from invokeai.app.invocations.upscale import ESRGANInvocation from invokeai.app.invocations.image import ShowImageInvocation from invokeai.app.invocations.math import AddInvocation, SubtractInvocation -from invokeai.app.invocations.primitives import IntegerInvocation +from invokeai.app.invocations.primitives import FloatInvocation, IntegerInvocation from invokeai.app.services.default_graphs import create_text_to_image import pytest @@ -610,6 +611,59 @@ def test_graph_can_deserialize(): assert g2.edges[0].destination.field == "image" +def test_invocation_decorator(): + invocation_type = "test_invocation" + title = "Test Invocation" + tags = ["first", "second", "third"] + category = "category" + + @invocation(invocation_type, title=title, tags=tags, category=category) + class Test(BaseInvocation): + def invoke(self): + pass + + schema = Test.schema() + + assert schema.get("title") == title + assert schema.get("tags") == tags + assert schema.get("category") == category + assert Test(id="1").type == invocation_type # type: ignore (type is dynamically added) + + +def test_invocation_output_decorator(): + output_type = "test_output" + + @invocation_output(output_type) + class TestOutput(BaseInvocationOutput): + pass + + assert TestOutput().type == output_type # type: ignore (type is dynamically added) + + +def test_floats_accept_ints(): + g = Graph() + n1 = IntegerInvocation(id="1", value=1) + n2 = FloatInvocation(id="2") + g.add_node(n1) + g.add_node(n2) + e = create_edge(n1.id, "value", n2.id, "value") + + # Not throwing on this line is sufficient + g.add_edge(e) + + +def test_ints_do_not_accept_floats(): + g = Graph() + n1 = FloatInvocation(id="1", value=1.0) + n2 = IntegerInvocation(id="2") + g.add_node(n1) + g.add_node(n2) + e = create_edge(n1.id, "value", n2.id, "value") + + with pytest.raises(InvalidEdgeError): + g.add_edge(e) + + def test_graph_can_generate_schema(): # Not throwing on this line is sufficient # NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation