mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): start hooking up dynamic txt2img node generation, create middleware for session invocation
This commit is contained in:
parent
3ebd289a59
commit
4fe49718e0
@ -114,6 +114,8 @@ const NodeAPITest = () => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
setResultImages([]);
|
||||||
|
|
||||||
// set up socket.io listeners
|
// set up socket.io listeners
|
||||||
|
|
||||||
// TODO: suppose this should be handled in the socket.io middleware?
|
// TODO: suppose this should be handled in the socket.io middleware?
|
||||||
@ -190,10 +192,16 @@ const NodeAPITest = () => {
|
|||||||
<IAIButton onClick={handleCancelProcessing} colorScheme="error">
|
<IAIButton onClick={handleCancelProcessing} colorScheme="error">
|
||||||
Cancel Processing
|
Cancel Processing
|
||||||
</IAIButton>
|
</IAIButton>
|
||||||
<IAIButton onClick={handleCreateSession} colorScheme="accent">
|
|
||||||
Create Session
|
|
||||||
</IAIButton>
|
|
||||||
<IAIButton
|
<IAIButton
|
||||||
|
onClick={handleCreateSession}
|
||||||
|
colorScheme="accent"
|
||||||
|
loadingText={`Invoking ${
|
||||||
|
progress === null ? '...' : `${Math.round(progress * 100)}%`
|
||||||
|
}`}
|
||||||
|
>
|
||||||
|
Create Session & Invoke
|
||||||
|
</IAIButton>
|
||||||
|
{/* <IAIButton
|
||||||
onClick={handleInvokeSession}
|
onClick={handleInvokeSession}
|
||||||
loadingText={`Invoking ${
|
loadingText={`Invoking ${
|
||||||
progress === null ? '...' : `${Math.round(progress * 100)}%`
|
progress === null ? '...' : `${Math.round(progress * 100)}%`
|
||||||
@ -201,7 +209,7 @@ const NodeAPITest = () => {
|
|||||||
colorScheme="accent"
|
colorScheme="accent"
|
||||||
>
|
>
|
||||||
Invoke
|
Invoke
|
||||||
</IAIButton>
|
</IAIButton> */}
|
||||||
<Flex wrap="wrap" gap={4} overflow="scroll">
|
<Flex wrap="wrap" gap={4} overflow="scroll">
|
||||||
<Image
|
<Image
|
||||||
src={progressImage?.dataURL}
|
src={progressImage?.dataURL}
|
||||||
|
@ -15,6 +15,7 @@ import uiReducer from 'features/ui/store/uiSlice';
|
|||||||
import apiReducer from 'services/apiSlice';
|
import apiReducer from 'services/apiSlice';
|
||||||
|
|
||||||
import { socketioMiddleware } from './socketio/middleware';
|
import { socketioMiddleware } from './socketio/middleware';
|
||||||
|
import { invokeMiddleware } from 'services/invokeMiddleware';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* redux-persist provides an easy and reliable way to persist state across reloads.
|
* redux-persist provides an easy and reliable way to persist state across reloads.
|
||||||
@ -103,7 +104,7 @@ export const store = configureStore({
|
|||||||
getDefaultMiddleware({
|
getDefaultMiddleware({
|
||||||
immutableCheck: false,
|
immutableCheck: false,
|
||||||
serializableCheck: false,
|
serializableCheck: false,
|
||||||
}).concat(socketioMiddleware()),
|
}).concat(socketioMiddleware(), invokeMiddleware),
|
||||||
devTools: {
|
devTools: {
|
||||||
// Uncommenting these very rapidly called actions makes the redux dev tools output much more readable
|
// Uncommenting these very rapidly called actions makes the redux dev tools output much more readable
|
||||||
actionsDenylist: [
|
actionsDenylist: [
|
||||||
|
31
invokeai/frontend/web/src/common/util/buildGraph.ts
Normal file
31
invokeai/frontend/web/src/common/util/buildGraph.ts
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
|
|
||||||
|
import { RootState } from 'app/store';
|
||||||
|
import { InvokeTabName, tabMap } from 'features/ui/store/tabMap';
|
||||||
|
import { Graph } from 'services/api';
|
||||||
|
import { buildTxt2ImgNode } from './buildNodes';
|
||||||
|
|
||||||
|
function mapTabToFunction(activeTabName: InvokeTabName) {
|
||||||
|
switch (activeTabName) {
|
||||||
|
case 'txt2img':
|
||||||
|
return buildTxt2ImgNode;
|
||||||
|
|
||||||
|
default:
|
||||||
|
return buildTxt2ImgNode;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export const buildGraph = (state: RootState): Graph => {
|
||||||
|
const { activeTab } = state.ui;
|
||||||
|
const activeTabName = tabMap[activeTab];
|
||||||
|
const nodeId = uuidv4();
|
||||||
|
|
||||||
|
return {
|
||||||
|
nodes: {
|
||||||
|
[nodeId]: {
|
||||||
|
id: nodeId,
|
||||||
|
...mapTabToFunction(activeTabName)(state),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
};
|
@ -1,4 +1,3 @@
|
|||||||
import { v4 as uuidv4 } from 'uuid';
|
|
||||||
import { RootState } from 'app/store';
|
import { RootState } from 'app/store';
|
||||||
import {
|
import {
|
||||||
ImageToImageInvocation,
|
ImageToImageInvocation,
|
||||||
@ -12,10 +11,12 @@ import {
|
|||||||
// be todo add symmetry fields
|
// be todo add symmetry fields
|
||||||
// be todo variations....
|
// be todo variations....
|
||||||
|
|
||||||
export function buildTxt2ImgNode(state: RootState): TextToImageInvocation {
|
export function buildTxt2ImgNode(
|
||||||
|
state: RootState
|
||||||
|
): Omit<TextToImageInvocation, 'id'> {
|
||||||
const { generation, system } = state;
|
const { generation, system } = state;
|
||||||
|
|
||||||
const { shouldDisplayInProgressType, openModel } = system;
|
const { shouldDisplayInProgressType, model } = system;
|
||||||
|
|
||||||
const {
|
const {
|
||||||
prompt,
|
prompt,
|
||||||
@ -30,7 +31,6 @@ export function buildTxt2ImgNode(state: RootState): TextToImageInvocation {
|
|||||||
|
|
||||||
// missing fields in TextToImageInvocation: strength, hires_fix
|
// missing fields in TextToImageInvocation: strength, hires_fix
|
||||||
return {
|
return {
|
||||||
id: uuidv4(),
|
|
||||||
type: 'txt2img',
|
type: 'txt2img',
|
||||||
prompt,
|
prompt,
|
||||||
seed,
|
seed,
|
||||||
@ -40,12 +40,14 @@ export function buildTxt2ImgNode(state: RootState): TextToImageInvocation {
|
|||||||
cfg_scale,
|
cfg_scale,
|
||||||
sampler_name: sampler as TextToImageInvocation['sampler_name'],
|
sampler_name: sampler as TextToImageInvocation['sampler_name'],
|
||||||
seamless,
|
seamless,
|
||||||
model: openModel as string | undefined,
|
model,
|
||||||
progress_images: shouldDisplayInProgressType === 'full-res',
|
progress_images: shouldDisplayInProgressType === 'full-res',
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
export function buildImg2ImgNode(state: RootState): ImageToImageInvocation {
|
export function buildImg2ImgNode(
|
||||||
|
state: RootState
|
||||||
|
): Omit<ImageToImageInvocation, 'id'> {
|
||||||
const { generation, system } = state;
|
const { generation, system } = state;
|
||||||
|
|
||||||
const { shouldDisplayInProgressType, openModel: model } = system;
|
const { shouldDisplayInProgressType, openModel: model } = system;
|
||||||
@ -65,7 +67,6 @@ export function buildImg2ImgNode(state: RootState): ImageToImageInvocation {
|
|||||||
} = generation;
|
} = generation;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
id: 'a',
|
|
||||||
type: 'img2img',
|
type: 'img2img',
|
||||||
prompt,
|
prompt,
|
||||||
seed,
|
seed,
|
||||||
@ -86,7 +87,9 @@ export function buildImg2ImgNode(state: RootState): ImageToImageInvocation {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
export function buildFacetoolNode(state: RootState): RestoreFaceInvocation {
|
export function buildFacetoolNode(
|
||||||
|
state: RootState
|
||||||
|
): Omit<RestoreFaceInvocation, 'id'> {
|
||||||
const { generation, postprocessing } = state;
|
const { generation, postprocessing } = state;
|
||||||
|
|
||||||
const { initialImage } = generation;
|
const { initialImage } = generation;
|
||||||
@ -95,7 +98,6 @@ export function buildFacetoolNode(state: RootState): RestoreFaceInvocation {
|
|||||||
|
|
||||||
// missing fields in RestoreFaceInvocation: type, codeformer_fidelity
|
// missing fields in RestoreFaceInvocation: type, codeformer_fidelity
|
||||||
return {
|
return {
|
||||||
id: uuidv4(),
|
|
||||||
type: 'restore_face',
|
type: 'restore_face',
|
||||||
image: {
|
image: {
|
||||||
image_name:
|
image_name:
|
||||||
@ -106,7 +108,9 @@ export function buildFacetoolNode(state: RootState): RestoreFaceInvocation {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// is this ESRGAN??
|
// is this ESRGAN??
|
||||||
export function buildUpscaleNode(state: RootState): UpscaleInvocation {
|
export function buildUpscaleNode(
|
||||||
|
state: RootState
|
||||||
|
): Omit<UpscaleInvocation, 'id'> {
|
||||||
const { generation, postprocessing } = state;
|
const { generation, postprocessing } = state;
|
||||||
|
|
||||||
const { initialImage } = generation;
|
const { initialImage } = generation;
|
||||||
@ -115,7 +119,6 @@ export function buildUpscaleNode(state: RootState): UpscaleInvocation {
|
|||||||
|
|
||||||
// missing fields in UpscaleInvocation: denoise_str
|
// missing fields in UpscaleInvocation: denoise_str
|
||||||
return {
|
return {
|
||||||
id: uuidv4(),
|
|
||||||
type: 'upscale',
|
type: 'upscale',
|
||||||
image: {
|
image: {
|
||||||
image_name:
|
image_name:
|
||||||
|
@ -11,6 +11,7 @@ import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
|||||||
import { useHotkeys } from 'react-hotkeys-hook';
|
import { useHotkeys } from 'react-hotkeys-hook';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { FaPlay } from 'react-icons/fa';
|
import { FaPlay } from 'react-icons/fa';
|
||||||
|
import { createSession } from 'services/thunks/session';
|
||||||
|
|
||||||
interface InvokeButton
|
interface InvokeButton
|
||||||
extends Omit<IAIButtonProps | IAIIconButtonProps, 'aria-label'> {
|
extends Omit<IAIButtonProps | IAIIconButtonProps, 'aria-label'> {
|
||||||
@ -24,7 +25,8 @@ export default function InvokeButton(props: InvokeButton) {
|
|||||||
const activeTabName = useAppSelector(activeTabNameSelector);
|
const activeTabName = useAppSelector(activeTabNameSelector);
|
||||||
|
|
||||||
const handleClickGenerate = () => {
|
const handleClickGenerate = () => {
|
||||||
dispatch(generateImage(activeTabName));
|
// dispatch(generateImage(activeTabName));
|
||||||
|
dispatch(createSession());
|
||||||
};
|
};
|
||||||
|
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
17
invokeai/frontend/web/src/services/invokeMiddleware.ts
Normal file
17
invokeai/frontend/web/src/services/invokeMiddleware.ts
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
import { Middleware } from '@reduxjs/toolkit';
|
||||||
|
import { setSessionId } from './apiSlice';
|
||||||
|
import { invokeSession } from './thunks/session';
|
||||||
|
|
||||||
|
export const invokeMiddleware: Middleware = (store) => (next) => (action) => {
|
||||||
|
const { dispatch } = store;
|
||||||
|
|
||||||
|
if (action.type === 'api/createSession/fulfilled' && action?.payload?.id) {
|
||||||
|
console.log('createSession.fulfilled');
|
||||||
|
|
||||||
|
dispatch(setSessionId(action.payload.id));
|
||||||
|
// types are wrong but this works
|
||||||
|
dispatch(invokeSession({ sessionId: action.payload.id }));
|
||||||
|
} else {
|
||||||
|
next(action);
|
||||||
|
}
|
||||||
|
};
|
@ -1,5 +1,6 @@
|
|||||||
import { createAppAsyncThunk } from 'app/storeUtils';
|
import { createAppAsyncThunk } from 'app/storeUtils';
|
||||||
import { SessionsService } from 'services/api';
|
import { SessionsService } from 'services/api';
|
||||||
|
import { buildGraph } from 'common/util/buildGraph';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* createSession thunk
|
* createSession thunk
|
||||||
@ -18,7 +19,16 @@ type CreateSessionRequestBody = Parameters<
|
|||||||
export const createSession = createAppAsyncThunk(
|
export const createSession = createAppAsyncThunk(
|
||||||
'api/createSession',
|
'api/createSession',
|
||||||
async (arg: CreateSessionRequestBody, _thunkApi) => {
|
async (arg: CreateSessionRequestBody, _thunkApi) => {
|
||||||
const response = await SessionsService.createSession({ requestBody: arg });
|
let graph = arg;
|
||||||
|
if (!arg) {
|
||||||
|
const { getState } = _thunkApi;
|
||||||
|
const state = getState();
|
||||||
|
graph = buildGraph(state);
|
||||||
|
}
|
||||||
|
|
||||||
|
const response = await SessionsService.createSession({
|
||||||
|
requestBody: graph,
|
||||||
|
});
|
||||||
|
|
||||||
return response;
|
return response;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user