diff --git a/invokeai/frontend/web/src/app/NodeAPITest.tsx b/invokeai/frontend/web/src/app/NodeAPITest.tsx index 0aa8e84f3a..f90326349a 100644 --- a/invokeai/frontend/web/src/app/NodeAPITest.tsx +++ b/invokeai/frontend/web/src/app/NodeAPITest.tsx @@ -1,14 +1,25 @@ -import { Flex, Heading, Text } from '@chakra-ui/react'; +import { Flex, Image, Text } from '@chakra-ui/react'; import IAIButton from 'common/components/IAIButton'; -import { useEffect, useState } from 'react'; -import { SessionsService } from 'services/api'; +import { + setProgress, + setProgressImage, + setSessionId, + setStatus, +} from 'services/apiSlice'; +import { useEffect } from 'react'; +import { STATUS, ProgressImage } from 'services/apiSliceTypes'; +import { getImage } from 'services/thunks/image'; +import { createSession, invokeSession } from 'services/thunks/session'; import { io } from 'socket.io-client'; +import { useAppDispatch, useAppSelector } from './storeHooks'; +import { RootState } from './store'; type GeneratorProgress = { session_id: string; invocation_id: string; + progress_image: ProgressImage; step: number; - percent: number; + total_steps: number; }; const socket_url = `ws://${window.location.host}`; @@ -16,98 +27,102 @@ const socket = io(socket_url, { path: '/ws/socket.io', }); -enum STATUS { - waiting = 'WAITING', - ready = 'READY', - preparing = 'PREPARING', - generating = 'GENERATING', - finished = 'FINISHED', -} - const NodeAPITest = () => { - const [invocationProgress, setInvocationProgress] = useState(); - const [status, setStatus] = useState(STATUS.waiting); - const [sessionId, setSessionId] = useState(null); + const dispatch = useAppDispatch(); + const { sessionId, status, progress, progressImage } = useAppSelector( + (state: RootState) => state.api + ); - const handleCreateSession = async () => { - // create a session with a simple graph - const payload = await SessionsService.createSession({ - nodes: { - a: { - id: 'a', - type: 'txt2img', - prompt: 'pizza', - steps: 10, + const handleCreateSession = () => { + dispatch( + createSession({ + requestBody: { + nodes: { + a: { + id: 'a', + type: 'txt2img', + prompt: 'pizza', + steps: 10, + }, + b: { + id: 'b', + type: 'show_image', + }, + }, + edges: [ + { + source: { node_id: 'a', field: 'image' }, + destination: { node_id: 'b', field: 'image' }, + }, + ], }, - b: { - id: 'b', - type: 'show_image', - }, - }, - edges: [ - { - source: { node_id: 'a', field: 'image' }, - destination: { node_id: 'b', field: 'image' }, - }, - ], - }); - - // the generated types have `id` as optional but i'm pretty sure we always get the id - setSessionId(payload.id!); - setStatus(STATUS.ready); - console.log('payload', payload); - - // subscribe to this session - socket.emit('subscribe', { session: payload.id }); - console.log('subscribe', { session: payload.id }); + }) + ); }; - const handleInvokeSession = async () => { + const handleInvokeSession = () => { + dispatch(invokeSession()); + }; + + useEffect(() => { if (!sessionId) { return; } - setStatus(STATUS.preparing); - // invoke the session, the resultant image should open in your platform's native image viewer when completed - await SessionsService.invokeSession(sessionId, true); - }; + // set up socket.io listeners - useEffect(() => { + // TODO: suppose this should be handled in the socket.io middleware + // TODO: write types for the socket.io payloads, haven't found a generator for them yet... + + // subscribe to the current session + socket.emit('subscribe', { session: sessionId }); + console.log('subscribe', { session: sessionId }); + + // received on each generation step socket.on('generator_progress', (data: GeneratorProgress) => { - // this is broken on the backend, the nodes web server does not get `step` or `steps`, so we don't get a percentage - // see https://github.com/invoke-ai/InvokeAI/issues/2951 console.log('generator_progress', data); - setInvocationProgress(data.percent); + dispatch(setProgress(data.step / data.total_steps)); + dispatch(setProgressImage(data.progress_image)); }); + + // received after invokeSession called socket.on('invocation_started', (data) => { console.log('invocation_started', data); - setStatus(STATUS.generating); + dispatch(setStatus(STATUS.busy)); }); + + // received when generation complete socket.on('invocation_complete', (data) => { // for now, just unsubscribe from the session when we finish a generation // in the future we will want to continue building the graph and executing etc - setStatus(STATUS.finished); console.log('invocation_complete', data); - socket.emit('unsubscribe', { session: data.session_id }); - console.log('unsubscribe', { session: data.session_id }); - setTimeout(() => { - setSessionId(null); - setStatus(STATUS.waiting); - }, 2000); + dispatch(setProgress(null)); + dispatch(setSessionId(null)); + dispatch(setStatus(STATUS.idle)); + + // think this gets a blob... + // dispatch( + // getImage({ + // imageType: data.result.image.image_type, + // imageName: data.result.image.image_name, + // }) + // ); + socket.emit('unsubscribe', { session: sessionId }); + console.log('unsubscribe', { session: sessionId }); }); + + // not sure when we get this? socket.on('session_complete', (data) => { - console.log('session_complete', data); - socket.emit('unsubscribe', { session: data.session_id }); - console.log('unsubscribe', { session: data.session_id }); - setSessionId(null); - setStatus(STATUS.waiting); + // console.log('session_complete', data); }); () => { + // cleanup + socket.emit('unsubscribe', { session: sessionId }); socket.removeAllListeners(); socket.disconnect(); }; - }, []); + }, [dispatch, sessionId]); return ( { borderRadius: 'base', }} > - Status: {status} Session: {sessionId ? sessionId : '...'} Create Session Invoke + ); }; diff --git a/invokeai/frontend/web/src/app/store.ts b/invokeai/frontend/web/src/app/store.ts index 29dbff3fba..152c988a1d 100644 --- a/invokeai/frontend/web/src/app/store.ts +++ b/invokeai/frontend/web/src/app/store.ts @@ -12,6 +12,7 @@ import generationReducer from 'features/parameters/store/generationSlice'; import postprocessingReducer from 'features/parameters/store/postprocessingSlice'; import systemReducer from 'features/system/store/systemSlice'; import uiReducer from 'features/ui/store/uiSlice'; +import apiReducer from 'services/apiSlice'; import { socketioMiddleware } from './socketio/middleware'; @@ -64,6 +65,10 @@ const lightboxBlacklist = ['isLightboxOpen'].map( (blacklistItem) => `lightbox.${blacklistItem}` ); +const apiBlacklist = ['sessionId', 'status', 'progress', 'progressImage'].map( + (blacklistItem) => `api.${blacklistItem}` +); + const rootReducer = combineReducers({ generation: generationReducer, postprocessing: postprocessingReducer, @@ -72,6 +77,7 @@ const rootReducer = combineReducers({ canvas: canvasReducer, ui: uiReducer, lightbox: lightboxReducer, + api: apiReducer, }); const rootPersistConfig = getPersistConfig({ @@ -83,6 +89,7 @@ const rootPersistConfig = getPersistConfig({ ...systemBlacklist, ...galleryBlacklist, ...lightboxBlacklist, + ...apiBlacklist, ], debounce: 300, }); diff --git a/invokeai/frontend/web/src/app/storeUtils.ts b/invokeai/frontend/web/src/app/storeUtils.ts new file mode 100644 index 0000000000..851c0ba09d --- /dev/null +++ b/invokeai/frontend/web/src/app/storeUtils.ts @@ -0,0 +1,8 @@ +import { createAsyncThunk } from '@reduxjs/toolkit'; +import { AppDispatch, RootState } from './store'; + +// https://redux-toolkit.js.org/usage/usage-with-typescript#defining-a-pre-typed-createasyncthunk +export const createAppAsyncThunk = createAsyncThunk.withTypes<{ + state: RootState; + dispatch: AppDispatch; +}>(); diff --git a/invokeai/frontend/web/src/services/apiSlice.ts b/invokeai/frontend/web/src/services/apiSlice.ts new file mode 100644 index 0000000000..9589e92631 --- /dev/null +++ b/invokeai/frontend/web/src/services/apiSlice.ts @@ -0,0 +1,79 @@ +import type { PayloadAction } from '@reduxjs/toolkit'; +import { createSlice } from '@reduxjs/toolkit'; +import { APIState, STATUS } from './apiSliceTypes'; +import { createSession, invokeSession } from 'services/thunks/session'; +import { getImage } from './thunks/image'; + +const initialSystemState: APIState = { + sessionId: null, + status: STATUS.idle, + progress: null, + progressImage: null, +}; + +export const apiSlice = createSlice({ + name: 'api', + initialState: initialSystemState, + reducers: { + setSessionId: (state, action: PayloadAction) => { + state.sessionId = action.payload; + }, + setStatus: (state, action: PayloadAction) => { + state.status = action.payload; + }, + setProgressImage: ( + state, + action: PayloadAction + ) => { + state.progressImage = action.payload; + }, + setProgress: (state, action: PayloadAction) => { + state.progress = action.payload; + }, + }, + extraReducers: (builder) => { + builder.addCase(createSession.fulfilled, (state, { payload: { id } }) => { + // HTTP 200 + // state.networkStatus = 'idle' + state.sessionId = id; + }); + builder.addCase(createSession.pending, (state, action) => { + // HTTP request pending + // state.networkStatus = 'busy' + }); + builder.addCase(createSession.rejected, (state, action) => { + // !HTTP 200 + console.error('createSession rejected: ', action); + // state.networkStatus = 'idle' + }); + builder.addCase(invokeSession.fulfilled, (state, action) => { + // HTTP 200 + // state.networkStatus = 'idle' + }); + builder.addCase(invokeSession.pending, (state, action) => { + // HTTP request pending + // state.networkStatus = 'busy' + }); + builder.addCase(invokeSession.rejected, (state, action) => { + // state.networkStatus = 'idle' + }); + builder.addCase(getImage.fulfilled, (state, action) => { + // !HTTP 200 + console.log(action.payload); + // state.networkStatus = 'idle' + }); + builder.addCase(getImage.pending, (state, action) => { + // HTTP request pending + // state.networkStatus = 'busy' + }); + builder.addCase(getImage.rejected, (state, action) => { + // !HTTP 200 + // state.networkStatus = 'idle' + }); + }, +}); + +export const { setSessionId, setStatus, setProgressImage, setProgress } = + apiSlice.actions; + +export default apiSlice.reducer; diff --git a/invokeai/frontend/web/src/services/apiSliceTypes.ts b/invokeai/frontend/web/src/services/apiSliceTypes.ts new file mode 100644 index 0000000000..f6733078f6 --- /dev/null +++ b/invokeai/frontend/web/src/services/apiSliceTypes.ts @@ -0,0 +1,18 @@ +export enum STATUS { + idle = 'IDLE', + busy = 'BUSY', + error = 'ERROR', +} + +export type ProgressImage = { + width: number; + height: number; + dataURL: string; +}; + +export interface APIState { + sessionId: string | null; + progressImage: ProgressImage | null; + progress: number | null; + status: STATUS; +} diff --git a/invokeai/frontend/web/src/services/thunks/image.ts b/invokeai/frontend/web/src/services/thunks/image.ts new file mode 100644 index 0000000000..c2fa83db98 --- /dev/null +++ b/invokeai/frontend/web/src/services/thunks/image.ts @@ -0,0 +1,28 @@ +import { createAppAsyncThunk } from 'app/storeUtils'; +import { ImagesService, ImageType } from 'services/api'; + +type GetImageArg = { + /** + * The type of image to get + */ + imageType: ImageType; + /** + * The name of the image to get + */ + imageName: string; +}; + +// createAppAsyncThunk provides typing for getState and dispatch +export const getImage = createAppAsyncThunk( + 'api/getImage', + async (arg: GetImageArg, { getState, dispatch, ...moreThunkStuff }) => { + const response = await ImagesService.getImage(arg); + return response; + }, + { + condition: (arg, { getState }) => { + // we can get an image at any time + return true; + }, + } +); diff --git a/invokeai/frontend/web/src/services/thunks/session.ts b/invokeai/frontend/web/src/services/thunks/session.ts new file mode 100644 index 0000000000..3af130e840 --- /dev/null +++ b/invokeai/frontend/web/src/services/thunks/session.ts @@ -0,0 +1,82 @@ +import { createAppAsyncThunk } from 'app/storeUtils'; +import { Graph, SessionsService } from 'services/api'; +import { STATUS } from 'services/apiSliceTypes'; + +/** + * createSession + */ + +type CreateSessionArg = { requestBody?: Graph }; + +// createAppAsyncThunk provides typing for getState and dispatch +export const createSession = createAppAsyncThunk( + 'api/createSession', + async (arg: CreateSessionArg, { getState, dispatch, ...moreThunkStuff }) => { + const response = await SessionsService.createSession(arg); + return response; + }, + { + // if this returns false, the api call is skipped + // we can guard in many places, and maybe this isn't right for us, + // but just trying it here + condition: (arg, { getState }) => { + const { + api: { status, sessionId }, + } = getState(); + + // don't create session if we are processing already + if (status === STATUS.busy) { + return false; + } + + // don't create session if we have a sessionId + if (sessionId) { + return false; + } + }, + } +); + +/** + * invokeSession + */ + +export const invokeSession = createAppAsyncThunk( + 'api/invokeSession', + async (_arg, { getState }) => { + const { + api: { sessionId }, + } = getState(); + + // i'd really like for the typing on the condition callback below to tell this + // function here that sessionId will never be empty, but guess we do not get + // that luxury + if (!sessionId) { + return; + } + + const response = await SessionsService.invokeSession({ + sessionId, + all: true, + }); + + return response; + }, + { + condition: (arg, { getState }) => { + const { + api: { status, sessionId }, + } = getState(); + + // don't invoke if we are processing already + if (status === STATUS.busy) { + return false; + } + + // don't invoke if we don't have a sessionId + if (!sessionId) { + return false; + } + }, + } +);