mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Allow cancel of running batch
This commit is contained in:
parent
1d798d4119
commit
1debc31e3d
@ -57,7 +57,7 @@ async def create_batch(
|
|||||||
|
|
||||||
|
|
||||||
@session_router.delete(
|
@session_router.delete(
|
||||||
"{batch_process_id}/batch",
|
"/batch/{batch_process_id}",
|
||||||
operation_id="cancel_batch",
|
operation_id="cancel_batch",
|
||||||
responses={202: {"description": "The batch is canceled"}},
|
responses={202: {"description": "The batch is canceled"}},
|
||||||
)
|
)
|
||||||
|
@ -83,7 +83,9 @@ class BatchManager(BatchManagerBase):
|
|||||||
batch_session.session_id,
|
batch_session.session_id,
|
||||||
updateSession,
|
updateSession,
|
||||||
)
|
)
|
||||||
self.run_batch_process(batch_session.batch_id)
|
batch_process = self.__batch_process_storage.get(batch_session.batch_id)
|
||||||
|
if not batch_process.canceled:
|
||||||
|
self.run_batch_process(batch_process.batch_id)
|
||||||
|
|
||||||
def _create_batch_session(self, batch_process: BatchProcess, batch_indices: list[int]) -> GraphExecutionState:
|
def _create_batch_session(self, batch_process: BatchProcess, batch_indices: list[int]) -> GraphExecutionState:
|
||||||
graph = copy.deepcopy(batch_process.graph)
|
graph = copy.deepcopy(batch_process.graph)
|
||||||
@ -147,4 +149,4 @@ class BatchManager(BatchManagerBase):
|
|||||||
return sessions
|
return sessions
|
||||||
|
|
||||||
def cancel_batch_process(self, batch_process_id: str):
|
def cancel_batch_process(self, batch_process_id: str):
|
||||||
self.__batches = [batch for batch in self.__batches if batch.id != batch_process_id]
|
self.__batch_process_storage.cancel(batch_process_id)
|
||||||
|
@ -48,6 +48,7 @@ class BatchProcess(BaseModel):
|
|||||||
description="List of batch configs to apply to this session",
|
description="List of batch configs to apply to this session",
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
)
|
)
|
||||||
|
canceled: bool = Field(description="Flag for saying whether or not to run sessions from this batch", default=False)
|
||||||
graph: Graph = Field(description="The graph being executed")
|
graph: Graph = Field(description="The graph being executed")
|
||||||
|
|
||||||
|
|
||||||
@ -123,6 +124,14 @@ class BatchProcessStorageBase(ABC):
|
|||||||
"""Gets a Batch Process record."""
|
"""Gets a Batch Process record."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def cancel(
|
||||||
|
self,
|
||||||
|
batch_id: str,
|
||||||
|
):
|
||||||
|
"""Cancel Batch Process record."""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_session(
|
def create_session(
|
||||||
self,
|
self,
|
||||||
@ -200,6 +209,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
|||||||
batch_id TEXT NOT NULL PRIMARY KEY,
|
batch_id TEXT NOT NULL PRIMARY KEY,
|
||||||
batches TEXT NOT NULL,
|
batches TEXT NOT NULL,
|
||||||
graph TEXT NOT NULL,
|
graph TEXT NOT NULL,
|
||||||
|
canceled BOOLEAN NOT NULL DEFAULT(0),
|
||||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
-- Updated via trigger
|
-- Updated via trigger
|
||||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
@ -325,12 +335,14 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
|||||||
batch_id = session_dict.get("batch_id", "unknown")
|
batch_id = session_dict.get("batch_id", "unknown")
|
||||||
batches_raw = session_dict.get("batches", "unknown")
|
batches_raw = session_dict.get("batches", "unknown")
|
||||||
graph_raw = session_dict.get("graph", "unknown")
|
graph_raw = session_dict.get("graph", "unknown")
|
||||||
|
canceled = session_dict.get("canceled", 0)
|
||||||
batches = json.loads(batches_raw)
|
batches = json.loads(batches_raw)
|
||||||
batches = [parse_raw_as(Batch, batch) for batch in batches]
|
batches = [parse_raw_as(Batch, batch) for batch in batches]
|
||||||
return BatchProcess(
|
return BatchProcess(
|
||||||
batch_id=batch_id,
|
batch_id=batch_id,
|
||||||
batches=batches,
|
batches=batches,
|
||||||
graph=parse_raw_as(Graph, graph_raw),
|
graph=parse_raw_as(Graph, graph_raw),
|
||||||
|
canceled = canceled == 1
|
||||||
)
|
)
|
||||||
|
|
||||||
def get(
|
def get(
|
||||||
@ -358,6 +370,28 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
|||||||
raise BatchProcessNotFoundException
|
raise BatchProcessNotFoundException
|
||||||
return self._deserialize_batch_process(dict(result))
|
return self._deserialize_batch_process(dict(result))
|
||||||
|
|
||||||
|
|
||||||
|
def cancel(
|
||||||
|
self,
|
||||||
|
batch_id: str,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
f"""--sql
|
||||||
|
UPDATE batch_process
|
||||||
|
SET canceled = 1
|
||||||
|
WHERE batch_id = ?;
|
||||||
|
""",
|
||||||
|
(batch_id,),
|
||||||
|
)
|
||||||
|
self._conn.commit()
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise BatchSessionSaveException from e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
def create_session(
|
def create_session(
|
||||||
self,
|
self,
|
||||||
session: BatchSession,
|
session: BatchSession,
|
||||||
|
Loading…
Reference in New Issue
Block a user