tidy(nodes): remove GraphInvocation

`GraphInvocation` is a node that can contain a whole graph. It is removed for a number of reasons:

1. This feature was unused (the UI doesn't support it) and there is no plan for it to be used.

The use-case it served is known in other node execution engines as "node groups" or "blocks" - a self-contained group of nodes, which has group inputs and outputs. This is a planned feature that will be handled client-side.

2. It adds substantial complexity to the graph processing logic. It's probably not enough to have a measurable performance impact but it does make it harder to work in the graph logic.

3. It allows for graphs to be recursive, and the improved invocations union handling does not play well with it. Actually, it works fine within `graph.py` but not in the tests for some reason. I do not understand why. There's probably a workaround, but I took this as encouragement to remove `GraphInvocation` from the app since we don't use it.
This commit is contained in:
psychedelicious 2024-02-17 19:56:13 +11:00 committed by Brandon Rising
parent d7adab89bd
commit ab83fb2cea
4 changed files with 178 additions and 336 deletions

View File

@ -184,10 +184,6 @@ class NodeIdMismatchError(ValueError):
pass
class InvalidSubGraphError(ValueError):
pass
class CyclicalGraphError(ValueError):
pass
@ -196,25 +192,6 @@ class UnknownGraphValidationError(ValueError):
pass
# TODO: Create and use an Empty output?
@invocation_output("graph_output")
class GraphInvocationOutput(BaseInvocationOutput):
pass
# TODO: Fill this out and move to invocations
@invocation("graph", version="1.0.0")
class GraphInvocation(BaseInvocation):
"""Execute a graph"""
# TODO: figure out how to create a default here
graph: "Graph" = InputField(description="The graph to run", default=None)
def invoke(self, context: InvocationContext) -> GraphInvocationOutput:
"""Invoke with provided services and return outputs."""
return GraphInvocationOutput()
@invocation_output("iterate_output")
class IterateInvocationOutput(BaseInvocationOutput):
"""Used to connect iteration outputs. Will be expanded to a specific output."""
@ -346,41 +323,21 @@ class Graph(BaseModel):
self.nodes[node.id] = node
def _get_graph_and_node(self, node_path: str) -> tuple["Graph", str]:
"""Returns the graph and node id for a node path."""
# Materialized graphs may have nodes at the top level
if node_path in self.nodes:
return (self, node_path)
node_id = node_path if "." not in node_path else node_path[: node_path.index(".")]
if node_id not in self.nodes:
raise NodeNotFoundError(f"Node {node_path} not found in graph")
node = self.nodes[node_id]
if not isinstance(node, GraphInvocation):
# There's more node path left but this isn't a graph - failure
raise NodeNotFoundError("Node path terminated early at a non-graph node")
return node.graph._get_graph_and_node(node_path[node_path.index(".") + 1 :])
def delete_node(self, node_path: str) -> None:
def delete_node(self, node_id: str) -> None:
"""Deletes a node from a graph"""
try:
graph, node_id = self._get_graph_and_node(node_path)
# Delete edges for this node
input_edges = self._get_input_edges_and_graphs(node_path)
output_edges = self._get_output_edges_and_graphs(node_path)
input_edges = self._get_input_edges(node_id)
output_edges = self._get_output_edges(node_id)
for edge_graph, _, edge in input_edges:
edge_graph.delete_edge(edge)
for edge in input_edges:
self.delete_edge(edge)
for edge_graph, _, edge in output_edges:
edge_graph.delete_edge(edge)
for edge in output_edges:
self.delete_edge(edge)
del graph.nodes[node_id]
del self.nodes[node_id]
except NodeNotFoundError:
pass # Ignore, not doesn't exist (should this throw?)
@ -430,13 +387,6 @@ class Graph(BaseModel):
if k != v.id:
raise NodeIdMismatchError(f"Node ids must match, got {k} and {v.id}")
# Validate all subgraphs
for gn in (n for n in self.nodes.values() if isinstance(n, GraphInvocation)):
try:
gn.graph.validate_self()
except Exception as e:
raise InvalidSubGraphError(f"Subgraph {gn.id} is invalid") from e
# Validate that all edges match nodes and fields in the graph
for edge in self.edges:
source_node = self.nodes.get(edge.source.node_id, None)
@ -498,7 +448,6 @@ class Graph(BaseModel):
except (
DuplicateNodeIdError,
NodeIdMismatchError,
InvalidSubGraphError,
NodeNotFoundError,
NodeFieldNotFoundError,
CyclicalGraphError,
@ -519,7 +468,7 @@ class Graph(BaseModel):
def _validate_edge(self, edge: Edge):
"""Validates that a new edge doesn't create a cycle in the graph"""
# Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly)
# Validate that the nodes exist
try:
from_node = self.get_node(edge.source.node_id)
to_node = self.get_node(edge.destination.node_id)
@ -586,171 +535,90 @@ class Graph(BaseModel):
f"Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
)
def has_node(self, node_path: str) -> bool:
def has_node(self, node_id: str) -> bool:
"""Determines whether or not a node exists in the graph."""
try:
n = self.get_node(node_path)
if n is not None:
return True
else:
return False
_ = self.get_node(node_id)
return True
except NodeNotFoundError:
return False
def get_node(self, node_path: str) -> BaseInvocation:
"""Gets a node from the graph using a node path."""
# Materialized graphs may have nodes at the top level
graph, node_id = self._get_graph_and_node(node_path)
return graph.nodes[node_id]
def get_node(self, node_id: str) -> BaseInvocation:
"""Gets a node from the graph."""
try:
return self.nodes[node_id]
except KeyError as e:
raise NodeNotFoundError(f"Node {node_id} not found in graph") from e
def _get_node_path(self, node_id: str, prefix: Optional[str] = None) -> str:
return node_id if prefix is None or prefix == "" else f"{prefix}.{node_id}"
def update_node(self, node_path: str, new_node: BaseInvocation) -> None:
def update_node(self, node_id: str, new_node: BaseInvocation) -> None:
"""Updates a node in the graph."""
graph, node_id = self._get_graph_and_node(node_path)
node = graph.nodes[node_id]
node = self.nodes[node_id]
# Ensure the node type matches the new node
if type(node) is not type(new_node):
raise TypeError(f"Node {node_path} is type {type(node)} but new node is type {type(new_node)}")
raise TypeError(f"Node {node_id} is type {type(node)} but new node is type {type(new_node)}")
# Ensure the new id is either the same or is not in the graph
prefix = None if "." not in node_path else node_path[: node_path.rindex(".")]
new_path = self._get_node_path(new_node.id, prefix=prefix)
if new_node.id != node.id and self.has_node(new_path):
raise NodeAlreadyInGraphError("Node with id {new_node.id} already exists in graph")
if new_node.id != node.id and self.has_node(new_node.id):
raise NodeAlreadyInGraphError(f"Node with id {new_node.id} already exists in graph")
# Set the new node in the graph
graph.nodes[new_node.id] = new_node
self.nodes[new_node.id] = new_node
if new_node.id != node.id:
input_edges = self._get_input_edges_and_graphs(node_path)
output_edges = self._get_output_edges_and_graphs(node_path)
input_edges = self._get_input_edges(node_id)
output_edges = self._get_output_edges(node_id)
# Delete node and all edges
graph.delete_node(node_path)
self.delete_node(node_id)
# Create new edges for each input and output
for graph, _, edge in input_edges:
# Remove the graph prefix from the node path
new_graph_node_path = (
new_node.id
if "." not in edge.destination.node_id
else f'{edge.destination.node_id[edge.destination.node_id.rindex("."):]}.{new_node.id}'
)
graph.add_edge(
for edge in input_edges:
self.add_edge(
Edge(
source=edge.source,
destination=EdgeConnection(node_id=new_graph_node_path, field=edge.destination.field),
destination=EdgeConnection(node_id=new_node.id, field=edge.destination.field),
)
)
for graph, _, edge in output_edges:
# Remove the graph prefix from the node path
new_graph_node_path = (
new_node.id
if "." not in edge.source.node_id
else f'{edge.source.node_id[edge.source.node_id.rindex("."):]}.{new_node.id}'
)
graph.add_edge(
for edge in output_edges:
self.add_edge(
Edge(
source=EdgeConnection(node_id=new_graph_node_path, field=edge.source.field),
source=EdgeConnection(node_id=new_node.id, field=edge.source.field),
destination=edge.destination,
)
)
def _get_input_edges(self, node_path: str, field: Optional[str] = None) -> list[Edge]:
"""Gets all input edges for a node"""
edges = self._get_input_edges_and_graphs(node_path)
def _get_input_edges(self, node_id: str, field: Optional[str] = None) -> list[Edge]:
"""Gets all input edges for a node. If field is provided, only edges to that field are returned."""
# Filter to edges that match the field
filtered_edges = (e for e in edges if field is None or e[2].destination.field == field)
edges = [e for e in self.edges if e.destination.node_id == node_id]
# Create full node paths for each edge
return [
Edge(
source=EdgeConnection(
node_id=self._get_node_path(e.source.node_id, prefix=prefix),
field=e.source.field,
),
destination=EdgeConnection(
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
field=e.destination.field,
),
)
for _, prefix, e in filtered_edges
]
if field is None:
return edges
def _get_input_edges_and_graphs(
self, node_path: str, prefix: Optional[str] = None
) -> list[tuple["Graph", Union[str, None], Edge]]:
"""Gets all input edges for a node along with the graph they are in and the graph's path"""
edges = []
filtered_edges = [e for e in edges if e.destination.field == field]
# Return any input edges that appear in this graph
edges.extend([(self, prefix, e) for e in self.edges if e.destination.node_id == node_path])
return filtered_edges
node_id = node_path if "." not in node_path else node_path[: node_path.index(".")]
node = self.nodes[node_id]
def _get_output_edges(self, node_id: str, field: Optional[str] = None) -> list[Edge]:
"""Gets all output edges for a node. If field is provided, only edges from that field are returned."""
edges = [e for e in self.edges if e.source.node_id == node_id]
if isinstance(node, GraphInvocation):
graph = node.graph
graph_path = node.id if prefix is None or prefix == "" else self._get_node_path(node.id, prefix=prefix)
graph_edges = graph._get_input_edges_and_graphs(node_path[(len(node_id) + 1) :], prefix=graph_path)
edges.extend(graph_edges)
if field is None:
return edges
return edges
filtered_edges = [e for e in edges if e.source.field == field]
def _get_output_edges(self, node_path: str, field: str) -> list[Edge]:
"""Gets all output edges for a node"""
edges = self._get_output_edges_and_graphs(node_path)
# Filter to edges that match the field
filtered_edges = (e for e in edges if e[2].source.field == field)
# Create full node paths for each edge
return [
Edge(
source=EdgeConnection(
node_id=self._get_node_path(e.source.node_id, prefix=prefix),
field=e.source.field,
),
destination=EdgeConnection(
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
field=e.destination.field,
),
)
for _, prefix, e in filtered_edges
]
def _get_output_edges_and_graphs(
self, node_path: str, prefix: Optional[str] = None
) -> list[tuple["Graph", Union[str, None], Edge]]:
"""Gets all output edges for a node along with the graph they are in and the graph's path"""
edges = []
# Return any input edges that appear in this graph
edges.extend([(self, prefix, e) for e in self.edges if e.source.node_id == node_path])
node_id = node_path if "." not in node_path else node_path[: node_path.index(".")]
node = self.nodes[node_id]
if isinstance(node, GraphInvocation):
graph = node.graph
graph_path = node.id if prefix is None or prefix == "" else self._get_node_path(node.id, prefix=prefix)
graph_edges = graph._get_output_edges_and_graphs(node_path[(len(node_id) + 1) :], prefix=graph_path)
edges.extend(graph_edges)
return edges
return filtered_edges
def _is_iterator_connection_valid(
self,
node_path: str,
node_id: str,
new_input: Optional[EdgeConnection] = None,
new_output: Optional[EdgeConnection] = None,
) -> bool:
inputs = [e.source for e in self._get_input_edges(node_path, "collection")]
outputs = [e.destination for e in self._get_output_edges(node_path, "item")]
inputs = [e.source for e in self._get_input_edges(node_id, "collection")]
outputs = [e.destination for e in self._get_output_edges(node_id, "item")]
if new_input is not None:
inputs.append(new_input)
@ -778,12 +646,12 @@ class Graph(BaseModel):
def _is_collector_connection_valid(
self,
node_path: str,
node_id: str,
new_input: Optional[EdgeConnection] = None,
new_output: Optional[EdgeConnection] = None,
) -> bool:
inputs = [e.source for e in self._get_input_edges(node_path, "item")]
outputs = [e.destination for e in self._get_output_edges(node_path, "collection")]
inputs = [e.source for e in self._get_input_edges(node_id, "item")]
outputs = [e.destination for e in self._get_output_edges(node_id, "collection")]
if new_input is not None:
inputs.append(new_input)
@ -839,27 +707,17 @@ class Graph(BaseModel):
g.add_edges_from({(e.source.node_id, e.destination.node_id) for e in self.edges})
return g
def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None) -> nx.DiGraph:
def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None) -> nx.DiGraph:
"""Returns a flattened NetworkX DiGraph, including all subgraphs (but not with iterations expanded)"""
g = nx_graph or nx.DiGraph()
# Add all nodes from this graph except graph/iteration nodes
g.add_nodes_from(
[
self._get_node_path(n.id, prefix)
for n in self.nodes.values()
if not isinstance(n, GraphInvocation) and not isinstance(n, IterateInvocation)
]
)
# Expand graph nodes
for sgn in (gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation)):
g = sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix))
g.add_nodes_from([n.id for n in self.nodes.values() if not isinstance(n, IterateInvocation)])
# TODO: figure out if iteration nodes need to be expanded
unique_edges = {(e.source.node_id, e.destination.node_id) for e in self.edges}
g.add_edges_from([(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix)) for e in unique_edges])
g.add_edges_from([(e[0], e[1]) for e in unique_edges])
return g
@ -1017,17 +875,17 @@ class GraphExecutionState(BaseModel):
"""Returns true if the graph has any errors"""
return len(self.errors) > 0
def _create_execution_node(self, node_path: str, iteration_node_map: list[tuple[str, str]]) -> list[str]:
def _create_execution_node(self, node_id: str, iteration_node_map: list[tuple[str, str]]) -> list[str]:
"""Prepares an iteration node and connects all edges, returning the new node id"""
node = self.graph.get_node(node_path)
node = self.graph.get_node(node_id)
self_iteration_count = -1
# If this is an iterator node, we must create a copy for each iteration
if isinstance(node, IterateInvocation):
# Get input collection edge (should error if there are no inputs)
input_collection_edge = next(iter(self.graph._get_input_edges(node_path, "collection")))
input_collection_edge = next(iter(self.graph._get_input_edges(node_id, "collection")))
input_collection_prepared_node_id = next(
n[1] for n in iteration_node_map if n[0] == input_collection_edge.source.node_id
)
@ -1041,7 +899,7 @@ class GraphExecutionState(BaseModel):
return new_nodes
# Get all input edges
input_edges = self.graph._get_input_edges(node_path)
input_edges = self.graph._get_input_edges(node_id)
# Create new edges for this iteration
# For collect nodes, this may contain multiple inputs to the same field
@ -1068,10 +926,10 @@ class GraphExecutionState(BaseModel):
# Add to execution graph
self.execution_graph.add_node(new_node)
self.prepared_source_mapping[new_node.id] = node_path
if node_path not in self.source_prepared_mapping:
self.source_prepared_mapping[node_path] = set()
self.source_prepared_mapping[node_path].add(new_node.id)
self.prepared_source_mapping[new_node.id] = node_id
if node_id not in self.source_prepared_mapping:
self.source_prepared_mapping[node_id] = set()
self.source_prepared_mapping[node_id].add(new_node.id)
# Add new edges to execution graph
for edge in new_edges:
@ -1175,13 +1033,13 @@ class GraphExecutionState(BaseModel):
def _get_iteration_node(
self,
source_node_path: str,
source_node_id: str,
graph: nx.DiGraph,
execution_graph: nx.DiGraph,
prepared_iterator_nodes: list[str],
) -> Optional[str]:
"""Gets the prepared version of the specified source node that matches every iteration specified"""
prepared_nodes = self.source_prepared_mapping[source_node_path]
prepared_nodes = self.source_prepared_mapping[source_node_id]
if len(prepared_nodes) == 1:
return next(iter(prepared_nodes))
@ -1192,7 +1050,7 @@ class GraphExecutionState(BaseModel):
# Filter to only iterator nodes that are a parent of the specified node, in tuple format (prepared, source)
iterator_source_node_mapping = [(n, self.prepared_source_mapping[n]) for n in prepared_iterator_nodes]
parent_iterators = [itn for itn in iterator_source_node_mapping if nx.has_path(graph, itn[1], source_node_path)]
parent_iterators = [itn for itn in iterator_source_node_mapping if nx.has_path(graph, itn[1], source_node_id)]
return next(
(n for n in prepared_nodes if all(nx.has_path(execution_graph, pit[0], n) for pit in parent_iterators)),
@ -1261,19 +1119,19 @@ class GraphExecutionState(BaseModel):
def add_node(self, node: BaseInvocation) -> None:
self.graph.add_node(node)
def update_node(self, node_path: str, new_node: BaseInvocation) -> None:
if not self._is_node_updatable(node_path):
def update_node(self, node_id: str, new_node: BaseInvocation) -> None:
if not self._is_node_updatable(node_id):
raise NodeAlreadyExecutedError(
f"Node {node_path} has already been prepared or executed and cannot be updated"
f"Node {node_id} has already been prepared or executed and cannot be updated"
)
self.graph.update_node(node_path, new_node)
self.graph.update_node(node_id, new_node)
def delete_node(self, node_path: str) -> None:
if not self._is_node_updatable(node_path):
def delete_node(self, node_id: str) -> None:
if not self._is_node_updatable(node_id):
raise NodeAlreadyExecutedError(
f"Node {node_path} has already been prepared or executed and cannot be deleted"
f"Node {node_id} has already been prepared or executed and cannot be deleted"
)
self.graph.delete_node(node_path)
self.graph.delete_node(node_id)
def add_edge(self, edge: Edge) -> None:
if not self._is_node_updatable(edge.destination.node_id):

View File

@ -23,7 +23,7 @@ from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation
from invokeai.app.services.shared.graph import Graph, GraphExecutionState
@pytest.fixture
@ -35,17 +35,6 @@ def simple_graph():
return g
@pytest.fixture
def graph_with_subgraph():
sub_g = Graph()
sub_g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi"))
sub_g.add_node(TextToImageTestInvocation(id="2"))
sub_g.add_edge(create_edge("1", "prompt", "2", "prompt"))
g = Graph()
g.add_node(GraphInvocation(id="1", graph=sub_g))
return g
# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
# the test invocations.

View File

@ -8,8 +8,6 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
invocation_output,
)
from invokeai.app.invocations.image import ShowImageInvocation
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
from invokeai.app.invocations.primitives import (
FloatCollectionInvocation,
FloatInvocation,
@ -17,13 +15,11 @@ from invokeai.app.invocations.primitives import (
StringInvocation,
)
from invokeai.app.invocations.upscale import ESRGANInvocation
from invokeai.app.services.shared.default_graphs import create_text_to_image
from invokeai.app.services.shared.graph import (
CollectInvocation,
Edge,
EdgeConnection,
Graph,
GraphInvocation,
InvalidEdgeError,
IterateInvocation,
NodeAlreadyInGraphError,
@ -425,19 +421,19 @@ def test_graph_invalid_if_edges_reference_missing_nodes():
assert g.is_valid() is False
def test_graph_invalid_if_subgraph_invalid():
g = Graph()
n1 = GraphInvocation(id="1")
n1.graph = Graph()
# def test_graph_invalid_if_subgraph_invalid():
# g = Graph()
# n1 = GraphInvocation(id="1")
# n1.graph = Graph()
n1_1 = TextToImageTestInvocation(id="2", prompt="Banana sushi")
n1.graph.nodes[n1_1.id] = n1_1
e1 = create_edge("1", "image", "2", "image")
n1.graph.edges.append(e1)
# n1_1 = TextToImageTestInvocation(id="2", prompt="Banana sushi")
# n1.graph.nodes[n1_1.id] = n1_1
# e1 = create_edge("1", "image", "2", "image")
# n1.graph.edges.append(e1)
g.nodes[n1.id] = n1
# g.nodes[n1.id] = n1
assert g.is_valid() is False
# assert g.is_valid() is False
def test_graph_invalid_if_has_cycle():
@ -466,108 +462,108 @@ def test_graph_invalid_with_invalid_connection():
assert g.is_valid() is False
# TODO: Subgraph operations
def test_graph_gets_subgraph_node():
g = Graph()
n1 = GraphInvocation(id="1")
n1.graph = Graph()
# # TODO: Subgraph operations
# def test_graph_gets_subgraph_node():
# g = Graph()
# n1 = GraphInvocation(id="1")
# n1.graph = Graph()
n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
n1.graph.add_node(n1_1)
# n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
# n1.graph.add_node(n1_1)
g.add_node(n1)
# g.add_node(n1)
result = g.get_node("1.1")
# result = g.get_node("1.1")
assert result is not None
assert result.id == "1"
assert result == n1_1
# assert result is not None
# assert result.id == "1"
# assert result == n1_1
def test_graph_expands_subgraph():
g = Graph()
n1 = GraphInvocation(id="1")
n1.graph = Graph()
# def test_graph_expands_subgraph():
# g = Graph()
# n1 = GraphInvocation(id="1")
# n1.graph = Graph()
n1_1 = AddInvocation(id="1", a=1, b=2)
n1_2 = SubtractInvocation(id="2", b=3)
n1.graph.add_node(n1_1)
n1.graph.add_node(n1_2)
n1.graph.add_edge(create_edge("1", "value", "2", "a"))
# n1_1 = AddInvocation(id="1", a=1, b=2)
# n1_2 = SubtractInvocation(id="2", b=3)
# n1.graph.add_node(n1_1)
# n1.graph.add_node(n1_2)
# n1.graph.add_edge(create_edge("1", "value", "2", "a"))
g.add_node(n1)
# g.add_node(n1)
n2 = AddInvocation(id="2", b=5)
g.add_node(n2)
g.add_edge(create_edge("1.2", "value", "2", "a"))
# n2 = AddInvocation(id="2", b=5)
# g.add_node(n2)
# g.add_edge(create_edge("1.2", "value", "2", "a"))
dg = g.nx_graph_flat()
assert set(dg.nodes) == {"1.1", "1.2", "2"}
assert set(dg.edges) == {("1.1", "1.2"), ("1.2", "2")}
# dg = g.nx_graph_flat()
# assert set(dg.nodes) == {"1.1", "1.2", "2"}
# assert set(dg.edges) == {("1.1", "1.2"), ("1.2", "2")}
def test_graph_subgraph_t2i():
g = Graph()
n1 = GraphInvocation(id="1")
# def test_graph_subgraph_t2i():
# g = Graph()
# n1 = GraphInvocation(id="1")
# Get text to image default graph
lg = create_text_to_image()
n1.graph = lg.graph
# # Get text to image default graph
# lg = create_text_to_image()
# n1.graph = lg.graph
g.add_node(n1)
# g.add_node(n1)
n2 = IntegerInvocation(id="2", value=512)
n3 = IntegerInvocation(id="3", value=256)
# n2 = IntegerInvocation(id="2", value=512)
# n3 = IntegerInvocation(id="3", value=256)
g.add_node(n2)
g.add_node(n3)
# g.add_node(n2)
# g.add_node(n3)
g.add_edge(create_edge("2", "value", "1.width", "value"))
g.add_edge(create_edge("3", "value", "1.height", "value"))
# g.add_edge(create_edge("2", "value", "1.width", "value"))
# g.add_edge(create_edge("3", "value", "1.height", "value"))
n4 = ShowImageInvocation(id="4")
g.add_node(n4)
g.add_edge(create_edge("1.8", "image", "4", "image"))
# n4 = ShowImageInvocation(id="4")
# g.add_node(n4)
# g.add_edge(create_edge("1.8", "image", "4", "image"))
# Validate
dg = g.nx_graph_flat()
assert set(dg.nodes) == {"1.width", "1.height", "1.seed", "1.3", "1.4", "1.5", "1.6", "1.7", "1.8", "2", "3", "4"}
expected_edges = [(f"1.{e.source.node_id}", f"1.{e.destination.node_id}") for e in lg.graph.edges]
expected_edges.extend([("2", "1.width"), ("3", "1.height"), ("1.8", "4")])
print(expected_edges)
print(list(dg.edges))
assert set(dg.edges) == set(expected_edges)
# # Validate
# dg = g.nx_graph_flat()
# assert set(dg.nodes) == {"1.width", "1.height", "1.seed", "1.3", "1.4", "1.5", "1.6", "1.7", "1.8", "2", "3", "4"}
# expected_edges = [(f"1.{e.source.node_id}", f"1.{e.destination.node_id}") for e in lg.graph.edges]
# expected_edges.extend([("2", "1.width"), ("3", "1.height"), ("1.8", "4")])
# print(expected_edges)
# print(list(dg.edges))
# assert set(dg.edges) == set(expected_edges)
def test_graph_fails_to_get_missing_subgraph_node():
g = Graph()
n1 = GraphInvocation(id="1")
n1.graph = Graph()
# def test_graph_fails_to_get_missing_subgraph_node():
# g = Graph()
# n1 = GraphInvocation(id="1")
# n1.graph = Graph()
n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
n1.graph.add_node(n1_1)
# n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
# n1.graph.add_node(n1_1)
g.add_node(n1)
# g.add_node(n1)
with pytest.raises(NodeNotFoundError):
_ = g.get_node("1.2")
# with pytest.raises(NodeNotFoundError):
# _ = g.get_node("1.2")
def test_graph_fails_to_enumerate_non_subgraph_node():
g = Graph()
n1 = GraphInvocation(id="1")
n1.graph = Graph()
# def test_graph_fails_to_enumerate_non_subgraph_node():
# g = Graph()
# n1 = GraphInvocation(id="1")
# n1.graph = Graph()
n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
n1.graph.add_node(n1_1)
# n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
# n1.graph.add_node(n1_1)
g.add_node(n1)
# g.add_node(n1)
n2 = ESRGANInvocation(id="2")
g.add_node(n2)
# n2 = ESRGANInvocation(id="2")
# g.add_node(n2)
with pytest.raises(NodeNotFoundError):
_ = g.get_node("2.1")
# with pytest.raises(NodeNotFoundError):
# _ = g.get_node("2.1")
def test_graph_gets_networkx_graph():

View File

@ -8,10 +8,9 @@ from invokeai.app.services.session_queue.session_queue_common import (
NodeFieldValue,
calc_session_count,
create_session_nfv_tuples,
populate_graph,
prepare_values_to_insert,
)
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation
from invokeai.app.services.shared.graph import Graph, GraphExecutionState
from tests.aa_nodes.test_nodes import PromptTestInvocation
@ -39,28 +38,28 @@ def batch_graph() -> Graph:
return g
def test_populate_graph_with_subgraph():
g1 = Graph()
g1.add_node(PromptTestInvocation(id="1", prompt="Banana sushi"))
g1.add_node(PromptTestInvocation(id="2", prompt="Banana sushi"))
n1 = PromptTestInvocation(id="1", prompt="Banana snake")
subgraph = Graph()
subgraph.add_node(n1)
g1.add_node(GraphInvocation(id="3", graph=subgraph))
# def test_populate_graph_with_subgraph():
# g1 = Graph()
# g1.add_node(PromptTestInvocation(id="1", prompt="Banana sushi"))
# g1.add_node(PromptTestInvocation(id="2", prompt="Banana sushi"))
# n1 = PromptTestInvocation(id="1", prompt="Banana snake")
# subgraph = Graph()
# subgraph.add_node(n1)
# g1.add_node(GraphInvocation(id="3", graph=subgraph))
nfvs = [
NodeFieldValue(node_path="1", field_name="prompt", value="Strawberry sushi"),
NodeFieldValue(node_path="2", field_name="prompt", value="Strawberry sunday"),
NodeFieldValue(node_path="3.1", field_name="prompt", value="Strawberry snake"),
]
# nfvs = [
# NodeFieldValue(node_path="1", field_name="prompt", value="Strawberry sushi"),
# NodeFieldValue(node_path="2", field_name="prompt", value="Strawberry sunday"),
# NodeFieldValue(node_path="3.1", field_name="prompt", value="Strawberry snake"),
# ]
g2 = populate_graph(g1, nfvs)
# g2 = populate_graph(g1, nfvs)
# do not mutate g1
assert g1 is not g2
assert g2.get_node("1").prompt == "Strawberry sushi"
assert g2.get_node("2").prompt == "Strawberry sunday"
assert g2.get_node("3.1").prompt == "Strawberry snake"
# # do not mutate g1
# assert g1 is not g2
# assert g2.get_node("1").prompt == "Strawberry sushi"
# assert g2.get_node("2").prompt == "Strawberry sunday"
# assert g2.get_node("3.1").prompt == "Strawberry snake"
def test_create_sessions_from_batch_with_runs(batch_data_collection, batch_graph):