mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): more nodes api prototyping
This commit is contained in:
parent
b49338b464
commit
07428769df
@ -1,14 +1,25 @@
|
|||||||
import { Flex, Heading, Text } from '@chakra-ui/react';
|
import { Flex, Image, Text } from '@chakra-ui/react';
|
||||||
import IAIButton from 'common/components/IAIButton';
|
import IAIButton from 'common/components/IAIButton';
|
||||||
import { useEffect, useState } from 'react';
|
import {
|
||||||
import { SessionsService } from 'services/api';
|
setProgress,
|
||||||
|
setProgressImage,
|
||||||
|
setSessionId,
|
||||||
|
setStatus,
|
||||||
|
} from 'services/apiSlice';
|
||||||
|
import { useEffect } from 'react';
|
||||||
|
import { STATUS, ProgressImage } from 'services/apiSliceTypes';
|
||||||
|
import { getImage } from 'services/thunks/image';
|
||||||
|
import { createSession, invokeSession } from 'services/thunks/session';
|
||||||
import { io } from 'socket.io-client';
|
import { io } from 'socket.io-client';
|
||||||
|
import { useAppDispatch, useAppSelector } from './storeHooks';
|
||||||
|
import { RootState } from './store';
|
||||||
|
|
||||||
type GeneratorProgress = {
|
type GeneratorProgress = {
|
||||||
session_id: string;
|
session_id: string;
|
||||||
invocation_id: string;
|
invocation_id: string;
|
||||||
|
progress_image: ProgressImage;
|
||||||
step: number;
|
step: number;
|
||||||
percent: number;
|
total_steps: number;
|
||||||
};
|
};
|
||||||
|
|
||||||
const socket_url = `ws://${window.location.host}`;
|
const socket_url = `ws://${window.location.host}`;
|
||||||
@ -16,98 +27,102 @@ const socket = io(socket_url, {
|
|||||||
path: '/ws/socket.io',
|
path: '/ws/socket.io',
|
||||||
});
|
});
|
||||||
|
|
||||||
enum STATUS {
|
|
||||||
waiting = 'WAITING',
|
|
||||||
ready = 'READY',
|
|
||||||
preparing = 'PREPARING',
|
|
||||||
generating = 'GENERATING',
|
|
||||||
finished = 'FINISHED',
|
|
||||||
}
|
|
||||||
|
|
||||||
const NodeAPITest = () => {
|
const NodeAPITest = () => {
|
||||||
const [invocationProgress, setInvocationProgress] = useState<number>();
|
const dispatch = useAppDispatch();
|
||||||
const [status, setStatus] = useState<STATUS>(STATUS.waiting);
|
const { sessionId, status, progress, progressImage } = useAppSelector(
|
||||||
const [sessionId, setSessionId] = useState<string | null>(null);
|
(state: RootState) => state.api
|
||||||
|
);
|
||||||
|
|
||||||
const handleCreateSession = async () => {
|
const handleCreateSession = () => {
|
||||||
// create a session with a simple graph
|
dispatch(
|
||||||
const payload = await SessionsService.createSession({
|
createSession({
|
||||||
nodes: {
|
requestBody: {
|
||||||
a: {
|
nodes: {
|
||||||
id: 'a',
|
a: {
|
||||||
type: 'txt2img',
|
id: 'a',
|
||||||
prompt: 'pizza',
|
type: 'txt2img',
|
||||||
steps: 10,
|
prompt: 'pizza',
|
||||||
|
steps: 10,
|
||||||
|
},
|
||||||
|
b: {
|
||||||
|
id: 'b',
|
||||||
|
type: 'show_image',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
edges: [
|
||||||
|
{
|
||||||
|
source: { node_id: 'a', field: 'image' },
|
||||||
|
destination: { node_id: 'b', field: 'image' },
|
||||||
|
},
|
||||||
|
],
|
||||||
},
|
},
|
||||||
b: {
|
})
|
||||||
id: 'b',
|
);
|
||||||
type: 'show_image',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
edges: [
|
|
||||||
{
|
|
||||||
source: { node_id: 'a', field: 'image' },
|
|
||||||
destination: { node_id: 'b', field: 'image' },
|
|
||||||
},
|
|
||||||
],
|
|
||||||
});
|
|
||||||
|
|
||||||
// the generated types have `id` as optional but i'm pretty sure we always get the id
|
|
||||||
setSessionId(payload.id!);
|
|
||||||
setStatus(STATUS.ready);
|
|
||||||
console.log('payload', payload);
|
|
||||||
|
|
||||||
// subscribe to this session
|
|
||||||
socket.emit('subscribe', { session: payload.id });
|
|
||||||
console.log('subscribe', { session: payload.id });
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleInvokeSession = async () => {
|
const handleInvokeSession = () => {
|
||||||
|
dispatch(invokeSession());
|
||||||
|
};
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
if (!sessionId) {
|
if (!sessionId) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
setStatus(STATUS.preparing);
|
// set up socket.io listeners
|
||||||
// invoke the session, the resultant image should open in your platform's native image viewer when completed
|
|
||||||
await SessionsService.invokeSession(sessionId, true);
|
|
||||||
};
|
|
||||||
|
|
||||||
useEffect(() => {
|
// TODO: suppose this should be handled in the socket.io middleware
|
||||||
|
// TODO: write types for the socket.io payloads, haven't found a generator for them yet...
|
||||||
|
|
||||||
|
// subscribe to the current session
|
||||||
|
socket.emit('subscribe', { session: sessionId });
|
||||||
|
console.log('subscribe', { session: sessionId });
|
||||||
|
|
||||||
|
// received on each generation step
|
||||||
socket.on('generator_progress', (data: GeneratorProgress) => {
|
socket.on('generator_progress', (data: GeneratorProgress) => {
|
||||||
// this is broken on the backend, the nodes web server does not get `step` or `steps`, so we don't get a percentage
|
|
||||||
// see https://github.com/invoke-ai/InvokeAI/issues/2951
|
|
||||||
console.log('generator_progress', data);
|
console.log('generator_progress', data);
|
||||||
setInvocationProgress(data.percent);
|
dispatch(setProgress(data.step / data.total_steps));
|
||||||
|
dispatch(setProgressImage(data.progress_image));
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// received after invokeSession called
|
||||||
socket.on('invocation_started', (data) => {
|
socket.on('invocation_started', (data) => {
|
||||||
console.log('invocation_started', data);
|
console.log('invocation_started', data);
|
||||||
setStatus(STATUS.generating);
|
dispatch(setStatus(STATUS.busy));
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// received when generation complete
|
||||||
socket.on('invocation_complete', (data) => {
|
socket.on('invocation_complete', (data) => {
|
||||||
// for now, just unsubscribe from the session when we finish a generation
|
// for now, just unsubscribe from the session when we finish a generation
|
||||||
// in the future we will want to continue building the graph and executing etc
|
// in the future we will want to continue building the graph and executing etc
|
||||||
setStatus(STATUS.finished);
|
|
||||||
console.log('invocation_complete', data);
|
console.log('invocation_complete', data);
|
||||||
socket.emit('unsubscribe', { session: data.session_id });
|
dispatch(setProgress(null));
|
||||||
console.log('unsubscribe', { session: data.session_id });
|
dispatch(setSessionId(null));
|
||||||
setTimeout(() => {
|
dispatch(setStatus(STATUS.idle));
|
||||||
setSessionId(null);
|
|
||||||
setStatus(STATUS.waiting);
|
// think this gets a blob...
|
||||||
}, 2000);
|
// dispatch(
|
||||||
|
// getImage({
|
||||||
|
// imageType: data.result.image.image_type,
|
||||||
|
// imageName: data.result.image.image_name,
|
||||||
|
// })
|
||||||
|
// );
|
||||||
|
socket.emit('unsubscribe', { session: sessionId });
|
||||||
|
console.log('unsubscribe', { session: sessionId });
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// not sure when we get this?
|
||||||
socket.on('session_complete', (data) => {
|
socket.on('session_complete', (data) => {
|
||||||
console.log('session_complete', data);
|
// console.log('session_complete', data);
|
||||||
socket.emit('unsubscribe', { session: data.session_id });
|
|
||||||
console.log('unsubscribe', { session: data.session_id });
|
|
||||||
setSessionId(null);
|
|
||||||
setStatus(STATUS.waiting);
|
|
||||||
});
|
});
|
||||||
|
|
||||||
() => {
|
() => {
|
||||||
|
// cleanup
|
||||||
|
socket.emit('unsubscribe', { session: sessionId });
|
||||||
socket.removeAllListeners();
|
socket.removeAllListeners();
|
||||||
socket.disconnect();
|
socket.disconnect();
|
||||||
};
|
};
|
||||||
}, []);
|
}, [dispatch, sessionId]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex
|
<Flex
|
||||||
@ -119,28 +134,33 @@ const NodeAPITest = () => {
|
|||||||
borderRadius: 'base',
|
borderRadius: 'base',
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<Heading size="lg">Status: {status}</Heading>
|
|
||||||
<Text>Session: {sessionId ? sessionId : '...'}</Text>
|
<Text>Session: {sessionId ? sessionId : '...'}</Text>
|
||||||
<IAIButton
|
<IAIButton
|
||||||
onClick={handleCreateSession}
|
onClick={handleCreateSession}
|
||||||
isDisabled={!!sessionId}
|
isDisabled={status === STATUS.busy || Boolean(sessionId)}
|
||||||
colorScheme="accent"
|
colorScheme="accent"
|
||||||
>
|
>
|
||||||
Create Session
|
Create Session
|
||||||
</IAIButton>
|
</IAIButton>
|
||||||
<IAIButton
|
<IAIButton
|
||||||
onClick={handleInvokeSession}
|
onClick={handleInvokeSession}
|
||||||
isDisabled={!sessionId || status !== STATUS.ready}
|
isDisabled={status === STATUS.busy}
|
||||||
isLoading={[STATUS.preparing, STATUS.generating].includes(status)}
|
isLoading={status === STATUS.busy}
|
||||||
loadingText={`Invoking ${
|
loadingText={`Invoking ${
|
||||||
invocationProgress === undefined
|
progress === null ? '...' : `${Math.round(progress * 100)}%`
|
||||||
? '...'
|
|
||||||
: `${Math.round(invocationProgress * 100)}%`
|
|
||||||
}`}
|
}`}
|
||||||
colorScheme="accent"
|
colorScheme="accent"
|
||||||
>
|
>
|
||||||
Invoke
|
Invoke
|
||||||
</IAIButton>
|
</IAIButton>
|
||||||
|
<Image
|
||||||
|
src={progressImage?.dataURL}
|
||||||
|
width={progressImage?.width}
|
||||||
|
height={progressImage?.height}
|
||||||
|
sx={{
|
||||||
|
imageRendering: 'pixelated',
|
||||||
|
}}
|
||||||
|
/>
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -12,6 +12,7 @@ import generationReducer from 'features/parameters/store/generationSlice';
|
|||||||
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
|
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
|
||||||
import systemReducer from 'features/system/store/systemSlice';
|
import systemReducer from 'features/system/store/systemSlice';
|
||||||
import uiReducer from 'features/ui/store/uiSlice';
|
import uiReducer from 'features/ui/store/uiSlice';
|
||||||
|
import apiReducer from 'services/apiSlice';
|
||||||
|
|
||||||
import { socketioMiddleware } from './socketio/middleware';
|
import { socketioMiddleware } from './socketio/middleware';
|
||||||
|
|
||||||
@ -64,6 +65,10 @@ const lightboxBlacklist = ['isLightboxOpen'].map(
|
|||||||
(blacklistItem) => `lightbox.${blacklistItem}`
|
(blacklistItem) => `lightbox.${blacklistItem}`
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const apiBlacklist = ['sessionId', 'status', 'progress', 'progressImage'].map(
|
||||||
|
(blacklistItem) => `api.${blacklistItem}`
|
||||||
|
);
|
||||||
|
|
||||||
const rootReducer = combineReducers({
|
const rootReducer = combineReducers({
|
||||||
generation: generationReducer,
|
generation: generationReducer,
|
||||||
postprocessing: postprocessingReducer,
|
postprocessing: postprocessingReducer,
|
||||||
@ -72,6 +77,7 @@ const rootReducer = combineReducers({
|
|||||||
canvas: canvasReducer,
|
canvas: canvasReducer,
|
||||||
ui: uiReducer,
|
ui: uiReducer,
|
||||||
lightbox: lightboxReducer,
|
lightbox: lightboxReducer,
|
||||||
|
api: apiReducer,
|
||||||
});
|
});
|
||||||
|
|
||||||
const rootPersistConfig = getPersistConfig({
|
const rootPersistConfig = getPersistConfig({
|
||||||
@ -83,6 +89,7 @@ const rootPersistConfig = getPersistConfig({
|
|||||||
...systemBlacklist,
|
...systemBlacklist,
|
||||||
...galleryBlacklist,
|
...galleryBlacklist,
|
||||||
...lightboxBlacklist,
|
...lightboxBlacklist,
|
||||||
|
...apiBlacklist,
|
||||||
],
|
],
|
||||||
debounce: 300,
|
debounce: 300,
|
||||||
});
|
});
|
||||||
|
8
invokeai/frontend/web/src/app/storeUtils.ts
Normal file
8
invokeai/frontend/web/src/app/storeUtils.ts
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
import { createAsyncThunk } from '@reduxjs/toolkit';
|
||||||
|
import { AppDispatch, RootState } from './store';
|
||||||
|
|
||||||
|
// https://redux-toolkit.js.org/usage/usage-with-typescript#defining-a-pre-typed-createasyncthunk
|
||||||
|
export const createAppAsyncThunk = createAsyncThunk.withTypes<{
|
||||||
|
state: RootState;
|
||||||
|
dispatch: AppDispatch;
|
||||||
|
}>();
|
79
invokeai/frontend/web/src/services/apiSlice.ts
Normal file
79
invokeai/frontend/web/src/services/apiSlice.ts
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||||
|
import { createSlice } from '@reduxjs/toolkit';
|
||||||
|
import { APIState, STATUS } from './apiSliceTypes';
|
||||||
|
import { createSession, invokeSession } from 'services/thunks/session';
|
||||||
|
import { getImage } from './thunks/image';
|
||||||
|
|
||||||
|
const initialSystemState: APIState = {
|
||||||
|
sessionId: null,
|
||||||
|
status: STATUS.idle,
|
||||||
|
progress: null,
|
||||||
|
progressImage: null,
|
||||||
|
};
|
||||||
|
|
||||||
|
export const apiSlice = createSlice({
|
||||||
|
name: 'api',
|
||||||
|
initialState: initialSystemState,
|
||||||
|
reducers: {
|
||||||
|
setSessionId: (state, action: PayloadAction<APIState['sessionId']>) => {
|
||||||
|
state.sessionId = action.payload;
|
||||||
|
},
|
||||||
|
setStatus: (state, action: PayloadAction<APIState['status']>) => {
|
||||||
|
state.status = action.payload;
|
||||||
|
},
|
||||||
|
setProgressImage: (
|
||||||
|
state,
|
||||||
|
action: PayloadAction<APIState['progressImage']>
|
||||||
|
) => {
|
||||||
|
state.progressImage = action.payload;
|
||||||
|
},
|
||||||
|
setProgress: (state, action: PayloadAction<APIState['progress']>) => {
|
||||||
|
state.progress = action.payload;
|
||||||
|
},
|
||||||
|
},
|
||||||
|
extraReducers: (builder) => {
|
||||||
|
builder.addCase(createSession.fulfilled, (state, { payload: { id } }) => {
|
||||||
|
// HTTP 200
|
||||||
|
// state.networkStatus = 'idle'
|
||||||
|
state.sessionId = id;
|
||||||
|
});
|
||||||
|
builder.addCase(createSession.pending, (state, action) => {
|
||||||
|
// HTTP request pending
|
||||||
|
// state.networkStatus = 'busy'
|
||||||
|
});
|
||||||
|
builder.addCase(createSession.rejected, (state, action) => {
|
||||||
|
// !HTTP 200
|
||||||
|
console.error('createSession rejected: ', action);
|
||||||
|
// state.networkStatus = 'idle'
|
||||||
|
});
|
||||||
|
builder.addCase(invokeSession.fulfilled, (state, action) => {
|
||||||
|
// HTTP 200
|
||||||
|
// state.networkStatus = 'idle'
|
||||||
|
});
|
||||||
|
builder.addCase(invokeSession.pending, (state, action) => {
|
||||||
|
// HTTP request pending
|
||||||
|
// state.networkStatus = 'busy'
|
||||||
|
});
|
||||||
|
builder.addCase(invokeSession.rejected, (state, action) => {
|
||||||
|
// state.networkStatus = 'idle'
|
||||||
|
});
|
||||||
|
builder.addCase(getImage.fulfilled, (state, action) => {
|
||||||
|
// !HTTP 200
|
||||||
|
console.log(action.payload);
|
||||||
|
// state.networkStatus = 'idle'
|
||||||
|
});
|
||||||
|
builder.addCase(getImage.pending, (state, action) => {
|
||||||
|
// HTTP request pending
|
||||||
|
// state.networkStatus = 'busy'
|
||||||
|
});
|
||||||
|
builder.addCase(getImage.rejected, (state, action) => {
|
||||||
|
// !HTTP 200
|
||||||
|
// state.networkStatus = 'idle'
|
||||||
|
});
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
export const { setSessionId, setStatus, setProgressImage, setProgress } =
|
||||||
|
apiSlice.actions;
|
||||||
|
|
||||||
|
export default apiSlice.reducer;
|
18
invokeai/frontend/web/src/services/apiSliceTypes.ts
Normal file
18
invokeai/frontend/web/src/services/apiSliceTypes.ts
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
export enum STATUS {
|
||||||
|
idle = 'IDLE',
|
||||||
|
busy = 'BUSY',
|
||||||
|
error = 'ERROR',
|
||||||
|
}
|
||||||
|
|
||||||
|
export type ProgressImage = {
|
||||||
|
width: number;
|
||||||
|
height: number;
|
||||||
|
dataURL: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export interface APIState {
|
||||||
|
sessionId: string | null;
|
||||||
|
progressImage: ProgressImage | null;
|
||||||
|
progress: number | null;
|
||||||
|
status: STATUS;
|
||||||
|
}
|
28
invokeai/frontend/web/src/services/thunks/image.ts
Normal file
28
invokeai/frontend/web/src/services/thunks/image.ts
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
import { createAppAsyncThunk } from 'app/storeUtils';
|
||||||
|
import { ImagesService, ImageType } from 'services/api';
|
||||||
|
|
||||||
|
type GetImageArg = {
|
||||||
|
/**
|
||||||
|
* The type of image to get
|
||||||
|
*/
|
||||||
|
imageType: ImageType;
|
||||||
|
/**
|
||||||
|
* The name of the image to get
|
||||||
|
*/
|
||||||
|
imageName: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
// createAppAsyncThunk provides typing for getState and dispatch
|
||||||
|
export const getImage = createAppAsyncThunk(
|
||||||
|
'api/getImage',
|
||||||
|
async (arg: GetImageArg, { getState, dispatch, ...moreThunkStuff }) => {
|
||||||
|
const response = await ImagesService.getImage(arg);
|
||||||
|
return response;
|
||||||
|
},
|
||||||
|
{
|
||||||
|
condition: (arg, { getState }) => {
|
||||||
|
// we can get an image at any time
|
||||||
|
return true;
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
82
invokeai/frontend/web/src/services/thunks/session.ts
Normal file
82
invokeai/frontend/web/src/services/thunks/session.ts
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
import { createAppAsyncThunk } from 'app/storeUtils';
|
||||||
|
import { Graph, SessionsService } from 'services/api';
|
||||||
|
import { STATUS } from 'services/apiSliceTypes';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* createSession
|
||||||
|
*/
|
||||||
|
|
||||||
|
type CreateSessionArg = { requestBody?: Graph };
|
||||||
|
|
||||||
|
// createAppAsyncThunk provides typing for getState and dispatch
|
||||||
|
export const createSession = createAppAsyncThunk(
|
||||||
|
'api/createSession',
|
||||||
|
async (arg: CreateSessionArg, { getState, dispatch, ...moreThunkStuff }) => {
|
||||||
|
const response = await SessionsService.createSession(arg);
|
||||||
|
return response;
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// if this returns false, the api call is skipped
|
||||||
|
// we can guard in many places, and maybe this isn't right for us,
|
||||||
|
// but just trying it here
|
||||||
|
condition: (arg, { getState }) => {
|
||||||
|
const {
|
||||||
|
api: { status, sessionId },
|
||||||
|
} = getState();
|
||||||
|
|
||||||
|
// don't create session if we are processing already
|
||||||
|
if (status === STATUS.busy) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// don't create session if we have a sessionId
|
||||||
|
if (sessionId) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* invokeSession
|
||||||
|
*/
|
||||||
|
|
||||||
|
export const invokeSession = createAppAsyncThunk(
|
||||||
|
'api/invokeSession',
|
||||||
|
async (_arg, { getState }) => {
|
||||||
|
const {
|
||||||
|
api: { sessionId },
|
||||||
|
} = getState();
|
||||||
|
|
||||||
|
// i'd really like for the typing on the condition callback below to tell this
|
||||||
|
// function here that sessionId will never be empty, but guess we do not get
|
||||||
|
// that luxury
|
||||||
|
if (!sessionId) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const response = await SessionsService.invokeSession({
|
||||||
|
sessionId,
|
||||||
|
all: true,
|
||||||
|
});
|
||||||
|
|
||||||
|
return response;
|
||||||
|
},
|
||||||
|
{
|
||||||
|
condition: (arg, { getState }) => {
|
||||||
|
const {
|
||||||
|
api: { status, sessionId },
|
||||||
|
} = getState();
|
||||||
|
|
||||||
|
// don't invoke if we are processing already
|
||||||
|
if (status === STATUS.busy) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// don't invoke if we don't have a sessionId
|
||||||
|
if (!sessionId) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
Loading…
Reference in New Issue
Block a user