From 40b2d2b05b0133840adbb19a0261adf0d0a4d268 Mon Sep 17 00:00:00 2001 From: maryhipp Date: Thu, 30 Mar 2023 09:19:09 -0700 Subject: [PATCH] feat(ui): write separate nodes socket layer, txt2img generating and rendering w single node --- .../web/src/app/nodesSocketio/actions.ts | 11 ++ .../web/src/app/nodesSocketio/emitters.ts | 15 ++ .../web/src/app/nodesSocketio/listeners.ts | 158 ++++++++++++++++++ .../web/src/app/nodesSocketio/middleware.ts | 78 +++++++++ invokeai/frontend/web/src/app/store.ts | 11 +- .../web/src/services/invokeMiddleware.ts | 7 +- 6 files changed, 277 insertions(+), 3 deletions(-) create mode 100644 invokeai/frontend/web/src/app/nodesSocketio/actions.ts create mode 100644 invokeai/frontend/web/src/app/nodesSocketio/emitters.ts create mode 100644 invokeai/frontend/web/src/app/nodesSocketio/listeners.ts create mode 100644 invokeai/frontend/web/src/app/nodesSocketio/middleware.ts diff --git a/invokeai/frontend/web/src/app/nodesSocketio/actions.ts b/invokeai/frontend/web/src/app/nodesSocketio/actions.ts new file mode 100644 index 0000000000..35eac2e81e --- /dev/null +++ b/invokeai/frontend/web/src/app/nodesSocketio/actions.ts @@ -0,0 +1,11 @@ +import { createAction } from '@reduxjs/toolkit'; + +/** + * We can't use redux-toolkit's createSlice() to make these actions, + * because they have no associated reducer. They only exist to dispatch + * requests to the server via socketio. These actions will be handled + * by the middleware. + */ + +export const emitSubscribe = createAction('socketio/subscribe'); +export const emitUnsubscribe = createAction('socketio/unsubscribe'); diff --git a/invokeai/frontend/web/src/app/nodesSocketio/emitters.ts b/invokeai/frontend/web/src/app/nodesSocketio/emitters.ts new file mode 100644 index 0000000000..7155b5ae22 --- /dev/null +++ b/invokeai/frontend/web/src/app/nodesSocketio/emitters.ts @@ -0,0 +1,15 @@ +import { Socket } from 'socket.io-client'; + +const makeSocketIOEmitters = (socketio: Socket) => { + return { + emitSubscribe: (sessionId: string) => { + socketio.emit('subscribe', { session: sessionId }); + }, + + emitUnsubscribe: (sessionId: string) => { + socketio.emit('unsubscribe', { session: sessionId }); + }, + }; +}; + +export default makeSocketIOEmitters; diff --git a/invokeai/frontend/web/src/app/nodesSocketio/listeners.ts b/invokeai/frontend/web/src/app/nodesSocketio/listeners.ts new file mode 100644 index 0000000000..7a47f6a448 --- /dev/null +++ b/invokeai/frontend/web/src/app/nodesSocketio/listeners.ts @@ -0,0 +1,158 @@ +import { AnyAction, Dispatch, MiddlewareAPI } from '@reduxjs/toolkit'; +import dateFormat from 'dateformat'; +import i18n from 'i18n'; +import { v4 as uuidv4 } from 'uuid'; + +import { + addLogEntry, + errorOccurred, + setCurrentStatus, + setIsCancelable, + setIsConnected, + setIsProcessing, +} from 'features/system/store/systemSlice'; + +import { + addImage, + clearIntermediateImage, +} from 'features/gallery/store/gallerySlice'; + +import type { RootState } from 'app/store'; +import { + GeneratorProgressEvent, + InvocationCompleteEvent, + InvocationErrorEvent, + InvocationStartedEvent, +} from 'services/events/types'; +import { + setProgress, + setProgressImage, + setSessionId, + setStatus, + STATUS, +} from 'services/apiSlice'; +import { emitUnsubscribe } from './actions'; + +/** + * Returns an object containing listener callbacks + */ +const makeSocketIOListeners = ( + store: MiddlewareAPI, RootState> +) => { + const { dispatch, getState } = store; + + return { + /** + * Callback to run when we receive a 'connect' event. + */ + onConnect: () => { + try { + dispatch(setIsConnected(true)); + dispatch(setCurrentStatus(i18n.t('common.statusConnected'))); + } catch (e) { + console.error(e); + } + }, + /** + * Callback to run when we receive a 'disconnect' event. + */ + onDisconnect: () => { + try { + dispatch(setIsConnected(false)); + dispatch(setCurrentStatus(i18n.t('common.statusDisconnected'))); + + dispatch( + addLogEntry({ + timestamp: dateFormat(new Date(), 'isoDateTime'), + message: `Disconnected from server`, + level: 'warning', + }) + ); + } catch (e) { + console.error(e); + } + }, + onInvocationStarted: (data: InvocationStartedEvent) => { + console.log('invocation_started', data); + dispatch(setStatus(STATUS.busy)); + }, + /** + * Callback to run when we receive a 'generationResult' event. + */ + onInvocationComplete: (data: InvocationCompleteEvent) => { + console.log('invocation_complete', data); + try { + const sessionId = data.graph_execution_state_id; + if (data.result.type === 'image') { + const url = `api/v1/images/${data.result.image.image_type}/${data.result.image.image_name}`; + dispatch( + addImage({ + category: 'result', + image: { + uuid: uuidv4(), + url, + thumbnail: '', + width: 512, + height: 512, + category: 'result', + mtime: new Date().getTime(), + }, + }) + ); + dispatch( + addLogEntry({ + timestamp: dateFormat(new Date(), 'isoDateTime'), + message: `Generated: ${data.result.image.image_name}`, + }) + ); + dispatch(setIsProcessing(false)); + dispatch(setIsCancelable(false)); + dispatch(emitUnsubscribe(sessionId)); + dispatch(setSessionId(null)); + } + } catch (e) { + console.error(e); + } + }, + /** + * Callback to run when we receive a 'progressUpdate' event. + * TODO: Add additional progress phases + */ + onGeneratorProgress: (data: GeneratorProgressEvent) => { + try { + console.log('generator_progress', data); + dispatch(setProgress(data.step / data.total_steps)); + if (data.progress_image) { + dispatch(setProgressImage(data.progress_image)); + } + } catch (e) { + console.error(e); + } + }, + /** + * Callback to run when we receive a 'progressUpdate' event. + */ + onInvocationError: (data: InvocationErrorEvent) => { + const { error } = data; + + try { + dispatch( + addLogEntry({ + timestamp: dateFormat(new Date(), 'isoDateTime'), + message: `Server error: ${error}`, + level: 'error', + }) + ); + dispatch(errorOccurred()); + dispatch(clearIntermediateImage()); + } catch (e) { + console.error(e); + } + }, + /** + * Callback to run when we receive a 'galleryImages' event. + */ + }; +}; + +export default makeSocketIOListeners; diff --git a/invokeai/frontend/web/src/app/nodesSocketio/middleware.ts b/invokeai/frontend/web/src/app/nodesSocketio/middleware.ts new file mode 100644 index 0000000000..dcc6f2c078 --- /dev/null +++ b/invokeai/frontend/web/src/app/nodesSocketio/middleware.ts @@ -0,0 +1,78 @@ +import { Middleware } from '@reduxjs/toolkit'; +import { io } from 'socket.io-client'; + +import makeSocketIOEmitters from './emitters'; +import makeSocketIOListeners from './listeners'; + +import { + GeneratorProgressEvent, + InvocationCompleteEvent, + InvocationErrorEvent, + InvocationStartedEvent, +} from 'services/events/types'; + +const socket_url = `ws://${window.location.host}`; + +const socketio = io(socket_url, { + timeout: 60000, + path: '/ws/socket.io', +}); + +export const socketioMiddleware = () => { + let areListenersSet = false; + + const middleware: Middleware = (store) => (next) => (action) => { + const { emitSubscribe, emitUnsubscribe } = makeSocketIOEmitters(socketio); + + const { + onConnect, + onDisconnect, + onInvocationStarted, + onGeneratorProgress, + onInvocationError, + onInvocationComplete, + } = makeSocketIOListeners(store); + + if (!areListenersSet) { + socketio.on('connect', () => onConnect()); + socketio.on('disconnect', () => onDisconnect()); + } + + areListenersSet = true; + + /** + * Handle redux actions caught by middleware. + */ + switch (action.type) { + case 'socketio/subscribe': { + emitSubscribe(action.payload); + + socketio.on('invocation_started', (data: InvocationStartedEvent) => + onInvocationStarted(data) + ); + socketio.on('generator_progress', (data: GeneratorProgressEvent) => + onGeneratorProgress(data) + ); + socketio.on('invocation_error', (data: InvocationErrorEvent) => + onInvocationError(data) + ); + socketio.on('invocation_complete', (data: InvocationCompleteEvent) => + onInvocationComplete(data) + ); + + break; + } + + case 'socketio/unsubscribe': { + emitUnsubscribe(action.payload); + + socketio.removeAllListeners(); + break; + } + } + + next(action); + }; + + return middleware; +}; diff --git a/invokeai/frontend/web/src/app/store.ts b/invokeai/frontend/web/src/app/store.ts index a8599641ca..d43f673bee 100644 --- a/invokeai/frontend/web/src/app/store.ts +++ b/invokeai/frontend/web/src/app/store.ts @@ -15,6 +15,7 @@ import uiReducer from 'features/ui/store/uiSlice'; import apiReducer from 'services/apiSlice'; import { socketioMiddleware } from './socketio/middleware'; +import { socketioMiddleware as nodesSocketioMiddleware } from './nodesSocketio/middleware'; import { invokeMiddleware } from 'services/invokeMiddleware'; /** @@ -97,6 +98,14 @@ const rootPersistConfig = getPersistConfig({ const persistedReducer = persistReducer(rootPersistConfig, rootReducer); +function buildMiddleware() { + if (import.meta.env.MODE === 'nodes') { + return [nodesSocketioMiddleware(), invokeMiddleware]; + } else { + return [socketioMiddleware()]; + } +} + // Continue with store setup export const store = configureStore({ reducer: persistedReducer, @@ -104,7 +113,7 @@ export const store = configureStore({ getDefaultMiddleware({ immutableCheck: false, serializableCheck: false, - }).concat(socketioMiddleware(), invokeMiddleware), + }).concat(buildMiddleware()), devTools: { // Uncommenting these very rapidly called actions makes the redux dev tools output much more readable actionsDenylist: [ diff --git a/invokeai/frontend/web/src/services/invokeMiddleware.ts b/invokeai/frontend/web/src/services/invokeMiddleware.ts index 8b5909f98b..29bf51bdf1 100644 --- a/invokeai/frontend/web/src/services/invokeMiddleware.ts +++ b/invokeai/frontend/web/src/services/invokeMiddleware.ts @@ -1,4 +1,5 @@ import { Middleware } from '@reduxjs/toolkit'; +import { emitSubscribe } from 'app/nodesSocketio/actions'; import { setSessionId } from './apiSlice'; import { invokeSession } from './thunks/session'; @@ -6,11 +7,13 @@ export const invokeMiddleware: Middleware = (store) => (next) => (action) => { const { dispatch } = store; if (action.type === 'api/createSession/fulfilled' && action?.payload?.id) { + const sessionId = action.payload.id; console.log('createSession.fulfilled'); - dispatch(setSessionId(action.payload.id)); + dispatch(setSessionId(sessionId)); + dispatch(emitSubscribe(sessionId)); // types are wrong but this works - dispatch(invokeSession({ sessionId: action.payload.id })); + dispatch(invokeSession({ sessionId })); } else { next(action); }