diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 6082057bd3..f9da560a08 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -376,9 +376,8 @@ class TextToLatentsInvocation(BaseInvocation): # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 result_latents = result_latents.to("cpu") - torch.cuda.empty_cache() - - name = f'{context.graph_execution_state_id}__{self.id}' + import uuid + name = f'{context.graph_execution_state_id}__{self.id}_{uuid.uuid4()}' context.services.latents.save(name, result_latents) return build_latents_output(latents_name=name, latents=result_latents) diff --git a/invokeai/app/services/graph.py b/invokeai/app/services/graph.py index 3e44415d3d..97b081133c 100644 --- a/invokeai/app/services/graph.py +++ b/invokeai/app/services/graph.py @@ -803,8 +803,6 @@ class GraphExecutionState(BaseModel): # TODO: Store a reference to the graph instead of the actual graph? graph: Graph = Field(description="The graph being executed") - batch_indices: list[int] = Field(description="Tracker for which batch is currently being processed", default_factory=list) - # The graph of materialized nodes execution_graph: Graph = Field( description="The expanded graph of activated and executed nodes", @@ -857,13 +855,14 @@ class GraphExecutionState(BaseModel): ] } - def next(self) -> Optional[BaseInvocation]: + def next(self, batch_indices: list = list()) -> Optional[BaseInvocation]: """Gets the next node ready to execute.""" # TODO: enable multiple nodes to execute simultaneously by tracking currently executing nodes # possibly with a timeout? # If there are no prepared nodes, prepare some nodes + self._apply_batch_config() next_node = self._get_next_node() if next_node is None: prepared_id = self._prepare() @@ -872,17 +871,15 @@ class GraphExecutionState(BaseModel): while prepared_id is not None: prepared_id = self._prepare() next_node = self._get_next_node() - # Get values from edges if next_node is not None: self._prepare_inputs(next_node) - - if sum(self.batch_indices) != 0: - for index in self.batch_indices: + if next_node is None and sum(self.batch_indices) != 0: + for index in range(len(self.batch_indices)): if self.batch_indices[index] > 0: - self.executed.clear() self.batch_indices[index] -= 1 - return self.next(self) + self.executed.clear() + return self.next() # If next is still none, there's no next node, return None return next_node @@ -912,7 +909,7 @@ class GraphExecutionState(BaseModel): def is_complete(self) -> bool: """Returns true if the graph is complete""" node_ids = set(self.graph.nx_graph_flat().nodes) - return self.has_error() or all((k in self.executed for k in node_ids)) + return sum(self.batch_indices) == 0 and (self.has_error() or all((k in self.executed for k in node_ids))) def has_error(self) -> bool: """Returns true if the graph has any errors""" @@ -1020,6 +1017,20 @@ class GraphExecutionState(BaseModel): ] return iterators + def _apply_batch_config(self): + g = self.graph.nx_graph_flat() + sorted_nodes = nx.topological_sort(g) + batchable_nodes = [n for n in sorted_nodes if n not in self.executed] + for npath in batchable_nodes: + node = self.graph.get_node(npath) + (index, batch) = next(((i,b) for i,b in enumerate(self.graph.batches) if b.node_id in node.id), (None, None)) + if batch: + batch_index = self.batch_indices[index] + datum = batch.data[batch_index] + datum.id = node.id + self.graph.update_node(npath, datum) + + def _prepare(self) -> Optional[str]: # Get flattened source graph g = self.graph.nx_graph_flat() @@ -1051,7 +1062,6 @@ class GraphExecutionState(BaseModel): ), None, ) - if next_node_id == None: return None diff --git a/invokeai/app/services/invoker.py b/invokeai/app/services/invoker.py index 4ca98f31bd..9c9548b837 100644 --- a/invokeai/app/services/invoker.py +++ b/invokeai/app/services/invoker.py @@ -25,14 +25,6 @@ class Invoker: invocation = graph_execution_state.next() if not invocation: return None - (index, batch) = next(((i,b) for i,b in enumerate(graph_execution_state.graph.batches) if b.node_id in invocation.id), (None, None)) - if batch: - # assert(isinstance(invocation.type, batch.node_type), f"Type mismatch between nodes and batch config on {invocation.id}") - batch_index = graph_execution_state.batch_indices[index] - datum = batch.data[batch_index] - for param in datum.keys(): - invocation[param] = datum[param] - # TODO graph.update_node # Save the execution state self.services.graph_execution_manager.set(graph_execution_state) diff --git a/invokeai/app/services/processor.py b/invokeai/app/services/processor.py index 5995e4ffc3..fa5087ed23 100644 --- a/invokeai/app/services/processor.py +++ b/invokeai/app/services/processor.py @@ -73,8 +73,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC): error_type=e.__class__.__name__, error=traceback.format_exc(), ) - continue - + continue # get the source node id to provide to clients (the prepared node id is not as useful) source_node_id = graph_execution_state.prepared_source_mapping[invocation.id]