mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): write separate nodes socket layer, txt2img generating and rendering w single node
This commit is contained in:
parent
4fe49718e0
commit
40b2d2b05b
11
invokeai/frontend/web/src/app/nodesSocketio/actions.ts
Normal file
11
invokeai/frontend/web/src/app/nodesSocketio/actions.ts
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
import { createAction } from '@reduxjs/toolkit';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* We can't use redux-toolkit's createSlice() to make these actions,
|
||||||
|
* because they have no associated reducer. They only exist to dispatch
|
||||||
|
* requests to the server via socketio. These actions will be handled
|
||||||
|
* by the middleware.
|
||||||
|
*/
|
||||||
|
|
||||||
|
export const emitSubscribe = createAction<string>('socketio/subscribe');
|
||||||
|
export const emitUnsubscribe = createAction<string>('socketio/unsubscribe');
|
15
invokeai/frontend/web/src/app/nodesSocketio/emitters.ts
Normal file
15
invokeai/frontend/web/src/app/nodesSocketio/emitters.ts
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
import { Socket } from 'socket.io-client';
|
||||||
|
|
||||||
|
const makeSocketIOEmitters = (socketio: Socket) => {
|
||||||
|
return {
|
||||||
|
emitSubscribe: (sessionId: string) => {
|
||||||
|
socketio.emit('subscribe', { session: sessionId });
|
||||||
|
},
|
||||||
|
|
||||||
|
emitUnsubscribe: (sessionId: string) => {
|
||||||
|
socketio.emit('unsubscribe', { session: sessionId });
|
||||||
|
},
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
export default makeSocketIOEmitters;
|
158
invokeai/frontend/web/src/app/nodesSocketio/listeners.ts
Normal file
158
invokeai/frontend/web/src/app/nodesSocketio/listeners.ts
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
import { AnyAction, Dispatch, MiddlewareAPI } from '@reduxjs/toolkit';
|
||||||
|
import dateFormat from 'dateformat';
|
||||||
|
import i18n from 'i18n';
|
||||||
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
|
|
||||||
|
import {
|
||||||
|
addLogEntry,
|
||||||
|
errorOccurred,
|
||||||
|
setCurrentStatus,
|
||||||
|
setIsCancelable,
|
||||||
|
setIsConnected,
|
||||||
|
setIsProcessing,
|
||||||
|
} from 'features/system/store/systemSlice';
|
||||||
|
|
||||||
|
import {
|
||||||
|
addImage,
|
||||||
|
clearIntermediateImage,
|
||||||
|
} from 'features/gallery/store/gallerySlice';
|
||||||
|
|
||||||
|
import type { RootState } from 'app/store';
|
||||||
|
import {
|
||||||
|
GeneratorProgressEvent,
|
||||||
|
InvocationCompleteEvent,
|
||||||
|
InvocationErrorEvent,
|
||||||
|
InvocationStartedEvent,
|
||||||
|
} from 'services/events/types';
|
||||||
|
import {
|
||||||
|
setProgress,
|
||||||
|
setProgressImage,
|
||||||
|
setSessionId,
|
||||||
|
setStatus,
|
||||||
|
STATUS,
|
||||||
|
} from 'services/apiSlice';
|
||||||
|
import { emitUnsubscribe } from './actions';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns an object containing listener callbacks
|
||||||
|
*/
|
||||||
|
const makeSocketIOListeners = (
|
||||||
|
store: MiddlewareAPI<Dispatch<AnyAction>, RootState>
|
||||||
|
) => {
|
||||||
|
const { dispatch, getState } = store;
|
||||||
|
|
||||||
|
return {
|
||||||
|
/**
|
||||||
|
* Callback to run when we receive a 'connect' event.
|
||||||
|
*/
|
||||||
|
onConnect: () => {
|
||||||
|
try {
|
||||||
|
dispatch(setIsConnected(true));
|
||||||
|
dispatch(setCurrentStatus(i18n.t('common.statusConnected')));
|
||||||
|
} catch (e) {
|
||||||
|
console.error(e);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
/**
|
||||||
|
* Callback to run when we receive a 'disconnect' event.
|
||||||
|
*/
|
||||||
|
onDisconnect: () => {
|
||||||
|
try {
|
||||||
|
dispatch(setIsConnected(false));
|
||||||
|
dispatch(setCurrentStatus(i18n.t('common.statusDisconnected')));
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
addLogEntry({
|
||||||
|
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||||
|
message: `Disconnected from server`,
|
||||||
|
level: 'warning',
|
||||||
|
})
|
||||||
|
);
|
||||||
|
} catch (e) {
|
||||||
|
console.error(e);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
onInvocationStarted: (data: InvocationStartedEvent) => {
|
||||||
|
console.log('invocation_started', data);
|
||||||
|
dispatch(setStatus(STATUS.busy));
|
||||||
|
},
|
||||||
|
/**
|
||||||
|
* Callback to run when we receive a 'generationResult' event.
|
||||||
|
*/
|
||||||
|
onInvocationComplete: (data: InvocationCompleteEvent) => {
|
||||||
|
console.log('invocation_complete', data);
|
||||||
|
try {
|
||||||
|
const sessionId = data.graph_execution_state_id;
|
||||||
|
if (data.result.type === 'image') {
|
||||||
|
const url = `api/v1/images/${data.result.image.image_type}/${data.result.image.image_name}`;
|
||||||
|
dispatch(
|
||||||
|
addImage({
|
||||||
|
category: 'result',
|
||||||
|
image: {
|
||||||
|
uuid: uuidv4(),
|
||||||
|
url,
|
||||||
|
thumbnail: '',
|
||||||
|
width: 512,
|
||||||
|
height: 512,
|
||||||
|
category: 'result',
|
||||||
|
mtime: new Date().getTime(),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
);
|
||||||
|
dispatch(
|
||||||
|
addLogEntry({
|
||||||
|
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||||
|
message: `Generated: ${data.result.image.image_name}`,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
dispatch(setIsProcessing(false));
|
||||||
|
dispatch(setIsCancelable(false));
|
||||||
|
dispatch(emitUnsubscribe(sessionId));
|
||||||
|
dispatch(setSessionId(null));
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
console.error(e);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
/**
|
||||||
|
* Callback to run when we receive a 'progressUpdate' event.
|
||||||
|
* TODO: Add additional progress phases
|
||||||
|
*/
|
||||||
|
onGeneratorProgress: (data: GeneratorProgressEvent) => {
|
||||||
|
try {
|
||||||
|
console.log('generator_progress', data);
|
||||||
|
dispatch(setProgress(data.step / data.total_steps));
|
||||||
|
if (data.progress_image) {
|
||||||
|
dispatch(setProgressImage(data.progress_image));
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
console.error(e);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
/**
|
||||||
|
* Callback to run when we receive a 'progressUpdate' event.
|
||||||
|
*/
|
||||||
|
onInvocationError: (data: InvocationErrorEvent) => {
|
||||||
|
const { error } = data;
|
||||||
|
|
||||||
|
try {
|
||||||
|
dispatch(
|
||||||
|
addLogEntry({
|
||||||
|
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||||
|
message: `Server error: ${error}`,
|
||||||
|
level: 'error',
|
||||||
|
})
|
||||||
|
);
|
||||||
|
dispatch(errorOccurred());
|
||||||
|
dispatch(clearIntermediateImage());
|
||||||
|
} catch (e) {
|
||||||
|
console.error(e);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
/**
|
||||||
|
* Callback to run when we receive a 'galleryImages' event.
|
||||||
|
*/
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
export default makeSocketIOListeners;
|
78
invokeai/frontend/web/src/app/nodesSocketio/middleware.ts
Normal file
78
invokeai/frontend/web/src/app/nodesSocketio/middleware.ts
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
import { Middleware } from '@reduxjs/toolkit';
|
||||||
|
import { io } from 'socket.io-client';
|
||||||
|
|
||||||
|
import makeSocketIOEmitters from './emitters';
|
||||||
|
import makeSocketIOListeners from './listeners';
|
||||||
|
|
||||||
|
import {
|
||||||
|
GeneratorProgressEvent,
|
||||||
|
InvocationCompleteEvent,
|
||||||
|
InvocationErrorEvent,
|
||||||
|
InvocationStartedEvent,
|
||||||
|
} from 'services/events/types';
|
||||||
|
|
||||||
|
const socket_url = `ws://${window.location.host}`;
|
||||||
|
|
||||||
|
const socketio = io(socket_url, {
|
||||||
|
timeout: 60000,
|
||||||
|
path: '/ws/socket.io',
|
||||||
|
});
|
||||||
|
|
||||||
|
export const socketioMiddleware = () => {
|
||||||
|
let areListenersSet = false;
|
||||||
|
|
||||||
|
const middleware: Middleware = (store) => (next) => (action) => {
|
||||||
|
const { emitSubscribe, emitUnsubscribe } = makeSocketIOEmitters(socketio);
|
||||||
|
|
||||||
|
const {
|
||||||
|
onConnect,
|
||||||
|
onDisconnect,
|
||||||
|
onInvocationStarted,
|
||||||
|
onGeneratorProgress,
|
||||||
|
onInvocationError,
|
||||||
|
onInvocationComplete,
|
||||||
|
} = makeSocketIOListeners(store);
|
||||||
|
|
||||||
|
if (!areListenersSet) {
|
||||||
|
socketio.on('connect', () => onConnect());
|
||||||
|
socketio.on('disconnect', () => onDisconnect());
|
||||||
|
}
|
||||||
|
|
||||||
|
areListenersSet = true;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handle redux actions caught by middleware.
|
||||||
|
*/
|
||||||
|
switch (action.type) {
|
||||||
|
case 'socketio/subscribe': {
|
||||||
|
emitSubscribe(action.payload);
|
||||||
|
|
||||||
|
socketio.on('invocation_started', (data: InvocationStartedEvent) =>
|
||||||
|
onInvocationStarted(data)
|
||||||
|
);
|
||||||
|
socketio.on('generator_progress', (data: GeneratorProgressEvent) =>
|
||||||
|
onGeneratorProgress(data)
|
||||||
|
);
|
||||||
|
socketio.on('invocation_error', (data: InvocationErrorEvent) =>
|
||||||
|
onInvocationError(data)
|
||||||
|
);
|
||||||
|
socketio.on('invocation_complete', (data: InvocationCompleteEvent) =>
|
||||||
|
onInvocationComplete(data)
|
||||||
|
);
|
||||||
|
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case 'socketio/unsubscribe': {
|
||||||
|
emitUnsubscribe(action.payload);
|
||||||
|
|
||||||
|
socketio.removeAllListeners();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
next(action);
|
||||||
|
};
|
||||||
|
|
||||||
|
return middleware;
|
||||||
|
};
|
@ -15,6 +15,7 @@ import uiReducer from 'features/ui/store/uiSlice';
|
|||||||
import apiReducer from 'services/apiSlice';
|
import apiReducer from 'services/apiSlice';
|
||||||
|
|
||||||
import { socketioMiddleware } from './socketio/middleware';
|
import { socketioMiddleware } from './socketio/middleware';
|
||||||
|
import { socketioMiddleware as nodesSocketioMiddleware } from './nodesSocketio/middleware';
|
||||||
import { invokeMiddleware } from 'services/invokeMiddleware';
|
import { invokeMiddleware } from 'services/invokeMiddleware';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -97,6 +98,14 @@ const rootPersistConfig = getPersistConfig({
|
|||||||
|
|
||||||
const persistedReducer = persistReducer(rootPersistConfig, rootReducer);
|
const persistedReducer = persistReducer(rootPersistConfig, rootReducer);
|
||||||
|
|
||||||
|
function buildMiddleware() {
|
||||||
|
if (import.meta.env.MODE === 'nodes') {
|
||||||
|
return [nodesSocketioMiddleware(), invokeMiddleware];
|
||||||
|
} else {
|
||||||
|
return [socketioMiddleware()];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Continue with store setup
|
// Continue with store setup
|
||||||
export const store = configureStore({
|
export const store = configureStore({
|
||||||
reducer: persistedReducer,
|
reducer: persistedReducer,
|
||||||
@ -104,7 +113,7 @@ export const store = configureStore({
|
|||||||
getDefaultMiddleware({
|
getDefaultMiddleware({
|
||||||
immutableCheck: false,
|
immutableCheck: false,
|
||||||
serializableCheck: false,
|
serializableCheck: false,
|
||||||
}).concat(socketioMiddleware(), invokeMiddleware),
|
}).concat(buildMiddleware()),
|
||||||
devTools: {
|
devTools: {
|
||||||
// Uncommenting these very rapidly called actions makes the redux dev tools output much more readable
|
// Uncommenting these very rapidly called actions makes the redux dev tools output much more readable
|
||||||
actionsDenylist: [
|
actionsDenylist: [
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import { Middleware } from '@reduxjs/toolkit';
|
import { Middleware } from '@reduxjs/toolkit';
|
||||||
|
import { emitSubscribe } from 'app/nodesSocketio/actions';
|
||||||
import { setSessionId } from './apiSlice';
|
import { setSessionId } from './apiSlice';
|
||||||
import { invokeSession } from './thunks/session';
|
import { invokeSession } from './thunks/session';
|
||||||
|
|
||||||
@ -6,11 +7,13 @@ export const invokeMiddleware: Middleware = (store) => (next) => (action) => {
|
|||||||
const { dispatch } = store;
|
const { dispatch } = store;
|
||||||
|
|
||||||
if (action.type === 'api/createSession/fulfilled' && action?.payload?.id) {
|
if (action.type === 'api/createSession/fulfilled' && action?.payload?.id) {
|
||||||
|
const sessionId = action.payload.id;
|
||||||
console.log('createSession.fulfilled');
|
console.log('createSession.fulfilled');
|
||||||
|
|
||||||
dispatch(setSessionId(action.payload.id));
|
dispatch(setSessionId(sessionId));
|
||||||
|
dispatch(emitSubscribe(sessionId));
|
||||||
// types are wrong but this works
|
// types are wrong but this works
|
||||||
dispatch(invokeSession({ sessionId: action.payload.id }));
|
dispatch(invokeSession({ sessionId }));
|
||||||
} else {
|
} else {
|
||||||
next(action);
|
next(action);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user