mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): start hooking up dynamic txt2img node generation, create middleware for session invocation
This commit is contained in:
parent
3ebd289a59
commit
4fe49718e0
@ -114,6 +114,8 @@ const NodeAPITest = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
setResultImages([]);
|
||||
|
||||
// set up socket.io listeners
|
||||
|
||||
// TODO: suppose this should be handled in the socket.io middleware?
|
||||
@ -190,10 +192,16 @@ const NodeAPITest = () => {
|
||||
<IAIButton onClick={handleCancelProcessing} colorScheme="error">
|
||||
Cancel Processing
|
||||
</IAIButton>
|
||||
<IAIButton onClick={handleCreateSession} colorScheme="accent">
|
||||
Create Session
|
||||
</IAIButton>
|
||||
<IAIButton
|
||||
onClick={handleCreateSession}
|
||||
colorScheme="accent"
|
||||
loadingText={`Invoking ${
|
||||
progress === null ? '...' : `${Math.round(progress * 100)}%`
|
||||
}`}
|
||||
>
|
||||
Create Session & Invoke
|
||||
</IAIButton>
|
||||
{/* <IAIButton
|
||||
onClick={handleInvokeSession}
|
||||
loadingText={`Invoking ${
|
||||
progress === null ? '...' : `${Math.round(progress * 100)}%`
|
||||
@ -201,7 +209,7 @@ const NodeAPITest = () => {
|
||||
colorScheme="accent"
|
||||
>
|
||||
Invoke
|
||||
</IAIButton>
|
||||
</IAIButton> */}
|
||||
<Flex wrap="wrap" gap={4} overflow="scroll">
|
||||
<Image
|
||||
src={progressImage?.dataURL}
|
||||
|
@ -15,6 +15,7 @@ import uiReducer from 'features/ui/store/uiSlice';
|
||||
import apiReducer from 'services/apiSlice';
|
||||
|
||||
import { socketioMiddleware } from './socketio/middleware';
|
||||
import { invokeMiddleware } from 'services/invokeMiddleware';
|
||||
|
||||
/**
|
||||
* redux-persist provides an easy and reliable way to persist state across reloads.
|
||||
@ -103,7 +104,7 @@ export const store = configureStore({
|
||||
getDefaultMiddleware({
|
||||
immutableCheck: false,
|
||||
serializableCheck: false,
|
||||
}).concat(socketioMiddleware()),
|
||||
}).concat(socketioMiddleware(), invokeMiddleware),
|
||||
devTools: {
|
||||
// Uncommenting these very rapidly called actions makes the redux dev tools output much more readable
|
||||
actionsDenylist: [
|
||||
|
31
invokeai/frontend/web/src/common/util/buildGraph.ts
Normal file
31
invokeai/frontend/web/src/common/util/buildGraph.ts
Normal file
@ -0,0 +1,31 @@
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
import { RootState } from 'app/store';
|
||||
import { InvokeTabName, tabMap } from 'features/ui/store/tabMap';
|
||||
import { Graph } from 'services/api';
|
||||
import { buildTxt2ImgNode } from './buildNodes';
|
||||
|
||||
function mapTabToFunction(activeTabName: InvokeTabName) {
|
||||
switch (activeTabName) {
|
||||
case 'txt2img':
|
||||
return buildTxt2ImgNode;
|
||||
|
||||
default:
|
||||
return buildTxt2ImgNode;
|
||||
}
|
||||
}
|
||||
|
||||
export const buildGraph = (state: RootState): Graph => {
|
||||
const { activeTab } = state.ui;
|
||||
const activeTabName = tabMap[activeTab];
|
||||
const nodeId = uuidv4();
|
||||
|
||||
return {
|
||||
nodes: {
|
||||
[nodeId]: {
|
||||
id: nodeId,
|
||||
...mapTabToFunction(activeTabName)(state),
|
||||
},
|
||||
},
|
||||
};
|
||||
};
|
@ -1,4 +1,3 @@
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { RootState } from 'app/store';
|
||||
import {
|
||||
ImageToImageInvocation,
|
||||
@ -12,10 +11,12 @@ import {
|
||||
// be todo add symmetry fields
|
||||
// be todo variations....
|
||||
|
||||
export function buildTxt2ImgNode(state: RootState): TextToImageInvocation {
|
||||
export function buildTxt2ImgNode(
|
||||
state: RootState
|
||||
): Omit<TextToImageInvocation, 'id'> {
|
||||
const { generation, system } = state;
|
||||
|
||||
const { shouldDisplayInProgressType, openModel } = system;
|
||||
const { shouldDisplayInProgressType, model } = system;
|
||||
|
||||
const {
|
||||
prompt,
|
||||
@ -30,7 +31,6 @@ export function buildTxt2ImgNode(state: RootState): TextToImageInvocation {
|
||||
|
||||
// missing fields in TextToImageInvocation: strength, hires_fix
|
||||
return {
|
||||
id: uuidv4(),
|
||||
type: 'txt2img',
|
||||
prompt,
|
||||
seed,
|
||||
@ -40,12 +40,14 @@ export function buildTxt2ImgNode(state: RootState): TextToImageInvocation {
|
||||
cfg_scale,
|
||||
sampler_name: sampler as TextToImageInvocation['sampler_name'],
|
||||
seamless,
|
||||
model: openModel as string | undefined,
|
||||
model,
|
||||
progress_images: shouldDisplayInProgressType === 'full-res',
|
||||
};
|
||||
}
|
||||
|
||||
export function buildImg2ImgNode(state: RootState): ImageToImageInvocation {
|
||||
export function buildImg2ImgNode(
|
||||
state: RootState
|
||||
): Omit<ImageToImageInvocation, 'id'> {
|
||||
const { generation, system } = state;
|
||||
|
||||
const { shouldDisplayInProgressType, openModel: model } = system;
|
||||
@ -65,7 +67,6 @@ export function buildImg2ImgNode(state: RootState): ImageToImageInvocation {
|
||||
} = generation;
|
||||
|
||||
return {
|
||||
id: 'a',
|
||||
type: 'img2img',
|
||||
prompt,
|
||||
seed,
|
||||
@ -86,7 +87,9 @@ export function buildImg2ImgNode(state: RootState): ImageToImageInvocation {
|
||||
};
|
||||
}
|
||||
|
||||
export function buildFacetoolNode(state: RootState): RestoreFaceInvocation {
|
||||
export function buildFacetoolNode(
|
||||
state: RootState
|
||||
): Omit<RestoreFaceInvocation, 'id'> {
|
||||
const { generation, postprocessing } = state;
|
||||
|
||||
const { initialImage } = generation;
|
||||
@ -95,7 +98,6 @@ export function buildFacetoolNode(state: RootState): RestoreFaceInvocation {
|
||||
|
||||
// missing fields in RestoreFaceInvocation: type, codeformer_fidelity
|
||||
return {
|
||||
id: uuidv4(),
|
||||
type: 'restore_face',
|
||||
image: {
|
||||
image_name:
|
||||
@ -106,7 +108,9 @@ export function buildFacetoolNode(state: RootState): RestoreFaceInvocation {
|
||||
}
|
||||
|
||||
// is this ESRGAN??
|
||||
export function buildUpscaleNode(state: RootState): UpscaleInvocation {
|
||||
export function buildUpscaleNode(
|
||||
state: RootState
|
||||
): Omit<UpscaleInvocation, 'id'> {
|
||||
const { generation, postprocessing } = state;
|
||||
|
||||
const { initialImage } = generation;
|
||||
@ -115,7 +119,6 @@ export function buildUpscaleNode(state: RootState): UpscaleInvocation {
|
||||
|
||||
// missing fields in UpscaleInvocation: denoise_str
|
||||
return {
|
||||
id: uuidv4(),
|
||||
type: 'upscale',
|
||||
image: {
|
||||
image_name:
|
||||
|
@ -11,6 +11,7 @@ import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaPlay } from 'react-icons/fa';
|
||||
import { createSession } from 'services/thunks/session';
|
||||
|
||||
interface InvokeButton
|
||||
extends Omit<IAIButtonProps | IAIIconButtonProps, 'aria-label'> {
|
||||
@ -24,7 +25,8 @@ export default function InvokeButton(props: InvokeButton) {
|
||||
const activeTabName = useAppSelector(activeTabNameSelector);
|
||||
|
||||
const handleClickGenerate = () => {
|
||||
dispatch(generateImage(activeTabName));
|
||||
// dispatch(generateImage(activeTabName));
|
||||
dispatch(createSession());
|
||||
};
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
17
invokeai/frontend/web/src/services/invokeMiddleware.ts
Normal file
17
invokeai/frontend/web/src/services/invokeMiddleware.ts
Normal file
@ -0,0 +1,17 @@
|
||||
import { Middleware } from '@reduxjs/toolkit';
|
||||
import { setSessionId } from './apiSlice';
|
||||
import { invokeSession } from './thunks/session';
|
||||
|
||||
export const invokeMiddleware: Middleware = (store) => (next) => (action) => {
|
||||
const { dispatch } = store;
|
||||
|
||||
if (action.type === 'api/createSession/fulfilled' && action?.payload?.id) {
|
||||
console.log('createSession.fulfilled');
|
||||
|
||||
dispatch(setSessionId(action.payload.id));
|
||||
// types are wrong but this works
|
||||
dispatch(invokeSession({ sessionId: action.payload.id }));
|
||||
} else {
|
||||
next(action);
|
||||
}
|
||||
};
|
@ -1,5 +1,6 @@
|
||||
import { createAppAsyncThunk } from 'app/storeUtils';
|
||||
import { SessionsService } from 'services/api';
|
||||
import { buildGraph } from 'common/util/buildGraph';
|
||||
|
||||
/**
|
||||
* createSession thunk
|
||||
@ -18,7 +19,16 @@ type CreateSessionRequestBody = Parameters<
|
||||
export const createSession = createAppAsyncThunk(
|
||||
'api/createSession',
|
||||
async (arg: CreateSessionRequestBody, _thunkApi) => {
|
||||
const response = await SessionsService.createSession({ requestBody: arg });
|
||||
let graph = arg;
|
||||
if (!arg) {
|
||||
const { getState } = _thunkApi;
|
||||
const state = getState();
|
||||
graph = buildGraph(state);
|
||||
}
|
||||
|
||||
const response = await SessionsService.createSession({
|
||||
requestBody: graph,
|
||||
});
|
||||
|
||||
return response;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user