feat(ui): persist socket session ids and re-sub on connect

This commit is contained in:
psychedelicious
2023-04-27 17:29:59 +10:00
parent a8cec4c7e6
commit 5d8728c7ef
4 changed files with 88 additions and 72 deletions

View File

@ -114,6 +114,7 @@ export const store = configureStore({
'canvas/setBoundingBoxDimensions', 'canvas/setBoundingBoxDimensions',
'canvas/setIsDrawing', 'canvas/setIsDrawing',
'canvas/addPointToCurrentLine', 'canvas/addPointToCurrentLine',
'socket/generatorProgress',
], ],
}, },
}); });

View File

@ -32,6 +32,7 @@ import { receivedModels } from 'services/thunks/model';
import { receivedOpenAPISchema } from 'services/thunks/schema'; import { receivedOpenAPISchema } from 'services/thunks/schema';
import { isImageOutput } from 'services/types/guards'; import { isImageOutput } from 'services/types/guards';
import { imageReceived, thumbnailReceived } from 'services/thunks/image'; import { imageReceived, thumbnailReceived } from 'services/thunks/image';
import { setEventListeners } from './util/setEventListeners';
export const socketMiddleware = () => { export const socketMiddleware = () => {
let areListenersSet = false; let areListenersSet = false;
@ -66,25 +67,25 @@ export const socketMiddleware = () => {
(store: MiddlewareAPI<AppDispatch, RootState>) => (next) => (action) => { (store: MiddlewareAPI<AppDispatch, RootState>) => (next) => (action) => {
const { dispatch, getState } = store; const { dispatch, getState } = store;
// Nothing dispatches `socketReset` actions yet, so this is a noop, but including anyways // Nothing dispatches `socketReset` actions yet
if (socketReset.match(action)) { // if (socketReset.match(action)) {
const { sessionId } = getState().system; // const { sessionId } = getState().system;
if (sessionId) { // if (sessionId) {
socket.emit('unsubscribe', { session: sessionId }); // socket.emit('unsubscribe', { session: sessionId });
dispatch( // dispatch(
socketUnsubscribed({ sessionId, timestamp: getTimestamp() }) // socketUnsubscribed({ sessionId, timestamp: getTimestamp() })
); // );
} // }
if (socket.connected) { // if (socket.connected) {
socket.disconnect(); // socket.disconnect();
dispatch(socketDisconnected({ timestamp: getTimestamp() })); // dispatch(socketDisconnected({ timestamp: getTimestamp() }));
} // }
socket.removeAllListeners(); // socket.removeAllListeners();
areListenersSet = false; // areListenersSet = false;
} // }
// Set listeners for `connect` and `disconnect` events once // Set listeners for `connect` and `disconnect` events once
// Must happen in middleware to get access to `dispatch` // Must happen in middleware to get access to `dispatch`
@ -92,7 +93,8 @@ export const socketMiddleware = () => {
socket.on('connect', () => { socket.on('connect', () => {
dispatch(socketConnected({ timestamp: getTimestamp() })); dispatch(socketConnected({ timestamp: getTimestamp() }));
const { results, uploads, models, nodes, config } = getState(); const { results, uploads, models, nodes, config, system } =
getState();
const { disabledTabs } = config; const { disabledTabs } = config;
@ -112,6 +114,18 @@ export const socketMiddleware = () => {
if (!nodes.schema && !disabledTabs.includes('nodes')) { if (!nodes.schema && !disabledTabs.includes('nodes')) {
dispatch(receivedOpenAPISchema()); dispatch(receivedOpenAPISchema());
} }
if (system.sessionId) {
console.log(`Re-subscribing to session ${system.sessionId}`);
socket.emit('subscribe', { session: system.sessionId });
dispatch(
socketSubscribed({
sessionId: system.sessionId,
timestamp: getTimestamp(),
})
);
setEventListeners({ socket, store });
}
}); });
socket.on('disconnect', () => { socket.on('disconnect', () => {
@ -128,9 +142,6 @@ export const socketMiddleware = () => {
if (isFulfilledSessionCreatedAction(action)) { if (isFulfilledSessionCreatedAction(action)) {
const oldSessionId = getState().system.sessionId; const oldSessionId = getState().system.sessionId;
// temp disable event subscription
const shouldHandleEvent = (id: string): boolean => true;
// const subscribedNodeIds = getState().system.subscribedNodeIds; // const subscribedNodeIds = getState().system.subscribedNodeIds;
// const shouldHandleEvent = (id: string): boolean => { // const shouldHandleEvent = (id: string): boolean => {
// if (subscribedNodeIds.length === 1 && subscribedNodeIds[0] === '*') { // if (subscribedNodeIds.length === 1 && subscribedNodeIds[0] === '*') {
@ -152,7 +163,6 @@ export const socketMiddleware = () => {
timestamp: getTimestamp(), timestamp: getTimestamp(),
}) })
); );
const listenersToRemove: (keyof ServerToClientEvents)[] = [ const listenersToRemove: (keyof ServerToClientEvents)[] = [
'invocation_started', 'invocation_started',
'generator_progress', 'generator_progress',
@ -168,57 +178,14 @@ export const socketMiddleware = () => {
const sessionId = action.payload.id; const sessionId = action.payload.id;
// After a session is created, we immediately subscribe to events and then invoke the session
socket.emit('subscribe', { session: sessionId }); socket.emit('subscribe', { session: sessionId });
// Always dispatch the event actions for other consumers who want to know when we subscribed
dispatch( dispatch(
socketSubscribed({ socketSubscribed({
sessionId, sessionId: sessionId,
timestamp: getTimestamp(), timestamp: getTimestamp(),
}) })
); );
setEventListeners({ socket, store });
// Set up listeners for the present subscription
socket.on('invocation_started', (data) => {
if (shouldHandleEvent(data.node.id)) {
dispatch(invocationStarted({ data, timestamp: getTimestamp() }));
}
});
socket.on('generator_progress', (data) => {
if (shouldHandleEvent(data.node.id)) {
dispatch(generatorProgress({ data, timestamp: getTimestamp() }));
}
});
socket.on('invocation_error', (data) => {
if (shouldHandleEvent(data.node.id)) {
dispatch(invocationError({ data, timestamp: getTimestamp() }));
}
});
socket.on('invocation_complete', (data) => {
if (shouldHandleEvent(data.node.id)) {
const sessionId = data.graph_execution_state_id;
const { cancelType, isCancelScheduled } = getState().system;
const { shouldFetchImages } = getState().config;
// Handle scheduled cancelation
if (cancelType === 'scheduled' && isCancelScheduled) {
dispatch(sessionCanceled({ sessionId }));
}
dispatch(
invocationComplete({
data,
timestamp: getTimestamp(),
shouldFetchImages,
})
);
}
});
// Finally we actually invoke the session, starting processing // Finally we actually invoke the session, starting processing
dispatch(sessionInvoked({ sessionId })); dispatch(sessionInvoked({ sessionId }));

View File

@ -15,12 +15,6 @@ export type AnyInvocationType = NonNullable<
export type AnyInvocation = NonNullable<Graph['nodes']>[string]; export type AnyInvocation = NonNullable<Graph['nodes']>[string];
// export type AnyInvocation = {
// id: string;
// type: AnyInvocationType | string;
// [key: string]: any;
// };
export type AnyResult = GraphExecutionState['results'][string]; export type AnyResult = GraphExecutionState['results'][string];
/** /**

View File

@ -0,0 +1,54 @@
import { MiddlewareAPI } from '@reduxjs/toolkit';
import { AppDispatch, RootState } from 'app/store';
import { getTimestamp } from 'common/util/getTimestamp';
import { sessionCanceled } from 'services/thunks/session';
import { Socket } from 'socket.io-client';
import {
generatorProgress,
invocationComplete,
invocationError,
invocationStarted,
} from '../actions';
import { ClientToServerEvents, ServerToClientEvents } from '../types';
type SetEventListenersArg = {
socket: Socket<ServerToClientEvents, ClientToServerEvents>;
store: MiddlewareAPI<AppDispatch, RootState>;
};
export const setEventListeners = (arg: SetEventListenersArg) => {
const { socket, store } = arg;
const { dispatch, getState } = store;
// Set up listeners for the present subscription
socket.on('invocation_started', (data) => {
dispatch(invocationStarted({ data, timestamp: getTimestamp() }));
});
socket.on('generator_progress', (data) => {
dispatch(generatorProgress({ data, timestamp: getTimestamp() }));
});
socket.on('invocation_error', (data) => {
dispatch(invocationError({ data, timestamp: getTimestamp() }));
});
socket.on('invocation_complete', (data) => {
const sessionId = data.graph_execution_state_id;
const { cancelType, isCancelScheduled } = getState().system;
const { shouldFetchImages } = getState().config;
// Handle scheduled cancelation
if (cancelType === 'scheduled' && isCancelScheduled) {
dispatch(sessionCanceled({ sessionId }));
}
dispatch(
invocationComplete({
data,
timestamp: getTimestamp(),
shouldFetchImages,
})
);
});
};