From 1d798d411931a6354797505abb22f0db409e6f95 Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Fri, 11 Aug 2023 11:45:27 -0400 Subject: [PATCH] Return session id's on batch creation --- invokeai/app/api/routers/sessions.py | 12 ++++++------ invokeai/app/services/batch_manager.py | 23 +++++++++++++++++------ 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/invokeai/app/api/routers/sessions.py b/invokeai/app/api/routers/sessions.py index f2acc388c3..91e7d5e2b3 100644 --- a/invokeai/app/api/routers/sessions.py +++ b/invokeai/app/api/routers/sessions.py @@ -15,7 +15,7 @@ from ...services.graph import ( GraphExecutionState, NodeAlreadyExecutedError, ) -from ...services.batch_manager import Batch, BatchProcess +from ...services.batch_manager import Batch, BatchProcessResponse from ...services.item_storage import PaginatedResults from ..dependencies import ApiDependencies @@ -42,18 +42,18 @@ async def create_session( "/batch", operation_id="create_batch", responses={ - 200: {"model": BatchProcess}, + 200: {"model": BatchProcessResponse}, 400: {"description": "Invalid json"}, }, ) async def create_batch( graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with"), batches: list[Batch] = Body(description="Batch config to apply to the given graph"), -) -> BatchProcess: +) -> BatchProcessResponse: """Creates and starts a new new batch process""" - batch_id = ApiDependencies.invoker.services.batch_manager.create_batch_process(batches, graph) - ApiDependencies.invoker.services.batch_manager.run_batch_process(batch_id) - return {"batch_id":batch_id} + batch_process_res = ApiDependencies.invoker.services.batch_manager.create_batch_process(batches, graph) + ApiDependencies.invoker.services.batch_manager.run_batch_process(batch_process_res.batch_id) + return batch_process_res @session_router.delete( diff --git a/invokeai/app/services/batch_manager.py b/invokeai/app/services/batch_manager.py index a5c92bb1fc..cd1597a5a9 100644 --- a/invokeai/app/services/batch_manager.py +++ b/invokeai/app/services/batch_manager.py @@ -19,6 +19,10 @@ from invokeai.app.services.batch_manager_storage import ( BatchSessionChanges, ) +class BatchProcessResponse(BaseModel): + batch_id: str = Field(description="ID for the batch") + session_ids: list[str] = Field(description="List of session IDs created for this batch") + class BatchManagerBase(ABC): @abstractmethod @@ -26,7 +30,7 @@ class BatchManagerBase(ABC): pass @abstractmethod - def create_batch_process(self, batches: list[Batch], graph: Graph) -> str: + def create_batch_process(self, batches: list[Batch], graph: Graph) -> BatchProcessResponse: pass @abstractmethod @@ -107,9 +111,11 @@ class BatchManager(BatchManagerBase): self.__invoker.invoke(ges, invoke_all=True) def _valid_batch_config(self, batch_process: BatchProcess) -> bool: + # TODO: Check that the node_ids in the batches are unique + # TODO: Validate data types are correct for each batch data return True - def create_batch_process(self, batches: list[Batch], graph: Graph) -> str: + def create_batch_process(self, batches: list[Batch], graph: Graph) -> BatchProcessResponse: batch_process = BatchProcess( batches=batches, graph=graph, @@ -117,11 +123,15 @@ class BatchManager(BatchManagerBase): if not self._valid_batch_config(batch_process): return None batch_process = self.__batch_process_storage.save(batch_process) - self._create_sessions(batch_process) - return batch_process.batch_id + sessions = self._create_sessions(batch_process) + return BatchProcessResponse( + batch_id=batch_process.batch_id, + session_ids=[session.session_id for session in sessions], + ) - def _create_sessions(self, batch_process: BatchProcess): + def _create_sessions(self, batch_process: BatchProcess) -> list[BatchSession]: batch_indices = list() + sessions = list() for batch in batch_process.batches: batch_indices.append(list(range(len(batch.data)))) all_batch_indices = product(*batch_indices) @@ -133,7 +143,8 @@ class BatchManager(BatchManagerBase): session_id=ges.id, state="created" ) - self.__batch_process_storage.create_session(batch_session) + sessions.append(self.__batch_process_storage.create_session(batch_session)) + return sessions def cancel_batch_process(self, batch_process_id: str): self.__batches = [batch for batch in self.__batches if batch.id != batch_process_id]