From a2de5c9963483ce87ca5aad69c6ecc21cbf6673d Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 25 May 2023 23:47:57 +1000 Subject: [PATCH] feat(ui): change intermediates handling - Update the canvas graph generation to flag its uploaded init and mask images as `intermediate`. - During canvas setup, hit the update route to associate the uploaded images with the session id. - Organize the socketio and RTK listener middlware better. Needed to facilitate the updated canvas logic. - Add a new action `sessionReadyToInvoke`. The `sessionInvoked` action is *only* ever run in response to this event. This lets us do whatever complicated setup (eg canvas) and explicitly invoking. Previously, invoking was tied to the socket subscribe events. - Some minor tidying. --- .../middleware/listenerMiddleware/index.ts | 25 ++- .../listeners/imageUploaded.ts | 2 +- .../listeners/sessionReadyToInvoke.ts | 19 ++ .../listeners/socketio/generatorProgress.ts | 28 +++ .../socketio/graphExecutionStateComplete.ts | 17 ++ .../{ => socketio}/invocationComplete.ts | 41 +++-- .../listeners/socketio/invocationError.ts | 17 ++ .../listeners/socketio/invocationStarted.ts | 28 +++ .../listeners/socketio/socketConnected.ts | 43 +++++ .../listeners/socketio/socketDisconnected.ts | 14 ++ .../listeners/socketio/socketSubscribed.ts | 17 ++ .../listeners/socketio/socketUnsubscribed.ts | 17 ++ .../listeners/userInvokedCanvas.ts | 172 +++++++++++------- .../listeners/userInvokedImageToImage.ts | 7 +- .../listeners/userInvokedNodes.ts | 7 +- .../listeners/userInvokedTextToImage.ts | 9 +- invokeai/frontend/web/src/app/store/store.ts | 2 + .../src/features/nodes/util/parseSchema.ts | 6 +- .../web/src/features/system/store/actions.ts | 3 + .../src/features/system/store/sessionSlice.ts | 62 +++++++ .../web/src/services/events/middleware.ts | 22 +-- .../services/events/util/setEventListeners.ts | 85 +-------- .../web/src/services/thunks/gallery.ts | 10 +- .../frontend/web/src/services/thunks/image.ts | 16 ++ .../web/src/services/thunks/session.ts | 45 +++++ 25 files changed, 529 insertions(+), 185 deletions(-) create mode 100644 invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionReadyToInvoke.ts create mode 100644 invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/generatorProgress.ts create mode 100644 invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/graphExecutionStateComplete.ts rename invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/{ => socketio}/invocationComplete.ts (59%) create mode 100644 invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/invocationError.ts create mode 100644 invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/invocationStarted.ts create mode 100644 invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts create mode 100644 invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected.ts create mode 100644 invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSubscribed.ts create mode 100644 invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketUnsubscribed.ts create mode 100644 invokeai/frontend/web/src/features/system/store/actions.ts create mode 100644 invokeai/frontend/web/src/features/system/store/sessionSlice.ts diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts index f23e83a191..c04a3943f3 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts @@ -8,7 +8,6 @@ import type { TypedStartListening, TypedAddListener } from '@reduxjs/toolkit'; import type { RootState, AppDispatch } from '../../store'; import { addInitialImageSelectedListener } from './listeners/initialImageSelected'; -import { addImageResultReceivedListener } from './listeners/invocationComplete'; import { addImageUploadedListener } from './listeners/imageUploaded'; import { addRequestedImageDeletionListener } from './listeners/imageDeleted'; import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas'; @@ -19,6 +18,16 @@ import { addCanvasSavedToGalleryListener } from './listeners/canvasSavedToGaller import { addCanvasDownloadedAsImageListener } from './listeners/canvasDownloadedAsImage'; import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard'; import { addCanvasMergedListener } from './listeners/canvasMerged'; +import { addGeneratorProgressListener } from './listeners/socketio/generatorProgress'; +import { addGraphExecutionStateCompleteListener } from './listeners/socketio/graphExecutionStateComplete'; +import { addInvocationCompleteListener } from './listeners/socketio/invocationComplete'; +import { addInvocationErrorListener } from './listeners/socketio/invocationError'; +import { addInvocationStartedListener } from './listeners/socketio/invocationStarted'; +import { addSocketConnectedListener } from './listeners/socketio/socketConnected'; +import { addSocketDisconnectedListener } from './listeners/socketio/socketDisconnected'; +import { addSocketSubscribedListener } from './listeners/socketio/socketSubscribed'; +import { addSocketUnsubscribedListener } from './listeners/socketio/socketUnsubscribed'; +import { addSessionReadyToInvokeListener } from './listeners/sessionReadyToInvoke'; export const listenerMiddleware = createListenerMiddleware(); @@ -40,15 +49,27 @@ export type AppListenerEffect = ListenerEffect< addImageUploadedListener(); addInitialImageSelectedListener(); -addImageResultReceivedListener(); addRequestedImageDeletionListener(); addUserInvokedCanvasListener(); addUserInvokedNodesListener(); addUserInvokedTextToImageListener(); addUserInvokedImageToImageListener(); +addSessionReadyToInvokeListener(); addCanvasSavedToGalleryListener(); addCanvasDownloadedAsImageListener(); addCanvasCopiedToClipboardListener(); addCanvasMergedListener(); + +// socketio + +addGeneratorProgressListener(); +addGraphExecutionStateCompleteListener(); +addInvocationCompleteListener(); +addInvocationErrorListener(); +addInvocationStartedListener(); +addSocketConnectedListener(); +addSocketDisconnectedListener(); +addSocketSubscribedListener(); +addSocketUnsubscribedListener(); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts index 1d66166c12..b37cd3d139 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts @@ -12,7 +12,7 @@ export const addImageUploadedListener = () => { startAppListening({ predicate: (action): action is ReturnType => imageUploaded.fulfilled.match(action) && - action.payload.response.image_type !== 'intermediates', + action.payload.response.is_intermediate === false, effect: (action, { dispatch, getState }) => { const { response: image } = action.payload; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionReadyToInvoke.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionReadyToInvoke.ts new file mode 100644 index 0000000000..eb65017a25 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionReadyToInvoke.ts @@ -0,0 +1,19 @@ +import { startAppListening } from '..'; +import { sessionInvoked } from 'services/thunks/session'; +import { log } from 'app/logging/useLogger'; +import { sessionReadyToInvoke } from 'features/system/store/actions'; + +const moduleLog = log.child({ namespace: 'invoke' }); + +export const addSessionReadyToInvokeListener = () => { + startAppListening({ + actionCreator: sessionReadyToInvoke, + effect: (action, { getState, dispatch }) => { + const { sessionId } = getState().system; + if (sessionId) { + moduleLog.info({ sessionId }, `Session invoked (${sessionId})})`); + dispatch(sessionInvoked({ sessionId })); + } + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/generatorProgress.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/generatorProgress.ts new file mode 100644 index 0000000000..341b5e46d3 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/generatorProgress.ts @@ -0,0 +1,28 @@ +import { startAppListening } from '../..'; +import { log } from 'app/logging/useLogger'; +import { generatorProgress } from 'services/events/actions'; + +const moduleLog = log.child({ namespace: 'socketio' }); + +export const addGeneratorProgressListener = () => { + startAppListening({ + actionCreator: generatorProgress, + effect: (action, { dispatch, getState }) => { + if ( + getState().system.canceledSession === + action.payload.data.graph_execution_state_id + ) { + moduleLog.trace( + action.payload, + 'Ignored generator progress for canceled session' + ); + return; + } + + moduleLog.trace( + action.payload, + `Generator progress (${action.payload.data.node.type})` + ); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/graphExecutionStateComplete.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/graphExecutionStateComplete.ts new file mode 100644 index 0000000000..c8ac46f6f1 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/graphExecutionStateComplete.ts @@ -0,0 +1,17 @@ +import { log } from 'app/logging/useLogger'; +import { graphExecutionStateComplete } from 'services/events/actions'; +import { startAppListening } from '../..'; + +const moduleLog = log.child({ namespace: 'socketio' }); + +export const addGraphExecutionStateCompleteListener = () => { + startAppListening({ + actionCreator: graphExecutionStateComplete, + effect: (action, { dispatch, getState }) => { + moduleLog.debug( + action.payload, + `Graph execution state complete (${action.payload.data.graph_execution_state_id})` + ); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/invocationComplete.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/invocationComplete.ts similarity index 59% rename from invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/invocationComplete.ts rename to invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/invocationComplete.ts index 0222eea93c..76ae46c4a2 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/invocationComplete.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/invocationComplete.ts @@ -1,40 +1,49 @@ +import { addImageToStagingArea } from 'features/canvas/store/canvasSlice'; +import { startAppListening } from '../..'; +import { log } from 'app/logging/useLogger'; import { invocationComplete } from 'services/events/actions'; -import { isImageOutput } from 'services/types/guards'; import { imageMetadataReceived, imageUrlsReceived, } from 'services/thunks/image'; -import { startAppListening } from '..'; -import { addImageToStagingArea } from 'features/canvas/store/canvasSlice'; +import { sessionCanceled } from 'services/thunks/session'; +import { isImageOutput } from 'services/types/guards'; +const moduleLog = log.child({ namespace: 'socketio' }); const nodeDenylist = ['dataURL_image']; -export const addImageResultReceivedListener = () => { +export const addInvocationCompleteListener = () => { startAppListening({ - predicate: (action) => { - if ( - invocationComplete.match(action) && - isImageOutput(action.payload.data.result) - ) { - return true; - } - return false; - }, - effect: async (action, { getState, dispatch, take }) => { - if (!invocationComplete.match(action)) { - return; + actionCreator: invocationComplete, + effect: async (action, { dispatch, getState, take }) => { + moduleLog.info( + action.payload, + `Invocation complete (${action.payload.data.node.type})` + ); + + const sessionId = action.payload.data.graph_execution_state_id; + + const { cancelType, isCancelScheduled } = getState().system; + + // Handle scheduled cancelation + if (cancelType === 'scheduled' && isCancelScheduled) { + dispatch(sessionCanceled({ sessionId })); } const { data } = action.payload; const { result, node, graph_execution_state_id } = data; + // This complete event has an associated image output if (isImageOutput(result) && !nodeDenylist.includes(node.type)) { const { image_name, image_type } = result.image; + // Get its URLS + // TODO: is this extraneous? I think so... dispatch( imageUrlsReceived({ imageName: image_name, imageType: image_type }) ); + // Get its metadata dispatch( imageMetadataReceived({ imageName: image_name, diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/invocationError.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/invocationError.ts new file mode 100644 index 0000000000..d0e4d975be --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/invocationError.ts @@ -0,0 +1,17 @@ +import { startAppListening } from '../..'; +import { log } from 'app/logging/useLogger'; +import { invocationError } from 'services/events/actions'; + +const moduleLog = log.child({ namespace: 'socketio' }); + +export const addInvocationErrorListener = () => { + startAppListening({ + actionCreator: invocationError, + effect: (action, { dispatch, getState }) => { + moduleLog.debug( + action.payload, + `Invocation error (${action.payload.data.node.type})` + ); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/invocationStarted.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/invocationStarted.ts new file mode 100644 index 0000000000..373802fa16 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/invocationStarted.ts @@ -0,0 +1,28 @@ +import { startAppListening } from '../..'; +import { log } from 'app/logging/useLogger'; +import { invocationStarted } from 'services/events/actions'; + +const moduleLog = log.child({ namespace: 'socketio' }); + +export const addInvocationStartedListener = () => { + startAppListening({ + actionCreator: invocationStarted, + effect: (action, { dispatch, getState }) => { + if ( + getState().system.canceledSession === + action.payload.data.graph_execution_state_id + ) { + moduleLog.trace( + action.payload, + 'Ignored invocation started for canceled session' + ); + return; + } + + moduleLog.info( + action.payload, + `Invocation started (${action.payload.data.node.type})` + ); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts new file mode 100644 index 0000000000..bc9ecbec1e --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts @@ -0,0 +1,43 @@ +import { startAppListening } from '../..'; +import { log } from 'app/logging/useLogger'; +import { socketConnected } from 'services/events/actions'; +import { + receivedResultImagesPage, + receivedUploadImagesPage, +} from 'services/thunks/gallery'; +import { receivedModels } from 'services/thunks/model'; +import { receivedOpenAPISchema } from 'services/thunks/schema'; + +const moduleLog = log.child({ namespace: 'socketio' }); + +export const addSocketConnectedListener = () => { + startAppListening({ + actionCreator: socketConnected, + effect: (action, { dispatch, getState }) => { + const { timestamp } = action.payload; + + moduleLog.debug({ timestamp }, 'Connected'); + + const { results, uploads, models, nodes, config } = getState(); + + const { disabledTabs } = config; + + // These thunks need to be dispatch in middleware; cannot handle in a reducer + if (!results.ids.length) { + dispatch(receivedResultImagesPage()); + } + + if (!uploads.ids.length) { + dispatch(receivedUploadImagesPage()); + } + + if (!models.ids.length) { + dispatch(receivedModels()); + } + + if (!nodes.schema && !disabledTabs.includes('nodes')) { + dispatch(receivedOpenAPISchema()); + } + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected.ts new file mode 100644 index 0000000000..131c3ba18f --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected.ts @@ -0,0 +1,14 @@ +import { startAppListening } from '../..'; +import { log } from 'app/logging/useLogger'; +import { socketDisconnected } from 'services/events/actions'; + +const moduleLog = log.child({ namespace: 'socketio' }); + +export const addSocketDisconnectedListener = () => { + startAppListening({ + actionCreator: socketDisconnected, + effect: (action, { dispatch, getState }) => { + moduleLog.debug(action.payload, 'Disconnected'); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSubscribed.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSubscribed.ts new file mode 100644 index 0000000000..400f8a1689 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSubscribed.ts @@ -0,0 +1,17 @@ +import { startAppListening } from '../..'; +import { log } from 'app/logging/useLogger'; +import { socketSubscribed } from 'services/events/actions'; + +const moduleLog = log.child({ namespace: 'socketio' }); + +export const addSocketSubscribedListener = () => { + startAppListening({ + actionCreator: socketSubscribed, + effect: (action, { dispatch, getState }) => { + moduleLog.debug( + action.payload, + `Subscribed (${action.payload.sessionId}))` + ); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketUnsubscribed.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketUnsubscribed.ts new file mode 100644 index 0000000000..af15c55d42 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketUnsubscribed.ts @@ -0,0 +1,17 @@ +import { startAppListening } from '../..'; +import { log } from 'app/logging/useLogger'; +import { socketUnsubscribed } from 'services/events/actions'; + +const moduleLog = log.child({ namespace: 'socketio' }); + +export const addSocketUnsubscribedListener = () => { + startAppListening({ + actionCreator: socketUnsubscribed, + effect: (action, { dispatch, getState }) => { + moduleLog.debug( + action.payload, + `Unsubscribed (${action.payload.sessionId})` + ); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts index 2ebd3684e9..b90197f7cb 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts @@ -1,9 +1,9 @@ import { startAppListening } from '..'; -import { sessionCreated, sessionInvoked } from 'services/thunks/session'; +import { nodeUpdated, sessionCreated } from 'services/thunks/session'; import { buildCanvasGraphComponents } from 'features/nodes/util/graphBuilders/buildCanvasGraph'; import { log } from 'app/logging/useLogger'; import { canvasGraphBuilt } from 'features/nodes/store/actions'; -import { imageUploaded } from 'services/thunks/image'; +import { imageUpdated, imageUploaded } from 'services/thunks/image'; import { v4 as uuidv4 } from 'uuid'; import { Graph } from 'services/api'; import { @@ -15,12 +15,25 @@ import { getCanvasData } from 'features/canvas/util/getCanvasData'; import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGenerationMode'; import { blobToDataURL } from 'features/canvas/util/blobToDataURL'; import openBase64ImageInTab from 'common/util/openBase64ImageInTab'; +import { sessionReadyToInvoke } from 'features/system/store/actions'; const moduleLog = log.child({ namespace: 'invoke' }); /** - * This listener is responsible for building the canvas graph and blobs when the user invokes the canvas. - * It is also responsible for uploading the base and mask layers to the server. + * This listener is responsible invoking the canvas. This involved a number of steps: + * + * 1. Generate image blobs from the canvas layers + * 2. Determine the generation mode from the layers (txt2img, img2img, inpaint) + * 3. Build the canvas graph + * 4. Create the session + * 5. Upload the init image if necessary, then update the graph to refer to it (needs a separate request) + * 6. Upload the mask image if necessary, then update the graph to refer to it (needs a separate request) + * 7. Initialize the staging area if not yet initialized + * 8. Finally, dispatch the sessionReadyToInvoke action to invoke the session + * + * We have to do the uploads after creating the session: + * - We need to associate these particular uploads to a session, and flag them as intermediates + * - To do this, we need to associa */ export const addUserInvokedCanvasListener = () => { startAppListening({ @@ -70,63 +83,7 @@ export const addUserInvokedCanvasListener = () => { const { rangeNode, iterateNode, baseNode, edges } = graphComponents; - // Upload the base layer, to be used as init image - const baseFilename = `${uuidv4()}.png`; - - dispatch( - imageUploaded({ - imageType: 'intermediates', - formData: { - file: new File([baseBlob], baseFilename, { type: 'image/png' }), - }, - }) - ); - - if (baseNode.type === 'img2img' || baseNode.type === 'inpaint') { - const [{ payload: basePayload }] = await take( - (action): action is ReturnType => - imageUploaded.fulfilled.match(action) && - action.meta.arg.formData.file.name === baseFilename - ); - - const { image_name: baseName, image_type: baseType } = - basePayload.response; - - baseNode.image = { - image_name: baseName, - image_type: baseType, - }; - } - - // Upload the mask layer image - const maskFilename = `${uuidv4()}.png`; - - if (baseNode.type === 'inpaint') { - dispatch( - imageUploaded({ - imageType: 'intermediates', - formData: { - file: new File([maskBlob], maskFilename, { type: 'image/png' }), - }, - }) - ); - - const [{ payload: maskPayload }] = await take( - (action): action is ReturnType => - imageUploaded.fulfilled.match(action) && - action.meta.arg.formData.file.name === maskFilename - ); - - const { image_name: maskName, image_type: maskType } = - maskPayload.response; - - baseNode.mask = { - image_name: maskName, - image_type: maskType, - }; - } - - // Assemble! + // Assemble! Note that this graph *does not have the init or mask image set yet!* const nodes: Graph['nodes'] = { [rangeNode.id]: rangeNode, [iterateNode.id]: iterateNode, @@ -136,15 +93,96 @@ export const addUserInvokedCanvasListener = () => { const graph = { nodes, edges }; dispatch(canvasGraphBuilt(graph)); + moduleLog({ data: graph }, 'Canvas graph built'); - // Actually create the session + // If we are generating img2img or inpaint, we need to upload the init images + if (baseNode.type === 'img2img' || baseNode.type === 'inpaint') { + const baseFilename = `${uuidv4()}.png`; + dispatch( + imageUploaded({ + formData: { + file: new File([baseBlob], baseFilename, { type: 'image/png' }), + }, + isIntermediate: true, + }) + ); + + // Wait for the image to be uploaded + const [{ payload: basePayload }] = await take( + (action): action is ReturnType => + imageUploaded.fulfilled.match(action) && + action.meta.arg.formData.file.name === baseFilename + ); + + // Update the base node with the image name and type + const { image_name: baseName, image_type: baseType } = + basePayload.response; + + baseNode.image = { + image_name: baseName, + image_type: baseType, + }; + } + + // For inpaint, we also need to upload the mask layer + if (baseNode.type === 'inpaint') { + const maskFilename = `${uuidv4()}.png`; + dispatch( + imageUploaded({ + formData: { + file: new File([maskBlob], maskFilename, { type: 'image/png' }), + }, + isIntermediate: true, + }) + ); + + // Wait for the mask to be uploaded + const [{ payload: maskPayload }] = await take( + (action): action is ReturnType => + imageUploaded.fulfilled.match(action) && + action.meta.arg.formData.file.name === maskFilename + ); + + // Update the base node with the image name and type + const { image_name: maskName, image_type: maskType } = + maskPayload.response; + + baseNode.mask = { + image_name: maskName, + image_type: maskType, + }; + } + + // Create the session and wait for response dispatch(sessionCreated({ graph })); + const [sessionCreatedAction] = await take(sessionCreated.fulfilled.match); + const sessionId = sessionCreatedAction.payload.id; - // Wait for the session to be invoked (this is just the HTTP request to start processing) - const [{ meta }] = await take(sessionInvoked.fulfilled.match); + // Associate the init image with the session, now that we have the session ID + if ( + (baseNode.type === 'img2img' || baseNode.type === 'inpaint') && + baseNode.image + ) { + dispatch( + imageUpdated({ + imageName: baseNode.image.image_name, + imageType: baseNode.image.image_type, + requestBody: { session_id: sessionId }, + }) + ); + } - const { sessionId } = meta.arg; + // Associate the mask image with the session, now that we have the session ID + if (baseNode.type === 'inpaint' && baseNode.mask) { + dispatch( + imageUpdated({ + imageName: baseNode.mask.image_name, + imageType: baseNode.mask.image_type, + requestBody: { session_id: sessionId }, + }) + ); + } if (!state.canvas.layerState.stagingArea.boundingBox) { dispatch( @@ -158,7 +196,11 @@ export const addUserInvokedCanvasListener = () => { ); } + // Flag the session with the canvas session ID dispatch(canvasSessionIdChanged(sessionId)); + + // We are ready to invoke the session! + dispatch(sessionReadyToInvoke()); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedImageToImage.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedImageToImage.ts index e747aefa08..8940237782 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedImageToImage.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedImageToImage.ts @@ -4,6 +4,7 @@ import { sessionCreated } from 'services/thunks/session'; import { log } from 'app/logging/useLogger'; import { imageToImageGraphBuilt } from 'features/nodes/store/actions'; import { userInvoked } from 'app/store/actions'; +import { sessionReadyToInvoke } from 'features/system/store/actions'; const moduleLog = log.child({ namespace: 'invoke' }); @@ -11,7 +12,7 @@ export const addUserInvokedImageToImageListener = () => { startAppListening({ predicate: (action): action is ReturnType => userInvoked.match(action) && action.payload === 'img2img', - effect: (action, { getState, dispatch }) => { + effect: async (action, { getState, dispatch, take }) => { const state = getState(); const graph = buildImageToImageGraph(state); @@ -19,6 +20,10 @@ export const addUserInvokedImageToImageListener = () => { moduleLog({ data: graph }, 'Image to Image graph built'); dispatch(sessionCreated({ graph })); + + await take(sessionCreated.fulfilled.match); + + dispatch(sessionReadyToInvoke()); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedNodes.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedNodes.ts index 01e532d5ff..45dcf7b0b2 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedNodes.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedNodes.ts @@ -4,6 +4,7 @@ import { buildNodesGraph } from 'features/nodes/util/graphBuilders/buildNodesGra import { log } from 'app/logging/useLogger'; import { nodesGraphBuilt } from 'features/nodes/store/actions'; import { userInvoked } from 'app/store/actions'; +import { sessionReadyToInvoke } from 'features/system/store/actions'; const moduleLog = log.child({ namespace: 'invoke' }); @@ -11,7 +12,7 @@ export const addUserInvokedNodesListener = () => { startAppListening({ predicate: (action): action is ReturnType => userInvoked.match(action) && action.payload === 'nodes', - effect: (action, { getState, dispatch }) => { + effect: async (action, { getState, dispatch, take }) => { const state = getState(); const graph = buildNodesGraph(state); @@ -19,6 +20,10 @@ export const addUserInvokedNodesListener = () => { moduleLog({ data: graph }, 'Nodes graph built'); dispatch(sessionCreated({ graph })); + + await take(sessionCreated.fulfilled.match); + + dispatch(sessionReadyToInvoke()); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedTextToImage.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedTextToImage.ts index e3eb5d0b38..f7245b9301 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedTextToImage.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedTextToImage.ts @@ -4,6 +4,7 @@ import { sessionCreated } from 'services/thunks/session'; import { log } from 'app/logging/useLogger'; import { textToImageGraphBuilt } from 'features/nodes/store/actions'; import { userInvoked } from 'app/store/actions'; +import { sessionReadyToInvoke } from 'features/system/store/actions'; const moduleLog = log.child({ namespace: 'invoke' }); @@ -11,14 +12,20 @@ export const addUserInvokedTextToImageListener = () => { startAppListening({ predicate: (action): action is ReturnType => userInvoked.match(action) && action.payload === 'txt2img', - effect: (action, { getState, dispatch }) => { + effect: async (action, { getState, dispatch, take }) => { const state = getState(); const graph = buildTextToImageGraph(state); + dispatch(textToImageGraphBuilt(graph)); + moduleLog({ data: graph }, 'Text to Image graph built'); dispatch(sessionCreated({ graph })); + + await take(sessionCreated.fulfilled.match); + + dispatch(sessionReadyToInvoke()); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/store.ts b/invokeai/frontend/web/src/app/store/store.ts index b89615b2c0..4e9c154f3a 100644 --- a/invokeai/frontend/web/src/app/store/store.ts +++ b/invokeai/frontend/web/src/app/store/store.ts @@ -16,6 +16,7 @@ import lightboxReducer from 'features/lightbox/store/lightboxSlice'; import generationReducer from 'features/parameters/store/generationSlice'; import postprocessingReducer from 'features/parameters/store/postprocessingSlice'; import systemReducer from 'features/system/store/systemSlice'; +// import sessionReducer from 'features/system/store/sessionSlice'; import configReducer from 'features/system/store/configSlice'; import uiReducer from 'features/ui/store/uiSlice'; import hotkeysReducer from 'features/ui/store/hotkeysSlice'; @@ -46,6 +47,7 @@ const allReducers = { ui: uiReducer, uploads: uploadsReducer, hotkeys: hotkeysReducer, + // session: sessionReducer, }; const rootReducer = combineReducers(allReducers); diff --git a/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts b/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts index ddd19b8749..631552414d 100644 --- a/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts +++ b/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts @@ -13,7 +13,9 @@ import { buildOutputFieldTemplates, } from './fieldTemplateBuilders'; -const invocationDenylist = ['Graph']; +const RESERVED_FIELD_NAMES = ['id', 'type', 'meta']; + +const invocationDenylist = ['Graph', 'InvocationMeta']; export const parseSchema = (openAPI: OpenAPIV3.Document) => { // filter out non-invocation schemas, plus some tricky invocations for now @@ -73,7 +75,7 @@ export const parseSchema = (openAPI: OpenAPIV3.Document) => { (inputsAccumulator, property, propertyName) => { if ( // `type` and `id` are not valid inputs/outputs - !['type', 'id'].includes(propertyName) && + !RESERVED_FIELD_NAMES.includes(propertyName) && isSchemaObject(property) ) { const field: InputFieldTemplate | undefined = diff --git a/invokeai/frontend/web/src/features/system/store/actions.ts b/invokeai/frontend/web/src/features/system/store/actions.ts new file mode 100644 index 0000000000..66181bc803 --- /dev/null +++ b/invokeai/frontend/web/src/features/system/store/actions.ts @@ -0,0 +1,3 @@ +import { createAction } from '@reduxjs/toolkit'; + +export const sessionReadyToInvoke = createAction('system/sessionReadyToInvoke'); diff --git a/invokeai/frontend/web/src/features/system/store/sessionSlice.ts b/invokeai/frontend/web/src/features/system/store/sessionSlice.ts new file mode 100644 index 0000000000..40d59c7baa --- /dev/null +++ b/invokeai/frontend/web/src/features/system/store/sessionSlice.ts @@ -0,0 +1,62 @@ +// TODO: split system slice inot this + +// import type { PayloadAction } from '@reduxjs/toolkit'; +// import { createSlice } from '@reduxjs/toolkit'; +// import { socketSubscribed, socketUnsubscribed } from 'services/events/actions'; + +// export type SessionState = { +// /** +// * The current socket session id +// */ +// sessionId: string; +// /** +// * Whether the current session is a canvas session. Needed to manage the staging area. +// */ +// isCanvasSession: boolean; +// /** +// * When a session is canceled, its ID is stored here until a new session is created. +// */ +// canceledSessionId: string; +// }; + +// export const initialSessionState: SessionState = { +// sessionId: '', +// isCanvasSession: false, +// canceledSessionId: '', +// }; + +// export const sessionSlice = createSlice({ +// name: 'session', +// initialState: initialSessionState, +// reducers: { +// sessionIdChanged: (state, action: PayloadAction) => { +// state.sessionId = action.payload; +// }, +// isCanvasSessionChanged: (state, action: PayloadAction) => { +// state.isCanvasSession = action.payload; +// }, +// }, +// extraReducers: (builder) => { +// /** +// * Socket Subscribed +// */ +// builder.addCase(socketSubscribed, (state, action) => { +// state.sessionId = action.payload.sessionId; +// state.canceledSessionId = ''; +// }); + +// /** +// * Socket Unsubscribed +// */ +// builder.addCase(socketUnsubscribed, (state) => { +// state.sessionId = ''; +// }); +// }, +// }); + +// export const { sessionIdChanged, isCanvasSessionChanged } = +// sessionSlice.actions; + +// export default sessionSlice.reducer; + +export default {}; diff --git a/invokeai/frontend/web/src/services/events/middleware.ts b/invokeai/frontend/web/src/services/events/middleware.ts index bd1d60099a..a78e0de97b 100644 --- a/invokeai/frontend/web/src/services/events/middleware.ts +++ b/invokeai/frontend/web/src/services/events/middleware.ts @@ -8,7 +8,11 @@ import { import { socketSubscribed, socketUnsubscribed } from './actions'; import { AppThunkDispatch, RootState } from 'app/store/store'; import { getTimestamp } from 'common/util/getTimestamp'; -import { sessionInvoked, sessionCreated } from 'services/thunks/session'; +import { + sessionInvoked, + sessionCreated, + sessionWithoutGraphCreated, +} from 'services/thunks/session'; import { OpenAPI } from 'services/api'; import { setEventListeners } from 'services/events/util/setEventListeners'; import { log } from 'app/logging/useLogger'; @@ -62,17 +66,14 @@ export const socketMiddleware = () => { socket.connect(); } - if (sessionCreated.fulfilled.match(action)) { + if ( + sessionCreated.fulfilled.match(action) || + sessionWithoutGraphCreated.fulfilled.match(action) + ) { const sessionId = action.payload.id; - const sessionLog = socketioLog.child({ sessionId }); const oldSessionId = getState().system.sessionId; if (oldSessionId) { - sessionLog.debug( - { oldSessionId }, - `Unsubscribed from old session (${oldSessionId})` - ); - socket.emit('unsubscribe', { session: oldSessionId, }); @@ -85,8 +86,6 @@ export const socketMiddleware = () => { ); } - sessionLog.debug(`Subscribe to new session (${sessionId})`); - socket.emit('subscribe', { session: sessionId }); dispatch( @@ -95,9 +94,6 @@ export const socketMiddleware = () => { timestamp: getTimestamp(), }) ); - - // Finally we actually invoke the session, starting processing - dispatch(sessionInvoked({ sessionId })); } next(action); diff --git a/invokeai/frontend/web/src/services/events/util/setEventListeners.ts b/invokeai/frontend/web/src/services/events/util/setEventListeners.ts index 4431a9fd8b..5262b26d1e 100644 --- a/invokeai/frontend/web/src/services/events/util/setEventListeners.ts +++ b/invokeai/frontend/web/src/services/events/util/setEventListeners.ts @@ -1,7 +1,6 @@ import { MiddlewareAPI } from '@reduxjs/toolkit'; import { AppDispatch, RootState } from 'app/store/store'; import { getTimestamp } from 'common/util/getTimestamp'; -import { sessionCanceled } from 'services/thunks/session'; import { Socket } from 'socket.io-client'; import { generatorProgress, @@ -16,12 +15,6 @@ import { import { ClientToServerEvents, ServerToClientEvents } from '../types'; import { Logger } from 'roarr'; import { JsonObject } from 'roarr/dist/types'; -import { - receivedResultImagesPage, - receivedUploadImagesPage, -} from 'services/thunks/gallery'; -import { receivedModels } from 'services/thunks/model'; -import { receivedOpenAPISchema } from 'services/thunks/schema'; import { makeToast } from '../../../app/components/Toaster'; import { addToast } from '../../../features/system/store/systemSlice'; @@ -43,37 +36,13 @@ export const setEventListeners = (arg: SetEventListenersArg) => { dispatch(socketConnected({ timestamp: getTimestamp() })); - const { results, uploads, models, nodes, config, system } = getState(); + const { sessionId } = getState().system; - const { disabledTabs } = config; - - // These thunks need to be dispatch in middleware; cannot handle in a reducer - if (!results.ids.length) { - dispatch(receivedResultImagesPage()); - } - - if (!uploads.ids.length) { - dispatch(receivedUploadImagesPage()); - } - - if (!models.ids.length) { - dispatch(receivedModels()); - } - - if (!nodes.schema && !disabledTabs.includes('nodes')) { - dispatch(receivedOpenAPISchema()); - } - - if (system.sessionId) { - log.debug( - { sessionId: system.sessionId }, - `Subscribed to existing session (${system.sessionId})` - ); - - socket.emit('subscribe', { session: system.sessionId }); + if (sessionId) { + socket.emit('subscribe', { session: sessionId }); dispatch( socketSubscribed({ - sessionId: system.sessionId, + sessionId, timestamp: getTimestamp(), }) ); @@ -101,7 +70,6 @@ export const setEventListeners = (arg: SetEventListenersArg) => { * Disconnect */ socket.on('disconnect', () => { - log.debug('Disconnected'); dispatch(socketDisconnected({ timestamp: getTimestamp() })); }); @@ -109,18 +77,6 @@ export const setEventListeners = (arg: SetEventListenersArg) => { * Invocation started */ socket.on('invocation_started', (data) => { - if (getState().system.canceledSession === data.graph_execution_state_id) { - log.trace( - { data, sessionId: data.graph_execution_state_id }, - `Ignored invocation started (${data.node.type}) for canceled session (${data.graph_execution_state_id})` - ); - return; - } - - log.info( - { data, sessionId: data.graph_execution_state_id }, - `Invocation started (${data.node.type})` - ); dispatch(invocationStarted({ data, timestamp: getTimestamp() })); }); @@ -128,18 +84,6 @@ export const setEventListeners = (arg: SetEventListenersArg) => { * Generator progress */ socket.on('generator_progress', (data) => { - if (getState().system.canceledSession === data.graph_execution_state_id) { - log.trace( - { data, sessionId: data.graph_execution_state_id }, - `Ignored generator progress (${data.node.type}) for canceled session (${data.graph_execution_state_id})` - ); - return; - } - - log.trace( - { data, sessionId: data.graph_execution_state_id }, - `Generator progress (${data.node.type})` - ); dispatch(generatorProgress({ data, timestamp: getTimestamp() })); }); @@ -147,10 +91,6 @@ export const setEventListeners = (arg: SetEventListenersArg) => { * Invocation error */ socket.on('invocation_error', (data) => { - log.error( - { data, sessionId: data.graph_execution_state_id }, - `Invocation error (${data.node.type})` - ); dispatch(invocationError({ data, timestamp: getTimestamp() })); }); @@ -158,19 +98,6 @@ export const setEventListeners = (arg: SetEventListenersArg) => { * Invocation complete */ socket.on('invocation_complete', (data) => { - log.info( - { data, sessionId: data.graph_execution_state_id }, - `Invocation complete (${data.node.type})` - ); - const sessionId = data.graph_execution_state_id; - - const { cancelType, isCancelScheduled } = getState().system; - - // Handle scheduled cancelation - if (cancelType === 'scheduled' && isCancelScheduled) { - dispatch(sessionCanceled({ sessionId })); - } - dispatch( invocationComplete({ data, @@ -183,10 +110,6 @@ export const setEventListeners = (arg: SetEventListenersArg) => { * Graph complete */ socket.on('graph_execution_state_complete', (data) => { - log.info( - { data, sessionId: data.graph_execution_state_id }, - `Graph execution state complete (${data.graph_execution_state_id})` - ); dispatch(graphExecutionStateComplete({ data, timestamp: getTimestamp() })); }); }; diff --git a/invokeai/frontend/web/src/services/thunks/gallery.ts b/invokeai/frontend/web/src/services/thunks/gallery.ts index 01e8a986b2..5321b7ca3e 100644 --- a/invokeai/frontend/web/src/services/thunks/gallery.ts +++ b/invokeai/frontend/web/src/services/thunks/gallery.ts @@ -12,7 +12,7 @@ export const receivedResultImagesPage = createAppAsyncThunk( const { page, pages, nextPage } = getState().results; if (nextPage === page) { - rejectWithValue([]); + return rejectWithValue([]); } const response = await ImagesService.listImagesWithMetadata({ @@ -30,7 +30,13 @@ export const receivedResultImagesPage = createAppAsyncThunk( export const receivedUploadImagesPage = createAppAsyncThunk( 'uploads/receivedUploadImagesPage', - async (_arg, { getState }) => { + async (_arg, { getState, rejectWithValue }) => { + const { page, pages, nextPage } = getState().uploads; + + if (nextPage === page) { + return rejectWithValue([]); + } + const response = await ImagesService.listImagesWithMetadata({ imageType: 'uploads', imageCategory: 'general', diff --git a/invokeai/frontend/web/src/services/thunks/image.ts b/invokeai/frontend/web/src/services/thunks/image.ts index 6831eb647d..34b369e3eb 100644 --- a/invokeai/frontend/web/src/services/thunks/image.ts +++ b/invokeai/frontend/web/src/services/thunks/image.ts @@ -76,3 +76,19 @@ export const imageDeleted = createAppAsyncThunk( return response; } ); + +type ImageUpdatedArg = Parameters<(typeof ImagesService)['updateImage']>[0]; + +/** + * `ImagesService.deleteImage()` thunk + */ +export const imageUpdated = createAppAsyncThunk( + 'api/imageUpdated', + async (arg: ImageUpdatedArg) => { + const response = await ImagesService.updateImage(arg); + + imagesLog.debug({ arg, response }, 'Image updated'); + + return response; + } +); diff --git a/invokeai/frontend/web/src/services/thunks/session.ts b/invokeai/frontend/web/src/services/thunks/session.ts index dca4134886..a1ee5a34ed 100644 --- a/invokeai/frontend/web/src/services/thunks/session.ts +++ b/invokeai/frontend/web/src/services/thunks/session.ts @@ -35,6 +35,28 @@ export const sessionCreated = createAppAsyncThunk( } ); +/** + * `SessionsService.createSession()` without graph thunk + */ +export const sessionWithoutGraphCreated = createAppAsyncThunk( + 'api/sessionWithoutGraphCreated', + async (_, { rejectWithValue }) => { + try { + const response = await SessionsService.createSession({}); + sessionLog.info({ response }, `Session created (${response.id})`); + return response; + } catch (err: any) { + sessionLog.error( + { + error: serializeError(err), + }, + 'Problem creating session' + ); + return rejectWithValue(err.message); + } + } +); + type NodeAddedArg = Parameters<(typeof SessionsService)['addNode']>[0]; /** @@ -57,6 +79,29 @@ export const nodeAdded = createAppAsyncThunk( } ); +type NodeUpdatedArg = Parameters<(typeof SessionsService)['updateNode']>[0]; + +/** + * `SessionsService.addNode()` thunk + */ +export const nodeUpdated = createAppAsyncThunk( + 'api/nodeUpdated', + async ( + arg: { node: NodeUpdatedArg['requestBody']; sessionId: string }, + _thunkApi + ) => { + const response = await SessionsService.updateNode({ + requestBody: arg.node, + sessionId: arg.sessionId, + nodePath: arg.node.id, + }); + + sessionLog.info({ arg, response }, `Node updated (${response})`); + + return response; + } +); + /** * `SessionsService.invokeSession()` thunk */