From 15e7ca1baa263cb044442106a0cc59322c43ee43 Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Tue, 15 Aug 2023 16:28:47 -0400 Subject: [PATCH] Break apart create/start logic --- invokeai/app/api/routers/sessions.py | 17 ++++++++++- invokeai/app/services/batch_manager.py | 6 ++-- .../app/services/batch_manager_storage.py | 29 +++++++++++++++++++ 3 files changed, 47 insertions(+), 5 deletions(-) diff --git a/invokeai/app/api/routers/sessions.py b/invokeai/app/api/routers/sessions.py index a82b5cb949..c533b95afe 100644 --- a/invokeai/app/api/routers/sessions.py +++ b/invokeai/app/api/routers/sessions.py @@ -52,10 +52,25 @@ async def create_batch( ) -> BatchProcessResponse: """Creates and starts a new new batch process""" batch_process_res = ApiDependencies.invoker.services.batch_manager.create_batch_process(batches, graph) - ApiDependencies.invoker.services.batch_manager.run_batch_process(batch_process_res.batch_id) return batch_process_res +@session_router.put( + "/batch/{batch_process_id}/invoke", + operation_id="start_batch", + responses={ + 200: {"model": BatchProcessResponse}, + 400: {"description": "Invalid json"}, + }, +) +async def start_batch( + batch_process_id: str = Path(description="ID of Batch to start"), +) -> BatchProcessResponse: + ApiDependencies.invoker.services.batch_manager.run_batch_process(batch_process_id) + + return Response(status_code=202) + + @session_router.delete( "/batch/{batch_process_id}", operation_id="cancel_batch", diff --git a/invokeai/app/services/batch_manager.py b/invokeai/app/services/batch_manager.py index 1a67916cf7..9b4e2ce57c 100644 --- a/invokeai/app/services/batch_manager.py +++ b/invokeai/app/services/batch_manager.py @@ -47,7 +47,6 @@ class BatchManager(BatchManagerBase): """Responsible for managing currently running and scheduled batch jobs""" __invoker: Invoker - __batches: list[BatchProcess] __batch_process_storage: BatchProcessStorageBase def __init__(self, batch_process_storage: BatchProcessStorageBase) -> None: @@ -55,9 +54,7 @@ class BatchManager(BatchManagerBase): self.__batch_process_storage = batch_process_storage def start(self, invoker: Invoker) -> None: - # if we do want multithreading at some point, we could make this configurable self.__invoker = invoker - self.__batches = list() local_handler.register(event_name=EventServiceBase.session_event, _func=self.on_event) async def on_event(self, event: Event): @@ -87,7 +84,7 @@ class BatchManager(BatchManagerBase): self.run_batch_process(batch_process.batch_id) def _create_batch_session(self, batch_process: BatchProcess, batch_indices: list[int]) -> GraphExecutionState: - graph = copy.deepcopy(batch_process.graph) + graph = batch_process.graph.copy(deep=True) batches = batch_process.batches g = graph.nx_graph_flat() sorted_nodes = nx.topological_sort(g) @@ -104,6 +101,7 @@ class BatchManager(BatchManagerBase): return GraphExecutionState(graph=graph) def run_batch_process(self, batch_id: str): + self.__batch_process_storage.start(batch_id) try: created_session = self.__batch_process_storage.get_created_session(batch_id) except BatchSessionNotFoundException: diff --git a/invokeai/app/services/batch_manager_storage.py b/invokeai/app/services/batch_manager_storage.py index 5d25597623..7d16a73ebd 100644 --- a/invokeai/app/services/batch_manager_storage.py +++ b/invokeai/app/services/batch_manager_storage.py @@ -126,6 +126,14 @@ class BatchProcessStorageBase(ABC): """Gets a Batch Process record.""" pass + @abstractmethod + def start( + self, + batch_id: str, + ): + """Start Batch Process record.""" + pass + @abstractmethod def cancel( self, @@ -360,6 +368,27 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase): raise BatchProcessNotFoundException return self._deserialize_batch_process(dict(result)) + def start( + self, + batch_id: str, + ): + try: + self._lock.acquire() + self._cursor.execute( + f"""--sql + UPDATE batch_process + SET canceled = 0 + WHERE batch_id = ?; + """, + (batch_id,), + ) + self._conn.commit() + except sqlite3.Error as e: + self._conn.rollback() + raise BatchSessionSaveException from e + finally: + self._lock.release() + def cancel( self, batch_id: str,