feat(tests): add tests for decorator and int -> float

This commit is contained in:
psychedelicious 2023-09-04 19:07:41 +10:00
parent 920fc0e751
commit 59cb6305b9

View File

@ -1,3 +1,4 @@
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from .test_nodes import (
ImageToImageTestInvocation,
TextToImageTestInvocation,
@ -20,7 +21,7 @@ from invokeai.app.invocations.upscale import ESRGANInvocation
from invokeai.app.invocations.image import ShowImageInvocation
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
from invokeai.app.invocations.primitives import IntegerInvocation
from invokeai.app.invocations.primitives import FloatInvocation, IntegerInvocation
from invokeai.app.services.default_graphs import create_text_to_image
import pytest
@ -610,6 +611,59 @@ def test_graph_can_deserialize():
assert g2.edges[0].destination.field == "image"
def test_invocation_decorator():
invocation_type = "test_invocation"
title = "Test Invocation"
tags = ["first", "second", "third"]
category = "category"
@invocation(invocation_type, title=title, tags=tags, category=category)
class Test(BaseInvocation):
def invoke(self):
pass
schema = Test.schema()
assert schema.get("title") == title
assert schema.get("tags") == tags
assert schema.get("category") == category
assert Test(id="1").type == invocation_type # type: ignore (type is dynamically added)
def test_invocation_output_decorator():
output_type = "test_output"
@invocation_output(output_type)
class TestOutput(BaseInvocationOutput):
pass
assert TestOutput().type == output_type # type: ignore (type is dynamically added)
def test_floats_accept_ints():
g = Graph()
n1 = IntegerInvocation(id="1", value=1)
n2 = FloatInvocation(id="2")
g.add_node(n1)
g.add_node(n2)
e = create_edge(n1.id, "value", n2.id, "value")
# Not throwing on this line is sufficient
g.add_edge(e)
def test_ints_do_not_accept_floats():
g = Graph()
n1 = FloatInvocation(id="1", value=1.0)
n2 = IntegerInvocation(id="2")
g.add_node(n1)
g.add_node(n2)
e = create_edge(n1.id, "value", n2.id, "value")
with pytest.raises(InvalidEdgeError):
g.add_edge(e)
def test_graph_can_generate_schema():
# Not throwing on this line is sufficient
# NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation