fix: canvas not working on queue

Add `batch_id` to outbound events. This necessitates adding it to both `InvocationContext` and `InvocationQueueItem`. This allows the canvas to receive images.

When the user enqueues a batch on the canvas, it is expected that all images from that batch are directed to the canvas.

The simplest, most flexible solution is to add the `batch_id` to the invocation context-y stuff. Then everything knows what batch it came from, and we can have the canvas pick up images associated with its list of canvas `batch_id`s.
This commit is contained in:
psychedelicious
2023-09-20 22:29:44 +10:00
committed by Kent Keirsey
parent 1c38cce16d
commit bdfdf854fc
20 changed files with 129 additions and 50 deletions

View File

@ -425,12 +425,21 @@ class InvocationContext:
graph_execution_state_id: str graph_execution_state_id: str
queue_id: str queue_id: str
queue_item_id: int queue_item_id: int
queue_batch_id: str
def __init__(self, services: InvocationServices, queue_id: str, queue_item_id: int, graph_execution_state_id: str): def __init__(
self,
services: InvocationServices,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
):
self.services = services self.services = services
self.graph_execution_state_id = graph_execution_state_id self.graph_execution_state_id = graph_execution_state_id
self.queue_id = queue_id self.queue_id = queue_id
self.queue_item_id = queue_item_id self.queue_item_id = queue_item_id
self.queue_batch_id = queue_batch_id
class BaseInvocationOutput(BaseModel): class BaseInvocationOutput(BaseModel):

View File

@ -30,6 +30,7 @@ class EventServiceBase:
self, self,
queue_id: str, queue_id: str,
queue_item_id: int, queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str, graph_execution_state_id: str,
node: dict, node: dict,
source_node_id: str, source_node_id: str,
@ -44,6 +45,7 @@ class EventServiceBase:
payload=dict( payload=dict(
queue_id=queue_id, queue_id=queue_id,
queue_item_id=queue_item_id, queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id, graph_execution_state_id=graph_execution_state_id,
node_id=node.get("id"), node_id=node.get("id"),
source_node_id=source_node_id, source_node_id=source_node_id,
@ -58,6 +60,7 @@ class EventServiceBase:
self, self,
queue_id: str, queue_id: str,
queue_item_id: int, queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str, graph_execution_state_id: str,
result: dict, result: dict,
node: dict, node: dict,
@ -69,6 +72,7 @@ class EventServiceBase:
payload=dict( payload=dict(
queue_id=queue_id, queue_id=queue_id,
queue_item_id=queue_item_id, queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id, graph_execution_state_id=graph_execution_state_id,
node=node, node=node,
source_node_id=source_node_id, source_node_id=source_node_id,
@ -80,6 +84,7 @@ class EventServiceBase:
self, self,
queue_id: str, queue_id: str,
queue_item_id: int, queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str, graph_execution_state_id: str,
node: dict, node: dict,
source_node_id: str, source_node_id: str,
@ -92,6 +97,7 @@ class EventServiceBase:
payload=dict( payload=dict(
queue_id=queue_id, queue_id=queue_id,
queue_item_id=queue_item_id, queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id, graph_execution_state_id=graph_execution_state_id,
node=node, node=node,
source_node_id=source_node_id, source_node_id=source_node_id,
@ -101,7 +107,13 @@ class EventServiceBase:
) )
def emit_invocation_started( def emit_invocation_started(
self, queue_id: str, queue_item_id: int, graph_execution_state_id: str, node: dict, source_node_id: str self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
node: dict,
source_node_id: str,
) -> None: ) -> None:
"""Emitted when an invocation has started""" """Emitted when an invocation has started"""
self.__emit_queue_event( self.__emit_queue_event(
@ -109,19 +121,23 @@ class EventServiceBase:
payload=dict( payload=dict(
queue_id=queue_id, queue_id=queue_id,
queue_item_id=queue_item_id, queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id, graph_execution_state_id=graph_execution_state_id,
node=node, node=node,
source_node_id=source_node_id, source_node_id=source_node_id,
), ),
) )
def emit_graph_execution_complete(self, queue_id: str, queue_item_id: int, graph_execution_state_id: str) -> None: def emit_graph_execution_complete(
self, queue_id: str, queue_item_id: int, queue_batch_id: str, graph_execution_state_id: str
) -> None:
"""Emitted when a session has completed all invocations""" """Emitted when a session has completed all invocations"""
self.__emit_queue_event( self.__emit_queue_event(
event_name="graph_execution_state_complete", event_name="graph_execution_state_complete",
payload=dict( payload=dict(
queue_id=queue_id, queue_id=queue_id,
queue_item_id=queue_item_id, queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id, graph_execution_state_id=graph_execution_state_id,
), ),
) )
@ -130,6 +146,7 @@ class EventServiceBase:
self, self,
queue_id: str, queue_id: str,
queue_item_id: int, queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str, graph_execution_state_id: str,
model_name: str, model_name: str,
base_model: BaseModelType, base_model: BaseModelType,
@ -142,6 +159,7 @@ class EventServiceBase:
payload=dict( payload=dict(
queue_id=queue_id, queue_id=queue_id,
queue_item_id=queue_item_id, queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id, graph_execution_state_id=graph_execution_state_id,
model_name=model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
@ -154,6 +172,7 @@ class EventServiceBase:
self, self,
queue_id: str, queue_id: str,
queue_item_id: int, queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str, graph_execution_state_id: str,
model_name: str, model_name: str,
base_model: BaseModelType, base_model: BaseModelType,
@ -167,6 +186,7 @@ class EventServiceBase:
payload=dict( payload=dict(
queue_id=queue_id, queue_id=queue_id,
queue_item_id=queue_item_id, queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id, graph_execution_state_id=graph_execution_state_id,
model_name=model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
@ -182,6 +202,7 @@ class EventServiceBase:
self, self,
queue_id: str, queue_id: str,
queue_item_id: int, queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str, graph_execution_state_id: str,
error_type: str, error_type: str,
error: str, error: str,
@ -192,6 +213,7 @@ class EventServiceBase:
payload=dict( payload=dict(
queue_id=queue_id, queue_id=queue_id,
queue_item_id=queue_item_id, queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id, graph_execution_state_id=graph_execution_state_id,
error_type=error_type, error_type=error_type,
error=error, error=error,
@ -202,6 +224,7 @@ class EventServiceBase:
self, self,
queue_id: str, queue_id: str,
queue_item_id: int, queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str, graph_execution_state_id: str,
node_id: str, node_id: str,
error_type: str, error_type: str,
@ -213,6 +236,7 @@ class EventServiceBase:
payload=dict( payload=dict(
queue_id=queue_id, queue_id=queue_id,
queue_item_id=queue_item_id, queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id, graph_execution_state_id=graph_execution_state_id,
node_id=node_id, node_id=node_id,
error_type=error_type, error_type=error_type,
@ -224,6 +248,7 @@ class EventServiceBase:
self, self,
queue_id: str, queue_id: str,
queue_item_id: int, queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str, graph_execution_state_id: str,
) -> None: ) -> None:
"""Emitted when a session is canceled""" """Emitted when a session is canceled"""
@ -232,6 +257,7 @@ class EventServiceBase:
payload=dict( payload=dict(
queue_id=queue_id, queue_id=queue_id,
queue_item_id=queue_item_id, queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id, graph_execution_state_id=graph_execution_state_id,
), ),
) )

View File

@ -15,6 +15,9 @@ class InvocationQueueItem(BaseModel):
session_queue_item_id: int = Field( session_queue_item_id: int = Field(
description="The ID of session queue item from which this invocation queue item came" description="The ID of session queue item from which this invocation queue item came"
) )
session_queue_batch_id: str = Field(
description="The ID of the session batch from which this invocation queue item came"
)
invoke_all: bool = Field(default=False) invoke_all: bool = Field(default=False)
timestamp: float = Field(default_factory=time.time) timestamp: float = Field(default_factory=time.time)

View File

@ -18,7 +18,12 @@ class Invoker:
self._start() self._start()
def invoke( def invoke(
self, queue_id: str, queue_item_id: int, graph_execution_state: GraphExecutionState, invoke_all: bool = False self,
session_queue_id: str,
session_queue_item_id: int,
session_queue_batch_id: str,
graph_execution_state: GraphExecutionState,
invoke_all: bool = False,
) -> Optional[str]: ) -> Optional[str]:
"""Determines the next node to invoke and enqueues it, preparing if needed. """Determines the next node to invoke and enqueues it, preparing if needed.
Returns the id of the queued node, or `None` if there are no nodes left to enqueue.""" Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
@ -34,8 +39,9 @@ class Invoker:
# Queue the invocation # Queue the invocation
self.services.queue.put( self.services.queue.put(
InvocationQueueItem( InvocationQueueItem(
session_queue_item_id=queue_item_id, session_queue_id=session_queue_id,
session_queue_id=queue_id, session_queue_item_id=session_queue_item_id,
session_queue_batch_id=session_queue_batch_id,
graph_execution_state_id=graph_execution_state.id, graph_execution_state_id=graph_execution_state.id,
invocation_id=invocation.id, invocation_id=invocation.id,
invoke_all=invoke_all, invoke_all=invoke_all,

View File

@ -539,6 +539,7 @@ class ModelManagerService(ModelManagerServiceBase):
context.services.events.emit_model_load_completed( context.services.events.emit_model_load_completed(
queue_id=context.queue_id, queue_id=context.queue_id,
queue_item_id=context.queue_item_id, queue_item_id=context.queue_item_id,
queue_batch_id=context.queue_batch_id,
graph_execution_state_id=context.graph_execution_state_id, graph_execution_state_id=context.graph_execution_state_id,
model_name=model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
@ -550,6 +551,7 @@ class ModelManagerService(ModelManagerServiceBase):
context.services.events.emit_model_load_started( context.services.events.emit_model_load_started(
queue_id=context.queue_id, queue_id=context.queue_id,
queue_item_id=context.queue_item_id, queue_item_id=context.queue_item_id,
queue_batch_id=context.queue_batch_id,
graph_execution_state_id=context.graph_execution_state_id, graph_execution_state_id=context.graph_execution_state_id,
model_name=model_name, model_name=model_name,
base_model=base_model, base_model=base_model,

View File

@ -57,6 +57,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
except Exception as e: except Exception as e:
self.__invoker.services.logger.error("Exception while retrieving session:\n%s" % e) self.__invoker.services.logger.error("Exception while retrieving session:\n%s" % e)
self.__invoker.services.events.emit_session_retrieval_error( self.__invoker.services.events.emit_session_retrieval_error(
queue_batch_id=queue_item.session_queue_batch_id,
queue_item_id=queue_item.session_queue_item_id, queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id, queue_id=queue_item.session_queue_id,
graph_execution_state_id=queue_item.graph_execution_state_id, graph_execution_state_id=queue_item.graph_execution_state_id,
@ -70,6 +71,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
except Exception as e: except Exception as e:
self.__invoker.services.logger.error("Exception while retrieving invocation:\n%s" % e) self.__invoker.services.logger.error("Exception while retrieving invocation:\n%s" % e)
self.__invoker.services.events.emit_invocation_retrieval_error( self.__invoker.services.events.emit_invocation_retrieval_error(
queue_batch_id=queue_item.session_queue_batch_id,
queue_item_id=queue_item.session_queue_item_id, queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id, queue_id=queue_item.session_queue_id,
graph_execution_state_id=queue_item.graph_execution_state_id, graph_execution_state_id=queue_item.graph_execution_state_id,
@ -84,6 +86,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# Send starting event # Send starting event
self.__invoker.services.events.emit_invocation_started( self.__invoker.services.events.emit_invocation_started(
queue_batch_id=queue_item.session_queue_batch_id,
queue_item_id=queue_item.session_queue_item_id, queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id, queue_id=queue_item.session_queue_id,
graph_execution_state_id=graph_execution_state.id, graph_execution_state_id=graph_execution_state.id,
@ -106,6 +109,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
graph_execution_state_id=graph_execution_state.id, graph_execution_state_id=graph_execution_state.id,
queue_item_id=queue_item.session_queue_item_id, queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id, queue_id=queue_item.session_queue_id,
queue_batch_id=queue_item.session_queue_batch_id,
) )
) )
@ -121,6 +125,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# Send complete event # Send complete event
self.__invoker.services.events.emit_invocation_complete( self.__invoker.services.events.emit_invocation_complete(
queue_batch_id=queue_item.session_queue_batch_id,
queue_item_id=queue_item.session_queue_item_id, queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id, queue_id=queue_item.session_queue_id,
graph_execution_state_id=graph_execution_state.id, graph_execution_state_id=graph_execution_state.id,
@ -150,6 +155,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
self.__invoker.services.logger.error("Error while invoking:\n%s" % e) self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
# Send error event # Send error event
self.__invoker.services.events.emit_invocation_error( self.__invoker.services.events.emit_invocation_error(
queue_batch_id=queue_item.session_queue_batch_id,
queue_item_id=queue_item.session_queue_item_id, queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id, queue_id=queue_item.session_queue_id,
graph_execution_state_id=graph_execution_state.id, graph_execution_state_id=graph_execution_state.id,
@ -170,14 +176,16 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
if queue_item.invoke_all and not is_complete: if queue_item.invoke_all and not is_complete:
try: try:
self.__invoker.invoke( self.__invoker.invoke(
queue_item_id=queue_item.session_queue_item_id, session_queue_batch_id=queue_item.session_queue_batch_id,
queue_id=queue_item.session_queue_id, session_queue_item_id=queue_item.session_queue_item_id,
session_queue_id=queue_item.session_queue_id,
graph_execution_state=graph_execution_state, graph_execution_state=graph_execution_state,
invoke_all=True, invoke_all=True,
) )
except Exception as e: except Exception as e:
self.__invoker.services.logger.error("Error while invoking:\n%s" % e) self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
self.__invoker.services.events.emit_invocation_error( self.__invoker.services.events.emit_invocation_error(
queue_batch_id=queue_item.session_queue_batch_id,
queue_item_id=queue_item.session_queue_item_id, queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id, queue_id=queue_item.session_queue_id,
graph_execution_state_id=graph_execution_state.id, graph_execution_state_id=graph_execution_state.id,
@ -188,6 +196,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
) )
elif is_complete: elif is_complete:
self.__invoker.services.events.emit_graph_execution_complete( self.__invoker.services.events.emit_graph_execution_complete(
queue_batch_id=queue_item.session_queue_batch_id,
queue_item_id=queue_item.session_queue_item_id, queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id, queue_id=queue_item.session_queue_id,
graph_execution_state_id=graph_execution_state.id, graph_execution_state_id=graph_execution_state.id,

View File

@ -102,8 +102,9 @@ class DefaultSessionProcessor(SessionProcessorBase):
self.__queue_item = queue_item self.__queue_item = queue_item
self.__invoker.services.graph_execution_manager.set(queue_item.session) self.__invoker.services.graph_execution_manager.set(queue_item.session)
self.__invoker.invoke( self.__invoker.invoke(
queue_item_id=queue_item.item_id, session_queue_batch_id=queue_item.batch_id,
queue_id=queue_item.queue_id, session_queue_id=queue_item.queue_id,
session_queue_item_id=queue_item.item_id,
graph_execution_state=queue_item.session, graph_execution_state=queue_item.session,
invoke_all=True, invoke_all=True,
) )

View File

@ -562,6 +562,7 @@ class SqliteSessionQueue(SessionQueueBase):
self.__invoker.services.events.emit_session_canceled( self.__invoker.services.events.emit_session_canceled(
queue_item_id=queue_item.item_id, queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id, queue_id=queue_item.queue_id,
queue_batch_id=queue_item.batch_id,
graph_execution_state_id=queue_item.session_id, graph_execution_state_id=queue_item.session_id,
) )
self.__invoker.services.events.emit_queue_item_status_changed(queue_item) self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
@ -604,6 +605,7 @@ class SqliteSessionQueue(SessionQueueBase):
self.__invoker.services.events.emit_session_canceled( self.__invoker.services.events.emit_session_canceled(
queue_item_id=current_queue_item.item_id, queue_item_id=current_queue_item.item_id,
queue_id=current_queue_item.queue_id, queue_id=current_queue_item.queue_id,
queue_batch_id=current_queue_item.batch_id,
graph_execution_state_id=current_queue_item.session_id, graph_execution_state_id=current_queue_item.session_id,
) )
self.__invoker.services.events.emit_queue_item_status_changed(current_queue_item) self.__invoker.services.events.emit_queue_item_status_changed(current_queue_item)
@ -649,6 +651,7 @@ class SqliteSessionQueue(SessionQueueBase):
self.__invoker.services.events.emit_session_canceled( self.__invoker.services.events.emit_session_canceled(
queue_item_id=current_queue_item.item_id, queue_item_id=current_queue_item.item_id,
queue_id=current_queue_item.queue_id, queue_id=current_queue_item.queue_id,
queue_batch_id=current_queue_item.batch_id,
graph_execution_state_id=current_queue_item.session_id, graph_execution_state_id=current_queue_item.session_id,
) )
self.__invoker.services.events.emit_queue_item_status_changed(current_queue_item) self.__invoker.services.events.emit_queue_item_status_changed(current_queue_item)

View File

@ -112,6 +112,7 @@ def stable_diffusion_step_callback(
context.services.events.emit_generator_progress( context.services.events.emit_generator_progress(
queue_id=context.queue_id, queue_id=context.queue_id,
queue_item_id=context.queue_item_id, queue_item_id=context.queue_item_id,
queue_batch_id=context.queue_batch_id,
graph_execution_state_id=context.graph_execution_state_id, graph_execution_state_id=context.graph_execution_state_id,
node=node, node=node,
source_node_id=source_node_id, source_node_id=source_node_id,

View File

@ -1,7 +1,7 @@
import { isAnyOf } from '@reduxjs/toolkit'; import { isAnyOf } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import { import {
canvasBatchesAndSessionsReset, canvasBatchIdsReset,
commitStagingAreaImage, commitStagingAreaImage,
discardStagedImages, discardStagedImages,
} from 'features/canvas/store/canvasSlice'; } from 'features/canvas/store/canvasSlice';
@ -38,7 +38,7 @@ export const addCommitStagingAreaImageListener = () => {
}) })
); );
} }
dispatch(canvasBatchesAndSessionsReset()); dispatch(canvasBatchIdsReset());
} catch { } catch {
log.error('Failed to cancel canvas batches'); log.error('Failed to cancel canvas batches');
dispatch( dispatch(

View File

@ -30,7 +30,7 @@ export const addInvocationCompleteEventListener = () => {
`Invocation complete (${action.payload.data.node.type})` `Invocation complete (${action.payload.data.node.type})`
); );
const { result, node, graph_execution_state_id } = data; const { result, node, queue_batch_id } = data;
// This complete event has an associated image output // This complete event has an associated image output
if (isImageOutput(result) && !nodeDenylist.includes(node.type)) { if (isImageOutput(result) && !nodeDenylist.includes(node.type)) {
@ -43,7 +43,7 @@ export const addInvocationCompleteEventListener = () => {
// Add canvas images to the staging area // Add canvas images to the staging area
if ( if (
canvas.sessionIds.includes(graph_execution_state_id) && canvas.batchIds.includes(queue_batch_id) &&
[CANVAS_OUTPUT].includes(data.source_node_id) [CANVAS_OUTPUT].includes(data.source_node_id)
) { ) {
dispatch(addImageToStagingArea(imageDTO)); dispatch(addImageToStagingArea(imageDTO));

View File

@ -1,5 +1,4 @@
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import { canvasSessionIdAdded } from 'features/canvas/store/canvasSlice';
import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue'; import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue';
import { import {
appSocketQueueItemStatusChanged, appSocketQueueItemStatusChanged,
@ -10,12 +9,11 @@ import { startAppListening } from '../..';
export const addSocketQueueItemStatusChangedEventListener = () => { export const addSocketQueueItemStatusChangedEventListener = () => {
startAppListening({ startAppListening({
actionCreator: socketQueueItemStatusChanged, actionCreator: socketQueueItemStatusChanged,
effect: (action, { dispatch, getState }) => { effect: (action, { dispatch }) => {
const log = logger('socketio'); const log = logger('socketio');
const { const {
queue_item_id: item_id, queue_item_id: item_id,
batch_id, queue_batch_id,
graph_execution_state_id,
status, status,
} = action.payload.data; } = action.payload.data;
log.debug( log.debug(
@ -36,11 +34,6 @@ export const addSocketQueueItemStatusChangedEventListener = () => {
}) })
); );
const state = getState();
if (state.canvas.batchIds.includes(batch_id)) {
dispatch(canvasSessionIdAdded(graph_execution_state_id));
}
dispatch( dispatch(
queueApi.util.invalidateTags([ queueApi.util.invalidateTags([
'CurrentSessionQueueItem', 'CurrentSessionQueueItem',
@ -48,7 +41,7 @@ export const addSocketQueueItemStatusChangedEventListener = () => {
'SessionQueueStatus', 'SessionQueueStatus',
{ type: 'SessionQueueItem', id: item_id }, { type: 'SessionQueueItem', id: item_id },
{ type: 'SessionQueueItemDTO', id: item_id }, { type: 'SessionQueueItemDTO', id: item_id },
{ type: 'BatchStatus', id: batch_id }, { type: 'BatchStatus', id: queue_batch_id },
]) ])
); );
}, },

View File

@ -11,12 +11,12 @@ const selector = createSelector(
({ system, canvas }) => { ({ system, canvas }) => {
const { denoiseProgress } = system; const { denoiseProgress } = system;
const { boundingBox } = canvas.layerState.stagingArea; const { boundingBox } = canvas.layerState.stagingArea;
const { sessionIds } = canvas; const { batchIds } = canvas;
return { return {
boundingBox, boundingBox,
progressImage: progressImage:
denoiseProgress && sessionIds.includes(denoiseProgress.session_id) denoiseProgress && batchIds.includes(denoiseProgress.batch_id)
? denoiseProgress.progress_image ? denoiseProgress.progress_image
: undefined, : undefined,
}; };

View File

@ -85,7 +85,6 @@ export const initialCanvasState: CanvasState = {
stageDimensions: { width: 0, height: 0 }, stageDimensions: { width: 0, height: 0 },
stageScale: 1, stageScale: 1,
tool: 'brush', tool: 'brush',
sessionIds: [],
batchIds: [], batchIds: [],
}; };
@ -302,11 +301,7 @@ export const canvasSlice = createSlice({
canvasBatchIdAdded: (state, action: PayloadAction<string>) => { canvasBatchIdAdded: (state, action: PayloadAction<string>) => {
state.batchIds.push(action.payload); state.batchIds.push(action.payload);
}, },
canvasSessionIdAdded: (state, action: PayloadAction<string>) => { canvasBatchIdsReset: (state) => {
state.sessionIds.push(action.payload);
},
canvasBatchesAndSessionsReset: (state) => {
state.sessionIds = [];
state.batchIds = []; state.batchIds = [];
}, },
stagingAreaInitialized: ( stagingAreaInitialized: (
@ -879,8 +874,7 @@ export const {
setShouldAntialias, setShouldAntialias,
canvasResized, canvasResized,
canvasBatchIdAdded, canvasBatchIdAdded,
canvasSessionIdAdded, canvasBatchIdsReset,
canvasBatchesAndSessionsReset,
} = canvasSlice.actions; } = canvasSlice.actions;
export default canvasSlice.reducer; export default canvasSlice.reducer;

View File

@ -166,7 +166,6 @@ export interface CanvasState {
tool: CanvasTool; tool: CanvasTool;
generationMode?: GenerationMode; generationMode?: GenerationMode;
batchIds: string[]; batchIds: string[];
sessionIds: string[];
} }
export type GenerationMode = 'txt2img' | 'img2img' | 'inpaint' | 'outpaint'; export type GenerationMode = 'txt2img' | 'img2img' | 'inpaint' | 'outpaint';

View File

@ -113,6 +113,7 @@ export const systemSlice = createSlice({
order, order,
progress_image, progress_image,
graph_execution_state_id: session_id, graph_execution_state_id: session_id,
queue_batch_id: batch_id,
} = action.payload.data; } = action.payload.data;
state.denoiseProgress = { state.denoiseProgress = {
@ -122,6 +123,7 @@ export const systemSlice = createSlice({
percentage: calculateStepPercentage(step, total_steps, order), percentage: calculateStepPercentage(step, total_steps, order),
progress_image, progress_image,
session_id, session_id,
batch_id,
}; };
state.status = 'PROCESSING'; state.status = 'PROCESSING';

View File

@ -12,6 +12,7 @@ export type SystemStatus =
export type DenoiseProgress = { export type DenoiseProgress = {
session_id: string; session_id: string;
batch_id: string;
progress_image: ProgressImage | null | undefined; progress_image: ProgressImage | null | undefined;
step: number; step: number;
total_steps: number; total_steps: number;

View File

@ -34,7 +34,8 @@ export type BaseNode = {
export type ModelLoadStartedEvent = { export type ModelLoadStartedEvent = {
queue_id: string; queue_id: string;
queue_item_id: string; queue_item_id: number;
queue_batch_id: string;
graph_execution_state_id: string; graph_execution_state_id: string;
model_name: string; model_name: string;
base_model: BaseModelType; base_model: BaseModelType;
@ -44,7 +45,8 @@ export type ModelLoadStartedEvent = {
export type ModelLoadCompletedEvent = { export type ModelLoadCompletedEvent = {
queue_id: string; queue_id: string;
queue_item_id: string; queue_item_id: number;
queue_batch_id: string;
graph_execution_state_id: string; graph_execution_state_id: string;
model_name: string; model_name: string;
base_model: BaseModelType; base_model: BaseModelType;
@ -62,7 +64,8 @@ export type ModelLoadCompletedEvent = {
*/ */
export type GeneratorProgressEvent = { export type GeneratorProgressEvent = {
queue_id: string; queue_id: string;
queue_item_id: string; queue_item_id: number;
queue_batch_id: string;
graph_execution_state_id: string; graph_execution_state_id: string;
node_id: string; node_id: string;
source_node_id: string; source_node_id: string;
@ -81,7 +84,8 @@ export type GeneratorProgressEvent = {
*/ */
export type InvocationCompleteEvent = { export type InvocationCompleteEvent = {
queue_id: string; queue_id: string;
queue_item_id: string; queue_item_id: number;
queue_batch_id: string;
graph_execution_state_id: string; graph_execution_state_id: string;
node: BaseNode; node: BaseNode;
source_node_id: string; source_node_id: string;
@ -95,7 +99,8 @@ export type InvocationCompleteEvent = {
*/ */
export type InvocationErrorEvent = { export type InvocationErrorEvent = {
queue_id: string; queue_id: string;
queue_item_id: string; queue_item_id: number;
queue_batch_id: string;
graph_execution_state_id: string; graph_execution_state_id: string;
node: BaseNode; node: BaseNode;
source_node_id: string; source_node_id: string;
@ -110,7 +115,8 @@ export type InvocationErrorEvent = {
*/ */
export type InvocationStartedEvent = { export type InvocationStartedEvent = {
queue_id: string; queue_id: string;
queue_item_id: string; queue_item_id: number;
queue_batch_id: string;
graph_execution_state_id: string; graph_execution_state_id: string;
node: BaseNode; node: BaseNode;
source_node_id: string; source_node_id: string;
@ -123,7 +129,8 @@ export type InvocationStartedEvent = {
*/ */
export type GraphExecutionStateCompleteEvent = { export type GraphExecutionStateCompleteEvent = {
queue_id: string; queue_id: string;
queue_item_id: string; queue_item_id: number;
queue_batch_id: string;
graph_execution_state_id: string; graph_execution_state_id: string;
}; };
@ -134,7 +141,8 @@ export type GraphExecutionStateCompleteEvent = {
*/ */
export type SessionRetrievalErrorEvent = { export type SessionRetrievalErrorEvent = {
queue_id: string; queue_id: string;
queue_item_id: string; queue_item_id: number;
queue_batch_id: string;
graph_execution_state_id: string; graph_execution_state_id: string;
error_type: string; error_type: string;
error: string; error: string;
@ -147,7 +155,8 @@ export type SessionRetrievalErrorEvent = {
*/ */
export type InvocationRetrievalErrorEvent = { export type InvocationRetrievalErrorEvent = {
queue_id: string; queue_id: string;
queue_item_id: string; queue_item_id: number;
queue_batch_id: string;
graph_execution_state_id: string; graph_execution_state_id: string;
node_id: string; node_id: string;
error_type: string; error_type: string;
@ -161,8 +170,8 @@ export type InvocationRetrievalErrorEvent = {
*/ */
export type QueueItemStatusChangedEvent = { export type QueueItemStatusChangedEvent = {
queue_id: string; queue_id: string;
queue_item_id: string; queue_item_id: number;
batch_id: string; queue_batch_id: string;
session_id: string; session_id: string;
graph_execution_state_id: string; graph_execution_state_id: string;
status: components['schemas']['SessionQueueItemDTO']['status']; status: components['schemas']['SessionQueueItemDTO']['status'];

View File

@ -75,7 +75,13 @@ def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[B
print(f"invoking {n.id}: {type(n)}") print(f"invoking {n.id}: {type(n)}")
o = n.invoke( o = n.invoke(
InvocationContext(queue_item_id="1", queue_id=DEFAULT_QUEUE_ID, services=services, graph_execution_state_id="1") InvocationContext(
queue_batch_id="1",
queue_item_id=1,
queue_id=DEFAULT_QUEUE_ID,
services=services,
graph_execution_state_id="1",
)
) )
g.complete(n.id, o) g.complete(n.id, o)

View File

@ -102,7 +102,12 @@ def test_can_create_graph_state_from_graph(mock_invoker: Invoker, simple_graph):
# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor") # @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
def test_can_invoke(mock_invoker: Invoker, simple_graph): def test_can_invoke(mock_invoker: Invoker, simple_graph):
g = mock_invoker.create_execution_state(graph=simple_graph) g = mock_invoker.create_execution_state(graph=simple_graph)
invocation_id = mock_invoker.invoke(queue_item_id="1", queue_id=DEFAULT_QUEUE_ID, graph_execution_state=g) invocation_id = mock_invoker.invoke(
session_queue_batch_id="1",
session_queue_item_id=1,
session_queue_id=DEFAULT_QUEUE_ID,
graph_execution_state=g,
)
assert invocation_id is not None assert invocation_id is not None
def has_executed_any(g: GraphExecutionState): def has_executed_any(g: GraphExecutionState):
@ -120,7 +125,11 @@ def test_can_invoke(mock_invoker: Invoker, simple_graph):
def test_can_invoke_all(mock_invoker: Invoker, simple_graph): def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
g = mock_invoker.create_execution_state(graph=simple_graph) g = mock_invoker.create_execution_state(graph=simple_graph)
invocation_id = mock_invoker.invoke( invocation_id = mock_invoker.invoke(
queue_item_id="1", queue_id=DEFAULT_QUEUE_ID, graph_execution_state=g, invoke_all=True session_queue_batch_id="1",
session_queue_item_id=1,
session_queue_id=DEFAULT_QUEUE_ID,
graph_execution_state=g,
invoke_all=True,
) )
assert invocation_id is not None assert invocation_id is not None
@ -140,7 +149,13 @@ def test_handles_errors(mock_invoker: Invoker):
g = mock_invoker.create_execution_state() g = mock_invoker.create_execution_state()
g.graph.add_node(ErrorInvocation(id="1")) g.graph.add_node(ErrorInvocation(id="1"))
mock_invoker.invoke(queue_item_id="1", queue_id=DEFAULT_QUEUE_ID, graph_execution_state=g, invoke_all=True) mock_invoker.invoke(
session_queue_batch_id="1",
session_queue_item_id=1,
session_queue_id=DEFAULT_QUEUE_ID,
graph_execution_state=g,
invoke_all=True,
)
def has_executed_all(g: GraphExecutionState): def has_executed_all(g: GraphExecutionState):
g = mock_invoker.services.graph_execution_manager.get(g.id) g = mock_invoker.services.graph_execution_manager.get(g.id)