diff --git a/invokeai/frontend/web/src/app/NodeAPITest.tsx b/invokeai/frontend/web/src/app/NodeAPITest.tsx index 6c891a78f4..a6f4427b60 100644 --- a/invokeai/frontend/web/src/app/NodeAPITest.tsx +++ b/invokeai/frontend/web/src/app/NodeAPITest.tsx @@ -3,12 +3,16 @@ import IAIButton from 'common/components/IAIButton'; import { setProgress, setProgressImage, - setSessionId, setStatus, + STATUS, } from 'services/apiSlice'; -import { useEffect } from 'react'; -import { STATUS, ProgressImage } from 'services/apiSliceTypes'; -import { getImage } from 'services/thunks/image'; +import { useCallback, useEffect, useState } from 'react'; +import { + GeneratorProgressEvent, + GraphExecutionStateCompleteEvent, + InvocationCompleteEvent, + InvocationStartedEvent, +} from 'services/events/types'; import { cancelProcessing, createSession, @@ -18,14 +22,6 @@ import { io } from 'socket.io-client'; import { useAppDispatch, useAppSelector } from './storeHooks'; import { RootState } from './store'; -type GeneratorProgress = { - session_id: string; - invocation_id: string; - progress_image: ProgressImage; - step: number; - total_steps: number; -}; - const socket_url = `ws://${window.location.host}`; const socket = io(socket_url, { path: '/ws/socket.io', @@ -33,60 +29,84 @@ const socket = io(socket_url, { const NodeAPITest = () => { const dispatch = useAppDispatch(); - const { sessionId, status, progress, progressImage } = useAppSelector( + const { sessionId, progress, progressImage } = useAppSelector( (state: RootState) => state.api ); + const [resultImages, setResultImages] = useState([]); + + const appendResultImage = useCallback( + (url: string) => { + setResultImages([...resultImages, url]); + }, + [resultImages] + ); + const handleCreateSession = () => { dispatch( createSession({ - requestBody: { - nodes: { - a: { - id: 'a', - type: 'txt2img', - prompt: 'pizza', - steps: 50, - seed: 123, - }, - b: { - id: 'b', - type: 'img2img', - prompt: 'dog', - steps: 50, - seed: 123, - strength: 0.9, - }, - c: { - id: 'c', - type: 'img2img', - prompt: 'cat', - steps: 50, - seed: 123, - strength: 0.9, - }, + nodes: { + a: { + id: 'a', + type: 'txt2img', + prompt: 'pizza', + steps: 30, + }, + b: { + id: 'b', + type: 'img2img', + prompt: 'dog', + steps: 30, + strength: 0.75, + }, + c: { + id: 'c', + type: 'img2img', + prompt: 'cat', + steps: 30, + strength: 0.75, + }, + d: { + id: 'd', + type: 'img2img', + prompt: 'jalapeno', + steps: 30, + strength: 0.75, }, - edges: [ - { - source: { node_id: 'a', field: 'image' }, - destination: { node_id: 'b', field: 'image' }, - }, - { - source: { node_id: 'b', field: 'image' }, - destination: { node_id: 'c', field: 'image' }, - }, - ], }, + edges: [ + { + source: { node_id: 'a', field: 'image' }, + destination: { node_id: 'b', field: 'image' }, + }, + { + source: { node_id: 'b', field: 'image' }, + destination: { node_id: 'c', field: 'image' }, + }, + { + source: { node_id: 'c', field: 'image' }, + destination: { node_id: 'd', field: 'image' }, + }, + ], }) ); }; const handleInvokeSession = () => { - dispatch(invokeSession()); + if (!sessionId) { + return; + } + + dispatch(invokeSession({ sessionId })); + setResultImages([]); }; const handleCancelProcessing = () => { - dispatch(cancelProcessing()); + if (!sessionId) { + return; + } + + dispatch(cancelProcessing({ sessionId })); }; useEffect(() => { @@ -96,51 +116,12 @@ const NodeAPITest = () => { // set up socket.io listeners - // TODO: suppose this should be handled in the socket.io middleware - // TODO: write types for the socket.io payloads, haven't found a generator for them yet... + // TODO: suppose this should be handled in the socket.io middleware? // subscribe to the current session socket.emit('subscribe', { session: sessionId }); console.log('subscribe', { session: sessionId }); - // received on each generation step - socket.on('generator_progress', (data: GeneratorProgress) => { - console.log('generator_progress', data); - dispatch(setProgress(data.step / data.total_steps)); - dispatch(setProgressImage(data.progress_image)); - }); - - // received after invokeSession called - socket.on('invocation_started', (data) => { - console.log('invocation_started', data); - dispatch(setStatus(STATUS.busy)); - }); - - // received when generation complete - socket.on('invocation_complete', (data) => { - // for now, just unsubscribe from the session when we finish a generation - // in the future we will want to continue building the graph and executing etc - console.log('invocation_complete', data); - dispatch(setProgress(null)); - // dispatch(setSessionId(null)); - dispatch(setStatus(STATUS.idle)); - - // think this gets a blob... - // dispatch( - // getImage({ - // imageType: data.result.image.image_type, - // imageName: data.result.image.image_name, - // }) - // ); - }); - - // not sure when we get this? - socket.on('session_complete', (data) => { - // socket.emit('unsubscribe', { session: sessionId }); - // console.log('unsubscribe', { session: sessionId }); - // console.log('session_complete', data); - }); - () => { // cleanup socket.emit('unsubscribe', { session: sessionId }); @@ -149,6 +130,52 @@ const NodeAPITest = () => { }; }, [dispatch, sessionId]); + useEffect(() => { + /** + * `invocation_started` + */ + socket.on('invocation_started', (data: InvocationStartedEvent) => { + console.log('invocation_started', data); + dispatch(setStatus(STATUS.busy)); + }); + + /** + * `generator_progress` + */ + socket.on('generator_progress', (data: GeneratorProgressEvent) => { + console.log('generator_progress', data); + dispatch(setProgress(data.step / data.total_steps)); + if (data.progress_image) { + dispatch(setProgressImage(data.progress_image)); + } + }); + + /** + * `invocation_complete` + */ + socket.on('invocation_complete', (data: InvocationCompleteEvent) => { + if (data.result.type === 'image') { + const url = `api/v1/images/${data.result.image.image_type}/${data.result.image.image_name}`; + appendResultImage(url); + } + + console.log('invocation_complete', data); + dispatch(setProgress(null)); + dispatch(setStatus(STATUS.idle)); + console.log(data); + }); + + /** + * `graph_execution_state_complete` + */ + socket.on( + 'graph_execution_state_complete', + (data: GraphExecutionStateCompleteEvent) => { + console.log(data); + } + ); + }, [dispatch, appendResultImage]); + return ( { }} > Session: {sessionId ? sessionId : '...'} - + Cancel Processing - + Create Session { > Invoke - + + + {resultImages.map((url) => ( + + ))} + ); }; diff --git a/invokeai/frontend/web/src/services/apiSlice.ts b/invokeai/frontend/web/src/services/apiSlice.ts index 9589e92631..0eff6206c3 100644 --- a/invokeai/frontend/web/src/services/apiSlice.ts +++ b/invokeai/frontend/web/src/services/apiSlice.ts @@ -1,9 +1,28 @@ import type { PayloadAction } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit'; -import { APIState, STATUS } from './apiSliceTypes'; +import { ProgressImage } from './events/types'; import { createSession, invokeSession } from 'services/thunks/session'; import { getImage } from './thunks/image'; +/** + * Just temp until we work out better statuses + */ +export enum STATUS { + idle = 'IDLE', + busy = 'BUSY', + error = 'ERROR', +} + +/** + * Type for the temp (?) API slice. + */ +export interface APIState { + sessionId: string | null; + progressImage: ProgressImage | null; + progress: number | null; + status: STATUS; +} + const initialSystemState: APIState = { sessionId: null, status: STATUS.idle, @@ -32,7 +51,10 @@ export const apiSlice = createSlice({ }, }, extraReducers: (builder) => { - builder.addCase(createSession.fulfilled, (state, { payload: { id } }) => { + builder.addCase(createSession.fulfilled, (state, action) => { + const { + payload: { id }, + } = action; // HTTP 200 // state.networkStatus = 'idle' state.sessionId = id; @@ -47,6 +69,7 @@ export const apiSlice = createSlice({ // state.networkStatus = 'idle' }); builder.addCase(invokeSession.fulfilled, (state, action) => { + console.log('invokeSession.fulfilled: ', action.payload); // HTTP 200 // state.networkStatus = 'idle' }); diff --git a/invokeai/frontend/web/src/services/apiSliceTypes.ts b/invokeai/frontend/web/src/services/apiSliceTypes.ts deleted file mode 100644 index f6733078f6..0000000000 --- a/invokeai/frontend/web/src/services/apiSliceTypes.ts +++ /dev/null @@ -1,18 +0,0 @@ -export enum STATUS { - idle = 'IDLE', - busy = 'BUSY', - error = 'ERROR', -} - -export type ProgressImage = { - width: number; - height: number; - dataURL: string; -}; - -export interface APIState { - sessionId: string | null; - progressImage: ProgressImage | null; - progress: number | null; - status: STATUS; -} diff --git a/invokeai/frontend/web/src/services/thunks/image.ts b/invokeai/frontend/web/src/services/thunks/image.ts index c2fa83db98..b52af6f3f0 100644 --- a/invokeai/frontend/web/src/services/thunks/image.ts +++ b/invokeai/frontend/web/src/services/thunks/image.ts @@ -1,28 +1,13 @@ import { createAppAsyncThunk } from 'app/storeUtils'; -import { ImagesService, ImageType } from 'services/api'; +import { ImagesService } from 'services/api'; -type GetImageArg = { - /** - * The type of image to get - */ - imageType: ImageType; - /** - * The name of the image to get - */ - imageName: string; -}; +type GetImageArg = Parameters<(typeof ImagesService)['getImage']>[0]; // createAppAsyncThunk provides typing for getState and dispatch export const getImage = createAppAsyncThunk( 'api/getImage', - async (arg: GetImageArg, { getState, dispatch, ...moreThunkStuff }) => { + async (arg: GetImageArg, _thunkApi) => { const response = await ImagesService.getImage(arg); return response; - }, - { - condition: (arg, { getState }) => { - // we can get an image at any time - return true; - }, } ); diff --git a/invokeai/frontend/web/src/services/thunks/session.ts b/invokeai/frontend/web/src/services/thunks/session.ts index 67cdf967af..0c22fcd474 100644 --- a/invokeai/frontend/web/src/services/thunks/session.ts +++ b/invokeai/frontend/web/src/services/thunks/session.ts @@ -1,39 +1,57 @@ import { createAppAsyncThunk } from 'app/storeUtils'; -import { Graph, SessionsService } from 'services/api'; -import { STATUS } from 'services/apiSliceTypes'; +import { SessionsService } from 'services/api'; /** - * createSession + * createSession thunk */ -type CreateSessionArg = { requestBody?: Graph }; +/** + * Extract the type of the requestBody from the generated API client. + * + * Would really like for this to be generated but it's easy enough to extract it. + */ + +type CreateSessionRequestBody = Parameters< + (typeof SessionsService)['createSession'] +>[0]['requestBody']; -// createAppAsyncThunk provides typing for getState and dispatch export const createSession = createAppAsyncThunk( 'api/createSession', - async (arg: CreateSessionArg, { getState, dispatch, ...moreThunkStuff }) => { - const response = await SessionsService.createSession(arg); + async (arg: CreateSessionRequestBody, _thunkApi) => { + const response = await SessionsService.createSession({ requestBody: arg }); + return response; } ); /** - * invokeSession + * addNode thunk + */ + +type AddNodeRequestBody = Parameters< + (typeof SessionsService)['addNode'] +>[0]['requestBody']; + +export const addNode = createAppAsyncThunk( + 'api/addNode', + async (arg: { node: AddNodeRequestBody; sessionId: string }, _thunkApi) => { + const response = await SessionsService.addNode({ + requestBody: arg.node, + sessionId: arg.sessionId, + }); + + return response; + } +); + +/** + * invokeSession thunk */ export const invokeSession = createAppAsyncThunk( 'api/invokeSession', - async (_arg, { getState }) => { - const { - api: { sessionId }, - } = getState(); - - // i'd really like for the typing on the condition callback below to tell this - // function here that sessionId will never be empty, but guess we do not get - // that luxury - if (!sessionId) { - return; - } + async (arg: { sessionId: string }, _thunkApi) => { + const { sessionId } = arg; const response = await SessionsService.invokeSession({ sessionId, @@ -43,26 +61,15 @@ export const invokeSession = createAppAsyncThunk( return response; } ); + /** - * invokeSession + * invokeSession thunk */ export const cancelProcessing = createAppAsyncThunk( 'api/cancelProcessing', - async (_arg, { getState }) => { - console.log('before canceling'); - const { - api: { sessionId }, - } = getState(); - - // i'd really like for the typing on the condition callback below to tell this - // function here that sessionId will never be empty, but guess we do not get - // that luxury - if (!sessionId) { - return; - } - - console.log('canceling'); + async (arg: { sessionId: string }, _thunkApi) => { + const { sessionId } = arg; const response = await SessionsService.cancelSessionInvoke({ sessionId, diff --git a/invokeai/frontend/web/src/services/util.ts b/invokeai/frontend/web/src/services/util.ts new file mode 100644 index 0000000000..79859150fd --- /dev/null +++ b/invokeai/frontend/web/src/services/util.ts @@ -0,0 +1,24 @@ +import { Graph, TextToImageInvocation } from './api'; + +/** + * Make a graph of however many images + */ +export const makeGraphOfXImages = (numberOfImages: string) => + Array.from(Array(numberOfImages)) + .map( + (_val, i): TextToImageInvocation => ({ + id: i.toString(), + type: 'txt2img', + prompt: 'pizza', + steps: 50, + seed: 123, + sampler_name: 'ddim', + }) + ) + .reduce( + (acc, val: TextToImageInvocation) => { + acc.nodes![val.id] = val; + return acc; + }, + { nodes: {} } as Graph + );