feat(ui): wip nodes

- extract api client method arg types instead of manually declaring them
- update example to display images
- general tidy up
This commit is contained in:
psychedelicious 2023-03-26 15:42:39 +11:00
parent 999c3a443b
commit ca41a52174
6 changed files with 225 additions and 182 deletions

View File

@ -3,12 +3,16 @@ import IAIButton from 'common/components/IAIButton';
import { import {
setProgress, setProgress,
setProgressImage, setProgressImage,
setSessionId,
setStatus, setStatus,
STATUS,
} from 'services/apiSlice'; } from 'services/apiSlice';
import { useEffect } from 'react'; import { useCallback, useEffect, useState } from 'react';
import { STATUS, ProgressImage } from 'services/apiSliceTypes'; import {
import { getImage } from 'services/thunks/image'; GeneratorProgressEvent,
GraphExecutionStateCompleteEvent,
InvocationCompleteEvent,
InvocationStartedEvent,
} from 'services/events/types';
import { import {
cancelProcessing, cancelProcessing,
createSession, createSession,
@ -18,14 +22,6 @@ import { io } from 'socket.io-client';
import { useAppDispatch, useAppSelector } from './storeHooks'; import { useAppDispatch, useAppSelector } from './storeHooks';
import { RootState } from './store'; import { RootState } from './store';
type GeneratorProgress = {
session_id: string;
invocation_id: string;
progress_image: ProgressImage;
step: number;
total_steps: number;
};
const socket_url = `ws://${window.location.host}`; const socket_url = `ws://${window.location.host}`;
const socket = io(socket_url, { const socket = io(socket_url, {
path: '/ws/socket.io', path: '/ws/socket.io',
@ -33,37 +29,49 @@ const socket = io(socket_url, {
const NodeAPITest = () => { const NodeAPITest = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { sessionId, status, progress, progressImage } = useAppSelector( const { sessionId, progress, progressImage } = useAppSelector(
(state: RootState) => state.api (state: RootState) => state.api
); );
const [resultImages, setResultImages] = useState<string[]>([]);
const appendResultImage = useCallback(
(url: string) => {
setResultImages([...resultImages, url]);
},
[resultImages]
);
const handleCreateSession = () => { const handleCreateSession = () => {
dispatch( dispatch(
createSession({ createSession({
requestBody: {
nodes: { nodes: {
a: { a: {
id: 'a', id: 'a',
type: 'txt2img', type: 'txt2img',
prompt: 'pizza', prompt: 'pizza',
steps: 50, steps: 30,
seed: 123,
}, },
b: { b: {
id: 'b', id: 'b',
type: 'img2img', type: 'img2img',
prompt: 'dog', prompt: 'dog',
steps: 50, steps: 30,
seed: 123, strength: 0.75,
strength: 0.9,
}, },
c: { c: {
id: 'c', id: 'c',
type: 'img2img', type: 'img2img',
prompt: 'cat', prompt: 'cat',
steps: 50, steps: 30,
seed: 123, strength: 0.75,
strength: 0.9, },
d: {
id: 'd',
type: 'img2img',
prompt: 'jalapeno',
steps: 30,
strength: 0.75,
}, },
}, },
edges: [ edges: [
@ -75,18 +83,30 @@ const NodeAPITest = () => {
source: { node_id: 'b', field: 'image' }, source: { node_id: 'b', field: 'image' },
destination: { node_id: 'c', field: 'image' }, destination: { node_id: 'c', field: 'image' },
}, },
], {
source: { node_id: 'c', field: 'image' },
destination: { node_id: 'd', field: 'image' },
}, },
],
}) })
); );
}; };
const handleInvokeSession = () => { const handleInvokeSession = () => {
dispatch(invokeSession()); if (!sessionId) {
return;
}
dispatch(invokeSession({ sessionId }));
setResultImages([]);
}; };
const handleCancelProcessing = () => { const handleCancelProcessing = () => {
dispatch(cancelProcessing()); if (!sessionId) {
return;
}
dispatch(cancelProcessing({ sessionId }));
}; };
useEffect(() => { useEffect(() => {
@ -96,51 +116,12 @@ const NodeAPITest = () => {
// set up socket.io listeners // set up socket.io listeners
// TODO: suppose this should be handled in the socket.io middleware // TODO: suppose this should be handled in the socket.io middleware?
// TODO: write types for the socket.io payloads, haven't found a generator for them yet...
// subscribe to the current session // subscribe to the current session
socket.emit('subscribe', { session: sessionId }); socket.emit('subscribe', { session: sessionId });
console.log('subscribe', { session: sessionId }); console.log('subscribe', { session: sessionId });
// received on each generation step
socket.on('generator_progress', (data: GeneratorProgress) => {
console.log('generator_progress', data);
dispatch(setProgress(data.step / data.total_steps));
dispatch(setProgressImage(data.progress_image));
});
// received after invokeSession called
socket.on('invocation_started', (data) => {
console.log('invocation_started', data);
dispatch(setStatus(STATUS.busy));
});
// received when generation complete
socket.on('invocation_complete', (data) => {
// for now, just unsubscribe from the session when we finish a generation
// in the future we will want to continue building the graph and executing etc
console.log('invocation_complete', data);
dispatch(setProgress(null));
// dispatch(setSessionId(null));
dispatch(setStatus(STATUS.idle));
// think this gets a blob...
// dispatch(
// getImage({
// imageType: data.result.image.image_type,
// imageName: data.result.image.image_name,
// })
// );
});
// not sure when we get this?
socket.on('session_complete', (data) => {
// socket.emit('unsubscribe', { session: sessionId });
// console.log('unsubscribe', { session: sessionId });
// console.log('session_complete', data);
});
() => { () => {
// cleanup // cleanup
socket.emit('unsubscribe', { session: sessionId }); socket.emit('unsubscribe', { session: sessionId });
@ -149,6 +130,52 @@ const NodeAPITest = () => {
}; };
}, [dispatch, sessionId]); }, [dispatch, sessionId]);
useEffect(() => {
/**
* `invocation_started`
*/
socket.on('invocation_started', (data: InvocationStartedEvent) => {
console.log('invocation_started', data);
dispatch(setStatus(STATUS.busy));
});
/**
* `generator_progress`
*/
socket.on('generator_progress', (data: GeneratorProgressEvent) => {
console.log('generator_progress', data);
dispatch(setProgress(data.step / data.total_steps));
if (data.progress_image) {
dispatch(setProgressImage(data.progress_image));
}
});
/**
* `invocation_complete`
*/
socket.on('invocation_complete', (data: InvocationCompleteEvent) => {
if (data.result.type === 'image') {
const url = `api/v1/images/${data.result.image.image_type}/${data.result.image.image_name}`;
appendResultImage(url);
}
console.log('invocation_complete', data);
dispatch(setProgress(null));
dispatch(setStatus(STATUS.idle));
console.log(data);
});
/**
* `graph_execution_state_complete`
*/
socket.on(
'graph_execution_state_complete',
(data: GraphExecutionStateCompleteEvent) => {
console.log(data);
}
);
}, [dispatch, appendResultImage]);
return ( return (
<Flex <Flex
sx={{ sx={{
@ -160,24 +187,14 @@ const NodeAPITest = () => {
}} }}
> >
<Text>Session: {sessionId ? sessionId : '...'}</Text> <Text>Session: {sessionId ? sessionId : '...'}</Text>
<IAIButton <IAIButton onClick={handleCancelProcessing} colorScheme="error">
onClick={handleCancelProcessing}
// isDisabled={!sessionId}
colorScheme="error"
>
Cancel Processing Cancel Processing
</IAIButton> </IAIButton>
<IAIButton <IAIButton onClick={handleCreateSession} colorScheme="accent">
onClick={handleCreateSession}
// isDisabled={status === STATUS.busy || Boolean(sessionId)}
colorScheme="accent"
>
Create Session Create Session
</IAIButton> </IAIButton>
<IAIButton <IAIButton
onClick={handleInvokeSession} onClick={handleInvokeSession}
// isDisabled={status === STATUS.busy}
// isLoading={status === STATUS.busy}
loadingText={`Invoking ${ loadingText={`Invoking ${
progress === null ? '...' : `${Math.round(progress * 100)}%` progress === null ? '...' : `${Math.round(progress * 100)}%`
}`} }`}
@ -185,6 +202,7 @@ const NodeAPITest = () => {
> >
Invoke Invoke
</IAIButton> </IAIButton>
<Flex wrap="wrap" gap={4} overflow="scroll">
<Image <Image
src={progressImage?.dataURL} src={progressImage?.dataURL}
width={progressImage?.width} width={progressImage?.width}
@ -193,6 +211,10 @@ const NodeAPITest = () => {
imageRendering: 'pixelated', imageRendering: 'pixelated',
}} }}
/> />
{resultImages.map((url) => (
<Image key={url} src={url} />
))}
</Flex>
</Flex> </Flex>
); );
}; };

View File

@ -1,9 +1,28 @@
import type { PayloadAction } from '@reduxjs/toolkit'; import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit';
import { APIState, STATUS } from './apiSliceTypes'; import { ProgressImage } from './events/types';
import { createSession, invokeSession } from 'services/thunks/session'; import { createSession, invokeSession } from 'services/thunks/session';
import { getImage } from './thunks/image'; import { getImage } from './thunks/image';
/**
* Just temp until we work out better statuses
*/
export enum STATUS {
idle = 'IDLE',
busy = 'BUSY',
error = 'ERROR',
}
/**
* Type for the temp (?) API slice.
*/
export interface APIState {
sessionId: string | null;
progressImage: ProgressImage | null;
progress: number | null;
status: STATUS;
}
const initialSystemState: APIState = { const initialSystemState: APIState = {
sessionId: null, sessionId: null,
status: STATUS.idle, status: STATUS.idle,
@ -32,7 +51,10 @@ export const apiSlice = createSlice({
}, },
}, },
extraReducers: (builder) => { extraReducers: (builder) => {
builder.addCase(createSession.fulfilled, (state, { payload: { id } }) => { builder.addCase(createSession.fulfilled, (state, action) => {
const {
payload: { id },
} = action;
// HTTP 200 // HTTP 200
// state.networkStatus = 'idle' // state.networkStatus = 'idle'
state.sessionId = id; state.sessionId = id;
@ -47,6 +69,7 @@ export const apiSlice = createSlice({
// state.networkStatus = 'idle' // state.networkStatus = 'idle'
}); });
builder.addCase(invokeSession.fulfilled, (state, action) => { builder.addCase(invokeSession.fulfilled, (state, action) => {
console.log('invokeSession.fulfilled: ', action.payload);
// HTTP 200 // HTTP 200
// state.networkStatus = 'idle' // state.networkStatus = 'idle'
}); });

View File

@ -1,18 +0,0 @@
export enum STATUS {
idle = 'IDLE',
busy = 'BUSY',
error = 'ERROR',
}
export type ProgressImage = {
width: number;
height: number;
dataURL: string;
};
export interface APIState {
sessionId: string | null;
progressImage: ProgressImage | null;
progress: number | null;
status: STATUS;
}

View File

@ -1,28 +1,13 @@
import { createAppAsyncThunk } from 'app/storeUtils'; import { createAppAsyncThunk } from 'app/storeUtils';
import { ImagesService, ImageType } from 'services/api'; import { ImagesService } from 'services/api';
type GetImageArg = { type GetImageArg = Parameters<(typeof ImagesService)['getImage']>[0];
/**
* The type of image to get
*/
imageType: ImageType;
/**
* The name of the image to get
*/
imageName: string;
};
// createAppAsyncThunk provides typing for getState and dispatch // createAppAsyncThunk provides typing for getState and dispatch
export const getImage = createAppAsyncThunk( export const getImage = createAppAsyncThunk(
'api/getImage', 'api/getImage',
async (arg: GetImageArg, { getState, dispatch, ...moreThunkStuff }) => { async (arg: GetImageArg, _thunkApi) => {
const response = await ImagesService.getImage(arg); const response = await ImagesService.getImage(arg);
return response; return response;
},
{
condition: (arg, { getState }) => {
// we can get an image at any time
return true;
},
} }
); );

View File

@ -1,39 +1,57 @@
import { createAppAsyncThunk } from 'app/storeUtils'; import { createAppAsyncThunk } from 'app/storeUtils';
import { Graph, SessionsService } from 'services/api'; import { SessionsService } from 'services/api';
import { STATUS } from 'services/apiSliceTypes';
/** /**
* createSession * createSession thunk
*/ */
type CreateSessionArg = { requestBody?: Graph }; /**
* Extract the type of the requestBody from the generated API client.
*
* Would really like for this to be generated but it's easy enough to extract it.
*/
type CreateSessionRequestBody = Parameters<
(typeof SessionsService)['createSession']
>[0]['requestBody'];
// createAppAsyncThunk provides typing for getState and dispatch
export const createSession = createAppAsyncThunk( export const createSession = createAppAsyncThunk(
'api/createSession', 'api/createSession',
async (arg: CreateSessionArg, { getState, dispatch, ...moreThunkStuff }) => { async (arg: CreateSessionRequestBody, _thunkApi) => {
const response = await SessionsService.createSession(arg); const response = await SessionsService.createSession({ requestBody: arg });
return response; return response;
} }
); );
/** /**
* invokeSession * addNode thunk
*/
type AddNodeRequestBody = Parameters<
(typeof SessionsService)['addNode']
>[0]['requestBody'];
export const addNode = createAppAsyncThunk(
'api/addNode',
async (arg: { node: AddNodeRequestBody; sessionId: string }, _thunkApi) => {
const response = await SessionsService.addNode({
requestBody: arg.node,
sessionId: arg.sessionId,
});
return response;
}
);
/**
* invokeSession thunk
*/ */
export const invokeSession = createAppAsyncThunk( export const invokeSession = createAppAsyncThunk(
'api/invokeSession', 'api/invokeSession',
async (_arg, { getState }) => { async (arg: { sessionId: string }, _thunkApi) => {
const { const { sessionId } = arg;
api: { sessionId },
} = getState();
// i'd really like for the typing on the condition callback below to tell this
// function here that sessionId will never be empty, but guess we do not get
// that luxury
if (!sessionId) {
return;
}
const response = await SessionsService.invokeSession({ const response = await SessionsService.invokeSession({
sessionId, sessionId,
@ -43,26 +61,15 @@ export const invokeSession = createAppAsyncThunk(
return response; return response;
} }
); );
/** /**
* invokeSession * invokeSession thunk
*/ */
export const cancelProcessing = createAppAsyncThunk( export const cancelProcessing = createAppAsyncThunk(
'api/cancelProcessing', 'api/cancelProcessing',
async (_arg, { getState }) => { async (arg: { sessionId: string }, _thunkApi) => {
console.log('before canceling'); const { sessionId } = arg;
const {
api: { sessionId },
} = getState();
// i'd really like for the typing on the condition callback below to tell this
// function here that sessionId will never be empty, but guess we do not get
// that luxury
if (!sessionId) {
return;
}
console.log('canceling');
const response = await SessionsService.cancelSessionInvoke({ const response = await SessionsService.cancelSessionInvoke({
sessionId, sessionId,

View File

@ -0,0 +1,24 @@
import { Graph, TextToImageInvocation } from './api';
/**
* Make a graph of however many images
*/
export const makeGraphOfXImages = (numberOfImages: string) =>
Array.from(Array(numberOfImages))
.map(
(_val, i): TextToImageInvocation => ({
id: i.toString(),
type: 'txt2img',
prompt: 'pizza',
steps: 50,
seed: 123,
sampler_name: 'ddim',
})
)
.reduce(
(acc, val: TextToImageInvocation) => {
acc.nodes![val.id] = val;
return acc;
},
{ nodes: {} } as Graph
);