mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Return session id's on batch creation
This commit is contained in:
parent
c1dde83abb
commit
1d798d4119
@ -15,7 +15,7 @@ from ...services.graph import (
|
|||||||
GraphExecutionState,
|
GraphExecutionState,
|
||||||
NodeAlreadyExecutedError,
|
NodeAlreadyExecutedError,
|
||||||
)
|
)
|
||||||
from ...services.batch_manager import Batch, BatchProcess
|
from ...services.batch_manager import Batch, BatchProcessResponse
|
||||||
from ...services.item_storage import PaginatedResults
|
from ...services.item_storage import PaginatedResults
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
@ -42,18 +42,18 @@ async def create_session(
|
|||||||
"/batch",
|
"/batch",
|
||||||
operation_id="create_batch",
|
operation_id="create_batch",
|
||||||
responses={
|
responses={
|
||||||
200: {"model": BatchProcess},
|
200: {"model": BatchProcessResponse},
|
||||||
400: {"description": "Invalid json"},
|
400: {"description": "Invalid json"},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
async def create_batch(
|
async def create_batch(
|
||||||
graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with"),
|
graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with"),
|
||||||
batches: list[Batch] = Body(description="Batch config to apply to the given graph"),
|
batches: list[Batch] = Body(description="Batch config to apply to the given graph"),
|
||||||
) -> BatchProcess:
|
) -> BatchProcessResponse:
|
||||||
"""Creates and starts a new new batch process"""
|
"""Creates and starts a new new batch process"""
|
||||||
batch_id = ApiDependencies.invoker.services.batch_manager.create_batch_process(batches, graph)
|
batch_process_res = ApiDependencies.invoker.services.batch_manager.create_batch_process(batches, graph)
|
||||||
ApiDependencies.invoker.services.batch_manager.run_batch_process(batch_id)
|
ApiDependencies.invoker.services.batch_manager.run_batch_process(batch_process_res.batch_id)
|
||||||
return {"batch_id":batch_id}
|
return batch_process_res
|
||||||
|
|
||||||
|
|
||||||
@session_router.delete(
|
@session_router.delete(
|
||||||
|
@ -19,6 +19,10 @@ from invokeai.app.services.batch_manager_storage import (
|
|||||||
BatchSessionChanges,
|
BatchSessionChanges,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class BatchProcessResponse(BaseModel):
|
||||||
|
batch_id: str = Field(description="ID for the batch")
|
||||||
|
session_ids: list[str] = Field(description="List of session IDs created for this batch")
|
||||||
|
|
||||||
|
|
||||||
class BatchManagerBase(ABC):
|
class BatchManagerBase(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -26,7 +30,7 @@ class BatchManagerBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_batch_process(self, batches: list[Batch], graph: Graph) -> str:
|
def create_batch_process(self, batches: list[Batch], graph: Graph) -> BatchProcessResponse:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -107,9 +111,11 @@ class BatchManager(BatchManagerBase):
|
|||||||
self.__invoker.invoke(ges, invoke_all=True)
|
self.__invoker.invoke(ges, invoke_all=True)
|
||||||
|
|
||||||
def _valid_batch_config(self, batch_process: BatchProcess) -> bool:
|
def _valid_batch_config(self, batch_process: BatchProcess) -> bool:
|
||||||
|
# TODO: Check that the node_ids in the batches are unique
|
||||||
|
# TODO: Validate data types are correct for each batch data
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def create_batch_process(self, batches: list[Batch], graph: Graph) -> str:
|
def create_batch_process(self, batches: list[Batch], graph: Graph) -> BatchProcessResponse:
|
||||||
batch_process = BatchProcess(
|
batch_process = BatchProcess(
|
||||||
batches=batches,
|
batches=batches,
|
||||||
graph=graph,
|
graph=graph,
|
||||||
@ -117,11 +123,15 @@ class BatchManager(BatchManagerBase):
|
|||||||
if not self._valid_batch_config(batch_process):
|
if not self._valid_batch_config(batch_process):
|
||||||
return None
|
return None
|
||||||
batch_process = self.__batch_process_storage.save(batch_process)
|
batch_process = self.__batch_process_storage.save(batch_process)
|
||||||
self._create_sessions(batch_process)
|
sessions = self._create_sessions(batch_process)
|
||||||
return batch_process.batch_id
|
return BatchProcessResponse(
|
||||||
|
batch_id=batch_process.batch_id,
|
||||||
|
session_ids=[session.session_id for session in sessions],
|
||||||
|
)
|
||||||
|
|
||||||
def _create_sessions(self, batch_process: BatchProcess):
|
def _create_sessions(self, batch_process: BatchProcess) -> list[BatchSession]:
|
||||||
batch_indices = list()
|
batch_indices = list()
|
||||||
|
sessions = list()
|
||||||
for batch in batch_process.batches:
|
for batch in batch_process.batches:
|
||||||
batch_indices.append(list(range(len(batch.data))))
|
batch_indices.append(list(range(len(batch.data))))
|
||||||
all_batch_indices = product(*batch_indices)
|
all_batch_indices = product(*batch_indices)
|
||||||
@ -133,7 +143,8 @@ class BatchManager(BatchManagerBase):
|
|||||||
session_id=ges.id,
|
session_id=ges.id,
|
||||||
state="created"
|
state="created"
|
||||||
)
|
)
|
||||||
self.__batch_process_storage.create_session(batch_session)
|
sessions.append(self.__batch_process_storage.create_session(batch_session))
|
||||||
|
return sessions
|
||||||
|
|
||||||
def cancel_batch_process(self, batch_process_id: str):
|
def cancel_batch_process(self, batch_process_id: str):
|
||||||
self.__batches = [batch for batch in self.__batches if batch.id != batch_process_id]
|
self.__batches = [batch for batch in self.__batches if batch.id != batch_process_id]
|
||||||
|
Loading…
Reference in New Issue
Block a user