mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Trying different places of applying batches
This commit is contained in:
parent
e81601acf3
commit
d2f968b902
@ -376,9 +376,8 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
result_latents = result_latents.to("cpu")
|
result_latents = result_latents.to("cpu")
|
||||||
torch.cuda.empty_cache()
|
import uuid
|
||||||
|
name = f'{context.graph_execution_state_id}__{self.id}_{uuid.uuid4()}'
|
||||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
|
||||||
context.services.latents.save(name, result_latents)
|
context.services.latents.save(name, result_latents)
|
||||||
return build_latents_output(latents_name=name, latents=result_latents)
|
return build_latents_output(latents_name=name, latents=result_latents)
|
||||||
|
|
||||||
|
@ -803,8 +803,6 @@ class GraphExecutionState(BaseModel):
|
|||||||
# TODO: Store a reference to the graph instead of the actual graph?
|
# TODO: Store a reference to the graph instead of the actual graph?
|
||||||
graph: Graph = Field(description="The graph being executed")
|
graph: Graph = Field(description="The graph being executed")
|
||||||
|
|
||||||
batch_indices: list[int] = Field(description="Tracker for which batch is currently being processed", default_factory=list)
|
|
||||||
|
|
||||||
# The graph of materialized nodes
|
# The graph of materialized nodes
|
||||||
execution_graph: Graph = Field(
|
execution_graph: Graph = Field(
|
||||||
description="The expanded graph of activated and executed nodes",
|
description="The expanded graph of activated and executed nodes",
|
||||||
@ -857,13 +855,14 @@ class GraphExecutionState(BaseModel):
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
def next(self) -> Optional[BaseInvocation]:
|
def next(self, batch_indices: list = list()) -> Optional[BaseInvocation]:
|
||||||
"""Gets the next node ready to execute."""
|
"""Gets the next node ready to execute."""
|
||||||
|
|
||||||
# TODO: enable multiple nodes to execute simultaneously by tracking currently executing nodes
|
# TODO: enable multiple nodes to execute simultaneously by tracking currently executing nodes
|
||||||
# possibly with a timeout?
|
# possibly with a timeout?
|
||||||
|
|
||||||
# If there are no prepared nodes, prepare some nodes
|
# If there are no prepared nodes, prepare some nodes
|
||||||
|
self._apply_batch_config()
|
||||||
next_node = self._get_next_node()
|
next_node = self._get_next_node()
|
||||||
if next_node is None:
|
if next_node is None:
|
||||||
prepared_id = self._prepare()
|
prepared_id = self._prepare()
|
||||||
@ -872,17 +871,15 @@ class GraphExecutionState(BaseModel):
|
|||||||
while prepared_id is not None:
|
while prepared_id is not None:
|
||||||
prepared_id = self._prepare()
|
prepared_id = self._prepare()
|
||||||
next_node = self._get_next_node()
|
next_node = self._get_next_node()
|
||||||
|
|
||||||
# Get values from edges
|
# Get values from edges
|
||||||
if next_node is not None:
|
if next_node is not None:
|
||||||
self._prepare_inputs(next_node)
|
self._prepare_inputs(next_node)
|
||||||
|
if next_node is None and sum(self.batch_indices) != 0:
|
||||||
if sum(self.batch_indices) != 0:
|
for index in range(len(self.batch_indices)):
|
||||||
for index in self.batch_indices:
|
|
||||||
if self.batch_indices[index] > 0:
|
if self.batch_indices[index] > 0:
|
||||||
self.executed.clear()
|
|
||||||
self.batch_indices[index] -= 1
|
self.batch_indices[index] -= 1
|
||||||
return self.next(self)
|
self.executed.clear()
|
||||||
|
return self.next()
|
||||||
|
|
||||||
# If next is still none, there's no next node, return None
|
# If next is still none, there's no next node, return None
|
||||||
return next_node
|
return next_node
|
||||||
@ -912,7 +909,7 @@ class GraphExecutionState(BaseModel):
|
|||||||
def is_complete(self) -> bool:
|
def is_complete(self) -> bool:
|
||||||
"""Returns true if the graph is complete"""
|
"""Returns true if the graph is complete"""
|
||||||
node_ids = set(self.graph.nx_graph_flat().nodes)
|
node_ids = set(self.graph.nx_graph_flat().nodes)
|
||||||
return self.has_error() or all((k in self.executed for k in node_ids))
|
return sum(self.batch_indices) == 0 and (self.has_error() or all((k in self.executed for k in node_ids)))
|
||||||
|
|
||||||
def has_error(self) -> bool:
|
def has_error(self) -> bool:
|
||||||
"""Returns true if the graph has any errors"""
|
"""Returns true if the graph has any errors"""
|
||||||
@ -1020,6 +1017,20 @@ class GraphExecutionState(BaseModel):
|
|||||||
]
|
]
|
||||||
return iterators
|
return iterators
|
||||||
|
|
||||||
|
def _apply_batch_config(self):
|
||||||
|
g = self.graph.nx_graph_flat()
|
||||||
|
sorted_nodes = nx.topological_sort(g)
|
||||||
|
batchable_nodes = [n for n in sorted_nodes if n not in self.executed]
|
||||||
|
for npath in batchable_nodes:
|
||||||
|
node = self.graph.get_node(npath)
|
||||||
|
(index, batch) = next(((i,b) for i,b in enumerate(self.graph.batches) if b.node_id in node.id), (None, None))
|
||||||
|
if batch:
|
||||||
|
batch_index = self.batch_indices[index]
|
||||||
|
datum = batch.data[batch_index]
|
||||||
|
datum.id = node.id
|
||||||
|
self.graph.update_node(npath, datum)
|
||||||
|
|
||||||
|
|
||||||
def _prepare(self) -> Optional[str]:
|
def _prepare(self) -> Optional[str]:
|
||||||
# Get flattened source graph
|
# Get flattened source graph
|
||||||
g = self.graph.nx_graph_flat()
|
g = self.graph.nx_graph_flat()
|
||||||
@ -1051,7 +1062,6 @@ class GraphExecutionState(BaseModel):
|
|||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if next_node_id == None:
|
if next_node_id == None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -25,14 +25,6 @@ class Invoker:
|
|||||||
invocation = graph_execution_state.next()
|
invocation = graph_execution_state.next()
|
||||||
if not invocation:
|
if not invocation:
|
||||||
return None
|
return None
|
||||||
(index, batch) = next(((i,b) for i,b in enumerate(graph_execution_state.graph.batches) if b.node_id in invocation.id), (None, None))
|
|
||||||
if batch:
|
|
||||||
# assert(isinstance(invocation.type, batch.node_type), f"Type mismatch between nodes and batch config on {invocation.id}")
|
|
||||||
batch_index = graph_execution_state.batch_indices[index]
|
|
||||||
datum = batch.data[batch_index]
|
|
||||||
for param in datum.keys():
|
|
||||||
invocation[param] = datum[param]
|
|
||||||
# TODO graph.update_node
|
|
||||||
|
|
||||||
# Save the execution state
|
# Save the execution state
|
||||||
self.services.graph_execution_manager.set(graph_execution_state)
|
self.services.graph_execution_manager.set(graph_execution_state)
|
||||||
|
@ -74,7 +74,6 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
error=traceback.format_exc(),
|
error=traceback.format_exc(),
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# get the source node id to provide to clients (the prepared node id is not as useful)
|
# get the source node id to provide to clients (the prepared node id is not as useful)
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[invocation.id]
|
source_node_id = graph_execution_state.prepared_source_mapping[invocation.id]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user