feat(ui): start hooking up dynamic txt2img node generation, create middleware for session invocation

This commit is contained in:
maryhipp 2023-03-29 13:32:22 -07:00 committed by psychedelicious
parent 3ebd289a59
commit 4fe49718e0
7 changed files with 90 additions and 18 deletions

View File

@ -114,6 +114,8 @@ const NodeAPITest = () => {
return; return;
} }
setResultImages([]);
// set up socket.io listeners // set up socket.io listeners
// TODO: suppose this should be handled in the socket.io middleware? // TODO: suppose this should be handled in the socket.io middleware?
@ -190,10 +192,16 @@ const NodeAPITest = () => {
<IAIButton onClick={handleCancelProcessing} colorScheme="error"> <IAIButton onClick={handleCancelProcessing} colorScheme="error">
Cancel Processing Cancel Processing
</IAIButton> </IAIButton>
<IAIButton onClick={handleCreateSession} colorScheme="accent">
Create Session
</IAIButton>
<IAIButton <IAIButton
onClick={handleCreateSession}
colorScheme="accent"
loadingText={`Invoking ${
progress === null ? '...' : `${Math.round(progress * 100)}%`
}`}
>
Create Session & Invoke
</IAIButton>
{/* <IAIButton
onClick={handleInvokeSession} onClick={handleInvokeSession}
loadingText={`Invoking ${ loadingText={`Invoking ${
progress === null ? '...' : `${Math.round(progress * 100)}%` progress === null ? '...' : `${Math.round(progress * 100)}%`
@ -201,7 +209,7 @@ const NodeAPITest = () => {
colorScheme="accent" colorScheme="accent"
> >
Invoke Invoke
</IAIButton> </IAIButton> */}
<Flex wrap="wrap" gap={4} overflow="scroll"> <Flex wrap="wrap" gap={4} overflow="scroll">
<Image <Image
src={progressImage?.dataURL} src={progressImage?.dataURL}

View File

@ -15,6 +15,7 @@ import uiReducer from 'features/ui/store/uiSlice';
import apiReducer from 'services/apiSlice'; import apiReducer from 'services/apiSlice';
import { socketioMiddleware } from './socketio/middleware'; import { socketioMiddleware } from './socketio/middleware';
import { invokeMiddleware } from 'services/invokeMiddleware';
/** /**
* redux-persist provides an easy and reliable way to persist state across reloads. * redux-persist provides an easy and reliable way to persist state across reloads.
@ -103,7 +104,7 @@ export const store = configureStore({
getDefaultMiddleware({ getDefaultMiddleware({
immutableCheck: false, immutableCheck: false,
serializableCheck: false, serializableCheck: false,
}).concat(socketioMiddleware()), }).concat(socketioMiddleware(), invokeMiddleware),
devTools: { devTools: {
// Uncommenting these very rapidly called actions makes the redux dev tools output much more readable // Uncommenting these very rapidly called actions makes the redux dev tools output much more readable
actionsDenylist: [ actionsDenylist: [

View File

@ -0,0 +1,31 @@
import { v4 as uuidv4 } from 'uuid';
import { RootState } from 'app/store';
import { InvokeTabName, tabMap } from 'features/ui/store/tabMap';
import { Graph } from 'services/api';
import { buildTxt2ImgNode } from './buildNodes';
function mapTabToFunction(activeTabName: InvokeTabName) {
switch (activeTabName) {
case 'txt2img':
return buildTxt2ImgNode;
default:
return buildTxt2ImgNode;
}
}
export const buildGraph = (state: RootState): Graph => {
const { activeTab } = state.ui;
const activeTabName = tabMap[activeTab];
const nodeId = uuidv4();
return {
nodes: {
[nodeId]: {
id: nodeId,
...mapTabToFunction(activeTabName)(state),
},
},
};
};

View File

@ -1,4 +1,3 @@
import { v4 as uuidv4 } from 'uuid';
import { RootState } from 'app/store'; import { RootState } from 'app/store';
import { import {
ImageToImageInvocation, ImageToImageInvocation,
@ -12,10 +11,12 @@ import {
// be todo add symmetry fields // be todo add symmetry fields
// be todo variations.... // be todo variations....
export function buildTxt2ImgNode(state: RootState): TextToImageInvocation { export function buildTxt2ImgNode(
state: RootState
): Omit<TextToImageInvocation, 'id'> {
const { generation, system } = state; const { generation, system } = state;
const { shouldDisplayInProgressType, openModel } = system; const { shouldDisplayInProgressType, model } = system;
const { const {
prompt, prompt,
@ -30,7 +31,6 @@ export function buildTxt2ImgNode(state: RootState): TextToImageInvocation {
// missing fields in TextToImageInvocation: strength, hires_fix // missing fields in TextToImageInvocation: strength, hires_fix
return { return {
id: uuidv4(),
type: 'txt2img', type: 'txt2img',
prompt, prompt,
seed, seed,
@ -40,12 +40,14 @@ export function buildTxt2ImgNode(state: RootState): TextToImageInvocation {
cfg_scale, cfg_scale,
sampler_name: sampler as TextToImageInvocation['sampler_name'], sampler_name: sampler as TextToImageInvocation['sampler_name'],
seamless, seamless,
model: openModel as string | undefined, model,
progress_images: shouldDisplayInProgressType === 'full-res', progress_images: shouldDisplayInProgressType === 'full-res',
}; };
} }
export function buildImg2ImgNode(state: RootState): ImageToImageInvocation { export function buildImg2ImgNode(
state: RootState
): Omit<ImageToImageInvocation, 'id'> {
const { generation, system } = state; const { generation, system } = state;
const { shouldDisplayInProgressType, openModel: model } = system; const { shouldDisplayInProgressType, openModel: model } = system;
@ -65,7 +67,6 @@ export function buildImg2ImgNode(state: RootState): ImageToImageInvocation {
} = generation; } = generation;
return { return {
id: 'a',
type: 'img2img', type: 'img2img',
prompt, prompt,
seed, seed,
@ -86,7 +87,9 @@ export function buildImg2ImgNode(state: RootState): ImageToImageInvocation {
}; };
} }
export function buildFacetoolNode(state: RootState): RestoreFaceInvocation { export function buildFacetoolNode(
state: RootState
): Omit<RestoreFaceInvocation, 'id'> {
const { generation, postprocessing } = state; const { generation, postprocessing } = state;
const { initialImage } = generation; const { initialImage } = generation;
@ -95,7 +98,6 @@ export function buildFacetoolNode(state: RootState): RestoreFaceInvocation {
// missing fields in RestoreFaceInvocation: type, codeformer_fidelity // missing fields in RestoreFaceInvocation: type, codeformer_fidelity
return { return {
id: uuidv4(),
type: 'restore_face', type: 'restore_face',
image: { image: {
image_name: image_name:
@ -106,7 +108,9 @@ export function buildFacetoolNode(state: RootState): RestoreFaceInvocation {
} }
// is this ESRGAN?? // is this ESRGAN??
export function buildUpscaleNode(state: RootState): UpscaleInvocation { export function buildUpscaleNode(
state: RootState
): Omit<UpscaleInvocation, 'id'> {
const { generation, postprocessing } = state; const { generation, postprocessing } = state;
const { initialImage } = generation; const { initialImage } = generation;
@ -115,7 +119,6 @@ export function buildUpscaleNode(state: RootState): UpscaleInvocation {
// missing fields in UpscaleInvocation: denoise_str // missing fields in UpscaleInvocation: denoise_str
return { return {
id: uuidv4(),
type: 'upscale', type: 'upscale',
image: { image: {
image_name: image_name:

View File

@ -11,6 +11,7 @@ import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { useHotkeys } from 'react-hotkeys-hook'; import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { FaPlay } from 'react-icons/fa'; import { FaPlay } from 'react-icons/fa';
import { createSession } from 'services/thunks/session';
interface InvokeButton interface InvokeButton
extends Omit<IAIButtonProps | IAIIconButtonProps, 'aria-label'> { extends Omit<IAIButtonProps | IAIIconButtonProps, 'aria-label'> {
@ -24,7 +25,8 @@ export default function InvokeButton(props: InvokeButton) {
const activeTabName = useAppSelector(activeTabNameSelector); const activeTabName = useAppSelector(activeTabNameSelector);
const handleClickGenerate = () => { const handleClickGenerate = () => {
dispatch(generateImage(activeTabName)); // dispatch(generateImage(activeTabName));
dispatch(createSession());
}; };
const { t } = useTranslation(); const { t } = useTranslation();

View File

@ -0,0 +1,17 @@
import { Middleware } from '@reduxjs/toolkit';
import { setSessionId } from './apiSlice';
import { invokeSession } from './thunks/session';
export const invokeMiddleware: Middleware = (store) => (next) => (action) => {
const { dispatch } = store;
if (action.type === 'api/createSession/fulfilled' && action?.payload?.id) {
console.log('createSession.fulfilled');
dispatch(setSessionId(action.payload.id));
// types are wrong but this works
dispatch(invokeSession({ sessionId: action.payload.id }));
} else {
next(action);
}
};

View File

@ -1,5 +1,6 @@
import { createAppAsyncThunk } from 'app/storeUtils'; import { createAppAsyncThunk } from 'app/storeUtils';
import { SessionsService } from 'services/api'; import { SessionsService } from 'services/api';
import { buildGraph } from 'common/util/buildGraph';
/** /**
* createSession thunk * createSession thunk
@ -18,7 +19,16 @@ type CreateSessionRequestBody = Parameters<
export const createSession = createAppAsyncThunk( export const createSession = createAppAsyncThunk(
'api/createSession', 'api/createSession',
async (arg: CreateSessionRequestBody, _thunkApi) => { async (arg: CreateSessionRequestBody, _thunkApi) => {
const response = await SessionsService.createSession({ requestBody: arg }); let graph = arg;
if (!arg) {
const { getState } = _thunkApi;
const state = getState();
graph = buildGraph(state);
}
const response = await SessionsService.createSession({
requestBody: graph,
});
return response; return response;
} }