feat(ui): change intermediates handling

- Update the canvas graph generation to flag its uploaded init and mask images as `intermediate`.
- During canvas setup, hit the update route to associate the uploaded images with the session id.
- Organize the socketio and RTK listener middlware better. Needed to facilitate the updated canvas logic.
- Add a new action `sessionReadyToInvoke`. The `sessionInvoked` action is *only* ever run in response to this event. This lets us do whatever complicated setup (eg canvas) and explicitly invoking. Previously, invoking was tied to the socket subscribe events.
- Some minor tidying.
This commit is contained in:
psychedelicious 2023-05-25 23:47:57 +10:00 committed by Kent Keirsey
parent 5025f84627
commit a2de5c9963
25 changed files with 529 additions and 185 deletions

View File

@ -8,7 +8,6 @@ import type { TypedStartListening, TypedAddListener } from '@reduxjs/toolkit';
import type { RootState, AppDispatch } from '../../store'; import type { RootState, AppDispatch } from '../../store';
import { addInitialImageSelectedListener } from './listeners/initialImageSelected'; import { addInitialImageSelectedListener } from './listeners/initialImageSelected';
import { addImageResultReceivedListener } from './listeners/invocationComplete';
import { addImageUploadedListener } from './listeners/imageUploaded'; import { addImageUploadedListener } from './listeners/imageUploaded';
import { addRequestedImageDeletionListener } from './listeners/imageDeleted'; import { addRequestedImageDeletionListener } from './listeners/imageDeleted';
import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas'; import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
@ -19,6 +18,16 @@ import { addCanvasSavedToGalleryListener } from './listeners/canvasSavedToGaller
import { addCanvasDownloadedAsImageListener } from './listeners/canvasDownloadedAsImage'; import { addCanvasDownloadedAsImageListener } from './listeners/canvasDownloadedAsImage';
import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard'; import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard';
import { addCanvasMergedListener } from './listeners/canvasMerged'; import { addCanvasMergedListener } from './listeners/canvasMerged';
import { addGeneratorProgressListener } from './listeners/socketio/generatorProgress';
import { addGraphExecutionStateCompleteListener } from './listeners/socketio/graphExecutionStateComplete';
import { addInvocationCompleteListener } from './listeners/socketio/invocationComplete';
import { addInvocationErrorListener } from './listeners/socketio/invocationError';
import { addInvocationStartedListener } from './listeners/socketio/invocationStarted';
import { addSocketConnectedListener } from './listeners/socketio/socketConnected';
import { addSocketDisconnectedListener } from './listeners/socketio/socketDisconnected';
import { addSocketSubscribedListener } from './listeners/socketio/socketSubscribed';
import { addSocketUnsubscribedListener } from './listeners/socketio/socketUnsubscribed';
import { addSessionReadyToInvokeListener } from './listeners/sessionReadyToInvoke';
export const listenerMiddleware = createListenerMiddleware(); export const listenerMiddleware = createListenerMiddleware();
@ -40,15 +49,27 @@ export type AppListenerEffect = ListenerEffect<
addImageUploadedListener(); addImageUploadedListener();
addInitialImageSelectedListener(); addInitialImageSelectedListener();
addImageResultReceivedListener();
addRequestedImageDeletionListener(); addRequestedImageDeletionListener();
addUserInvokedCanvasListener(); addUserInvokedCanvasListener();
addUserInvokedNodesListener(); addUserInvokedNodesListener();
addUserInvokedTextToImageListener(); addUserInvokedTextToImageListener();
addUserInvokedImageToImageListener(); addUserInvokedImageToImageListener();
addSessionReadyToInvokeListener();
addCanvasSavedToGalleryListener(); addCanvasSavedToGalleryListener();
addCanvasDownloadedAsImageListener(); addCanvasDownloadedAsImageListener();
addCanvasCopiedToClipboardListener(); addCanvasCopiedToClipboardListener();
addCanvasMergedListener(); addCanvasMergedListener();
// socketio
addGeneratorProgressListener();
addGraphExecutionStateCompleteListener();
addInvocationCompleteListener();
addInvocationErrorListener();
addInvocationStartedListener();
addSocketConnectedListener();
addSocketDisconnectedListener();
addSocketSubscribedListener();
addSocketUnsubscribedListener();

View File

