prioritize iterate in _get_next_node

This commit is contained in:
Joe Kubler 2024-03-22 12:13:33 -04:00 committed by psychedelicious
parent 3cdfc6ab16
commit 83b3828b55

View File

@ -1043,23 +1043,32 @@ class GraphExecutionState(BaseModel):
"""Gets the deepest node that is ready to be executed"""
g = self.execution_graph.nx_graph()
# Depth-first search with pre-order traversal is a depth-first topological sort
sorted_nodes = nx.dfs_preorder_nodes(g)
# Perform a topological sort using depth-first search
topo_order = list(nx.dfs_postorder_nodes(g))
next_node = next(
(
n
for n in sorted_nodes
if n not in self.executed # the node must not already be executed...
and all((e[0] in self.executed for e in g.in_edges(n))) # ...and all its inputs must be executed
),
None,
)
# Get all IterateInvocation nodes
iterate_nodes = [n for n in topo_order if isinstance(self.execution_graph.nodes[n], IterateInvocation)]
if next_node is None:
return None
# Sort the IterateInvocation nodes based on their index attribute
iterate_nodes.sort(key=lambda x: self.execution_graph.nodes[x].index)
return self.execution_graph.nodes[next_node]
# Prioritize IterateInvocation nodes and their children
for iterate_node in iterate_nodes:
if iterate_node not in self.executed and all((e[0] in self.executed for e in g.in_edges(iterate_node))):
return self.execution_graph.nodes[iterate_node]
# Check the children of the IterateInvocation node
for child_node in nx.dfs_postorder_nodes(g, iterate_node):
if child_node not in self.executed and all((e[0] in self.executed for e in g.in_edges(child_node))):
return self.execution_graph.nodes[child_node]
# If no IterateInvocation node or its children are ready, return the first ready node in the topological order
for node in topo_order:
if node not in self.executed and all((e[0] in self.executed for e in g.in_edges(node))):
return self.execution_graph.nodes[node]
# If no node is found, return None
return None
def _prepare_inputs(self, node: BaseInvocation):
input_edges = [e for e in self.execution_graph.edges if e.destination.node_id == node.id]