mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Allow cancel of running batch
This commit is contained in:
parent
1d798d4119
commit
1debc31e3d
@ -57,7 +57,7 @@ async def create_batch(
|
||||
|
||||
|
||||
@session_router.delete(
|
||||
"{batch_process_id}/batch",
|
||||
"/batch/{batch_process_id}",
|
||||
operation_id="cancel_batch",
|
||||
responses={202: {"description": "The batch is canceled"}},
|
||||
)
|
||||
|
@ -83,7 +83,9 @@ class BatchManager(BatchManagerBase):
|
||||
batch_session.session_id,
|
||||
updateSession,
|
||||
)
|
||||
self.run_batch_process(batch_session.batch_id)
|
||||
batch_process = self.__batch_process_storage.get(batch_session.batch_id)
|
||||
if not batch_process.canceled:
|
||||
self.run_batch_process(batch_process.batch_id)
|
||||
|
||||
def _create_batch_session(self, batch_process: BatchProcess, batch_indices: list[int]) -> GraphExecutionState:
|
||||
graph = copy.deepcopy(batch_process.graph)
|
||||
@ -147,4 +149,4 @@ class BatchManager(BatchManagerBase):
|
||||
return sessions
|
||||
|
||||
def cancel_batch_process(self, batch_process_id: str):
|
||||
self.__batches = [batch for batch in self.__batches if batch.id != batch_process_id]
|
||||
self.__batch_process_storage.cancel(batch_process_id)
|
||||
|
@ -48,6 +48,7 @@ class BatchProcess(BaseModel):
|
||||
description="List of batch configs to apply to this session",
|
||||
default_factory=list,
|
||||
)
|
||||
canceled: bool = Field(description="Flag for saying whether or not to run sessions from this batch", default=False)
|
||||
graph: Graph = Field(description="The graph being executed")
|
||||
|
||||
|
||||
@ -123,6 +124,14 @@ class BatchProcessStorageBase(ABC):
|
||||
"""Gets a Batch Process record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel(
|
||||
self,
|
||||
batch_id: str,
|
||||
):
|
||||
"""Cancel Batch Process record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_session(
|
||||
self,
|
||||
@ -200,6 +209,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
batch_id TEXT NOT NULL PRIMARY KEY,
|
||||
batches TEXT NOT NULL,
|
||||
graph TEXT NOT NULL,
|
||||
canceled BOOLEAN NOT NULL DEFAULT(0),
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
@ -325,12 +335,14 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
batch_id = session_dict.get("batch_id", "unknown")
|
||||
batches_raw = session_dict.get("batches", "unknown")
|
||||
graph_raw = session_dict.get("graph", "unknown")
|
||||
canceled = session_dict.get("canceled", 0)
|
||||
batches = json.loads(batches_raw)
|
||||
batches = [parse_raw_as(Batch, batch) for batch in batches]
|
||||
return BatchProcess(
|
||||
batch_id=batch_id,
|
||||
batches=batches,
|
||||
graph=parse_raw_as(Graph, graph_raw),
|
||||
canceled = canceled == 1
|
||||
)
|
||||
|
||||
def get(
|
||||
@ -358,6 +370,28 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
raise BatchProcessNotFoundException
|
||||
return self._deserialize_batch_process(dict(result))
|
||||
|
||||
|
||||
def cancel(
|
||||
self,
|
||||
batch_id: str,
|
||||
):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE batch_process
|
||||
SET canceled = 1
|
||||
WHERE batch_id = ?;
|
||||
""",
|
||||
(batch_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchSessionSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def create_session(
|
||||
self,
|
||||
session: BatchSession,
|
||||
|
Loading…
Reference in New Issue
Block a user