Allow cancel of running batch

This commit is contained in:
Brandon Rising 2023-08-11 15:52:49 -04:00
parent 1d798d4119
commit 1debc31e3d
3 changed files with 39 additions and 3 deletions

View File

@ -57,7 +57,7 @@ async def create_batch(
@session_router.delete(
"{batch_process_id}/batch",
"/batch/{batch_process_id}",
operation_id="cancel_batch",
responses={202: {"description": "The batch is canceled"}},
)

View File

@ -83,7 +83,9 @@ class BatchManager(BatchManagerBase):
batch_session.session_id,
updateSession,
)
self.run_batch_process(batch_session.batch_id)
batch_process = self.__batch_process_storage.get(batch_session.batch_id)
if not batch_process.canceled:
self.run_batch_process(batch_process.batch_id)
def _create_batch_session(self, batch_process: BatchProcess, batch_indices: list[int]) -> GraphExecutionState:
graph = copy.deepcopy(batch_process.graph)
@ -147,4 +149,4 @@ class BatchManager(BatchManagerBase):
return sessions
def cancel_batch_process(self, batch_process_id: str):
self.__batches = [batch for batch in self.__batches if batch.id != batch_process_id]
self.__batch_process_storage.cancel(batch_process_id)

View File

@ -48,6 +48,7 @@ class BatchProcess(BaseModel):
description="List of batch configs to apply to this session",
default_factory=list,
)
canceled: bool = Field(description="Flag for saying whether or not to run sessions from this batch", default=False)
graph: Graph = Field(description="The graph being executed")
@ -123,6 +124,14 @@ class BatchProcessStorageBase(ABC):
"""Gets a Batch Process record."""
pass
@abstractmethod
def cancel(
self,
batch_id: str,
):
"""Cancel Batch Process record."""
pass
@abstractmethod
def create_session(
self,
@ -200,6 +209,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
batch_id TEXT NOT NULL PRIMARY KEY,
batches TEXT NOT NULL,
graph TEXT NOT NULL,
canceled BOOLEAN NOT NULL DEFAULT(0),
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
@ -325,12 +335,14 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
batch_id = session_dict.get("batch_id", "unknown")
batches_raw = session_dict.get("batches", "unknown")
graph_raw = session_dict.get("graph", "unknown")
canceled = session_dict.get("canceled", 0)
batches = json.loads(batches_raw)
batches = [parse_raw_as(Batch, batch) for batch in batches]
return BatchProcess(
batch_id=batch_id,
batches=batches,
graph=parse_raw_as(Graph, graph_raw),
canceled = canceled == 1
)
def get(
@ -358,6 +370,28 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
raise BatchProcessNotFoundException
return self._deserialize_batch_process(dict(result))
def cancel(
self,
batch_id: str,
):
try:
self._lock.acquire()
self._cursor.execute(
f"""--sql
UPDATE batch_process
SET canceled = 1
WHERE batch_id = ?;
""",
(batch_id,),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BatchSessionSaveException from e
finally:
self._lock.release()
def create_session(
self,
session: BatchSession,