From 56fbe751db874cf32722854d4e3d465f88aeb1c6 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 29 Aug 2024 17:49:06 +1000 Subject: [PATCH] feat(app): add `destination` column to `session_queue` The frontend needs to know where queue items came from (i.e. which tab), and where results are going to (i.e. send images to gallery or canvas). The `origin` column is not quite enough to represent this cleanly. A `destination` column provides the frontend what it needs to handle incoming generations. --- invokeai/app/services/events/events_common.py | 8 ++++++- .../session_queue/session_queue_common.py | 21 +++++++++++++++++-- .../session_queue/session_queue_sqlite.py | 11 ++++++---- .../migrations/migration_15.py | 3 +++ 4 files changed, 36 insertions(+), 7 deletions(-) diff --git a/invokeai/app/services/events/events_common.py b/invokeai/app/services/events/events_common.py index c348611bab..adcb226799 100644 --- a/invokeai/app/services/events/events_common.py +++ b/invokeai/app/services/events/events_common.py @@ -88,7 +88,8 @@ class QueueItemEventBase(QueueEventBase): item_id: int = Field(description="The ID of the queue item") batch_id: str = Field(description="The ID of the queue batch") - origin: str | None = Field(default=None, description="The origin of the batch") + origin: str | None = Field(default=None, description="The origin of the queue item") + destination: str | None = Field(default=None, description="The destination of the queue item") class InvocationEventBase(QueueItemEventBase): @@ -114,6 +115,7 @@ class InvocationStartedEvent(InvocationEventBase): item_id=queue_item.item_id, batch_id=queue_item.batch_id, origin=queue_item.origin, + destination=queue_item.destination, session_id=queue_item.session_id, invocation=invocation, invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], @@ -148,6 +150,7 @@ class InvocationDenoiseProgressEvent(InvocationEventBase): item_id=queue_item.item_id, batch_id=queue_item.batch_id, origin=queue_item.origin, + destination=queue_item.destination, session_id=queue_item.session_id, invocation=invocation, invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], @@ -186,6 +189,7 @@ class InvocationCompleteEvent(InvocationEventBase): item_id=queue_item.item_id, batch_id=queue_item.batch_id, origin=queue_item.origin, + destination=queue_item.destination, session_id=queue_item.session_id, invocation=invocation, invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], @@ -219,6 +223,7 @@ class InvocationErrorEvent(InvocationEventBase): item_id=queue_item.item_id, batch_id=queue_item.batch_id, origin=queue_item.origin, + destination=queue_item.destination, session_id=queue_item.session_id, invocation=invocation, invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], @@ -257,6 +262,7 @@ class QueueItemStatusChangedEvent(QueueItemEventBase): item_id=queue_item.item_id, batch_id=queue_item.batch_id, origin=queue_item.origin, + destination=queue_item.destination, session_id=queue_item.session_id, status=queue_item.status, error_type=queue_item.error_type, diff --git a/invokeai/app/services/session_queue/session_queue_common.py b/invokeai/app/services/session_queue/session_queue_common.py index 1a546dab9c..8e37e205d2 100644 --- a/invokeai/app/services/session_queue/session_queue_common.py +++ b/invokeai/app/services/session_queue/session_queue_common.py @@ -77,7 +77,14 @@ BatchDataCollection: TypeAlias = list[list[BatchDatum]] class Batch(BaseModel): batch_id: str = Field(default_factory=uuid_string, description="The ID of the batch") - origin: str | None = Field(default=None, description="The origin of this batch.") + origin: str | None = Field( + default=None, + description="The origin of this queue item. This data is used by the frontend to determine how to handle results.", + ) + destination: str | None = Field( + default=None, + description="The origin of this queue item. This data is used by the frontend to determine how to handle results", + ) data: Optional[BatchDataCollection] = Field(default=None, description="The batch data collection.") graph: Graph = Field(description="The graph to initialize the session with") workflow: Optional[WorkflowWithoutID] = Field( @@ -196,7 +203,14 @@ class SessionQueueItemWithoutGraph(BaseModel): status: QUEUE_ITEM_STATUS = Field(default="pending", description="The status of this queue item") priority: int = Field(default=0, description="The priority of this queue item") batch_id: str = Field(description="The ID of the batch associated with this queue item") - origin: str | None = Field(default=None, description="The origin of this queue item. ") + origin: str | None = Field( + default=None, + description="The origin of this queue item. This data is used by the frontend to determine how to handle results.", + ) + destination: str | None = Field( + default=None, + description="The origin of this queue item. This data is used by the frontend to determine how to handle results", + ) session_id: str = Field( description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed." ) @@ -297,6 +311,7 @@ class BatchStatus(BaseModel): queue_id: str = Field(..., description="The ID of the queue") batch_id: str = Field(..., description="The ID of the batch") origin: str | None = Field(..., description="The origin of the batch") + destination: str | None = Field(..., description="The destination of the batch") pending: int = Field(..., description="Number of queue items with status 'pending'") in_progress: int = Field(..., description="Number of queue items with status 'in_progress'") completed: int = Field(..., description="Number of queue items with status 'complete'") @@ -443,6 +458,7 @@ class SessionQueueValueToInsert(NamedTuple): priority: int # priority workflow: Optional[str] # workflow json origin: str | None + destination: str | None ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert] @@ -464,6 +480,7 @@ def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new priority, # priority json.dumps(workflow, default=to_jsonable_python) if workflow else None, # workflow (json) batch.origin, # origin + batch.destination, # destination ) ) return values_to_insert diff --git a/invokeai/app/services/session_queue/session_queue_sqlite.py b/invokeai/app/services/session_queue/session_queue_sqlite.py index 265c6065a5..d536aeba75 100644 --- a/invokeai/app/services/session_queue/session_queue_sqlite.py +++ b/invokeai/app/services/session_queue/session_queue_sqlite.py @@ -128,8 +128,8 @@ class SqliteSessionQueue(SessionQueueBase): self.__cursor.executemany( """--sql - INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) + INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) """, values_to_insert, ) @@ -579,7 +579,8 @@ class SqliteSessionQueue(SessionQueueBase): session_id, batch_id, queue_id, - origin + origin, + destination FROM session_queue WHERE queue_id = ? """ @@ -659,7 +660,7 @@ class SqliteSessionQueue(SessionQueueBase): self.__lock.acquire() self.__cursor.execute( """--sql - SELECT status, count(*), origin + SELECT status, count(*), origin, destination FROM session_queue WHERE queue_id = ? @@ -672,6 +673,7 @@ class SqliteSessionQueue(SessionQueueBase): total = sum(row[1] for row in result) counts: dict[str, int] = {row[0]: row[1] for row in result} origin = result[0]["origin"] if result else None + destination = result[0]["destination"] if result else None except Exception: self.__conn.rollback() raise @@ -681,6 +683,7 @@ class SqliteSessionQueue(SessionQueueBase): return BatchStatus( batch_id=batch_id, origin=origin, + destination=destination, queue_id=queue_id, pending=counts.get("pending", 0), in_progress=counts.get("in_progress", 0), diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_15.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_15.py index 026df180f3..455ff71ab5 100644 --- a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_15.py +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_15.py @@ -10,9 +10,11 @@ class Migration15Callback: def _add_origin_col(self, cursor: sqlite3.Cursor) -> None: """ - Adds `origin` column to the session queue table. + - Adds `destination` column to the session queue table. """ cursor.execute("ALTER TABLE session_queue ADD COLUMN origin TEXT;") + cursor.execute("ALTER TABLE session_queue ADD COLUMN destination TEXT;") def build_migration_15() -> Migration: @@ -21,6 +23,7 @@ def build_migration_15() -> Migration: This migration does the following: - Adds `origin` column to the session queue table. + - Adds `destination` column to the session queue table. """ migration_15 = Migration( from_version=14,