diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 24bfa6aac5..97bd29ff17 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -425,12 +425,21 @@ class InvocationContext: graph_execution_state_id: str queue_id: str queue_item_id: int + queue_batch_id: str - def __init__(self, services: InvocationServices, queue_id: str, queue_item_id: int, graph_execution_state_id: str): + def __init__( + self, + services: InvocationServices, + queue_id: str, + queue_item_id: int, + queue_batch_id: str, + graph_execution_state_id: str, + ): self.services = services self.graph_execution_state_id = graph_execution_state_id self.queue_id = queue_id self.queue_item_id = queue_item_id + self.queue_batch_id = queue_batch_id class BaseInvocationOutput(BaseModel): diff --git a/invokeai/app/services/events.py b/invokeai/app/services/events.py index 260ff4b173..3b36ffb917 100644 --- a/invokeai/app/services/events.py +++ b/invokeai/app/services/events.py @@ -30,6 +30,7 @@ class EventServiceBase: self, queue_id: str, queue_item_id: int, + queue_batch_id: str, graph_execution_state_id: str, node: dict, source_node_id: str, @@ -44,6 +45,7 @@ class EventServiceBase: payload=dict( queue_id=queue_id, queue_item_id=queue_item_id, + queue_batch_id=queue_batch_id, graph_execution_state_id=graph_execution_state_id, node_id=node.get("id"), source_node_id=source_node_id, @@ -58,6 +60,7 @@ class EventServiceBase: self, queue_id: str, queue_item_id: int, + queue_batch_id: str, graph_execution_state_id: str, result: dict, node: dict, @@ -69,6 +72,7 @@ class EventServiceBase: payload=dict( queue_id=queue_id, queue_item_id=queue_item_id, + queue_batch_id=queue_batch_id, graph_execution_state_id=graph_execution_state_id, node=node, source_node_id=source_node_id, @@ -80,6 +84,7 @@ class EventServiceBase: self, queue_id: str, queue_item_id: int, + queue_batch_id: str, graph_execution_state_id: str, node: dict, source_node_id: str, @@ -92,6 +97,7 @@ class EventServiceBase: payload=dict( queue_id=queue_id, queue_item_id=queue_item_id, + queue_batch_id=queue_batch_id, graph_execution_state_id=graph_execution_state_id, node=node, source_node_id=source_node_id, @@ -101,7 +107,13 @@ class EventServiceBase: ) def emit_invocation_started( - self, queue_id: str, queue_item_id: int, graph_execution_state_id: str, node: dict, source_node_id: str + self, + queue_id: str, + queue_item_id: int, + queue_batch_id: str, + graph_execution_state_id: str, + node: dict, + source_node_id: str, ) -> None: """Emitted when an invocation has started""" self.__emit_queue_event( @@ -109,19 +121,23 @@ class EventServiceBase: payload=dict( queue_id=queue_id, queue_item_id=queue_item_id, + queue_batch_id=queue_batch_id, graph_execution_state_id=graph_execution_state_id, node=node, source_node_id=source_node_id, ), ) - def emit_graph_execution_complete(self, queue_id: str, queue_item_id: int, graph_execution_state_id: str) -> None: + def emit_graph_execution_complete( + self, queue_id: str, queue_item_id: int, queue_batch_id: str, graph_execution_state_id: str + ) -> None: """Emitted when a session has completed all invocations""" self.__emit_queue_event( event_name="graph_execution_state_complete", payload=dict( queue_id=queue_id, queue_item_id=queue_item_id, + queue_batch_id=queue_batch_id, graph_execution_state_id=graph_execution_state_id, ), ) @@ -130,6 +146,7 @@ class EventServiceBase: self, queue_id: str, queue_item_id: int, + queue_batch_id: str, graph_execution_state_id: str, model_name: str, base_model: BaseModelType, @@ -142,6 +159,7 @@ class EventServiceBase: payload=dict( queue_id=queue_id, queue_item_id=queue_item_id, + queue_batch_id=queue_batch_id, graph_execution_state_id=graph_execution_state_id, model_name=model_name, base_model=base_model, @@ -154,6 +172,7 @@ class EventServiceBase: self, queue_id: str, queue_item_id: int, + queue_batch_id: str, graph_execution_state_id: str, model_name: str, base_model: BaseModelType, @@ -167,6 +186,7 @@ class EventServiceBase: payload=dict( queue_id=queue_id, queue_item_id=queue_item_id, + queue_batch_id=queue_batch_id, graph_execution_state_id=graph_execution_state_id, model_name=model_name, base_model=base_model, @@ -182,6 +202,7 @@ class EventServiceBase: self, queue_id: str, queue_item_id: int, + queue_batch_id: str, graph_execution_state_id: str, error_type: str, error: str, @@ -192,6 +213,7 @@ class EventServiceBase: payload=dict( queue_id=queue_id, queue_item_id=queue_item_id, + queue_batch_id=queue_batch_id, graph_execution_state_id=graph_execution_state_id, error_type=error_type, error=error, @@ -202,6 +224,7 @@ class EventServiceBase: self, queue_id: str, queue_item_id: int, + queue_batch_id: str, graph_execution_state_id: str, node_id: str, error_type: str, @@ -213,6 +236,7 @@ class EventServiceBase: payload=dict( queue_id=queue_id, queue_item_id=queue_item_id, + queue_batch_id=queue_batch_id, graph_execution_state_id=graph_execution_state_id, node_id=node_id, error_type=error_type, @@ -224,6 +248,7 @@ class EventServiceBase: self, queue_id: str, queue_item_id: int, + queue_batch_id: str, graph_execution_state_id: str, ) -> None: """Emitted when a session is canceled""" @@ -232,6 +257,7 @@ class EventServiceBase: payload=dict( queue_id=queue_id, queue_item_id=queue_item_id, + queue_batch_id=queue_batch_id, graph_execution_state_id=graph_execution_state_id, ), ) diff --git a/invokeai/app/services/invocation_queue.py b/invokeai/app/services/invocation_queue.py index 0819d90748..378a9d12cf 100644 --- a/invokeai/app/services/invocation_queue.py +++ b/invokeai/app/services/invocation_queue.py @@ -15,6 +15,9 @@ class InvocationQueueItem(BaseModel): session_queue_item_id: int = Field( description="The ID of session queue item from which this invocation queue item came" ) + session_queue_batch_id: str = Field( + description="The ID of the session batch from which this invocation queue item came" + ) invoke_all: bool = Field(default=False) timestamp: float = Field(default_factory=time.time) diff --git a/invokeai/app/services/invoker.py b/invokeai/app/services/invoker.py index 1466fd0142..0c98fc285c 100644 --- a/invokeai/app/services/invoker.py +++ b/invokeai/app/services/invoker.py @@ -18,7 +18,12 @@ class Invoker: self._start() def invoke( - self, queue_id: str, queue_item_id: int, graph_execution_state: GraphExecutionState, invoke_all: bool = False + self, + session_queue_id: str, + session_queue_item_id: int, + session_queue_batch_id: str, + graph_execution_state: GraphExecutionState, + invoke_all: bool = False, ) -> Optional[str]: """Determines the next node to invoke and enqueues it, preparing if needed. Returns the id of the queued node, or `None` if there are no nodes left to enqueue.""" @@ -34,8 +39,9 @@ class Invoker: # Queue the invocation self.services.queue.put( InvocationQueueItem( - session_queue_item_id=queue_item_id, - session_queue_id=queue_id, + session_queue_id=session_queue_id, + session_queue_item_id=session_queue_item_id, + session_queue_batch_id=session_queue_batch_id, graph_execution_state_id=graph_execution_state.id, invocation_id=invocation.id, invoke_all=invoke_all, diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index d7d274aec0..143fa8f357 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -539,6 +539,7 @@ class ModelManagerService(ModelManagerServiceBase): context.services.events.emit_model_load_completed( queue_id=context.queue_id, queue_item_id=context.queue_item_id, + queue_batch_id=context.queue_batch_id, graph_execution_state_id=context.graph_execution_state_id, model_name=model_name, base_model=base_model, @@ -550,6 +551,7 @@ class ModelManagerService(ModelManagerServiceBase): context.services.events.emit_model_load_started( queue_id=context.queue_id, queue_item_id=context.queue_item_id, + queue_batch_id=context.queue_batch_id, graph_execution_state_id=context.graph_execution_state_id, model_name=model_name, base_model=base_model, diff --git a/invokeai/app/services/processor.py b/invokeai/app/services/processor.py index 54531be85c..b4c894c52d 100644 --- a/invokeai/app/services/processor.py +++ b/invokeai/app/services/processor.py @@ -57,6 +57,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC): except Exception as e: self.__invoker.services.logger.error("Exception while retrieving session:\n%s" % e) self.__invoker.services.events.emit_session_retrieval_error( + queue_batch_id=queue_item.session_queue_batch_id, queue_item_id=queue_item.session_queue_item_id, queue_id=queue_item.session_queue_id, graph_execution_state_id=queue_item.graph_execution_state_id, @@ -70,6 +71,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC): except Exception as e: self.__invoker.services.logger.error("Exception while retrieving invocation:\n%s" % e) self.__invoker.services.events.emit_invocation_retrieval_error( + queue_batch_id=queue_item.session_queue_batch_id, queue_item_id=queue_item.session_queue_item_id, queue_id=queue_item.session_queue_id, graph_execution_state_id=queue_item.graph_execution_state_id, @@ -84,6 +86,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC): # Send starting event self.__invoker.services.events.emit_invocation_started( + queue_batch_id=queue_item.session_queue_batch_id, queue_item_id=queue_item.session_queue_item_id, queue_id=queue_item.session_queue_id, graph_execution_state_id=graph_execution_state.id, @@ -106,6 +109,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC): graph_execution_state_id=graph_execution_state.id, queue_item_id=queue_item.session_queue_item_id, queue_id=queue_item.session_queue_id, + queue_batch_id=queue_item.session_queue_batch_id, ) ) @@ -121,6 +125,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC): # Send complete event self.__invoker.services.events.emit_invocation_complete( + queue_batch_id=queue_item.session_queue_batch_id, queue_item_id=queue_item.session_queue_item_id, queue_id=queue_item.session_queue_id, graph_execution_state_id=graph_execution_state.id, @@ -150,6 +155,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC): self.__invoker.services.logger.error("Error while invoking:\n%s" % e) # Send error event self.__invoker.services.events.emit_invocation_error( + queue_batch_id=queue_item.session_queue_batch_id, queue_item_id=queue_item.session_queue_item_id, queue_id=queue_item.session_queue_id, graph_execution_state_id=graph_execution_state.id, @@ -170,14 +176,16 @@ class DefaultInvocationProcessor(InvocationProcessorABC): if queue_item.invoke_all and not is_complete: try: self.__invoker.invoke( - queue_item_id=queue_item.session_queue_item_id, - queue_id=queue_item.session_queue_id, + session_queue_batch_id=queue_item.session_queue_batch_id, + session_queue_item_id=queue_item.session_queue_item_id, + session_queue_id=queue_item.session_queue_id, graph_execution_state=graph_execution_state, invoke_all=True, ) except Exception as e: self.__invoker.services.logger.error("Error while invoking:\n%s" % e) self.__invoker.services.events.emit_invocation_error( + queue_batch_id=queue_item.session_queue_batch_id, queue_item_id=queue_item.session_queue_item_id, queue_id=queue_item.session_queue_id, graph_execution_state_id=graph_execution_state.id, @@ -188,6 +196,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC): ) elif is_complete: self.__invoker.services.events.emit_graph_execution_complete( + queue_batch_id=queue_item.session_queue_batch_id, queue_item_id=queue_item.session_queue_item_id, queue_id=queue_item.session_queue_id, graph_execution_state_id=graph_execution_state.id, diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 92b68aeae7..b682c7e56c 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -102,8 +102,9 @@ class DefaultSessionProcessor(SessionProcessorBase): self.__queue_item = queue_item self.__invoker.services.graph_execution_manager.set(queue_item.session) self.__invoker.invoke( - queue_item_id=queue_item.item_id, - queue_id=queue_item.queue_id, + session_queue_batch_id=queue_item.batch_id, + session_queue_id=queue_item.queue_id, + session_queue_item_id=queue_item.item_id, graph_execution_state=queue_item.session, invoke_all=True, ) diff --git a/invokeai/app/services/session_queue/session_queue_sqlite.py b/invokeai/app/services/session_queue/session_queue_sqlite.py index 4925170c48..e1701aa288 100644 --- a/invokeai/app/services/session_queue/session_queue_sqlite.py +++ b/invokeai/app/services/session_queue/session_queue_sqlite.py @@ -562,6 +562,7 @@ class SqliteSessionQueue(SessionQueueBase): self.__invoker.services.events.emit_session_canceled( queue_item_id=queue_item.item_id, queue_id=queue_item.queue_id, + queue_batch_id=queue_item.batch_id, graph_execution_state_id=queue_item.session_id, ) self.__invoker.services.events.emit_queue_item_status_changed(queue_item) @@ -604,6 +605,7 @@ class SqliteSessionQueue(SessionQueueBase): self.__invoker.services.events.emit_session_canceled( queue_item_id=current_queue_item.item_id, queue_id=current_queue_item.queue_id, + queue_batch_id=current_queue_item.batch_id, graph_execution_state_id=current_queue_item.session_id, ) self.__invoker.services.events.emit_queue_item_status_changed(current_queue_item) @@ -649,6 +651,7 @@ class SqliteSessionQueue(SessionQueueBase): self.__invoker.services.events.emit_session_canceled( queue_item_id=current_queue_item.item_id, queue_id=current_queue_item.queue_id, + queue_batch_id=current_queue_item.batch_id, graph_execution_state_id=current_queue_item.session_id, ) self.__invoker.services.events.emit_queue_item_status_changed(current_queue_item) diff --git a/invokeai/app/util/step_callback.py b/invokeai/app/util/step_callback.py index 8edcc11f05..6d4a857491 100644 --- a/invokeai/app/util/step_callback.py +++ b/invokeai/app/util/step_callback.py @@ -112,6 +112,7 @@ def stable_diffusion_step_callback( context.services.events.emit_generator_progress( queue_id=context.queue_id, queue_item_id=context.queue_item_id, + queue_batch_id=context.queue_batch_id, graph_execution_state_id=context.graph_execution_state_id, node=node, source_node_id=source_node_id, diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts index 4ee1bdb15a..d302d50255 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts @@ -1,7 +1,7 @@ import { isAnyOf } from '@reduxjs/toolkit'; import { logger } from 'app/logging/logger'; import { - canvasBatchesAndSessionsReset, + canvasBatchIdsReset, commitStagingAreaImage, discardStagedImages, } from 'features/canvas/store/canvasSlice'; @@ -38,7 +38,7 @@ export const addCommitStagingAreaImageListener = () => { }) ); } - dispatch(canvasBatchesAndSessionsReset()); + dispatch(canvasBatchIdsReset()); } catch { log.error('Failed to cancel canvas batches'); dispatch( diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts index d326a122c9..e2c97100af 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts @@ -30,7 +30,7 @@ export const addInvocationCompleteEventListener = () => { `Invocation complete (${action.payload.data.node.type})` ); - const { result, node, graph_execution_state_id } = data; + const { result, node, queue_batch_id } = data; // This complete event has an associated image output if (isImageOutput(result) && !nodeDenylist.includes(node.type)) { @@ -43,7 +43,7 @@ export const addInvocationCompleteEventListener = () => { // Add canvas images to the staging area if ( - canvas.sessionIds.includes(graph_execution_state_id) && + canvas.batchIds.includes(queue_batch_id) && [CANVAS_OUTPUT].includes(data.source_node_id) ) { dispatch(addImageToStagingArea(imageDTO)); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged.ts index a56fb4cc76..241e1da92c 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged.ts @@ -1,5 +1,4 @@ import { logger } from 'app/logging/logger'; -import { canvasSessionIdAdded } from 'features/canvas/store/canvasSlice'; import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue'; import { appSocketQueueItemStatusChanged, @@ -10,12 +9,11 @@ import { startAppListening } from '../..'; export const addSocketQueueItemStatusChangedEventListener = () => { startAppListening({ actionCreator: socketQueueItemStatusChanged, - effect: (action, { dispatch, getState }) => { + effect: (action, { dispatch }) => { const log = logger('socketio'); const { queue_item_id: item_id, - batch_id, - graph_execution_state_id, + queue_batch_id, status, } = action.payload.data; log.debug( @@ -36,11 +34,6 @@ export const addSocketQueueItemStatusChangedEventListener = () => { }) ); - const state = getState(); - if (state.canvas.batchIds.includes(batch_id)) { - dispatch(canvasSessionIdAdded(graph_execution_state_id)); - } - dispatch( queueApi.util.invalidateTags([ 'CurrentSessionQueueItem', @@ -48,7 +41,7 @@ export const addSocketQueueItemStatusChangedEventListener = () => { 'SessionQueueStatus', { type: 'SessionQueueItem', id: item_id }, { type: 'SessionQueueItemDTO', id: item_id }, - { type: 'BatchStatus', id: batch_id }, + { type: 'BatchStatus', id: queue_batch_id }, ]) ); }, diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvasIntermediateImage.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvasIntermediateImage.tsx index 4a29e12859..0febf7fb21 100644 --- a/invokeai/frontend/web/src/features/canvas/components/IAICanvasIntermediateImage.tsx +++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvasIntermediateImage.tsx @@ -11,12 +11,12 @@ const selector = createSelector( ({ system, canvas }) => { const { denoiseProgress } = system; const { boundingBox } = canvas.layerState.stagingArea; - const { sessionIds } = canvas; + const { batchIds } = canvas; return { boundingBox, progressImage: - denoiseProgress && sessionIds.includes(denoiseProgress.session_id) + denoiseProgress && batchIds.includes(denoiseProgress.batch_id) ? denoiseProgress.progress_image : undefined, }; diff --git a/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts b/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts index 66b87a84e4..b726e757f6 100644 --- a/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts +++ b/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts @@ -85,7 +85,6 @@ export const initialCanvasState: CanvasState = { stageDimensions: { width: 0, height: 0 }, stageScale: 1, tool: 'brush', - sessionIds: [], batchIds: [], }; @@ -302,11 +301,7 @@ export const canvasSlice = createSlice({ canvasBatchIdAdded: (state, action: PayloadAction) => { state.batchIds.push(action.payload); }, - canvasSessionIdAdded: (state, action: PayloadAction) => { - state.sessionIds.push(action.payload); - }, - canvasBatchesAndSessionsReset: (state) => { - state.sessionIds = []; + canvasBatchIdsReset: (state) => { state.batchIds = []; }, stagingAreaInitialized: ( @@ -879,8 +874,7 @@ export const { setShouldAntialias, canvasResized, canvasBatchIdAdded, - canvasSessionIdAdded, - canvasBatchesAndSessionsReset, + canvasBatchIdsReset, } = canvasSlice.actions; export default canvasSlice.reducer; diff --git a/invokeai/frontend/web/src/features/canvas/store/canvasTypes.ts b/invokeai/frontend/web/src/features/canvas/store/canvasTypes.ts index 233d38dc80..875157d36a 100644 --- a/invokeai/frontend/web/src/features/canvas/store/canvasTypes.ts +++ b/invokeai/frontend/web/src/features/canvas/store/canvasTypes.ts @@ -166,7 +166,6 @@ export interface CanvasState { tool: CanvasTool; generationMode?: GenerationMode; batchIds: string[]; - sessionIds: string[]; } export type GenerationMode = 'txt2img' | 'img2img' | 'inpaint' | 'outpaint'; diff --git a/invokeai/frontend/web/src/features/system/store/systemSlice.ts b/invokeai/frontend/web/src/features/system/store/systemSlice.ts index 78f1ce6e70..40e8c42145 100644 --- a/invokeai/frontend/web/src/features/system/store/systemSlice.ts +++ b/invokeai/frontend/web/src/features/system/store/systemSlice.ts @@ -113,6 +113,7 @@ export const systemSlice = createSlice({ order, progress_image, graph_execution_state_id: session_id, + queue_batch_id: batch_id, } = action.payload.data; state.denoiseProgress = { @@ -122,6 +123,7 @@ export const systemSlice = createSlice({ percentage: calculateStepPercentage(step, total_steps, order), progress_image, session_id, + batch_id, }; state.status = 'PROCESSING'; diff --git a/invokeai/frontend/web/src/features/system/store/types.ts b/invokeai/frontend/web/src/features/system/store/types.ts index 05b86f1e79..b81e292e36 100644 --- a/invokeai/frontend/web/src/features/system/store/types.ts +++ b/invokeai/frontend/web/src/features/system/store/types.ts @@ -12,6 +12,7 @@ export type SystemStatus = export type DenoiseProgress = { session_id: string; + batch_id: string; progress_image: ProgressImage | null | undefined; step: number; total_steps: number; diff --git a/invokeai/frontend/web/src/services/events/types.ts b/invokeai/frontend/web/src/services/events/types.ts index 8087b88077..47a3d83eba 100644 --- a/invokeai/frontend/web/src/services/events/types.ts +++ b/invokeai/frontend/web/src/services/events/types.ts @@ -34,7 +34,8 @@ export type BaseNode = { export type ModelLoadStartedEvent = { queue_id: string; - queue_item_id: string; + queue_item_id: number; + queue_batch_id: string; graph_execution_state_id: string; model_name: string; base_model: BaseModelType; @@ -44,7 +45,8 @@ export type ModelLoadStartedEvent = { export type ModelLoadCompletedEvent = { queue_id: string; - queue_item_id: string; + queue_item_id: number; + queue_batch_id: string; graph_execution_state_id: string; model_name: string; base_model: BaseModelType; @@ -62,7 +64,8 @@ export type ModelLoadCompletedEvent = { */ export type GeneratorProgressEvent = { queue_id: string; - queue_item_id: string; + queue_item_id: number; + queue_batch_id: string; graph_execution_state_id: string; node_id: string; source_node_id: string; @@ -81,7 +84,8 @@ export type GeneratorProgressEvent = { */ export type InvocationCompleteEvent = { queue_id: string; - queue_item_id: string; + queue_item_id: number; + queue_batch_id: string; graph_execution_state_id: string; node: BaseNode; source_node_id: string; @@ -95,7 +99,8 @@ export type InvocationCompleteEvent = { */ export type InvocationErrorEvent = { queue_id: string; - queue_item_id: string; + queue_item_id: number; + queue_batch_id: string; graph_execution_state_id: string; node: BaseNode; source_node_id: string; @@ -110,7 +115,8 @@ export type InvocationErrorEvent = { */ export type InvocationStartedEvent = { queue_id: string; - queue_item_id: string; + queue_item_id: number; + queue_batch_id: string; graph_execution_state_id: string; node: BaseNode; source_node_id: string; @@ -123,7 +129,8 @@ export type InvocationStartedEvent = { */ export type GraphExecutionStateCompleteEvent = { queue_id: string; - queue_item_id: string; + queue_item_id: number; + queue_batch_id: string; graph_execution_state_id: string; }; @@ -134,7 +141,8 @@ export type GraphExecutionStateCompleteEvent = { */ export type SessionRetrievalErrorEvent = { queue_id: string; - queue_item_id: string; + queue_item_id: number; + queue_batch_id: string; graph_execution_state_id: string; error_type: string; error: string; @@ -147,7 +155,8 @@ export type SessionRetrievalErrorEvent = { */ export type InvocationRetrievalErrorEvent = { queue_id: string; - queue_item_id: string; + queue_item_id: number; + queue_batch_id: string; graph_execution_state_id: string; node_id: string; error_type: string; @@ -161,8 +170,8 @@ export type InvocationRetrievalErrorEvent = { */ export type QueueItemStatusChangedEvent = { queue_id: string; - queue_item_id: string; - batch_id: string; + queue_item_id: number; + queue_batch_id: string; session_id: string; graph_execution_state_id: string; status: components['schemas']['SessionQueueItemDTO']['status']; diff --git a/tests/nodes/test_graph_execution_state.py b/tests/nodes/test_graph_execution_state.py index 41ca93551a..9009140134 100644 --- a/tests/nodes/test_graph_execution_state.py +++ b/tests/nodes/test_graph_execution_state.py @@ -75,7 +75,13 @@ def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[B print(f"invoking {n.id}: {type(n)}") o = n.invoke( - InvocationContext(queue_item_id="1", queue_id=DEFAULT_QUEUE_ID, services=services, graph_execution_state_id="1") + InvocationContext( + queue_batch_id="1", + queue_item_id=1, + queue_id=DEFAULT_QUEUE_ID, + services=services, + graph_execution_state_id="1", + ) ) g.complete(n.id, o) diff --git a/tests/nodes/test_invoker.py b/tests/nodes/test_invoker.py index 7dc5cf57b3..119ac70498 100644 --- a/tests/nodes/test_invoker.py +++ b/tests/nodes/test_invoker.py @@ -102,7 +102,12 @@ def test_can_create_graph_state_from_graph(mock_invoker: Invoker, simple_graph): # @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor") def test_can_invoke(mock_invoker: Invoker, simple_graph): g = mock_invoker.create_execution_state(graph=simple_graph) - invocation_id = mock_invoker.invoke(queue_item_id="1", queue_id=DEFAULT_QUEUE_ID, graph_execution_state=g) + invocation_id = mock_invoker.invoke( + session_queue_batch_id="1", + session_queue_item_id=1, + session_queue_id=DEFAULT_QUEUE_ID, + graph_execution_state=g, + ) assert invocation_id is not None def has_executed_any(g: GraphExecutionState): @@ -120,7 +125,11 @@ def test_can_invoke(mock_invoker: Invoker, simple_graph): def test_can_invoke_all(mock_invoker: Invoker, simple_graph): g = mock_invoker.create_execution_state(graph=simple_graph) invocation_id = mock_invoker.invoke( - queue_item_id="1", queue_id=DEFAULT_QUEUE_ID, graph_execution_state=g, invoke_all=True + session_queue_batch_id="1", + session_queue_item_id=1, + session_queue_id=DEFAULT_QUEUE_ID, + graph_execution_state=g, + invoke_all=True, ) assert invocation_id is not None @@ -140,7 +149,13 @@ def test_handles_errors(mock_invoker: Invoker): g = mock_invoker.create_execution_state() g.graph.add_node(ErrorInvocation(id="1")) - mock_invoker.invoke(queue_item_id="1", queue_id=DEFAULT_QUEUE_ID, graph_execution_state=g, invoke_all=True) + mock_invoker.invoke( + session_queue_batch_id="1", + session_queue_item_id=1, + session_queue_id=DEFAULT_QUEUE_ID, + graph_execution_state=g, + invoke_all=True, + ) def has_executed_all(g: GraphExecutionState): g = mock_invoker.services.graph_execution_manager.get(g.id)