diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index e3941d9ca3..cc2ea5cedb 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -1043,23 +1043,32 @@ class GraphExecutionState(BaseModel): """Gets the deepest node that is ready to be executed""" g = self.execution_graph.nx_graph() - # Depth-first search with pre-order traversal is a depth-first topological sort - sorted_nodes = nx.dfs_preorder_nodes(g) + # Perform a topological sort using depth-first search + topo_order = list(nx.dfs_postorder_nodes(g)) - next_node = next( - ( - n - for n in sorted_nodes - if n not in self.executed # the node must not already be executed... - and all((e[0] in self.executed for e in g.in_edges(n))) # ...and all its inputs must be executed - ), - None, - ) + # Get all IterateInvocation nodes + iterate_nodes = [n for n in topo_order if isinstance(self.execution_graph.nodes[n], IterateInvocation)] - if next_node is None: - return None + # Sort the IterateInvocation nodes based on their index attribute + iterate_nodes.sort(key=lambda x: self.execution_graph.nodes[x].index) - return self.execution_graph.nodes[next_node] + # Prioritize IterateInvocation nodes and their children + for iterate_node in iterate_nodes: + if iterate_node not in self.executed and all((e[0] in self.executed for e in g.in_edges(iterate_node))): + return self.execution_graph.nodes[iterate_node] + + # Check the children of the IterateInvocation node + for child_node in nx.dfs_postorder_nodes(g, iterate_node): + if child_node not in self.executed and all((e[0] in self.executed for e in g.in_edges(child_node))): + return self.execution_graph.nodes[child_node] + + # If no IterateInvocation node or its children are ready, return the first ready node in the topological order + for node in topo_order: + if node not in self.executed and all((e[0] in self.executed for e in g.in_edges(node))): + return self.execution_graph.nodes[node] + + # If no node is found, return None + return None def _prepare_inputs(self, node: BaseInvocation): input_edges = [e for e in self.execution_graph.edges if e.destination.node_id == node.id]