mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(tests): add tests for decorator and int -> float
This commit is contained in:
parent
920fc0e751
commit
59cb6305b9
@ -1,3 +1,4 @@
|
|||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||||
from .test_nodes import (
|
from .test_nodes import (
|
||||||
ImageToImageTestInvocation,
|
ImageToImageTestInvocation,
|
||||||
TextToImageTestInvocation,
|
TextToImageTestInvocation,
|
||||||
@ -20,7 +21,7 @@ from invokeai.app.invocations.upscale import ESRGANInvocation
|
|||||||
|
|
||||||
from invokeai.app.invocations.image import ShowImageInvocation
|
from invokeai.app.invocations.image import ShowImageInvocation
|
||||||
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
|
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
|
||||||
from invokeai.app.invocations.primitives import IntegerInvocation
|
from invokeai.app.invocations.primitives import FloatInvocation, IntegerInvocation
|
||||||
from invokeai.app.services.default_graphs import create_text_to_image
|
from invokeai.app.services.default_graphs import create_text_to_image
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -610,6 +611,59 @@ def test_graph_can_deserialize():
|
|||||||
assert g2.edges[0].destination.field == "image"
|
assert g2.edges[0].destination.field == "image"
|
||||||
|
|
||||||
|
|
||||||
|
def test_invocation_decorator():
|
||||||
|
invocation_type = "test_invocation"
|
||||||
|
title = "Test Invocation"
|
||||||
|
tags = ["first", "second", "third"]
|
||||||
|
category = "category"
|
||||||
|
|
||||||
|
@invocation(invocation_type, title=title, tags=tags, category=category)
|
||||||
|
class Test(BaseInvocation):
|
||||||
|
def invoke(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
schema = Test.schema()
|
||||||
|
|
||||||
|
assert schema.get("title") == title
|
||||||
|
assert schema.get("tags") == tags
|
||||||
|
assert schema.get("category") == category
|
||||||
|
assert Test(id="1").type == invocation_type # type: ignore (type is dynamically added)
|
||||||
|
|
||||||
|
|
||||||
|
def test_invocation_output_decorator():
|
||||||
|
output_type = "test_output"
|
||||||
|
|
||||||
|
@invocation_output(output_type)
|
||||||
|
class TestOutput(BaseInvocationOutput):
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert TestOutput().type == output_type # type: ignore (type is dynamically added)
|
||||||
|
|
||||||
|
|
||||||
|
def test_floats_accept_ints():
|
||||||
|
g = Graph()
|
||||||
|
n1 = IntegerInvocation(id="1", value=1)
|
||||||
|
n2 = FloatInvocation(id="2")
|
||||||
|
g.add_node(n1)
|
||||||
|
g.add_node(n2)
|
||||||
|
e = create_edge(n1.id, "value", n2.id, "value")
|
||||||
|
|
||||||
|
# Not throwing on this line is sufficient
|
||||||
|
g.add_edge(e)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ints_do_not_accept_floats():
|
||||||
|
g = Graph()
|
||||||
|
n1 = FloatInvocation(id="1", value=1.0)
|
||||||
|
n2 = IntegerInvocation(id="2")
|
||||||
|
g.add_node(n1)
|
||||||
|
g.add_node(n2)
|
||||||
|
e = create_edge(n1.id, "value", n2.id, "value")
|
||||||
|
|
||||||
|
with pytest.raises(InvalidEdgeError):
|
||||||
|
g.add_edge(e)
|
||||||
|
|
||||||
|
|
||||||
def test_graph_can_generate_schema():
|
def test_graph_can_generate_schema():
|
||||||
# Not throwing on this line is sufficient
|
# Not throwing on this line is sufficient
|
||||||
# NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation
|
# NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation
|
||||||
|
Loading…
x
Reference in New Issue
Block a user