diff --git a/invokeai/app/api/routers/sessions.py b/invokeai/app/api/routers/sessions.py index c3cb654a04..5f6228a1e3 100644 --- a/invokeai/app/api/routers/sessions.py +++ b/invokeai/app/api/routers/sessions.py @@ -48,7 +48,7 @@ async def create_session( ) async def create_batch( graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with"), - batches: list[Batch] = Body(description="Batch config to apply to the given graph") + batches: list[Batch] = Body(description="Batch config to apply to the given graph"), ) -> BatchProcess: """Creates and starts a new new batch process""" session = ApiDependencies.invoker.services.batch_manager.run_batch_process(batches, graph) diff --git a/invokeai/app/services/batch_manager.py b/invokeai/app/services/batch_manager.py index 2607964a27..8acd5f9a7b 100644 --- a/invokeai/app/services/batch_manager.py +++ b/invokeai/app/services/batch_manager.py @@ -1,4 +1,3 @@ - import networkx as nx import uuid import copy @@ -21,6 +20,8 @@ from invokeai.app.services.invoker import Invoker InvocationsUnion = Union[BaseInvocation.get_invocations()] # type: ignore + + class Batch(BaseModel): data: list[InvocationsUnion] = Field(description="Mapping of ") node_id: str = Field(description="ID of the node to batch") @@ -28,52 +29,44 @@ class Batch(BaseModel): class BatchProcess(BaseModel): batch_id: Optional[str] = Field(default_factory=uuid.uuid4().__str__, description="Identifier for this batch") - sessions: list[str] = Field(description="Tracker for which batch is currently being processed", default_factory=list) + sessions: list[str] = Field( + description="Tracker for which batch is currently being processed", default_factory=list + ) batches: list[Batch] = Field( description="List of batch configs to apply to this session", default_factory=list, ) - batch_indices: list[int] = Field(description="Tracker for which batch is currently being processed", default_factory=list) + batch_indices: list[int] = Field( + description="Tracker for which batch is currently being processed", default_factory=list + ) graph: Graph = Field(description="The graph being executed") class BatchManagerBase(ABC): @abstractmethod - def start( - self, - invoker: Invoker - ): + def start(self, invoker: Invoker): pass @abstractmethod - def run_batch_process( - self, - batches: list[Batch], - graph: Graph - ) -> BatchProcess: + def run_batch_process(self, batches: list[Batch], graph: Graph) -> BatchProcess: pass @abstractmethod - def cancel_batch_process( - self, - batch_process_id: str - ): + def cancel_batch_process(self, batch_process_id: str): pass class BatchManager(BatchManagerBase): """Responsible for managing currently running and scheduled batch jobs""" + __invoker: Invoker __batches: list[BatchProcess] - def start(self, invoker) -> None: # if we do want multithreading at some point, we could make this configurable self.__invoker = invoker self.__batches = list() - local_handler.register( - event_name=EventServiceBase.session_event, _func=self.on_event - ) + local_handler.register(event_name=EventServiceBase.session_event, _func=self.on_event) async def on_event(self, event: Event): event_name = event[1]["event"] @@ -85,22 +78,22 @@ class BatchManager(BatchManagerBase): await self.process(event) return event - + async def process(self, event: Event): data = event[1]["data"] batchTarget = None for batch in self.__batches: - if data['graph_execution_state_id'] in batch.sessions: + if data["graph_execution_state_id"] in batch.sessions: batchTarget = batch break - + if batchTarget == None: return - + if sum(batchTarget.batch_indices) == 0: self.__batches = [batch for batch in self.__batches if batch != batchTarget] return - + batchTarget.batch_indices = self._next_batch_index(batchTarget) ges = self._next_batch_session(batchTarget) batchTarget.sessions.append(ges.id) @@ -114,15 +107,14 @@ class BatchManager(BatchManagerBase): sorted_nodes = nx.topological_sort(g) for npath in sorted_nodes: node = graph.get_node(npath) - (index, batch) = next(((i,b) for i,b in enumerate(batches) if b.node_id in node.id), (None, None)) + (index, batch) = next(((i, b) for i, b in enumerate(batches) if b.node_id in node.id), (None, None)) if batch: batch_index = batch_process.batch_indices[index] datum = batch.data[batch_index] datum.id = node.id graph.update_node(npath, datum) - - return GraphExecutionState(graph=graph) + return GraphExecutionState(graph=graph) def _next_batch_index(self, batch_process: BatchProcess): batch_indicies = batch_process.batch_indices.copy() @@ -132,19 +124,14 @@ class BatchManager(BatchManagerBase): break return batch_indicies - - def run_batch_process( - self, - batches: list[Batch], - graph: Graph - ) -> BatchProcess: + def run_batch_process(self, batches: list[Batch], graph: Graph) -> BatchProcess: batch_indices = list() for batch in batches: - batch_indices.append(len(batch.data)-1) + batch_indices.append(len(batch.data) - 1) batch_process = BatchProcess( - batches = batches, - batch_indices = batch_indices, - graph = graph, + batches=batches, + batch_indices=batch_indices, + graph=graph, ) ges = self._next_batch_session(batch_process) batch_process.sessions.append(ges.id) @@ -153,8 +140,5 @@ class BatchManager(BatchManagerBase): self.__invoker.invoke(ges, invoke_all=True) return batch_process - def cancel_batch_process( - self, - batch_process_id: str - ): + def cancel_batch_process(self, batch_process_id: str): self.__batches = [batch for batch in self.__batches if batch.id != batch_process_id]