feat(ui): wip canvas

This commit is contained in:
psychedelicious 2023-05-05 00:06:50 +10:00
parent 206e6b1730
commit 1c9429a6ea
33 changed files with 712 additions and 222 deletions

View File

@ -12,15 +12,11 @@ import { addImageResultReceivedListener } from './listeners/invocationComplete';
import { addImageUploadedListener } from './listeners/imageUploaded';
import { addRequestedImageDeletionListener } from './listeners/imageDeleted';
import {
canvasGraphBuilt,
sessionCreated,
sessionInvoked,
} from 'services/thunks/session';
import { tabMap } from 'features/ui/store/tabMap';
import {
canvasSessionIdChanged,
stagingAreaInitialized,
} from 'features/canvas/store/canvasSlice';
addUserInvokedCanvasListener,
addUserInvokedCreateListener,
addUserInvokedNodesListener,
} from './listeners/userInvoked';
import { addCanvasGraphBuiltListener } from './listeners/canvasGraphBuilt';
export const listenerMiddleware = createListenerMiddleware();
@ -44,26 +40,7 @@ addImageUploadedListener();
addInitialImageSelectedListener();
addImageResultReceivedListener();
addRequestedImageDeletionListener();
startAppListening({
actionCreator: canvasGraphBuilt.fulfilled,
effect: async (action, { dispatch, getState, condition, fork, take }) => {
const [{ meta }] = await take(sessionInvoked.fulfilled.match);
const { sessionId } = meta.arg;
const state = getState();
if (!state.canvas.layerState.stagingArea.boundingBox) {
dispatch(
stagingAreaInitialized({
sessionId,
boundingBox: {
...state.canvas.boundingBoxCoordinates,
...state.canvas.boundingBoxDimensions,
},
})
);
}
dispatch(canvasSessionIdChanged(sessionId));
},
});
addUserInvokedCanvasListener();
addUserInvokedCreateListener();
addUserInvokedNodesListener();
// addCanvasGraphBuiltListener();

View File

@ -0,0 +1,31 @@
import { canvasGraphBuilt } from 'features/nodes/store/actions';
import { startAppListening } from '..';
import {
canvasSessionIdChanged,
stagingAreaInitialized,
} from 'features/canvas/store/canvasSlice';
import { sessionInvoked } from 'services/thunks/session';
export const addCanvasGraphBuiltListener = () =>
startAppListening({
actionCreator: canvasGraphBuilt,
effect: async (action, { dispatch, getState, take }) => {
const [{ meta }] = await take(sessionInvoked.fulfilled.match);
const { sessionId } = meta.arg;
const state = getState();
if (!state.canvas.layerState.stagingArea.boundingBox) {
dispatch(
stagingAreaInitialized({
sessionId,
boundingBox: {
...state.canvas.boundingBoxCoordinates,
...state.canvas.boundingBoxDimensions,
},
})
);
}
dispatch(canvasSessionIdChanged(sessionId));
},
});

View File

