From f080c56771be1844789abb9875ecc14c222723c6 Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Tue, 25 Jul 2023 16:50:07 -0400 Subject: [PATCH] Testing out generating a new session for each batch_index --- invokeai/app/api/routers/sessions.py | 7 ++++++- invokeai/app/invocations/latent.py | 5 +++-- invokeai/app/services/graph.py | 15 ++++++--------- invokeai/app/services/invoker.py | 9 ++------- invokeai/app/services/processor.py | 12 +++++++++++- 5 files changed, 28 insertions(+), 20 deletions(-) diff --git a/invokeai/app/api/routers/sessions.py b/invokeai/app/api/routers/sessions.py index da842a3968..0f8f1fa5f8 100644 --- a/invokeai/app/api/routers/sessions.py +++ b/invokeai/app/api/routers/sessions.py @@ -35,7 +35,12 @@ async def create_session( ) ) -> GraphExecutionState: """Creates a new session, optionally initializing it with an invocation graph""" - session = ApiDependencies.invoker.create_execution_state(graph) + + batch_indices = list() + if graph.batches: + for batch in graph.batches: + batch_indices.append(len(batch.data)-1) + session = ApiDependencies.invoker.create_execution_state(graph, batch_indices) return session diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index f9da560a08..6082057bd3 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -376,8 +376,9 @@ class TextToLatentsInvocation(BaseInvocation): # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 result_latents = result_latents.to("cpu") - import uuid - name = f'{context.graph_execution_state_id}__{self.id}_{uuid.uuid4()}' + torch.cuda.empty_cache() + + name = f'{context.graph_execution_state_id}__{self.id}' context.services.latents.save(name, result_latents) return build_latents_output(latents_name=name, latents=result_latents) diff --git a/invokeai/app/services/graph.py b/invokeai/app/services/graph.py index 97b081133c..d3a4f134fe 100644 --- a/invokeai/app/services/graph.py +++ b/invokeai/app/services/graph.py @@ -818,6 +818,8 @@ class GraphExecutionState(BaseModel): default_factory=list, ) + batch_indices: list[int] = Field(description="Tracker for which batch is currently being processed", default_factory=list) + # The results of executed nodes results: dict[ str, Annotated[InvocationOutputsUnion, Field(discriminator="type")] @@ -855,14 +857,14 @@ class GraphExecutionState(BaseModel): ] } - def next(self, batch_indices: list = list()) -> Optional[BaseInvocation]: + def next(self) -> Optional[BaseInvocation]: """Gets the next node ready to execute.""" # TODO: enable multiple nodes to execute simultaneously by tracking currently executing nodes # possibly with a timeout? - # If there are no prepared nodes, prepare some nodes self._apply_batch_config() + # If there are no prepared nodes, prepare some nodes next_node = self._get_next_node() if next_node is None: prepared_id = self._prepare() @@ -871,15 +873,10 @@ class GraphExecutionState(BaseModel): while prepared_id is not None: prepared_id = self._prepare() next_node = self._get_next_node() + # Get values from edges if next_node is not None: self._prepare_inputs(next_node) - if next_node is None and sum(self.batch_indices) != 0: - for index in range(len(self.batch_indices)): - if self.batch_indices[index] > 0: - self.batch_indices[index] -= 1 - self.executed.clear() - return self.next() # If next is still none, there's no next node, return None return next_node @@ -909,7 +906,7 @@ class GraphExecutionState(BaseModel): def is_complete(self) -> bool: """Returns true if the graph is complete""" node_ids = set(self.graph.nx_graph_flat().nodes) - return sum(self.batch_indices) == 0 and (self.has_error() or all((k in self.executed for k in node_ids))) + return self.has_error() or all((k in self.executed for k in node_ids)) def has_error(self) -> bool: """Returns true if the graph has any errors""" diff --git a/invokeai/app/services/invoker.py b/invokeai/app/services/invoker.py index 9c9548b837..8fa27793b7 100644 --- a/invokeai/app/services/invoker.py +++ b/invokeai/app/services/invoker.py @@ -41,14 +41,9 @@ class Invoker: return invocation.id - def create_execution_state(self, graph: Optional[Graph] = None) -> GraphExecutionState: + def create_execution_state(self, graph: Optional[Graph] = None, batch_indices: list[int] = list()) -> GraphExecutionState: """Creates a new execution state for the given graph""" - new_state = GraphExecutionState(graph=Graph() if graph is None else graph) - if graph.batches: - batch_indices = list() - for batch in graph.batches: - batch_indices.append(len(batch.data)-1) - new_state.batch_indices = batch_indices + new_state = GraphExecutionState(graph=Graph() if graph is None else graph, batch_indices=batch_indices) self.services.graph_execution_manager.set(new_state) return new_state diff --git a/invokeai/app/services/processor.py b/invokeai/app/services/processor.py index fa5087ed23..f7ef326355 100644 --- a/invokeai/app/services/processor.py +++ b/invokeai/app/services/processor.py @@ -6,6 +6,7 @@ from ..invocations.baseinvocation import InvocationContext from .invocation_queue import InvocationQueueItem from .invoker import InvocationProcessorABC, Invoker from ..models.exceptions import CanceledException +from .graph import GraphExecutionState import invokeai.backend.util.logging as logger class DefaultInvocationProcessor(InvocationProcessorABC): @@ -73,7 +74,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC): error_type=e.__class__.__name__, error=traceback.format_exc(), ) - continue + continue # get the source node id to provide to clients (the prepared node id is not as useful) source_node_id = graph_execution_state.prepared_source_mapping[invocation.id] @@ -165,6 +166,15 @@ class DefaultInvocationProcessor(InvocationProcessorABC): error_type=e.__class__.__name__, error=traceback.format_exc() ) + elif queue_item.invoke_all and sum(graph_execution_state.batch_indices) > 0: + batch_indicies = graph_execution_state.batch_indices.copy() + for index in range(len(batch_indicies)): + if batch_indicies[index] > 0: + batch_indicies[index] -= 1 + break + new_ges = GraphExecutionState(graph=graph_execution_state.graph, batch_indices=batch_indicies) + self.__invoker.services.graph_execution_manager.set(new_ges) + self.__invoker.invoke(new_ges, invoke_all=True) elif is_complete: self.__invoker.services.events.emit_graph_execution_complete( graph_execution_state.id