feat(ui): wip canvas nodes migration

This commit is contained in:
psychedelicious 2023-05-02 20:11:12 +10:00
parent ff5e2a9a8c
commit 08ec12b391
19 changed files with 652 additions and 241 deletions

View File

@ -1,209 +1,209 @@
// import { AnyAction, Dispatch, MiddlewareAPI } from '@reduxjs/toolkit'; import { AnyAction, Dispatch, MiddlewareAPI } from '@reduxjs/toolkit';
// import * as InvokeAI from 'app/types/invokeai'; import * as InvokeAI from 'app/types/invokeai';
// import type { RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
// import { import {
// frontendToBackendParameters, frontendToBackendParameters,
// FrontendToBackendParametersConfig, FrontendToBackendParametersConfig,
// } from 'common/util/parameterTranslation'; } from 'common/util/parameterTranslation';
// import dateFormat from 'dateformat'; import dateFormat from 'dateformat';
// import { import {
// GalleryCategory, GalleryCategory,
// GalleryState, GalleryState,
// removeImage, removeImage,
// } from 'features/gallery/store/gallerySlice'; } from 'features/gallery/store/gallerySlice';
// import { import {
// generationRequested, generationRequested,
// modelChangeRequested, modelChangeRequested,
// modelConvertRequested, modelConvertRequested,
// modelMergingRequested, modelMergingRequested,
// setIsProcessing, setIsProcessing,
// } from 'features/system/store/systemSlice'; } from 'features/system/store/systemSlice';
// import { InvokeTabName } from 'features/ui/store/tabMap'; import { InvokeTabName } from 'features/ui/store/tabMap';
// import { Socket } from 'socket.io-client'; import { Socket } from 'socket.io-client';
// /** /**
// * Returns an object containing all functions which use `socketio.emit()`. * Returns an object containing all functions which use `socketio.emit()`.
// * i.e. those which make server requests. * i.e. those which make server requests.
// */ */
// const makeSocketIOEmitters = ( const makeSocketIOEmitters = (
// store: MiddlewareAPI<Dispatch<AnyAction>, RootState>, store: MiddlewareAPI<Dispatch<AnyAction>, RootState>,
// socketio: Socket socketio: Socket
// ) => { ) => {
// // We need to dispatch actions to redux and get pieces of state from the store. // We need to dispatch actions to redux and get pieces of state from the store.
// const { dispatch, getState } = store; const { dispatch, getState } = store;
// return { return {
// emitGenerateImage: (generationMode: InvokeTabName) => { emitGenerateImage: (generationMode: InvokeTabName) => {
// dispatch(setIsProcessing(true)); dispatch(setIsProcessing(true));
// const state: RootState = getState(); const state: RootState = getState();
// const { const {
// generation: generationState, generation: generationState,
// postprocessing: postprocessingState, postprocessing: postprocessingState,
// system: systemState, system: systemState,
// canvas: canvasState, canvas: canvasState,
// } = state; } = state;
// const frontendToBackendParametersConfig: FrontendToBackendParametersConfig = const frontendToBackendParametersConfig: FrontendToBackendParametersConfig =
// { {
// generationMode, generationMode,
// generationState, generationState,
// postprocessingState, postprocessingState,
// canvasState, canvasState,
// systemState, systemState,
// }; };
// dispatch(generationRequested()); dispatch(generationRequested());
// const { generationParameters, esrganParameters, facetoolParameters } = const { generationParameters, esrganParameters, facetoolParameters } =
// frontendToBackendParameters(frontendToBackendParametersConfig); frontendToBackendParameters(frontendToBackendParametersConfig);
// socketio.emit( socketio.emit(
// 'generateImage', 'generateImage',
// generationParameters, generationParameters,
// esrganParameters, esrganParameters,
// facetoolParameters facetoolParameters
// ); );
// // we need to truncate the init_mask base64 else it takes up the whole log // we need to truncate the init_mask base64 else it takes up the whole log
// // TODO: handle maintaining masks for reproducibility in future // TODO: handle maintaining masks for reproducibility in future
// if (generationParameters.init_mask) { if (generationParameters.init_mask) {
// generationParameters.init_mask = generationParameters.init_mask generationParameters.init_mask = generationParameters.init_mask
// .substr(0, 64) .substr(0, 64)
// .concat('...'); .concat('...');
// } }
// if (generationParameters.init_img) { if (generationParameters.init_img) {
// generationParameters.init_img = generationParameters.init_img generationParameters.init_img = generationParameters.init_img
// .substr(0, 64) .substr(0, 64)
// .concat('...'); .concat('...');
// } }
// dispatch( dispatch(
// addLogEntry({ addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'), timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `Image generation requested: ${JSON.stringify({ message: `Image generation requested: ${JSON.stringify({
// ...generationParameters, ...generationParameters,
// ...esrganParameters, ...esrganParameters,
// ...facetoolParameters, ...facetoolParameters,
// })}`, })}`,
// }) })
// ); );
// }, },
// emitRunESRGAN: (imageToProcess: InvokeAI._Image) => { emitRunESRGAN: (imageToProcess: InvokeAI._Image) => {
// dispatch(setIsProcessing(true)); dispatch(setIsProcessing(true));
// const { const {
// postprocessing: { postprocessing: {
// upscalingLevel, upscalingLevel,
// upscalingDenoising, upscalingDenoising,
// upscalingStrength, upscalingStrength,
// }, },
// } = getState(); } = getState();
// const esrganParameters = { const esrganParameters = {
// upscale: [upscalingLevel, upscalingDenoising, upscalingStrength], upscale: [upscalingLevel, upscalingDenoising, upscalingStrength],
// }; };
// socketio.emit('runPostprocessing', imageToProcess, { socketio.emit('runPostprocessing', imageToProcess, {
// type: 'esrgan', type: 'esrgan',
// ...esrganParameters, ...esrganParameters,
// }); });
// dispatch( dispatch(
// addLogEntry({ addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'), timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `ESRGAN upscale requested: ${JSON.stringify({ message: `ESRGAN upscale requested: ${JSON.stringify({
// file: imageToProcess.url, file: imageToProcess.url,
// ...esrganParameters, ...esrganParameters,
// })}`, })}`,
// }) })
// ); );
// }, },
// emitRunFacetool: (imageToProcess: InvokeAI._Image) => { emitRunFacetool: (imageToProcess: InvokeAI._Image) => {
// dispatch(setIsProcessing(true)); dispatch(setIsProcessing(true));
// const { const {
// postprocessing: { facetoolType, facetoolStrength, codeformerFidelity }, postprocessing: { facetoolType, facetoolStrength, codeformerFidelity },
// } = getState(); } = getState();
// const facetoolParameters: Record<string, unknown> = { const facetoolParameters: Record<string, unknown> = {
// facetool_strength: facetoolStrength, facetool_strength: facetoolStrength,
// }; };
// if (facetoolType === 'codeformer') { if (facetoolType === 'codeformer') {
// facetoolParameters.codeformer_fidelity = codeformerFidelity; facetoolParameters.codeformer_fidelity = codeformerFidelity;
// } }
// socketio.emit('runPostprocessing', imageToProcess, { socketio.emit('runPostprocessing', imageToProcess, {
// type: facetoolType, type: facetoolType,
// ...facetoolParameters, ...facetoolParameters,
// }); });
// dispatch( dispatch(
// addLogEntry({ addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'), timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `Face restoration (${facetoolType}) requested: ${JSON.stringify( message: `Face restoration (${facetoolType}) requested: ${JSON.stringify(
// { {
// file: imageToProcess.url, file: imageToProcess.url,
// ...facetoolParameters, ...facetoolParameters,
// } }
// )}`, )}`,
// }) })
// ); );
// }, },
// emitDeleteImage: (imageToDelete: InvokeAI._Image) => { emitDeleteImage: (imageToDelete: InvokeAI._Image) => {
// const { url, uuid, category, thumbnail } = imageToDelete; const { url, uuid, category, thumbnail } = imageToDelete;
// dispatch(removeImage(imageToDelete)); dispatch(removeImage(imageToDelete));
// socketio.emit('deleteImage', url, thumbnail, uuid, category); socketio.emit('deleteImage', url, thumbnail, uuid, category);
// }, },
// emitRequestImages: (category: GalleryCategory) => { emitRequestImages: (category: GalleryCategory) => {
// const gallery: GalleryState = getState().gallery; const gallery: GalleryState = getState().gallery;
// const { earliest_mtime } = gallery.categories[category]; const { earliest_mtime } = gallery.categories[category];
// socketio.emit('requestImages', category, earliest_mtime); socketio.emit('requestImages', category, earliest_mtime);
// }, },
// emitRequestNewImages: (category: GalleryCategory) => { emitRequestNewImages: (category: GalleryCategory) => {
// const gallery: GalleryState = getState().gallery; const gallery: GalleryState = getState().gallery;
// const { latest_mtime } = gallery.categories[category]; const { latest_mtime } = gallery.categories[category];
// socketio.emit('requestLatestImages', category, latest_mtime); socketio.emit('requestLatestImages', category, latest_mtime);
// }, },
// emitCancelProcessing: () => { emitCancelProcessing: () => {
// socketio.emit('cancel'); socketio.emit('cancel');
// }, },
// emitRequestSystemConfig: () => { emitRequestSystemConfig: () => {
// socketio.emit('requestSystemConfig'); socketio.emit('requestSystemConfig');
// }, },
// emitSearchForModels: (modelFolder: string) => { emitSearchForModels: (modelFolder: string) => {
// socketio.emit('searchForModels', modelFolder); socketio.emit('searchForModels', modelFolder);
// }, },
// emitAddNewModel: (modelConfig: InvokeAI.InvokeModelConfigProps) => { emitAddNewModel: (modelConfig: InvokeAI.InvokeModelConfigProps) => {
// socketio.emit('addNewModel', modelConfig); socketio.emit('addNewModel', modelConfig);
// }, },
// emitDeleteModel: (modelName: string) => { emitDeleteModel: (modelName: string) => {
// socketio.emit('deleteModel', modelName); socketio.emit('deleteModel', modelName);
// }, },
// emitConvertToDiffusers: ( emitConvertToDiffusers: (
// modelToConvert: InvokeAI.InvokeModelConversionProps modelToConvert: InvokeAI.InvokeModelConversionProps
// ) => { ) => {
// dispatch(modelConvertRequested()); dispatch(modelConvertRequested());
// socketio.emit('convertToDiffusers', modelToConvert); socketio.emit('convertToDiffusers', modelToConvert);
// }, },
// emitMergeDiffusersModels: ( emitMergeDiffusersModels: (
// modelMergeInfo: InvokeAI.InvokeModelMergingProps modelMergeInfo: InvokeAI.InvokeModelMergingProps
// ) => { ) => {
// dispatch(modelMergingRequested()); dispatch(modelMergingRequested());
// socketio.emit('mergeDiffusersModels', modelMergeInfo); socketio.emit('mergeDiffusersModels', modelMergeInfo);
// }, },
// emitRequestModelChange: (modelName: string) => { emitRequestModelChange: (modelName: string) => {
// dispatch(modelChangeRequested()); dispatch(modelChangeRequested());
// socketio.emit('requestModelChange', modelName); socketio.emit('requestModelChange', modelName);
// }, },
// emitSaveStagingAreaImageToGallery: (url: string) => { emitSaveStagingAreaImageToGallery: (url: string) => {
// socketio.emit('requestSaveStagingAreaImageToGallery', url); socketio.emit('requestSaveStagingAreaImageToGallery', url);
// }, },
// emitRequestEmptyTempFolder: () => { emitRequestEmptyTempFolder: () => {
// socketio.emit('requestEmptyTempFolder'); socketio.emit('requestEmptyTempFolder');
// }, },
// }; };
// }; };
// export default makeSocketIOEmitters; export default makeSocketIOEmitters;
export default {}; export default {};

