tidy(nodes): remove GraphInvocation

`GraphInvocation` is a node that can contain a whole graph. It is removed for a number of reasons:

1. This feature was unused (the UI doesn't support it) and there is no plan for it to be used.

The use-case it served is known in other node execution engines as "node groups" or "blocks" - a self-contained group of nodes, which has group inputs and outputs. This is a planned feature that will be handled client-side.

2. It adds substantial complexity to the graph processing logic. It's probably not enough to have a measurable performance impact but it does make it harder to work in the graph logic.

3. It allows for graphs to be recursive, and the improved invocations union handling does not play well with it. Actually, it works fine within `graph.py` but not in the tests for some reason. I do not understand why. There's probably a workaround, but I took this as encouragement to remove `GraphInvocation` from the app since we don't use it.
This commit is contained in:
psychedelicious 2024-02-17 19:56:13 +11:00
parent 47b5a90177
commit 5fc745653a
4 changed files with 178 additions and 336 deletions

View File

@ -184,10 +184,6 @@ class NodeIdMismatchError(ValueError):
pass pass
class InvalidSubGraphError(ValueError):
pass
class CyclicalGraphError(ValueError): class CyclicalGraphError(ValueError):
pass pass
@ -196,25 +192,6 @@ class UnknownGraphValidationError(ValueError):
pass pass
# TODO: Create and use an Empty output?
@invocation_output("graph_output")
class GraphInvocationOutput(BaseInvocationOutput):
pass
# TODO: Fill this out and move to invocations
@invocation("graph", version="1.0.0")
class GraphInvocation(BaseInvocation):
"""Execute a graph"""
# TODO: figure out how to create a default here
graph: "Graph" = InputField(description="The graph to run", default=None)
def invoke(self, context: InvocationContext) -> GraphInvocationOutput:
"""Invoke with provided services and return outputs."""
return GraphInvocationOutput()
@invocation_output("iterate_output") @invocation_output("iterate_output")
class IterateInvocationOutput(BaseInvocationOutput): class IterateInvocationOutput(BaseInvocationOutput):
"""Used to connect iteration outputs. Will be expanded to a specific output.""" """Used to connect iteration outputs. Will be expanded to a specific output."""
@ -346,41 +323,21 @@ class Graph(BaseModel):
self.nodes[node.id] = node self.nodes[node.id] = node
def _get_graph_and_node(self, node_path: str) -> tuple["Graph", str]: def delete_node(self, node_id: str) -> None:
"""Returns the graph and node id for a node path."""
# Materialized graphs may have nodes at the top level
if node_path in self.nodes:
return (self, node_path)
node_id = node_path if "." not in node_path else node_path[: node_path.index(".")]
if node_id not in self.nodes:
raise NodeNotFoundError(f"Node {node_path} not found in graph")
node = self.nodes[node_id]
if not isinstance(node, GraphInvocation):
# There's more node path left but this isn't a graph - failure
raise NodeNotFoundError("Node path terminated early at a non-graph node")
return node.graph._get_graph_and_node(node_path[node_path.index(".") + 1 :])
def delete_node(self, node_path: str) -> None:
"""Deletes a node from a graph""" """Deletes a node from a graph"""
try: try:
graph, node_id = self._get_graph_and_node(node_path)
# Delete edges for this node # Delete edges for this node
input_edges = self._get_input_edges_and_graphs(node_path) input_edges = self._get_input_edges(node_id)
output_edges = self._get_output_edges_and_graphs(node_path) output_edges = self._get_output_edges(node_id)
for edge_graph, _, edge in input_edges: for edge in input_edges:
edge_graph.delete_edge(edge) self.delete_edge(edge)
for edge_graph, _, edge in output_edges: for edge in output_edges:
edge_graph.delete_edge(edge) self.delete_edge(edge)
del graph.nodes[node_id] del self.nodes[node_id]
except NodeNotFoundError: except NodeNotFoundError:
pass # Ignore, not doesn't exist (should this throw?) pass # Ignore, not doesn't exist (should this throw?)
@ -430,13 +387,6 @@ class Graph(BaseModel):
if k != v.id: if k != v.id:
raise NodeIdMismatchError(f"Node ids must match, got {k} and {v.id}") raise NodeIdMismatchError(f"Node ids must match, got {k} and {v.id}")
# Validate all subgraphs
for gn in (n for n in self.nodes.values() if isinstance(n, GraphInvocation)):
try:
gn.graph.validate_self()
except Exception as e:
raise InvalidSubGraphError(f"Subgraph {gn.id} is invalid") from e
# Validate that all edges match nodes and fields in the graph # Validate that all edges match nodes and fields in the graph
for edge in self.edges: for edge in self.edges:
source_node = self.nodes.get(edge.source.node_id, None) source_node = self.nodes.get(edge.source.node_id, None)
@ -498,7 +448,6 @@ class Graph(BaseModel):
except ( except (
DuplicateNodeIdError, DuplicateNodeIdError,
NodeIdMismatchError, NodeIdMismatchError,
InvalidSubGraphError,
NodeNotFoundError, NodeNotFoundError,
NodeFieldNotFoundError, NodeFieldNotFoundError,
CyclicalGraphError, CyclicalGraphError,
@ -519,7 +468,7 @@ class Graph(BaseModel):
def _validate_edge(self, edge: Edge): def _validate_edge(self, edge: Edge):
"""Validates that a new edge doesn't create a cycle in the graph""" """Validates that a new edge doesn't create a cycle in the graph"""
# Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly) # Validate that the nodes exist
try: try:
from_node = self.get_node(edge.source.node_id) from_node = self.get_node(edge.source.node_id)
to_node = self.get_node(edge.destination.node_id) to_node = self.get_node(edge.destination.node_id)
@ -586,171 +535,90 @@ class Graph(BaseModel):
f"Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}" f"Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
) )
def has_node(self, node_path: str) -> bool: def has_node(self, node_id: str) -> bool:
"""Determines whether or not a node exists in the graph.""" """Determines whether or not a node exists in the graph."""
try: try:
n = self.get_node(node_path) _ = self.get_node(node_id)
if n is not None: return True
return True
else:
return False
except NodeNotFoundError: except NodeNotFoundError:
return False return False
def get_node(self, node_path: str) -> BaseInvocation: def get_node(self, node_id: str) -> BaseInvocation:
"""Gets a node from the graph using a node path.""" """Gets a node from the graph."""
# Materialized graphs may have nodes at the top level try:
graph, node_id = self._get_graph_and_node(node_path) return self.nodes[node_id]
return graph.nodes[node_id] except KeyError as e:
raise NodeNotFoundError(f"Node {node_id} not found in graph") from e
def _get_node_path(self, node_id: str, prefix: Optional[str] = None) -> str: def update_node(self, node_id: str, new_node: BaseInvocation) -> None:
return node_id if prefix is None or prefix == "" else f"{prefix}.{node_id}"
def update_node(self, node_path: str, new_node: BaseInvocation) -> None:
"""Updates a node in the graph.""" """Updates a node in the graph."""
graph, node_id = self._get_graph_and_node(node_path) node = self.nodes[node_id]
node = graph.nodes[node_id]
# Ensure the node type matches the new node # Ensure the node type matches the new node
if type(node) is not type(new_node): if type(node) is not type(new_node):
raise TypeError(f"Node {node_path} is type {type(node)} but new node is type {type(new_node)}") raise TypeError(f"Node {node_id} is type {type(node)} but new node is type {type(new_node)}")
# Ensure the new id is either the same or is not in the graph # Ensure the new id is either the same or is not in the graph
prefix = None if "." not in node_path else node_path[: node_path.rindex(".")] if new_node.id != node.id and self.has_node(new_node.id):
new_path = self._get_node_path(new_node.id, prefix=prefix) raise NodeAlreadyInGraphError(f"Node with id {new_node.id} already exists in graph")
if new_node.id != node.id and self.has_node(new_path):
raise NodeAlreadyInGraphError("Node with id {new_node.id} already exists in graph")
# Set the new node in the graph # Set the new node in the graph
graph.nodes[new_node.id] = new_node self.nodes[new_node.id] = new_node
if new_node.id != node.id: if new_node.id != node.id:
input_edges = self._get_input_edges_and_graphs(node_path) input_edges = self._get_input_edges(node_id)
output_edges = self._get_output_edges_and_graphs(node_path) output_edges = self._get_output_edges(node_id)
# Delete node and all edges # Delete node and all edges
graph.delete_node(node_path) self.delete_node(node_id)
# Create new edges for each input and output # Create new edges for each input and output
for graph, _, edge in input_edges: for edge in input_edges:
# Remove the graph prefix from the node path self.add_edge(
new_graph_node_path = (
new_node.id
if "." not in edge.destination.node_id
else f'{edge.destination.node_id[edge.destination.node_id.rindex("."):]}.{new_node.id}'
)
graph.add_edge(
Edge( Edge(
source=edge.source, source=edge.source,
destination=EdgeConnection(node_id=new_graph_node_path, field=edge.destination.field), destination=EdgeConnection(node_id=new_node.id, field=edge.destination.field),
) )
) )
for graph, _, edge in output_edges: for edge in output_edges:
# Remove the graph prefix from the node path self.add_edge(
new_graph_node_path = (
new_node.id
if "." not in edge.source.node_id
else f'{edge.source.node_id[edge.source.node_id.rindex("."):]}.{new_node.id}'
)
graph.add_edge(
Edge( Edge(
source=EdgeConnection(node_id=new_graph_node_path, field=edge.source.field), source=EdgeConnection(node_id=new_node.id, field=edge.source.field),
destination=edge.destination, destination=edge.destination,
) )
) )
def _get_input_edges(self, node_path: str, field: Optional[str] = None) -> list[Edge]: def _get_input_edges(self, node_id: str, field: Optional[str] = None) -> list[Edge]:
"""Gets all input edges for a node""" """Gets all input edges for a node. If field is provided, only edges to that field are returned."""
edges = self._get_input_edges_and_graphs(node_path)
# Filter to edges that match the field edges = [e for e in self.edges if e.destination.node_id == node_id]
filtered_edges = (e for e in edges if field is None or e[2].destination.field == field)
# Create full node paths for each edge if field is None:
return [ return edges
Edge(
source=EdgeConnection(
node_id=self._get_node_path(e.source.node_id, prefix=prefix),
field=e.source.field,
),
destination=EdgeConnection(
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
field=e.destination.field,
),
)
for _, prefix, e in filtered_edges
]
def _get_input_edges_and_graphs( filtered_edges = [e for e in edges if e.destination.field == field]
self, node_path: str, prefix: Optional[str] = None
) -> list[tuple["Graph", Union[str, None], Edge]]:
"""Gets all input edges for a node along with the graph they are in and the graph's path"""
edges = []
# Return any input edges that appear in this graph return filtered_edges
edges.extend([(self, prefix, e) for e in self.edges if e.destination.node_id == node_path])
node_id = node_path if "." not in node_path else node_path[: node_path.index(".")] def _get_output_edges(self, node_id: str, field: Optional[str] = None) -> list[Edge]:
node = self.nodes[node_id] """Gets all output edges for a node. If field is provided, only edges from that field are returned."""
edges = [e for e in self.edges if e.source.node_id == node_id]
if isinstance(node, GraphInvocation): if field is None:
graph = node.graph return edges
graph_path = node.id if prefix is None or prefix == "" else self._get_node_path(node.id, prefix=prefix)
graph_edges = graph._get_input_edges_and_graphs(node_path[(len(node_id) + 1) :], prefix=graph_path)
edges.extend(graph_edges)
return edges filtered_edges = [e for e in edges if e.source.field == field]
def _get_output_edges(self, node_path: str, field: str) -> list[Edge]: return filtered_edges
"""Gets all output edges for a node"""
edges = self._get_output_edges_and_graphs(node_path)
# Filter to edges that match the field
filtered_edges = (e for e in edges if e[2].source.field == field)
# Create full node paths for each edge
return [
Edge(
source=EdgeConnection(
node_id=self._get_node_path(e.source.node_id, prefix=prefix),
field=e.source.field,
),
destination=EdgeConnection(
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
field=e.destination.field,
),
)
for _, prefix, e in filtered_edges
]
def _get_output_edges_and_graphs(
self, node_path: str, prefix: Optional[str] = None
) -> list[tuple["Graph", Union[str, None], Edge]]:
"""Gets all output edges for a node along with the graph they are in and the graph's path"""
edges = []
# Return any input edges that appear in this graph
edges.extend([(self, prefix, e) for e in self.edges if e.source.node_id == node_path])
node_id = node_path if "." not in node_path else node_path[: node_path.index(".")]
node = self.nodes[node_id]
if isinstance(node, GraphInvocation):
graph = node.graph
graph_path = node.id if prefix is None or prefix == "" else self._get_node_path(node.id, prefix=prefix)
graph_edges = graph._get_output_edges_and_graphs(node_path[(len(node_id) + 1) :], prefix=graph_path)
edges.extend(graph_edges)
return edges
def _is_iterator_connection_valid( def _is_iterator_connection_valid(
self, self,
node_path: str, node_id: str,
new_input: Optional[EdgeConnection] = None, new_input: Optional[EdgeConnection] = None,
new_output: Optional[EdgeConnection] = None, new_output: Optional[EdgeConnection] = None,
) -> bool: ) -> bool:
inputs = [e.source for e in self._get_input_edges(node_path, "collection")] inputs = [e.source for e in self._get_input_edges(node_id, "collection")]
outputs = [e.destination for e in self._get_output_edges(node_path, "item")] outputs = [e.destination for e in self._get_output_edges(node_id, "item")]
if new_input is not None: if new_input is not None:
inputs.append(new_input) inputs.append(new_input)
@ -778,12 +646,12 @@ class Graph(BaseModel):
def _is_collector_connection_valid( def _is_collector_connection_valid(
self, self,
node_path: str, node_id: str,
new_input: Optional[EdgeConnection] = None, new_input: Optional[EdgeConnection] = None,
new_output: Optional[EdgeConnection] = None, new_output: Optional[EdgeConnection] = None,
) -> bool: ) -> bool:
inputs = [e.source for e in self._get_input_edges(node_path, "item")] inputs = [e.source for e in self._get_input_edges(node_id, "item")]
outputs = [e.destination for e in self._get_output_edges(node_path, "collection")] outputs = [e.destination for e in self._get_output_edges(node_id, "collection")]
if new_input is not None: if new_input is not None:
inputs.append(new_input) inputs.append(new_input)
@ -839,27 +707,17 @@ class Graph(BaseModel):
g.add_edges_from({(e.source.node_id, e.destination.node_id) for e in self.edges}) g.add_edges_from({(e.source.node_id, e.destination.node_id) for e in self.edges})
return g return g
def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None) -> nx.DiGraph: def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None) -> nx.DiGraph:
"""Returns a flattened NetworkX DiGraph, including all subgraphs (but not with iterations expanded)""" """Returns a flattened NetworkX DiGraph, including all subgraphs (but not with iterations expanded)"""
g = nx_graph or nx.DiGraph() g = nx_graph or nx.DiGraph()
# Add all nodes from this graph except graph/iteration nodes # Add all nodes from this graph except graph/iteration nodes
g.add_nodes_from( g.add_nodes_from([n.id for n in self.nodes.values() if not isinstance(n, IterateInvocation)])
[
self._get_node_path(n.id, prefix)
for n in self.nodes.values()
if not isinstance(n, GraphInvocation) and not isinstance(n, IterateInvocation)
]
)
# Expand graph nodes
for sgn in (gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation)):
g = sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix))
# TODO: figure out if iteration nodes need to be expanded # TODO: figure out if iteration nodes need to be expanded
unique_edges = {(e.source.node_id, e.destination.node_id) for e in self.edges} unique_edges = {(e.source.node_id, e.destination.node_id) for e in self.edges}
g.add_edges_from([(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix)) for e in unique_edges]) g.add_edges_from([(e[0], e[1]) for e in unique_edges])
return g return g
@ -1017,17 +875,17 @@ class GraphExecutionState(BaseModel):
"""Returns true if the graph has any errors""" """Returns true if the graph has any errors"""
return len(self.errors) > 0 return len(self.errors) > 0
def _create_execution_node(self, node_path: str, iteration_node_map: list[tuple[str, str]]) -> list[str]: def _create_execution_node(self, node_id: str, iteration_node_map: list[tuple[str, str]]) -> list[str]:
"""Prepares an iteration node and connects all edges, returning the new node id""" """Prepares an iteration node and connects all edges, returning the new node id"""
node = self.graph.get_node(node_path) node = self.graph.get_node(node_id)
self_iteration_count = -1 self_iteration_count = -1
# If this is an iterator node, we must create a copy for each iteration # If this is an iterator node, we must create a copy for each iteration
if isinstance(node, IterateInvocation): if isinstance(node, IterateInvocation):
# Get input collection edge (should error if there are no inputs) # Get input collection edge (should error if there are no inputs)
input_collection_edge = next(iter(self.graph._get_input_edges(node_path, "collection"))) input_collection_edge = next(iter(self.graph._get_input_edges(node_id, "collection")))
input_collection_prepared_node_id = next( input_collection_prepared_node_id = next(
n[1] for n in iteration_node_map if n[0] == input_collection_edge.source.node_id n[1] for n in iteration_node_map if n[0] == input_collection_edge.source.node_id
) )
@ -1041,7 +899,7 @@ class GraphExecutionState(BaseModel):
return new_nodes return new_nodes
# Get all input edges # Get all input edges
input_edges = self.graph._get_input_edges(node_path) input_edges = self.graph._get_input_edges(node_id)
# Create new edges for this iteration # Create new edges for this iteration
# For collect nodes, this may contain multiple inputs to the same field # For collect nodes, this may contain multiple inputs to the same field
@ -1068,10 +926,10 @@ class GraphExecutionState(BaseModel):
# Add to execution graph # Add to execution graph
self.execution_graph.add_node(new_node) self.execution_graph.add_node(new_node)
self.prepared_source_mapping[new_node.id] = node_path self.prepared_source_mapping[new_node.id] = node_id
if node_path not in self.source_prepared_mapping: if node_id not in self.source_prepared_mapping:
self.source_prepared_mapping[node_path] = set() self.source_prepared_mapping[node_id] = set()
self.source_prepared_mapping[node_path].add(new_node.id) self.source_prepared_mapping[node_id].add(new_node.id)
# Add new edges to execution graph # Add new edges to execution graph
for edge in new_edges: for edge in new_edges:
@ -1175,13 +1033,13 @@ class GraphExecutionState(BaseModel):
def _get_iteration_node( def _get_iteration_node(
self, self,
source_node_path: str, source_node_id: str,
graph: nx.DiGraph, graph: nx.DiGraph,
execution_graph: nx.DiGraph, execution_graph: nx.DiGraph,
prepared_iterator_nodes: list[str], prepared_iterator_nodes: list[str],
) -> Optional[str]: ) -> Optional[str]:
"""Gets the prepared version of the specified source node that matches every iteration specified""" """Gets the prepared version of the specified source node that matches every iteration specified"""
prepared_nodes = self.source_prepared_mapping[source_node_path] prepared_nodes = self.source_prepared_mapping[source_node_id]
if len(prepared_nodes) == 1: if len(prepared_nodes) == 1:
return next(iter(prepared_nodes)) return next(iter(prepared_nodes))
@ -1192,7 +1050,7 @@ class GraphExecutionState(BaseModel):
# Filter to only iterator nodes that are a parent of the specified node, in tuple format (prepared, source) # Filter to only iterator nodes that are a parent of the specified node, in tuple format (prepared, source)
iterator_source_node_mapping = [(n, self.prepared_source_mapping[n]) for n in prepared_iterator_nodes] iterator_source_node_mapping = [(n, self.prepared_source_mapping[n]) for n in prepared_iterator_nodes]
parent_iterators = [itn for itn in iterator_source_node_mapping if nx.has_path(graph, itn[1], source_node_path)] parent_iterators = [itn for itn in iterator_source_node_mapping if nx.has_path(graph, itn[1], source_node_id)]
return next( return next(
(n for n in prepared_nodes if all(nx.has_path(execution_graph, pit[0], n) for pit in parent_iterators)), (n for n in prepared_nodes if all(nx.has_path(execution_graph, pit[0], n) for pit in parent_iterators)),
@ -1261,19 +1119,19 @@ class GraphExecutionState(BaseModel):
def add_node(self, node: BaseInvocation) -> None: def add_node(self, node: BaseInvocation) -> None:
self.graph.add_node(node) self.graph.add_node(node)
def update_node(self, node_path: str, new_node: BaseInvocation) -> None: def update_node(self, node_id: str, new_node: BaseInvocation) -> None:
if not self._is_node_updatable(node_path): if not self._is_node_updatable(node_id):
raise NodeAlreadyExecutedError( raise NodeAlreadyExecutedError(
f"Node {node_path} has already been prepared or executed and cannot be updated" f"Node {node_id} has already been prepared or executed and cannot be updated"
) )
self.graph.update_node(node_path, new_node) self.graph.update_node(node_id, new_node)
def delete_node(self, node_path: str) -> None: def delete_node(self, node_id: str) -> None:
if not self._is_node_updatable(node_path): if not self._is_node_updatable(node_id):
raise NodeAlreadyExecutedError( raise NodeAlreadyExecutedError(
f"Node {node_path} has already been prepared or executed and cannot be deleted" f"Node {node_id} has already been prepared or executed and cannot be deleted"
) )
self.graph.delete_node(node_path) self.graph.delete_node(node_id)
def add_edge(self, edge: Edge) -> None: def add_edge(self, edge: Edge) -> None:
if not self._is_node_updatable(edge.destination.node_id): if not self._is_node_updatable(edge.destination.node_id):

View File

@ -23,7 +23,7 @@ from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation from invokeai.app.services.shared.graph import Graph, GraphExecutionState
@pytest.fixture @pytest.fixture
@ -35,17 +35,6 @@ def simple_graph():
return g return g
@pytest.fixture
def graph_with_subgraph():
sub_g = Graph()
sub_g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi"))
sub_g.add_node(TextToImageTestInvocation(id="2"))
sub_g.add_edge(create_edge("1", "prompt", "2", "prompt"))
g = Graph()
g.add_node(GraphInvocation(id="1", graph=sub_g))
return g
# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types # This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate # Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
# the test invocations. # the test invocations.

View File

@ -8,8 +8,6 @@ from invokeai.app.invocations.baseinvocation import (
invocation, invocation,
invocation_output, invocation_output,
) )
from invokeai.app.invocations.image import ShowImageInvocation
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
from invokeai.app.invocations.primitives import ( from invokeai.app.invocations.primitives import (
FloatCollectionInvocation, FloatCollectionInvocation,
FloatInvocation, FloatInvocation,
@ -17,13 +15,11 @@ from invokeai.app.invocations.primitives import (
StringInvocation, StringInvocation,
) )
from invokeai.app.invocations.upscale import ESRGANInvocation from invokeai.app.invocations.upscale import ESRGANInvocation
from invokeai.app.services.shared.default_graphs import create_text_to_image
from invokeai.app.services.shared.graph import ( from invokeai.app.services.shared.graph import (
CollectInvocation, CollectInvocation,
Edge, Edge,
EdgeConnection, EdgeConnection,
Graph, Graph,
GraphInvocation,
InvalidEdgeError, InvalidEdgeError,
IterateInvocation, IterateInvocation,
NodeAlreadyInGraphError, NodeAlreadyInGraphError,
@ -425,19 +421,19 @@ def test_graph_invalid_if_edges_reference_missing_nodes():
assert g.is_valid() is False assert g.is_valid() is False
def test_graph_invalid_if_subgraph_invalid(): # def test_graph_invalid_if_subgraph_invalid():
g = Graph() # g = Graph()
n1 = GraphInvocation(id="1") # n1 = GraphInvocation(id="1")
n1.graph = Graph() # n1.graph = Graph()
n1_1 = TextToImageTestInvocation(id="2", prompt="Banana sushi") # n1_1 = TextToImageTestInvocation(id="2", prompt="Banana sushi")
n1.graph.nodes[n1_1.id] = n1_1 # n1.graph.nodes[n1_1.id] = n1_1
e1 = create_edge("1", "image", "2", "image") # e1 = create_edge("1", "image", "2", "image")
n1.graph.edges.append(e1) # n1.graph.edges.append(e1)
g.nodes[n1.id] = n1 # g.nodes[n1.id] = n1
assert g.is_valid() is False # assert g.is_valid() is False
def test_graph_invalid_if_has_cycle(): def test_graph_invalid_if_has_cycle():
@ -466,108 +462,108 @@ def test_graph_invalid_with_invalid_connection():
assert g.is_valid() is False assert g.is_valid() is False
# TODO: Subgraph operations # # TODO: Subgraph operations
def test_graph_gets_subgraph_node(): # def test_graph_gets_subgraph_node():
g = Graph() # g = Graph()
n1 = GraphInvocation(id="1") # n1 = GraphInvocation(id="1")
n1.graph = Graph() # n1.graph = Graph()
n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") # n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
n1.graph.add_node(n1_1) # n1.graph.add_node(n1_1)
g.add_node(n1) # g.add_node(n1)
result = g.get_node("1.1") # result = g.get_node("1.1")
assert result is not None # assert result is not None
assert result.id == "1" # assert result.id == "1"
assert result == n1_1 # assert result == n1_1
def test_graph_expands_subgraph(): # def test_graph_expands_subgraph():
g = Graph() # g = Graph()
n1 = GraphInvocation(id="1") # n1 = GraphInvocation(id="1")
n1.graph = Graph() # n1.graph = Graph()
n1_1 = AddInvocation(id="1", a=1, b=2) # n1_1 = AddInvocation(id="1", a=1, b=2)
n1_2 = SubtractInvocation(id="2", b=3) # n1_2 = SubtractInvocation(id="2", b=3)
n1.graph.add_node(n1_1) # n1.graph.add_node(n1_1)
n1.graph.add_node(n1_2) # n1.graph.add_node(n1_2)
n1.graph.add_edge(create_edge("1", "value", "2", "a")) # n1.graph.add_edge(create_edge("1", "value", "2", "a"))
g.add_node(n1) # g.add_node(n1)
n2 = AddInvocation(id="2", b=5) # n2 = AddInvocation(id="2", b=5)
g.add_node(n2) # g.add_node(n2)
g.add_edge(create_edge("1.2", "value", "2", "a")) # g.add_edge(create_edge("1.2", "value", "2", "a"))
dg = g.nx_graph_flat() # dg = g.nx_graph_flat()
assert set(dg.nodes) == {"1.1", "1.2", "2"} # assert set(dg.nodes) == {"1.1", "1.2", "2"}
assert set(dg.edges) == {("1.1", "1.2"), ("1.2", "2")} # assert set(dg.edges) == {("1.1", "1.2"), ("1.2", "2")}
def test_graph_subgraph_t2i(): # def test_graph_subgraph_t2i():
g = Graph() # g = Graph()
n1 = GraphInvocation(id="1") # n1 = GraphInvocation(id="1")
# Get text to image default graph # # Get text to image default graph
lg = create_text_to_image() # lg = create_text_to_image()
n1.graph = lg.graph # n1.graph = lg.graph
g.add_node(n1) # g.add_node(n1)
n2 = IntegerInvocation(id="2", value=512) # n2 = IntegerInvocation(id="2", value=512)
n3 = IntegerInvocation(id="3", value=256) # n3 = IntegerInvocation(id="3", value=256)
g.add_node(n2) # g.add_node(n2)
g.add_node(n3) # g.add_node(n3)
g.add_edge(create_edge("2", "value", "1.width", "value")) # g.add_edge(create_edge("2", "value", "1.width", "value"))
g.add_edge(create_edge("3", "value", "1.height", "value")) # g.add_edge(create_edge("3", "value", "1.height", "value"))
n4 = ShowImageInvocation(id="4") # n4 = ShowImageInvocation(id="4")
g.add_node(n4) # g.add_node(n4)
g.add_edge(create_edge("1.8", "image", "4", "image")) # g.add_edge(create_edge("1.8", "image", "4", "image"))
# Validate # # Validate
dg = g.nx_graph_flat() # dg = g.nx_graph_flat()
assert set(dg.nodes) == {"1.width", "1.height", "1.seed", "1.3", "1.4", "1.5", "1.6", "1.7", "1.8", "2", "3", "4"} # assert set(dg.nodes) == {"1.width", "1.height", "1.seed", "1.3", "1.4", "1.5", "1.6", "1.7", "1.8", "2", "3", "4"}
expected_edges = [(f"1.{e.source.node_id}", f"1.{e.destination.node_id}") for e in lg.graph.edges] # expected_edges = [(f"1.{e.source.node_id}", f"1.{e.destination.node_id}") for e in lg.graph.edges]
expected_edges.extend([("2", "1.width"), ("3", "1.height"), ("1.8", "4")]) # expected_edges.extend([("2", "1.width"), ("3", "1.height"), ("1.8", "4")])
print(expected_edges) # print(expected_edges)
print(list(dg.edges)) # print(list(dg.edges))
assert set(dg.edges) == set(expected_edges) # assert set(dg.edges) == set(expected_edges)
def test_graph_fails_to_get_missing_subgraph_node(): # def test_graph_fails_to_get_missing_subgraph_node():
g = Graph() # g = Graph()
n1 = GraphInvocation(id="1") # n1 = GraphInvocation(id="1")
n1.graph = Graph() # n1.graph = Graph()
n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") # n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
n1.graph.add_node(n1_1) # n1.graph.add_node(n1_1)
g.add_node(n1) # g.add_node(n1)
with pytest.raises(NodeNotFoundError): # with pytest.raises(NodeNotFoundError):
_ = g.get_node("1.2") # _ = g.get_node("1.2")
def test_graph_fails_to_enumerate_non_subgraph_node(): # def test_graph_fails_to_enumerate_non_subgraph_node():
g = Graph() # g = Graph()
n1 = GraphInvocation(id="1") # n1 = GraphInvocation(id="1")
n1.graph = Graph() # n1.graph = Graph()
n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") # n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
n1.graph.add_node(n1_1) # n1.graph.add_node(n1_1)
g.add_node(n1) # g.add_node(n1)
n2 = ESRGANInvocation(id="2") # n2 = ESRGANInvocation(id="2")
g.add_node(n2) # g.add_node(n2)
with pytest.raises(NodeNotFoundError): # with pytest.raises(NodeNotFoundError):
_ = g.get_node("2.1") # _ = g.get_node("2.1")
def test_graph_gets_networkx_graph(): def test_graph_gets_networkx_graph():

View File

@ -8,10 +8,9 @@ from invokeai.app.services.session_queue.session_queue_common import (
NodeFieldValue, NodeFieldValue,
calc_session_count, calc_session_count,
create_session_nfv_tuples, create_session_nfv_tuples,
populate_graph,
prepare_values_to_insert, prepare_values_to_insert,
) )
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation from invokeai.app.services.shared.graph import Graph, GraphExecutionState
from tests.aa_nodes.test_nodes import PromptTestInvocation from tests.aa_nodes.test_nodes import PromptTestInvocation
@ -39,28 +38,28 @@ def batch_graph() -> Graph:
return g return g
def test_populate_graph_with_subgraph(): # def test_populate_graph_with_subgraph():
g1 = Graph() # g1 = Graph()
g1.add_node(PromptTestInvocation(id="1", prompt="Banana sushi")) # g1.add_node(PromptTestInvocation(id="1", prompt="Banana sushi"))
g1.add_node(PromptTestInvocation(id="2", prompt="Banana sushi")) # g1.add_node(PromptTestInvocation(id="2", prompt="Banana sushi"))
n1 = PromptTestInvocation(id="1", prompt="Banana snake") # n1 = PromptTestInvocation(id="1", prompt="Banana snake")
subgraph = Graph() # subgraph = Graph()
subgraph.add_node(n1) # subgraph.add_node(n1)
g1.add_node(GraphInvocation(id="3", graph=subgraph)) # g1.add_node(GraphInvocation(id="3", graph=subgraph))
nfvs = [ # nfvs = [
NodeFieldValue(node_path="1", field_name="prompt", value="Strawberry sushi"), # NodeFieldValue(node_path="1", field_name="prompt", value="Strawberry sushi"),
NodeFieldValue(node_path="2", field_name="prompt", value="Strawberry sunday"), # NodeFieldValue(node_path="2", field_name="prompt", value="Strawberry sunday"),
NodeFieldValue(node_path="3.1", field_name="prompt", value="Strawberry snake"), # NodeFieldValue(node_path="3.1", field_name="prompt", value="Strawberry snake"),
] # ]
g2 = populate_graph(g1, nfvs) # g2 = populate_graph(g1, nfvs)
# do not mutate g1 # # do not mutate g1
assert g1 is not g2 # assert g1 is not g2
assert g2.get_node("1").prompt == "Strawberry sushi" # assert g2.get_node("1").prompt == "Strawberry sushi"
assert g2.get_node("2").prompt == "Strawberry sunday" # assert g2.get_node("2").prompt == "Strawberry sunday"
assert g2.get_node("3.1").prompt == "Strawberry snake" # assert g2.get_node("3.1").prompt == "Strawberry snake"
def test_create_sessions_from_batch_with_runs(batch_data_collection, batch_graph): def test_create_sessions_from_batch_with_runs(batch_data_collection, batch_graph):