@ -0,0 +1,167 @@
import { createAction } from '@reduxjs/toolkit';
import { startAppListening } from '..';
import { InvokeTabName } from 'features/ui/store/tabMap';
import { buildLinearGraph } from 'features/nodes/util/buildLinearGraph';
import { sessionCreated, sessionInvoked } from 'services/thunks/session';
import { buildCanvasGraphAndBlobs } from 'features/nodes/util/buildCanvasGraph';
import { buildNodesGraph } from 'features/nodes/util/buildNodesGraph';
import { log } from 'app/logging/useLogger';
import {
canvasGraphBuilt,
createGraphBuilt,
nodesGraphBuilt,
} from 'features/nodes/store/actions';
import { imageUploaded } from 'services/thunks/image';
import { v4 as uuidv4 } from 'uuid';
import { Graph } from 'services/api';
import {
canvasSessionIdChanged,
stagingAreaInitialized,
} from 'features/canvas/store/canvasSlice';
const moduleLog = log.child({ namespace: 'invoke' });
export const userInvoked = createAction<InvokeTabName>('app/userInvoked');
export const addUserInvokedCreateListener = () => {
startAppListening({
predicate: (action): action is ReturnType<typeof userInvoked> =>
userInvoked.match(action) && action.payload === 'generate',
effect: (action, { getState, dispatch }) => {
const state = getState();
const graph = buildLinearGraph(state);
dispatch(createGraphBuilt(graph));
moduleLog({ data: graph }, 'Create graph built');
dispatch(sessionCreated({ graph }));
},
});
};
export const addUserInvokedCanvasListener = () => {
startAppListening({
predicate: (action): action is ReturnType<typeof userInvoked> =>
userInvoked.match(action) && action.payload === 'unifiedCanvas',
effect: async (action, { getState, dispatch, take }) => {
const state = getState();
const data = await buildCanvasGraphAndBlobs(state);
if (!data) {
moduleLog.error('Problem building graph');
return;
}
const {
rangeNode,
iterateNode,
baseNode,
edges,
baseBlob,
maskBlob,
generationMode,
} = data;
const baseFilename = `${uuidv4()}.png`;
const maskFilename = `${uuidv4()}.png`;
dispatch(
imageUploaded({
imageType: 'intermediates',
formData: {
file: new File([baseBlob], baseFilename, { type: 'image/png' }),
},
})
);
if (baseNode.type === 'img2img' || baseNode.type === 'inpaint') {
const [{ payload: basePayload }] = await take(
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
imageUploaded.fulfilled.match(action) &&
action.meta.arg.formData.file.name === baseFilename
);
const { image_name: baseName, image_type: baseType } =
basePayload.response;
baseNode.image = {
image_name: baseName,
image_type: baseType,
};
}
if (baseNode.type === 'inpaint') {
dispatch(
imageUploaded({
imageType: 'intermediates',
formData: {
file: new File([maskBlob], maskFilename, { type: 'image/png' }),
},
})
);
const [{ payload: maskPayload }] = await take(
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
imageUploaded.fulfilled.match(action) &&
action.meta.arg.formData.file.name === maskFilename
);
const { image_name: maskName, image_type: maskType } =
maskPayload.response;
baseNode.mask = {
image_name: maskName,
image_type: maskType,
};
}
// Assemble!
const nodes: Graph['nodes'] = {
[rangeNode.id]: rangeNode,
[iterateNode.id]: iterateNode,
[baseNode.id]: baseNode,
};
const graph = { nodes, edges };
dispatch(canvasGraphBuilt(graph));
moduleLog({ data: graph }, 'Canvas graph built');
dispatch(sessionCreated({ graph }));
const [{ meta }] = await take(sessionInvoked.fulfilled.match);
const { sessionId } = meta.arg;
if (!state.canvas.layerState.stagingArea.boundingBox) {
dispatch(
stagingAreaInitialized({
sessionId,
boundingBox: {
...state.canvas.boundingBoxCoordinates,
...state.canvas.boundingBoxDimensions,
},
})
);
}
dispatch(canvasSessionIdChanged(sessionId));
},
});
};
export const addUserInvokedNodesListener = () => {
startAppListening({
predicate: (action): action is ReturnType<typeof userInvoked> =>
userInvoked.match(action) && action.payload === 'nodes',
effect: (action, { getState, dispatch }) => {
const state = getState();
const graph = buildNodesGraph(state);
dispatch(nodesGraphBuilt(graph));
moduleLog({ data: graph }, 'Nodes graph built');
dispatch(sessionCreated({ graph }));
},
});
};

View File

