feat(ui): more nodes api prototyping

This commit is contained in:
psychedelicious 2023-03-15 23:34:13 +11:00
parent b49338b464
commit 07428769df
7 changed files with 317 additions and 75 deletions

View File

@ -1,14 +1,25 @@
import { Flex, Heading, Text } from '@chakra-ui/react'; import { Flex, Image, Text } from '@chakra-ui/react';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import { useEffect, useState } from 'react'; import {
import { SessionsService } from 'services/api'; setProgress,
setProgressImage,
setSessionId,
setStatus,
} from 'services/apiSlice';
import { useEffect } from 'react';
import { STATUS, ProgressImage } from 'services/apiSliceTypes';
import { getImage } from 'services/thunks/image';
import { createSession, invokeSession } from 'services/thunks/session';
import { io } from 'socket.io-client'; import { io } from 'socket.io-client';
import { useAppDispatch, useAppSelector } from './storeHooks';
import { RootState } from './store';
type GeneratorProgress = { type GeneratorProgress = {
session_id: string; session_id: string;
invocation_id: string; invocation_id: string;
progress_image: ProgressImage;
step: number; step: number;
percent: number; total_steps: number;
}; };
const socket_url = `ws://${window.location.host}`; const socket_url = `ws://${window.location.host}`;
@ -16,98 +27,102 @@ const socket = io(socket_url, {
path: '/ws/socket.io', path: '/ws/socket.io',
}); });
enum STATUS {
waiting = 'WAITING',
ready = 'READY',
preparing = 'PREPARING',
generating = 'GENERATING',
finished = 'FINISHED',
}
const NodeAPITest = () => { const NodeAPITest = () => {
const [invocationProgress, setInvocationProgress] = useState<number>(); const dispatch = useAppDispatch();
const [status, setStatus] = useState<STATUS>(STATUS.waiting); const { sessionId, status, progress, progressImage } = useAppSelector(
const [sessionId, setSessionId] = useState<string | null>(null); (state: RootState) => state.api
);
const handleCreateSession = async () => { const handleCreateSession = () => {
// create a session with a simple graph dispatch(
const payload = await SessionsService.createSession({ createSession({
nodes: { requestBody: {
a: { nodes: {
id: 'a', a: {
type: 'txt2img', id: 'a',
prompt: 'pizza', type: 'txt2img',
steps: 10, prompt: 'pizza',
steps: 10,
},
b: {
id: 'b',
type: 'show_image',
},
},
edges: [
{
source: { node_id: 'a', field: 'image' },
destination: { node_id: 'b', field: 'image' },
},
],
}, },
b: { })
id: 'b', );
type: 'show_image',
},
},
edges: [
{
source: { node_id: 'a', field: 'image' },
destination: { node_id: 'b', field: 'image' },
},
],
});
// the generated types have `id` as optional but i'm pretty sure we always get the id
setSessionId(payload.id!);
setStatus(STATUS.ready);
console.log('payload', payload);
// subscribe to this session
socket.emit('subscribe', { session: payload.id });
console.log('subscribe', { session: payload.id });
}; };
const handleInvokeSession = async () => { const handleInvokeSession = () => {
dispatch(invokeSession());
};
useEffect(() => {
if (!sessionId) { if (!sessionId) {
return; return;
} }
setStatus(STATUS.preparing); // set up socket.io listeners
// invoke the session, the resultant image should open in your platform's native image viewer when completed
await SessionsService.invokeSession(sessionId, true);
};
useEffect(() => { // TODO: suppose this should be handled in the socket.io middleware
// TODO: write types for the socket.io payloads, haven't found a generator for them yet...
// subscribe to the current session
socket.emit('subscribe', { session: sessionId });
console.log('subscribe', { session: sessionId });
// received on each generation step
socket.on('generator_progress', (data: GeneratorProgress) => { socket.on('generator_progress', (data: GeneratorProgress) => {
// this is broken on the backend, the nodes web server does not get `step` or `steps`, so we don't get a percentage
// see https://github.com/invoke-ai/InvokeAI/issues/2951
console.log('generator_progress', data); console.log('generator_progress', data);
setInvocationProgress(data.percent); dispatch(setProgress(data.step / data.total_steps));
dispatch(setProgressImage(data.progress_image));
}); });
// received after invokeSession called
socket.on('invocation_started', (data) => { socket.on('invocation_started', (data) => {
console.log('invocation_started', data); console.log('invocation_started', data);
setStatus(STATUS.generating); dispatch(setStatus(STATUS.busy));
}); });
// received when generation complete
socket.on('invocation_complete', (data) => { socket.on('invocation_complete', (data) => {
// for now, just unsubscribe from the session when we finish a generation // for now, just unsubscribe from the session when we finish a generation
// in the future we will want to continue building the graph and executing etc // in the future we will want to continue building the graph and executing etc
setStatus(STATUS.finished);
console.log('invocation_complete', data); console.log('invocation_complete', data);
socket.emit('unsubscribe', { session: data.session_id }); dispatch(setProgress(null));
console.log('unsubscribe', { session: data.session_id }); dispatch(setSessionId(null));
setTimeout(() => { dispatch(setStatus(STATUS.idle));
setSessionId(null);
setStatus(STATUS.waiting); // think this gets a blob...
}, 2000); // dispatch(
// getImage({
// imageType: data.result.image.image_type,
// imageName: data.result.image.image_name,
// })
// );
socket.emit('unsubscribe', { session: sessionId });
console.log('unsubscribe', { session: sessionId });
}); });
// not sure when we get this?
socket.on('session_complete', (data) => { socket.on('session_complete', (data) => {
console.log('session_complete', data); // console.log('session_complete', data);
socket.emit('unsubscribe', { session: data.session_id });
console.log('unsubscribe', { session: data.session_id });
setSessionId(null);
setStatus(STATUS.waiting);
}); });
() => { () => {
// cleanup
socket.emit('unsubscribe', { session: sessionId });
socket.removeAllListeners(); socket.removeAllListeners();
socket.disconnect(); socket.disconnect();
}; };
}, []); }, [dispatch, sessionId]);
return ( return (
<Flex <Flex
@ -119,28 +134,33 @@ const NodeAPITest = () => {
borderRadius: 'base', borderRadius: 'base',
}} }}
> >
<Heading size="lg">Status: {status}</Heading>
<Text>Session: {sessionId ? sessionId : '...'}</Text> <Text>Session: {sessionId ? sessionId : '...'}</Text>
<IAIButton <IAIButton
onClick={handleCreateSession} onClick={handleCreateSession}
isDisabled={!!sessionId} isDisabled={status === STATUS.busy || Boolean(sessionId)}
colorScheme="accent" colorScheme="accent"
> >
Create Session Create Session
</IAIButton> </IAIButton>
<IAIButton <IAIButton
onClick={handleInvokeSession} onClick={handleInvokeSession}
isDisabled={!sessionId || status !== STATUS.ready} isDisabled={status === STATUS.busy}
isLoading={[STATUS.preparing, STATUS.generating].includes(status)} isLoading={status === STATUS.busy}
loadingText={`Invoking ${ loadingText={`Invoking ${
invocationProgress === undefined progress === null ? '...' : `${Math.round(progress * 100)}%`
? '...'
: `${Math.round(invocationProgress * 100)}%`
}`} }`}
colorScheme="accent" colorScheme="accent"
> >
Invoke Invoke
</IAIButton> </IAIButton>
<Image
src={progressImage?.dataURL}
width={progressImage?.width}
height={progressImage?.height}
sx={{
imageRendering: 'pixelated',
}}
/>
</Flex> </Flex>
); );
}; };

