feat(tests): add tests for decorator and int -> float

This commit is contained in:
psychedelicious 2023-09-04 19:07:41 +10:00
parent 920fc0e751
commit 59cb6305b9

View File

@ -1,3 +1,4 @@
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from .test_nodes import ( from .test_nodes import (
ImageToImageTestInvocation, ImageToImageTestInvocation,
TextToImageTestInvocation, TextToImageTestInvocation,
@ -20,7 +21,7 @@ from invokeai.app.invocations.upscale import ESRGANInvocation
from invokeai.app.invocations.image import ShowImageInvocation from invokeai.app.invocations.image import ShowImageInvocation
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
from invokeai.app.invocations.primitives import IntegerInvocation from invokeai.app.invocations.primitives import FloatInvocation, IntegerInvocation
from invokeai.app.services.default_graphs import create_text_to_image from invokeai.app.services.default_graphs import create_text_to_image
import pytest import pytest
@ -610,6 +611,59 @@ def test_graph_can_deserialize():
assert g2.edges[0].destination.field == "image" assert g2.edges[0].destination.field == "image"
def test_invocation_decorator():
invocation_type = "test_invocation"
title = "Test Invocation"
tags = ["first", "second", "third"]
category = "category"
@invocation(invocation_type, title=title, tags=tags, category=category)
class Test(BaseInvocation):
def invoke(self):
pass
schema = Test.schema()
assert schema.get("title") == title
assert schema.get("tags") == tags
assert schema.get("category") == category
assert Test(id="1").type == invocation_type # type: ignore (type is dynamically added)
def test_invocation_output_decorator():
output_type = "test_output"
@invocation_output(output_type)
class TestOutput(BaseInvocationOutput):
pass
assert TestOutput().type == output_type # type: ignore (type is dynamically added)
def test_floats_accept_ints():
g = Graph()
n1 = IntegerInvocation(id="1", value=1)
n2 = FloatInvocation(id="2")
g.add_node(n1)
g.add_node(n2)
e = create_edge(n1.id, "value", n2.id, "value")
# Not throwing on this line is sufficient
g.add_edge(e)
def test_ints_do_not_accept_floats():
g = Graph()
n1 = FloatInvocation(id="1", value=1.0)
n2 = IntegerInvocation(id="2")
g.add_node(n1)
g.add_node(n2)
e = create_edge(n1.id, "value", n2.id, "value")
with pytest.raises(InvalidEdgeError):
g.add_edge(e)
def test_graph_can_generate_schema(): def test_graph_can_generate_schema():
# Not throwing on this line is sufficient # Not throwing on this line is sufficient
# NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation # NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation