mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Make batch_indices in graph class more clear
This commit is contained in:
parent
4bad96d9d6
commit
d090be60e8
@ -803,7 +803,7 @@ class GraphExecutionState(BaseModel):
|
|||||||
# TODO: Store a reference to the graph instead of the actual graph?
|
# TODO: Store a reference to the graph instead of the actual graph?
|
||||||
graph: Graph = Field(description="The graph being executed")
|
graph: Graph = Field(description="The graph being executed")
|
||||||
|
|
||||||
batch_index: list[int] = Field(description="Tracker for which batch is currently being processed", default_factory=list)
|
batch_indices: list[int] = Field(description="Tracker for which batch is currently being processed", default_factory=list)
|
||||||
|
|
||||||
# The graph of materialized nodes
|
# The graph of materialized nodes
|
||||||
execution_graph: Graph = Field(
|
execution_graph: Graph = Field(
|
||||||
@ -877,11 +877,11 @@ class GraphExecutionState(BaseModel):
|
|||||||
if next_node is not None:
|
if next_node is not None:
|
||||||
self._prepare_inputs(next_node)
|
self._prepare_inputs(next_node)
|
||||||
|
|
||||||
if sum(self.batch_index) != 0:
|
if sum(self.batch_indices) != 0:
|
||||||
for index in self.batch_index:
|
for index in self.batch_indices:
|
||||||
if self.batch_index[index] > 0:
|
if self.batch_indices[index] > 0:
|
||||||
self.executed.clear()
|
self.executed.clear()
|
||||||
self.batch_index[index] -= 1
|
self.batch_indices[index] -= 1
|
||||||
return next(self)
|
return next(self)
|
||||||
|
|
||||||
# If next is still none, there's no next node, return None
|
# If next is still none, there's no next node, return None
|
||||||
|
@ -28,7 +28,7 @@ class Invoker:
|
|||||||
(index, batch) = next(((i,b) for i,b in enumerate(graph_execution_state.graph.batches) if b.node_id in invocation.id), (None, None))
|
(index, batch) = next(((i,b) for i,b in enumerate(graph_execution_state.graph.batches) if b.node_id in invocation.id), (None, None))
|
||||||
if batch:
|
if batch:
|
||||||
# assert(isinstance(invocation.type, batch.node_type), f"Type mismatch between nodes and batch config on {invocation.id}")
|
# assert(isinstance(invocation.type, batch.node_type), f"Type mismatch between nodes and batch config on {invocation.id}")
|
||||||
batch_index = graph_execution_state.batch_index[index]
|
batch_index = graph_execution_state.batch_indices[index]
|
||||||
datum = batch.data[batch_index]
|
datum = batch.data[batch_index]
|
||||||
for param in datum.keys():
|
for param in datum.keys():
|
||||||
invocation[param] = datum[param]
|
invocation[param] = datum[param]
|
||||||
|
Loading…
Reference in New Issue
Block a user