mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): persist socket session ids and re-sub on connect
This commit is contained in:
@ -114,6 +114,7 @@ export const store = configureStore({
|
|||||||
'canvas/setBoundingBoxDimensions',
|
'canvas/setBoundingBoxDimensions',
|
||||||
'canvas/setIsDrawing',
|
'canvas/setIsDrawing',
|
||||||
'canvas/addPointToCurrentLine',
|
'canvas/addPointToCurrentLine',
|
||||||
|
'socket/generatorProgress',
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
@ -32,6 +32,7 @@ import { receivedModels } from 'services/thunks/model';
|
|||||||
import { receivedOpenAPISchema } from 'services/thunks/schema';
|
import { receivedOpenAPISchema } from 'services/thunks/schema';
|
||||||
import { isImageOutput } from 'services/types/guards';
|
import { isImageOutput } from 'services/types/guards';
|
||||||
import { imageReceived, thumbnailReceived } from 'services/thunks/image';
|
import { imageReceived, thumbnailReceived } from 'services/thunks/image';
|
||||||
|
import { setEventListeners } from './util/setEventListeners';
|
||||||
|
|
||||||
export const socketMiddleware = () => {
|
export const socketMiddleware = () => {
|
||||||
let areListenersSet = false;
|
let areListenersSet = false;
|
||||||
@ -66,25 +67,25 @@ export const socketMiddleware = () => {
|
|||||||
(store: MiddlewareAPI<AppDispatch, RootState>) => (next) => (action) => {
|
(store: MiddlewareAPI<AppDispatch, RootState>) => (next) => (action) => {
|
||||||
const { dispatch, getState } = store;
|
const { dispatch, getState } = store;
|
||||||
|
|
||||||
// Nothing dispatches `socketReset` actions yet, so this is a noop, but including anyways
|
// Nothing dispatches `socketReset` actions yet
|
||||||
if (socketReset.match(action)) {
|
// if (socketReset.match(action)) {
|
||||||
const { sessionId } = getState().system;
|
// const { sessionId } = getState().system;
|
||||||
|
|
||||||
if (sessionId) {
|
// if (sessionId) {
|
||||||
socket.emit('unsubscribe', { session: sessionId });
|
// socket.emit('unsubscribe', { session: sessionId });
|
||||||
dispatch(
|
// dispatch(
|
||||||
socketUnsubscribed({ sessionId, timestamp: getTimestamp() })
|
// socketUnsubscribed({ sessionId, timestamp: getTimestamp() })
|
||||||
);
|
// );
|
||||||
}
|
// }
|
||||||
|
|
||||||
if (socket.connected) {
|
// if (socket.connected) {
|
||||||
socket.disconnect();
|
// socket.disconnect();
|
||||||
dispatch(socketDisconnected({ timestamp: getTimestamp() }));
|
// dispatch(socketDisconnected({ timestamp: getTimestamp() }));
|
||||||
}
|
// }
|
||||||
|
|
||||||
socket.removeAllListeners();
|
// socket.removeAllListeners();
|
||||||
areListenersSet = false;
|
// areListenersSet = false;
|
||||||
}
|
// }
|
||||||
|
|
||||||
// Set listeners for `connect` and `disconnect` events once
|
// Set listeners for `connect` and `disconnect` events once
|
||||||
// Must happen in middleware to get access to `dispatch`
|
// Must happen in middleware to get access to `dispatch`
|
||||||
@ -92,7 +93,8 @@ export const socketMiddleware = () => {
|
|||||||
socket.on('connect', () => {
|
socket.on('connect', () => {
|
||||||
dispatch(socketConnected({ timestamp: getTimestamp() }));
|
dispatch(socketConnected({ timestamp: getTimestamp() }));
|
||||||
|
|
||||||
const { results, uploads, models, nodes, config } = getState();
|
const { results, uploads, models, nodes, config, system } =
|
||||||
|
getState();
|
||||||
|
|
||||||
const { disabledTabs } = config;
|
const { disabledTabs } = config;
|
||||||
|
|
||||||
@ -112,6 +114,18 @@ export const socketMiddleware = () => {
|
|||||||
if (!nodes.schema && !disabledTabs.includes('nodes')) {
|
if (!nodes.schema && !disabledTabs.includes('nodes')) {
|
||||||
dispatch(receivedOpenAPISchema());
|
dispatch(receivedOpenAPISchema());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (system.sessionId) {
|
||||||
|
console.log(`Re-subscribing to session ${system.sessionId}`);
|
||||||
|
socket.emit('subscribe', { session: system.sessionId });
|
||||||
|
dispatch(
|
||||||
|
socketSubscribed({
|
||||||
|
sessionId: system.sessionId,
|
||||||
|
timestamp: getTimestamp(),
|
||||||
|
})
|
||||||
|
);
|
||||||
|
setEventListeners({ socket, store });
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
socket.on('disconnect', () => {
|
socket.on('disconnect', () => {
|
||||||
@ -128,9 +142,6 @@ export const socketMiddleware = () => {
|
|||||||
if (isFulfilledSessionCreatedAction(action)) {
|
if (isFulfilledSessionCreatedAction(action)) {
|
||||||
const oldSessionId = getState().system.sessionId;
|
const oldSessionId = getState().system.sessionId;
|
||||||
|
|
||||||
// temp disable event subscription
|
|
||||||
const shouldHandleEvent = (id: string): boolean => true;
|
|
||||||
|
|
||||||
// const subscribedNodeIds = getState().system.subscribedNodeIds;
|
// const subscribedNodeIds = getState().system.subscribedNodeIds;
|
||||||
// const shouldHandleEvent = (id: string): boolean => {
|
// const shouldHandleEvent = (id: string): boolean => {
|
||||||
// if (subscribedNodeIds.length === 1 && subscribedNodeIds[0] === '*') {
|
// if (subscribedNodeIds.length === 1 && subscribedNodeIds[0] === '*') {
|
||||||
@ -152,7 +163,6 @@ export const socketMiddleware = () => {
|
|||||||
timestamp: getTimestamp(),
|
timestamp: getTimestamp(),
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
|
|
||||||
const listenersToRemove: (keyof ServerToClientEvents)[] = [
|
const listenersToRemove: (keyof ServerToClientEvents)[] = [
|
||||||
'invocation_started',
|
'invocation_started',
|
||||||
'generator_progress',
|
'generator_progress',
|
||||||
@ -168,57 +178,14 @@ export const socketMiddleware = () => {
|
|||||||
|
|
||||||
const sessionId = action.payload.id;
|
const sessionId = action.payload.id;
|
||||||
|
|
||||||
// After a session is created, we immediately subscribe to events and then invoke the session
|
|
||||||
socket.emit('subscribe', { session: sessionId });
|
socket.emit('subscribe', { session: sessionId });
|
||||||
|
|
||||||
// Always dispatch the event actions for other consumers who want to know when we subscribed
|
|
||||||
dispatch(
|
dispatch(
|
||||||
socketSubscribed({
|
socketSubscribed({
|
||||||
sessionId,
|
sessionId: sessionId,
|
||||||
timestamp: getTimestamp(),
|
timestamp: getTimestamp(),
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
|
setEventListeners({ socket, store });
|
||||||
// Set up listeners for the present subscription
|
|
||||||
socket.on('invocation_started', (data) => {
|
|
||||||
if (shouldHandleEvent(data.node.id)) {
|
|
||||||
dispatch(invocationStarted({ data, timestamp: getTimestamp() }));
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
socket.on('generator_progress', (data) => {
|
|
||||||
if (shouldHandleEvent(data.node.id)) {
|
|
||||||
dispatch(generatorProgress({ data, timestamp: getTimestamp() }));
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
socket.on('invocation_error', (data) => {
|
|
||||||
if (shouldHandleEvent(data.node.id)) {
|
|
||||||
dispatch(invocationError({ data, timestamp: getTimestamp() }));
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
socket.on('invocation_complete', (data) => {
|
|
||||||
if (shouldHandleEvent(data.node.id)) {
|
|
||||||
const sessionId = data.graph_execution_state_id;
|
|
||||||
|
|
||||||
const { cancelType, isCancelScheduled } = getState().system;
|
|
||||||
const { shouldFetchImages } = getState().config;
|
|
||||||
|
|
||||||
// Handle scheduled cancelation
|
|
||||||
if (cancelType === 'scheduled' && isCancelScheduled) {
|
|
||||||
dispatch(sessionCanceled({ sessionId }));
|
|
||||||
}
|
|
||||||
|
|
||||||
dispatch(
|
|
||||||
invocationComplete({
|
|
||||||
data,
|
|
||||||
timestamp: getTimestamp(),
|
|
||||||
shouldFetchImages,
|
|
||||||
})
|
|
||||||
);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// Finally we actually invoke the session, starting processing
|
// Finally we actually invoke the session, starting processing
|
||||||
dispatch(sessionInvoked({ sessionId }));
|
dispatch(sessionInvoked({ sessionId }));
|
||||||
|
@ -15,12 +15,6 @@ export type AnyInvocationType = NonNullable<
|
|||||||
|
|
||||||
export type AnyInvocation = NonNullable<Graph['nodes']>[string];
|
export type AnyInvocation = NonNullable<Graph['nodes']>[string];
|
||||||
|
|
||||||
// export type AnyInvocation = {
|
|
||||||
// id: string;
|
|
||||||
// type: AnyInvocationType | string;
|
|
||||||
// [key: string]: any;
|
|
||||||
// };
|
|
||||||
|
|
||||||
export type AnyResult = GraphExecutionState['results'][string];
|
export type AnyResult = GraphExecutionState['results'][string];
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -0,0 +1,54 @@
|
|||||||
|
import { MiddlewareAPI } from '@reduxjs/toolkit';
|
||||||
|
import { AppDispatch, RootState } from 'app/store';
|
||||||
|
import { getTimestamp } from 'common/util/getTimestamp';
|
||||||
|
import { sessionCanceled } from 'services/thunks/session';
|
||||||
|
import { Socket } from 'socket.io-client';
|
||||||
|
import {
|
||||||
|
generatorProgress,
|
||||||
|
invocationComplete,
|
||||||
|
invocationError,
|
||||||
|
invocationStarted,
|
||||||
|
} from '../actions';
|
||||||
|
import { ClientToServerEvents, ServerToClientEvents } from '../types';
|
||||||
|
|
||||||
|
type SetEventListenersArg = {
|
||||||
|
socket: Socket<ServerToClientEvents, ClientToServerEvents>;
|
||||||
|
store: MiddlewareAPI<AppDispatch, RootState>;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const setEventListeners = (arg: SetEventListenersArg) => {
|
||||||
|
const { socket, store } = arg;
|
||||||
|
const { dispatch, getState } = store;
|
||||||
|
// Set up listeners for the present subscription
|
||||||
|
socket.on('invocation_started', (data) => {
|
||||||
|
dispatch(invocationStarted({ data, timestamp: getTimestamp() }));
|
||||||
|
});
|
||||||
|
|
||||||
|
socket.on('generator_progress', (data) => {
|
||||||
|
dispatch(generatorProgress({ data, timestamp: getTimestamp() }));
|
||||||
|
});
|
||||||
|
|
||||||
|
socket.on('invocation_error', (data) => {
|
||||||
|
dispatch(invocationError({ data, timestamp: getTimestamp() }));
|
||||||
|
});
|
||||||
|
|
||||||
|
socket.on('invocation_complete', (data) => {
|
||||||
|
const sessionId = data.graph_execution_state_id;
|
||||||
|
|
||||||
|
const { cancelType, isCancelScheduled } = getState().system;
|
||||||
|
const { shouldFetchImages } = getState().config;
|
||||||
|
|
||||||
|
// Handle scheduled cancelation
|
||||||
|
if (cancelType === 'scheduled' && isCancelScheduled) {
|
||||||
|
dispatch(sessionCanceled({ sessionId }));
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
invocationComplete({
|
||||||
|
data,
|
||||||
|
timestamp: getTimestamp(),
|
||||||
|
shouldFetchImages,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
});
|
||||||
|
};
|
Reference in New Issue
Block a user