feat(ui): wip events, comments, and general refactoring

This commit is contained in:
psychedelicious 2023-04-06 00:27:46 +10:00
parent 500bdfa7dd
commit f6c6f61da6
26 changed files with 358 additions and 635 deletions

View File

@ -7,4 +7,4 @@ index.html
.yarn/
*.scss
src/services/api/
src/services/openapi.json
src/services/fixtures/*

View File

@ -7,4 +7,4 @@ index.html
.yarn/
*.scss
src/services/api/
src/services/openapi.json
src/services/fixtures/*

View File

@ -7,8 +7,8 @@
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
"dev:nodes": "concurrently \"vite dev --mode nodes\" \"yarn run theme:watch\"",
"build": "yarn run lint && vite build",
"api:web": "openapi -i http://localhost:9090/openapi.json -o src/services/api --client axios --useOptions --useUnionTypes --exportSchemas true --indent 2 --request src/services/request.ts",
"api:file": "openapi -i openapi.json -o src/services/api --client axios --useOptions --useUnionTypes --exportSchemas true --indent 2 --request src/services/request.ts",
"api:web": "openapi -i http://localhost:9090/openapi.json -o src/services/api --client axios --useOptions --useUnionTypes --exportSchemas true --indent 2 --request src/services/fixtures/request.ts",
"api:file": "openapi -i src/services/fixtures/openapi.json -o src/services/api --client axios --useOptions --useUnionTypes --exportSchemas true --indent 2 --request src/services/fixtures/request.ts",
"preview": "vite preview",
"lint:madge": "madge --circular src/main.tsx",
"lint:eslint": "eslint --max-warnings=0 .",

View File

@ -1,230 +0,0 @@
import { Flex, Image, Text } from '@chakra-ui/react';
import IAIButton from 'common/components/IAIButton';
import {
setProgress,
setProgressImage,
setStatus,
STATUS,
} from 'services/apiSlice';
import { useCallback, useEffect, useState } from 'react';
import {
GeneratorProgressEvent,
GraphExecutionStateCompleteEvent,
InvocationCompleteEvent,
InvocationStartedEvent,
} from 'services/events/types';
import {
cancelProcessing,
createSession,
invokeSession,
} from 'services/thunks/session';
import { io } from 'socket.io-client';
import { useAppDispatch, useAppSelector } from './storeHooks';
import { RootState } from './store';
const socket_url = `ws://${window.location.host}`;
const socket = io(socket_url, {
path: '/ws/socket.io',
});
const NodeAPITest = () => {
const dispatch = useAppDispatch();
const { sessionId, progress, progressImage } = useAppSelector(
(state: RootState) => state.api
);
const [resultImages, setResultImages] = useState<string[]>([]);
const appendResultImage = useCallback(
(url: string) => {
setResultImages([...resultImages, url]);
},
[resultImages]
);
const handleCreateSession = () => {
dispatch(
createSession({
nodes: {
a: {
id: 'a',
type: 'txt2img',
prompt: 'pizza',
steps: 30,
},
b: {
id: 'b',
type: 'img2img',
prompt: 'dog',
steps: 30,
strength: 0.75,
},
c: {
id: 'c',
type: 'img2img',
prompt: 'cat',
steps: 30,
strength: 0.75,
},
d: {
id: 'd',
type: 'img2img',
prompt: 'jalapeno',
steps: 30,
strength: 0.75,
},
},
edges: [
{
source: { node_id: 'a', field: 'image' },
destination: { node_id: 'b', field: 'image' },
},
{
source: { node_id: 'b', field: 'image' },
destination: { node_id: 'c', field: 'image' },
},
{
source: { node_id: 'c', field: 'image' },
destination: { node_id: 'd', field: 'image' },
},
],
})
);
};
const handleInvokeSession = () => {
if (!sessionId) {
return;
}
dispatch(invokeSession({ sessionId }));
setResultImages([]);
};
const handleCancelProcessing = () => {
if (!sessionId) {
return;
}
dispatch(cancelProcessing({ sessionId }));
};
useEffect(() => {
if (!sessionId) {
return;
}
setResultImages([]);
// set up socket.io listeners
// TODO: suppose this should be handled in the socket.io middleware?
// subscribe to the current session
socket.emit('subscribe', { session: sessionId });
console.log('subscribe', { session: sessionId });
() => {
// cleanup
socket.emit('unsubscribe', { session: sessionId });
socket.removeAllListeners();
socket.disconnect();
};
}, [dispatch, sessionId]);
useEffect(() => {
/**
* `invocation_started`
*/
socket.on('invocation_started', (data: InvocationStartedEvent) => {
console.log('invocation_started', data);
dispatch(setStatus(STATUS.busy));
});
/**
* `generator_progress`
*/
socket.on('generator_progress', (data: GeneratorProgressEvent) => {
console.log('generator_progress', data);
dispatch(setProgress(data.step / data.total_steps));
if (data.progress_image) {
dispatch(setProgressImage(data.progress_image));
}
});
/**
* `invocation_complete`
*/
socket.on('invocation_complete', (data: InvocationCompleteEvent) => {
if (data.result.type === 'image') {
const url = `api/v1/images/${data.result.image.image_type}/${data.result.image.image_name}`;
appendResultImage(url);
}
console.log('invocation_complete', data);
dispatch(setProgress(null));
dispatch(setStatus(STATUS.idle));
console.log(data);
});
/**
* `graph_execution_state_complete`
*/
socket.on(
'graph_execution_state_complete',
(data: GraphExecutionStateCompleteEvent) => {
console.log(data);
}
);
}, [dispatch, appendResultImage]);
return (
<Flex
sx={{
flexDirection: 'column',
gap: 4,
p: 4,
alignItems: 'center',
borderRadius: 'base',
}}
>
<Text>Session: {sessionId ? sessionId : '...'}</Text>
<IAIButton onClick={handleCancelProcessing} colorScheme="error">
Cancel Processing
</IAIButton>
<IAIButton
onClick={handleCreateSession}
colorScheme="accent"
loadingText={`Invoking ${
progress === null ? '...' : `${Math.round(progress * 100)}%`
}`}
>
Create Session & Invoke
</IAIButton>
{/* <IAIButton
onClick={handleInvokeSession}
loadingText={`Invoking ${
progress === null ? '...' : `${Math.round(progress * 100)}%`
}`}
colorScheme="accent"
>
Invoke
</IAIButton> */}
<Flex wrap="wrap" gap={4} overflow="scroll">
<Image
src={progressImage?.dataURL}
width={progressImage?.width}
height={progressImage?.height}
sx={{
imageRendering: 'pixelated',
}}
/>
{resultImages.map((url) => (
<Image key={url} src={url} />
))}
</Flex>
</Flex>
);
};
export default NodeAPITest;

