Testing out generating a new session for each batch_index

This commit is contained in:
Brandon Rising 2023-07-25 16:50:07 -04:00
parent d2f968b902
commit f080c56771
5 changed files with 28 additions and 20 deletions

View File

@ -35,7 +35,12 @@ async def create_session(
) )
) -> GraphExecutionState: ) -> GraphExecutionState:
"""Creates a new session, optionally initializing it with an invocation graph""" """Creates a new session, optionally initializing it with an invocation graph"""
session = ApiDependencies.invoker.create_execution_state(graph)
batch_indices = list()
if graph.batches:
for batch in graph.batches:
batch_indices.append(len(batch.data)-1)
session = ApiDependencies.invoker.create_execution_state(graph, batch_indices)
return session return session

View File

@ -376,8 +376,9 @@ class TextToLatentsInvocation(BaseInvocation):
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.to("cpu") result_latents = result_latents.to("cpu")
import uuid torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}_{uuid.uuid4()}'
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.save(name, result_latents) context.services.latents.save(name, result_latents)
return build_latents_output(latents_name=name, latents=result_latents) return build_latents_output(latents_name=name, latents=result_latents)

View File

@ -818,6 +818,8 @@ class GraphExecutionState(BaseModel):
default_factory=list, default_factory=list,
) )
batch_indices: list[int] = Field(description="Tracker for which batch is currently being processed", default_factory=list)
# The results of executed nodes # The results of executed nodes
results: dict[ results: dict[
str, Annotated[InvocationOutputsUnion, Field(discriminator="type")] str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]
@ -855,14 +857,14 @@ class GraphExecutionState(BaseModel):
] ]
} }
def next(self, batch_indices: list = list()) -> Optional[BaseInvocation]: def next(self) -> Optional[BaseInvocation]:
"""Gets the next node ready to execute.""" """Gets the next node ready to execute."""
# TODO: enable multiple nodes to execute simultaneously by tracking currently executing nodes # TODO: enable multiple nodes to execute simultaneously by tracking currently executing nodes
# possibly with a timeout? # possibly with a timeout?
# If there are no prepared nodes, prepare some nodes
self._apply_batch_config() self._apply_batch_config()
# If there are no prepared nodes, prepare some nodes
next_node = self._get_next_node() next_node = self._get_next_node()
if next_node is None: if next_node is None:
prepared_id = self._prepare() prepared_id = self._prepare()
@ -871,15 +873,10 @@ class GraphExecutionState(BaseModel):
while prepared_id is not None: while prepared_id is not None:
prepared_id = self._prepare() prepared_id = self._prepare()
next_node = self._get_next_node() next_node = self._get_next_node()
# Get values from edges # Get values from edges
if next_node is not None: if next_node is not None:
self._prepare_inputs(next_node) self._prepare_inputs(next_node)
if next_node is None and sum(self.batch_indices) != 0:
for index in range(len(self.batch_indices)):
if self.batch_indices[index] > 0:
self.batch_indices[index] -= 1
self.executed.clear()
return self.next()
# If next is still none, there's no next node, return None # If next is still none, there's no next node, return None
return next_node return next_node
@ -909,7 +906,7 @@ class GraphExecutionState(BaseModel):
def is_complete(self) -> bool: def is_complete(self) -> bool:
"""Returns true if the graph is complete""" """Returns true if the graph is complete"""
node_ids = set(self.graph.nx_graph_flat().nodes) node_ids = set(self.graph.nx_graph_flat().nodes)
return sum(self.batch_indices) == 0 and (self.has_error() or all((k in self.executed for k in node_ids))) return self.has_error() or all((k in self.executed for k in node_ids))
def has_error(self) -> bool: def has_error(self) -> bool:
"""Returns true if the graph has any errors""" """Returns true if the graph has any errors"""

View File

@ -41,14 +41,9 @@ class Invoker:
return invocation.id return invocation.id
def create_execution_state(self, graph: Optional[Graph] = None) -> GraphExecutionState: def create_execution_state(self, graph: Optional[Graph] = None, batch_indices: list[int] = list()) -> GraphExecutionState:
"""Creates a new execution state for the given graph""" """Creates a new execution state for the given graph"""
new_state = GraphExecutionState(graph=Graph() if graph is None else graph) new_state = GraphExecutionState(graph=Graph() if graph is None else graph, batch_indices=batch_indices)
if graph.batches:
batch_indices = list()
for batch in graph.batches:
batch_indices.append(len(batch.data)-1)
new_state.batch_indices = batch_indices
self.services.graph_execution_manager.set(new_state) self.services.graph_execution_manager.set(new_state)
return new_state return new_state

View File

@ -6,6 +6,7 @@ from ..invocations.baseinvocation import InvocationContext
from .invocation_queue import InvocationQueueItem from .invocation_queue import InvocationQueueItem
from .invoker import InvocationProcessorABC, Invoker from .invoker import InvocationProcessorABC, Invoker
from ..models.exceptions import CanceledException from ..models.exceptions import CanceledException
from .graph import GraphExecutionState
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
class DefaultInvocationProcessor(InvocationProcessorABC): class DefaultInvocationProcessor(InvocationProcessorABC):
@ -73,7 +74,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
error_type=e.__class__.__name__, error_type=e.__class__.__name__,
error=traceback.format_exc(), error=traceback.format_exc(),
) )
continue continue
# get the source node id to provide to clients (the prepared node id is not as useful) # get the source node id to provide to clients (the prepared node id is not as useful)
source_node_id = graph_execution_state.prepared_source_mapping[invocation.id] source_node_id = graph_execution_state.prepared_source_mapping[invocation.id]
@ -165,6 +166,15 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
error_type=e.__class__.__name__, error_type=e.__class__.__name__,
error=traceback.format_exc() error=traceback.format_exc()
) )
elif queue_item.invoke_all and sum(graph_execution_state.batch_indices) > 0:
batch_indicies = graph_execution_state.batch_indices.copy()
for index in range(len(batch_indicies)):
if batch_indicies[index] > 0:
batch_indicies[index] -= 1
break
new_ges = GraphExecutionState(graph=graph_execution_state.graph, batch_indices=batch_indicies)
self.__invoker.services.graph_execution_manager.set(new_ges)
self.__invoker.invoke(new_ges, invoke_all=True)
elif is_complete: elif is_complete:
self.__invoker.services.events.emit_graph_execution_complete( self.__invoker.services.events.emit_graph_execution_complete(
graph_execution_state.id graph_execution_state.id