@ -12,7 +12,7 @@ export const addImageUploadedListener = () => {
startAppListening({ startAppListening({
predicate: (action): action is ReturnType<typeof imageUploaded.fulfilled> => predicate: (action): action is ReturnType<typeof imageUploaded.fulfilled> =>
imageUploaded.fulfilled.match(action) && imageUploaded.fulfilled.match(action) &&
action.payload.response.image_type !== 'intermediates', action.payload.response.is_intermediate === false,
effect: (action, { dispatch, getState }) => { effect: (action, { dispatch, getState }) => {
const { response: image } = action.payload; const { response: image } = action.payload;

View File

@ -0,0 +1,19 @@
import { startAppListening } from '..';
import { sessionInvoked } from 'services/thunks/session';
import { log } from 'app/logging/useLogger';
import { sessionReadyToInvoke } from 'features/system/store/actions';
const moduleLog = log.child({ namespace: 'invoke' });
export const addSessionReadyToInvokeListener = () => {
startAppListening({
actionCreator: sessionReadyToInvoke,
effect: (action, { getState, dispatch }) => {
const { sessionId } = getState().system;
if (sessionId) {
moduleLog.info({ sessionId }, `Session invoked (${sessionId})})`);
dispatch(sessionInvoked({ sessionId }));
}
},
});
};

View File

@ -0,0 +1,28 @@
import { startAppListening } from '../..';
import { log } from 'app/logging/useLogger';
import { generatorProgress } from 'services/events/actions';
const moduleLog = log.child({ namespace: 'socketio' });
export const addGeneratorProgressListener = () => {
startAppListening({
actionCreator: generatorProgress,
effect: (action, { dispatch, getState }) => {
if (
getState().system.canceledSession ===
action.payload.data.graph_execution_state_id
) {
moduleLog.trace(
action.payload,
'Ignored generator progress for canceled session'
);
return;
}
moduleLog.trace(
action.payload,
`Generator progress (${action.payload.data.node.type})`
);
},
});
};

View File

@ -0,0 +1,17 @@
import { log } from 'app/logging/useLogger';
import { graphExecutionStateComplete } from 'services/events/actions';
import { startAppListening } from '../..';
const moduleLog = log.child({ namespace: 'socketio' });
export const addGraphExecutionStateCompleteListener = () => {
startAppListening({
actionCreator: graphExecutionStateComplete,
effect: (action, { dispatch, getState }) => {
moduleLog.debug(
action.payload,
`Graph execution state complete (${action.payload.data.graph_execution_state_id})`
);
},
});
};

View File

@ -1,40 +1,49 @@
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
import { startAppListening } from '../..';
import { log } from 'app/logging/useLogger';
import { invocationComplete } from 'services/events/actions'; import { invocationComplete } from 'services/events/actions';
import { isImageOutput } from 'services/types/guards';
import { import {
imageMetadataReceived, imageMetadataReceived,
imageUrlsReceived, imageUrlsReceived,
} from 'services/thunks/image'; } from 'services/thunks/image';
import { startAppListening } from '..'; import { sessionCanceled } from 'services/thunks/session';
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice'; import { isImageOutput } from 'services/types/guards';
const moduleLog = log.child({ namespace: 'socketio' });
const nodeDenylist = ['dataURL_image']; const nodeDenylist = ['dataURL_image'];
export const addImageResultReceivedListener = () => { export const addInvocationCompleteListener = () => {
startAppListening({ startAppListening({
predicate: (action) => { actionCreator: invocationComplete,
if ( effect: async (action, { dispatch, getState, take }) => {
invocationComplete.match(action) && moduleLog.info(
isImageOutput(action.payload.data.result) action.payload,
) { `Invocation complete (${action.payload.data.node.type})`
return true; );
}
return false; const sessionId = action.payload.data.graph_execution_state_id;
},
effect: async (action, { getState, dispatch, take }) => { const { cancelType, isCancelScheduled } = getState().system;
if (!invocationComplete.match(action)) {
return; // Handle scheduled cancelation
if (cancelType === 'scheduled' && isCancelScheduled) {
dispatch(sessionCanceled({ sessionId }));
} }
const { data } = action.payload; const { data } = action.payload;
const { result, node, graph_execution_state_id } = data; const { result, node, graph_execution_state_id } = data;
// This complete event has an associated image output
if (isImageOutput(result) && !nodeDenylist.includes(node.type)) { if (isImageOutput(result) && !nodeDenylist.includes(node.type)) {
const { image_name, image_type } = result.image; const { image_name, image_type } = result.image;
// Get its URLS
// TODO: is this extraneous? I think so...
dispatch( dispatch(
imageUrlsReceived({ imageName: image_name, imageType: image_type }) imageUrlsReceived({ imageName: image_name, imageType: image_type })
); );
// Get its metadata
dispatch( dispatch(
imageMetadataReceived({ imageMetadataReceived({
imageName: image_name, imageName: image_name,

View File

@ -0,0 +1,17 @@
import { startAppListening } from '../..';
import { log } from 'app/logging/useLogger';
import { invocationError } from 'services/events/actions';
const moduleLog = log.child({ namespace: 'socketio' });
export const addInvocationErrorListener = () => {
startAppListening({
actionCreator: invocationError,
effect: (action, { dispatch, getState }) => {
moduleLog.debug(
action.payload,
`Invocation error (${action.payload.data.node.type})`
);
},
});
};

View File

@ -0,0 +1,28 @@
import { startAppListening } from '../..';
import { log } from 'app/logging/useLogger';
import { invocationStarted } from 'services/events/actions';
const moduleLog = log.child({ namespace: 'socketio' });
export const addInvocationStartedListener = () => {
startAppListening({
actionCreator: invocationStarted,
effect: (action, { dispatch, getState }) => {
if (
getState().system.canceledSession ===
action.payload.data.graph_execution_state_id
) {
moduleLog.trace(
action.payload,
'Ignored invocation started for canceled session'
);
return;
}
moduleLog.info(
action.payload,
`Invocation started (${action.payload.data.node.type})`
);
},
});
};

View File

@ -0,0 +1,43 @@
import { startAppListening } from '../..';
import { log } from 'app/logging/useLogger';
import { socketConnected } from 'services/events/actions';
import {
receivedResultImagesPage,
receivedUploadImagesPage,
} from 'services/thunks/gallery';
import { receivedModels } from 'services/thunks/model';
import { receivedOpenAPISchema } from 'services/thunks/schema';
const moduleLog = log.child({ namespace: 'socketio' });
export const addSocketConnectedListener = () => {
startAppListening({
actionCreator: socketConnected,
effect: (action, { dispatch, getState }) => {
const { timestamp } = action.payload;
moduleLog.debug({ timestamp }, 'Connected');
const { results, uploads, models, nodes, config } = getState();
const { disabledTabs } = config;
// These thunks need to be dispatch in middleware; cannot handle in a reducer
if (!results.ids.length) {
dispatch(receivedResultImagesPage());
}
if (!uploads.ids.length) {
dispatch(receivedUploadImagesPage());
}
if (!models.ids.length) {
dispatch(receivedModels());
}
if (!nodes.schema && !disabledTabs.includes('nodes')) {
dispatch(receivedOpenAPISchema());
}
},
});
};

View File

@ -0,0 +1,14 @@
import { startAppListening } from '../..';
import { log } from 'app/logging/useLogger';
import { socketDisconnected } from 'services/events/actions';
const moduleLog = log.child({ namespace: 'socketio' });
export const addSocketDisconnectedListener = () => {
startAppListening({
actionCreator: socketDisconnected,
effect: (action, { dispatch, getState }) => {
moduleLog.debug(action.payload, 'Disconnected');
},
});
};

View File

@ -0,0 +1,17 @@
import { startAppListening } from '../..';
import { log } from 'app/logging/useLogger';
import { socketSubscribed } from 'services/events/actions';
const moduleLog = log.child({ namespace: 'socketio' });
export const addSocketSubscribedListener = () => {
startAppListening({
actionCreator: socketSubscribed,
effect: (action, { dispatch, getState }) => {
moduleLog.debug(
action.payload,
`Subscribed (${action.payload.sessionId}))`
);
},
});
};

View File

@ -0,0 +1,17 @@
import { startAppListening } from '../..';
import { log } from 'app/logging/useLogger';
import { socketUnsubscribed } from 'services/events/actions';
const moduleLog = log.child({ namespace: 'socketio' });
export const addSocketUnsubscribedListener = () => {
startAppListening({
actionCreator: socketUnsubscribed,
effect: (action, { dispatch, getState }) => {
moduleLog.debug(
action.payload,
`Unsubscribed (${action.payload.sessionId})`
);
},
});
};

View File

@ -1,9 +1,9 @@
import { startAppListening } from '..'; import { startAppListening } from '..';
import { sessionCreated, sessionInvoked } from 'services/thunks/session'; import { nodeUpdated, sessionCreated } from 'services/thunks/session';
import { buildCanvasGraphComponents } from 'features/nodes/util/graphBuilders/buildCanvasGraph'; import { buildCanvasGraphComponents } from 'features/nodes/util/graphBuilders/buildCanvasGraph';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { canvasGraphBuilt } from 'features/nodes/store/actions'; import { canvasGraphBuilt } from 'features/nodes/store/actions';
import { imageUploaded } from 'services/thunks/image'; import { imageUpdated, imageUploaded } from 'services/thunks/image';
import { v4 as uuidv4 } from 'uuid'; import { v4 as uuidv4 } from 'uuid';
import { Graph } from 'services/api'; import { Graph } from 'services/api';
import { import {
@ -15,12 +15,25 @@ import { getCanvasData } from 'features/canvas/util/getCanvasData';
import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGenerationMode'; import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGenerationMode';
import { blobToDataURL } from 'features/canvas/util/blobToDataURL'; import { blobToDataURL } from 'features/canvas/util/blobToDataURL';
import openBase64ImageInTab from 'common/util/openBase64ImageInTab'; import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
import { sessionReadyToInvoke } from 'features/system/store/actions';
const moduleLog = log.child({ namespace: 'invoke' }); const moduleLog = log.child({ namespace: 'invoke' });
/** /**
* This listener is responsible for building the canvas graph and blobs when the user invokes the canvas. * This listener is responsible invoking the canvas. This involved a number of steps:
* It is also responsible for uploading the base and mask layers to the server. *
* 1. Generate image blobs from the canvas layers
* 2. Determine the generation mode from the layers (txt2img, img2img, inpaint)
* 3. Build the canvas graph
* 4. Create the session
* 5. Upload the init image if necessary, then update the graph to refer to it (needs a separate request)
* 6. Upload the mask image if necessary, then update the graph to refer to it (needs a separate request)
* 7. Initialize the staging area if not yet initialized
* 8. Finally, dispatch the sessionReadyToInvoke action to invoke the session
*
* We have to do the uploads after creating the session:
* - We need to associate these particular uploads to a session, and flag them as intermediates
* - To do this, we need to associa
*/ */
export const addUserInvokedCanvasListener = () => { export const addUserInvokedCanvasListener = () => {
startAppListening({ startAppListening({
@ -70,63 +83,7 @@ export const addUserInvokedCanvasListener = () => {
const { rangeNode, iterateNode, baseNode, edges } = graphComponents; const { rangeNode, iterateNode, baseNode, edges } = graphComponents;
// Upload the base layer, to be used as init image // Assemble! Note that this graph *does not have the init or mask image set yet!*
const baseFilename = `${uuidv4()}.png`;
dispatch(
imageUploaded({
imageType: 'intermediates',
formData: {
file: new File([baseBlob], baseFilename, { type: 'image/png' }),
},
})
);
if (baseNode.type === 'img2img' || baseNode.type === 'inpaint') {
const [{ payload: basePayload }] = await take(
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
imageUploaded.fulfilled.match(action) &&
action.meta.arg.formData.file.name === baseFilename
);
const { image_name: baseName, image_type: baseType } =
basePayload.response;
baseNode.image = {
image_name: baseName,
image_type: baseType,
};
}
// Upload the mask layer image
const maskFilename = `${uuidv4()}.png`;
if (baseNode.type === 'inpaint') {
dispatch(
imageUploaded({
imageType: 'intermediates',
formData: {
file: new File([maskBlob], maskFilename, { type: 'image/png' }),
},
})
);
const [{ payload: maskPayload }] = await take(
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
imageUploaded.fulfilled.match(action) &&
action.meta.arg.formData.file.name === maskFilename
);
const { image_name: maskName, image_type: maskType } =
maskPayload.response;
baseNode.mask = {
image_name: maskName,
image_type: maskType,
};
}
// Assemble!
const nodes: Graph['nodes'] = { const nodes: Graph['nodes'] = {
[rangeNode.id]: rangeNode, [rangeNode.id]: rangeNode,
[iterateNode.id]: iterateNode, [iterateNode.id]: iterateNode,
@ -136,15 +93,96 @@ export const addUserInvokedCanvasListener = () => {
const graph = { nodes, edges }; const graph = { nodes, edges };
dispatch(canvasGraphBuilt(graph)); dispatch(canvasGraphBuilt(graph));
moduleLog({ data: graph }, 'Canvas graph built'); moduleLog({ data: graph }, 'Canvas graph built');
// Actually create the session // If we are generating img2img or inpaint, we need to upload the init images
if (baseNode.type === 'img2img' || baseNode.type === 'inpaint') {
const baseFilename = `${uuidv4()}.png`;
dispatch(
imageUploaded({
formData: {
file: new File([baseBlob], baseFilename, { type: 'image/png' }),
},
isIntermediate: true,
})
);
// Wait for the image to be uploaded
const [{ payload: basePayload }] = await take(
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
imageUploaded.fulfilled.match(action) &&
action.meta.arg.formData.file.name === baseFilename
);
// Update the base node with the image name and type
const { image_name: baseName, image_type: baseType } =
basePayload.response;
baseNode.image = {
image_name: baseName,
image_type: baseType,
};
}
// For inpaint, we also need to upload the mask layer
if (baseNode.type === 'inpaint') {
const maskFilename = `${uuidv4()}.png`;
dispatch(
imageUploaded({
formData: {
file: new File([maskBlob], maskFilename, { type: 'image/png' }),
},
isIntermediate: true,
})
);
// Wait for the mask to be uploaded
const [{ payload: maskPayload }] = await take(
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
imageUploaded.fulfilled.match(action) &&
action.meta.arg.formData.file.name === maskFilename
);
// Update the base node with the image name and type
const { image_name: maskName, image_type: maskType } =
maskPayload.response;
baseNode.mask = {
image_name: maskName,
image_type: maskType,
};
}
// Create the session and wait for response
dispatch(sessionCreated({ graph })); dispatch(sessionCreated({ graph }));
const [sessionCreatedAction] = await take(sessionCreated.fulfilled.match);
const sessionId = sessionCreatedAction.payload.id;
// Wait for the session to be invoked (this is just the HTTP request to start processing) // Associate the init image with the session, now that we have the session ID
const [{ meta }] = await take(sessionInvoked.fulfilled.match); if (
(baseNode.type === 'img2img' || baseNode.type === 'inpaint') &&
baseNode.image
) {
dispatch(
imageUpdated({
imageName: baseNode.image.image_name,
imageType: baseNode.image.image_type,
requestBody: { session_id: sessionId },
})
);
}
const { sessionId } = meta.arg; // Associate the mask image with the session, now that we have the session ID
if (baseNode.type === 'inpaint' && baseNode.mask) {
dispatch(
imageUpdated({
imageName: baseNode.mask.image_name,
imageType: baseNode.mask.image_type,
requestBody: { session_id: sessionId },
})
);
}
if (!state.canvas.layerState.stagingArea.boundingBox) { if (!state.canvas.layerState.stagingArea.boundingBox) {
dispatch( dispatch(
@ -158,7 +196,11 @@ export const addUserInvokedCanvasListener = () => {
); );
} }
// Flag the session with the canvas session ID
dispatch(canvasSessionIdChanged(sessionId)); dispatch(canvasSessionIdChanged(sessionId));
// We are ready to invoke the session!
dispatch(sessionReadyToInvoke());
}, },
}); });
}; };

View File

@ -4,6 +4,7 @@ import { sessionCreated } from 'services/thunks/session';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { imageToImageGraphBuilt } from 'features/nodes/store/actions'; import { imageToImageGraphBuilt } from 'features/nodes/store/actions';
import { userInvoked } from 'app/store/actions'; import { userInvoked } from 'app/store/actions';
import { sessionReadyToInvoke } from 'features/system/store/actions';
const moduleLog = log.child({ namespace: 'invoke' }); const moduleLog = log.child({ namespace: 'invoke' });
@ -11,7 +12,7 @@ export const addUserInvokedImageToImageListener = () => {
startAppListening({ startAppListening({
predicate: (action): action is ReturnType<typeof userInvoked> => predicate: (action): action is ReturnType<typeof userInvoked> =>
userInvoked.match(action) && action.payload === 'img2img', userInvoked.match(action) && action.payload === 'img2img',
effect: (action, { getState, dispatch }) => { effect: async (action, { getState, dispatch, take }) => {
const state = getState(); const state = getState();
const graph = buildImageToImageGraph(state); const graph = buildImageToImageGraph(state);
@ -19,6 +20,10 @@ export const addUserInvokedImageToImageListener = () => {
moduleLog({ data: graph }, 'Image to Image graph built'); moduleLog({ data: graph }, 'Image to Image graph built');
dispatch(sessionCreated({ graph })); dispatch(sessionCreated({ graph }));
await take(sessionCreated.fulfilled.match);
dispatch(sessionReadyToInvoke());
}, },
}); });
}; };

View File

@ -4,6 +4,7 @@ import { buildNodesGraph } from 'features/nodes/util/graphBuilders/buildNodesGra
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { nodesGraphBuilt } from 'features/nodes/store/actions'; import { nodesGraphBuilt } from 'features/nodes/store/actions';
import { userInvoked } from 'app/store/actions'; import { userInvoked } from 'app/store/actions';
import { sessionReadyToInvoke } from 'features/system/store/actions';
const moduleLog = log.child({ namespace: 'invoke' }); const moduleLog = log.child({ namespace: 'invoke' });
@ -11,7 +12,7 @@ export const addUserInvokedNodesListener = () => {
startAppListening({ startAppListening({
predicate: (action): action is ReturnType<typeof userInvoked> => predicate: (action): action is ReturnType<typeof userInvoked> =>
userInvoked.match(action) && action.payload === 'nodes', userInvoked.match(action) && action.payload === 'nodes',
effect: (action, { getState, dispatch }) => { effect: async (action, { getState, dispatch, take }) => {
const state = getState(); const state = getState();
const graph = buildNodesGraph(state); const graph = buildNodesGraph(state);
@ -19,6 +20,10 @@ export const addUserInvokedNodesListener = () => {
moduleLog({ data: graph }, 'Nodes graph built'); moduleLog({ data: graph }, 'Nodes graph built');
dispatch(sessionCreated({ graph })); dispatch(sessionCreated({ graph }));
await take(sessionCreated.fulfilled.match);
dispatch(sessionReadyToInvoke());
}, },
}); });
}; };

View File

@ -4,6 +4,7 @@ import { sessionCreated } from 'services/thunks/session';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { textToImageGraphBuilt } from 'features/nodes/store/actions'; import { textToImageGraphBuilt } from 'features/nodes/store/actions';
import { userInvoked } from 'app/store/actions'; import { userInvoked } from 'app/store/actions';
import { sessionReadyToInvoke } from 'features/system/store/actions';
const moduleLog = log.child({ namespace: 'invoke' }); const moduleLog = log.child({ namespace: 'invoke' });
@ -11,14 +12,20 @@ export const addUserInvokedTextToImageListener = () => {
startAppListening({ startAppListening({
predicate: (action): action is ReturnType<typeof userInvoked> => predicate: (action): action is ReturnType<typeof userInvoked> =>
userInvoked.match(action) && action.payload === 'txt2img', userInvoked.match(action) && action.payload === 'txt2img',
effect: (action, { getState, dispatch }) => { effect: async (action, { getState, dispatch, take }) => {
const state = getState(); const state = getState();
const graph = buildTextToImageGraph(state); const graph = buildTextToImageGraph(state);
dispatch(textToImageGraphBuilt(graph)); dispatch(textToImageGraphBuilt(graph));
moduleLog({ data: graph }, 'Text to Image graph built'); moduleLog({ data: graph }, 'Text to Image graph built');
dispatch(sessionCreated({ graph })); dispatch(sessionCreated({ graph }));
await take(sessionCreated.fulfilled.match);
dispatch(sessionReadyToInvoke());
}, },
}); });
}; };

View File

@ -16,6 +16,7 @@ import lightboxReducer from 'features/lightbox/store/lightboxSlice';
import generationReducer from 'features/parameters/store/generationSlice'; import generationReducer from 'features/parameters/store/generationSlice';
import postprocessingReducer from 'features/parameters/store/postprocessingSlice'; import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
import systemReducer from 'features/system/store/systemSlice'; import systemReducer from 'features/system/store/systemSlice';
// import sessionReducer from 'features/system/store/sessionSlice';
import configReducer from 'features/system/store/configSlice'; import configReducer from 'features/system/store/configSlice';
import uiReducer from 'features/ui/store/uiSlice'; import uiReducer from 'features/ui/store/uiSlice';
import hotkeysReducer from 'features/ui/store/hotkeysSlice'; import hotkeysReducer from 'features/ui/store/hotkeysSlice';
@ -46,6 +47,7 @@ const allReducers = {
ui: uiReducer, ui: uiReducer,
uploads: uploadsReducer, uploads: uploadsReducer,
hotkeys: hotkeysReducer, hotkeys: hotkeysReducer,
// session: sessionReducer,
}; };
const rootReducer = combineReducers(allReducers); const rootReducer = combineReducers(allReducers);

View File

@ -13,7 +13,9 @@ import {
buildOutputFieldTemplates, buildOutputFieldTemplates,
} from './fieldTemplateBuilders'; } from './fieldTemplateBuilders';
const invocationDenylist = ['Graph']; const RESERVED_FIELD_NAMES = ['id', 'type', 'meta'];
const invocationDenylist = ['Graph', 'InvocationMeta'];
export const parseSchema = (openAPI: OpenAPIV3.Document) => { export const parseSchema = (openAPI: OpenAPIV3.Document) => {
// filter out non-invocation schemas, plus some tricky invocations for now // filter out non-invocation schemas, plus some tricky invocations for now
@ -73,7 +75,7 @@ export const parseSchema = (openAPI: OpenAPIV3.Document) => {
(inputsAccumulator, property, propertyName) => { (inputsAccumulator, property, propertyName) => {
if ( if (
// `type` and `id` are not valid inputs/outputs // `type` and `id` are not valid inputs/outputs
!['type', 'id'].includes(propertyName) && !RESERVED_FIELD_NAMES.includes(propertyName) &&
isSchemaObject(property) isSchemaObject(property)
) { ) {
const field: InputFieldTemplate | undefined = const field: InputFieldTemplate | undefined =

View File

@ -0,0 +1,3 @@
import { createAction } from '@reduxjs/toolkit';
export const sessionReadyToInvoke = createAction('system/sessionReadyToInvoke');

View File

@ -0,0 +1,62 @@
// TODO: split system slice inot this
// import type { PayloadAction } from '@reduxjs/toolkit';
// import { createSlice } from '@reduxjs/toolkit';
// import { socketSubscribed, socketUnsubscribed } from 'services/events/actions';
// export type SessionState = {
// /**
// * The current socket session id
// */
// sessionId: string;
// /**
// * Whether the current session is a canvas session. Needed to manage the staging area.
// */
// isCanvasSession: boolean;
// /**
// * When a session is canceled, its ID is stored here until a new session is created.
// */
// canceledSessionId: string;
// };
// export const initialSessionState: SessionState = {
// sessionId: '',
// isCanvasSession: false,
// canceledSessionId: '',
// };
// export const sessionSlice = createSlice({
// name: 'session',
// initialState: initialSessionState,
// reducers: {
// sessionIdChanged: (state, action: PayloadAction<string>) => {
// state.sessionId = action.payload;
// },
// isCanvasSessionChanged: (state, action: PayloadAction<boolean>) => {
// state.isCanvasSession = action.payload;
// },
// },
// extraReducers: (builder) => {
// /**
// * Socket Subscribed
// */
// builder.addCase(socketSubscribed, (state, action) => {
// state.sessionId = action.payload.sessionId;
// state.canceledSessionId = '';
// });
// /**
// * Socket Unsubscribed
// */
// builder.addCase(socketUnsubscribed, (state) => {
// state.sessionId = '';
// });
// },
// });
// export const { sessionIdChanged, isCanvasSessionChanged } =
// sessionSlice.actions;
// export default sessionSlice.reducer;
export default {};

View File

@ -8,7 +8,11 @@ import {
import { socketSubscribed, socketUnsubscribed } from './actions'; import { socketSubscribed, socketUnsubscribed } from './actions';
import { AppThunkDispatch, RootState } from 'app/store/store'; import { AppThunkDispatch, RootState } from 'app/store/store';
import { getTimestamp } from 'common/util/getTimestamp'; import { getTimestamp } from 'common/util/getTimestamp';
import { sessionInvoked, sessionCreated } from 'services/thunks/session'; import {
sessionInvoked,
sessionCreated,
sessionWithoutGraphCreated,
} from 'services/thunks/session';
import { OpenAPI } from 'services/api'; import { OpenAPI } from 'services/api';
import { setEventListeners } from 'services/events/util/setEventListeners'; import { setEventListeners } from 'services/events/util/setEventListeners';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
@ -62,17 +66,14 @@ export const socketMiddleware = () => {
socket.connect(); socket.connect();
} }
if (sessionCreated.fulfilled.match(action)) { if (
sessionCreated.fulfilled.match(action) ||
sessionWithoutGraphCreated.fulfilled.match(action)
) {
const sessionId = action.payload.id; const sessionId = action.payload.id;
const sessionLog = socketioLog.child({ sessionId });
const oldSessionId = getState().system.sessionId; const oldSessionId = getState().system.sessionId;
if (oldSessionId) { if (oldSessionId) {
sessionLog.debug(
{ oldSessionId },
`Unsubscribed from old session (${oldSessionId})`
);
socket.emit('unsubscribe', { socket.emit('unsubscribe', {
session: oldSessionId, session: oldSessionId,
}); });
@ -85,8 +86,6 @@ export const socketMiddleware = () => {
); );
} }
sessionLog.debug(`Subscribe to new session (${sessionId})`);
socket.emit('subscribe', { session: sessionId }); socket.emit('subscribe', { session: sessionId });
dispatch( dispatch(
@ -95,9 +94,6 @@ export const socketMiddleware = () => {
timestamp: getTimestamp(), timestamp: getTimestamp(),
}) })
); );
// Finally we actually invoke the session, starting processing
dispatch(sessionInvoked({ sessionId }));
} }
next(action); next(action);

View File

@ -1,7 +1,6 @@
import { MiddlewareAPI } from '@reduxjs/toolkit'; import { MiddlewareAPI } from '@reduxjs/toolkit';
import { AppDispatch, RootState } from 'app/store/store'; import { AppDispatch, RootState } from 'app/store/store';
import { getTimestamp } from 'common/util/getTimestamp'; import { getTimestamp } from 'common/util/getTimestamp';
import { sessionCanceled } from 'services/thunks/session';
import { Socket } from 'socket.io-client'; import { Socket } from 'socket.io-client';
import { import {
generatorProgress, generatorProgress,
@ -16,12 +15,6 @@ import {
import { ClientToServerEvents, ServerToClientEvents } from '../types'; import { ClientToServerEvents, ServerToClientEvents } from '../types';
import { Logger } from 'roarr'; import { Logger } from 'roarr';
import { JsonObject } from 'roarr/dist/types'; import { JsonObject } from 'roarr/dist/types';
import {
receivedResultImagesPage,
receivedUploadImagesPage,
} from 'services/thunks/gallery';
import { receivedModels } from 'services/thunks/model';
import { receivedOpenAPISchema } from 'services/thunks/schema';
import { makeToast } from '../../../app/components/Toaster'; import { makeToast } from '../../../app/components/Toaster';
import { addToast } from '../../../features/system/store/systemSlice'; import { addToast } from '../../../features/system/store/systemSlice';
@ -43,37 +36,13 @@ export const setEventListeners = (arg: SetEventListenersArg) => {
dispatch(socketConnected({ timestamp: getTimestamp() })); dispatch(socketConnected({ timestamp: getTimestamp() }));
const { results, uploads, models, nodes, config, system } = getState(); const { sessionId } = getState().system;
const { disabledTabs } = config; if (sessionId) {
socket.emit('subscribe', { session: sessionId });
// These thunks need to be dispatch in middleware; cannot handle in a reducer
if (!results.ids.length) {
dispatch(receivedResultImagesPage());
}
if (!uploads.ids.length) {
dispatch(receivedUploadImagesPage());
}
if (!models.ids.length) {
dispatch(receivedModels());
}
if (!nodes.schema && !disabledTabs.includes('nodes')) {
dispatch(receivedOpenAPISchema());
}
if (system.sessionId) {
log.debug(
{ sessionId: system.sessionId },
`Subscribed to existing session (${system.sessionId})`
);
socket.emit('subscribe', { session: system.sessionId });
dispatch( dispatch(
socketSubscribed({ socketSubscribed({
sessionId: system.sessionId, sessionId,
timestamp: getTimestamp(), timestamp: getTimestamp(),
}) })
); );
@ -101,7 +70,6 @@ export const setEventListeners = (arg: SetEventListenersArg) => {
* Disconnect * Disconnect
*/ */
socket.on('disconnect', () => { socket.on('disconnect', () => {
log.debug('Disconnected');
dispatch(socketDisconnected({ timestamp: getTimestamp() })); dispatch(socketDisconnected({ timestamp: getTimestamp() }));
}); });
@ -109,18 +77,6 @@ export const setEventListeners = (arg: SetEventListenersArg) => {
* Invocation started * Invocation started
*/ */
socket.on('invocation_started', (data) => { socket.on('invocation_started', (data) => {
if (getState().system.canceledSession === data.graph_execution_state_id) {
log.trace(
{ data, sessionId: data.graph_execution_state_id },
`Ignored invocation started (${data.node.type}) for canceled session (${data.graph_execution_state_id})`
);
return;
}
log.info(
{ data, sessionId: data.graph_execution_state_id },
`Invocation started (${data.node.type})`
);
dispatch(invocationStarted({ data, timestamp: getTimestamp() })); dispatch(invocationStarted({ data, timestamp: getTimestamp() }));
}); });
@ -128,18 +84,6 @@ export const setEventListeners = (arg: SetEventListenersArg) => {
* Generator progress * Generator progress
*/ */
socket.on('generator_progress', (data) => { socket.on('generator_progress', (data) => {
if (getState().system.canceledSession === data.graph_execution_state_id) {
log.trace(
{ data, sessionId: data.graph_execution_state_id },
`Ignored generator progress (${data.node.type}) for canceled session (${data.graph_execution_state_id})`
);
return;
}
log.trace(
{ data, sessionId: data.graph_execution_state_id },
`Generator progress (${data.node.type})`
);
dispatch(generatorProgress({ data, timestamp: getTimestamp() })); dispatch(generatorProgress({ data, timestamp: getTimestamp() }));
}); });
@ -147,10 +91,6 @@ export const setEventListeners = (arg: SetEventListenersArg) => {
* Invocation error * Invocation error
*/ */
socket.on('invocation_error', (data) => { socket.on('invocation_error', (data) => {
log.error(
{ data, sessionId: data.graph_execution_state_id },
`Invocation error (${data.node.type})`
);
dispatch(invocationError({ data, timestamp: getTimestamp() })); dispatch(invocationError({ data, timestamp: getTimestamp() }));
}); });
@ -158,19 +98,6 @@ export const setEventListeners = (arg: SetEventListenersArg) => {
* Invocation complete * Invocation complete
*/ */
socket.on('invocation_complete', (data) => { socket.on('invocation_complete', (data) => {
log.info(
{ data, sessionId: data.graph_execution_state_id },
`Invocation complete (${data.node.type})`
);
const sessionId = data.graph_execution_state_id;
const { cancelType, isCancelScheduled } = getState().system;
// Handle scheduled cancelation
if (cancelType === 'scheduled' && isCancelScheduled) {
dispatch(sessionCanceled({ sessionId }));
}
dispatch( dispatch(
invocationComplete({ invocationComplete({
data, data,
@ -183,10 +110,6 @@ export const setEventListeners = (arg: SetEventListenersArg) => {
* Graph complete * Graph complete
*/ */
socket.on('graph_execution_state_complete', (data) => { socket.on('graph_execution_state_complete', (data) => {
log.info(
{ data, sessionId: data.graph_execution_state_id },
`Graph execution state complete (${data.graph_execution_state_id})`
);
dispatch(graphExecutionStateComplete({ data, timestamp: getTimestamp() })); dispatch(graphExecutionStateComplete({ data, timestamp: getTimestamp() }));
}); });
}; };

View File

@ -12,7 +12,7 @@ export const receivedResultImagesPage = createAppAsyncThunk(
const { page, pages, nextPage } = getState().results; const { page, pages, nextPage } = getState().results;
if (nextPage === page) { if (nextPage === page) {
rejectWithValue([]); return rejectWithValue([]);
} }
const response = await ImagesService.listImagesWithMetadata({ const response = await ImagesService.listImagesWithMetadata({
@ -30,7 +30,13 @@ export const receivedResultImagesPage = createAppAsyncThunk(
export const receivedUploadImagesPage = createAppAsyncThunk( export const receivedUploadImagesPage = createAppAsyncThunk(
'uploads/receivedUploadImagesPage', 'uploads/receivedUploadImagesPage',
async (_arg, { getState }) => { async (_arg, { getState, rejectWithValue }) => {
const { page, pages, nextPage } = getState().uploads;
if (nextPage === page) {
return rejectWithValue([]);
}
const response = await ImagesService.listImagesWithMetadata({ const response = await ImagesService.listImagesWithMetadata({
imageType: 'uploads', imageType: 'uploads',
imageCategory: 'general', imageCategory: 'general',

View File

@ -76,3 +76,19 @@ export const imageDeleted = createAppAsyncThunk(
return response; return response;
} }
); );
type ImageUpdatedArg = Parameters<(typeof ImagesService)['updateImage']>[0];
/**
* `ImagesService.deleteImage()` thunk
*/
export const imageUpdated = createAppAsyncThunk(
'api/imageUpdated',
async (arg: ImageUpdatedArg) => {
const response = await ImagesService.updateImage(arg);
imagesLog.debug({ arg, response }, 'Image updated');
return response;
}
);

View File

@ -35,6 +35,28 @@ export const sessionCreated = createAppAsyncThunk(
} }
); );
/**
* `SessionsService.createSession()` without graph thunk
*/
export const sessionWithoutGraphCreated = createAppAsyncThunk(
'api/sessionWithoutGraphCreated',
async (_, { rejectWithValue }) => {
try {
const response = await SessionsService.createSession({});
sessionLog.info({ response }, `Session created (${response.id})`);
return response;
} catch (err: any) {
sessionLog.error(
{
error: serializeError(err),
},
'Problem creating session'
);
return rejectWithValue(err.message);
}
}
);
type NodeAddedArg = Parameters<(typeof SessionsService)['addNode']>[0]; type NodeAddedArg = Parameters<(typeof SessionsService)['addNode']>[0];
/** /**
@ -57,6 +79,29 @@ export const nodeAdded = createAppAsyncThunk(
} }
); );
type NodeUpdatedArg = Parameters<(typeof SessionsService)['updateNode']>[0];
/**
* `SessionsService.addNode()` thunk
*/
export const nodeUpdated = createAppAsyncThunk(
'api/nodeUpdated',
async (
arg: { node: NodeUpdatedArg['requestBody']; sessionId: string },
_thunkApi
) => {
const response = await SessionsService.updateNode({
requestBody: arg.node,
sessionId: arg.sessionId,
nodePath: arg.node.id,
});
sessionLog.info({ arg, response }, `Node updated (${response})`);
return response;
}
);
/** /**
* `SessionsService.invokeSession()` thunk * `SessionsService.invokeSession()` thunk
*/ */