feat(ui): wip canvas nodes migration

This commit is contained in:
psychedelicious 2023-05-02 20:11:12 +10:00
parent ff5e2a9a8c
commit 08ec12b391
19 changed files with 652 additions and 241 deletions

View File

@ -1,209 +1,209 @@
// import { AnyAction, Dispatch, MiddlewareAPI } from '@reduxjs/toolkit';
// import * as InvokeAI from 'app/types/invokeai';
// import type { RootState } from 'app/store/store';
// import {
// frontendToBackendParameters,
// FrontendToBackendParametersConfig,
// } from 'common/util/parameterTranslation';
// import dateFormat from 'dateformat';
// import {
// GalleryCategory,
// GalleryState,
// removeImage,
// } from 'features/gallery/store/gallerySlice';
// import {
// generationRequested,
// modelChangeRequested,
// modelConvertRequested,
// modelMergingRequested,
// setIsProcessing,
// } from 'features/system/store/systemSlice';
// import { InvokeTabName } from 'features/ui/store/tabMap';
// import { Socket } from 'socket.io-client';
import { AnyAction, Dispatch, MiddlewareAPI } from '@reduxjs/toolkit';
import * as InvokeAI from 'app/types/invokeai';
import type { RootState } from 'app/store/store';
import {
frontendToBackendParameters,
FrontendToBackendParametersConfig,
} from 'common/util/parameterTranslation';
import dateFormat from 'dateformat';
import {
GalleryCategory,
GalleryState,
removeImage,
} from 'features/gallery/store/gallerySlice';
import {
generationRequested,
modelChangeRequested,
modelConvertRequested,
modelMergingRequested,
setIsProcessing,
} from 'features/system/store/systemSlice';
import { InvokeTabName } from 'features/ui/store/tabMap';
import { Socket } from 'socket.io-client';
// /**
// * Returns an object containing all functions which use `socketio.emit()`.
// * i.e. those which make server requests.
// */
// const makeSocketIOEmitters = (
// store: MiddlewareAPI<Dispatch<AnyAction>, RootState>,
// socketio: Socket
// ) => {
// // We need to dispatch actions to redux and get pieces of state from the store.
// const { dispatch, getState } = store;
/**
* Returns an object containing all functions which use `socketio.emit()`.
* i.e. those which make server requests.
*/
const makeSocketIOEmitters = (
store: MiddlewareAPI<Dispatch<AnyAction>, RootState>,
socketio: Socket
) => {
// We need to dispatch actions to redux and get pieces of state from the store.
const { dispatch, getState } = store;
// return {
// emitGenerateImage: (generationMode: InvokeTabName) => {
// dispatch(setIsProcessing(true));
return {
emitGenerateImage: (generationMode: InvokeTabName) => {
dispatch(setIsProcessing(true));
// const state: RootState = getState();
const state: RootState = getState();
// const {
// generation: generationState,
// postprocessing: postprocessingState,
// system: systemState,
// canvas: canvasState,
// } = state;
const {
generation: generationState,
postprocessing: postprocessingState,
system: systemState,
canvas: canvasState,
} = state;
// const frontendToBackendParametersConfig: FrontendToBackendParametersConfig =
// {
// generationMode,
// generationState,
// postprocessingState,
// canvasState,
// systemState,
// };
const frontendToBackendParametersConfig: FrontendToBackendParametersConfig =
{
generationMode,
generationState,
postprocessingState,
canvasState,
systemState,
};
// dispatch(generationRequested());
dispatch(generationRequested());
// const { generationParameters, esrganParameters, facetoolParameters } =
// frontendToBackendParameters(frontendToBackendParametersConfig);
const { generationParameters, esrganParameters, facetoolParameters } =
frontendToBackendParameters(frontendToBackendParametersConfig);
// socketio.emit(
// 'generateImage',
// generationParameters,
// esrganParameters,
// facetoolParameters
// );
socketio.emit(
'generateImage',
generationParameters,
esrganParameters,
facetoolParameters
);
// // we need to truncate the init_mask base64 else it takes up the whole log
// // TODO: handle maintaining masks for reproducibility in future
// if (generationParameters.init_mask) {
// generationParameters.init_mask = generationParameters.init_mask
// .substr(0, 64)
// .concat('...');
// }
// if (generationParameters.init_img) {
// generationParameters.init_img = generationParameters.init_img
// .substr(0, 64)
// .concat('...');
// }
// we need to truncate the init_mask base64 else it takes up the whole log
// TODO: handle maintaining masks for reproducibility in future
if (generationParameters.init_mask) {
generationParameters.init_mask = generationParameters.init_mask
.substr(0, 64)
.concat('...');
}
if (generationParameters.init_img) {
generationParameters.init_img = generationParameters.init_img
.substr(0, 64)
.concat('...');
}
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `Image generation requested: ${JSON.stringify({
// ...generationParameters,
// ...esrganParameters,
// ...facetoolParameters,
// })}`,
// })
// );
// },
// emitRunESRGAN: (imageToProcess: InvokeAI._Image) => {
// dispatch(setIsProcessing(true));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Image generation requested: ${JSON.stringify({
...generationParameters,
...esrganParameters,
...facetoolParameters,
})}`,
})
);
},
emitRunESRGAN: (imageToProcess: InvokeAI._Image) => {
dispatch(setIsProcessing(true));
// const {
// postprocessing: {
// upscalingLevel,
// upscalingDenoising,
// upscalingStrength,
// },
// } = getState();
const {
postprocessing: {
upscalingLevel,
upscalingDenoising,
upscalingStrength,
},
} = getState();
// const esrganParameters = {
// upscale: [upscalingLevel, upscalingDenoising, upscalingStrength],
// };
// socketio.emit('runPostprocessing', imageToProcess, {
// type: 'esrgan',
// ...esrganParameters,
// });
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `ESRGAN upscale requested: ${JSON.stringify({
// file: imageToProcess.url,
// ...esrganParameters,
// })}`,
// })
// );
// },
// emitRunFacetool: (imageToProcess: InvokeAI._Image) => {
// dispatch(setIsProcessing(true));
const esrganParameters = {
upscale: [upscalingLevel, upscalingDenoising, upscalingStrength],
};
socketio.emit('runPostprocessing', imageToProcess, {
type: 'esrgan',
...esrganParameters,
});
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `ESRGAN upscale requested: ${JSON.stringify({
file: imageToProcess.url,
...esrganParameters,
})}`,
})
);
},
emitRunFacetool: (imageToProcess: InvokeAI._Image) => {
dispatch(setIsProcessing(true));
// const {
// postprocessing: { facetoolType, facetoolStrength, codeformerFidelity },
// } = getState();
const {
postprocessing: { facetoolType, facetoolStrength, codeformerFidelity },
} = getState();
// const facetoolParameters: Record<string, unknown> = {
// facetool_strength: facetoolStrength,
// };
const facetoolParameters: Record<string, unknown> = {
facetool_strength: facetoolStrength,
};
// if (facetoolType === 'codeformer') {
// facetoolParameters.codeformer_fidelity = codeformerFidelity;
// }
if (facetoolType === 'codeformer') {
facetoolParameters.codeformer_fidelity = codeformerFidelity;
}
// socketio.emit('runPostprocessing', imageToProcess, {
// type: facetoolType,
// ...facetoolParameters,
// });
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `Face restoration (${facetoolType}) requested: ${JSON.stringify(
// {
// file: imageToProcess.url,
// ...facetoolParameters,
// }
// )}`,
// })
// );
// },
// emitDeleteImage: (imageToDelete: InvokeAI._Image) => {
// const { url, uuid, category, thumbnail } = imageToDelete;
// dispatch(removeImage(imageToDelete));
// socketio.emit('deleteImage', url, thumbnail, uuid, category);
// },
// emitRequestImages: (category: GalleryCategory) => {
// const gallery: GalleryState = getState().gallery;
// const { earliest_mtime } = gallery.categories[category];
// socketio.emit('requestImages', category, earliest_mtime);
// },
// emitRequestNewImages: (category: GalleryCategory) => {
// const gallery: GalleryState = getState().gallery;
// const { latest_mtime } = gallery.categories[category];
// socketio.emit('requestLatestImages', category, latest_mtime);
// },
// emitCancelProcessing: () => {
// socketio.emit('cancel');
// },
// emitRequestSystemConfig: () => {
// socketio.emit('requestSystemConfig');
// },
// emitSearchForModels: (modelFolder: string) => {
// socketio.emit('searchForModels', modelFolder);
// },
// emitAddNewModel: (modelConfig: InvokeAI.InvokeModelConfigProps) => {
// socketio.emit('addNewModel', modelConfig);
// },
// emitDeleteModel: (modelName: string) => {
// socketio.emit('deleteModel', modelName);
// },
// emitConvertToDiffusers: (
// modelToConvert: InvokeAI.InvokeModelConversionProps
// ) => {
// dispatch(modelConvertRequested());
// socketio.emit('convertToDiffusers', modelToConvert);
// },
// emitMergeDiffusersModels: (
// modelMergeInfo: InvokeAI.InvokeModelMergingProps
// ) => {
// dispatch(modelMergingRequested());
// socketio.emit('mergeDiffusersModels', modelMergeInfo);
// },
// emitRequestModelChange: (modelName: string) => {
// dispatch(modelChangeRequested());
// socketio.emit('requestModelChange', modelName);
// },
// emitSaveStagingAreaImageToGallery: (url: string) => {
// socketio.emit('requestSaveStagingAreaImageToGallery', url);
// },
// emitRequestEmptyTempFolder: () => {
// socketio.emit('requestEmptyTempFolder');
// },
// };
// };
socketio.emit('runPostprocessing', imageToProcess, {
type: facetoolType,
...facetoolParameters,
});
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Face restoration (${facetoolType}) requested: ${JSON.stringify(
{
file: imageToProcess.url,
...facetoolParameters,
}
)}`,
})
);
},
emitDeleteImage: (imageToDelete: InvokeAI._Image) => {
const { url, uuid, category, thumbnail } = imageToDelete;
dispatch(removeImage(imageToDelete));
socketio.emit('deleteImage', url, thumbnail, uuid, category);
},
emitRequestImages: (category: GalleryCategory) => {
const gallery: GalleryState = getState().gallery;
const { earliest_mtime } = gallery.categories[category];
socketio.emit('requestImages', category, earliest_mtime);
},
emitRequestNewImages: (category: GalleryCategory) => {
const gallery: GalleryState = getState().gallery;
const { latest_mtime } = gallery.categories[category];
socketio.emit('requestLatestImages', category, latest_mtime);
},
emitCancelProcessing: () => {
socketio.emit('cancel');
},
emitRequestSystemConfig: () => {
socketio.emit('requestSystemConfig');
},
emitSearchForModels: (modelFolder: string) => {
socketio.emit('searchForModels', modelFolder);
},
emitAddNewModel: (modelConfig: InvokeAI.InvokeModelConfigProps) => {
socketio.emit('addNewModel', modelConfig);
},
emitDeleteModel: (modelName: string) => {
socketio.emit('deleteModel', modelName);
},
emitConvertToDiffusers: (
modelToConvert: InvokeAI.InvokeModelConversionProps
) => {
dispatch(modelConvertRequested());
socketio.emit('convertToDiffusers', modelToConvert);
},
emitMergeDiffusersModels: (
modelMergeInfo: InvokeAI.InvokeModelMergingProps
) => {
dispatch(modelMergingRequested());
socketio.emit('mergeDiffusersModels', modelMergeInfo);
},
emitRequestModelChange: (modelName: string) => {
dispatch(modelChangeRequested());
socketio.emit('requestModelChange', modelName);
},
emitSaveStagingAreaImageToGallery: (url: string) => {
socketio.emit('requestSaveStagingAreaImageToGallery', url);
},
emitRequestEmptyTempFolder: () => {
socketio.emit('requestEmptyTempFolder');
},
};
};
// export default makeSocketIOEmitters;
export default makeSocketIOEmitters;
export default {};

View File

@ -0,0 +1,59 @@
export const getIsImageDataPartiallyTransparent = (imageData: ImageData) => {
let hasTransparency = false;
let isFullyTransparent = true;
const len = imageData.data.length;
let i = 3;
for (i; i < len; i += 4) {
if (imageData.data[i] !== 0) {
isFullyTransparent = false;
} else {
hasTransparency = true;
}
}
return { hasTransparency, isFullyTransparent };
};
export const getImageDataTransparency = (imageData: ImageData) => {
let isFullyTransparent = true;
let isPartiallyTransparent = false;
const len = imageData.data.length;
let i = 3;
for (i; i < len; i += 4) {
if (imageData.data[i] === 255) {
isFullyTransparent = false;
} else {
isPartiallyTransparent = true;
}
if (!isFullyTransparent && isPartiallyTransparent) {
return { isFullyTransparent, isPartiallyTransparent };
}
}
return { isFullyTransparent, isPartiallyTransparent };
};
export const areAnyPixelsBlack = (imageData: ImageData) => {
const len = imageData.data.length;
let i = 0;
for (i; i < len; ) {
if (
imageData.data[i++] === 255 &&
imageData.data[i++] === 255 &&
imageData.data[i++] === 255 &&
imageData.data[i++] === 255
) {
return true;
}
}
return false;
};
export const getIsImageDataWhite = (imageData: ImageData) => {
const len = imageData.data.length;
let i = 0;
for (i; i < len; ) {
if (imageData.data[i++] !== 255) {
return false;
}
}
return true;
};

View File

@ -19,6 +19,7 @@ import { InvokeTabName } from 'features/ui/store/tabMap';
import openBase64ImageInTab from './openBase64ImageInTab';
import randomInt from './randomInt';
import { stringToSeedWeightsArray } from './seedWeightPairs';
import { getIsImageDataTransparent, getIsImageDataWhite } from './arrayBuffer';
export type FrontendToBackendParametersConfig = {
generationMode: InvokeTabName;
@ -256,7 +257,7 @@ export const frontendToBackendParameters = (
...boundingBoxDimensions,
};
const maskDataURL = generateMask(
const { dataURL: maskDataURL, imageData: maskImageData } = generateMask(
isMaskEnabled ? objects.filter(isCanvasMaskLine) : [],
boundingBox
);
@ -287,6 +288,19 @@ export const frontendToBackendParameters = (
height: boundingBox.height,
});
const ctx = canvasBaseLayer.getContext();
const imageData = ctx.getImageData(
boundingBox.x + absPos.x,
boundingBox.y + absPos.y,
boundingBox.width,
boundingBox.height
);
const doesBaseHaveTransparency = getIsImageDataTransparent(imageData);
const doesMaskHaveTransparency = getIsImageDataWhite(maskImageData);
console.log(doesBaseHaveTransparency, doesMaskHaveTransparency);
if (enableImageDebugging) {
openBase64ImageInTab([
{ base64: maskDataURL, caption: 'mask sent as init_mask' },

View File

@ -0,0 +1,39 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import {
FrontendToBackendParametersConfig,
frontendToBackendParameters,
} from 'common/util/parameterTranslation';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import { postprocessingSelector } from 'features/parameters/store/postprocessingSelectors';
import { systemSelector } from 'features/system/store/systemSelectors';
import { canvasSelector } from '../store/canvasSelectors';
import { useCallback, useMemo } from 'react';
const selector = createSelector(
[generationSelector, postprocessingSelector, systemSelector, canvasSelector],
(generation, postprocessing, system, canvas) => {
const frontendToBackendParametersConfig: FrontendToBackendParametersConfig =
{
generationMode: 'unifiedCanvas',
generationState: generation,
postprocessingState: postprocessing,
canvasState: canvas,
systemState: system,
};
return frontendToBackendParametersConfig;
}
);
export const usePrepareCanvasState = () => {
const frontendToBackendParametersConfig = useAppSelector(selector);
const getGenerationParameters = useCallback(() => {
const { generationParameters, esrganParameters, facetoolParameters } =
frontendToBackendParameters(frontendToBackendParametersConfig);
console.log(generationParameters);
}, [frontendToBackendParametersConfig]);
return getGenerationParameters;
};

View File

@ -156,22 +156,20 @@ export const canvasSlice = createSlice({
setCursorPosition: (state, action: PayloadAction<Vector2d | null>) => {
state.cursorPosition = action.payload;
},
setInitialCanvasImage: (state, action: PayloadAction<InvokeAI._Image>) => {
setInitialCanvasImage: (state, action: PayloadAction<InvokeAI.Image>) => {
const image = action.payload;
const { width, height } = image.metadata;
const { stageDimensions } = state;
const newBoundingBoxDimensions = {
width: roundDownToMultiple(clamp(image.width, 64, 512), 64),
height: roundDownToMultiple(clamp(image.height, 64, 512), 64),
width: roundDownToMultiple(clamp(width, 64, 512), 64),
height: roundDownToMultiple(clamp(height, 64, 512), 64),
};
const newBoundingBoxCoordinates = {
x: roundToMultiple(
image.width / 2 - newBoundingBoxDimensions.width / 2,
64
),
x: roundToMultiple(width / 2 - newBoundingBoxDimensions.width / 2, 64),
y: roundToMultiple(
image.height / 2 - newBoundingBoxDimensions.height / 2,
height / 2 - newBoundingBoxDimensions.height / 2,
64
),
};
@ -196,8 +194,8 @@ export const canvasSlice = createSlice({
layer: 'base',
x: 0,
y: 0,
width: image.width,
height: image.height,
width: width,
height: height,
image: image,
},
],
@ -208,8 +206,8 @@ export const canvasSlice = createSlice({
const newScale = calculateScale(
stageDimensions.width,
stageDimensions.height,
image.width,
image.height,
width,
height,
STAGE_PADDING_PERCENTAGE
);
@ -218,8 +216,8 @@ export const canvasSlice = createSlice({
stageDimensions.height,
0,
0,
image.width,
image.height,
width,
height,
newScale
);
state.stageScale = newScale;

View File

@ -12,7 +12,10 @@ import { IRect } from 'konva/lib/types';
* drawing the mask and compositing everything correctly to output a valid
* mask image.
*/
const generateMask = (lines: CanvasMaskLine[], boundingBox: IRect): string => {
const generateMask = (
lines: CanvasMaskLine[],
boundingBox: IRect
): { dataURL: string; imageData: ImageData } => {
// create an offscreen canvas and add the mask to it
const { width, height } = boundingBox;
@ -55,10 +58,19 @@ const generateMask = (lines: CanvasMaskLine[], boundingBox: IRect): string => {
stage.add(maskLayer);
const dataURL = stage.toDataURL({ ...boundingBox });
const imageData = stage
.toCanvas()
.getContext('2d')
?.getImageData(
boundingBox.x,
boundingBox.y,
boundingBox.width,
boundingBox.height
);
offscreenContainer.remove();
return dataURL;
return { dataURL, imageData };
};
export default generateMask;

View File

@ -0,0 +1,123 @@
import { RootState } from 'app/store/store';
import { getCanvasBaseLayer, getCanvasStage } from './konvaInstanceProvider';
import { isCanvasMaskLine } from '../store/canvasTypes';
import generateMask from './generateMask';
import { log } from 'app/logging/useLogger';
import {
areAnyPixelsBlack,
getImageDataTransparency,
getIsImageDataWhite,
} from 'common/util/arrayBuffer';
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
export const getCanvasDataURLs = (state: RootState) => {
const canvasBaseLayer = getCanvasBaseLayer();
const canvasStage = getCanvasStage();
if (!canvasBaseLayer || !canvasStage) {
log.error(
{ namespace: 'getCanvasDataURLs' },
'Unable to find canvas / stage'
);
return;
}
const {
layerState: { objects },
boundingBoxCoordinates,
boundingBoxDimensions,
stageScale,
isMaskEnabled,
shouldPreserveMaskedArea,
boundingBoxScaleMethod: boundingBoxScale,
scaledBoundingBoxDimensions,
} = state.canvas;
const boundingBox = {
...boundingBoxCoordinates,
...boundingBoxDimensions,
};
// generationParameters.fit = false;
// generationParameters.strength = img2imgStrength;
// generationParameters.invert_mask = shouldPreserveMaskedArea;
// generationParameters.bounding_box = boundingBox;
const tempScale = canvasBaseLayer.scale();
canvasBaseLayer.scale({
x: 1 / stageScale,
y: 1 / stageScale,
});
const absPos = canvasBaseLayer.getAbsolutePosition();
const { dataURL: maskDataURL, imageData: maskImageData } = generateMask(
isMaskEnabled ? objects.filter(isCanvasMaskLine) : [],
{
x: boundingBox.x + absPos.x,
y: boundingBox.y + absPos.y,
width: boundingBox.width,
height: boundingBox.height,
}
);
const baseDataURL = canvasBaseLayer.toDataURL({
x: boundingBox.x + absPos.x,
y: boundingBox.y + absPos.y,
width: boundingBox.width,
height: boundingBox.height,
});
const ctx = canvasBaseLayer.getContext();
const baseImageData = ctx.getImageData(
boundingBox.x + absPos.x,
boundingBox.y + absPos.y,
boundingBox.width,
boundingBox.height
);
const {
isPartiallyTransparent: baseIsPartiallyTransparent,
isFullyTransparent: baseIsFullyTransparent,
} = getImageDataTransparency(baseImageData);
const doesMaskHaveBlackPixels = areAnyPixelsBlack(maskImageData);
if (state.system.enableImageDebugging) {
openBase64ImageInTab([
{ base64: maskDataURL, caption: 'mask sent as init_mask' },
{ base64: baseDataURL, caption: 'image sent as init_img' },
]);
}
canvasBaseLayer.scale(tempScale);
// generationParameters.init_img = imageDataURL;
// generationParameters.progress_images = false;
// if (boundingBoxScale !== 'none') {
// generationParameters.inpaint_width = scaledBoundingBoxDimensions.width;
// generationParameters.inpaint_height = scaledBoundingBoxDimensions.height;
// }
// generationParameters.seam_size = seamSize;
// generationParameters.seam_blur = seamBlur;
// generationParameters.seam_strength = seamStrength;
// generationParameters.seam_steps = seamSteps;
// generationParameters.tile_size = tileSize;
// generationParameters.infill_method = infillMethod;
// generationParameters.force_outpaint = false;
return {
baseDataURL,
maskDataURL,
baseIsPartiallyTransparent,
baseIsFullyTransparent,
doesMaskHaveBlackPixels,
};
};

View File

@ -17,7 +17,10 @@ import { FaCheck, FaExpand, FaImage, FaShare, FaTrash } from 'react-icons/fa';
import DeleteImageModal from './DeleteImageModal';
import { ContextMenu } from 'chakra-ui-contextmenu';
import * as InvokeAI from 'app/types/invokeai';
import { resizeAndScaleCanvas } from 'features/canvas/store/canvasSlice';
import {
resizeAndScaleCanvas,
setInitialCanvasImage,
} from 'features/canvas/store/canvasSlice';
import { gallerySelector } from 'features/gallery/store/gallerySelectors';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { useTranslation } from 'react-i18next';
@ -159,7 +162,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
* TODO: the rest of these
*/
const handleSendToCanvas = () => {
// dispatch(setInitialCanvasImage(image));
dispatch(setInitialCanvasImage(image));
dispatch(resizeAndScaleCanvas());
@ -315,6 +318,8 @@ const HoverableImage = memo((props: HoverableImageProps) => {
sx={{
width: '50%',
height: '50%',
maxWidth: '4rem',
maxHeight: '4rem',
fill: 'ok.500',
}}
/>

View File

@ -4,12 +4,8 @@ import { GalleryState } from './gallerySlice';
* Gallery slice persist denylist
*/
const itemsToDenylist: (keyof GalleryState)[] = [
'categories',
'currentCategory',
'currentImage',
'currentImageUuid',
'shouldAutoSwitchToNewImages',
'intermediateImage',
];
export const galleryDenylist = itemsToDenylist.map(

View File

@ -5,7 +5,7 @@ import { ResultsState } from './resultsSlice';
*
* Currently denylisting results slice entirely, see persist config in store.ts
*/
const itemsToDenylist: (keyof ResultsState)[] = ['isLoading'];
const itemsToDenylist: (keyof ResultsState)[] = [];
export const resultsDenylist = itemsToDenylist.map(
(denylistItem) => `results.${denylistItem}`

View File

@ -5,7 +5,7 @@ import { UploadsState } from './uploadsSlice';
*
* Currently denylisting uploads slice entirely, see persist config in store.ts
*/
const itemsToDenylist: (keyof UploadsState)[] = ['isLoading'];
const itemsToDenylist: (keyof UploadsState)[] = [];
export const uploadsDenylist = itemsToDenylist.map(
(denylistItem) => `uploads.${denylistItem}`

View File

@ -0,0 +1,19 @@
export const getNodeType = (
baseIsPartiallyTransparent: boolean,
baseIsFullyTransparent: boolean,
doesMaskHaveBlackPixels: boolean
): 'txt2img' | `img2img` | 'inpaint' | 'outpaint' => {
if (baseIsPartiallyTransparent) {
if (baseIsFullyTransparent) {
return 'txt2img';
}
return 'outpaint';
} else {
if (doesMaskHaveBlackPixels) {
return 'inpaint';
}
return 'img2img';
}
};

View File

@ -5,10 +5,13 @@ import {
ImageToImageInvocation,
TextToImageInvocation,
} from 'services/api';
import { _Image } from 'app/types/invokeai';
import { initialImageSelector } from 'features/parameters/store/generationSelectors';
import { O } from 'ts-toolbelt';
export const buildImg2ImgNode = (state: RootState): ImageToImageInvocation => {
export const buildImg2ImgNode = (
state: RootState,
overrides: O.Partial<ImageToImageInvocation, 'deep'> = {}
): ImageToImageInvocation => {
const nodeId = uuidv4();
const { generation, system, models } = state;
@ -33,7 +36,7 @@ export const buildImg2ImgNode = (state: RootState): ImageToImageInvocation => {
if (!initialImage) {
// TODO: handle this
throw 'no initial image';
// throw 'no initial image';
}
const imageToImageNode: ImageToImageInvocation = {
@ -48,10 +51,12 @@ export const buildImg2ImgNode = (state: RootState): ImageToImageInvocation => {
seamless,
model: selectedModelName,
progress_images: true,
image: {
image: initialImage
? {
image_name: initialImage.name,
image_type: initialImage.type,
},
}
: undefined,
strength,
fit,
};
@ -60,6 +65,8 @@ export const buildImg2ImgNode = (state: RootState): ImageToImageInvocation => {
imageToImageNode.seed = seed;
}
Object.assign(imageToImageNode, overrides);
return imageToImageNode;
};

View File

@ -1,10 +1,15 @@
import { RootState } from 'app/store/store';
import { Graph } from 'services/api';
import { DataURLToImageInvocation, Graph } from 'services/api';
import { buildImg2ImgNode } from './buildImageToImageNode';
import { buildTxt2ImgNode } from './buildTextToImageNode';
import { buildRangeNode } from './buildRangeNode';
import { buildIterateNode } from './buildIterateNode';
import { buildEdges } from './buildEdges';
import { getCanvasBaseLayer } from 'features/canvas/util/konvaInstanceProvider';
import { getCanvasDataURLs } from 'features/canvas/util/getCanvasDataURLs';
import { log } from 'console';
import { getNodeType } from '../getNodeType';
import { v4 as uuidv4 } from 'uuid';
/**
* Builds the Linear workflow graph.
@ -37,3 +42,86 @@ export const buildLinearGraph = (state: RootState): Graph => {
return graph;
};
/**
* Builds the Linear workflow graph.
*/
export const buildCanvasGraph = (state: RootState): Graph => {
const c = getCanvasDataURLs(state);
if (!c) {
throw 'problm creating canvas graph';
}
const {
baseDataURL,
maskDataURL,
baseIsPartiallyTransparent,
baseIsFullyTransparent,
doesMaskHaveBlackPixels,
} = c;
console.log({
baseDataURL,
maskDataURL,
baseIsPartiallyTransparent,
baseIsFullyTransparent,
doesMaskHaveBlackPixels,
});
const nodeType = getNodeType(
baseIsPartiallyTransparent,
baseIsFullyTransparent,
doesMaskHaveBlackPixels
);
console.log(nodeType);
// The base node is either a txt2img or img2img node
const baseNode =
nodeType === 'img2img'
? buildImg2ImgNode(state, state.canvas.boundingBoxDimensions)
: buildTxt2ImgNode(state, state.canvas.boundingBoxDimensions);
const dataURLNode: DataURLToImageInvocation = {
id: uuidv4(),
type: 'dataURL_image',
dataURL: baseDataURL,
};
// We always range and iterate nodes, no matter the iteration count
// This is required to provide the correct seeds to the backend engine
const rangeNode = buildRangeNode(state);
const iterateNode = buildIterateNode();
// Build the edges for the nodes selected.
const edges = buildEdges(baseNode, rangeNode, iterateNode);
if (baseNode.type === 'img2img') {
edges.push({
source: {
node_id: dataURLNode.id,
field: 'image',
},
destination: {
node_id: baseNode.id,
field: 'image',
},
});
}
// Assemble!
const graph = {
nodes: {
[dataURLNode.id]: dataURLNode,
[rangeNode.id]: rangeNode,
[iterateNode.id]: iterateNode,
[baseNode.id]: baseNode,
},
edges,
};
// TODO: hires fix requires latent space upscaling; we don't have nodes for this yet
return graph;
};

View File

@ -1,8 +1,12 @@
import { v4 as uuidv4 } from 'uuid';
import { RootState } from 'app/store/store';
import { TextToImageInvocation } from 'services/api';
import { O } from 'ts-toolbelt';
export const buildTxt2ImgNode = (state: RootState): TextToImageInvocation => {
export const buildTxt2ImgNode = (
state: RootState,
overrides: O.Partial<TextToImageInvocation, 'deep'> = {}
): TextToImageInvocation => {
const nodeId = uuidv4();
const { generation, models } = state;
@ -39,5 +43,7 @@ export const buildTxt2ImgNode = (state: RootState): TextToImageInvocation => {
textToImageNode.seed = seed;
}
Object.assign(textToImageNode, overrides);
return textToImageNode;
};

View File

@ -28,7 +28,7 @@ const selector = createSelector(
(parameters, system, canvas) => {
const { tileSize, infillMethod } = parameters;
const { infill_methods: availableInfillMethods } = system;
const { infillMethods } = system;
const {
boundingBoxScaleMethod: boundingBoxScale,
@ -40,7 +40,7 @@ const selector = createSelector(
scaledBoundingBoxDimensions,
tileSize,
infillMethod,
availableInfillMethods,
infillMethods,
isManual: boundingBoxScale === 'manual',
};
},
@ -56,7 +56,7 @@ const InfillAndScalingSettings = () => {
const {
tileSize,
infillMethod,
availableInfillMethods,
infillMethods,
boundingBoxScale,
isManual,
scaledBoundingBoxDimensions,
@ -147,7 +147,7 @@ const InfillAndScalingSettings = () => {
<IAISelect
label={t('parameters.infillMethod')}
value={infillMethod}
validValues={availableInfillMethods}
validValues={infillMethods}
onChange={(e) => dispatch(setInfillMethod(e.target.value))}
/>
<IAISlider

View File

@ -5,13 +5,14 @@ import IAIButton, { IAIButtonProps } from 'common/components/IAIButton';
import IAIIconButton, {
IAIIconButtonProps,
} from 'common/components/IAIIconButton';
import { usePrepareCanvasState } from 'features/canvas/hooks/usePrepareCanvasState';
import { clampSymmetrySteps } from 'features/parameters/store/generationSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { useCallback } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { FaPlay } from 'react-icons/fa';
import { generateGraphBuilt } from 'services/thunks/session';
import { canvasGraphBuilt, generateGraphBuilt } from 'services/thunks/session';
interface InvokeButton
extends Omit<IAIButtonProps | IAIIconButtonProps, 'aria-label'> {
@ -23,11 +24,16 @@ export default function InvokeButton(props: InvokeButton) {
const dispatch = useAppDispatch();
const { isReady } = useAppSelector(readinessSelector);
const activeTabName = useAppSelector(activeTabNameSelector);
// const getGenerationParameters = usePrepareCanvasState();
const handleInvoke = useCallback(() => {
dispatch(clampSymmetrySteps());
if (activeTabName === 'unifiedCanvas') {
dispatch(canvasGraphBuilt());
} else {
dispatch(generateGraphBuilt());
}, [dispatch]);
}
}, [dispatch, activeTabName]);
const { t } = useTranslation();

View File

@ -27,6 +27,8 @@ import { t } from 'i18next';
export type CancelStrategy = 'immediate' | 'scheduled';
export type InfillMethod = 'tile' | 'patchmatch';
export interface SystemState {
isGFPGANAvailable: boolean;
isESRGANAvailable: boolean;
@ -79,7 +81,14 @@ export interface SystemState {
consoleLogLevel: InvokeLogLevel;
shouldLogToConsole: boolean;
statusTranslationKey: TFuncKey;
/**
* When a session is canceled, its ID is stored here until a new session is created.
*/
canceledSession: string;
/**
* TODO: get this from backend
*/
infillMethods: InfillMethod[];
}
const initialSystemState: SystemState = {
@ -111,6 +120,7 @@ const initialSystemState: SystemState = {
shouldLogToConsole: true,
statusTranslationKey: 'common.statusDisconnected',
canceledSession: '',
infillMethods: ['tile'],
};
export const systemSlice = createSlice({

View File

@ -1,6 +1,9 @@
import { createAppAsyncThunk } from 'app/store/storeUtils';
import { SessionsService } from 'services/api';
import { buildLinearGraph as buildGenerateGraph } from 'features/nodes/util/linearGraphBuilder/buildLinearGraph';
import {
buildCanvasGraph,
buildLinearGraph as buildGenerateGraph,
} from 'features/nodes/util/linearGraphBuilder/buildLinearGraph';
import { isAnyOf, isFulfilled } from '@reduxjs/toolkit';
import { buildNodesGraph } from 'features/nodes/util/nodesGraphBuilder/buildNodesGraph';
import { log } from 'app/logging/useLogger';
@ -42,9 +45,27 @@ export const nodesGraphBuilt = createAppAsyncThunk(
}
);
export const canvasGraphBuilt = createAppAsyncThunk(
'api/canvasGraphBuilt',
async (_, { dispatch, getState, rejectWithValue }) => {
try {
const graph = buildCanvasGraph(getState());
dispatch(sessionCreated({ graph }));
return graph;
} catch (err: any) {
sessionLog.error(
{ error: serializeError(err) },
'Problem building graph'
);
return rejectWithValue(err.message);
}
}
);
export const isFulfilledAnyGraphBuilt = isAnyOf(
generateGraphBuilt.fulfilled,
nodesGraphBuilt.fulfilled
nodesGraphBuilt.fulfilled,
canvasGraphBuilt.fulfilled
);
type SessionCreatedArg = {
@ -58,14 +79,22 @@ type SessionCreatedArg = {
*/
export const sessionCreated = createAppAsyncThunk(
'api/sessionCreated',
async (arg: SessionCreatedArg, { dispatch, getState }) => {
async (arg: SessionCreatedArg, { rejectWithValue }) => {
try {
const response = await SessionsService.createSession({
requestBody: arg.graph,
});
sessionLog.info({ arg, response }, `Session created (${response.id})`);
return response;
} catch (err: any) {
sessionLog.error(
{
error: serializeError(err),
},
'Problem creating session'
);
return rejectWithValue(err.message);
}
}
);