tidy(nodes): remove GraphInvocation

`GraphInvocation` is a node that can contain a whole graph. It is removed for a number of reasons:

1. This feature was unused (the UI doesn't support it) and there is no plan for it to be used.

The use-case it served is known in other node execution engines as "node groups" or "blocks" - a self-contained group of nodes, which has group inputs and outputs. This is a planned feature that will be handled client-side.

2. It adds substantial complexity to the graph processing logic. It's probably not enough to have a measurable performance impact but it does make it harder to work in the graph logic.

3. It allows for graphs to be recursive, and the improved invocations union handling does not play well with it. Actually, it works fine within `graph.py` but not in the tests for some reason. I do not understand why. There's probably a workaround, but I took this as encouragement to remove `GraphInvocation` from the app since we don't use it.
This commit is contained in:
psychedelicious
2024-02-17 19:56:13 +11:00
parent 47b5a90177
commit 5fc745653a
4 changed files with 178 additions and 336 deletions

View File

@ -23,7 +23,7 @@ from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation
from invokeai.app.services.shared.graph import Graph, GraphExecutionState
@pytest.fixture
@ -35,17 +35,6 @@ def simple_graph():
return g
@pytest.fixture
def graph_with_subgraph():
sub_g = Graph()
sub_g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi"))
sub_g.add_node(TextToImageTestInvocation(id="2"))
sub_g.add_edge(create_edge("1", "prompt", "2", "prompt"))
g = Graph()
g.add_node(GraphInvocation(id="1", graph=sub_g))
return g
# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
# the test invocations.

View File

@ -8,8 +8,6 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
invocation_output,
)
from invokeai.app.invocations.image import ShowImageInvocation
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
from invokeai.app.invocations.primitives import (
FloatCollectionInvocation,
FloatInvocation,
@ -17,13 +15,11 @@ from invokeai.app.invocations.primitives import (
StringInvocation,
)
from invokeai.app.invocations.upscale import ESRGANInvocation
from invokeai.app.services.shared.default_graphs import create_text_to_image
from invokeai.app.services.shared.graph import (
CollectInvocation,
Edge,
EdgeConnection,
Graph,
GraphInvocation,
InvalidEdgeError,
IterateInvocation,
NodeAlreadyInGraphError,
@ -425,19 +421,19 @@ def test_graph_invalid_if_edges_reference_missing_nodes():
assert g.is_valid() is False
def test_graph_invalid_if_subgraph_invalid():
g = Graph()
n1 = GraphInvocation(id="1")
n1.graph = Graph()
# def test_graph_invalid_if_subgraph_invalid():
# g = Graph()
# n1 = GraphInvocation(id="1")
# n1.graph = Graph()
n1_1 = TextToImageTestInvocation(id="2", prompt="Banana sushi")
n1.graph.nodes[n1_1.id] = n1_1
e1 = create_edge("1", "image", "2", "image")
n1.graph.edges.append(e1)
# n1_1 = TextToImageTestInvocation(id="2", prompt="Banana sushi")
# n1.graph.nodes[n1_1.id] = n1_1
# e1 = create_edge("1", "image", "2", "image")
# n1.graph.edges.append(e1)
g.nodes[n1.id] = n1
# g.nodes[n1.id] = n1
assert g.is_valid() is False
# assert g.is_valid() is False
def test_graph_invalid_if_has_cycle():
@ -466,108 +462,108 @@ def test_graph_invalid_with_invalid_connection():
assert g.is_valid() is False
# TODO: Subgraph operations
def test_graph_gets_subgraph_node():
g = Graph()
n1 = GraphInvocation(id="1")
n1.graph = Graph()
# # TODO: Subgraph operations
# def test_graph_gets_subgraph_node():
# g = Graph()
# n1 = GraphInvocation(id="1")
# n1.graph = Graph()
n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
n1.graph.add_node(n1_1)
# n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
# n1.graph.add_node(n1_1)
g.add_node(n1)
# g.add_node(n1)
result = g.get_node("1.1")
# result = g.get_node("1.1")
assert result is not None
assert result.id == "1"
assert result == n1_1
# assert result is not None
# assert result.id == "1"
# assert result == n1_1
def test_graph_expands_subgraph():
g = Graph()
n1 = GraphInvocation(id="1")
n1.graph = Graph()
# def test_graph_expands_subgraph():
# g = Graph()
# n1 = GraphInvocation(id="1")
# n1.graph = Graph()
n1_1 = AddInvocation(id="1", a=1, b=2)
n1_2 = SubtractInvocation(id="2", b=3)
n1.graph.add_node(n1_1)
n1.graph.add_node(n1_2)
n1.graph.add_edge(create_edge("1", "value", "2", "a"))
# n1_1 = AddInvocation(id="1", a=1, b=2)
# n1_2 = SubtractInvocation(id="2", b=3)
# n1.graph.add_node(n1_1)
# n1.graph.add_node(n1_2)
# n1.graph.add_edge(create_edge("1", "value", "2", "a"))
g.add_node(n1)
# g.add_node(n1)
n2 = AddInvocation(id="2", b=5)
g.add_node(n2)
g.add_edge(create_edge("1.2", "value", "2", "a"))
# n2 = AddInvocation(id="2", b=5)
# g.add_node(n2)
# g.add_edge(create_edge("1.2", "value", "2", "a"))
dg = g.nx_graph_flat()
assert set(dg.nodes) == {"1.1", "1.2", "2"}
assert set(dg.edges) == {("1.1", "1.2"), ("1.2", "2")}
# dg = g.nx_graph_flat()
# assert set(dg.nodes) == {"1.1", "1.2", "2"}
# assert set(dg.edges) == {("1.1", "1.2"), ("1.2", "2")}
def test_graph_subgraph_t2i():
g = Graph()
n1 = GraphInvocation(id="1")
# def test_graph_subgraph_t2i():
# g = Graph()
# n1 = GraphInvocation(id="1")
# Get text to image default graph
lg = create_text_to_image()
n1.graph = lg.graph
# # Get text to image default graph
# lg = create_text_to_image()
# n1.graph = lg.graph
g.add_node(n1)
# g.add_node(n1)
n2 = IntegerInvocation(id="2", value=512)
n3 = IntegerInvocation(id="3", value=256)
# n2 = IntegerInvocation(id="2", value=512)
# n3 = IntegerInvocation(id="3", value=256)
g.add_node(n2)
g.add_node(n3)
# g.add_node(n2)
# g.add_node(n3)
g.add_edge(create_edge("2", "value", "1.width", "value"))
g.add_edge(create_edge("3", "value", "1.height", "value"))
# g.add_edge(create_edge("2", "value", "1.width", "value"))
# g.add_edge(create_edge("3", "value", "1.height", "value"))
n4 = ShowImageInvocation(id="4")
g.add_node(n4)
g.add_edge(create_edge("1.8", "image", "4", "image"))
# n4 = ShowImageInvocation(id="4")
# g.add_node(n4)
# g.add_edge(create_edge("1.8", "image", "4", "image"))
# Validate
dg = g.nx_graph_flat()
assert set(dg.nodes) == {"1.width", "1.height", "1.seed", "1.3", "1.4", "1.5", "1.6", "1.7", "1.8", "2", "3", "4"}
expected_edges = [(f"1.{e.source.node_id}", f"1.{e.destination.node_id}") for e in lg.graph.edges]
expected_edges.extend([("2", "1.width"), ("3", "1.height"), ("1.8", "4")])
print(expected_edges)
print(list(dg.edges))
assert set(dg.edges) == set(expected_edges)
# # Validate
# dg = g.nx_graph_flat()
# assert set(dg.nodes) == {"1.width", "1.height", "1.seed", "1.3", "1.4", "1.5", "1.6", "1.7", "1.8", "2", "3", "4"}
# expected_edges = [(f"1.{e.source.node_id}", f"1.{e.destination.node_id}") for e in lg.graph.edges]
# expected_edges.extend([("2", "1.width"), ("3", "1.height"), ("1.8", "4")])
# print(expected_edges)
# print(list(dg.edges))
# assert set(dg.edges) == set(expected_edges)
def test_graph_fails_to_get_missing_subgraph_node():
g = Graph()
n1 = GraphInvocation(id="1")
n1.graph = Graph()
# def test_graph_fails_to_get_missing_subgraph_node():
# g = Graph()
# n1 = GraphInvocation(id="1")
# n1.graph = Graph()
n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
n1.graph.add_node(n1_1)
# n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
# n1.graph.add_node(n1_1)
g.add_node(n1)
# g.add_node(n1)
with pytest.raises(NodeNotFoundError):
_ = g.get_node("1.2")
# with pytest.raises(NodeNotFoundError):
# _ = g.get_node("1.2")
def test_graph_fails_to_enumerate_non_subgraph_node():
g = Graph()
n1 = GraphInvocation(id="1")
n1.graph = Graph()
# def test_graph_fails_to_enumerate_non_subgraph_node():
# g = Graph()
# n1 = GraphInvocation(id="1")
# n1.graph = Graph()
n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
n1.graph.add_node(n1_1)
# n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
# n1.graph.add_node(n1_1)
g.add_node(n1)
# g.add_node(n1)
n2 = ESRGANInvocation(id="2")
g.add_node(n2)
# n2 = ESRGANInvocation(id="2")
# g.add_node(n2)
with pytest.raises(NodeNotFoundError):
_ = g.get_node("2.1")
# with pytest.raises(NodeNotFoundError):
# _ = g.get_node("2.1")
def test_graph_gets_networkx_graph():

View File

@ -8,10 +8,9 @@ from invokeai.app.services.session_queue.session_queue_common import (
NodeFieldValue,
calc_session_count,
create_session_nfv_tuples,
populate_graph,
prepare_values_to_insert,
)
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation
from invokeai.app.services.shared.graph import Graph, GraphExecutionState
from tests.aa_nodes.test_nodes import PromptTestInvocation
@ -39,28 +38,28 @@ def batch_graph() -> Graph:
return g
def test_populate_graph_with_subgraph():
g1 = Graph()
g1.add_node(PromptTestInvocation(id="1", prompt="Banana sushi"))
g1.add_node(PromptTestInvocation(id="2", prompt="Banana sushi"))
n1 = PromptTestInvocation(id="1", prompt="Banana snake")
subgraph = Graph()
subgraph.add_node(n1)
g1.add_node(GraphInvocation(id="3", graph=subgraph))
# def test_populate_graph_with_subgraph():
# g1 = Graph()
# g1.add_node(PromptTestInvocation(id="1", prompt="Banana sushi"))
# g1.add_node(PromptTestInvocation(id="2", prompt="Banana sushi"))
# n1 = PromptTestInvocation(id="1", prompt="Banana snake")
# subgraph = Graph()
# subgraph.add_node(n1)
# g1.add_node(GraphInvocation(id="3", graph=subgraph))
nfvs = [
NodeFieldValue(node_path="1", field_name="prompt", value="Strawberry sushi"),
NodeFieldValue(node_path="2", field_name="prompt", value="Strawberry sunday"),
NodeFieldValue(node_path="3.1", field_name="prompt", value="Strawberry snake"),
]
# nfvs = [
# NodeFieldValue(node_path="1", field_name="prompt", value="Strawberry sushi"),
# NodeFieldValue(node_path="2", field_name="prompt", value="Strawberry sunday"),
# NodeFieldValue(node_path="3.1", field_name="prompt", value="Strawberry snake"),
# ]
g2 = populate_graph(g1, nfvs)
# g2 = populate_graph(g1, nfvs)
# do not mutate g1
assert g1 is not g2
assert g2.get_node("1").prompt == "Strawberry sushi"
assert g2.get_node("2").prompt == "Strawberry sunday"
assert g2.get_node("3.1").prompt == "Strawberry snake"
# # do not mutate g1
# assert g1 is not g2
# assert g2.get_node("1").prompt == "Strawberry sushi"
# assert g2.get_node("2").prompt == "Strawberry sunday"
# assert g2.get_node("3.1").prompt == "Strawberry snake"
def test_create_sessions_from_batch_with_runs(batch_data_collection, batch_graph):