mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Return session id's on batch creation
This commit is contained in:
parent
c1dde83abb
commit
1d798d4119
@ -15,7 +15,7 @@ from ...services.graph import (
|
||||
GraphExecutionState,
|
||||
NodeAlreadyExecutedError,
|
||||
)
|
||||
from ...services.batch_manager import Batch, BatchProcess
|
||||
from ...services.batch_manager import Batch, BatchProcessResponse
|
||||
from ...services.item_storage import PaginatedResults
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
@ -42,18 +42,18 @@ async def create_session(
|
||||
"/batch",
|
||||
operation_id="create_batch",
|
||||
responses={
|
||||
200: {"model": BatchProcess},
|
||||
200: {"model": BatchProcessResponse},
|
||||
400: {"description": "Invalid json"},
|
||||
},
|
||||
)
|
||||
async def create_batch(
|
||||
graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with"),
|
||||
batches: list[Batch] = Body(description="Batch config to apply to the given graph"),
|
||||
) -> BatchProcess:
|
||||
) -> BatchProcessResponse:
|
||||
"""Creates and starts a new new batch process"""
|
||||
batch_id = ApiDependencies.invoker.services.batch_manager.create_batch_process(batches, graph)
|
||||
ApiDependencies.invoker.services.batch_manager.run_batch_process(batch_id)
|
||||
return {"batch_id":batch_id}
|
||||
batch_process_res = ApiDependencies.invoker.services.batch_manager.create_batch_process(batches, graph)
|
||||
ApiDependencies.invoker.services.batch_manager.run_batch_process(batch_process_res.batch_id)
|
||||
return batch_process_res
|
||||
|
||||
|
||||
@session_router.delete(
|
||||
|
@ -19,6 +19,10 @@ from invokeai.app.services.batch_manager_storage import (
|
||||
BatchSessionChanges,
|
||||
)
|
||||
|
||||
class BatchProcessResponse(BaseModel):
|
||||
batch_id: str = Field(description="ID for the batch")
|
||||
session_ids: list[str] = Field(description="List of session IDs created for this batch")
|
||||
|
||||
|
||||
class BatchManagerBase(ABC):
|
||||
@abstractmethod
|
||||
@ -26,7 +30,7 @@ class BatchManagerBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_batch_process(self, batches: list[Batch], graph: Graph) -> str:
|
||||
def create_batch_process(self, batches: list[Batch], graph: Graph) -> BatchProcessResponse:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -107,9 +111,11 @@ class BatchManager(BatchManagerBase):
|
||||
self.__invoker.invoke(ges, invoke_all=True)
|
||||
|
||||
def _valid_batch_config(self, batch_process: BatchProcess) -> bool:
|
||||
# TODO: Check that the node_ids in the batches are unique
|
||||
# TODO: Validate data types are correct for each batch data
|
||||
return True
|
||||
|
||||
def create_batch_process(self, batches: list[Batch], graph: Graph) -> str:
|
||||
def create_batch_process(self, batches: list[Batch], graph: Graph) -> BatchProcessResponse:
|
||||
batch_process = BatchProcess(
|
||||
batches=batches,
|
||||
graph=graph,
|
||||
@ -117,11 +123,15 @@ class BatchManager(BatchManagerBase):
|
||||
if not self._valid_batch_config(batch_process):
|
||||
return None
|
||||
batch_process = self.__batch_process_storage.save(batch_process)
|
||||
self._create_sessions(batch_process)
|
||||
return batch_process.batch_id
|
||||
sessions = self._create_sessions(batch_process)
|
||||
return BatchProcessResponse(
|
||||
batch_id=batch_process.batch_id,
|
||||
session_ids=[session.session_id for session in sessions],
|
||||
)
|
||||
|
||||
def _create_sessions(self, batch_process: BatchProcess):
|
||||
def _create_sessions(self, batch_process: BatchProcess) -> list[BatchSession]:
|
||||
batch_indices = list()
|
||||
sessions = list()
|
||||
for batch in batch_process.batches:
|
||||
batch_indices.append(list(range(len(batch.data))))
|
||||
all_batch_indices = product(*batch_indices)
|
||||
@ -133,7 +143,8 @@ class BatchManager(BatchManagerBase):
|
||||
session_id=ges.id,
|
||||
state="created"
|
||||
)
|
||||
self.__batch_process_storage.create_session(batch_session)
|
||||
sessions.append(self.__batch_process_storage.create_session(batch_session))
|
||||
return sessions
|
||||
|
||||
def cancel_batch_process(self, batch_process_id: str):
|
||||
self.__batches = [batch for batch in self.__batches if batch.id != batch_process_id]
|
||||
|
Loading…
Reference in New Issue
Block a user