mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: metadata refactor
- Refactor how metadata is handled to support a user-defined metadata in graphs - Update workflow embed handling - Update UI to work with these changes - Update tests to support metadata/workflow changes
This commit is contained in:
@ -10,7 +10,12 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
)
|
||||
from invokeai.app.invocations.image import ShowImageInvocation
|
||||
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
|
||||
from invokeai.app.invocations.primitives import FloatInvocation, IntegerInvocation
|
||||
from invokeai.app.invocations.primitives import (
|
||||
FloatCollectionInvocation,
|
||||
FloatInvocation,
|
||||
IntegerInvocation,
|
||||
StringInvocation,
|
||||
)
|
||||
from invokeai.app.invocations.upscale import ESRGANInvocation
|
||||
from invokeai.app.services.shared.default_graphs import create_text_to_image
|
||||
from invokeai.app.services.shared.graph import (
|
||||
@ -27,8 +32,11 @@ from invokeai.app.services.shared.graph import (
|
||||
)
|
||||
|
||||
from .test_nodes import (
|
||||
AnyTypeTestInvocation,
|
||||
ImageToImageTestInvocation,
|
||||
ListPassThroughInvocation,
|
||||
PolymorphicStringTestInvocation,
|
||||
PromptCollectionTestInvocation,
|
||||
PromptTestInvocation,
|
||||
TextToImageTestInvocation,
|
||||
)
|
||||
@ -692,6 +700,144 @@ def test_ints_do_not_accept_floats():
|
||||
g.add_edge(e)
|
||||
|
||||
|
||||
def test_polymorphic_accepts_single():
|
||||
g = Graph()
|
||||
n1 = StringInvocation(id="1", value="banana")
|
||||
n2 = PolymorphicStringTestInvocation(id="2")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
e1 = create_edge(n1.id, "value", n2.id, "value")
|
||||
# Not throwing on this line is sufficient
|
||||
g.add_edge(e1)
|
||||
|
||||
|
||||
def test_polymorphic_accepts_collection_of_same_base_type():
|
||||
g = Graph()
|
||||
n1 = PromptCollectionTestInvocation(id="1", collection=["banana", "sundae"])
|
||||
n2 = PolymorphicStringTestInvocation(id="2")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
e1 = create_edge(n1.id, "collection", n2.id, "value")
|
||||
# Not throwing on this line is sufficient
|
||||
g.add_edge(e1)
|
||||
|
||||
|
||||
def test_polymorphic_does_not_accept_collection_of_different_base_type():
|
||||
g = Graph()
|
||||
n1 = FloatCollectionInvocation(id="1", collection=[1.0, 2.0, 3.0])
|
||||
n2 = PolymorphicStringTestInvocation(id="2")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
e1 = create_edge(n1.id, "collection", n2.id, "value")
|
||||
with pytest.raises(InvalidEdgeError):
|
||||
g.add_edge(e1)
|
||||
|
||||
|
||||
def test_polymorphic_does_not_accept_generic_collection():
|
||||
g = Graph()
|
||||
n1 = IntegerInvocation(id="1", value=1)
|
||||
n2 = IntegerInvocation(id="2", value=2)
|
||||
n3 = CollectInvocation(id="3")
|
||||
n4 = PolymorphicStringTestInvocation(id="4")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
g.add_node(n3)
|
||||
g.add_node(n4)
|
||||
e1 = create_edge(n1.id, "value", n3.id, "item")
|
||||
e2 = create_edge(n2.id, "value", n3.id, "item")
|
||||
e3 = create_edge(n3.id, "collection", n4.id, "value")
|
||||
g.add_edge(e1)
|
||||
g.add_edge(e2)
|
||||
with pytest.raises(InvalidEdgeError):
|
||||
g.add_edge(e3)
|
||||
|
||||
|
||||
def test_any_accepts_integer():
|
||||
g = Graph()
|
||||
n1 = IntegerInvocation(id="1", value=1)
|
||||
n2 = AnyTypeTestInvocation(id="2")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
e = create_edge(n1.id, "value", n2.id, "value")
|
||||
# Not throwing on this line is sufficient
|
||||
g.add_edge(e)
|
||||
|
||||
|
||||
def test_any_accepts_string():
|
||||
g = Graph()
|
||||
n1 = StringInvocation(id="1", value="banana sundae")
|
||||
n2 = AnyTypeTestInvocation(id="2")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
e = create_edge(n1.id, "value", n2.id, "value")
|
||||
# Not throwing on this line is sufficient
|
||||
g.add_edge(e)
|
||||
|
||||
|
||||
def test_any_accepts_generic_collection():
|
||||
g = Graph()
|
||||
n1 = IntegerInvocation(id="1", value=1)
|
||||
n2 = IntegerInvocation(id="2", value=2)
|
||||
n3 = CollectInvocation(id="3")
|
||||
n4 = AnyTypeTestInvocation(id="4")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
g.add_node(n3)
|
||||
g.add_node(n4)
|
||||
e1 = create_edge(n1.id, "value", n3.id, "item")
|
||||
e2 = create_edge(n2.id, "value", n3.id, "item")
|
||||
e3 = create_edge(n3.id, "collection", n4.id, "value")
|
||||
g.add_edge(e1)
|
||||
g.add_edge(e2)
|
||||
# Not throwing on this line is sufficient
|
||||
g.add_edge(e3)
|
||||
|
||||
|
||||
def test_any_accepts_prompt_collection():
|
||||
g = Graph()
|
||||
n1 = PromptCollectionTestInvocation(id="1", collection=["banana", "sundae"])
|
||||
n2 = AnyTypeTestInvocation(id="2")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
e = create_edge(n1.id, "collection", n2.id, "value")
|
||||
# Not throwing on this line is sufficient
|
||||
g.add_edge(e)
|
||||
|
||||
|
||||
def test_any_accepts_any():
|
||||
g = Graph()
|
||||
n1 = AnyTypeTestInvocation(id="1")
|
||||
n2 = AnyTypeTestInvocation(id="2")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
e = create_edge(n1.id, "value", n2.id, "value")
|
||||
# Not throwing on this line is sufficient
|
||||
g.add_edge(e)
|
||||
|
||||
|
||||
def test_iterate_accepts_collection():
|
||||
"""We need to update the validation for Collect -> Iterate to traverse to the Iterate
|
||||
node's output and compare that against the item type of the Collect node's collection. Until
|
||||
then, Collect nodes may not output into Iterate nodes."""
|
||||
g = Graph()
|
||||
n1 = IntegerInvocation(id="1", value=1)
|
||||
n2 = IntegerInvocation(id="2", value=2)
|
||||
n3 = CollectInvocation(id="3")
|
||||
n4 = IterateInvocation(id="4")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
g.add_node(n3)
|
||||
g.add_node(n4)
|
||||
e1 = create_edge(n1.id, "value", n3.id, "item")
|
||||
e2 = create_edge(n2.id, "value", n3.id, "item")
|
||||
e3 = create_edge(n3.id, "collection", n4.id, "collection")
|
||||
g.add_edge(e1)
|
||||
g.add_edge(e2)
|
||||
# Once we fix the validation logic as described, this should should not raise an error
|
||||
with pytest.raises(InvalidEdgeError, match="Cannot connect collector to iterator"):
|
||||
g.add_edge(e3)
|
||||
|
||||
|
||||
def test_graph_can_generate_schema():
|
||||
# Not throwing on this line is sufficient
|
||||
# NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation
|
||||
|
@ -81,6 +81,29 @@ class PromptCollectionTestInvocation(BaseInvocation):
|
||||
return PromptCollectionTestInvocationOutput(collection=self.collection.copy())
|
||||
|
||||
|
||||
@invocation_output("test_any_output")
|
||||
class AnyTypeTestInvocationOutput(BaseInvocationOutput):
|
||||
value: Any = Field()
|
||||
|
||||
|
||||
@invocation("test_any")
|
||||
class AnyTypeTestInvocation(BaseInvocation):
|
||||
value: Any = Field(default=None)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> AnyTypeTestInvocationOutput:
|
||||
return AnyTypeTestInvocationOutput(value=self.value)
|
||||
|
||||
|
||||
@invocation("test_polymorphic")
|
||||
class PolymorphicStringTestInvocation(BaseInvocation):
|
||||
value: Union[str, list[str]] = Field(default="")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput:
|
||||
if isinstance(self.value, str):
|
||||
return PromptCollectionTestInvocationOutput(collection=[self.value])
|
||||
return PromptCollectionTestInvocationOutput(collection=self.value)
|
||||
|
||||
|
||||
# Importing these must happen after test invocations are defined or they won't register
|
||||
from invokeai.app.services.events.events_base import EventServiceBase # noqa: E402
|
||||
from invokeai.app.services.shared.graph import Edge, EdgeConnection # noqa: E402
|
||||
|
Reference in New Issue
Block a user