Return session id's on batch creation

This commit is contained in:
Brandon Rising 2023-08-11 11:45:27 -04:00
parent c1dde83abb
commit 1d798d4119
2 changed files with 23 additions and 12 deletions

View File

@ -15,7 +15,7 @@ from ...services.graph import (
GraphExecutionState,
NodeAlreadyExecutedError,
)
from ...services.batch_manager import Batch, BatchProcess
from ...services.batch_manager import Batch, BatchProcessResponse
from ...services.item_storage import PaginatedResults
from ..dependencies import ApiDependencies
@ -42,18 +42,18 @@ async def create_session(
"/batch",
operation_id="create_batch",
responses={
200: {"model": BatchProcess},
200: {"model": BatchProcessResponse},
400: {"description": "Invalid json"},
},
)
async def create_batch(
graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with"),
batches: list[Batch] = Body(description="Batch config to apply to the given graph"),
) -> BatchProcess:
) -> BatchProcessResponse:
"""Creates and starts a new new batch process"""
batch_id = ApiDependencies.invoker.services.batch_manager.create_batch_process(batches, graph)
ApiDependencies.invoker.services.batch_manager.run_batch_process(batch_id)
return {"batch_id":batch_id}
batch_process_res = ApiDependencies.invoker.services.batch_manager.create_batch_process(batches, graph)
ApiDependencies.invoker.services.batch_manager.run_batch_process(batch_process_res.batch_id)
return batch_process_res
@session_router.delete(

View File

@ -19,6 +19,10 @@ from invokeai.app.services.batch_manager_storage import (
BatchSessionChanges,
)
class BatchProcessResponse(BaseModel):
batch_id: str = Field(description="ID for the batch")
session_ids: list[str] = Field(description="List of session IDs created for this batch")
class BatchManagerBase(ABC):
@abstractmethod
@ -26,7 +30,7 @@ class BatchManagerBase(ABC):
pass
@abstractmethod
def create_batch_process(self, batches: list[Batch], graph: Graph) -> str:
def create_batch_process(self, batches: list[Batch], graph: Graph) -> BatchProcessResponse:
pass
@abstractmethod
@ -107,9 +111,11 @@ class BatchManager(BatchManagerBase):
self.__invoker.invoke(ges, invoke_all=True)
def _valid_batch_config(self, batch_process: BatchProcess) -> bool:
# TODO: Check that the node_ids in the batches are unique
# TODO: Validate data types are correct for each batch data
return True
def create_batch_process(self, batches: list[Batch], graph: Graph) -> str:
def create_batch_process(self, batches: list[Batch], graph: Graph) -> BatchProcessResponse:
batch_process = BatchProcess(
batches=batches,
graph=graph,
@ -117,11 +123,15 @@ class BatchManager(BatchManagerBase):
if not self._valid_batch_config(batch_process):
return None
batch_process = self.__batch_process_storage.save(batch_process)
self._create_sessions(batch_process)
return batch_process.batch_id
sessions = self._create_sessions(batch_process)
return BatchProcessResponse(
batch_id=batch_process.batch_id,
session_ids=[session.session_id for session in sessions],
)
def _create_sessions(self, batch_process: BatchProcess):
def _create_sessions(self, batch_process: BatchProcess) -> list[BatchSession]:
batch_indices = list()
sessions = list()
for batch in batch_process.batches:
batch_indices.append(list(range(len(batch.data))))
all_batch_indices = product(*batch_indices)
@ -133,7 +143,8 @@ class BatchManager(BatchManagerBase):
session_id=ges.id,
state="created"
)
self.__batch_process_storage.create_session(batch_session)
sessions.append(self.__batch_process_storage.create_session(batch_session))
return sessions
def cancel_batch_process(self, batch_process_id: str):
self.__batches = [batch for batch in self.__batches if batch.id != batch_process_id]