View File

@ -1,43 +0,0 @@
import { createAction } from '@reduxjs/toolkit';
import {
GeneratorProgressEvent,
InvocationCompleteEvent,
InvocationErrorEvent,
InvocationStartedEvent,
} from 'services/events/types';
type SocketioPayload = {
timestamp: Date;
};
export const socketioConnected = createAction<SocketioPayload>(
'socketio/socketioConnected'
);
export const socketioDisconnected = createAction<SocketioPayload>(
'socketio/socketioDisconnected'
);
export const socketioSubscribed = createAction<
SocketioPayload & { sessionId: string }
>('socketio/socketioSubscribed');
export const socketioUnsubscribed = createAction<
SocketioPayload & { sessionId: string }
>('socketio/socketioUnsubscribed');
export const invocationStarted = createAction<
SocketioPayload & { data: InvocationStartedEvent }
>('socketio/invocationStarted');
export const invocationComplete = createAction<
SocketioPayload & { data: InvocationCompleteEvent }
>('socketio/invocationComplete');
export const invocationError = createAction<
SocketioPayload & { data: InvocationErrorEvent }
>('socketio/invocationError');
export const generatorProgress = createAction<
SocketioPayload & { data: GeneratorProgressEvent }
>('socketio/generatorProgress');

View File

@ -1,93 +0,0 @@
import { Middleware, MiddlewareAPI } from '@reduxjs/toolkit';
import { io } from 'socket.io-client';
import {
GeneratorProgressEvent,
InvocationCompleteEvent,
InvocationErrorEvent,
InvocationStartedEvent,
} from 'services/events/types';
import {
generatorProgress,
invocationComplete,
invocationError,
invocationStarted,
socketioConnected,
socketioDisconnected,
socketioSubscribed,
} from './actions';
import {
receivedResultImagesPage,
receivedUploadImagesPage,
} from 'services/thunks/gallery';
import { AppDispatch, RootState } from 'app/store';
const socket_url = `ws://${window.location.host}`;
const socketio = io(socket_url, {
timeout: 60000,
path: '/ws/socket.io',
});
export const socketioMiddleware = () => {
let areListenersSet = false;
const middleware: Middleware =
(store: MiddlewareAPI<AppDispatch, RootState>) => (next) => (action) => {
const { dispatch, getState } = store;
const timestamp = new Date();
if (!areListenersSet) {
socketio.on('connect', () => {
dispatch(socketioConnected({ timestamp }));
if (!getState().results.ids.length) {
dispatch(receivedResultImagesPage());
}
if (!getState().uploads.ids.length) {
dispatch(receivedUploadImagesPage());
}
});
socketio.on('disconnect', () => {
dispatch(socketioDisconnected({ timestamp }));
socketio.removeAllListeners();
});
}
areListenersSet = true;
if (invocationComplete.match(action)) {
socketio.emit('unsubscribe', {
session: action.payload.data.graph_execution_state_id,
});
socketio.removeAllListeners();
}
if (socketioSubscribed.match(action)) {
socketio.emit('subscribe', { session: action.payload.sessionId });
socketio.on('invocation_started', (data: InvocationStartedEvent) => {
dispatch(invocationStarted({ data, timestamp }));
});
socketio.on('generator_progress', (data: GeneratorProgressEvent) => {
dispatch(generatorProgress({ data, timestamp }));
});
socketio.on('invocation_error', (data: InvocationErrorEvent) => {
dispatch(invocationError({ data, timestamp }));
});
socketio.on('invocation_complete', (data: InvocationCompleteEvent) => {
dispatch(invocationComplete({ data, timestamp }));
});
}
next(action);
};
return middleware;
};

View File

@ -14,11 +14,9 @@ import generationReducer from 'features/parameters/store/generationSlice';
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
import systemReducer from 'features/system/store/systemSlice';
import uiReducer from 'features/ui/store/uiSlice';
import apiReducer from 'services/apiSlice';
import { socketioMiddleware } from './socketio/middleware';
import { socketioMiddleware as nodesSocketioMiddleware } from './nodesSocketio/middleware';
import { invokeMiddleware } from 'services/invokeMiddleware';
import { socketMiddleware } from 'services/events/middleware';
/**
* redux-persist provides an easy and reliable way to persist state across reloads.
@ -81,7 +79,6 @@ const rootReducer = combineReducers({
canvas: canvasReducer,
ui: uiReducer,
lightbox: lightboxReducer,
api: apiReducer,
results: resultsReducer,
uploads: uploadsReducer,
});
@ -107,7 +104,7 @@ const persistedReducer = persistReducer(rootPersistConfig, rootReducer);
function buildMiddleware() {
if (import.meta.env.MODE === 'nodes' || import.meta.env.MODE === 'package') {
return [nodesSocketioMiddleware(), invokeMiddleware];
return [socketMiddleware()];
} else {
return [socketioMiddleware()];
}

View File

@ -1,5 +1,4 @@
import { Flex, Heading, Text, VStack } from '@chakra-ui/react';
import NodeAPITest from 'app/NodeAPITest';
import { useTranslation } from 'react-i18next';
import WorkInProgress from './WorkInProgress';
@ -10,13 +9,18 @@ export default function NodesWIP() {
<Flex
sx={{
flexDirection: 'column',
alignItems: 'center',
justifyContent: 'center',
w: '100%',
h: '100%',
gap: 4,
textAlign: 'center',
}}
>
{/* <NodeAPITest /> */}
<Heading>{t('common.nodes')}</Heading>
<VStack maxW="50rem" gap={4}>
<Text>{t('common.nodesDesc')}</Text>
</VStack>
</Flex>
</WorkInProgress>
);

View File

@ -0,0 +1,6 @@
import dateFormat from 'dateformat';
/**
* Get a `now` timestamp with 1s precision, formatted as ISO datetime.
*/
export const getTimestamp = () => dateFormat(new Date(), 'isoDateTime');

View File