@ -1,8 +1,10 @@
import {
Action,
AnyAction,
ThunkDispatch,
combineReducers,
configureStore,
isAnyOf,
} from '@reduxjs/toolkit';
import { persistReducer } from 'redux-persist';
@ -33,9 +35,10 @@ import { nodesDenylist } from 'features/nodes/store/nodesPersistDenylist';
import { postprocessingDenylist } from 'features/parameters/store/postprocessingPersistDenylist';
import { systemDenylist } from 'features/system/store/systemPersistDenylist';
import { uiDenylist } from 'features/ui/store/uiPersistDenylist';
import { resultsDenylist } from 'features/gallery/store/resultsPersistDenylist';
import { uploadsDenylist } from 'features/gallery/store/uploadsPersistDenylist';
import { listenerMiddleware } from './middleware/listenerMiddleware';
import { isAnyGraphBuilt } from 'features/nodes/store/actions';
import { forEach } from 'lodash-es';
import { Graph } from 'services/api';
/**
* redux-persist provides an easy and reliable way to persist state across reloads.
@ -101,6 +104,27 @@ const persistedReducer = persistReducer(rootPersistConfig, rootReducer);
// }
// }
// const actionSanitizer = (action: AnyAction): AnyAction => {
// if (isAnyGraphBuilt(action)) {
// if (action.payload.nodes) {
// const sanitizedNodes: Graph['nodes'] = {};
// forEach(action.payload.nodes, (node, key) => {
// if (node.type === 'dataURL_image') {
// const { dataURL, ...rest } = node;
// sanitizedNodes[key] = { ...rest, dataURL: '<<dataURL>>' };
// }
// });
// const sanitizedAction: AnyAction = {
// ...action,
// payload: { ...action.payload, nodes: sanitizedNodes },
// };
// return sanitizedAction;
// }
// }
// return action;
// };
export const store = configureStore({
reducer: persistedReducer,
middleware: (getDefaultMiddleware) =>
@ -123,6 +147,31 @@ export const store = configureStore({
'canvas/addPointToCurrentLine',
'socket/generatorProgress',
],
actionSanitizer: (action) => {
if (isAnyGraphBuilt(action)) {
if (action.payload.nodes) {
const sanitizedNodes: Graph['nodes'] = {};
forEach(action.payload.nodes, (node, key) => {
if (node.type === 'dataURL_image') {
const { dataURL, ...rest } = node;
sanitizedNodes[key] = { ...rest, dataURL: '<<dataURL>>' };
} else {
sanitizedNodes[key] = { ...node };
}
});
return {
...action,
payload: { ...action.payload, nodes: sanitizedNodes },
};
}
}
return action;
},
// stateSanitizer: (state) =>
// state.data ? { ...state, data: '<<LONG_BLOB>>' } : state,
},
});

View File

@ -49,7 +49,7 @@ const ImageUploader = (props: ImageUploaderProps) => {
const fileAcceptedCallback = useCallback(
async (file: File) => {
dispatch(imageUploaded({ formData: { file } }));
dispatch(imageUploaded({ imageType: 'uploads', formData: { file } }));
},
[dispatch]
);
@ -124,7 +124,7 @@ const ImageUploader = (props: ImageUploaderProps) => {
return;
}
dispatch(imageUploaded({ formData: { file } }));
dispatch(imageUploaded({ imageType: 'uploads', formData: { file } }));
};
document.addEventListener('paste', pasteImageListener);
return () => {

View File

@ -1,5 +1,4 @@
export const getImageDataTransparency = (pixels: Uint8ClampedArray) => {
console.log(pixels);
let isFullyTransparent = true;
let isPartiallyTransparent = false;
const len = pixels.length;

View File

@ -299,8 +299,6 @@ export const frontendToBackendParameters = (
const doesBaseHaveTransparency = getIsImageDataTransparent(imageData);
const doesMaskHaveTransparency = getIsImageDataWhite(maskImageData);
console.log(doesBaseHaveTransparency, doesMaskHaveTransparency);
if (enableImageDebugging) {
openBase64ImageInTab([
{ base64: maskDataURL, caption: 'mask sent as init_mask' },

View File

@ -1,39 +0,0 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import {
FrontendToBackendParametersConfig,
frontendToBackendParameters,
} from 'common/util/parameterTranslation';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import { postprocessingSelector } from 'features/parameters/store/postprocessingSelectors';
import { systemSelector } from 'features/system/store/systemSelectors';
import { canvasSelector } from '../store/canvasSelectors';
import { useCallback, useMemo } from 'react';
const selector = createSelector(
[generationSelector, postprocessingSelector, systemSelector, canvasSelector],
(generation, postprocessing, system, canvas) => {
const frontendToBackendParametersConfig: FrontendToBackendParametersConfig =
{
generationMode: 'unifiedCanvas',
generationState: generation,
postprocessingState: postprocessing,
canvasState: canvas,
systemState: system,
};
return frontendToBackendParametersConfig;
}
);
export const usePrepareCanvasState = () => {
const frontendToBackendParametersConfig = useAppSelector(selector);
const getGenerationParameters = useCallback(() => {
const { generationParameters, esrganParameters, facetoolParameters } =
frontendToBackendParameters(frontendToBackendParametersConfig);
console.log(generationParameters);
}, [frontendToBackendParametersConfig]);
return getGenerationParameters;
};

View File

@ -0,0 +1,13 @@
/**
* Gets a Blob from a canvas.
*/
export const canvasToBlob = async (canvas: HTMLCanvasElement): Promise<Blob> =>
new Promise((resolve, reject) => {
canvas.toBlob((blob) => {
if (blob) {
resolve(blob);
return;
}
reject('Unable to create Blob');
});
});

View File