View File

@ -12,6 +12,7 @@ import generationReducer from 'features/parameters/store/generationSlice';
import postprocessingReducer from 'features/parameters/store/postprocessingSlice'; import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
import systemReducer from 'features/system/store/systemSlice'; import systemReducer from 'features/system/store/systemSlice';
import uiReducer from 'features/ui/store/uiSlice'; import uiReducer from 'features/ui/store/uiSlice';
import apiReducer from 'services/apiSlice';
import { socketioMiddleware } from './socketio/middleware'; import { socketioMiddleware } from './socketio/middleware';
@ -64,6 +65,10 @@ const lightboxBlacklist = ['isLightboxOpen'].map(
(blacklistItem) => `lightbox.${blacklistItem}` (blacklistItem) => `lightbox.${blacklistItem}`
); );
const apiBlacklist = ['sessionId', 'status', 'progress', 'progressImage'].map(
(blacklistItem) => `api.${blacklistItem}`
);
const rootReducer = combineReducers({ const rootReducer = combineReducers({
generation: generationReducer, generation: generationReducer,
postprocessing: postprocessingReducer, postprocessing: postprocessingReducer,
@ -72,6 +77,7 @@ const rootReducer = combineReducers({
canvas: canvasReducer, canvas: canvasReducer,
ui: uiReducer, ui: uiReducer,
lightbox: lightboxReducer, lightbox: lightboxReducer,
api: apiReducer,
}); });
const rootPersistConfig = getPersistConfig({ const rootPersistConfig = getPersistConfig({
@ -83,6 +89,7 @@ const rootPersistConfig = getPersistConfig({
...systemBlacklist, ...systemBlacklist,
...galleryBlacklist, ...galleryBlacklist,
...lightboxBlacklist, ...lightboxBlacklist,
...apiBlacklist,
], ],
debounce: 300, debounce: 300,
}); });

View File

@ -0,0 +1,8 @@
import { createAsyncThunk } from '@reduxjs/toolkit';
import { AppDispatch, RootState } from './store';
// https://redux-toolkit.js.org/usage/usage-with-typescript#defining-a-pre-typed-createasyncthunk
export const createAppAsyncThunk = createAsyncThunk.withTypes<{
state: RootState;
dispatch: AppDispatch;
}>();

View File

@ -0,0 +1,79 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import { APIState, STATUS } from './apiSliceTypes';
import { createSession, invokeSession } from 'services/thunks/session';
import { getImage } from './thunks/image';
const initialSystemState: APIState = {
sessionId: null,
status: STATUS.idle,
progress: null,
progressImage: null,
};
export const apiSlice = createSlice({
name: 'api',
initialState: initialSystemState,
reducers: {
setSessionId: (state, action: PayloadAction<APIState['sessionId']>) => {
state.sessionId = action.payload;
},
setStatus: (state, action: PayloadAction<APIState['status']>) => {
state.status = action.payload;
},
setProgressImage: (
state,
action: PayloadAction<APIState['progressImage']>
) => {
state.progressImage = action.payload;
},
setProgress: (state, action: PayloadAction<APIState['progress']>) => {
state.progress = action.payload;
},
},
extraReducers: (builder) => {
builder.addCase(createSession.fulfilled, (state, { payload: { id } }) => {
// HTTP 200
// state.networkStatus = 'idle'
state.sessionId = id;
});
builder.addCase(createSession.pending, (state, action) => {
// HTTP request pending
// state.networkStatus = 'busy'
});
builder.addCase(createSession.rejected, (state, action) => {
// !HTTP 200
console.error('createSession rejected: ', action);
// state.networkStatus = 'idle'
});
builder.addCase(invokeSession.fulfilled, (state, action) => {
// HTTP 200
// state.networkStatus = 'idle'
});
builder.addCase(invokeSession.pending, (state, action) => {
// HTTP request pending
// state.networkStatus = 'busy'
});
builder.addCase(invokeSession.rejected, (state, action) => {
// state.networkStatus = 'idle'
});
builder.addCase(getImage.fulfilled, (state, action) => {
// !HTTP 200
console.log(action.payload);
// state.networkStatus = 'idle'
});
builder.addCase(getImage.pending, (state, action) => {
// HTTP request pending
// state.networkStatus = 'busy'
});
builder.addCase(getImage.rejected, (state, action) => {
// !HTTP 200
// state.networkStatus = 'idle'
});
},
});
export const { setSessionId, setStatus, setProgressImage, setProgress } =
apiSlice.actions;
export default apiSlice.reducer;

View File

@ -0,0 +1,18 @@
export enum STATUS {
idle = 'IDLE',
busy = 'BUSY',
error = 'ERROR',
}
export type ProgressImage = {
width: number;
height: number;
dataURL: string;
};
export interface APIState {
sessionId: string | null;
progressImage: ProgressImage | null;
progress: number | null;
status: STATUS;
}

View File

@ -0,0 +1,28 @@
import { createAppAsyncThunk } from 'app/storeUtils';
import { ImagesService, ImageType } from 'services/api';
type GetImageArg = {
/**
* The type of image to get
*/
imageType: ImageType;
/**
* The name of the image to get
*/
imageName: string;
};
// createAppAsyncThunk provides typing for getState and dispatch
export const getImage = createAppAsyncThunk(
'api/getImage',
async (arg: GetImageArg, { getState, dispatch, ...moreThunkStuff }) => {
const response = await ImagesService.getImage(arg);
return response;
},
{
condition: (arg, { getState }) => {
// we can get an image at any time
return true;
},
}
);

View File

@ -0,0 +1,82 @@
import { createAppAsyncThunk } from 'app/storeUtils';
import { Graph, SessionsService } from 'services/api';
import { STATUS } from 'services/apiSliceTypes';
/**
* createSession
*/
type CreateSessionArg = { requestBody?: Graph };
// createAppAsyncThunk provides typing for getState and dispatch
export const createSession = createAppAsyncThunk(
'api/createSession',
async (arg: CreateSessionArg, { getState, dispatch, ...moreThunkStuff }) => {
const response = await SessionsService.createSession(arg);
return response;
},
{
// if this returns false, the api call is skipped
// we can guard in many places, and maybe this isn't right for us,
// but just trying it here
condition: (arg, { getState }) => {
const {
api: { status, sessionId },
} = getState();
// don't create session if we are processing already
if (status === STATUS.busy) {
return false;
}
// don't create session if we have a sessionId
if (sessionId) {
return false;
}
},
}
);
/**
* invokeSession
*/
export const invokeSession = createAppAsyncThunk(
'api/invokeSession',
async (_arg, { getState }) => {
const {
api: { sessionId },
} = getState();
// i'd really like for the typing on the condition callback below to tell this
// function here that sessionId will never be empty, but guess we do not get
// that luxury
if (!sessionId) {
return;
}
const response = await SessionsService.invokeSession({
sessionId,
all: true,
});
return response;
},
{
condition: (arg, { getState }) => {
const {
api: { status, sessionId },
} = getState();
// don't invoke if we are processing already
if (status === STATUS.busy) {
return false;
}
// don't invoke if we don't have a sessionId
if (!sessionId) {
return false;
}
},
}
);