mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
prioritize iterate in _get_next_node
This commit is contained in:
parent
3cdfc6ab16
commit
83b3828b55
@ -1043,24 +1043,33 @@ class GraphExecutionState(BaseModel):
|
||||
"""Gets the deepest node that is ready to be executed"""
|
||||
g = self.execution_graph.nx_graph()
|
||||
|
||||
# Depth-first search with pre-order traversal is a depth-first topological sort
|
||||
sorted_nodes = nx.dfs_preorder_nodes(g)
|
||||
# Perform a topological sort using depth-first search
|
||||
topo_order = list(nx.dfs_postorder_nodes(g))
|
||||
|
||||
next_node = next(
|
||||
(
|
||||
n
|
||||
for n in sorted_nodes
|
||||
if n not in self.executed # the node must not already be executed...
|
||||
and all((e[0] in self.executed for e in g.in_edges(n))) # ...and all its inputs must be executed
|
||||
),
|
||||
None,
|
||||
)
|
||||
# Get all IterateInvocation nodes
|
||||
iterate_nodes = [n for n in topo_order if isinstance(self.execution_graph.nodes[n], IterateInvocation)]
|
||||
|
||||
if next_node is None:
|
||||
# Sort the IterateInvocation nodes based on their index attribute
|
||||
iterate_nodes.sort(key=lambda x: self.execution_graph.nodes[x].index)
|
||||
|
||||
# Prioritize IterateInvocation nodes and their children
|
||||
for iterate_node in iterate_nodes:
|
||||
if iterate_node not in self.executed and all((e[0] in self.executed for e in g.in_edges(iterate_node))):
|
||||
return self.execution_graph.nodes[iterate_node]
|
||||
|
||||
# Check the children of the IterateInvocation node
|
||||
for child_node in nx.dfs_postorder_nodes(g, iterate_node):
|
||||
if child_node not in self.executed and all((e[0] in self.executed for e in g.in_edges(child_node))):
|
||||
return self.execution_graph.nodes[child_node]
|
||||
|
||||
# If no IterateInvocation node or its children are ready, return the first ready node in the topological order
|
||||
for node in topo_order:
|
||||
if node not in self.executed and all((e[0] in self.executed for e in g.in_edges(node))):
|
||||
return self.execution_graph.nodes[node]
|
||||
|
||||
# If no node is found, return None
|
||||
return None
|
||||
|
||||
return self.execution_graph.nodes[next_node]
|
||||
|
||||
def _prepare_inputs(self, node: BaseInvocation):
|
||||
input_edges = [e for e in self.execution_graph.edges if e.destination.node_id == node.id]
|
||||
# Inputs must be deep-copied, else if a node mutates the object, other nodes that get the same input
|
||||
|
Loading…
Reference in New Issue
Block a user