feat(app): add origin to session queue

The origin is an optional field indicating the queue item's origin. For example, "canvas" when the queue item originated from the canvas or "workflows" when the queue item originated from the workflows tab. If omitted, we assume the queue item originated from the API directly.

- Add migration to add the nullable column to the `session_queue` table.
- Update relevant event payloads with the new field.
- Add `cancel_by_origin` method to `session_queue` service and corresponding route. This is required for the canvas to bail out early when staging images.
- Add `origin` to both `SessionQueueItem` and `Batch` - it needs to be provided initially via the batch and then passed onto the queue item.
-
This commit is contained in:
psychedelicious 2024-07-01 19:23:41 +10:00
parent 7ccf559a06
commit b939192b16
7 changed files with 134 additions and 11 deletions

View File

@ -11,9 +11,11 @@ from invokeai.app.services.session_queue.session_queue_common import (
Batch, Batch,
BatchStatus, BatchStatus,
CancelByBatchIDsResult, CancelByBatchIDsResult,
CancelByOriginResult,
ClearResult, ClearResult,
EnqueueBatchResult, EnqueueBatchResult,
PruneResult, PruneResult,
QueueItemOrigin,
SessionQueueItem, SessionQueueItem,
SessionQueueItemDTO, SessionQueueItemDTO,
SessionQueueStatus, SessionQueueStatus,
@ -105,6 +107,19 @@ async def cancel_by_batch_ids(
return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(queue_id=queue_id, batch_ids=batch_ids) return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(queue_id=queue_id, batch_ids=batch_ids)
@session_queue_router.put(
"/{queue_id}/cancel_by_origin",
operation_id="cancel_by_origin",
responses={200: {"model": CancelByBatchIDsResult}},
)
async def cancel_by_origin(
queue_id: str = Path(description="The queue id to perform this operation on"),
origin: QueueItemOrigin = Query(description="The origin to cancel all queue items for"),
) -> CancelByOriginResult:
"""Immediately cancels all queue items with the given origin"""
return ApiDependencies.invoker.services.session_queue.cancel_by_origin(queue_id=queue_id, origin=origin)
@session_queue_router.put( @session_queue_router.put(
"/{queue_id}/clear", "/{queue_id}/clear",
operation_id="clear", operation_id="clear",

View File

@ -10,6 +10,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
QUEUE_ITEM_STATUS, QUEUE_ITEM_STATUS,
BatchStatus, BatchStatus,
EnqueueBatchResult, EnqueueBatchResult,
QueueItemOrigin,
SessionQueueItem, SessionQueueItem,
SessionQueueStatus, SessionQueueStatus,
) )
@ -88,6 +89,7 @@ class QueueItemEventBase(QueueEventBase):
item_id: int = Field(description="The ID of the queue item") item_id: int = Field(description="The ID of the queue item")
batch_id: str = Field(description="The ID of the queue batch") batch_id: str = Field(description="The ID of the queue batch")
origin: QueueItemOrigin | None = Field(default=None, description="The origin of the batch")
class InvocationEventBase(QueueItemEventBase): class InvocationEventBase(QueueItemEventBase):
@ -95,8 +97,6 @@ class InvocationEventBase(QueueItemEventBase):
session_id: str = Field(description="The ID of the session (aka graph execution state)") session_id: str = Field(description="The ID of the session (aka graph execution state)")
queue_id: str = Field(description="The ID of the queue") queue_id: str = Field(description="The ID of the queue")
item_id: int = Field(description="The ID of the queue item")
batch_id: str = Field(description="The ID of the queue batch")
session_id: str = Field(description="The ID of the session (aka graph execution state)") session_id: str = Field(description="The ID of the session (aka graph execution state)")
invocation: AnyInvocation = Field(description="The ID of the invocation") invocation: AnyInvocation = Field(description="The ID of the invocation")
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node") invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
@ -114,6 +114,7 @@ class InvocationStartedEvent(InvocationEventBase):
queue_id=queue_item.queue_id, queue_id=queue_item.queue_id,
item_id=queue_item.item_id, item_id=queue_item.item_id,
batch_id=queue_item.batch_id, batch_id=queue_item.batch_id,
origin=queue_item.origin,
session_id=queue_item.session_id, session_id=queue_item.session_id,
invocation=invocation, invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
@ -147,6 +148,7 @@ class InvocationDenoiseProgressEvent(InvocationEventBase):
queue_id=queue_item.queue_id, queue_id=queue_item.queue_id,
item_id=queue_item.item_id, item_id=queue_item.item_id,
batch_id=queue_item.batch_id, batch_id=queue_item.batch_id,
origin=queue_item.origin,
session_id=queue_item.session_id, session_id=queue_item.session_id,
invocation=invocation, invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
@ -184,6 +186,7 @@ class InvocationCompleteEvent(InvocationEventBase):
queue_id=queue_item.queue_id, queue_id=queue_item.queue_id,
item_id=queue_item.item_id, item_id=queue_item.item_id,
batch_id=queue_item.batch_id, batch_id=queue_item.batch_id,
origin=queue_item.origin,
session_id=queue_item.session_id, session_id=queue_item.session_id,
invocation=invocation, invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
@ -216,6 +219,7 @@ class InvocationErrorEvent(InvocationEventBase):
queue_id=queue_item.queue_id, queue_id=queue_item.queue_id,
item_id=queue_item.item_id, item_id=queue_item.item_id,
batch_id=queue_item.batch_id, batch_id=queue_item.batch_id,
origin=queue_item.origin,
session_id=queue_item.session_id, session_id=queue_item.session_id,
invocation=invocation, invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
@ -253,6 +257,7 @@ class QueueItemStatusChangedEvent(QueueItemEventBase):
queue_id=queue_item.queue_id, queue_id=queue_item.queue_id,
item_id=queue_item.item_id, item_id=queue_item.item_id,
batch_id=queue_item.batch_id, batch_id=queue_item.batch_id,
origin=queue_item.origin,
session_id=queue_item.session_id, session_id=queue_item.session_id,
status=queue_item.status, status=queue_item.status,
error_type=queue_item.error_type, error_type=queue_item.error_type,
@ -279,12 +284,14 @@ class BatchEnqueuedEvent(QueueEventBase):
description="The number of invocations initially requested to be enqueued (may be less than enqueued if queue was full)" description="The number of invocations initially requested to be enqueued (may be less than enqueued if queue was full)"
) )
priority: int = Field(description="The priority of the batch") priority: int = Field(description="The priority of the batch")
origin: QueueItemOrigin | None = Field(default=None, description="The origin of the batch")
@classmethod @classmethod
def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent": def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent":
return cls( return cls(
queue_id=enqueue_result.queue_id, queue_id=enqueue_result.queue_id,
batch_id=enqueue_result.batch.batch_id, batch_id=enqueue_result.batch.batch_id,
origin=enqueue_result.batch.origin,
enqueued=enqueue_result.enqueued, enqueued=enqueue_result.enqueued,
requested=enqueue_result.requested, requested=enqueue_result.requested,
priority=enqueue_result.priority, priority=enqueue_result.priority,

View File

@ -6,12 +6,14 @@ from invokeai.app.services.session_queue.session_queue_common import (
Batch, Batch,
BatchStatus, BatchStatus,
CancelByBatchIDsResult, CancelByBatchIDsResult,
CancelByOriginResult,
CancelByQueueIDResult, CancelByQueueIDResult,
ClearResult, ClearResult,
EnqueueBatchResult, EnqueueBatchResult,
IsEmptyResult, IsEmptyResult,
IsFullResult, IsFullResult,
PruneResult, PruneResult,
QueueItemOrigin,
SessionQueueItem, SessionQueueItem,
SessionQueueItemDTO, SessionQueueItemDTO,
SessionQueueStatus, SessionQueueStatus,
@ -95,6 +97,11 @@ class SessionQueueBase(ABC):
"""Cancels all queue items with matching batch IDs""" """Cancels all queue items with matching batch IDs"""
pass pass
@abstractmethod
def cancel_by_origin(self, queue_id: str, origin: QueueItemOrigin) -> CancelByOriginResult:
"""Cancels all queue items with the given batch origin"""
pass
@abstractmethod @abstractmethod
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult: def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
"""Cancels all queue items with matching queue ID""" """Cancels all queue items with matching queue ID"""

View File

@ -1,5 +1,6 @@
import datetime import datetime
import json import json
from enum import Enum
from itertools import chain, product from itertools import chain, product
from typing import Generator, Iterable, Literal, NamedTuple, Optional, TypeAlias, Union, cast from typing import Generator, Iterable, Literal, NamedTuple, Optional, TypeAlias, Union, cast
@ -21,6 +22,7 @@ from invokeai.app.services.workflow_records.workflow_records_common import (
WorkflowWithoutID, WorkflowWithoutID,
WorkflowWithoutIDValidator, WorkflowWithoutIDValidator,
) )
from invokeai.app.util.metaenum import MetaEnum
from invokeai.app.util.misc import uuid_string from invokeai.app.util.misc import uuid_string
# region Errors # region Errors
@ -58,6 +60,13 @@ BatchDataType = Union[
] ]
class QueueItemOrigin(str, Enum, metaclass=MetaEnum):
"""The origin of a batch. For example, a batch can be created from the canvas or workflows tab."""
CANVAS = "canvas"
WORKFLOWS = "workflows"
class NodeFieldValue(BaseModel): class NodeFieldValue(BaseModel):
node_path: str = Field(description="The node into which this batch data item will be substituted.") node_path: str = Field(description="The node into which this batch data item will be substituted.")
field_name: str = Field(description="The field into which this batch data item will be substituted.") field_name: str = Field(description="The field into which this batch data item will be substituted.")
@ -77,6 +86,7 @@ BatchDataCollection: TypeAlias = list[list[BatchDatum]]
class Batch(BaseModel): class Batch(BaseModel):
batch_id: str = Field(default_factory=uuid_string, description="The ID of the batch") batch_id: str = Field(default_factory=uuid_string, description="The ID of the batch")
origin: QueueItemOrigin | None = Field(default=None, description="The origin of this batch.")
data: Optional[BatchDataCollection] = Field(default=None, description="The batch data collection.") data: Optional[BatchDataCollection] = Field(default=None, description="The batch data collection.")
graph: Graph = Field(description="The graph to initialize the session with") graph: Graph = Field(description="The graph to initialize the session with")
workflow: Optional[WorkflowWithoutID] = Field( workflow: Optional[WorkflowWithoutID] = Field(
@ -195,6 +205,7 @@ class SessionQueueItemWithoutGraph(BaseModel):
status: QUEUE_ITEM_STATUS = Field(default="pending", description="The status of this queue item") status: QUEUE_ITEM_STATUS = Field(default="pending", description="The status of this queue item")
priority: int = Field(default=0, description="The priority of this queue item") priority: int = Field(default=0, description="The priority of this queue item")
batch_id: str = Field(description="The ID of the batch associated with this queue item") batch_id: str = Field(description="The ID of the batch associated with this queue item")
origin: QueueItemOrigin | None = Field(default=None, description="The origin of this queue item. ")
session_id: str = Field( session_id: str = Field(
description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed." description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed."
) )
@ -294,6 +305,7 @@ class SessionQueueStatus(BaseModel):
class BatchStatus(BaseModel): class BatchStatus(BaseModel):
queue_id: str = Field(..., description="The ID of the queue") queue_id: str = Field(..., description="The ID of the queue")
batch_id: str = Field(..., description="The ID of the batch") batch_id: str = Field(..., description="The ID of the batch")
origin: QueueItemOrigin | None = Field(..., description="The origin of the batch")
pending: int = Field(..., description="Number of queue items with status 'pending'") pending: int = Field(..., description="Number of queue items with status 'pending'")
in_progress: int = Field(..., description="Number of queue items with status 'in_progress'") in_progress: int = Field(..., description="Number of queue items with status 'in_progress'")
completed: int = Field(..., description="Number of queue items with status 'complete'") completed: int = Field(..., description="Number of queue items with status 'complete'")
@ -328,6 +340,12 @@ class CancelByBatchIDsResult(BaseModel):
canceled: int = Field(..., description="Number of queue items canceled") canceled: int = Field(..., description="Number of queue items canceled")
class CancelByOriginResult(BaseModel):
"""Result of canceling by list of batch ids"""
canceled: int = Field(..., description="Number of queue items canceled")
class CancelByQueueIDResult(CancelByBatchIDsResult): class CancelByQueueIDResult(CancelByBatchIDsResult):
"""Result of canceling by queue id""" """Result of canceling by queue id"""
@ -433,6 +451,7 @@ class SessionQueueValueToInsert(NamedTuple):
field_values: Optional[str] # field_values json field_values: Optional[str] # field_values json
priority: int # priority priority: int # priority
workflow: Optional[str] # workflow json workflow: Optional[str] # workflow json
origin: QueueItemOrigin | None
ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert] ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert]
@ -453,6 +472,7 @@ def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new
json.dumps(field_values, default=to_jsonable_python) if field_values else None, # field_values (json) json.dumps(field_values, default=to_jsonable_python) if field_values else None, # field_values (json)
priority, # priority priority, # priority
json.dumps(workflow, default=to_jsonable_python) if workflow else None, # workflow (json) json.dumps(workflow, default=to_jsonable_python) if workflow else None, # workflow (json)
batch.origin, # origin
) )
) )
return values_to_insert return values_to_insert

