diff --git a/invokeai/app/api/routers/session_queue.py b/invokeai/app/api/routers/session_queue.py index 5dd4693795..f7d29a88c5 100644 --- a/invokeai/app/api/routers/session_queue.py +++ b/invokeai/app/api/routers/session_queue.py @@ -11,9 +11,11 @@ from invokeai.app.services.session_queue.session_queue_common import ( Batch, BatchStatus, CancelByBatchIDsResult, + CancelByOriginResult, ClearResult, EnqueueBatchResult, PruneResult, + QueueItemOrigin, SessionQueueItem, SessionQueueItemDTO, SessionQueueStatus, @@ -105,6 +107,19 @@ async def cancel_by_batch_ids( return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(queue_id=queue_id, batch_ids=batch_ids) +@session_queue_router.put( + "/{queue_id}/cancel_by_origin", + operation_id="cancel_by_origin", + responses={200: {"model": CancelByBatchIDsResult}}, +) +async def cancel_by_origin( + queue_id: str = Path(description="The queue id to perform this operation on"), + origin: QueueItemOrigin = Query(description="The origin to cancel all queue items for"), +) -> CancelByOriginResult: + """Immediately cancels all queue items with the given origin""" + return ApiDependencies.invoker.services.session_queue.cancel_by_origin(queue_id=queue_id, origin=origin) + + @session_queue_router.put( "/{queue_id}/clear", operation_id="clear", diff --git a/invokeai/app/services/events/events_common.py b/invokeai/app/services/events/events_common.py index c6a867fb08..a4570fa8e5 100644 --- a/invokeai/app/services/events/events_common.py +++ b/invokeai/app/services/events/events_common.py @@ -10,6 +10,7 @@ from invokeai.app.services.session_queue.session_queue_common import ( QUEUE_ITEM_STATUS, BatchStatus, EnqueueBatchResult, + QueueItemOrigin, SessionQueueItem, SessionQueueStatus, ) @@ -88,6 +89,7 @@ class QueueItemEventBase(QueueEventBase): item_id: int = Field(description="The ID of the queue item") batch_id: str = Field(description="The ID of the queue batch") + origin: QueueItemOrigin | None = Field(default=None, description="The origin of the batch") class InvocationEventBase(QueueItemEventBase): @@ -95,8 +97,6 @@ class InvocationEventBase(QueueItemEventBase): session_id: str = Field(description="The ID of the session (aka graph execution state)") queue_id: str = Field(description="The ID of the queue") - item_id: int = Field(description="The ID of the queue item") - batch_id: str = Field(description="The ID of the queue batch") session_id: str = Field(description="The ID of the session (aka graph execution state)") invocation: AnyInvocation = Field(description="The ID of the invocation") invocation_source_id: str = Field(description="The ID of the prepared invocation's source node") @@ -114,6 +114,7 @@ class InvocationStartedEvent(InvocationEventBase): queue_id=queue_item.queue_id, item_id=queue_item.item_id, batch_id=queue_item.batch_id, + origin=queue_item.origin, session_id=queue_item.session_id, invocation=invocation, invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], @@ -147,6 +148,7 @@ class InvocationDenoiseProgressEvent(InvocationEventBase): queue_id=queue_item.queue_id, item_id=queue_item.item_id, batch_id=queue_item.batch_id, + origin=queue_item.origin, session_id=queue_item.session_id, invocation=invocation, invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], @@ -184,6 +186,7 @@ class InvocationCompleteEvent(InvocationEventBase): queue_id=queue_item.queue_id, item_id=queue_item.item_id, batch_id=queue_item.batch_id, + origin=queue_item.origin, session_id=queue_item.session_id, invocation=invocation, invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], @@ -216,6 +219,7 @@ class InvocationErrorEvent(InvocationEventBase): queue_id=queue_item.queue_id, item_id=queue_item.item_id, batch_id=queue_item.batch_id, + origin=queue_item.origin, session_id=queue_item.session_id, invocation=invocation, invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], @@ -253,6 +257,7 @@ class QueueItemStatusChangedEvent(QueueItemEventBase): queue_id=queue_item.queue_id, item_id=queue_item.item_id, batch_id=queue_item.batch_id, + origin=queue_item.origin, session_id=queue_item.session_id, status=queue_item.status, error_type=queue_item.error_type, @@ -279,12 +284,14 @@ class BatchEnqueuedEvent(QueueEventBase): description="The number of invocations initially requested to be enqueued (may be less than enqueued if queue was full)" ) priority: int = Field(description="The priority of the batch") + origin: QueueItemOrigin | None = Field(default=None, description="The origin of the batch") @classmethod def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent": return cls( queue_id=enqueue_result.queue_id, batch_id=enqueue_result.batch.batch_id, + origin=enqueue_result.batch.origin, enqueued=enqueue_result.enqueued, requested=enqueue_result.requested, priority=enqueue_result.priority, diff --git a/invokeai/app/services/session_queue/session_queue_base.py b/invokeai/app/services/session_queue/session_queue_base.py index 341e034487..9658117048 100644 --- a/invokeai/app/services/session_queue/session_queue_base.py +++ b/invokeai/app/services/session_queue/session_queue_base.py @@ -6,12 +6,14 @@ from invokeai.app.services.session_queue.session_queue_common import ( Batch, BatchStatus, CancelByBatchIDsResult, + CancelByOriginResult, CancelByQueueIDResult, ClearResult, EnqueueBatchResult, IsEmptyResult, IsFullResult, PruneResult, + QueueItemOrigin, SessionQueueItem, SessionQueueItemDTO, SessionQueueStatus, @@ -95,6 +97,11 @@ class SessionQueueBase(ABC): """Cancels all queue items with matching batch IDs""" pass + @abstractmethod + def cancel_by_origin(self, queue_id: str, origin: QueueItemOrigin) -> CancelByOriginResult: + """Cancels all queue items with the given batch origin""" + pass + @abstractmethod def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult: """Cancels all queue items with matching queue ID""" diff --git a/invokeai/app/services/session_queue/session_queue_common.py b/invokeai/app/services/session_queue/session_queue_common.py index 7f4601eba7..5348339e71 100644 --- a/invokeai/app/services/session_queue/session_queue_common.py +++ b/invokeai/app/services/session_queue/session_queue_common.py @@ -1,5 +1,6 @@ import datetime import json +from enum import Enum from itertools import chain, product from typing import Generator, Iterable, Literal, NamedTuple, Optional, TypeAlias, Union, cast @@ -21,6 +22,7 @@ from invokeai.app.services.workflow_records.workflow_records_common import ( WorkflowWithoutID, WorkflowWithoutIDValidator, ) +from invokeai.app.util.metaenum import MetaEnum from invokeai.app.util.misc import uuid_string # region Errors @@ -58,6 +60,13 @@ BatchDataType = Union[ ] +class QueueItemOrigin(str, Enum, metaclass=MetaEnum): + """The origin of a batch. For example, a batch can be created from the canvas or workflows tab.""" + + CANVAS = "canvas" + WORKFLOWS = "workflows" + + class NodeFieldValue(BaseModel): node_path: str = Field(description="The node into which this batch data item will be substituted.") field_name: str = Field(description="The field into which this batch data item will be substituted.") @@ -77,6 +86,7 @@ BatchDataCollection: TypeAlias = list[list[BatchDatum]] class Batch(BaseModel): batch_id: str = Field(default_factory=uuid_string, description="The ID of the batch") + origin: QueueItemOrigin | None = Field(default=None, description="The origin of this batch.") data: Optional[BatchDataCollection] = Field(default=None, description="The batch data collection.") graph: Graph = Field(description="The graph to initialize the session with") workflow: Optional[WorkflowWithoutID] = Field( @@ -195,6 +205,7 @@ class SessionQueueItemWithoutGraph(BaseModel): status: QUEUE_ITEM_STATUS = Field(default="pending", description="The status of this queue item") priority: int = Field(default=0, description="The priority of this queue item") batch_id: str = Field(description="The ID of the batch associated with this queue item") + origin: QueueItemOrigin | None = Field(default=None, description="The origin of this queue item. ") session_id: str = Field( description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed." ) @@ -294,6 +305,7 @@ class SessionQueueStatus(BaseModel): class BatchStatus(BaseModel): queue_id: str = Field(..., description="The ID of the queue") batch_id: str = Field(..., description="The ID of the batch") + origin: QueueItemOrigin | None = Field(..., description="The origin of the batch") pending: int = Field(..., description="Number of queue items with status 'pending'") in_progress: int = Field(..., description="Number of queue items with status 'in_progress'") completed: int = Field(..., description="Number of queue items with status 'complete'") @@ -328,6 +340,12 @@ class CancelByBatchIDsResult(BaseModel): canceled: int = Field(..., description="Number of queue items canceled") +class CancelByOriginResult(BaseModel): + """Result of canceling by list of batch ids""" + + canceled: int = Field(..., description="Number of queue items canceled") + + class CancelByQueueIDResult(CancelByBatchIDsResult): """Result of canceling by queue id""" @@ -433,6 +451,7 @@ class SessionQueueValueToInsert(NamedTuple): field_values: Optional[str] # field_values json priority: int # priority workflow: Optional[str] # workflow json + origin: QueueItemOrigin | None ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert] @@ -453,6 +472,7 @@ def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new json.dumps(field_values, default=to_jsonable_python) if field_values else None, # field_values (json) priority, # priority json.dumps(workflow, default=to_jsonable_python) if workflow else None, # workflow (json) + batch.origin, # origin ) ) return values_to_insert diff --git a/invokeai/app/services/session_queue/session_queue_sqlite.py b/invokeai/app/services/session_queue/session_queue_sqlite.py index a3a7004c94..38f8eaa422 100644 --- a/invokeai/app/services/session_queue/session_queue_sqlite.py +++ b/invokeai/app/services/session_queue/session_queue_sqlite.py @@ -10,12 +10,14 @@ from invokeai.app.services.session_queue.session_queue_common import ( Batch, BatchStatus, CancelByBatchIDsResult, + CancelByOriginResult, CancelByQueueIDResult, ClearResult, EnqueueBatchResult, IsEmptyResult, IsFullResult, PruneResult, + QueueItemOrigin, SessionQueueItem, SessionQueueItemDTO, SessionQueueItemNotFoundError, @@ -127,8 +129,8 @@ class SqliteSessionQueue(SessionQueueBase): self.__cursor.executemany( """--sql - INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow) - VALUES (?, ?, ?, ?, ?, ?, ?) + INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) """, values_to_insert, ) @@ -417,11 +419,7 @@ class SqliteSessionQueue(SessionQueueBase): ) self.__conn.commit() if current_queue_item is not None and current_queue_item.batch_id in batch_ids: - batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id) - queue_status = self.get_queue_status(queue_id=queue_id) - self.__invoker.services.events.emit_queue_item_status_changed( - current_queue_item, batch_status, queue_status - ) + self._set_queue_item_status(current_queue_item.item_id, "canceled") except Exception: self.__conn.rollback() raise @@ -429,6 +427,46 @@ class SqliteSessionQueue(SessionQueueBase): self.__lock.release() return CancelByBatchIDsResult(canceled=count) + def cancel_by_origin(self, queue_id: str, origin: QueueItemOrigin) -> CancelByOriginResult: + try: + current_queue_item = self.get_current(queue_id) + self.__lock.acquire() + where = """--sql + WHERE + queue_id == ? + AND origin == ? + AND status != 'canceled' + AND status != 'completed' + AND status != 'failed' + """ + params = (queue_id, origin) + self.__cursor.execute( + f"""--sql + SELECT COUNT(*) + FROM session_queue + {where}; + """, + params, + ) + count = self.__cursor.fetchone()[0] + self.__cursor.execute( + f"""--sql + UPDATE session_queue + SET status = 'canceled' + {where}; + """, + params, + ) + self.__conn.commit() + if current_queue_item is not None and current_queue_item.origin == origin: + self._set_queue_item_status(current_queue_item.item_id, "canceled") + except Exception: + self.__conn.rollback() + raise + finally: + self.__lock.release() + return CancelByOriginResult(canceled=count) + def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult: try: current_queue_item = self.get_current(queue_id) @@ -541,7 +579,8 @@ class SqliteSessionQueue(SessionQueueBase): started_at, session_id, batch_id, - queue_id + queue_id, + origin FROM session_queue WHERE queue_id = ? """ @@ -621,7 +660,7 @@ class SqliteSessionQueue(SessionQueueBase): self.__lock.acquire() self.__cursor.execute( """--sql - SELECT status, count(*) + SELECT status, count(*), origin FROM session_queue WHERE queue_id = ? @@ -633,6 +672,7 @@ class SqliteSessionQueue(SessionQueueBase): result = cast(list[sqlite3.Row], self.__cursor.fetchall()) total = sum(row[1] for row in result) counts: dict[str, int] = {row[0]: row[1] for row in result} + origin = result[0]["origin"] if result else None except Exception: self.__conn.rollback() raise @@ -641,6 +681,7 @@ class SqliteSessionQueue(SessionQueueBase): return BatchStatus( batch_id=batch_id, + origin=origin, queue_id=queue_id, pending=counts.get("pending", 0), in_progress=counts.get("in_progress", 0), diff --git a/invokeai/app/services/shared/sqlite/sqlite_util.py b/invokeai/app/services/shared/sqlite/sqlite_util.py index e35c351ff0..5e1df29602 100644 --- a/invokeai/app/services/shared/sqlite/sqlite_util.py +++ b/invokeai/app/services/shared/sqlite/sqlite_util.py @@ -17,6 +17,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_11 import from invokeai.app.services.shared.sqlite_migrator.migrations.migration_12 import build_migration_12 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_13 import build_migration_13 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_14 import build_migration_14 +from invokeai.app.services.shared.sqlite_migrator.migrations.migration_15 import build_migration_15 from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator @@ -51,6 +52,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto migrator.register_migration(build_migration_12(app_config=config)) migrator.register_migration(build_migration_13()) migrator.register_migration(build_migration_14()) + migrator.register_migration(build_migration_15()) migrator.run_migrations() return db diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_15.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_15.py new file mode 100644 index 0000000000..026df180f3 --- /dev/null +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_15.py @@ -0,0 +1,31 @@ +import sqlite3 + +from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration + + +class Migration15Callback: + def __call__(self, cursor: sqlite3.Cursor) -> None: + self._add_origin_col(cursor) + + def _add_origin_col(self, cursor: sqlite3.Cursor) -> None: + """ + - Adds `origin` column to the session queue table. + """ + + cursor.execute("ALTER TABLE session_queue ADD COLUMN origin TEXT;") + + +def build_migration_15() -> Migration: + """ + Build the migration from database version 14 to 15. + + This migration does the following: + - Adds `origin` column to the session queue table. + """ + migration_15 = Migration( + from_version=14, + to_version=15, + callback=Migration15Callback(), + ) + + return migration_15