feat(app): add destination column to session_queue

The frontend needs to know where queue items came from (i.e. which tab), and where results are going to (i.e. send images to gallery or canvas). The `origin` column is not quite enough to represent this cleanly.

A `destination` column provides the frontend what it needs to handle incoming generations.
This commit is contained in:
psychedelicious 2024-08-29 17:49:06 +10:00
parent 93f1d67fbf
commit 56fbe751db
4 changed files with 36 additions and 7 deletions

View File

@ -88,7 +88,8 @@ class QueueItemEventBase(QueueEventBase):
item_id: int = Field(description="The ID of the queue item") item_id: int = Field(description="The ID of the queue item")
batch_id: str = Field(description="The ID of the queue batch") batch_id: str = Field(description="The ID of the queue batch")
origin: str | None = Field(default=None, description="The origin of the batch") origin: str | None = Field(default=None, description="The origin of the queue item")
destination: str | None = Field(default=None, description="The destination of the queue item")
class InvocationEventBase(QueueItemEventBase): class InvocationEventBase(QueueItemEventBase):
@ -114,6 +115,7 @@ class InvocationStartedEvent(InvocationEventBase):
item_id=queue_item.item_id, item_id=queue_item.item_id,
batch_id=queue_item.batch_id, batch_id=queue_item.batch_id,
origin=queue_item.origin, origin=queue_item.origin,
destination=queue_item.destination,
session_id=queue_item.session_id, session_id=queue_item.session_id,
invocation=invocation, invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
@ -148,6 +150,7 @@ class InvocationDenoiseProgressEvent(InvocationEventBase):
item_id=queue_item.item_id, item_id=queue_item.item_id,
batch_id=queue_item.batch_id, batch_id=queue_item.batch_id,
origin=queue_item.origin, origin=queue_item.origin,
destination=queue_item.destination,
session_id=queue_item.session_id, session_id=queue_item.session_id,
invocation=invocation, invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
@ -186,6 +189,7 @@ class InvocationCompleteEvent(InvocationEventBase):
item_id=queue_item.item_id, item_id=queue_item.item_id,
batch_id=queue_item.batch_id, batch_id=queue_item.batch_id,
origin=queue_item.origin, origin=queue_item.origin,
destination=queue_item.destination,
session_id=queue_item.session_id, session_id=queue_item.session_id,
invocation=invocation, invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
@ -219,6 +223,7 @@ class InvocationErrorEvent(InvocationEventBase):
item_id=queue_item.item_id, item_id=queue_item.item_id,
batch_id=queue_item.batch_id, batch_id=queue_item.batch_id,
origin=queue_item.origin, origin=queue_item.origin,
destination=queue_item.destination,
session_id=queue_item.session_id, session_id=queue_item.session_id,
invocation=invocation, invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
@ -257,6 +262,7 @@ class QueueItemStatusChangedEvent(QueueItemEventBase):
item_id=queue_item.item_id, item_id=queue_item.item_id,
batch_id=queue_item.batch_id, batch_id=queue_item.batch_id,
origin=queue_item.origin, origin=queue_item.origin,
destination=queue_item.destination,
session_id=queue_item.session_id, session_id=queue_item.session_id,
status=queue_item.status, status=queue_item.status,
error_type=queue_item.error_type, error_type=queue_item.error_type,

View File