View File

@ -10,12 +10,14 @@ from invokeai.app.services.session_queue.session_queue_common import (
Batch, Batch,
BatchStatus, BatchStatus,
CancelByBatchIDsResult, CancelByBatchIDsResult,
CancelByOriginResult,
CancelByQueueIDResult, CancelByQueueIDResult,
ClearResult, ClearResult,
EnqueueBatchResult, EnqueueBatchResult,
IsEmptyResult, IsEmptyResult,
IsFullResult, IsFullResult,
PruneResult, PruneResult,
QueueItemOrigin,
SessionQueueItem, SessionQueueItem,
SessionQueueItemDTO, SessionQueueItemDTO,
SessionQueueItemNotFoundError, SessionQueueItemNotFoundError,
@ -127,8 +129,8 @@ class SqliteSessionQueue(SessionQueueBase):
self.__cursor.executemany( self.__cursor.executemany(
"""--sql """--sql
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow) INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin)
VALUES (?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""", """,
values_to_insert, values_to_insert,
) )
@ -417,11 +419,7 @@ class SqliteSessionQueue(SessionQueueBase):
) )
self.__conn.commit() self.__conn.commit()
if current_queue_item is not None and current_queue_item.batch_id in batch_ids: if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id) self._set_queue_item_status(current_queue_item.item_id, "canceled")
queue_status = self.get_queue_status(queue_id=queue_id)
self.__invoker.services.events.emit_queue_item_status_changed(
current_queue_item, batch_status, queue_status
)
except Exception: except Exception:
self.__conn.rollback() self.__conn.rollback()
raise raise
@ -429,6 +427,46 @@ class SqliteSessionQueue(SessionQueueBase):
self.__lock.release() self.__lock.release()
return CancelByBatchIDsResult(canceled=count) return CancelByBatchIDsResult(canceled=count)
def cancel_by_origin(self, queue_id: str, origin: QueueItemOrigin) -> CancelByOriginResult:
try:
current_queue_item = self.get_current(queue_id)
self.__lock.acquire()
where = """--sql
WHERE
queue_id == ?
AND origin == ?
AND status != 'canceled'
AND status != 'completed'
AND status != 'failed'
"""
params = (queue_id, origin)
self.__cursor.execute(
f"""--sql
SELECT COUNT(*)
FROM session_queue
{where};
""",
params,
)
count = self.__cursor.fetchone()[0]
self.__cursor.execute(
f"""--sql
UPDATE session_queue
SET status = 'canceled'
{where};
""",
params,
)
self.__conn.commit()
if current_queue_item is not None and current_queue_item.origin == origin:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
return CancelByOriginResult(canceled=count)
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult: def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
try: try:
current_queue_item = self.get_current(queue_id) current_queue_item = self.get_current(queue_id)
@ -541,7 +579,8 @@ class SqliteSessionQueue(SessionQueueBase):
started_at, started_at,
session_id, session_id,
batch_id, batch_id,
queue_id queue_id,
origin
FROM session_queue FROM session_queue
WHERE queue_id = ? WHERE queue_id = ?
""" """
@ -621,7 +660,7 @@ class SqliteSessionQueue(SessionQueueBase):
self.__lock.acquire() self.__lock.acquire()
self.__cursor.execute( self.__cursor.execute(
"""--sql """--sql
SELECT status, count(*) SELECT status, count(*), origin
FROM session_queue FROM session_queue
WHERE WHERE
queue_id = ? queue_id = ?
@ -633,6 +672,7 @@ class SqliteSessionQueue(SessionQueueBase):
result = cast(list[sqlite3.Row], self.__cursor.fetchall()) result = cast(list[sqlite3.Row], self.__cursor.fetchall())
total = sum(row[1] for row in result) total = sum(row[1] for row in result)
counts: dict[str, int] = {row[0]: row[1] for row in result} counts: dict[str, int] = {row[0]: row[1] for row in result}
origin = result[0]["origin"] if result else None
except Exception: except Exception:
self.__conn.rollback() self.__conn.rollback()
raise raise
@ -641,6 +681,7 @@ class SqliteSessionQueue(SessionQueueBase):
return BatchStatus( return BatchStatus(
batch_id=batch_id, batch_id=batch_id,
origin=origin,
queue_id=queue_id, queue_id=queue_id,
pending=counts.get("pending", 0), pending=counts.get("pending", 0),
in_progress=counts.get("in_progress", 0), in_progress=counts.get("in_progress", 0),

View File

@ -17,6 +17,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_11 import
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_12 import build_migration_12 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_12 import build_migration_12
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_13 import build_migration_13 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_13 import build_migration_13
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_14 import build_migration_14 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_14 import build_migration_14
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_15 import build_migration_15
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@ -51,6 +52,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_12(app_config=config)) migrator.register_migration(build_migration_12(app_config=config))
migrator.register_migration(build_migration_13()) migrator.register_migration(build_migration_13())
migrator.register_migration(build_migration_14()) migrator.register_migration(build_migration_14())
migrator.register_migration(build_migration_15())
migrator.run_migrations() migrator.run_migrations()
return db return db

View File

@ -0,0 +1,31 @@
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration15Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._add_origin_col(cursor)
def _add_origin_col(self, cursor: sqlite3.Cursor) -> None:
"""
- Adds `origin` column to the session queue table.
"""
cursor.execute("ALTER TABLE session_queue ADD COLUMN origin TEXT;")
def build_migration_15() -> Migration:
"""
Build the migration from database version 14 to 15.
This migration does the following:
- Adds `origin` column to the session queue table.
"""
migration_15 = Migration(
from_version=14,
to_version=15,
callback=Migration15Callback(),
)
return migration_15