From 1debc31e3dc27f698b1198ad6738ca4a5cba0fbe Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Fri, 11 Aug 2023 15:52:49 -0400 Subject: [PATCH] Allow cancel of running batch --- invokeai/app/api/routers/sessions.py | 2 +- invokeai/app/services/batch_manager.py | 6 ++-- .../app/services/batch_manager_storage.py | 34 +++++++++++++++++++ 3 files changed, 39 insertions(+), 3 deletions(-) diff --git a/invokeai/app/api/routers/sessions.py b/invokeai/app/api/routers/sessions.py index 91e7d5e2b3..7eec88b4b6 100644 --- a/invokeai/app/api/routers/sessions.py +++ b/invokeai/app/api/routers/sessions.py @@ -57,7 +57,7 @@ async def create_batch( @session_router.delete( - "{batch_process_id}/batch", + "/batch/{batch_process_id}", operation_id="cancel_batch", responses={202: {"description": "The batch is canceled"}}, ) diff --git a/invokeai/app/services/batch_manager.py b/invokeai/app/services/batch_manager.py index cd1597a5a9..ce00fca708 100644 --- a/invokeai/app/services/batch_manager.py +++ b/invokeai/app/services/batch_manager.py @@ -83,7 +83,9 @@ class BatchManager(BatchManagerBase): batch_session.session_id, updateSession, ) - self.run_batch_process(batch_session.batch_id) + batch_process = self.__batch_process_storage.get(batch_session.batch_id) + if not batch_process.canceled: + self.run_batch_process(batch_process.batch_id) def _create_batch_session(self, batch_process: BatchProcess, batch_indices: list[int]) -> GraphExecutionState: graph = copy.deepcopy(batch_process.graph) @@ -147,4 +149,4 @@ class BatchManager(BatchManagerBase): return sessions def cancel_batch_process(self, batch_process_id: str): - self.__batches = [batch for batch in self.__batches if batch.id != batch_process_id] + self.__batch_process_storage.cancel(batch_process_id) diff --git a/invokeai/app/services/batch_manager_storage.py b/invokeai/app/services/batch_manager_storage.py index bfdcba236f..025ee4a338 100644 --- a/invokeai/app/services/batch_manager_storage.py +++ b/invokeai/app/services/batch_manager_storage.py @@ -48,6 +48,7 @@ class BatchProcess(BaseModel): description="List of batch configs to apply to this session", default_factory=list, ) + canceled: bool = Field(description="Flag for saying whether or not to run sessions from this batch", default=False) graph: Graph = Field(description="The graph being executed") @@ -123,6 +124,14 @@ class BatchProcessStorageBase(ABC): """Gets a Batch Process record.""" pass + @abstractmethod + def cancel( + self, + batch_id: str, + ): + """Cancel Batch Process record.""" + pass + @abstractmethod def create_session( self, @@ -200,6 +209,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase): batch_id TEXT NOT NULL PRIMARY KEY, batches TEXT NOT NULL, graph TEXT NOT NULL, + canceled BOOLEAN NOT NULL DEFAULT(0), created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- Updated via trigger updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), @@ -325,12 +335,14 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase): batch_id = session_dict.get("batch_id", "unknown") batches_raw = session_dict.get("batches", "unknown") graph_raw = session_dict.get("graph", "unknown") + canceled = session_dict.get("canceled", 0) batches = json.loads(batches_raw) batches = [parse_raw_as(Batch, batch) for batch in batches] return BatchProcess( batch_id=batch_id, batches=batches, graph=parse_raw_as(Graph, graph_raw), + canceled = canceled == 1 ) def get( @@ -358,6 +370,28 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase): raise BatchProcessNotFoundException return self._deserialize_batch_process(dict(result)) + + def cancel( + self, + batch_id: str, + ): + try: + self._lock.acquire() + self._cursor.execute( + f"""--sql + UPDATE batch_process + SET canceled = 1 + WHERE batch_id = ?; + """, + (batch_id,), + ) + self._conn.commit() + except sqlite3.Error as e: + self._conn.rollback() + raise BatchSessionSaveException from e + finally: + self._lock.release() + def create_session( self, session: BatchSession,