@ -1,11 +1,12 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import * as InvokeAI from 'app/invokeai';
import { invocationComplete } from 'app/nodesSocketio/actions';
import { invocationComplete } from 'services/events/actions';
import { InvokeTabName } from 'features/ui/store/tabMap';
import { IRect } from 'konva/lib/types';
import { clamp } from 'lodash';
import { isImageOutput } from 'services/types/guards';
import { uploadImage } from 'services/thunks/image';
export type GalleryCategory = 'user' | 'result';
@ -25,9 +26,25 @@ export type Gallery = {
};
export interface GalleryState {
/**
* The selected image's unique name
* Use `selectedImageSelector` to access the image
*/
selectedImageName: string;
/**
* The currently selected image
* @deprecated See `state.gallery.selectedImageName`
*/
currentImage?: InvokeAI._Image;
/**
* The currently selected image's uuid.
* @deprecated See `state.gallery.selectedImageName`, use `selectedImageSelector` to access the image
*/
currentImageUuid: string;
/**
* The current progress image
* @deprecated See `state.system.progressImage`
*/
intermediateImage?: InvokeAI._Image & {
boundingBox?: IRect;
generationMode?: InvokeTabName;
@ -263,6 +280,9 @@ export const gallerySlice = createSlice({
},
},
extraReducers(builder) {
/**
* Invocation Complete
*/
builder.addCase(invocationComplete, (state, action) => {
const { data } = action.payload;
if (isImageOutput(data.result)) {
@ -270,6 +290,15 @@ export const gallerySlice = createSlice({
state.intermediateImage = undefined;
}
});
/**
* Upload Image - FULFILLED
*/
builder.addCase(uploadImage.fulfilled, (state, action) => {
const location = action.payload;
const imageName = location.split('/').pop() || '';
state.selectedImageName = imageName;
});
},
});

View File

@ -1,6 +1,6 @@
import { createEntityAdapter, createSlice, isAnyOf } from '@reduxjs/toolkit';
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
import { Image } from 'app/invokeai';
import { invocationComplete } from 'app/nodesSocketio/actions';
import { invocationComplete } from 'services/events/actions';
import { RootState } from 'app/store';
import {
@ -11,7 +11,6 @@ import { isImageOutput } from 'services/types/guards';
import { deserializeImageField } from 'services/util/deserializeImageField';
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
// import { deserializeImageField } from 'services/util/deserializeImageField';
import { setCurrentCategory } from './gallerySlice';
// use `createEntityAdapter` to create a slice for results images
// https://redux-toolkit.js.org/api/createEntityAdapter#overview
@ -55,10 +54,17 @@ const resultsSlice = createSlice({
extraReducers: (builder) => {
// here we can respond to a fulfilled call of the `getNextResultsPage` thunk
// because we pass in the fulfilled thunk action creator, everything is typed
/**
* Received Result Images Page - PENDING
*/
builder.addCase(receivedResultImagesPage.pending, (state) => {
state.isLoading = true;
});
/**
* Received Result Images Page - FULFILLED
*/
builder.addCase(receivedResultImagesPage.fulfilled, (state, action) => {
const { items, page, pages } = action.payload;
@ -75,6 +81,9 @@ const resultsSlice = createSlice({
state.isLoading = false;
});
/**
* Invocation Complete
*/
builder.addCase(invocationComplete, (state, action) => {
const { data } = action.payload;

View File

@ -6,6 +6,7 @@ import {
receivedUploadImagesPage,
IMAGES_PER_PAGE,
} from 'services/thunks/gallery';
import { uploadImage } from 'services/thunks/image';
import { deserializeImageField } from 'services/util/deserializeImageField';
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
@ -33,9 +34,16 @@ const uploadsSlice = createSlice({
uploadAdded: uploadsAdapter.addOne,
},
extraReducers: (builder) => {
/**
* Received Upload Images Page - PENDING
*/
builder.addCase(receivedUploadImagesPage.pending, (state) => {
state.isLoading = true;
});
/**
* Received Upload Images Page - FULFILLED
*/
builder.addCase(receivedUploadImagesPage.fulfilled, (state, action) => {
const { items, page, pages } = action.payload;
@ -48,6 +56,20 @@ const uploadsSlice = createSlice({
state.nextPage = items.length < IMAGES_PER_PAGE ? page : page + 1;
state.isLoading = false;
});
/**
* Upload Image - FULFILLED
*/
builder.addCase(uploadImage.fulfilled, (state, action) => {
const location = action.payload;
const uploadedImage = deserializeImageField({
image_name: location.split('/').pop() || '',
image_type: 'uploads',
});
uploadsAdapter.addOne(state, uploadedImage);
});
},
});

View File

@ -1,9 +1,24 @@
import { useToast } from '@chakra-ui/react';
import { useToast, UseToastOptions } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { toastQueueSelector } from 'features/system/store/systemSelectors';
import { clearToastQueue } from 'features/system/store/systemSlice';
import { useEffect } from 'react';
export type MakeToastArg = string | UseToastOptions;
export const makeToast = (arg: MakeToastArg): UseToastOptions => {
if (typeof arg === 'string') {
return {
title: arg,
status: 'info',
isClosable: true,
duration: 2500,
};
}
return { status: 'info', isClosable: true, duration: 2500, ...arg };
};
const useToastWatcher = () => {
const dispatch = useAppDispatch();
const toastQueue = useAppSelector(toastQueueSelector);

View File

@ -1,4 +1,4 @@
import { ExpandedIndex, StatHelpText, UseToastOptions } from '@chakra-ui/react';
import { ExpandedIndex, UseToastOptions } from '@chakra-ui/react';
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import * as InvokeAI from 'app/invokeai';
@ -7,16 +7,15 @@ import {
invocationComplete,
invocationError,
invocationStarted,
socketioConnected,
socketioDisconnected,
} from 'app/nodesSocketio/actions';
import { resultAdded } from 'features/gallery/store/resultsSlice';
import dateFormat from 'dateformat';
socketConnected,
socketDisconnected,
} from 'services/events/actions';
import i18n from 'i18n';
import { isImageOutput } from 'services/types/guards';
import { ProgressImage } from 'services/events/types';
import { initialImageSelected } from 'features/parameters/store/generationSlice';
import { makeToast } from '../hooks/useToastWatcher';
export type LogLevel = 'info' | 'warning' | 'error';
@ -70,6 +69,9 @@ export interface SystemState
cancelType: CancelType;
cancelAfter: number | null;
};
/**
* The current progress image
*/
progressImage: ProgressImage | null;
}
@ -287,45 +289,55 @@ export const systemSlice = createSlice({
setCancelAfter: (state, action: PayloadAction<number | null>) => {
state.cancelOptions.cancelAfter = action.payload;
},
// socketioConnected: (state) => {
// state.isConnected = true;
// state.currentStatus = i18n.t('common.statusConnected');
// },
// socketioDisconnected: (state) => {
// state.isConnected = false;
// state.currentStatus = i18n.t('common.statusDisconnected');
// },
},
extraReducers(builder) {
builder.addCase(socketioConnected, (state, action) => {
/**
* Socket Connected
*/
builder.addCase(socketConnected, (state, action) => {
const { timestamp } = action.payload;
state.isConnected = true;
state.currentStatus = i18n.t('common.statusConnected');
state.log.push({
timestamp: dateFormat(timestamp, 'isoDateTime'),
timestamp,
message: `Connected to server`,
level: 'info',
});
state.toastQueue.push(
makeToast({ title: i18n.t('toast.connected'), status: 'success' })
);
});
builder.addCase(socketioDisconnected, (state, action) => {
/**
* Socket Disconnected
*/
builder.addCase(socketDisconnected, (state, action) => {
const { timestamp } = action.payload;
state.isConnected = false;
state.currentStatus = i18n.t('common.statusDisconnected');
state.log.push({
timestamp: dateFormat(timestamp, 'isoDateTime'),
timestamp,
message: `Disconnected from server`,
level: 'warning',
level: 'error',
});
state.toastQueue.push(
makeToast({ title: i18n.t('toast.disconnected'), status: 'error' })
);
});
builder.addCase(invocationStarted, (state, action) => {
/**
* Invocation Started
*/
builder.addCase(invocationStarted, (state) => {
state.isProcessing = true;
state.currentStatusHasSteps = false;
});
/**
* Generator Progress
*/
builder.addCase(generatorProgress, (state, action) => {
const { step, total_steps, progress_image } = action.payload.data;
@ -335,6 +347,9 @@ export const systemSlice = createSlice({
state.progressImage = progress_image ?? null;
});
/**
* Invocation Complete
*/
builder.addCase(invocationComplete, (state, action) => {
const { data, timestamp } = action.payload;
@ -346,18 +361,21 @@ export const systemSlice = createSlice({
// TODO: handle logging for other invocation types
if (isImageOutput(data.result)) {
state.log.push({
timestamp: dateFormat(timestamp, 'isoDateTime'),
timestamp,
message: `Generated: ${data.result.image.image_name}`,
level: 'info',
});
}
});
/**
* Invocation Error
*/
builder.addCase(invocationError, (state, action) => {
const { data, timestamp } = action.payload;
state.log.push({
timestamp: dateFormat(timestamp, 'isoDateTime'),
timestamp,
message: `Server error: ${data.error}`,
level: 'error',
});
@ -365,15 +383,16 @@ export const systemSlice = createSlice({
state.wasErrorSeen = true;
state.progressImage = null;
state.isProcessing = false;
state.toastQueue.push(
makeToast({ title: i18n.t('toast.serverError'), status: 'error' })
);
});
/**
* Initial Image Selected
*/
builder.addCase(initialImageSelected, (state) => {
state.toastQueue.push({
title: i18n.t('toast.sentToImageToImage'),
status: 'success',
duration: 2500,
isClosable: true,
});
state.toastQueue.push(makeToast(i18n.t('toast.sentToImageToImage')));
});
},
});
@ -410,8 +429,6 @@ export const {
setOpenModel,
setCancelType,
setCancelAfter,
// socketioConnected,
// socketioDisconnected,
} = systemSlice.actions;
export default systemSlice.reducer;

View File

@ -1,119 +0,0 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import { ProgressImage } from './events/types';
import { createSession, invokeSession } from 'services/thunks/session';
import { getImage, uploadImage } from './thunks/image';
import { invocationComplete } from 'app/nodesSocketio/actions';
/**
* Just temp until we work out better statuses
*/
export enum STATUS {
idle = 'IDLE',
busy = 'BUSY',
error = 'ERROR',
}
/**
* Type for the temp (?) API slice.
*/
export interface APIState {
sessionId: string;
progressImage: ProgressImage | null;
progress: number | null;
status: STATUS;
}
const initialSystemState: APIState = {
sessionId: '',
status: STATUS.idle,
progress: null,
progressImage: null,
};
export const apiSlice = createSlice({
name: 'api',
initialState: initialSystemState,
reducers: {
setSessionId: (state, action: PayloadAction<APIState['sessionId']>) => {
state.sessionId = action.payload;
},
setStatus: (state, action: PayloadAction<APIState['status']>) => {
state.status = action.payload;
},
setProgressImage: (
state,
action: PayloadAction<APIState['progressImage']>
) => {
state.progressImage = action.payload;
},
setProgress: (state, action: PayloadAction<APIState['progress']>) => {
state.progress = action.payload;
},
},
extraReducers: (builder) => {
builder.addCase(createSession.fulfilled, (state, action) => {
const {
payload: { id },
} = action;
// HTTP 200
// state.networkStatus = 'idle'
state.sessionId = id;
});
builder.addCase(createSession.pending, (state, action) => {
// HTTP request pending
// state.networkStatus = 'busy'
});
builder.addCase(createSession.rejected, (state, action) => {
// !HTTP 200
console.error('createSession rejected: ', action);
// state.networkStatus = 'idle'
});
builder.addCase(invokeSession.fulfilled, (state, action) => {
console.log('invokeSession.fulfilled: ', action.payload);
// HTTP 200
// state.networkStatus = 'idle'
});
builder.addCase(invokeSession.pending, (state, action) => {
// HTTP request pending
// state.networkStatus = 'busy'
});
builder.addCase(invokeSession.rejected, (state, action) => {
// state.networkStatus = 'idle'
});
builder.addCase(getImage.fulfilled, (state, action) => {
// !HTTP 200
console.log(action.payload);
// state.networkStatus = 'idle'
});
builder.addCase(getImage.pending, (state, action) => {
// HTTP request pending
// state.networkStatus = 'busy'
});
builder.addCase(getImage.rejected, (state, action) => {
// !HTTP 200
// state.networkStatus = 'idle'
});
builder.addCase(uploadImage.fulfilled, (state, action) => {
// !HTTP 200
console.log(action.payload);
// state.networkStatus = 'idle'
});
builder.addCase(uploadImage.pending, (state, action) => {
// HTTP request pending
// state.networkStatus = 'busy'
});
builder.addCase(uploadImage.rejected, (state, action) => {
// !HTTP 200
// state.networkStatus = 'idle'
});
builder.addCase(invocationComplete, (state) => {
state.sessionId = '';
});
},
});
export const { setSessionId, setStatus, setProgressImage, setProgress } =
apiSlice.actions;
export default apiSlice.reducer;

View File

@ -0,0 +1,47 @@
import { createAction } from '@reduxjs/toolkit';
import {
GeneratorProgressEvent,
InvocationCompleteEvent,
InvocationErrorEvent,
InvocationStartedEvent,
} from 'services/events/types';
// Common socket action payload data
type BaseSocketPayload = {
timestamp: string;
};
// Create actions for each socket event
// Middleware and redux can then respond to them as needed
export const socketConnected = createAction<BaseSocketPayload>(
'socket/socketConnected'
);
export const socketDisconnected = createAction<BaseSocketPayload>(
'socket/socketDisconnected'
);
export const socketSubscribed = createAction<
BaseSocketPayload & { sessionId: string }
>('socket/socketSubscribed');
export const socketUnsubscribed = createAction<
BaseSocketPayload & { sessionId: string }
>('socket/socketUnsubscribed');
export const invocationStarted = createAction<
BaseSocketPayload & { data: InvocationStartedEvent }
>('socket/invocationStarted');
export const invocationComplete = createAction<
BaseSocketPayload & { data: InvocationCompleteEvent }
>('socket/invocationComplete');
export const invocationError = createAction<
BaseSocketPayload & { data: InvocationErrorEvent }
>('socket/invocationError');
export const generatorProgress = createAction<
BaseSocketPayload & { data: GeneratorProgressEvent }
>('socket/generatorProgress');

View File

@ -0,0 +1,128 @@
import { Middleware, MiddlewareAPI } from '@reduxjs/toolkit';
import { io } from 'socket.io-client';
import {
GeneratorProgressEvent,
InvocationCompleteEvent,
InvocationErrorEvent,
InvocationStartedEvent,
} from 'services/events/types';
import {
generatorProgress,
invocationComplete,
invocationError,
invocationStarted,
socketConnected,
socketDisconnected,
socketSubscribed,
socketUnsubscribed,
} from './actions';
import {
receivedResultImagesPage,
receivedUploadImagesPage,
} from 'services/thunks/gallery';
import { AppDispatch, RootState } from 'app/store';
import { getTimestamp } from 'common/util/getTimestamp';
import {
invokeSession,
isFulfilledCreateSession,
} from 'services/thunks/session';
const socket_url = `ws://${window.location.host}`;
const socket = io(socket_url, {
timeout: 60000,
path: '/ws/socket.io',
});
export const socketMiddleware = () => {
let areListenersSet = false;
const middleware: Middleware =
(store: MiddlewareAPI<AppDispatch, RootState>) => (next) => (action) => {
const { dispatch, getState } = store;
// Set listeners for `connect` and `disconnect` events once
// Must happen in middleware to get access to `dispatch`
if (!areListenersSet) {
socket.on('connect', () => {
dispatch(socketConnected({ timestamp: getTimestamp() }));
// These thunks need to be dispatch in middleware; cannot handle in a reducer
if (!getState().results.ids.length) {
dispatch(receivedResultImagesPage());
}
if (!getState().uploads.ids.length) {
dispatch(receivedUploadImagesPage());
}
});
socket.on('disconnect', () => {
dispatch(socketDisconnected({ timestamp: getTimestamp() }));
});
areListenersSet = true;
}
// Everything else only happens once we have created a session
if (isFulfilledCreateSession(action)) {
const sessionId = action.payload.id;
// After a session is created, we immediately subscribe to events and then invoke the session
socket.emit('subscribe', { session: sessionId });
// Always dispatch the event actions for other consumers who want to know when we subscribed
dispatch(
socketSubscribed({
sessionId,
timestamp: getTimestamp(),
})
);
// Set up listeners for the present subscription
socket.on('invocation_started', (data: InvocationStartedEvent) => {
dispatch(invocationStarted({ data, timestamp: getTimestamp() }));
});
socket.on('generator_progress', (data: GeneratorProgressEvent) => {
dispatch(generatorProgress({ data, timestamp: getTimestamp() }));
});
socket.on('invocation_error', (data: InvocationErrorEvent) => {
dispatch(invocationError({ data, timestamp: getTimestamp() }));
});
socket.on('invocation_complete', (data: InvocationCompleteEvent) => {
const sessionId = data.graph_execution_state_id;
// Unsubscribe when invocations complete
socket.emit('unsubscribe', {
session: sessionId,
});
dispatch(
socketUnsubscribed({ sessionId, timestamp: getTimestamp() })
);
// Remove listeners for these events; we need to set them up fresh whenever we subscribe
[
'invocation_started',
'generator_progress',
'invocation_error',
'invocation_complete',
].forEach((event) => socket.removeAllListeners(event));
dispatch(invocationComplete({ data, timestamp: getTimestamp() }));
});
// Finally we actually invoke the session, starting processing
dispatch(invokeSession({ sessionId }));
}
// Always pass the action on so other middleware and reducers can handle it
next(action);
};
return middleware;
};

View File

@ -8,6 +8,8 @@
* Patches the request logic in such a way that we can extract headers from requests.
*
* Copied from https://github.com/ferdikoomen/openapi-typescript-codegen/issues/829#issuecomment-1228224477
*
* This file should be excluded in `tsconfig.json` and ignored by prettier/eslint!
*/
import axios from 'axios';

View File

@ -1,13 +0,0 @@
import json
from invokeai.app.api_app import app
from fastapi.openapi.utils import get_openapi
openapi_doc = get_openapi(
title=app.title,
version=app.version,
openapi_version=app.openapi_version,
routes=app.routes,
)
with open("./openapi.json", "w") as f:
json.dump(openapi_doc, f)

View File

@ -1,65 +0,0 @@
import { isFulfilled, Middleware, MiddlewareAPI } from '@reduxjs/toolkit';
import { v4 as uuidv4 } from 'uuid';
import { socketioSubscribed } from 'app/nodesSocketio/actions';
import { AppDispatch, RootState } from 'app/store';
import { setSessionId } from './apiSlice';
import { uploadImage } from './thunks/image';
import { createSession, invokeSession } from './thunks/session';
import * as InvokeAI from 'app/invokeai';
import { addImage } from 'features/gallery/store/gallerySlice';
import { tabMap } from 'features/ui/store/tabMap';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { initialImageSelected as initialImageSet } from 'features/parameters/store/generationSlice';
/**
* `redux-toolkit` provides nice matching utilities, which can be used as type guards
* See: https://redux-toolkit.js.org/api/matching-utilities
*/
const isFulfilledCreateSession = isFulfilled(createSession);
const isFulfilledUploadImage = isFulfilled(uploadImage);
export const invokeMiddleware: Middleware =
(store: MiddlewareAPI<AppDispatch, RootState>) => (next) => (action) => {
const { dispatch, getState } = store;
const timestamp = new Date();
if (isFulfilledCreateSession(action)) {
const sessionId = action.payload.id;
console.log('createSession.fulfilled');
dispatch(setSessionId(sessionId));
dispatch(socketioSubscribed({ sessionId, timestamp }));
dispatch(invokeSession({ sessionId }));
} else if (isFulfilledUploadImage(action)) {
const uploadLocation = action.payload;
console.log('uploadImage.fulfilled');
// TODO: actually get correct attributes here
const newImage: InvokeAI._Image = {
uuid: uuidv4(),
category: 'user',
url: uploadLocation,
width: 512,
height: 512,
mtime: new Date().getTime(),
thumbnail: uploadLocation,
};
dispatch(addImage({ image: newImage, category: 'user' }));
const { activeTab } = getState().ui;
const activeTabName = tabMap[activeTab];
if (activeTabName === 'unifiedCanvas') {
dispatch(setInitialCanvasImage(newImage));
} else if (activeTabName === 'img2img') {
// dispatch(setInitialImage(newImage));
dispatch(initialImageSet(newImage.uuid));
}
} else {
next(action);
}
};

View File

@ -1,10 +1,13 @@
import { isFulfilled } from '@reduxjs/toolkit';
import { createAppAsyncThunk } from 'app/storeUtils';
import { ImagesService } from 'services/api';
import { getHeaders } from 'services/util/getHeaders';
type GetImageArg = Parameters<(typeof ImagesService)['getImage']>[0];
// createAppAsyncThunk provides typing for getState and dispatch
/**
* `ImagesService.getImage()` thunk
*/
export const getImage = createAppAsyncThunk(
'api/getImage',
async (arg: GetImageArg, _thunkApi) => {
@ -15,6 +18,9 @@ export const getImage = createAppAsyncThunk(
type UploadImageArg = Parameters<(typeof ImagesService)['uploadImage']>[0];
/**
* `ImagesService.uploadImage()` thunk
*/
export const uploadImage = createAppAsyncThunk(
'api/uploadImage',
async (arg: UploadImageArg, _thunkApi) => {
@ -23,3 +29,8 @@ export const uploadImage = createAppAsyncThunk(
return location;
}
);
/**
* Function to check if an action is a fulfilled `ImagesService.uploadImage()` thunk
*/
export const isFulfilledUploadImage = isFulfilled(uploadImage);

View File

@ -1,21 +1,15 @@
import { createAppAsyncThunk } from 'app/storeUtils';
import { SessionsService } from 'services/api';
import { buildGraph } from 'common/util/buildGraph';
/**
* createSession thunk
*/
/**
* Extract the type of the requestBody from the generated API client.
*
* Would really like for this to be generated but it's easy enough to extract it.
*/
import { isFulfilled } from '@reduxjs/toolkit';
type CreateSessionArg = Parameters<
(typeof SessionsService)['createSession']
>[0];
/**
* `SessionsService.createSession()` thunk
*/
export const createSession = createAppAsyncThunk(
'api/createSession',
async (arg: CreateSessionArg['requestBody'], _thunkApi) => {
@ -35,11 +29,15 @@ export const createSession = createAppAsyncThunk(
);
/**
* addNode thunk
* Function to check if an action is a fulfilled `SessionsService.createSession()` thunk
*/
export const isFulfilledCreateSession = isFulfilled(createSession);
type AddNodeArg = Parameters<(typeof SessionsService)['addNode']>[0];
/**
* `SessionsService.addNode()` thunk
*/
export const addNode = createAppAsyncThunk(
'api/addNode',
async (
@ -56,9 +54,8 @@ export const addNode = createAppAsyncThunk(
);
/**
* invokeSession thunk
* `SessionsService.invokeSession()` thunk
*/
export const invokeSession = createAppAsyncThunk(
'api/invokeSession',
async (arg: { sessionId: string }, _thunkApi) => {
@ -73,14 +70,13 @@ export const invokeSession = createAppAsyncThunk(
}
);
/**
* cancelSession thunk
*/
type CancelSessionArg = Parameters<
(typeof SessionsService)['cancelSessionInvoke']
>[0];
/**
* `SessionsService.cancelSession()` thunk
*/
export const cancelProcessing = createAppAsyncThunk(
'api/cancelProcessing',
async (arg: CancelSessionArg, _thunkApi) => {
@ -94,12 +90,11 @@ export const cancelProcessing = createAppAsyncThunk(
}
);
/**
* listSessions thunk
*/
type ListSessionsArg = Parameters<(typeof SessionsService)['listSessions']>[0];
/**
* `SessionsService.listSessions()` thunk
*/
export const listSessions = createAppAsyncThunk(
'api/listSessions',
async (arg: ListSessionsArg, _thunkApi) => {

View File

@ -28,8 +28,11 @@ export const extractTimestampFromImageName = (imageName: string) => {
};
/**
* Process ImageField objects. These come from `invocation_complete` events and do not contain all the data we need.
* This is a WIP on the server side.
* Process ImageField objects. These come from `invocation_complete` events and do not contain all
* the data we need, so we need to do some janky stuff to get urls and timestamps.
*
* TODO: do some more janky stuff here to get image dimensions instead of defaulting to 512x512?
* TODO: better yet, improve the nodes server (wip)
*/
export const deserializeImageField = (image: ImageField): Image => {
const name = image.image_name;
@ -46,7 +49,7 @@ export const deserializeImageField = (image: ImageField): Image => {
thumbnail,
metadata: {
timestamp,
height: 512, // TODO: need the server to give this to us
height: 512,
width: 512,
},
};

View File

@ -1,7 +1,7 @@
import { HEADERS } from '../api/core/request';
/**
* Returns the headers of a given response object
* Returns the response headers of the response received by the generated API client.
*/
export const getHeaders = (response: any): Record<string, string> => {
if (!(HEADERS in response)) {

View File

@ -18,5 +18,6 @@
"jsx": "react-jsx"
},
"include": ["src", "index.d.ts"],
"exclude": ["src/services/fixtures/*"],
"references": [{ "path": "./tsconfig.node.json" }]
}