View File

@ -0,0 +1,59 @@
export const getIsImageDataPartiallyTransparent = (imageData: ImageData) => {
let hasTransparency = false;
let isFullyTransparent = true;
const len = imageData.data.length;
let i = 3;
for (i; i < len; i += 4) {
if (imageData.data[i] !== 0) {
isFullyTransparent = false;
} else {
hasTransparency = true;
}
}
return { hasTransparency, isFullyTransparent };
};
export const getImageDataTransparency = (imageData: ImageData) => {
let isFullyTransparent = true;
let isPartiallyTransparent = false;
const len = imageData.data.length;
let i = 3;
for (i; i < len; i += 4) {
if (imageData.data[i] === 255) {
isFullyTransparent = false;
} else {
isPartiallyTransparent = true;
}
if (!isFullyTransparent && isPartiallyTransparent) {
return { isFullyTransparent, isPartiallyTransparent };
}
}
return { isFullyTransparent, isPartiallyTransparent };
};
export const areAnyPixelsBlack = (imageData: ImageData) => {
const len = imageData.data.length;
let i = 0;
for (i; i < len; ) {
if (
imageData.data[i++] === 255 &&
imageData.data[i++] === 255 &&
imageData.data[i++] === 255 &&
imageData.data[i++] === 255
) {
return true;
}
}
return false;
};
export const getIsImageDataWhite = (imageData: ImageData) => {
const len = imageData.data.length;
let i = 0;
for (i; i < len; ) {
if (imageData.data[i++] !== 255) {
return false;
}
}
return true;
};

