mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
34e3aa1f88
author Kyle Schouviller <kyle0654@hotmail.com> 1669872800 -0800 committer Kyle Schouviller <kyle0654@hotmail.com> 1676240900 -0800 Adding base node architecture Fix type annotation errors Runs and generates, but breaks in saving session Fix default model value setting. Fix deprecation warning. Fixed node api Adding markdown docs Simplifying Generate construction in apps [nodes] A few minor changes (#2510) * Pin api-related requirements * Remove confusing extra CORS origins list * Adds response models for HTTP 200 [nodes] Adding graph_execution_state to soon replace session. Adding tests with pytest. Minor typing fixes [nodes] Fix some small output query hookups [node] Fixing some additional typing issues [nodes] Move and expand graph code. Add base item storage and sqlite implementation. Update startup to match new code [nodes] Add callbacks to item storage [nodes] Adding an InvocationContext object to use for invocations to provide easier extensibility [nodes] New execution model that handles iteration [nodes] Fixing the CLI [nodes] Adding a note to the CLI [nodes] Split processing thread into separate service [node] Add error message on node processing failure Removing old files and duplicated packages Adding python-multipart
502 lines
15 KiB
Python
502 lines
15 KiB
Python
from ldm.invoke.app.invocations.image import *
|
|
|
|
from .test_nodes import ListPassThroughInvocation, PromptTestInvocation
|
|
from ldm.invoke.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation
|
|
from ldm.invoke.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation
|
|
from ldm.invoke.app.invocations.upscale import UpscaleInvocation
|
|
import pytest
|
|
|
|
|
|
# Helpers
|
|
def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> tuple[EdgeConnection, EdgeConnection]:
|
|
return (EdgeConnection(node_id = from_id, field = from_field), EdgeConnection(node_id = to_id, field = to_field))
|
|
|
|
# Tests
|
|
def test_connections_are_compatible():
|
|
from_node = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
|
from_field = "image"
|
|
to_node = UpscaleInvocation(id = "2")
|
|
to_field = "image"
|
|
|
|
result = are_connections_compatible(from_node, from_field, to_node, to_field)
|
|
|
|
assert result == True
|
|
|
|
def test_connections_are_incompatible():
|
|
from_node = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
|
from_field = "image"
|
|
to_node = UpscaleInvocation(id = "2")
|
|
to_field = "strength"
|
|
|
|
result = are_connections_compatible(from_node, from_field, to_node, to_field)
|
|
|
|
assert result == False
|
|
|
|
def test_connections_incompatible_with_invalid_fields():
|
|
from_node = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
|
from_field = "invalid_field"
|
|
to_node = UpscaleInvocation(id = "2")
|
|
to_field = "image"
|
|
|
|
# From field is invalid
|
|
result = are_connections_compatible(from_node, from_field, to_node, to_field)
|
|
assert result == False
|
|
|
|
# To field is invalid
|
|
from_field = "image"
|
|
to_field = "invalid_field"
|
|
|
|
result = are_connections_compatible(from_node, from_field, to_node, to_field)
|
|
assert result == False
|
|
|
|
def test_graph_can_add_node():
|
|
g = Graph()
|
|
n = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
|
g.add_node(n)
|
|
|
|
assert n.id in g.nodes
|
|
|
|
def test_graph_fails_to_add_node_with_duplicate_id():
|
|
g = Graph()
|
|
n = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
|
g.add_node(n)
|
|
n2 = TextToImageInvocation(id = "1", prompt = "Banana sushi the second")
|
|
|
|
with pytest.raises(NodeAlreadyInGraphError):
|
|
g.add_node(n2)
|
|
|
|
def test_graph_updates_node():
|
|
g = Graph()
|
|
n = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
|
g.add_node(n)
|
|
n2 = TextToImageInvocation(id = "2", prompt = "Banana sushi the second")
|
|
g.add_node(n2)
|
|
|
|
nu = TextToImageInvocation(id = "1", prompt = "Banana sushi updated")
|
|
|
|
g.update_node("1", nu)
|
|
|
|
assert g.nodes["1"].prompt == "Banana sushi updated"
|
|
|
|
def test_graph_fails_to_update_node_if_type_changes():
|
|
g = Graph()
|
|
n = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
|
g.add_node(n)
|
|
n2 = UpscaleInvocation(id = "2")
|
|
g.add_node(n2)
|
|
|
|
nu = UpscaleInvocation(id = "1")
|
|
|
|
with pytest.raises(TypeError):
|
|
g.update_node("1", nu)
|
|
|
|
def test_graph_allows_non_conflicting_id_change():
|
|
g = Graph()
|
|
n = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
|
g.add_node(n)
|
|
n2 = UpscaleInvocation(id = "2")
|
|
g.add_node(n2)
|
|
e1 = create_edge(n.id,"image",n2.id,"image")
|
|
g.add_edge(e1)
|
|
|
|
nu = TextToImageInvocation(id = "3", prompt = "Banana sushi")
|
|
g.update_node("1", nu)
|
|
|
|
with pytest.raises(NodeNotFoundError):
|
|
g.get_node("1")
|
|
|
|
assert g.get_node("3").prompt == "Banana sushi"
|
|
|
|
assert len(g.edges) == 1
|
|
assert (EdgeConnection(node_id = "3", field = "image"), EdgeConnection(node_id = "2", field = "image")) in g.edges
|
|
|
|
def test_graph_fails_to_update_node_id_if_conflict():
|
|
g = Graph()
|
|
n = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
|
g.add_node(n)
|
|
n2 = TextToImageInvocation(id = "2", prompt = "Banana sushi the second")
|
|
g.add_node(n2)
|
|
|
|
nu = TextToImageInvocation(id = "2", prompt = "Banana sushi")
|
|
with pytest.raises(NodeAlreadyInGraphError):
|
|
g.update_node("1", nu)
|
|
|
|
def test_graph_adds_edge():
|
|
g = Graph()
|
|
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
|
n2 = UpscaleInvocation(id = "2")
|
|
g.add_node(n1)
|
|
g.add_node(n2)
|
|
e = create_edge(n1.id,"image",n2.id,"image")
|
|
|
|
g.add_edge(e)
|
|
|
|
assert e in g.edges
|
|
|
|
def test_graph_fails_to_add_edge_with_cycle():
|
|
g = Graph()
|
|
n1 = UpscaleInvocation(id = "1")
|
|
g.add_node(n1)
|
|
e = create_edge(n1.id,"image",n1.id,"image")
|
|
with pytest.raises(InvalidEdgeError):
|
|
g.add_edge(e)
|
|
|
|
def test_graph_fails_to_add_edge_with_long_cycle():
|
|
g = Graph()
|
|
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
|
n2 = UpscaleInvocation(id = "2")
|
|
n3 = UpscaleInvocation(id = "3")
|
|
g.add_node(n1)
|
|
g.add_node(n2)
|
|
g.add_node(n3)
|
|
e1 = create_edge(n1.id,"image",n2.id,"image")
|
|
e2 = create_edge(n2.id,"image",n3.id,"image")
|
|
e3 = create_edge(n3.id,"image",n2.id,"image")
|
|
g.add_edge(e1)
|
|
g.add_edge(e2)
|
|
with pytest.raises(InvalidEdgeError):
|
|
g.add_edge(e3)
|
|
|
|
def test_graph_fails_to_add_edge_with_missing_node_id():
|
|
g = Graph()
|
|
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
|
n2 = UpscaleInvocation(id = "2")
|
|
g.add_node(n1)
|
|
g.add_node(n2)
|
|
e1 = create_edge("1","image","3","image")
|
|
e2 = create_edge("3","image","1","image")
|
|
with pytest.raises(InvalidEdgeError):
|
|
g.add_edge(e1)
|
|
with pytest.raises(InvalidEdgeError):
|
|
g.add_edge(e2)
|
|
|
|
def test_graph_fails_to_add_edge_when_destination_exists():
|
|
g = Graph()
|
|
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
|
n2 = UpscaleInvocation(id = "2")
|
|
n3 = UpscaleInvocation(id = "3")
|
|
g.add_node(n1)
|
|
g.add_node(n2)
|
|
g.add_node(n3)
|
|
e1 = create_edge(n1.id,"image",n2.id,"image")
|
|
e2 = create_edge(n1.id,"image",n3.id,"image")
|
|
e3 = create_edge(n2.id,"image",n3.id,"image")
|
|
g.add_edge(e1)
|
|
g.add_edge(e2)
|
|
with pytest.raises(InvalidEdgeError):
|
|
g.add_edge(e3)
|
|
|
|
|
|
def test_graph_fails_to_add_edge_with_mismatched_types():
|
|
g = Graph()
|
|
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
|
n2 = UpscaleInvocation(id = "2")
|
|
g.add_node(n1)
|
|
g.add_node(n2)
|
|
e1 = create_edge("1","image","2","strength")
|
|
with pytest.raises(InvalidEdgeError):
|
|
g.add_edge(e1)
|
|
|
|
def test_graph_connects_collector():
|
|
g = Graph()
|
|
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
|
n2 = TextToImageInvocation(id = "2", prompt = "Banana sushi 2")
|
|
n3 = CollectInvocation(id = "3")
|
|
n4 = ListPassThroughInvocation(id = "4")
|
|
g.add_node(n1)
|
|
g.add_node(n2)
|
|
g.add_node(n3)
|
|
g.add_node(n4)
|
|
|
|
e1 = create_edge("1","image","3","item")
|
|
e2 = create_edge("2","image","3","item")
|
|
e3 = create_edge("3","collection","4","collection")
|
|
g.add_edge(e1)
|
|
g.add_edge(e2)
|
|
g.add_edge(e3)
|
|
|
|
# TODO: test that derived types mixed with base types are compatible
|
|
|
|
def test_graph_collector_invalid_with_varying_input_types():
|
|
g = Graph()
|
|
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
|
n2 = PromptTestInvocation(id = "2", prompt = "banana sushi 2")
|
|
n3 = CollectInvocation(id = "3")
|
|
g.add_node(n1)
|
|
g.add_node(n2)
|
|
g.add_node(n3)
|
|
|
|
e1 = create_edge("1","image","3","item")
|
|
e2 = create_edge("2","prompt","3","item")
|
|
g.add_edge(e1)
|
|
|
|
with pytest.raises(InvalidEdgeError):
|
|
g.add_edge(e2)
|
|
|
|
def test_graph_collector_invalid_with_varying_input_output():
|
|
g = Graph()
|
|
n1 = PromptTestInvocation(id = "1", prompt = "Banana sushi")
|
|
n2 = PromptTestInvocation(id = "2", prompt = "Banana sushi 2")
|
|
n3 = CollectInvocation(id = "3")
|
|
n4 = ListPassThroughInvocation(id = "4")
|
|
g.add_node(n1)
|
|
g.add_node(n2)
|
|
g.add_node(n3)
|
|
g.add_node(n4)
|
|
|
|
e1 = create_edge("1","prompt","3","item")
|
|
e2 = create_edge("2","prompt","3","item")
|
|
e3 = create_edge("3","collection","4","collection")
|
|
g.add_edge(e1)
|
|
g.add_edge(e2)
|
|
|
|
with pytest.raises(InvalidEdgeError):
|
|
g.add_edge(e3)
|
|
|
|
def test_graph_collector_invalid_with_non_list_output():
|
|
g = Graph()
|
|
n1 = PromptTestInvocation(id = "1", prompt = "Banana sushi")
|
|
n2 = PromptTestInvocation(id = "2", prompt = "Banana sushi 2")
|
|
n3 = CollectInvocation(id = "3")
|
|
n4 = PromptTestInvocation(id = "4")
|
|
g.add_node(n1)
|
|
g.add_node(n2)
|
|
g.add_node(n3)
|
|
g.add_node(n4)
|
|
|
|
e1 = create_edge("1","prompt","3","item")
|
|
e2 = create_edge("2","prompt","3","item")
|
|
e3 = create_edge("3","collection","4","prompt")
|
|
g.add_edge(e1)
|
|
g.add_edge(e2)
|
|
|
|
with pytest.raises(InvalidEdgeError):
|
|
g.add_edge(e3)
|
|
|
|
def test_graph_connects_iterator():
|
|
g = Graph()
|
|
n1 = ListPassThroughInvocation(id = "1")
|
|
n2 = IterateInvocation(id = "2")
|
|
n3 = ImageToImageInvocation(id = "3", prompt = "Banana sushi")
|
|
g.add_node(n1)
|
|
g.add_node(n2)
|
|
g.add_node(n3)
|
|
|
|
e1 = create_edge("1","collection","2","collection")
|
|
e2 = create_edge("2","item","3","image")
|
|
g.add_edge(e1)
|
|
g.add_edge(e2)
|
|
|
|
# TODO: TEST INVALID ITERATOR SCENARIOS
|
|
|
|
def test_graph_iterator_invalid_if_multiple_inputs():
|
|
g = Graph()
|
|
n1 = ListPassThroughInvocation(id = "1")
|
|
n2 = IterateInvocation(id = "2")
|
|
n3 = ImageToImageInvocation(id = "3", prompt = "Banana sushi")
|
|
n4 = ListPassThroughInvocation(id = "4")
|
|
g.add_node(n1)
|
|
g.add_node(n2)
|
|
g.add_node(n3)
|
|
g.add_node(n4)
|
|
|
|
e1 = create_edge("1","collection","2","collection")
|
|
e2 = create_edge("2","item","3","image")
|
|
e3 = create_edge("4","collection","2","collection")
|
|
g.add_edge(e1)
|
|
g.add_edge(e2)
|
|
|
|
with pytest.raises(InvalidEdgeError):
|
|
g.add_edge(e3)
|
|
|
|
def test_graph_iterator_invalid_if_input_not_list():
|
|
g = Graph()
|
|
n1 = TextToImageInvocation(id = "1", promopt = "Banana sushi")
|
|
n2 = IterateInvocation(id = "2")
|
|
g.add_node(n1)
|
|
g.add_node(n2)
|
|
|
|
e1 = create_edge("1","collection","2","collection")
|
|
|
|
with pytest.raises(InvalidEdgeError):
|
|
g.add_edge(e1)
|
|
|
|
def test_graph_iterator_invalid_if_output_and_input_types_different():
|
|
g = Graph()
|
|
n1 = ListPassThroughInvocation(id = "1")
|
|
n2 = IterateInvocation(id = "2")
|
|
n3 = PromptTestInvocation(id = "3", prompt = "Banana sushi")
|
|
g.add_node(n1)
|
|
g.add_node(n2)
|
|
g.add_node(n3)
|
|
|
|
e1 = create_edge("1","collection","2","collection")
|
|
e2 = create_edge("2","item","3","prompt")
|
|
g.add_edge(e1)
|
|
|
|
with pytest.raises(InvalidEdgeError):
|
|
g.add_edge(e2)
|
|
|
|
def test_graph_validates():
|
|
g = Graph()
|
|
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
|
n2 = UpscaleInvocation(id = "2")
|
|
g.add_node(n1)
|
|
g.add_node(n2)
|
|
e1 = create_edge("1","image","2","image")
|
|
g.add_edge(e1)
|
|
|
|
assert g.is_valid() == True
|
|
|
|
def test_graph_invalid_if_edges_reference_missing_nodes():
|
|
g = Graph()
|
|
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
|
g.nodes[n1.id] = n1
|
|
e1 = create_edge("1","image","2","image")
|
|
g.edges.append(e1)
|
|
|
|
assert g.is_valid() == False
|
|
|
|
def test_graph_invalid_if_subgraph_invalid():
|
|
g = Graph()
|
|
n1 = GraphInvocation(id = "1")
|
|
n1.graph = Graph()
|
|
|
|
n1_1 = TextToImageInvocation(id = "2", prompt = "Banana sushi")
|
|
n1.graph.nodes[n1_1.id] = n1_1
|
|
e1 = create_edge("1","image","2","image")
|
|
n1.graph.edges.append(e1)
|
|
|
|
g.nodes[n1.id] = n1
|
|
|
|
assert g.is_valid() == False
|
|
|
|
def test_graph_invalid_if_has_cycle():
|
|
g = Graph()
|
|
n1 = UpscaleInvocation(id = "1")
|
|
n2 = UpscaleInvocation(id = "2")
|
|
g.nodes[n1.id] = n1
|
|
g.nodes[n2.id] = n2
|
|
e1 = create_edge("1","image","2","image")
|
|
e2 = create_edge("2","image","1","image")
|
|
g.edges.append(e1)
|
|
g.edges.append(e2)
|
|
|
|
assert g.is_valid() == False
|
|
|
|
def test_graph_invalid_with_invalid_connection():
|
|
g = Graph()
|
|
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
|
n2 = UpscaleInvocation(id = "2")
|
|
g.nodes[n1.id] = n1
|
|
g.nodes[n2.id] = n2
|
|
e1 = create_edge("1","image","2","strength")
|
|
g.edges.append(e1)
|
|
|
|
assert g.is_valid() == False
|
|
|
|
|
|
# TODO: Subgraph operations
|
|
def test_graph_gets_subgraph_node():
|
|
g = Graph()
|
|
n1 = GraphInvocation(id = "1")
|
|
n1.graph = Graph()
|
|
n1.graph.add_node
|
|
|
|
n1_1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
|
n1.graph.add_node(n1_1)
|
|
|
|
g.add_node(n1)
|
|
|
|
result = g.get_node('1.1')
|
|
|
|
assert result is not None
|
|
assert result.id == '1'
|
|
assert result == n1_1
|
|
|
|
def test_graph_fails_to_get_missing_subgraph_node():
|
|
g = Graph()
|
|
n1 = GraphInvocation(id = "1")
|
|
n1.graph = Graph()
|
|
n1.graph.add_node
|
|
|
|
n1_1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
|
n1.graph.add_node(n1_1)
|
|
|
|
g.add_node(n1)
|
|
|
|
with pytest.raises(NodeNotFoundError):
|
|
result = g.get_node('1.2')
|
|
|
|
def test_graph_fails_to_enumerate_non_subgraph_node():
|
|
g = Graph()
|
|
n1 = GraphInvocation(id = "1")
|
|
n1.graph = Graph()
|
|
n1.graph.add_node
|
|
|
|
n1_1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
|
n1.graph.add_node(n1_1)
|
|
|
|
g.add_node(n1)
|
|
|
|
n2 = UpscaleInvocation(id = "2")
|
|
g.add_node(n2)
|
|
|
|
with pytest.raises(NodeNotFoundError):
|
|
result = g.get_node('2.1')
|
|
|
|
def test_graph_gets_networkx_graph():
|
|
g = Graph()
|
|
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
|
n2 = UpscaleInvocation(id = "2")
|
|
g.add_node(n1)
|
|
g.add_node(n2)
|
|
e = create_edge(n1.id,"image",n2.id,"image")
|
|
g.add_edge(e)
|
|
|
|
nxg = g.nx_graph()
|
|
|
|
assert '1' in nxg.nodes
|
|
assert '2' in nxg.nodes
|
|
assert ('1','2') in nxg.edges
|
|
|
|
|
|
# TODO: Graph serializes and deserializes
|
|
def test_graph_can_serialize():
|
|
g = Graph()
|
|
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
|
n2 = UpscaleInvocation(id = "2")
|
|
g.add_node(n1)
|
|
g.add_node(n2)
|
|
e = create_edge(n1.id,"image",n2.id,"image")
|
|
g.add_edge(e)
|
|
|
|
# Not throwing on this line is sufficient
|
|
json = g.json()
|
|
|
|
def test_graph_can_deserialize():
|
|
g = Graph()
|
|
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
|
n2 = UpscaleInvocation(id = "2")
|
|
g.add_node(n1)
|
|
g.add_node(n2)
|
|
e = create_edge(n1.id,"image",n2.id,"image")
|
|
g.add_edge(e)
|
|
|
|
json = g.json()
|
|
g2 = Graph.parse_raw(json)
|
|
|
|
assert g2 is not None
|
|
assert g2.nodes['1'] is not None
|
|
assert g2.nodes['2'] is not None
|
|
assert len(g2.edges) == 1
|
|
assert g2.edges[0][0].node_id == '1'
|
|
assert g2.edges[0][0].field == 'image'
|
|
assert g2.edges[0][1].node_id == '2'
|
|
assert g2.edges[0][1].field == 'image'
|
|
|
|
def test_graph_can_generate_schema():
|
|
# Not throwing on this line is sufficient
|
|
# NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation
|
|
schema = Graph.schema_json(indent = 2)
|