Make batch_indices in graph class more clear

This commit is contained in:
Brandon Rising 2023-07-24 17:43:49 -04:00
parent 4bad96d9d6
commit d090be60e8
2 changed files with 6 additions and 6 deletions

View File

@ -803,7 +803,7 @@ class GraphExecutionState(BaseModel):
# TODO: Store a reference to the graph instead of the actual graph? # TODO: Store a reference to the graph instead of the actual graph?
graph: Graph = Field(description="The graph being executed") graph: Graph = Field(description="The graph being executed")
batch_index: list[int] = Field(description="Tracker for which batch is currently being processed", default_factory=list) batch_indices: list[int] = Field(description="Tracker for which batch is currently being processed", default_factory=list)
# The graph of materialized nodes # The graph of materialized nodes
execution_graph: Graph = Field( execution_graph: Graph = Field(
@ -877,11 +877,11 @@ class GraphExecutionState(BaseModel):
if next_node is not None: if next_node is not None:
self._prepare_inputs(next_node) self._prepare_inputs(next_node)
if sum(self.batch_index) != 0: if sum(self.batch_indices) != 0:
for index in self.batch_index: for index in self.batch_indices:
if self.batch_index[index] > 0: if self.batch_indices[index] > 0:
self.executed.clear() self.executed.clear()
self.batch_index[index] -= 1 self.batch_indices[index] -= 1
return next(self) return next(self)
# If next is still none, there's no next node, return None # If next is still none, there's no next node, return None

View File

@ -28,7 +28,7 @@ class Invoker:
(index, batch) = next(((i,b) for i,b in enumerate(graph_execution_state.graph.batches) if b.node_id in invocation.id), (None, None)) (index, batch) = next(((i,b) for i,b in enumerate(graph_execution_state.graph.batches) if b.node_id in invocation.id), (None, None))
if batch: if batch:
# assert(isinstance(invocation.type, batch.node_type), f"Type mismatch between nodes and batch config on {invocation.id}") # assert(isinstance(invocation.type, batch.node_type), f"Type mismatch between nodes and batch config on {invocation.id}")
batch_index = graph_execution_state.batch_index[index] batch_index = graph_execution_state.batch_indices[index]
datum = batch.data[batch_index] datum = batch.data[batch_index]
for param in datum.keys(): for param in datum.keys():
invocation[param] = datum[param] invocation[param] = datum[param]