View File

@ -19,6 +19,7 @@ import { InvokeTabName } from 'features/ui/store/tabMap';
import openBase64ImageInTab from './openBase64ImageInTab'; import openBase64ImageInTab from './openBase64ImageInTab';
import randomInt from './randomInt'; import randomInt from './randomInt';
import { stringToSeedWeightsArray } from './seedWeightPairs'; import { stringToSeedWeightsArray } from './seedWeightPairs';
import { getIsImageDataTransparent, getIsImageDataWhite } from './arrayBuffer';
export type FrontendToBackendParametersConfig = { export type FrontendToBackendParametersConfig = {
generationMode: InvokeTabName; generationMode: InvokeTabName;
@ -256,7 +257,7 @@ export const frontendToBackendParameters = (
...boundingBoxDimensions, ...boundingBoxDimensions,
}; };
const maskDataURL = generateMask( const { dataURL: maskDataURL, imageData: maskImageData } = generateMask(
isMaskEnabled ? objects.filter(isCanvasMaskLine) : [], isMaskEnabled ? objects.filter(isCanvasMaskLine) : [],
boundingBox boundingBox
); );
@ -287,6 +288,19 @@ export const frontendToBackendParameters = (
height: boundingBox.height, height: boundingBox.height,
}); });
const ctx = canvasBaseLayer.getContext();
const imageData = ctx.getImageData(
boundingBox.x + absPos.x,
boundingBox.y + absPos.y,
boundingBox.width,
boundingBox.height
);
const doesBaseHaveTransparency = getIsImageDataTransparent(imageData);
const doesMaskHaveTransparency = getIsImageDataWhite(maskImageData);
console.log(doesBaseHaveTransparency, doesMaskHaveTransparency);
if (enableImageDebugging) { if (enableImageDebugging) {
openBase64ImageInTab([ openBase64ImageInTab([
{ base64: maskDataURL, caption: 'mask sent as init_mask' }, { base64: maskDataURL, caption: 'mask sent as init_mask' },

View File

@ -0,0 +1,39 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import {
FrontendToBackendParametersConfig,
frontendToBackendParameters,
} from 'common/util/parameterTranslation';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import { postprocessingSelector } from 'features/parameters/store/postprocessingSelectors';
import { systemSelector } from 'features/system/store/systemSelectors';
import { canvasSelector } from '../store/canvasSelectors';
import { useCallback, useMemo } from 'react';
const selector = createSelector(
[generationSelector, postprocessingSelector, systemSelector, canvasSelector],
(generation, postprocessing, system, canvas) => {
const frontendToBackendParametersConfig: FrontendToBackendParametersConfig =
{
generationMode: 'unifiedCanvas',
generationState: generation,
postprocessingState: postprocessing,
canvasState: canvas,
systemState: system,
};
return frontendToBackendParametersConfig;
}
);
export const usePrepareCanvasState = () => {
const frontendToBackendParametersConfig = useAppSelector(selector);
const getGenerationParameters = useCallback(() => {
const { generationParameters, esrganParameters, facetoolParameters } =
frontendToBackendParameters(frontendToBackendParametersConfig);
console.log(generationParameters);
}, [frontendToBackendParametersConfig]);
return getGenerationParameters;
};

View File

@ -156,22 +156,20 @@ export const canvasSlice = createSlice({
setCursorPosition: (state, action: PayloadAction<Vector2d | null>) => { setCursorPosition: (state, action: PayloadAction<Vector2d | null>) => {
state.cursorPosition = action.payload; state.cursorPosition = action.payload;
}, },
setInitialCanvasImage: (state, action: PayloadAction<InvokeAI._Image>) => { setInitialCanvasImage: (state, action: PayloadAction<InvokeAI.Image>) => {
const image = action.payload; const image = action.payload;
const { width, height } = image.metadata;
const { stageDimensions } = state; const { stageDimensions } = state;
const newBoundingBoxDimensions = { const newBoundingBoxDimensions = {
width: roundDownToMultiple(clamp(image.width, 64, 512), 64), width: roundDownToMultiple(clamp(width, 64, 512), 64),
height: roundDownToMultiple(clamp(image.height, 64, 512), 64), height: roundDownToMultiple(clamp(height, 64, 512), 64),
}; };
const newBoundingBoxCoordinates = { const newBoundingBoxCoordinates = {
x: roundToMultiple( x: roundToMultiple(width / 2 - newBoundingBoxDimensions.width / 2, 64),
image.width / 2 - newBoundingBoxDimensions.width / 2,
64
),
y: roundToMultiple( y: roundToMultiple(
image.height / 2 - newBoundingBoxDimensions.height / 2, height / 2 - newBoundingBoxDimensions.height / 2,
64 64
), ),
}; };
@ -196,8 +194,8 @@ export const canvasSlice = createSlice({
layer: 'base', layer: 'base',
x: 0, x: 0,
y: 0, y: 0,
width: image.width, width: width,
height: image.height, height: height,
image: image, image: image,
}, },
], ],
@ -208,8 +206,8 @@ export const canvasSlice = createSlice({
const newScale = calculateScale( const newScale = calculateScale(
stageDimensions.width, stageDimensions.width,
stageDimensions.height, stageDimensions.height,
image.width, width,
image.height, height,
STAGE_PADDING_PERCENTAGE STAGE_PADDING_PERCENTAGE
); );
@ -218,8 +216,8 @@ export const canvasSlice = createSlice({
stageDimensions.height, stageDimensions.height,
0, 0,
0, 0,
image.width, width,
image.height, height,
newScale newScale
); );
state.stageScale = newScale; state.stageScale = newScale;

View File

@ -12,7 +12,10 @@ import { IRect } from 'konva/lib/types';
* drawing the mask and compositing everything correctly to output a valid * drawing the mask and compositing everything correctly to output a valid
* mask image. * mask image.
*/ */
const generateMask = (lines: CanvasMaskLine[], boundingBox: IRect): string => { const generateMask = (
lines: CanvasMaskLine[],
boundingBox: IRect
): { dataURL: string; imageData: ImageData } => {
// create an offscreen canvas and add the mask to it // create an offscreen canvas and add the mask to it
const { width, height } = boundingBox; const { width, height } = boundingBox;
@ -55,10 +58,19 @@ const generateMask = (lines: CanvasMaskLine[], boundingBox: IRect): string => {
stage.add(maskLayer); stage.add(maskLayer);
const dataURL = stage.toDataURL({ ...boundingBox }); const dataURL = stage.toDataURL({ ...boundingBox });
const imageData = stage
.toCanvas()
.getContext('2d')
?.getImageData(
boundingBox.x,
boundingBox.y,
boundingBox.width,
boundingBox.height
);
offscreenContainer.remove(); offscreenContainer.remove();
return dataURL; return { dataURL, imageData };
}; };
export default generateMask; export default generateMask;

View File

@ -0,0 +1,123 @@
import { RootState } from 'app/store/store';
import { getCanvasBaseLayer, getCanvasStage } from './konvaInstanceProvider';
import { isCanvasMaskLine } from '../store/canvasTypes';
import generateMask from './generateMask';
import { log } from 'app/logging/useLogger';
import {
areAnyPixelsBlack,
getImageDataTransparency,
getIsImageDataWhite,
} from 'common/util/arrayBuffer';
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
export const getCanvasDataURLs = (state: RootState) => {
const canvasBaseLayer = getCanvasBaseLayer();
const canvasStage = getCanvasStage();
if (!canvasBaseLayer || !canvasStage) {
log.error(
{ namespace: 'getCanvasDataURLs' },
'Unable to find canvas / stage'
);
return;
}
const {
layerState: { objects },
boundingBoxCoordinates,
boundingBoxDimensions,
stageScale,
isMaskEnabled,
shouldPreserveMaskedArea,
boundingBoxScaleMethod: boundingBoxScale,
scaledBoundingBoxDimensions,
} = state.canvas;
const boundingBox = {
...boundingBoxCoordinates,
...boundingBoxDimensions,
};
// generationParameters.fit = false;
// generationParameters.strength = img2imgStrength;
// generationParameters.invert_mask = shouldPreserveMaskedArea;
// generationParameters.bounding_box = boundingBox;
const tempScale = canvasBaseLayer.scale();
canvasBaseLayer.scale({
x: 1 / stageScale,
y: 1 / stageScale,
});
const absPos = canvasBaseLayer.getAbsolutePosition();
const { dataURL: maskDataURL, imageData: maskImageData } = generateMask(
isMaskEnabled ? objects.filter(isCanvasMaskLine) : [],
{
x: boundingBox.x + absPos.x,
y: boundingBox.y + absPos.y,
width: boundingBox.width,
height: boundingBox.height,
}
);
const baseDataURL = canvasBaseLayer.toDataURL({
x: boundingBox.x + absPos.x,
y: boundingBox.y + absPos.y,
width: boundingBox.width,
height: boundingBox.height,
});
const ctx = canvasBaseLayer.getContext();
const baseImageData = ctx.getImageData(
boundingBox.x + absPos.x,
boundingBox.y + absPos.y,
boundingBox.width,
boundingBox.height
);
const {
isPartiallyTransparent: baseIsPartiallyTransparent,
isFullyTransparent: baseIsFullyTransparent,
} = getImageDataTransparency(baseImageData);
const doesMaskHaveBlackPixels = areAnyPixelsBlack(maskImageData);
if (state.system.enableImageDebugging) {
openBase64ImageInTab([
{ base64: maskDataURL, caption: 'mask sent as init_mask' },
{ base64: baseDataURL, caption: 'image sent as init_img' },
]);
}
canvasBaseLayer.scale(tempScale);
// generationParameters.init_img = imageDataURL;
// generationParameters.progress_images = false;
// if (boundingBoxScale !== 'none') {
// generationParameters.inpaint_width = scaledBoundingBoxDimensions.width;
// generationParameters.inpaint_height = scaledBoundingBoxDimensions.height;
// }
// generationParameters.seam_size = seamSize;
// generationParameters.seam_blur = seamBlur;
// generationParameters.seam_strength = seamStrength;
// generationParameters.seam_steps = seamSteps;
// generationParameters.tile_size = tileSize;
// generationParameters.infill_method = infillMethod;
// generationParameters.force_outpaint = false;
return {
baseDataURL,
maskDataURL,
baseIsPartiallyTransparent,
baseIsFullyTransparent,
doesMaskHaveBlackPixels,
};
};

View File

@ -17,7 +17,10 @@ import { FaCheck, FaExpand, FaImage, FaShare, FaTrash } from 'react-icons/fa';
import DeleteImageModal from './DeleteImageModal'; import DeleteImageModal from './DeleteImageModal';
import { ContextMenu } from 'chakra-ui-contextmenu'; import { ContextMenu } from 'chakra-ui-contextmenu';
import * as InvokeAI from 'app/types/invokeai'; import * as InvokeAI from 'app/types/invokeai';
import { resizeAndScaleCanvas } from 'features/canvas/store/canvasSlice'; import {
resizeAndScaleCanvas,
setInitialCanvasImage,
} from 'features/canvas/store/canvasSlice';
import { gallerySelector } from 'features/gallery/store/gallerySelectors'; import { gallerySelector } from 'features/gallery/store/gallerySelectors';
import { setActiveTab } from 'features/ui/store/uiSlice'; import { setActiveTab } from 'features/ui/store/uiSlice';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
@ -159,7 +162,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
* TODO: the rest of these * TODO: the rest of these
*/ */
const handleSendToCanvas = () => { const handleSendToCanvas = () => {
// dispatch(setInitialCanvasImage(image)); dispatch(setInitialCanvasImage(image));
dispatch(resizeAndScaleCanvas()); dispatch(resizeAndScaleCanvas());
@ -315,6 +318,8 @@ const HoverableImage = memo((props: HoverableImageProps) => {
sx={{ sx={{
width: '50%', width: '50%',
height: '50%', height: '50%',
maxWidth: '4rem',
maxHeight: '4rem',
fill: 'ok.500', fill: 'ok.500',
}} }}
/> />

View File

@ -4,12 +4,8 @@ import { GalleryState } from './gallerySlice';
* Gallery slice persist denylist * Gallery slice persist denylist
*/ */
const itemsToDenylist: (keyof GalleryState)[] = [ const itemsToDenylist: (keyof GalleryState)[] = [
'categories',
'currentCategory', 'currentCategory',
'currentImage',
'currentImageUuid',
'shouldAutoSwitchToNewImages', 'shouldAutoSwitchToNewImages',
'intermediateImage',
]; ];
export const galleryDenylist = itemsToDenylist.map( export const galleryDenylist = itemsToDenylist.map(

View File

@ -5,7 +5,7 @@ import { ResultsState } from './resultsSlice';
* *
* Currently denylisting results slice entirely, see persist config in store.ts * Currently denylisting results slice entirely, see persist config in store.ts
*/ */
const itemsToDenylist: (keyof ResultsState)[] = ['isLoading']; const itemsToDenylist: (keyof ResultsState)[] = [];
export const resultsDenylist = itemsToDenylist.map( export const resultsDenylist = itemsToDenylist.map(
(denylistItem) => `results.${denylistItem}` (denylistItem) => `results.${denylistItem}`

View File

@ -5,7 +5,7 @@ import { UploadsState } from './uploadsSlice';
* *
* Currently denylisting uploads slice entirely, see persist config in store.ts * Currently denylisting uploads slice entirely, see persist config in store.ts
*/ */
const itemsToDenylist: (keyof UploadsState)[] = ['isLoading']; const itemsToDenylist: (keyof UploadsState)[] = [];
export const uploadsDenylist = itemsToDenylist.map( export const uploadsDenylist = itemsToDenylist.map(
(denylistItem) => `uploads.${denylistItem}` (denylistItem) => `uploads.${denylistItem}`

View File

@ -0,0 +1,19 @@
export const getNodeType = (
baseIsPartiallyTransparent: boolean,
baseIsFullyTransparent: boolean,
doesMaskHaveBlackPixels: boolean
): 'txt2img' | `img2img` | 'inpaint' | 'outpaint' => {
if (baseIsPartiallyTransparent) {
if (baseIsFullyTransparent) {
return 'txt2img';
}
return 'outpaint';
} else {
if (doesMaskHaveBlackPixels) {
return 'inpaint';
}
return 'img2img';
}
};

View File

@ -5,10 +5,13 @@ import {
ImageToImageInvocation, ImageToImageInvocation,
TextToImageInvocation, TextToImageInvocation,
} from 'services/api'; } from 'services/api';
import { _Image } from 'app/types/invokeai';
import { initialImageSelector } from 'features/parameters/store/generationSelectors'; import { initialImageSelector } from 'features/parameters/store/generationSelectors';
import { O } from 'ts-toolbelt';
export const buildImg2ImgNode = (state: RootState): ImageToImageInvocation => { export const buildImg2ImgNode = (
state: RootState,
overrides: O.Partial<ImageToImageInvocation, 'deep'> = {}
): ImageToImageInvocation => {
const nodeId = uuidv4(); const nodeId = uuidv4();
const { generation, system, models } = state; const { generation, system, models } = state;
@ -33,7 +36,7 @@ export const buildImg2ImgNode = (state: RootState): ImageToImageInvocation => {
if (!initialImage) { if (!initialImage) {
// TODO: handle this // TODO: handle this
throw 'no initial image'; // throw 'no initial image';
} }
const imageToImageNode: ImageToImageInvocation = { const imageToImageNode: ImageToImageInvocation = {
@ -48,10 +51,12 @@ export const buildImg2ImgNode = (state: RootState): ImageToImageInvocation => {
seamless, seamless,
model: selectedModelName, model: selectedModelName,
progress_images: true, progress_images: true,
image: { image: initialImage
? {
image_name: initialImage.name, image_name: initialImage.name,
image_type: initialImage.type, image_type: initialImage.type,
}, }
: undefined,
strength, strength,
fit, fit,
}; };
@ -60,6 +65,8 @@ export const buildImg2ImgNode = (state: RootState): ImageToImageInvocation => {
imageToImageNode.seed = seed; imageToImageNode.seed = seed;
} }
Object.assign(imageToImageNode, overrides);
return imageToImageNode; return imageToImageNode;
}; };

View File

@ -1,10 +1,15 @@
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { Graph } from 'services/api'; import { DataURLToImageInvocation, Graph } from 'services/api';
import { buildImg2ImgNode } from './buildImageToImageNode'; import { buildImg2ImgNode } from './buildImageToImageNode';
import { buildTxt2ImgNode } from './buildTextToImageNode'; import { buildTxt2ImgNode } from './buildTextToImageNode';
import { buildRangeNode } from './buildRangeNode'; import { buildRangeNode } from './buildRangeNode';
import { buildIterateNode } from './buildIterateNode'; import { buildIterateNode } from './buildIterateNode';
import { buildEdges } from './buildEdges'; import { buildEdges } from './buildEdges';
import { getCanvasBaseLayer } from 'features/canvas/util/konvaInstanceProvider';
import { getCanvasDataURLs } from 'features/canvas/util/getCanvasDataURLs';
import { log } from 'console';
import { getNodeType } from '../getNodeType';
import { v4 as uuidv4 } from 'uuid';
/** /**
* Builds the Linear workflow graph. * Builds the Linear workflow graph.
@ -37,3 +42,86 @@ export const buildLinearGraph = (state: RootState): Graph => {
return graph; return graph;
}; };
/**
* Builds the Linear workflow graph.
*/
export const buildCanvasGraph = (state: RootState): Graph => {
const c = getCanvasDataURLs(state);
if (!c) {
throw 'problm creating canvas graph';
}
const {
baseDataURL,
maskDataURL,
baseIsPartiallyTransparent,
baseIsFullyTransparent,
doesMaskHaveBlackPixels,
} = c;
console.log({
baseDataURL,
maskDataURL,
baseIsPartiallyTransparent,
baseIsFullyTransparent,
doesMaskHaveBlackPixels,
});
const nodeType = getNodeType(
baseIsPartiallyTransparent,
baseIsFullyTransparent,
doesMaskHaveBlackPixels
);
console.log(nodeType);
// The base node is either a txt2img or img2img node
const baseNode =
nodeType === 'img2img'
? buildImg2ImgNode(state, state.canvas.boundingBoxDimensions)
: buildTxt2ImgNode(state, state.canvas.boundingBoxDimensions);
const dataURLNode: DataURLToImageInvocation = {
id: uuidv4(),
type: 'dataURL_image',
dataURL: baseDataURL,
};
// We always range and iterate nodes, no matter the iteration count
// This is required to provide the correct seeds to the backend engine
const rangeNode = buildRangeNode(state);
const iterateNode = buildIterateNode();
// Build the edges for the nodes selected.
const edges = buildEdges(baseNode, rangeNode, iterateNode);
if (baseNode.type === 'img2img') {
edges.push({
source: {
node_id: dataURLNode.id,
field: 'image',
},
destination: {
node_id: baseNode.id,
field: 'image',
},
});
}
// Assemble!
const graph = {
nodes: {
[dataURLNode.id]: dataURLNode,
[rangeNode.id]: rangeNode,
[iterateNode.id]: iterateNode,
[baseNode.id]: baseNode,
},
edges,
};
// TODO: hires fix requires latent space upscaling; we don't have nodes for this yet
return graph;
};

View File

@ -1,8 +1,12 @@
import { v4 as uuidv4 } from 'uuid'; import { v4 as uuidv4 } from 'uuid';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { TextToImageInvocation } from 'services/api'; import { TextToImageInvocation } from 'services/api';
import { O } from 'ts-toolbelt';
export const buildTxt2ImgNode = (state: RootState): TextToImageInvocation => { export const buildTxt2ImgNode = (
state: RootState,
overrides: O.Partial<TextToImageInvocation, 'deep'> = {}
): TextToImageInvocation => {
const nodeId = uuidv4(); const nodeId = uuidv4();
const { generation, models } = state; const { generation, models } = state;
@ -39,5 +43,7 @@ export const buildTxt2ImgNode = (state: RootState): TextToImageInvocation => {
textToImageNode.seed = seed; textToImageNode.seed = seed;
} }
Object.assign(textToImageNode, overrides);
return textToImageNode; return textToImageNode;
}; };

View File

@ -28,7 +28,7 @@ const selector = createSelector(
(parameters, system, canvas) => { (parameters, system, canvas) => {
const { tileSize, infillMethod } = parameters; const { tileSize, infillMethod } = parameters;
const { infill_methods: availableInfillMethods } = system; const { infillMethods } = system;
const { const {
boundingBoxScaleMethod: boundingBoxScale, boundingBoxScaleMethod: boundingBoxScale,
@ -40,7 +40,7 @@ const selector = createSelector(
scaledBoundingBoxDimensions, scaledBoundingBoxDimensions,
tileSize, tileSize,
infillMethod, infillMethod,
availableInfillMethods, infillMethods,
isManual: boundingBoxScale === 'manual', isManual: boundingBoxScale === 'manual',
}; };
}, },
@ -56,7 +56,7 @@ const InfillAndScalingSettings = () => {
const { const {
tileSize, tileSize,
infillMethod, infillMethod,
availableInfillMethods, infillMethods,
boundingBoxScale, boundingBoxScale,
isManual, isManual,
scaledBoundingBoxDimensions, scaledBoundingBoxDimensions,
@ -147,7 +147,7 @@ const InfillAndScalingSettings = () => {
<IAISelect <IAISelect
label={t('parameters.infillMethod')} label={t('parameters.infillMethod')}
value={infillMethod} value={infillMethod}
validValues={availableInfillMethods} validValues={infillMethods}
onChange={(e) => dispatch(setInfillMethod(e.target.value))} onChange={(e) => dispatch(setInfillMethod(e.target.value))}
/> />
<IAISlider <IAISlider

View File

@ -5,13 +5,14 @@ import IAIButton, { IAIButtonProps } from 'common/components/IAIButton';
import IAIIconButton, { import IAIIconButton, {
IAIIconButtonProps, IAIIconButtonProps,
} from 'common/components/IAIIconButton'; } from 'common/components/IAIIconButton';
import { usePrepareCanvasState } from 'features/canvas/hooks/usePrepareCanvasState';
import { clampSymmetrySteps } from 'features/parameters/store/generationSlice'; import { clampSymmetrySteps } from 'features/parameters/store/generationSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { useCallback } from 'react'; import { useCallback } from 'react';
import { useHotkeys } from 'react-hotkeys-hook'; import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { FaPlay } from 'react-icons/fa'; import { FaPlay } from 'react-icons/fa';
import { generateGraphBuilt } from 'services/thunks/session'; import { canvasGraphBuilt, generateGraphBuilt } from 'services/thunks/session';
interface InvokeButton interface InvokeButton
extends Omit<IAIButtonProps | IAIIconButtonProps, 'aria-label'> { extends Omit<IAIButtonProps | IAIIconButtonProps, 'aria-label'> {
@ -23,11 +24,16 @@ export default function InvokeButton(props: InvokeButton) {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { isReady } = useAppSelector(readinessSelector); const { isReady } = useAppSelector(readinessSelector);
const activeTabName = useAppSelector(activeTabNameSelector); const activeTabName = useAppSelector(activeTabNameSelector);
// const getGenerationParameters = usePrepareCanvasState();
const handleInvoke = useCallback(() => { const handleInvoke = useCallback(() => {
dispatch(clampSymmetrySteps()); dispatch(clampSymmetrySteps());
if (activeTabName === 'unifiedCanvas') {
dispatch(canvasGraphBuilt());
} else {
dispatch(generateGraphBuilt()); dispatch(generateGraphBuilt());
}, [dispatch]); }
}, [dispatch, activeTabName]);
const { t } = useTranslation(); const { t } = useTranslation();

View File

@ -27,6 +27,8 @@ import { t } from 'i18next';
export type CancelStrategy = 'immediate' | 'scheduled'; export type CancelStrategy = 'immediate' | 'scheduled';
export type InfillMethod = 'tile' | 'patchmatch';
export interface SystemState { export interface SystemState {
isGFPGANAvailable: boolean; isGFPGANAvailable: boolean;
isESRGANAvailable: boolean; isESRGANAvailable: boolean;
@ -79,7 +81,14 @@ export interface SystemState {
consoleLogLevel: InvokeLogLevel; consoleLogLevel: InvokeLogLevel;
shouldLogToConsole: boolean; shouldLogToConsole: boolean;
statusTranslationKey: TFuncKey; statusTranslationKey: TFuncKey;
/**
* When a session is canceled, its ID is stored here until a new session is created.
*/
canceledSession: string; canceledSession: string;
/**
* TODO: get this from backend
*/
infillMethods: InfillMethod[];
} }
const initialSystemState: SystemState = { const initialSystemState: SystemState = {
@ -111,6 +120,7 @@ const initialSystemState: SystemState = {
shouldLogToConsole: true, shouldLogToConsole: true,
statusTranslationKey: 'common.statusDisconnected', statusTranslationKey: 'common.statusDisconnected',
canceledSession: '', canceledSession: '',
infillMethods: ['tile'],
}; };
export const systemSlice = createSlice({ export const systemSlice = createSlice({

View File

@ -1,6 +1,9 @@
import { createAppAsyncThunk } from 'app/store/storeUtils'; import { createAppAsyncThunk } from 'app/store/storeUtils';
import { SessionsService } from 'services/api'; import { SessionsService } from 'services/api';
import { buildLinearGraph as buildGenerateGraph } from 'features/nodes/util/linearGraphBuilder/buildLinearGraph'; import {
buildCanvasGraph,
buildLinearGraph as buildGenerateGraph,
} from 'features/nodes/util/linearGraphBuilder/buildLinearGraph';
import { isAnyOf, isFulfilled } from '@reduxjs/toolkit'; import { isAnyOf, isFulfilled } from '@reduxjs/toolkit';
import { buildNodesGraph } from 'features/nodes/util/nodesGraphBuilder/buildNodesGraph'; import { buildNodesGraph } from 'features/nodes/util/nodesGraphBuilder/buildNodesGraph';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
@ -42,9 +45,27 @@ export const nodesGraphBuilt = createAppAsyncThunk(
} }
); );
export const canvasGraphBuilt = createAppAsyncThunk(
'api/canvasGraphBuilt',
async (_, { dispatch, getState, rejectWithValue }) => {
try {
const graph = buildCanvasGraph(getState());
dispatch(sessionCreated({ graph }));
return graph;
} catch (err: any) {
sessionLog.error(
{ error: serializeError(err) },
'Problem building graph'
);
return rejectWithValue(err.message);
}
}
);
export const isFulfilledAnyGraphBuilt = isAnyOf( export const isFulfilledAnyGraphBuilt = isAnyOf(
generateGraphBuilt.fulfilled, generateGraphBuilt.fulfilled,
nodesGraphBuilt.fulfilled nodesGraphBuilt.fulfilled,
canvasGraphBuilt.fulfilled
); );
type SessionCreatedArg = { type SessionCreatedArg = {
@ -58,14 +79,22 @@ type SessionCreatedArg = {
*/ */
export const sessionCreated = createAppAsyncThunk( export const sessionCreated = createAppAsyncThunk(
'api/sessionCreated', 'api/sessionCreated',
async (arg: SessionCreatedArg, { dispatch, getState }) => { async (arg: SessionCreatedArg, { rejectWithValue }) => {
try {
const response = await SessionsService.createSession({ const response = await SessionsService.createSession({
requestBody: arg.graph, requestBody: arg.graph,
}); });
sessionLog.info({ arg, response }, `Session created (${response.id})`); sessionLog.info({ arg, response }, `Session created (${response.id})`);
return response; return response;
} catch (err: any) {
sessionLog.error(
{
error: serializeError(err),
},
'Problem creating session'
);
return rejectWithValue(err.message);
}
} }
); );