@ -77,7 +77,14 @@ BatchDataCollection: TypeAlias = list[list[BatchDatum]]
class Batch(BaseModel): class Batch(BaseModel):
batch_id: str = Field(default_factory=uuid_string, description="The ID of the batch") batch_id: str = Field(default_factory=uuid_string, description="The ID of the batch")
origin: str | None = Field(default=None, description="The origin of this batch.") origin: str | None = Field(
default=None,
description="The origin of this queue item. This data is used by the frontend to determine how to handle results.",
)
destination: str | None = Field(
default=None,
description="The origin of this queue item. This data is used by the frontend to determine how to handle results",
)
data: Optional[BatchDataCollection] = Field(default=None, description="The batch data collection.") data: Optional[BatchDataCollection] = Field(default=None, description="The batch data collection.")
graph: Graph = Field(description="The graph to initialize the session with") graph: Graph = Field(description="The graph to initialize the session with")
workflow: Optional[WorkflowWithoutID] = Field( workflow: Optional[WorkflowWithoutID] = Field(
@ -196,7 +203,14 @@ class SessionQueueItemWithoutGraph(BaseModel):
status: QUEUE_ITEM_STATUS = Field(default="pending", description="The status of this queue item") status: QUEUE_ITEM_STATUS = Field(default="pending", description="The status of this queue item")
priority: int = Field(default=0, description="The priority of this queue item") priority: int = Field(default=0, description="The priority of this queue item")
batch_id: str = Field(description="The ID of the batch associated with this queue item") batch_id: str = Field(description="The ID of the batch associated with this queue item")
origin: str | None = Field(default=None, description="The origin of this queue item. ") origin: str | None = Field(
default=None,
description="The origin of this queue item. This data is used by the frontend to determine how to handle results.",
)
destination: str | None = Field(
default=None,
description="The origin of this queue item. This data is used by the frontend to determine how to handle results",
)
session_id: str = Field( session_id: str = Field(
description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed." description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed."
) )
@ -297,6 +311,7 @@ class BatchStatus(BaseModel):
queue_id: str = Field(..., description="The ID of the queue") queue_id: str = Field(..., description="The ID of the queue")
batch_id: str = Field(..., description="The ID of the batch") batch_id: str = Field(..., description="The ID of the batch")
origin: str | None = Field(..., description="The origin of the batch") origin: str | None = Field(..., description="The origin of the batch")
destination: str | None = Field(..., description="The destination of the batch")
pending: int = Field(..., description="Number of queue items with status 'pending'") pending: int = Field(..., description="Number of queue items with status 'pending'")
in_progress: int = Field(..., description="Number of queue items with status 'in_progress'") in_progress: int = Field(..., description="Number of queue items with status 'in_progress'")
completed: int = Field(..., description="Number of queue items with status 'complete'") completed: int = Field(..., description="Number of queue items with status 'complete'")
@ -443,6 +458,7 @@ class SessionQueueValueToInsert(NamedTuple):
priority: int # priority priority: int # priority
workflow: Optional[str] # workflow json workflow: Optional[str] # workflow json
origin: str | None origin: str | None
destination: str | None
ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert] ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert]
@ -464,6 +480,7 @@ def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new
priority, # priority priority, # priority
json.dumps(workflow, default=to_jsonable_python) if workflow else None, # workflow (json) json.dumps(workflow, default=to_jsonable_python) if workflow else None, # workflow (json)
batch.origin, # origin batch.origin, # origin
batch.destination, # destination
) )
) )
return values_to_insert return values_to_insert

View File

@ -128,8 +128,8 @@ class SqliteSessionQueue(SessionQueueBase):
self.__cursor.executemany( self.__cursor.executemany(
"""--sql """--sql
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin) INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination)
VALUES (?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""", """,
values_to_insert, values_to_insert,
) )
@ -579,7 +579,8 @@ class SqliteSessionQueue(SessionQueueBase):
session_id, session_id,
batch_id, batch_id,
queue_id, queue_id,
origin origin,
destination
FROM session_queue FROM session_queue
WHERE queue_id = ? WHERE queue_id = ?
""" """
@ -659,7 +660,7 @@ class SqliteSessionQueue(SessionQueueBase):
self.__lock.acquire() self.__lock.acquire()
self.__cursor.execute( self.__cursor.execute(
"""--sql """--sql
SELECT status, count(*), origin SELECT status, count(*), origin, destination
FROM session_queue FROM session_queue
WHERE WHERE
queue_id = ? queue_id = ?
@ -672,6 +673,7 @@ class SqliteSessionQueue(SessionQueueBase):
total = sum(row[1] for row in result) total = sum(row[1] for row in result)
counts: dict[str, int] = {row[0]: row[1] for row in result} counts: dict[str, int] = {row[0]: row[1] for row in result}
origin = result[0]["origin"] if result else None origin = result[0]["origin"] if result else None
destination = result[0]["destination"] if result else None
except Exception: except Exception:
self.__conn.rollback() self.__conn.rollback()
raise raise
@ -681,6 +683,7 @@ class SqliteSessionQueue(SessionQueueBase):
return BatchStatus( return BatchStatus(
batch_id=batch_id, batch_id=batch_id,
origin=origin, origin=origin,
destination=destination,
queue_id=queue_id, queue_id=queue_id,
pending=counts.get("pending", 0), pending=counts.get("pending", 0),
in_progress=counts.get("in_progress", 0), in_progress=counts.get("in_progress", 0),

View File

@ -10,9 +10,11 @@ class Migration15Callback:
def _add_origin_col(self, cursor: sqlite3.Cursor) -> None: def _add_origin_col(self, cursor: sqlite3.Cursor) -> None:
""" """
- Adds `origin` column to the session queue table. - Adds `origin` column to the session queue table.
- Adds `destination` column to the session queue table.
""" """
cursor.execute("ALTER TABLE session_queue ADD COLUMN origin TEXT;") cursor.execute("ALTER TABLE session_queue ADD COLUMN origin TEXT;")
cursor.execute("ALTER TABLE session_queue ADD COLUMN destination TEXT;")
def build_migration_15() -> Migration: def build_migration_15() -> Migration:
@ -21,6 +23,7 @@ def build_migration_15() -> Migration:
This migration does the following: This migration does the following:
- Adds `origin` column to the session queue table. - Adds `origin` column to the session queue table.
- Adds `destination` column to the session queue table.
""" """
migration_15 = Migration( migration_15 = Migration(
from_version=14, from_version=14,