@ -1,3 +1,6 @@
/**
* Gets an ImageData object from an image dataURL by drawing it to a canvas.
*/
export const dataURLToImageData = async (
dataURL: string,
width: number,

View File

@ -104,6 +104,7 @@
import { CanvasMaskLine } from 'features/canvas/store/canvasTypes';
import Konva from 'konva';
import { IRect } from 'konva/lib/types';
import { canvasToBlob } from './canvasToBlob';
/**
* Generating a mask image from InpaintingCanvas.tsx is not as simple
@ -115,7 +116,7 @@ import { IRect } from 'konva/lib/types';
* drawing the mask and compositing everything correctly to output a valid
* mask image.
*/
const generateMask = (lines: CanvasMaskLine[], boundingBox: IRect): string => {
const generateMask = async (lines: CanvasMaskLine[], boundingBox: IRect) => {
// create an offscreen canvas and add the mask to it
const { width, height } = boundingBox;
@ -157,11 +158,13 @@ const generateMask = (lines: CanvasMaskLine[], boundingBox: IRect): string => {
stage.add(baseLayer);
stage.add(maskLayer);
const dataURL = stage.toDataURL({ ...boundingBox });
const maskDataURL = stage.toDataURL(boundingBox);
const maskBlob = await canvasToBlob(stage.toCanvas(boundingBox));
offscreenContainer.remove();
return dataURL;
return { maskDataURL, maskBlob };
};
export default generateMask;

View File

@ -8,7 +8,8 @@ import {
} from 'common/util/arrayBuffer';
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
import generateMask from './generateMask';
import { dataURLToImageData } from './dataURLToUint8ClampedArray';
import { dataURLToImageData } from './dataURLToImageData';
import { canvasToBlob } from './canvasToBlob';
const moduleLog = log.child({ namespace: 'getCanvasDataURLs' });
@ -62,10 +63,13 @@ export const getCanvasData = async (state: RootState) => {
};
const baseDataURL = canvasBaseLayer.toDataURL(offsetBoundingBox);
const baseBlob = await canvasToBlob(
canvasBaseLayer.toCanvas(offsetBoundingBox)
);
canvasBaseLayer.scale(tempScale);
const maskDataURL = generateMask(
const { maskDataURL, maskBlob } = await generateMask(
isMaskEnabled ? objects.filter(isCanvasMaskLine) : [],
boundingBox
);
@ -82,9 +86,6 @@ export const getCanvasData = async (state: RootState) => {
boundingBox.height
);
console.log('baseImageData', baseImageData);
console.log('maskImageData', maskImageData);
const {
isPartiallyTransparent: baseIsPartiallyTransparent,
isFullyTransparent: baseIsFullyTransparent,
@ -117,7 +118,9 @@ export const getCanvasData = async (state: RootState) => {
return {
baseDataURL,
baseBlob,
maskDataURL,
maskBlob,
baseIsPartiallyTransparent,
baseIsFullyTransparent,
doesMaskHaveBlackPixels,

View File

@ -1,16 +1,16 @@
import { HStack } from '@chakra-ui/react';
import { userInvoked } from 'app/store/middleware/listenerMiddleware/listeners/userInvoked';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import { memo, useCallback } from 'react';
import { Panel } from 'reactflow';
import { receivedOpenAPISchema } from 'services/thunks/schema';
import { nodesGraphBuilt } from 'services/thunks/session';
const TopCenterPanel = () => {
const dispatch = useAppDispatch();
const handleInvoke = useCallback(() => {
dispatch(nodesGraphBuilt());
dispatch(userInvoked('nodes'));
}, [dispatch]);
const handleReloadSchema = useCallback(() => {

View File

@ -0,0 +1,12 @@
import { createAction, isAnyOf } from '@reduxjs/toolkit';
import { Graph } from 'services/api';
export const createGraphBuilt = createAction<Graph>('nodes/createGraphBuilt');
export const canvasGraphBuilt = createAction<Graph>('nodes/canvasGraphBuilt');
export const nodesGraphBuilt = createAction<Graph>('nodes/nodesGraphBuilt');
export const isAnyGraphBuilt = isAnyOf(
createGraphBuilt,
canvasGraphBuilt,
nodesGraphBuilt
);

View File

@ -13,11 +13,11 @@ import {
} from 'reactflow';
import { Graph, ImageField } from 'services/api';
import { receivedOpenAPISchema } from 'services/thunks/schema';
import { isFulfilledAnyGraphBuilt } from 'services/thunks/session';
import { InvocationTemplate, InvocationValue } from '../types/types';
import { parseSchema } from '../util/parseSchema';
import { log } from 'app/logging/useLogger';
import { size } from 'lodash-es';
import { isAnyGraphBuilt } from './actions';
export type NodesState = {
nodes: Node<InvocationValue>[];
@ -25,7 +25,6 @@ export type NodesState = {
schema: OpenAPIV3.Document | null;
invocationTemplates: Record<string, InvocationTemplate>;
connectionStartParams: OnConnectStartParams | null;
lastGraph: Graph | null;
shouldShowGraphOverlay: boolean;
};
@ -35,7 +34,6 @@ export const initialNodesState: NodesState = {
schema: null,
invocationTemplates: {},
connectionStartParams: null,
lastGraph: null,
shouldShowGraphOverlay: false,
};
@ -104,8 +102,9 @@ const nodesSlice = createSlice({
state.schema = action.payload;
});
builder.addMatcher(isFulfilledAnyGraphBuilt, (state, action) => {
state.lastGraph = action.payload;
builder.addMatcher(isAnyGraphBuilt, (state, action) => {
// TODO: Achtung! Side effect in a reducer!
log.info({ namespace: 'nodes', data: action.payload }, 'Graph built');
});
},
});

View File

@ -1,5 +1,15 @@
import { RootState } from 'app/store/store';
import { DataURLToImageInvocation, Graph } from 'services/api';
import {
DataURLToImageInvocation,
Edge,
Graph,
ImageToImageInvocation,
InpaintInvocation,
IterateInvocation,
RandomRangeInvocation,
RangeInvocation,
TextToImageInvocation,
} from 'services/api';
import { buildImg2ImgNode } from './linearGraphBuilder/buildImageToImageNode';
import { buildTxt2ImgNode } from './linearGraphBuilder/buildTextToImageNode';
import { buildRangeNode } from './linearGraphBuilder/buildRangeNode';
@ -7,18 +17,54 @@ import { buildIterateNode } from './linearGraphBuilder/buildIterateNode';
import { buildEdges } from './linearGraphBuilder/buildEdges';
import { getCanvasBaseLayer } from 'features/canvas/util/konvaInstanceProvider';
import { getCanvasData } from 'features/canvas/util/getCanvasData';
import { getNodeType } from './getNodeType';
import { getGenerationMode } from './getGenerationMode';
import { v4 as uuidv4 } from 'uuid';
import { log } from 'app/logging/useLogger';
import { buildInpaintNode } from './linearGraphBuilder/buildInpaintNode';
const moduleLog = log.child({ namespace: 'buildCanvasGraph' });
/**
* Builds the Canvas workflow graph.
*/
export const buildCanvasGraph = async (
const buildBaseNode = (
nodeType: 'txt2img' | 'img2img' | 'inpaint' | 'outpaint',
state: RootState
): Promise<Graph | undefined> => {
):
| TextToImageInvocation
| ImageToImageInvocation
| InpaintInvocation
| undefined => {
if (nodeType === 'txt2img') {
return buildTxt2ImgNode(state, state.canvas.boundingBoxDimensions);
}
if (nodeType === 'img2img') {
return buildImg2ImgNode(state, state.canvas.boundingBoxDimensions);
}
if (nodeType === 'inpaint' || nodeType === 'outpaint') {
return buildInpaintNode(state, state.canvas.boundingBoxDimensions);
}
};
/**
* Builds the Canvas workflow graph and image blobs.
*/
export const buildCanvasGraphAndBlobs = async (
state: RootState
): Promise<
| {
rangeNode: RangeInvocation | RandomRangeInvocation;
iterateNode: IterateInvocation;
baseNode:
| TextToImageInvocation
| ImageToImageInvocation
| InpaintInvocation;
edges: Edge[];
baseBlob: Blob;
maskBlob: Blob;
generationMode: 'txt2img' | 'img2img' | 'inpaint' | 'outpaint';
}
| undefined
> => {
const c = await getCanvasData(state);
if (!c) {
@ -26,35 +72,68 @@ export const buildCanvasGraph = async (
return;
}
moduleLog.debug({ data: c }, 'Built canvas data');
const {
baseDataURL,
baseBlob,
maskDataURL,
maskBlob,
baseIsPartiallyTransparent,
baseIsFullyTransparent,
doesMaskHaveBlackPixels,
} = c;
const nodeType = getNodeType(
moduleLog.debug(
{
data: {
// baseDataURL,
// maskDataURL,
baseIsPartiallyTransparent,
baseIsFullyTransparent,
doesMaskHaveBlackPixels,
},
},
'Built canvas data'
);
const generationMode = getGenerationMode(
baseIsPartiallyTransparent,
baseIsFullyTransparent,
doesMaskHaveBlackPixels
);
moduleLog.debug(`Node type ${nodeType}`);
moduleLog.debug(`Generation mode: ${generationMode}`);
// The base node is either a txt2img or img2img node
const baseNode =
nodeType === 'img2img'
? buildImg2ImgNode(state, state.canvas.boundingBoxDimensions)
: buildTxt2ImgNode(state, state.canvas.boundingBoxDimensions);
// The base node is a txt2img, img2img or inpaint node
const baseNode = buildBaseNode(generationMode, state);
const dataURLNode: DataURLToImageInvocation = {
id: uuidv4(),
type: 'dataURL_image',
dataURL: baseDataURL,
};
if (!baseNode) {
moduleLog.error('Problem building base node');
return;
}
if (baseNode.type === 'inpaint') {
const {
seamSize,
seamBlur,
seamSteps,
seamStrength,
tileSize,
infillMethod,
} = state.generation;
// generationParameters.invert_mask = shouldPreserveMaskedArea;
// if (boundingBoxScale !== 'none') {
// generationParameters.inpaint_width = scaledBoundingBoxDimensions.width;
// generationParameters.inpaint_height = scaledBoundingBoxDimensions.height;
// }
baseNode.seam_size = seamSize;
baseNode.seam_blur = seamBlur;
baseNode.seam_strength = seamStrength;
baseNode.seam_steps = seamSteps;
baseNode.tile_size = tileSize;
// baseNode.infill_method = infillMethod;
// baseNode.force_outpaint = false;
}
// We always range and iterate nodes, no matter the iteration count
// This is required to provide the correct seeds to the backend engine
@ -64,31 +143,13 @@ export const buildCanvasGraph = async (
// Build the edges for the nodes selected.
const edges = buildEdges(baseNode, rangeNode, iterateNode);
if (baseNode.type === 'img2img') {
edges.push({
source: {
node_id: dataURLNode.id,
field: 'image',
},
destination: {
node_id: baseNode.id,
field: 'image',
},
});
}
// Assemble!
const graph = {
nodes: {
[dataURLNode.id]: dataURLNode,
[rangeNode.id]: rangeNode,
[iterateNode.id]: iterateNode,
[baseNode.id]: baseNode,
},
return {
rangeNode,
iterateNode,
baseNode,
edges,
baseBlob,
maskBlob,
generationMode,
};
// TODO: hires fix requires latent space upscaling; we don't have nodes for this yet
return graph;
};

View File

@ -1,4 +1,4 @@
export const getNodeType = (
export const getGenerationMode = (
baseIsPartiallyTransparent: boolean,
baseIsFullyTransparent: boolean,
doesMaskHaveBlackPixels: boolean

View File

@ -1,6 +1,7 @@
import {
Edge,
ImageToImageInvocation,
InpaintInvocation,
IterateInvocation,
RandomRangeInvocation,
RangeInvocation,
@ -8,7 +9,7 @@ import {
} from 'services/api';
export const buildEdges = (
baseNode: TextToImageInvocation | ImageToImageInvocation,
baseNode: TextToImageInvocation | ImageToImageInvocation | InpaintInvocation,
rangeNode: RangeInvocation | RandomRangeInvocation,
iterateNode: IterateInvocation
): Edge[] => {

View File

@ -0,0 +1,72 @@
import { v4 as uuidv4 } from 'uuid';
import { RootState } from 'app/store/store';
import {
Edge,
ImageToImageInvocation,
InpaintInvocation,
TextToImageInvocation,
} from 'services/api';
import { initialImageSelector } from 'features/parameters/store/generationSelectors';
import { O } from 'ts-toolbelt';
export const buildInpaintNode = (
state: RootState,
overrides: O.Partial<InpaintInvocation, 'deep'> = {}
): InpaintInvocation => {
const nodeId = uuidv4();
const { generation, system, models } = state;
const { selectedModelName } = models;
const {
prompt,
negativePrompt,
seed,
steps,
width,
height,
cfgScale,
sampler,
seamless,
img2imgStrength: strength,
shouldFitToWidthHeight: fit,
shouldRandomizeSeed,
} = generation;
const initialImage = initialImageSelector(state);
if (!initialImage) {
// TODO: handle this
// throw 'no initial image';
}
const imageToImageNode: InpaintInvocation = {
id: nodeId,
type: 'inpaint',
prompt: `${prompt} [${negativePrompt}]`,
steps,
width,
height,
cfg_scale: cfgScale,
scheduler: sampler as InpaintInvocation['scheduler'],
seamless,
model: selectedModelName,
progress_images: true,
image: initialImage
? {
image_name: initialImage.name,
image_type: initialImage.type,
}
: undefined,
strength,
fit,
};
if (!shouldRandomizeSeed) {
imageToImageNode.seed = seed;
}
Object.assign(imageToImageNode, overrides);
return imageToImageNode;
};

View File

@ -1,18 +1,17 @@
import { Box } from '@chakra-ui/react';
import { readinessSelector } from 'app/selectors/readinessSelector';
import { userInvoked } from 'app/store/middleware/listenerMiddleware/listeners/userInvoked';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton, { IAIButtonProps } from 'common/components/IAIButton';
import IAIIconButton, {
IAIIconButtonProps,
} from 'common/components/IAIIconButton';
import { usePrepareCanvasState } from 'features/canvas/hooks/usePrepareCanvasState';
import { clampSymmetrySteps } from 'features/parameters/store/generationSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { useCallback } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { FaPlay } from 'react-icons/fa';
import { canvasGraphBuilt, generateGraphBuilt } from 'services/thunks/session';
interface InvokeButton
extends Omit<IAIButtonProps | IAIIconButtonProps, 'aria-label'> {
@ -24,15 +23,10 @@ export default function InvokeButton(props: InvokeButton) {
const dispatch = useAppDispatch();
const { isReady } = useAppSelector(readinessSelector);
const activeTabName = useAppSelector(activeTabNameSelector);
// const getGenerationParameters = usePrepareCanvasState();
const handleInvoke = useCallback(() => {
dispatch(clampSymmetrySteps());
if (activeTabName === 'unifiedCanvas') {
dispatch(canvasGraphBuilt());
} else {
dispatch(generateGraphBuilt());
}
dispatch(userInvoked(activeTabName));
}, [dispatch, activeTabName]);
const { t } = useTranslation();

View File

@ -1,7 +1,7 @@
import { Box, FormControl, Textarea } from '@chakra-ui/react';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { ChangeEvent, KeyboardEvent, useRef } from 'react';
import { ChangeEvent, KeyboardEvent, useCallback, useRef } from 'react';
import { createSelector } from '@reduxjs/toolkit';
import { readinessSelector } from 'app/selectors/readinessSelector';
@ -15,7 +15,7 @@ import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { isEqual } from 'lodash-es';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { generateGraphBuilt } from 'services/thunks/session';
import { userInvoked } from 'app/store/middleware/listenerMiddleware/listeners/userInvoked';
const promptInputSelector = createSelector(
[(state: RootState) => state.generation, activeTabNameSelector],
@ -37,7 +37,7 @@ const promptInputSelector = createSelector(
*/
const PromptInput = () => {
const dispatch = useAppDispatch();
const { prompt } = useAppSelector(promptInputSelector);
const { prompt, activeTabName } = useAppSelector(promptInputSelector);
const { isReady } = useAppSelector(readinessSelector);
const promptRef = useRef<HTMLTextAreaElement>(null);
@ -56,13 +56,16 @@ const PromptInput = () => {
[]
);
const handleKeyDown = (e: KeyboardEvent<HTMLTextAreaElement>) => {
if (e.key === 'Enter' && e.shiftKey === false && isReady) {
e.preventDefault();
dispatch(clampSymmetrySteps());
dispatch(generateGraphBuilt());
}
};
const handleKeyDown = useCallback(
(e: KeyboardEvent<HTMLTextAreaElement>) => {
if (e.key === 'Enter' && e.shiftKey === false && isReady) {
e.preventDefault();
dispatch(clampSymmetrySteps());
dispatch(userInvoked(activeTabName));
}
},
[dispatch, activeTabName, isReady]
);
return (
<Box>

View File

@ -120,7 +120,7 @@ const initialSystemState: SystemState = {
shouldLogToConsole: true,
statusTranslationKey: 'common.statusDisconnected',
canceledSession: '',
infillMethods: ['tile'],
infillMethods: ['tile', 'patchmatch'],
};
export const systemSlice = createSlice({

View File

@ -12,6 +12,7 @@ export type { Body_upload_image } from './models/Body_upload_image';
export type { CkptModelInfo } from './models/CkptModelInfo';
export type { CollectInvocation } from './models/CollectInvocation';
export type { CollectInvocationOutput } from './models/CollectInvocationOutput';
export type { ColorField } from './models/ColorField';
export type { CreateModelRequest } from './models/CreateModelRequest';
export type { CropImageInvocation } from './models/CropImageInvocation';
export type { CvInpaintInvocation } from './models/CvInpaintInvocation';
@ -76,6 +77,7 @@ export { $Body_upload_image } from './schemas/$Body_upload_image';
export { $CkptModelInfo } from './schemas/$CkptModelInfo';
export { $CollectInvocation } from './schemas/$CollectInvocation';
export { $CollectInvocationOutput } from './schemas/$CollectInvocationOutput';
export { $ColorField } from './schemas/$ColorField';
export { $CreateModelRequest } from './schemas/$CreateModelRequest';
export { $CropImageInvocation } from './schemas/$CropImageInvocation';
export { $CvInpaintInvocation } from './schemas/$CvInpaintInvocation';

View File

@ -0,0 +1,23 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
export type ColorField = {
/**
* The red component
*/
'r': number;
/**
* The blue component
*/
'b': number;
/**
* The green component
*/
'g': number;
/**
* The alpha component
*/
'a'?: number;
};

View File

@ -2,6 +2,7 @@
/* tslint:disable */
/* eslint-disable */
import type { ColorField } from './ColorField';
import type { ImageField } from './ImageField';
/**
@ -69,6 +70,42 @@ export type InpaintInvocation = {
* The mask
*/
mask?: ImageField;
/**
* The seam inpaint size (px)
*/
seam_size?: number;
/**
* The seam inpaint blur radius (px)
*/
seam_blur?: number;
/**
* The seam inpaint strength
*/
seam_strength?: number;
/**
* The number of steps to use for seam inpaint
*/
seam_steps?: number;
/**
* The tile infill method size (px)
*/
tile_size?: number;
/**
* The method used to infill empty regions (px)
*/
infill_method?: 'patchmatch' | 'tile' | 'solid';
/**
* The width of the inpaint region (px)
*/
inpaint_width?: number;
/**
* The height of the inpaint region (px)
*/
inpaint_height?: number;
/**
* The solid infill method color
*/
inpaint_fill?: ColorField;
/**
* The amount by which to replace masked areas with latent noise
*/

View File

@ -0,0 +1,30 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
export const $ColorField = {
properties: {
'r': {
type: 'number',
description: `The red component`,
isRequired: true,
maximum: 255,
},
'b': {
type: 'number',
description: `The blue component`,
isRequired: true,
maximum: 255,
},
'g': {
type: 'number',
description: `The green component`,
isRequired: true,
maximum: 255,
},
'a': {
type: 'number',
description: `The alpha component`,
maximum: 255,
},
},
} as const;

View File

@ -29,16 +29,17 @@ export const $ImageToImageInvocation = {
width: {
type: 'number',
description: `The width of the resulting image`,
multipleOf: 64,
multipleOf: 8,
},
height: {
type: 'number',
description: `The height of the resulting image`,
multipleOf: 64,
multipleOf: 8,
},
cfg_scale: {
type: 'number',
description: `The Classifier-Free Guidance, higher values may result in a result closer to the prompt`,
exclusiveMinimum: 1,
},
scheduler: {
type: 'Enum',

View File

@ -29,16 +29,17 @@ export const $InpaintInvocation = {
width: {
type: 'number',
description: `The width of the resulting image`,
multipleOf: 64,
multipleOf: 8,
},
height: {
type: 'number',
description: `The height of the resulting image`,
multipleOf: 64,
multipleOf: 8,
},
cfg_scale: {
type: 'number',
description: `The Classifier-Free Guidance, higher values may result in a result closer to the prompt`,
exclusiveMinimum: 1,
},
scheduler: {
type: 'Enum',
@ -78,6 +79,50 @@ export const $InpaintInvocation = {
type: 'ImageField',
}],
},
seam_size: {
type: 'number',
description: `The seam inpaint size (px)`,
minimum: 1,
},
seam_blur: {
type: 'number',
description: `The seam inpaint blur radius (px)`,
},
seam_strength: {
type: 'number',
description: `The seam inpaint strength`,
maximum: 1,
},
seam_steps: {
type: 'number',
description: `The number of steps to use for seam inpaint`,
minimum: 1,
},
tile_size: {
type: 'number',
description: `The tile infill method size (px)`,
minimum: 1,
},
infill_method: {
type: 'Enum',
},
inpaint_width: {
type: 'number',
description: `The width of the inpaint region (px)`,
multipleOf: 8,
},
inpaint_height: {
type: 'number',
description: `The height of the inpaint region (px)`,
multipleOf: 8,
},
inpaint_fill: {
type: 'all-of',
description: `The solid infill method color`,
contains: [{
type: 'ColorField',
}],
},
inpaint_replace: {
type: 'number',
description: `The amount by which to replace masked areas with latent noise`,

View File

@ -20,12 +20,12 @@ export const $NoiseInvocation = {
width: {
type: 'number',
description: `The width of the resulting noise`,
multipleOf: 64,
multipleOf: 8,
},
height: {
type: 'number',
description: `The height of the resulting noise`,
multipleOf: 64,
multipleOf: 8,
},
},
} as const;

View File

@ -29,16 +29,17 @@ export const $TextToImageInvocation = {
width: {
type: 'number',
description: `The width of the resulting image`,
multipleOf: 64,
multipleOf: 8,
},
height: {
type: 'number',
description: `The height of the resulting image`,
multipleOf: 64,
multipleOf: 8,
},
cfg_scale: {
type: 'number',
description: `The Classifier-Free Guidance, higher values may result in a result closer to the prompt`,
exclusiveMinimum: 1,
},
scheduler: {
type: 'Enum',

View File

@ -114,13 +114,18 @@ export class ImagesService {
* @throws ApiError
*/
public static uploadImage({
imageType,
formData,
}: {
imageType: ImageType,
formData: Body_upload_image,
}): CancelablePromise<ImageResponse> {
return __request(OpenAPI, {
method: 'POST',
url: '/api/v1/images/uploads/',
query: {
'image_type': imageType,
},
formData: formData,
mediaType: 'multipart/form-data',
errors: {

View File

@ -1,7 +1,7 @@
import { createAppAsyncThunk } from 'app/store/storeUtils';
import { SessionsService } from 'services/api';
import { buildLinearGraph as buildGenerateGraph } from 'features/nodes/util/buildLinearGraph';
import { buildCanvasGraph } from 'features/nodes/util/buildCanvasGraph';
import { buildCanvasGraphAndBlobs } from 'features/nodes/util/buildCanvasGraph';
import { isAnyOf, isFulfilled } from '@reduxjs/toolkit';
import { buildNodesGraph } from 'features/nodes/util/buildNodesGraph';
import { log } from 'app/logging/useLogger';
@ -9,62 +9,62 @@ import { serializeError } from 'serialize-error';
const sessionLog = log.child({ namespace: 'session' });
export const generateGraphBuilt = createAppAsyncThunk(
'api/generateGraphBuilt',
async (_, { dispatch, getState, rejectWithValue }) => {
try {
const graph = buildGenerateGraph(getState());
dispatch(sessionCreated({ graph }));
return graph;
} catch (err: any) {
sessionLog.error(
{ error: serializeError(err) },
'Problem building graph'
);
return rejectWithValue(err.message);
}
}
);
// export const generateGraphBuilt = createAppAsyncThunk(
// 'api/generateGraphBuilt',
// async (_, { dispatch, getState, rejectWithValue }) => {
// try {
// const graph = buildGenerateGraph(getState());
// dispatch(sessionCreated({ graph }));
// return graph;
// } catch (err: any) {
// sessionLog.error(
// { error: serializeError(err) },
// 'Problem building graph'
// );
// return rejectWithValue(err.message);
// }
// }
// );
export const nodesGraphBuilt = createAppAsyncThunk(
'api/nodesGraphBuilt',
async (_, { dispatch, getState, rejectWithValue }) => {
try {
const graph = buildNodesGraph(getState());
dispatch(sessionCreated({ graph }));
return graph;
} catch (err: any) {
sessionLog.error(
{ error: serializeError(err) },
'Problem building graph'
);
return rejectWithValue(err.message);
}
}
);
// export const nodesGraphBuilt = createAppAsyncThunk(
// 'api/nodesGraphBuilt',
// async (_, { dispatch, getState, rejectWithValue }) => {
// try {
// const graph = buildNodesGraph(getState());
// dispatch(sessionCreated({ graph }));
// return graph;
// } catch (err: any) {
// sessionLog.error(
// { error: serializeError(err) },
// 'Problem building graph'
// );
// return rejectWithValue(err.message);
// }
// }
// );
export const canvasGraphBuilt = createAppAsyncThunk(
'api/canvasGraphBuilt',
async (_, { dispatch, getState, rejectWithValue }) => {
try {
const graph = await buildCanvasGraph(getState());
dispatch(sessionCreated({ graph }));
return graph;
} catch (err: any) {
sessionLog.error(
{ error: serializeError(err) },
'Problem building graph'
);
return rejectWithValue(err.message);
}
}
);
// export const canvasGraphBuilt = createAppAsyncThunk(
// 'api/canvasGraphBuilt',
// async (_, { dispatch, getState, rejectWithValue }) => {
// try {
// const graph = await buildCanvasGraph(getState());
// dispatch(sessionCreated({ graph }));
// return graph;
// } catch (err: any) {
// sessionLog.error(
// { error: serializeError(err) },
// 'Problem building graph'
// );
// return rejectWithValue(err.message);
// }
// }
// );
export const isFulfilledAnyGraphBuilt = isAnyOf(
generateGraphBuilt.fulfilled,
nodesGraphBuilt.fulfilled,
canvasGraphBuilt.fulfilled
);
// export const isFulfilledAnyGraphBuilt = isAnyOf(
// generateGraphBuilt.fulfilled,
// nodesGraphBuilt.fulfilled,
// canvasGraphBuilt.fulfilled
// );
type SessionCreatedArg = {
graph: Parameters<