feat(ui): wip nodes

- extract api client method arg types instead of manually declaring them
- update example to display images
- general tidy up
This commit is contained in:
psychedelicious 2023-03-26 15:42:39 +11:00
parent 999c3a443b
commit ca41a52174
6 changed files with 225 additions and 182 deletions

View File

@ -3,12 +3,16 @@ import IAIButton from 'common/components/IAIButton';
import {
setProgress,
setProgressImage,
setSessionId,
setStatus,
STATUS,
} from 'services/apiSlice';
import { useEffect } from 'react';
import { STATUS, ProgressImage } from 'services/apiSliceTypes';
import { getImage } from 'services/thunks/image';
import { useCallback, useEffect, useState } from 'react';
import {
GeneratorProgressEvent,
GraphExecutionStateCompleteEvent,
InvocationCompleteEvent,
InvocationStartedEvent,
} from 'services/events/types';
import {
cancelProcessing,
createSession,
@ -18,14 +22,6 @@ import { io } from 'socket.io-client';
import { useAppDispatch, useAppSelector } from './storeHooks';
import { RootState } from './store';
type GeneratorProgress = {
session_id: string;
invocation_id: string;
progress_image: ProgressImage;
step: number;
total_steps: number;
};
const socket_url = `ws://${window.location.host}`;
const socket = io(socket_url, {
path: '/ws/socket.io',
@ -33,60 +29,84 @@ const socket = io(socket_url, {
const NodeAPITest = () => {
const dispatch = useAppDispatch();
const { sessionId, status, progress, progressImage } = useAppSelector(
const { sessionId, progress, progressImage } = useAppSelector(
(state: RootState) => state.api
);
const [resultImages, setResultImages] = useState<string[]>([]);
const appendResultImage = useCallback(
(url: string) => {
setResultImages([...resultImages, url]);
},
[resultImages]
);
const handleCreateSession = () => {
dispatch(
createSession({
requestBody: {
nodes: {
a: {
id: 'a',
type: 'txt2img',
prompt: 'pizza',
steps: 50,
seed: 123,
},
b: {
id: 'b',
type: 'img2img',
prompt: 'dog',
steps: 50,
seed: 123,
strength: 0.9,
},
c: {
id: 'c',
type: 'img2img',
prompt: 'cat',
steps: 50,
seed: 123,
strength: 0.9,
},
nodes: {
a: {
id: 'a',
type: 'txt2img',
prompt: 'pizza',
steps: 30,
},
b: {
id: 'b',
type: 'img2img',
prompt: 'dog',
steps: 30,
strength: 0.75,
},
c: {
id: 'c',
type: 'img2img',
prompt: 'cat',
steps: 30,
strength: 0.75,
},
d: {
id: 'd',
type: 'img2img',
prompt: 'jalapeno',
steps: 30,
strength: 0.75,
},
edges: [
{
source: { node_id: 'a', field: 'image' },
destination: { node_id: 'b', field: 'image' },
},
{
source: { node_id: 'b', field: 'image' },
destination: { node_id: 'c', field: 'image' },
},
],
},
edges: [
{
source: { node_id: 'a', field: 'image' },
destination: { node_id: 'b', field: 'image' },
},
{
source: { node_id: 'b', field: 'image' },
destination: { node_id: 'c', field: 'image' },
},
{
source: { node_id: 'c', field: 'image' },
destination: { node_id: 'd', field: 'image' },
},
],
})
);
};
const handleInvokeSession = () => {
dispatch(invokeSession());
if (!sessionId) {
return;
}
dispatch(invokeSession({ sessionId }));
setResultImages([]);
};
const handleCancelProcessing = () => {
dispatch(cancelProcessing());
if (!sessionId) {
return;
}
dispatch(cancelProcessing({ sessionId }));
};
useEffect(() => {
@ -96,51 +116,12 @@ const NodeAPITest = () => {
// set up socket.io listeners
// TODO: suppose this should be handled in the socket.io middleware
// TODO: write types for the socket.io payloads, haven't found a generator for them yet...
// TODO: suppose this should be handled in the socket.io middleware?
// subscribe to the current session
socket.emit('subscribe', { session: sessionId });
console.log('subscribe', { session: sessionId });
// received on each generation step
socket.on('generator_progress', (data: GeneratorProgress) => {
console.log('generator_progress', data);
dispatch(setProgress(data.step / data.total_steps));
dispatch(setProgressImage(data.progress_image));
});
// received after invokeSession called
socket.on('invocation_started', (data) => {
console.log('invocation_started', data);
dispatch(setStatus(STATUS.busy));
});
// received when generation complete
socket.on('invocation_complete', (data) => {
// for now, just unsubscribe from the session when we finish a generation
// in the future we will want to continue building the graph and executing etc
console.log('invocation_complete', data);
dispatch(setProgress(null));
// dispatch(setSessionId(null));
dispatch(setStatus(STATUS.idle));
// think this gets a blob...
// dispatch(
// getImage({
// imageType: data.result.image.image_type,
// imageName: data.result.image.image_name,
// })
// );
});
// not sure when we get this?
socket.on('session_complete', (data) => {
// socket.emit('unsubscribe', { session: sessionId });
// console.log('unsubscribe', { session: sessionId });
// console.log('session_complete', data);
});
() => {
// cleanup
socket.emit('unsubscribe', { session: sessionId });
@ -149,6 +130,52 @@ const NodeAPITest = () => {
};
}, [dispatch, sessionId]);
useEffect(() => {
/**
* `invocation_started`
*/
socket.on('invocation_started', (data: InvocationStartedEvent) => {
console.log('invocation_started', data);
dispatch(setStatus(STATUS.busy));
});
/**
* `generator_progress`
*/
socket.on('generator_progress', (data: GeneratorProgressEvent) => {
console.log('generator_progress', data);
dispatch(setProgress(data.step / data.total_steps));
if (data.progress_image) {
dispatch(setProgressImage(data.progress_image));
}
});
/**
* `invocation_complete`
*/
socket.on('invocation_complete', (data: InvocationCompleteEvent) => {
if (data.result.type === 'image') {
const url = `api/v1/images/${data.result.image.image_type}/${data.result.image.image_name}`;
appendResultImage(url);
}
console.log('invocation_complete', data);
dispatch(setProgress(null));
dispatch(setStatus(STATUS.idle));
console.log(data);
});
/**
* `graph_execution_state_complete`
*/
socket.on(
'graph_execution_state_complete',
(data: GraphExecutionStateCompleteEvent) => {
console.log(data);
}
);
}, [dispatch, appendResultImage]);
return (
<Flex
sx={{
@ -160,24 +187,14 @@ const NodeAPITest = () => {
}}
>
<Text>Session: {sessionId ? sessionId : '...'}</Text>
<IAIButton
onClick={handleCancelProcessing}
// isDisabled={!sessionId}
colorScheme="error"
>
<IAIButton onClick={handleCancelProcessing} colorScheme="error">
Cancel Processing
</IAIButton>
<IAIButton
onClick={handleCreateSession}
// isDisabled={status === STATUS.busy || Boolean(sessionId)}
colorScheme="accent"
>
<IAIButton onClick={handleCreateSession} colorScheme="accent">
Create Session
</IAIButton>
<IAIButton
onClick={handleInvokeSession}
// isDisabled={status === STATUS.busy}
// isLoading={status === STATUS.busy}
loadingText={`Invoking ${
progress === null ? '...' : `${Math.round(progress * 100)}%`
}`}
@ -185,14 +202,19 @@ const NodeAPITest = () => {
>
Invoke
</IAIButton>
<Image
src={progressImage?.dataURL}
width={progressImage?.width}
height={progressImage?.height}
sx={{
imageRendering: 'pixelated',
}}
/>
<Flex wrap="wrap" gap={4} overflow="scroll">
<Image
src={progressImage?.dataURL}
width={progressImage?.width}
height={progressImage?.height}
sx={{
imageRendering: 'pixelated',
}}
/>
{resultImages.map((url) => (
<Image key={url} src={url} />
))}
</Flex>
</Flex>
);
};

View File

@ -1,9 +1,28 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import { APIState, STATUS } from './apiSliceTypes';
import { ProgressImage } from './events/types';
import { createSession, invokeSession } from 'services/thunks/session';
import { getImage } from './thunks/image';
/**
* Just temp until we work out better statuses
*/
export enum STATUS {
idle = 'IDLE',
busy = 'BUSY',
error = 'ERROR',
}
/**
* Type for the temp (?) API slice.
*/
export interface APIState {
sessionId: string | null;
progressImage: ProgressImage | null;
progress: number | null;
status: STATUS;
}
const initialSystemState: APIState = {
sessionId: null,
status: STATUS.idle,
@ -32,7 +51,10 @@ export const apiSlice = createSlice({
},
},
extraReducers: (builder) => {
builder.addCase(createSession.fulfilled, (state, { payload: { id } }) => {
builder.addCase(createSession.fulfilled, (state, action) => {
const {
payload: { id },
} = action;
// HTTP 200
// state.networkStatus = 'idle'
state.sessionId = id;
@ -47,6 +69,7 @@ export const apiSlice = createSlice({
// state.networkStatus = 'idle'
});
builder.addCase(invokeSession.fulfilled, (state, action) => {
console.log('invokeSession.fulfilled: ', action.payload);
// HTTP 200
// state.networkStatus = 'idle'
});

View File

@ -1,18 +0,0 @@
export enum STATUS {
idle = 'IDLE',
busy = 'BUSY',
error = 'ERROR',
}
export type ProgressImage = {
width: number;
height: number;
dataURL: string;
};
export interface APIState {
sessionId: string | null;
progressImage: ProgressImage | null;
progress: number | null;
status: STATUS;
}

View File

@ -1,28 +1,13 @@
import { createAppAsyncThunk } from 'app/storeUtils';
import { ImagesService, ImageType } from 'services/api';
import { ImagesService } from 'services/api';
type GetImageArg = {
/**
* The type of image to get
*/
imageType: ImageType;
/**
* The name of the image to get
*/
imageName: string;
};
type GetImageArg = Parameters<(typeof ImagesService)['getImage']>[0];
// createAppAsyncThunk provides typing for getState and dispatch
export const getImage = createAppAsyncThunk(
'api/getImage',
async (arg: GetImageArg, { getState, dispatch, ...moreThunkStuff }) => {
async (arg: GetImageArg, _thunkApi) => {
const response = await ImagesService.getImage(arg);
return response;
},
{
condition: (arg, { getState }) => {
// we can get an image at any time
return true;
},
}
);

View File

@ -1,39 +1,57 @@
import { createAppAsyncThunk } from 'app/storeUtils';
import { Graph, SessionsService } from 'services/api';
import { STATUS } from 'services/apiSliceTypes';
import { SessionsService } from 'services/api';
/**
* createSession
* createSession thunk
*/
type CreateSessionArg = { requestBody?: Graph };
/**
* Extract the type of the requestBody from the generated API client.
*
* Would really like for this to be generated but it's easy enough to extract it.
*/
type CreateSessionRequestBody = Parameters<
(typeof SessionsService)['createSession']
>[0]['requestBody'];
// createAppAsyncThunk provides typing for getState and dispatch
export const createSession = createAppAsyncThunk(
'api/createSession',
async (arg: CreateSessionArg, { getState, dispatch, ...moreThunkStuff }) => {
const response = await SessionsService.createSession(arg);
async (arg: CreateSessionRequestBody, _thunkApi) => {
const response = await SessionsService.createSession({ requestBody: arg });
return response;
}
);
/**
* invokeSession
* addNode thunk
*/
type AddNodeRequestBody = Parameters<
(typeof SessionsService)['addNode']
>[0]['requestBody'];
export const addNode = createAppAsyncThunk(
'api/addNode',
async (arg: { node: AddNodeRequestBody; sessionId: string }, _thunkApi) => {
const response = await SessionsService.addNode({
requestBody: arg.node,
sessionId: arg.sessionId,
});
return response;
}
);
/**
* invokeSession thunk
*/
export const invokeSession = createAppAsyncThunk(
'api/invokeSession',
async (_arg, { getState }) => {
const {
api: { sessionId },
} = getState();
// i'd really like for the typing on the condition callback below to tell this
// function here that sessionId will never be empty, but guess we do not get
// that luxury
if (!sessionId) {
return;
}
async (arg: { sessionId: string }, _thunkApi) => {
const { sessionId } = arg;
const response = await SessionsService.invokeSession({
sessionId,
@ -43,26 +61,15 @@ export const invokeSession = createAppAsyncThunk(
return response;
}
);
/**
* invokeSession
* invokeSession thunk
*/
export const cancelProcessing = createAppAsyncThunk(
'api/cancelProcessing',
async (_arg, { getState }) => {
console.log('before canceling');
const {
api: { sessionId },
} = getState();
// i'd really like for the typing on the condition callback below to tell this
// function here that sessionId will never be empty, but guess we do not get
// that luxury
if (!sessionId) {
return;
}
console.log('canceling');
async (arg: { sessionId: string }, _thunkApi) => {
const { sessionId } = arg;
const response = await SessionsService.cancelSessionInvoke({
sessionId,

View File

@ -0,0 +1,24 @@
import { Graph, TextToImageInvocation } from './api';
/**
* Make a graph of however many images
*/
export const makeGraphOfXImages = (numberOfImages: string) =>
Array.from(Array(numberOfImages))
.map(
(_val, i): TextToImageInvocation => ({
id: i.toString(),
type: 'txt2img',
prompt: 'pizza',
steps: 50,
seed: 123,
sampler_name: 'ddim',
})
)
.reduce(
(acc, val: TextToImageInvocation) => {
acc.nodes![val.id] = val;
return acc;
},
{ nodes: {} } as Graph
);