feat(ui): make generation mode calculation more granular

This commit is contained in:
psychedelicious 2023-07-24 18:16:15 +10:00
parent 28031ead70
commit 61fa960a18
5 changed files with 103 additions and 53 deletions

View File

@ -39,8 +39,22 @@ export const addUserInvokedCanvasListener = () => {
const state = getState(); const state = getState();
const {
layerState,
boundingBoxCoordinates,
boundingBoxDimensions,
isMaskEnabled,
shouldPreserveMaskedArea,
} = state.canvas;
// Build canvas blobs // Build canvas blobs
const canvasBlobsAndImageData = await getCanvasData(state.canvas); const canvasBlobsAndImageData = await getCanvasData(
layerState,
boundingBoxCoordinates,
boundingBoxDimensions,
isMaskEnabled,
shouldPreserveMaskedArea
);
if (!canvasBlobsAndImageData) { if (!canvasBlobsAndImageData) {
log.error('Unable to create canvas data'); log.error('Unable to create canvas data');

View File

@ -0,0 +1,72 @@
import { useAppSelector } from 'app/store/storeHooks';
import { GenerationMode } from 'features/canvas/store/canvasTypes';
import { getCanvasData } from 'features/canvas/util/getCanvasData';
import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGenerationMode';
import { useEffect, useState } from 'react';
import { useDebounce } from 'react-use';
export const useCanvasGenerationMode = () => {
const layerState = useAppSelector((state) => state.canvas.layerState);
const boundingBoxCoordinates = useAppSelector(
(state) => state.canvas.boundingBoxCoordinates
);
const boundingBoxDimensions = useAppSelector(
(state) => state.canvas.boundingBoxDimensions
);
const isMaskEnabled = useAppSelector((state) => state.canvas.isMaskEnabled);
const shouldPreserveMaskedArea = useAppSelector(
(state) => state.canvas.shouldPreserveMaskedArea
);
const [generationMode, setGenerationMode] = useState<
GenerationMode | undefined
>();
useEffect(() => {
setGenerationMode(undefined);
}, [
layerState,
boundingBoxCoordinates,
boundingBoxDimensions,
isMaskEnabled,
shouldPreserveMaskedArea,
]);
useDebounce(
async () => {
// Build canvas blobs
const canvasBlobsAndImageData = await getCanvasData(
layerState,
boundingBoxCoordinates,
boundingBoxDimensions,
isMaskEnabled,
shouldPreserveMaskedArea
);
if (!canvasBlobsAndImageData) {
return;
}
const { baseImageData, maskImageData } = canvasBlobsAndImageData;
// Determine the generation mode
const generationMode = getCanvasGenerationMode(
baseImageData,
maskImageData
);
setGenerationMode(generationMode);
},
1000,
[
layerState,
boundingBoxCoordinates,
boundingBoxDimensions,
isMaskEnabled,
shouldPreserveMaskedArea,
]
);
return generationMode;
};

View File

@ -30,7 +30,6 @@ import {
CanvasState, CanvasState,
CanvasTool, CanvasTool,
Dimensions, Dimensions,
GenerationMode,
isCanvasAnyLine, isCanvasAnyLine,
isCanvasBaseImage, isCanvasBaseImage,
isCanvasMaskLine, isCanvasMaskLine,
@ -859,9 +858,6 @@ export const canvasSlice = createSlice({
state.isMovingBoundingBox = false; state.isMovingBoundingBox = false;
state.isTransformingBoundingBox = false; state.isTransformingBoundingBox = false;
}, },
generationModeChanged: (state, action: PayloadAction<GenerationMode>) => {
state.generationMode = action.payload;
},
}, },
extraReducers: (builder) => { extraReducers: (builder) => {
builder.addCase(sessionCanceled.pending, (state) => { builder.addCase(sessionCanceled.pending, (state) => {
@ -959,7 +955,6 @@ export const {
stagingAreaInitialized, stagingAreaInitialized,
canvasSessionIdChanged, canvasSessionIdChanged,
setShouldAntialias, setShouldAntialias,
generationModeChanged,
} = canvasSlice.actions; } = canvasSlice.actions;
export default canvasSlice.reducer; export default canvasSlice.reducer;

View File

@ -1,5 +1,10 @@
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import { CanvasState, isCanvasMaskLine } from '../store/canvasTypes'; import { Vector2d } from 'konva/lib/types';
import {
CanvasLayerState,
Dimensions,
isCanvasMaskLine,
} from '../store/canvasTypes';
import createMaskStage from './createMaskStage'; import createMaskStage from './createMaskStage';
import { getCanvasBaseLayer, getCanvasStage } from './konvaInstanceProvider'; import { getCanvasBaseLayer, getCanvasStage } from './konvaInstanceProvider';
import { konvaNodeToBlob } from './konvaNodeToBlob'; import { konvaNodeToBlob } from './konvaNodeToBlob';
@ -8,7 +13,13 @@ import { konvaNodeToImageData } from './konvaNodeToImageData';
/** /**
* Gets Blob and ImageData objects for the base and mask layers * Gets Blob and ImageData objects for the base and mask layers
*/ */
export const getCanvasData = async (canvasState: CanvasState) => { export const getCanvasData = async (
layerState: CanvasLayerState,
boundingBoxCoordinates: Vector2d,
boundingBoxDimensions: Dimensions,
isMaskEnabled: boolean,
shouldPreserveMaskedArea: boolean
) => {
const log = logger('canvas'); const log = logger('canvas');
const canvasBaseLayer = getCanvasBaseLayer(); const canvasBaseLayer = getCanvasBaseLayer();
@ -19,14 +30,6 @@ export const getCanvasData = async (canvasState: CanvasState) => {
return; return;
} }
const {
layerState: { objects },
boundingBoxCoordinates,
boundingBoxDimensions,
isMaskEnabled,
shouldPreserveMaskedArea,
} = canvasState;
const boundingBox = { const boundingBox = {
...boundingBoxCoordinates, ...boundingBoxCoordinates,
...boundingBoxDimensions, ...boundingBoxDimensions,
@ -57,7 +60,7 @@ export const getCanvasData = async (canvasState: CanvasState) => {
// For the mask layer, use the normal boundingBox // For the mask layer, use the normal boundingBox
const maskStage = await createMaskStage( const maskStage = await createMaskStage(
isMaskEnabled ? objects.filter(isCanvasMaskLine) : [], // only include mask lines, and only if mask is enabled isMaskEnabled ? layerState.objects.filter(isCanvasMaskLine) : [], // only include mask lines, and only if mask is enabled
boundingBox, boundingBox,
shouldPreserveMaskedArea shouldPreserveMaskedArea
); );

View File

@ -1,9 +1,5 @@
import { Box } from '@chakra-ui/react'; import { Box } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useCanvasGenerationMode } from 'features/canvas/hooks/useCanvasGenerationMode';
import { generationModeChanged } from 'features/canvas/store/canvasSlice';
import { getCanvasData } from 'features/canvas/util/getCanvasData';
import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGenerationMode';
import { useDebounce } from 'react-use';
const GENERATION_MODE_NAME_MAP = { const GENERATION_MODE_NAME_MAP = {
txt2img: 'Text to Image', txt2img: 'Text to Image',
@ -12,38 +8,8 @@ const GENERATION_MODE_NAME_MAP = {
outpaint: 'Inpaint', outpaint: 'Inpaint',
}; };
export const useGenerationMode = () => {
const dispatch = useAppDispatch();
const canvasState = useAppSelector((state) => state.canvas);
useDebounce(
async () => {
// Build canvas blobs
const canvasBlobsAndImageData = await getCanvasData(canvasState);
if (!canvasBlobsAndImageData) {
return;
}
const { baseImageData, maskImageData } = canvasBlobsAndImageData;
// Determine the generation mode
const generationMode = getCanvasGenerationMode(
baseImageData,
maskImageData
);
dispatch(generationModeChanged(generationMode));
},
1000,
[dispatch, canvasState, generationModeChanged]
);
};
const GenerationModeStatusText = () => { const GenerationModeStatusText = () => {
const generationMode = useAppSelector((state) => state.canvas.generationMode); const generationMode = useCanvasGenerationMode();
useGenerationMode();
return ( return (
<Box> <Box>