Break apart create/start logic

This commit is contained in:
Brandon Rising 2023-08-15 16:28:47 -04:00
parent 6cb90e01de
commit 15e7ca1baa
3 changed files with 47 additions and 5 deletions

View File

@ -52,10 +52,25 @@ async def create_batch(
) -> BatchProcessResponse:
"""Creates and starts a new new batch process"""
batch_process_res = ApiDependencies.invoker.services.batch_manager.create_batch_process(batches, graph)
ApiDependencies.invoker.services.batch_manager.run_batch_process(batch_process_res.batch_id)
return batch_process_res
@session_router.put(
"/batch/{batch_process_id}/invoke",
operation_id="start_batch",
responses={
200: {"model": BatchProcessResponse},
400: {"description": "Invalid json"},
},
)
async def start_batch(
batch_process_id: str = Path(description="ID of Batch to start"),
) -> BatchProcessResponse:
ApiDependencies.invoker.services.batch_manager.run_batch_process(batch_process_id)
return Response(status_code=202)
@session_router.delete(
"/batch/{batch_process_id}",
operation_id="cancel_batch",

View File

@ -47,7 +47,6 @@ class BatchManager(BatchManagerBase):
"""Responsible for managing currently running and scheduled batch jobs"""
__invoker: Invoker
__batches: list[BatchProcess]
__batch_process_storage: BatchProcessStorageBase
def __init__(self, batch_process_storage: BatchProcessStorageBase) -> None:
@ -55,9 +54,7 @@ class BatchManager(BatchManagerBase):
self.__batch_process_storage = batch_process_storage
def start(self, invoker: Invoker) -> None:
# if we do want multithreading at some point, we could make this configurable
self.__invoker = invoker
self.__batches = list()
local_handler.register(event_name=EventServiceBase.session_event, _func=self.on_event)
async def on_event(self, event: Event):
@ -87,7 +84,7 @@ class BatchManager(BatchManagerBase):
self.run_batch_process(batch_process.batch_id)
def _create_batch_session(self, batch_process: BatchProcess, batch_indices: list[int]) -> GraphExecutionState:
graph = copy.deepcopy(batch_process.graph)
graph = batch_process.graph.copy(deep=True)
batches = batch_process.batches
g = graph.nx_graph_flat()
sorted_nodes = nx.topological_sort(g)
@ -104,6 +101,7 @@ class BatchManager(BatchManagerBase):
return GraphExecutionState(graph=graph)
def run_batch_process(self, batch_id: str):
self.__batch_process_storage.start(batch_id)
try:
created_session = self.__batch_process_storage.get_created_session(batch_id)
except BatchSessionNotFoundException:

View File

@ -126,6 +126,14 @@ class BatchProcessStorageBase(ABC):
"""Gets a Batch Process record."""
pass
@abstractmethod
def start(
self,
batch_id: str,
):
"""Start Batch Process record."""
pass
@abstractmethod
def cancel(
self,
@ -360,6 +368,27 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
raise BatchProcessNotFoundException
return self._deserialize_batch_process(dict(result))
def start(
self,
batch_id: str,
):
try:
self._lock.acquire()
self._cursor.execute(
f"""--sql
UPDATE batch_process
SET canceled = 0
WHERE batch_id = ?;
""",
(batch_id,),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BatchSessionSaveException from e
finally:
self._lock.release()
def cancel(
self,
batch_id: str,