mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
tidy(nodes): remove GraphInvocation
`GraphInvocation` is a node that can contain a whole graph. It is removed for a number of reasons: 1. This feature was unused (the UI doesn't support it) and there is no plan for it to be used. The use-case it served is known in other node execution engines as "node groups" or "blocks" - a self-contained group of nodes, which has group inputs and outputs. This is a planned feature that will be handled client-side. 2. It adds substantial complexity to the graph processing logic. It's probably not enough to have a measurable performance impact but it does make it harder to work in the graph logic. 3. It allows for graphs to be recursive, and the improved invocations union handling does not play well with it. Actually, it works fine within `graph.py` but not in the tests for some reason. I do not understand why. There's probably a workaround, but I took this as encouragement to remove `GraphInvocation` from the app since we don't use it.
This commit is contained in:
@ -23,7 +23,7 @@ from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID
|
||||
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation
|
||||
from invokeai.app.services.shared.graph import Graph, GraphExecutionState
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -35,17 +35,6 @@ def simple_graph():
|
||||
return g
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def graph_with_subgraph():
|
||||
sub_g = Graph()
|
||||
sub_g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi"))
|
||||
sub_g.add_node(TextToImageTestInvocation(id="2"))
|
||||
sub_g.add_edge(create_edge("1", "prompt", "2", "prompt"))
|
||||
g = Graph()
|
||||
g.add_node(GraphInvocation(id="1", graph=sub_g))
|
||||
return g
|
||||
|
||||
|
||||
# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
|
||||
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
|
||||
# the test invocations.
|
||||
|
@ -8,8 +8,6 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.image import ShowImageInvocation
|
||||
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
|
||||
from invokeai.app.invocations.primitives import (
|
||||
FloatCollectionInvocation,
|
||||
FloatInvocation,
|
||||
@ -17,13 +15,11 @@ from invokeai.app.invocations.primitives import (
|
||||
StringInvocation,
|
||||
)
|
||||
from invokeai.app.invocations.upscale import ESRGANInvocation
|
||||
from invokeai.app.services.shared.default_graphs import create_text_to_image
|
||||
from invokeai.app.services.shared.graph import (
|
||||
CollectInvocation,
|
||||
Edge,
|
||||
EdgeConnection,
|
||||
Graph,
|
||||
GraphInvocation,
|
||||
InvalidEdgeError,
|
||||
IterateInvocation,
|
||||
NodeAlreadyInGraphError,
|
||||
@ -425,19 +421,19 @@ def test_graph_invalid_if_edges_reference_missing_nodes():
|
||||
assert g.is_valid() is False
|
||||
|
||||
|
||||
def test_graph_invalid_if_subgraph_invalid():
|
||||
g = Graph()
|
||||
n1 = GraphInvocation(id="1")
|
||||
n1.graph = Graph()
|
||||
# def test_graph_invalid_if_subgraph_invalid():
|
||||
# g = Graph()
|
||||
# n1 = GraphInvocation(id="1")
|
||||
# n1.graph = Graph()
|
||||
|
||||
n1_1 = TextToImageTestInvocation(id="2", prompt="Banana sushi")
|
||||
n1.graph.nodes[n1_1.id] = n1_1
|
||||
e1 = create_edge("1", "image", "2", "image")
|
||||
n1.graph.edges.append(e1)
|
||||
# n1_1 = TextToImageTestInvocation(id="2", prompt="Banana sushi")
|
||||
# n1.graph.nodes[n1_1.id] = n1_1
|
||||
# e1 = create_edge("1", "image", "2", "image")
|
||||
# n1.graph.edges.append(e1)
|
||||
|
||||
g.nodes[n1.id] = n1
|
||||
# g.nodes[n1.id] = n1
|
||||
|
||||
assert g.is_valid() is False
|
||||
# assert g.is_valid() is False
|
||||
|
||||
|
||||
def test_graph_invalid_if_has_cycle():
|
||||
@ -466,108 +462,108 @@ def test_graph_invalid_with_invalid_connection():
|
||||
assert g.is_valid() is False
|
||||
|
||||
|
||||
# TODO: Subgraph operations
|
||||
def test_graph_gets_subgraph_node():
|
||||
g = Graph()
|
||||
n1 = GraphInvocation(id="1")
|
||||
n1.graph = Graph()
|
||||
# # TODO: Subgraph operations
|
||||
# def test_graph_gets_subgraph_node():
|
||||
# g = Graph()
|
||||
# n1 = GraphInvocation(id="1")
|
||||
# n1.graph = Graph()
|
||||
|
||||
n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
|
||||
n1.graph.add_node(n1_1)
|
||||
# n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
|
||||
# n1.graph.add_node(n1_1)
|
||||
|
||||
g.add_node(n1)
|
||||
# g.add_node(n1)
|
||||
|
||||
result = g.get_node("1.1")
|
||||
# result = g.get_node("1.1")
|
||||
|
||||
assert result is not None
|
||||
assert result.id == "1"
|
||||
assert result == n1_1
|
||||
# assert result is not None
|
||||
# assert result.id == "1"
|
||||
# assert result == n1_1
|
||||
|
||||
|
||||
def test_graph_expands_subgraph():
|
||||
g = Graph()
|
||||
n1 = GraphInvocation(id="1")
|
||||
n1.graph = Graph()
|
||||
# def test_graph_expands_subgraph():
|
||||
# g = Graph()
|
||||
# n1 = GraphInvocation(id="1")
|
||||
# n1.graph = Graph()
|
||||
|
||||
n1_1 = AddInvocation(id="1", a=1, b=2)
|
||||
n1_2 = SubtractInvocation(id="2", b=3)
|
||||
n1.graph.add_node(n1_1)
|
||||
n1.graph.add_node(n1_2)
|
||||
n1.graph.add_edge(create_edge("1", "value", "2", "a"))
|
||||
# n1_1 = AddInvocation(id="1", a=1, b=2)
|
||||
# n1_2 = SubtractInvocation(id="2", b=3)
|
||||
# n1.graph.add_node(n1_1)
|
||||
# n1.graph.add_node(n1_2)
|
||||
# n1.graph.add_edge(create_edge("1", "value", "2", "a"))
|
||||
|
||||
g.add_node(n1)
|
||||
# g.add_node(n1)
|
||||
|
||||
n2 = AddInvocation(id="2", b=5)
|
||||
g.add_node(n2)
|
||||
g.add_edge(create_edge("1.2", "value", "2", "a"))
|
||||
# n2 = AddInvocation(id="2", b=5)
|
||||
# g.add_node(n2)
|
||||
# g.add_edge(create_edge("1.2", "value", "2", "a"))
|
||||
|
||||
dg = g.nx_graph_flat()
|
||||
assert set(dg.nodes) == {"1.1", "1.2", "2"}
|
||||
assert set(dg.edges) == {("1.1", "1.2"), ("1.2", "2")}
|
||||
# dg = g.nx_graph_flat()
|
||||
# assert set(dg.nodes) == {"1.1", "1.2", "2"}
|
||||
# assert set(dg.edges) == {("1.1", "1.2"), ("1.2", "2")}
|
||||
|
||||
|
||||
def test_graph_subgraph_t2i():
|
||||
g = Graph()
|
||||
n1 = GraphInvocation(id="1")
|
||||
# def test_graph_subgraph_t2i():
|
||||
# g = Graph()
|
||||
# n1 = GraphInvocation(id="1")
|
||||
|
||||
# Get text to image default graph
|
||||
lg = create_text_to_image()
|
||||
n1.graph = lg.graph
|
||||
# # Get text to image default graph
|
||||
# lg = create_text_to_image()
|
||||
# n1.graph = lg.graph
|
||||
|
||||
g.add_node(n1)
|
||||
# g.add_node(n1)
|
||||
|
||||
n2 = IntegerInvocation(id="2", value=512)
|
||||
n3 = IntegerInvocation(id="3", value=256)
|
||||
# n2 = IntegerInvocation(id="2", value=512)
|
||||
# n3 = IntegerInvocation(id="3", value=256)
|
||||
|
||||
g.add_node(n2)
|
||||
g.add_node(n3)
|
||||
# g.add_node(n2)
|
||||
# g.add_node(n3)
|
||||
|
||||
g.add_edge(create_edge("2", "value", "1.width", "value"))
|
||||
g.add_edge(create_edge("3", "value", "1.height", "value"))
|
||||
# g.add_edge(create_edge("2", "value", "1.width", "value"))
|
||||
# g.add_edge(create_edge("3", "value", "1.height", "value"))
|
||||
|
||||
n4 = ShowImageInvocation(id="4")
|
||||
g.add_node(n4)
|
||||
g.add_edge(create_edge("1.8", "image", "4", "image"))
|
||||
# n4 = ShowImageInvocation(id="4")
|
||||
# g.add_node(n4)
|
||||
# g.add_edge(create_edge("1.8", "image", "4", "image"))
|
||||
|
||||
# Validate
|
||||
dg = g.nx_graph_flat()
|
||||
assert set(dg.nodes) == {"1.width", "1.height", "1.seed", "1.3", "1.4", "1.5", "1.6", "1.7", "1.8", "2", "3", "4"}
|
||||
expected_edges = [(f"1.{e.source.node_id}", f"1.{e.destination.node_id}") for e in lg.graph.edges]
|
||||
expected_edges.extend([("2", "1.width"), ("3", "1.height"), ("1.8", "4")])
|
||||
print(expected_edges)
|
||||
print(list(dg.edges))
|
||||
assert set(dg.edges) == set(expected_edges)
|
||||
# # Validate
|
||||
# dg = g.nx_graph_flat()
|
||||
# assert set(dg.nodes) == {"1.width", "1.height", "1.seed", "1.3", "1.4", "1.5", "1.6", "1.7", "1.8", "2", "3", "4"}
|
||||
# expected_edges = [(f"1.{e.source.node_id}", f"1.{e.destination.node_id}") for e in lg.graph.edges]
|
||||
# expected_edges.extend([("2", "1.width"), ("3", "1.height"), ("1.8", "4")])
|
||||
# print(expected_edges)
|
||||
# print(list(dg.edges))
|
||||
# assert set(dg.edges) == set(expected_edges)
|
||||
|
||||
|
||||
def test_graph_fails_to_get_missing_subgraph_node():
|
||||
g = Graph()
|
||||
n1 = GraphInvocation(id="1")
|
||||
n1.graph = Graph()
|
||||
# def test_graph_fails_to_get_missing_subgraph_node():
|
||||
# g = Graph()
|
||||
# n1 = GraphInvocation(id="1")
|
||||
# n1.graph = Graph()
|
||||
|
||||
n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
|
||||
n1.graph.add_node(n1_1)
|
||||
# n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
|
||||
# n1.graph.add_node(n1_1)
|
||||
|
||||
g.add_node(n1)
|
||||
# g.add_node(n1)
|
||||
|
||||
with pytest.raises(NodeNotFoundError):
|
||||
_ = g.get_node("1.2")
|
||||
# with pytest.raises(NodeNotFoundError):
|
||||
# _ = g.get_node("1.2")
|
||||
|
||||
|
||||
def test_graph_fails_to_enumerate_non_subgraph_node():
|
||||
g = Graph()
|
||||
n1 = GraphInvocation(id="1")
|
||||
n1.graph = Graph()
|
||||
# def test_graph_fails_to_enumerate_non_subgraph_node():
|
||||
# g = Graph()
|
||||
# n1 = GraphInvocation(id="1")
|
||||
# n1.graph = Graph()
|
||||
|
||||
n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
|
||||
n1.graph.add_node(n1_1)
|
||||
# n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
|
||||
# n1.graph.add_node(n1_1)
|
||||
|
||||
g.add_node(n1)
|
||||
# g.add_node(n1)
|
||||
|
||||
n2 = ESRGANInvocation(id="2")
|
||||
g.add_node(n2)
|
||||
# n2 = ESRGANInvocation(id="2")
|
||||
# g.add_node(n2)
|
||||
|
||||
with pytest.raises(NodeNotFoundError):
|
||||
_ = g.get_node("2.1")
|
||||
# with pytest.raises(NodeNotFoundError):
|
||||
# _ = g.get_node("2.1")
|
||||
|
||||
|
||||
def test_graph_gets_networkx_graph():
|
||||
|
@ -8,10 +8,9 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
NodeFieldValue,
|
||||
calc_session_count,
|
||||
create_session_nfv_tuples,
|
||||
populate_graph,
|
||||
prepare_values_to_insert,
|
||||
)
|
||||
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation
|
||||
from invokeai.app.services.shared.graph import Graph, GraphExecutionState
|
||||
from tests.aa_nodes.test_nodes import PromptTestInvocation
|
||||
|
||||
|
||||
@ -39,28 +38,28 @@ def batch_graph() -> Graph:
|
||||
return g
|
||||
|
||||
|
||||
def test_populate_graph_with_subgraph():
|
||||
g1 = Graph()
|
||||
g1.add_node(PromptTestInvocation(id="1", prompt="Banana sushi"))
|
||||
g1.add_node(PromptTestInvocation(id="2", prompt="Banana sushi"))
|
||||
n1 = PromptTestInvocation(id="1", prompt="Banana snake")
|
||||
subgraph = Graph()
|
||||
subgraph.add_node(n1)
|
||||
g1.add_node(GraphInvocation(id="3", graph=subgraph))
|
||||
# def test_populate_graph_with_subgraph():
|
||||
# g1 = Graph()
|
||||
# g1.add_node(PromptTestInvocation(id="1", prompt="Banana sushi"))
|
||||
# g1.add_node(PromptTestInvocation(id="2", prompt="Banana sushi"))
|
||||
# n1 = PromptTestInvocation(id="1", prompt="Banana snake")
|
||||
# subgraph = Graph()
|
||||
# subgraph.add_node(n1)
|
||||
# g1.add_node(GraphInvocation(id="3", graph=subgraph))
|
||||
|
||||
nfvs = [
|
||||
NodeFieldValue(node_path="1", field_name="prompt", value="Strawberry sushi"),
|
||||
NodeFieldValue(node_path="2", field_name="prompt", value="Strawberry sunday"),
|
||||
NodeFieldValue(node_path="3.1", field_name="prompt", value="Strawberry snake"),
|
||||
]
|
||||
# nfvs = [
|
||||
# NodeFieldValue(node_path="1", field_name="prompt", value="Strawberry sushi"),
|
||||
# NodeFieldValue(node_path="2", field_name="prompt", value="Strawberry sunday"),
|
||||
# NodeFieldValue(node_path="3.1", field_name="prompt", value="Strawberry snake"),
|
||||
# ]
|
||||
|
||||
g2 = populate_graph(g1, nfvs)
|
||||
# g2 = populate_graph(g1, nfvs)
|
||||
|
||||
# do not mutate g1
|
||||
assert g1 is not g2
|
||||
assert g2.get_node("1").prompt == "Strawberry sushi"
|
||||
assert g2.get_node("2").prompt == "Strawberry sunday"
|
||||
assert g2.get_node("3.1").prompt == "Strawberry snake"
|
||||
# # do not mutate g1
|
||||
# assert g1 is not g2
|
||||
# assert g2.get_node("1").prompt == "Strawberry sushi"
|
||||
# assert g2.get_node("2").prompt == "Strawberry sunday"
|
||||
# assert g2.get_node("3.1").prompt == "Strawberry snake"
|
||||
|
||||
|
||||
def test_create_sessions_from_batch_with_runs(batch_data_collection, batch_graph):
|
||||
|
Reference in New Issue
Block a user