From b939192b1670eea71f9337e39f8973bdadf7faf5 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 1 Jul 2024 19:23:41 +1000 Subject: [PATCH] feat(app): add origin to session queue The origin is an optional field indicating the queue item's origin. For example, "canvas" when the queue item originated from the canvas or "workflows" when the queue item originated from the workflows tab. If omitted, we assume the queue item originated from the API directly. - Add migration to add the nullable column to the `session_queue` table. - Update relevant event payloads with the new field. - Add `cancel_by_origin` method to `session_queue` service and corresponding route. This is required for the canvas to bail out early when staging images. - Add `origin` to both `SessionQueueItem` and `Batch` - it needs to be provided initially via the batch and then passed onto the queue item. - --- invokeai/app/api/routers/session_queue.py | 15 +++++ invokeai/app/services/events/events_common.py | 11 +++- .../session_queue/session_queue_base.py | 7 +++ .../session_queue/session_queue_common.py | 20 +++++++ .../session_queue/session_queue_sqlite.py | 59 ++++++++++++++++--- .../app/services/shared/sqlite/sqlite_util.py | 2 + .../migrations/migration_15.py | 31 ++++++++++ 7 files changed, 134 insertions(+), 11 deletions(-) create mode 100644 invokeai/app/services/shared/sqlite_migrator/migrations/migration_15.py diff --git a/invokeai/app/api/routers/session_queue.py b/invokeai/app/api/routers/session_queue.py index 5dd4693795..f7d29a88c5 100644 --- a/invokeai/app/api/routers/session_queue.py +++ b/invokeai/app/api/routers/session_queue.py @@ -11,9 +11,11 @@ from invokeai.app.services.session_queue.session_queue_common import ( Batch, BatchStatus, CancelByBatchIDsResult, + CancelByOriginResult, ClearResult, EnqueueBatchResult, PruneResult, + QueueItemOrigin, SessionQueueItem, SessionQueueItemDTO, SessionQueueStatus, @@ -105,6 +107,19 @@ async def cancel_by_batch_ids( return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(queue_id=queue_id, batch_ids=batch_ids) +@session_queue_router.put( + "/{queue_id}/cancel_by_origin", + operation_id="cancel_by_origin", + responses={200: {"model": CancelByBatchIDsResult}}, +) +async def cancel_by_origin( + queue_id: str = Path(description="The queue id to perform this operation on"), + origin: QueueItemOrigin = Query(description="The origin to cancel all queue items for"), +) -> CancelByOriginResult: + """Immediately cancels all queue items with the given origin""" + return ApiDependencies.invoker.services.session_queue.cancel_by_origin(queue_id=queue_id, origin=origin) + + @session_queue_router.put( "/{queue_id}/clear", operation_id="clear", diff --git a/invokeai/app/services/events/events_common.py b/invokeai/app/services/events/events_common.py index c6a867fb08..a4570fa8e5 100644 --- a/invokeai/app/services/events/events_common.py +++ b/invokeai/app/services/events/events_common.py @@ -10,6 +10,7 @@ from invokeai.app.services.session_queue.session_queue_common import ( QUEUE_ITEM_STATUS, BatchStatus, EnqueueBatchResult, + QueueItemOrigin, SessionQueueItem, SessionQueueStatus, ) @@ -88,6 +89,7 @@ class QueueItemEventBase(QueueEventBase): item_id: int = Field(description="The ID of the queue item") batch_id: str = Field(description="The ID of the queue batch") + origin: QueueItemOrigin | None = Field(default=None, description="The origin of the batch") class InvocationEventBase(QueueItemEventBase): @@ -95,8 +97,6 @@ class InvocationEventBase(QueueItemEventBase): session_id: str = Field(description="The ID of the session (aka graph execution state)") queue_id: str = Field(description="The ID of the queue") - item_id: int = Field(description="The ID of the queue item") - batch_id: str = Field(description="The ID of the queue batch") session_id: str = Field(description="The ID of the session (aka graph execution state)") invocation: AnyInvocation = Field(description="The ID of the invocation") invocation_source_id: str = Field(description="The ID of the prepared invocation's source node") @@ -114,6 +114,7 @@ class InvocationStartedEvent(InvocationEventBase): queue_id=queue_item.queue_id, item_id=queue_item.item_id, batch_id=queue_item.batch_id, + origin=queue_item.origin, session_id=queue_item.session_id, invocation=invocation, invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], @@ -147,6 +148,7 @@ class InvocationDenoiseProgressEvent(InvocationEventBase): queue_id=queue_item.queue_id, item_id=queue_item.item_id, batch_id=queue_item.batch_id, + origin=queue_item.origin, session_id=queue_item.session_id, invocation=invocation, invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], @@ -184,6 +186,7 @@ class InvocationCompleteEvent(InvocationEventBase): queue_id=queue_item.queue_id, item_id=queue_item.item_id, batch_id=queue_item.batch_id, + origin=queue_item.origin, session_id=queue_item.session_id, invocation=invocation, invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], @@ -216,6 +219,7 @@ class InvocationErrorEvent(InvocationEventBase): queue_id=queue_item.queue_id, item_id=queue_item.item_id, batch_id=queue_item.batch_id, + origin=queue_item.origin, session_id=queue_item.session_id, invocation=invocation, invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], @@ -253,6 +257,7 @@ class QueueItemStatusChangedEvent(QueueItemEventBase): queue_id=queue_item.queue_id, item_id=queue_item.item_id, batch_id=queue_item.batch_id, + origin=queue_item.origin, session_id=queue_item.session_id, status=queue_item.status, error_type=queue_item.error_type, @@ -279,12 +284,14 @@ class BatchEnqueuedEvent(QueueEventBase): description="The number of invocations initially requested to be enqueued (may be less than enqueued if queue was full)" ) priority: int = Field(description="The priority of the batch") + origin: QueueItemOrigin | None = Field(default=None, description="The origin of the batch") @classmethod def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent": return cls( queue_id=enqueue_result.queue_id, batch_id=enqueue_result.batch.batch_id, + origin=enqueue_result.batch.origin, enqueued=enqueue_result.enqueued, requested=enqueue_result.requested, priority=enqueue_result.priority, diff --git a/invokeai/app/services/session_queue/session_queue_base.py b/invokeai/app/services/session_queue/session_queue_base.py index 341e034487..9658117048 100644 --- a/invokeai/app/services/session_queue/session_queue_base.py +++ b/invokeai/app/services/session_queue/session_queue_base.py @@ -6,12 +6,14 @@ from invokeai.app.services.session_queue.session_queue_common import ( Batch, BatchStatus, CancelByBatchIDsResult, + CancelByOriginResult, CancelByQueueIDResult, ClearResult, EnqueueBatchResult, IsEmptyResult, IsFullResult, PruneResult, + QueueItemOrigin, SessionQueueItem, SessionQueueItemDTO, SessionQueueStatus, @@ -95,6 +97,11 @@ class SessionQueueBase(ABC): """Cancels all queue items with matching batch IDs""" pass + @abstractmethod + def cancel_by_origin(self, queue_id: str, origin: QueueItemOrigin) -> CancelByOriginResult: + """Cancels all queue items with the given batch origin""" + pass + @abstractmethod def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult: """Cancels all queue items with matching queue ID""" diff --git a/invokeai/app/services/session_queue/session_queue_common.py b/invokeai/app/services/session_queue/session_queue_common.py index 7f4601eba7..5348339e71 100644 --- a/invokeai/app/services/session_queue/session_queue_common.py +++ b/invokeai/app/services/session_queue/session_queue_common.py @@ -1,5 +1,6 @@ import datetime import json +from enum import Enum from itertools import chain, product from typing import Generator, Iterable, Literal, NamedTuple, Optional, TypeAlias, Union, cast @@ -21,6 +22,7 @@ from invokeai.app.services.workflow_records.workflow_records_common import ( WorkflowWithoutID, WorkflowWithoutIDValidator, ) +from invokeai.app.util.metaenum import MetaEnum from invokeai.app.util.misc import uuid_string # region Errors @@ -58,6 +60,13 @@ BatchDataType = Union[ ] +class QueueItemOrigin(str, Enum, metaclass=MetaEnum): + """The origin of a batch. For example, a batch can be created from the canvas or workflows tab.""" + + CANVAS = "canvas" + WORKFLOWS = "workflows" + + class NodeFieldValue(BaseModel): node_path: str = Field(description="The node into which this batch data item will be substituted.") field_name: str = Field(description="The field into which this batch data item will be substituted.") @@ -77,6 +86,7 @@ BatchDataCollection: TypeAlias = list[list[BatchDatum]] class Batch(BaseModel): batch_id: str = Field(default_factory=uuid_string, description="The ID of the batch") + origin: QueueItemOrigin | None = Field(default=None, description="The origin of this batch.") data: Optional[BatchDataCollection] = Field(default=None, description="The batch data collection.") graph: Graph = Field(description="The graph to initialize the session with") workflow: Optional[WorkflowWithoutID] = Field( @@ -195,6 +205,7 @@ class SessionQueueItemWithoutGraph(BaseModel): status: QUEUE_ITEM_STATUS = Field(default="pending", description="The status of this queue item") priority: int = Field(default=0, description="The priority of this queue item") batch_id: str = Field(description="The ID of the batch associated with this queue item") + origin: QueueItemOrigin | None = Field(default=None, description="The origin of this queue item. ") session_id: str = Field( description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed." ) @@ -294,6 +305,7 @@ class SessionQueueStatus(BaseModel): class BatchStatus(BaseModel): queue_id: str = Field(..., description="The ID of the queue") batch_id: str = Field(..., description="The ID of the batch") + origin: QueueItemOrigin | None = Field(..., description="The origin of the batch") pending: int = Field(..., description="Number of queue items with status 'pending'") in_progress: int = Field(..., description="Number of queue items with status 'in_progress'") completed: int = Field(..., description="Number of queue items with status 'complete'") @@ -328,6 +340,12 @@ class CancelByBatchIDsResult(BaseModel): canceled: int = Field(..., description="Number of queue items canceled") +class CancelByOriginResult(BaseModel): + """Result of canceling by list of batch ids""" + + canceled: int = Field(..., description="Number of queue items canceled") + + class CancelByQueueIDResult(CancelByBatchIDsResult): """Result of canceling by queue id""" @@ -433,6 +451,7 @@ class SessionQueueValueToInsert(NamedTuple): field_values: Optional[str] # field_values json priority: int # priority workflow: Optional[str] # workflow json + origin: QueueItemOrigin | None ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert] @@ -453,6 +472,7 @@ def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new json.dumps(field_values, default=to_jsonable_python) if field_values else None, # field_values (json) priority, # priority json.dumps(workflow, default=to_jsonable_python) if workflow else None, # workflow (json) + batch.origin, # origin ) ) return values_to_insert diff --git a/invokeai/app/services/session_queue/session_queue_sqlite.py b/invokeai/app/services/session_queue/session_queue_sqlite.py index a3a7004c94..38f8eaa422 100644 --- a/invokeai/app/services/session_queue/session_queue_sqlite.py +++ b/invokeai/app/services/session_queue/session_queue_sqlite.py @@ -10,12 +10,14 @@ from invokeai.app.services.session_queue.session_queue_common import ( Batch, BatchStatus, CancelByBatchIDsResult, + CancelByOriginResult, CancelByQueueIDResult, ClearResult, EnqueueBatchResult, IsEmptyResult, IsFullResult, PruneResult, + QueueItemOrigin, SessionQueueItem, SessionQueueItemDTO, SessionQueueItemNotFoundError, @@ -127,8 +129,8 @@ class SqliteSessionQueue(SessionQueueBase): self.__cursor.executemany( """--sql - INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow) - VALUES (?, ?, ?, ?, ?, ?, ?) + INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) """, values_to_insert, ) @@ -417,11 +419,7 @@ class SqliteSessionQueue(SessionQueueBase): ) self.__conn.commit() if current_queue_item is not None and current_queue_item.batch_id in batch_ids: - batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id) - queue_status = self.get_queue_status(queue_id=queue_id) - self.__invoker.services.events.emit_queue_item_status_changed( - current_queue_item, batch_status, queue_status - ) + self._set_queue_item_status(current_queue_item.item_id, "canceled") except Exception: self.__conn.rollback() raise @@ -429,6 +427,46 @@ class SqliteSessionQueue(SessionQueueBase): self.__lock.release() return CancelByBatchIDsResult(canceled=count) + def cancel_by_origin(self, queue_id: str, origin: QueueItemOrigin) -> CancelByOriginResult: + try: + current_queue_item = self.get_current(queue_id) + self.__lock.acquire() + where = """--sql + WHERE + queue_id == ? + AND origin == ? + AND status != 'canceled' + AND status != 'completed' + AND status != 'failed' + """ + params = (queue_id, origin) + self.__cursor.execute( + f"""--sql + SELECT COUNT(*) + FROM session_queue + {where}; + """, + params, + ) + count = self.__cursor.fetchone()[0] + self.__cursor.execute( + f"""--sql + UPDATE session_queue + SET status = 'canceled' + {where}; + """, + params, + ) + self.__conn.commit() + if current_queue_item is not None and current_queue_item.origin == origin: + self._set_queue_item_status(current_queue_item.item_id, "canceled") + except Exception: + self.__conn.rollback() + raise + finally: + self.__lock.release() + return CancelByOriginResult(canceled=count) + def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult: try: current_queue_item = self.get_current(queue_id) @@ -541,7 +579,8 @@ class SqliteSessionQueue(SessionQueueBase): started_at, session_id, batch_id, - queue_id + queue_id, + origin FROM session_queue WHERE queue_id = ? """ @@ -621,7 +660,7 @@ class SqliteSessionQueue(SessionQueueBase): self.__lock.acquire() self.__cursor.execute( """--sql - SELECT status, count(*) + SELECT status, count(*), origin FROM session_queue WHERE queue_id = ? @@ -633,6 +672,7 @@ class SqliteSessionQueue(SessionQueueBase): result = cast(list[sqlite3.Row], self.__cursor.fetchall()) total = sum(row[1] for row in result) counts: dict[str, int] = {row[0]: row[1] for row in result} + origin = result[0]["origin"] if result else None except Exception: self.__conn.rollback() raise @@ -641,6 +681,7 @@ class SqliteSessionQueue(SessionQueueBase): return BatchStatus( batch_id=batch_id, + origin=origin, queue_id=queue_id, pending=counts.get("pending", 0), in_progress=counts.get("in_progress", 0), diff --git a/invokeai/app/services/shared/sqlite/sqlite_util.py b/invokeai/app/services/shared/sqlite/sqlite_util.py index e35c351ff0..5e1df29602 100644 --- a/invokeai/app/services/shared/sqlite/sqlite_util.py +++ b/invokeai/app/services/shared/sqlite/sqlite_util.py @@ -17,6 +17,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_11 import from invokeai.app.services.shared.sqlite_migrator.migrations.migration_12 import build_migration_12 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_13 import build_migration_13 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_14 import build_migration_14 +from invokeai.app.services.shared.sqlite_migrator.migrations.migration_15 import build_migration_15 from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator @@ -51,6 +52,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto migrator.register_migration(build_migration_12(app_config=config)) migrator.register_migration(build_migration_13()) migrator.register_migration(build_migration_14()) + migrator.register_migration(build_migration_15()) migrator.run_migrations() return db diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_15.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_15.py new file mode 100644 index 0000000000..026df180f3 --- /dev/null +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_15.py @@ -0,0 +1,31 @@ +import sqlite3 + +from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration + + +class Migration15Callback: + def __call__(self, cursor: sqlite3.Cursor) -> None: + self._add_origin_col(cursor) + + def _add_origin_col(self, cursor: sqlite3.Cursor) -> None: + """ + - Adds `origin` column to the session queue table. + """ + + cursor.execute("ALTER TABLE session_queue ADD COLUMN origin TEXT;") + + +def build_migration_15() -> Migration: + """ + Build the migration from database version 14 to 15. + + This migration does the following: + - Adds `origin` column to the session queue table. + """ + migration_15 = Migration( + from_version=14, + to_version=15, + callback=Migration15Callback(), + ) + + return migration_15