mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Allow cancel of running batch
This commit is contained in:
@ -48,6 +48,7 @@ class BatchProcess(BaseModel):
|
||||
description="List of batch configs to apply to this session",
|
||||
default_factory=list,
|
||||
)
|
||||
canceled: bool = Field(description="Flag for saying whether or not to run sessions from this batch", default=False)
|
||||
graph: Graph = Field(description="The graph being executed")
|
||||
|
||||
|
||||
@ -123,6 +124,14 @@ class BatchProcessStorageBase(ABC):
|
||||
"""Gets a Batch Process record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel(
|
||||
self,
|
||||
batch_id: str,
|
||||
):
|
||||
"""Cancel Batch Process record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_session(
|
||||
self,
|
||||
@ -200,6 +209,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
batch_id TEXT NOT NULL PRIMARY KEY,
|
||||
batches TEXT NOT NULL,
|
||||
graph TEXT NOT NULL,
|
||||
canceled BOOLEAN NOT NULL DEFAULT(0),
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
@ -325,12 +335,14 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
batch_id = session_dict.get("batch_id", "unknown")
|
||||
batches_raw = session_dict.get("batches", "unknown")
|
||||
graph_raw = session_dict.get("graph", "unknown")
|
||||
canceled = session_dict.get("canceled", 0)
|
||||
batches = json.loads(batches_raw)
|
||||
batches = [parse_raw_as(Batch, batch) for batch in batches]
|
||||
return BatchProcess(
|
||||
batch_id=batch_id,
|
||||
batches=batches,
|
||||
graph=parse_raw_as(Graph, graph_raw),
|
||||
canceled = canceled == 1
|
||||
)
|
||||
|
||||
def get(
|
||||
@ -358,6 +370,28 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
raise BatchProcessNotFoundException
|
||||
return self._deserialize_batch_process(dict(result))
|
||||
|
||||
|
||||
def cancel(
|
||||
self,
|
||||
batch_id: str,
|
||||
):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE batch_process
|
||||
SET canceled = 1
|
||||
WHERE batch_id = ?;
|
||||
""",
|
||||
(batch_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchSessionSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def create_session(
|
||||
self,
|
||||
session: BatchSession,
|
||||
|
Reference in New Issue
Block a user