mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): clean up and simplify socketio middleware
This commit is contained in:
parent
2e4e9434c1
commit
2eb7c25bae
@ -7,15 +7,9 @@ import {
|
|||||||
} from 'services/events/types';
|
} from 'services/events/types';
|
||||||
import {
|
import {
|
||||||
invocationComplete,
|
invocationComplete,
|
||||||
socketConnected,
|
|
||||||
socketDisconnected,
|
|
||||||
socketSubscribed,
|
socketSubscribed,
|
||||||
socketUnsubscribed,
|
socketUnsubscribed,
|
||||||
} from './actions';
|
} from './actions';
|
||||||
import {
|
|
||||||
receivedResultImagesPage,
|
|
||||||
receivedUploadImagesPage,
|
|
||||||
} from 'services/thunks/gallery';
|
|
||||||
import { AppDispatch, RootState } from 'app/store/store';
|
import { AppDispatch, RootState } from 'app/store/store';
|
||||||
import { getTimestamp } from 'common/util/getTimestamp';
|
import { getTimestamp } from 'common/util/getTimestamp';
|
||||||
import {
|
import {
|
||||||
@ -23,14 +17,12 @@ import {
|
|||||||
isFulfilledSessionCreatedAction,
|
isFulfilledSessionCreatedAction,
|
||||||
} from 'services/thunks/session';
|
} from 'services/thunks/session';
|
||||||
import { OpenAPI } from 'services/api';
|
import { OpenAPI } from 'services/api';
|
||||||
import { receivedModels } from 'services/thunks/model';
|
|
||||||
import { receivedOpenAPISchema } from 'services/thunks/schema';
|
|
||||||
import { isImageOutput } from 'services/types/guards';
|
import { isImageOutput } from 'services/types/guards';
|
||||||
import { imageReceived, thumbnailReceived } from 'services/thunks/image';
|
import { imageReceived, thumbnailReceived } from 'services/thunks/image';
|
||||||
import { setEventListeners } from 'services/events/util/setEventListeners';
|
import { setEventListeners } from 'services/events/util/setEventListeners';
|
||||||
import { log } from 'app/logging/useLogger';
|
import { log } from 'app/logging/useLogger';
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'socketio' });
|
const socketioLog = log.child({ namespace: 'socketio' });
|
||||||
|
|
||||||
export const socketMiddleware = () => {
|
export const socketMiddleware = () => {
|
||||||
let areListenersSet = false;
|
let areListenersSet = false;
|
||||||
@ -65,106 +57,27 @@ export const socketMiddleware = () => {
|
|||||||
(store: MiddlewareAPI<AppDispatch, RootState>) => (next) => (action) => {
|
(store: MiddlewareAPI<AppDispatch, RootState>) => (next) => (action) => {
|
||||||
const { dispatch, getState } = store;
|
const { dispatch, getState } = store;
|
||||||
|
|
||||||
// Nothing dispatches `socketReset` actions yet
|
|
||||||
// if (socketReset.match(action)) {
|
|
||||||
// const { sessionId } = getState().system;
|
|
||||||
|
|
||||||
// if (sessionId) {
|
|
||||||
// socket.emit('unsubscribe', { session: sessionId });
|
|
||||||
// dispatch(
|
|
||||||
// socketUnsubscribed({ sessionId, timestamp: getTimestamp() })
|
|
||||||
// );
|
|
||||||
// }
|
|
||||||
|
|
||||||
// if (socket.connected) {
|
|
||||||
// socket.disconnect();
|
|
||||||
// dispatch(socketDisconnected({ timestamp: getTimestamp() }));
|
|
||||||
// }
|
|
||||||
|
|
||||||
// socket.removeAllListeners();
|
|
||||||
// areListenersSet = false;
|
|
||||||
// }
|
|
||||||
|
|
||||||
// Set listeners for `connect` and `disconnect` events once
|
// Set listeners for `connect` and `disconnect` events once
|
||||||
// Must happen in middleware to get access to `dispatch`
|
// Must happen in middleware to get access to `dispatch`
|
||||||
if (!areListenersSet) {
|
if (!areListenersSet) {
|
||||||
socket.on('connect', () => {
|
setEventListeners({ store, socket, log: socketioLog });
|
||||||
moduleLog.debug('Connected');
|
|
||||||
|
|
||||||
dispatch(socketConnected({ timestamp: getTimestamp() }));
|
|
||||||
|
|
||||||
const { results, uploads, models, nodes, config, system } =
|
|
||||||
getState();
|
|
||||||
|
|
||||||
const { disabledTabs } = config;
|
|
||||||
|
|
||||||
// These thunks need to be dispatch in middleware; cannot handle in a reducer
|
|
||||||
if (!results.ids.length) {
|
|
||||||
dispatch(receivedResultImagesPage());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!uploads.ids.length) {
|
|
||||||
dispatch(receivedUploadImagesPage());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!models.ids.length) {
|
|
||||||
dispatch(receivedModels());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!nodes.schema && !disabledTabs.includes('nodes')) {
|
|
||||||
dispatch(receivedOpenAPISchema());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (system.sessionId) {
|
|
||||||
const sessionLog = moduleLog.child({ sessionId: system.sessionId });
|
|
||||||
|
|
||||||
sessionLog.debug(
|
|
||||||
`Subscribed to existing session (${system.sessionId})`
|
|
||||||
);
|
|
||||||
|
|
||||||
socket.emit('subscribe', { session: system.sessionId });
|
|
||||||
dispatch(
|
|
||||||
socketSubscribed({
|
|
||||||
sessionId: system.sessionId,
|
|
||||||
timestamp: getTimestamp(),
|
|
||||||
})
|
|
||||||
);
|
|
||||||
setEventListeners({ socket, store, sessionLog });
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
socket.on('disconnect', () => {
|
|
||||||
moduleLog.debug('Disconnected');
|
|
||||||
dispatch(socketDisconnected({ timestamp: getTimestamp() }));
|
|
||||||
});
|
|
||||||
|
|
||||||
areListenersSet = true;
|
areListenersSet = true;
|
||||||
|
|
||||||
// must manually connect
|
|
||||||
socket.connect();
|
socket.connect();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Everything else only happens once we have created a session
|
|
||||||
if (isFulfilledSessionCreatedAction(action)) {
|
if (isFulfilledSessionCreatedAction(action)) {
|
||||||
const sessionId = action.payload.id;
|
const sessionId = action.payload.id;
|
||||||
const sessionLog = moduleLog.child({ sessionId });
|
const sessionLog = socketioLog.child({ sessionId });
|
||||||
const oldSessionId = getState().system.sessionId;
|
const oldSessionId = getState().system.sessionId;
|
||||||
|
|
||||||
// const subscribedNodeIds = getState().system.subscribedNodeIds;
|
|
||||||
// const shouldHandleEvent = (id: string): boolean => {
|
|
||||||
// if (subscribedNodeIds.length === 1 && subscribedNodeIds[0] === '*') {
|
|
||||||
// return true;
|
|
||||||
// }
|
|
||||||
|
|
||||||
// return subscribedNodeIds.includes(id);
|
|
||||||
// };
|
|
||||||
|
|
||||||
if (oldSessionId) {
|
if (oldSessionId) {
|
||||||
sessionLog.debug(
|
sessionLog.debug(
|
||||||
{ oldSessionId },
|
{ oldSessionId },
|
||||||
`Unsubscribed from old session (${oldSessionId})`
|
`Unsubscribed from old session (${oldSessionId})`
|
||||||
);
|
);
|
||||||
// Unsubscribe when invocations complete
|
|
||||||
socket.emit('unsubscribe', {
|
socket.emit('unsubscribe', {
|
||||||
session: oldSessionId,
|
session: oldSessionId,
|
||||||
});
|
});
|
||||||
@ -175,28 +88,18 @@ export const socketMiddleware = () => {
|
|||||||
timestamp: getTimestamp(),
|
timestamp: getTimestamp(),
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
const listenersToRemove: (keyof ServerToClientEvents)[] = [
|
|
||||||
'invocation_started',
|
|
||||||
'generator_progress',
|
|
||||||
'invocation_error',
|
|
||||||
'invocation_complete',
|
|
||||||
];
|
|
||||||
|
|
||||||
// Remove listeners for these events; we need to set them up fresh whenever we subscribe
|
|
||||||
listenersToRemove.forEach((event: keyof ServerToClientEvents) => {
|
|
||||||
socket.removeAllListeners(event);
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionLog.debug(`Subscribe to new session (${sessionId})`);
|
sessionLog.debug(`Subscribe to new session (${sessionId})`);
|
||||||
|
|
||||||
socket.emit('subscribe', { session: sessionId });
|
socket.emit('subscribe', { session: sessionId });
|
||||||
|
|
||||||
dispatch(
|
dispatch(
|
||||||
socketSubscribed({
|
socketSubscribed({
|
||||||
sessionId: sessionId,
|
sessionId: sessionId,
|
||||||
timestamp: getTimestamp(),
|
timestamp: getTimestamp(),
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
setEventListeners({ socket, store, sessionLog });
|
|
||||||
|
|
||||||
// Finally we actually invoke the session, starting processing
|
// Finally we actually invoke the session, starting processing
|
||||||
dispatch(sessionInvoked({ sessionId }));
|
dispatch(sessionInvoked({ sessionId }));
|
||||||
@ -222,7 +125,6 @@ export const socketMiddleware = () => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Always pass the action on so other middleware and reducers can handle it
|
|
||||||
next(action);
|
next(action);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -9,38 +9,124 @@ import {
|
|||||||
invocationComplete,
|
invocationComplete,
|
||||||
invocationError,
|
invocationError,
|
||||||
invocationStarted,
|
invocationStarted,
|
||||||
|
socketConnected,
|
||||||
|
socketDisconnected,
|
||||||
|
socketSubscribed,
|
||||||
} from '../actions';
|
} from '../actions';
|
||||||
import { ClientToServerEvents, ServerToClientEvents } from '../types';
|
import { ClientToServerEvents, ServerToClientEvents } from '../types';
|
||||||
import { Logger } from 'roarr';
|
import { Logger } from 'roarr';
|
||||||
import { JsonObject } from 'roarr/dist/types';
|
import { JsonObject } from 'roarr/dist/types';
|
||||||
|
import {
|
||||||
|
receivedResultImagesPage,
|
||||||
|
receivedUploadImagesPage,
|
||||||
|
} from 'services/thunks/gallery';
|
||||||
|
import { receivedModels } from 'services/thunks/model';
|
||||||
|
import { receivedOpenAPISchema } from 'services/thunks/schema';
|
||||||
|
|
||||||
type SetEventListenersArg = {
|
type SetEventListenersArg = {
|
||||||
socket: Socket<ServerToClientEvents, ClientToServerEvents>;
|
socket: Socket<ServerToClientEvents, ClientToServerEvents>;
|
||||||
store: MiddlewareAPI<AppDispatch, RootState>;
|
store: MiddlewareAPI<AppDispatch, RootState>;
|
||||||
sessionLog: Logger<JsonObject>;
|
log: Logger<JsonObject>;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const setEventListeners = (arg: SetEventListenersArg) => {
|
export const setEventListeners = (arg: SetEventListenersArg) => {
|
||||||
const { socket, store, sessionLog } = arg;
|
const { socket, store, log } = arg;
|
||||||
const { dispatch, getState } = store;
|
const { dispatch, getState } = store;
|
||||||
// Set up listeners for the present subscription
|
|
||||||
|
/**
|
||||||
|
* Connect
|
||||||
|
*/
|
||||||
|
socket.on('connect', () => {
|
||||||
|
log.debug('Connected');
|
||||||
|
|
||||||
|
dispatch(socketConnected({ timestamp: getTimestamp() }));
|
||||||
|
|
||||||
|
const { results, uploads, models, nodes, config, system } = getState();
|
||||||
|
|
||||||
|
const { disabledTabs } = config;
|
||||||
|
|
||||||
|
// These thunks need to be dispatch in middleware; cannot handle in a reducer
|
||||||
|
if (!results.ids.length) {
|
||||||
|
dispatch(receivedResultImagesPage());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!uploads.ids.length) {
|
||||||
|
dispatch(receivedUploadImagesPage());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!models.ids.length) {
|
||||||
|
dispatch(receivedModels());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!nodes.schema && !disabledTabs.includes('nodes')) {
|
||||||
|
dispatch(receivedOpenAPISchema());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (system.sessionId) {
|
||||||
|
log.debug(
|
||||||
|
{ sessionId: system.sessionId },
|
||||||
|
`Subscribed to existing session (${system.sessionId})`
|
||||||
|
);
|
||||||
|
|
||||||
|
socket.emit('subscribe', { session: system.sessionId });
|
||||||
|
dispatch(
|
||||||
|
socketSubscribed({
|
||||||
|
sessionId: system.sessionId,
|
||||||
|
timestamp: getTimestamp(),
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Disconnect
|
||||||
|
*/
|
||||||
|
socket.on('disconnect', () => {
|
||||||
|
log.debug('Disconnected');
|
||||||
|
dispatch(socketDisconnected({ timestamp: getTimestamp() }));
|
||||||
|
});
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Invocation started
|
||||||
|
*/
|
||||||
socket.on('invocation_started', (data) => {
|
socket.on('invocation_started', (data) => {
|
||||||
sessionLog.child({ data }).info(`Invocation started (${data.node.type})`);
|
log.info(
|
||||||
|
{ data, sessionId: data.graph_execution_state_id },
|
||||||
|
`Invocation started (${data.node.type})`
|
||||||
|
);
|
||||||
dispatch(invocationStarted({ data, timestamp: getTimestamp() }));
|
dispatch(invocationStarted({ data, timestamp: getTimestamp() }));
|
||||||
});
|
});
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generator progress
|
||||||
|
*/
|
||||||
socket.on('generator_progress', (data) => {
|
socket.on('generator_progress', (data) => {
|
||||||
sessionLog.child({ data }).trace(`Generator progress (${data.node.type})`);
|
log.trace(
|
||||||
|
{ data, sessionId: data.graph_execution_state_id },
|
||||||
|
`Generator progress (${data.node.type})`
|
||||||
|
);
|
||||||
dispatch(generatorProgress({ data, timestamp: getTimestamp() }));
|
dispatch(generatorProgress({ data, timestamp: getTimestamp() }));
|
||||||
});
|
});
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Invocation error
|
||||||
|
*/
|
||||||
socket.on('invocation_error', (data) => {
|
socket.on('invocation_error', (data) => {
|
||||||
sessionLog.child({ data }).error(`Invocation error (${data.node.type})`);
|
log.error(
|
||||||
|
{ data, sessionId: data.graph_execution_state_id },
|
||||||
|
`Invocation error (${data.node.type})`
|
||||||
|
);
|
||||||
dispatch(invocationError({ data, timestamp: getTimestamp() }));
|
dispatch(invocationError({ data, timestamp: getTimestamp() }));
|
||||||
});
|
});
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Invocation complete
|
||||||
|
*/
|
||||||
socket.on('invocation_complete', (data) => {
|
socket.on('invocation_complete', (data) => {
|
||||||
sessionLog.child({ data }).info(`Invocation complete (${data.node.type})`);
|
log.info(
|
||||||
|
{ data, sessionId: data.graph_execution_state_id },
|
||||||
|
`Invocation complete (${data.node.type})`
|
||||||
|
);
|
||||||
const sessionId = data.graph_execution_state_id;
|
const sessionId = data.graph_execution_state_id;
|
||||||
|
|
||||||
const { cancelType, isCancelScheduled } = getState().system;
|
const { cancelType, isCancelScheduled } = getState().system;
|
||||||
@ -60,10 +146,12 @@ export const setEventListeners = (arg: SetEventListenersArg) => {
|
|||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Graph complete
|
||||||
|
*/
|
||||||
socket.on('graph_execution_state_complete', (data) => {
|
socket.on('graph_execution_state_complete', (data) => {
|
||||||
sessionLog
|
log.info(
|
||||||
.child({ data })
|
{ data, sessionId: data.graph_execution_state_id },
|
||||||
.info(
|
|
||||||
`Graph execution state complete (${data.graph_execution_state_id})`
|
`Graph execution state complete (${data.graph_execution_state_id})`
|
||||||
);
|
);
|
||||||
dispatch(graphExecutionStateComplete({ data, timestamp: getTimestamp() }));
|
dispatch(graphExecutionStateComplete({ data, timestamp: getTimestamp() }));
|
||||||
|
Loading…
x
Reference in New Issue
Block a user