mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): make generation mode calculation more granular
This commit is contained in:
parent
28031ead70
commit
61fa960a18
@ -39,8 +39,22 @@ export const addUserInvokedCanvasListener = () => {
|
|||||||
|
|
||||||
const state = getState();
|
const state = getState();
|
||||||
|
|
||||||
|
const {
|
||||||
|
layerState,
|
||||||
|
boundingBoxCoordinates,
|
||||||
|
boundingBoxDimensions,
|
||||||
|
isMaskEnabled,
|
||||||
|
shouldPreserveMaskedArea,
|
||||||
|
} = state.canvas;
|
||||||
|
|
||||||
// Build canvas blobs
|
// Build canvas blobs
|
||||||
const canvasBlobsAndImageData = await getCanvasData(state.canvas);
|
const canvasBlobsAndImageData = await getCanvasData(
|
||||||
|
layerState,
|
||||||
|
boundingBoxCoordinates,
|
||||||
|
boundingBoxDimensions,
|
||||||
|
isMaskEnabled,
|
||||||
|
shouldPreserveMaskedArea
|
||||||
|
);
|
||||||
|
|
||||||
if (!canvasBlobsAndImageData) {
|
if (!canvasBlobsAndImageData) {
|
||||||
log.error('Unable to create canvas data');
|
log.error('Unable to create canvas data');
|
||||||
|
@ -0,0 +1,72 @@
|
|||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { GenerationMode } from 'features/canvas/store/canvasTypes';
|
||||||
|
import { getCanvasData } from 'features/canvas/util/getCanvasData';
|
||||||
|
import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGenerationMode';
|
||||||
|
import { useEffect, useState } from 'react';
|
||||||
|
import { useDebounce } from 'react-use';
|
||||||
|
|
||||||
|
export const useCanvasGenerationMode = () => {
|
||||||
|
const layerState = useAppSelector((state) => state.canvas.layerState);
|
||||||
|
|
||||||
|
const boundingBoxCoordinates = useAppSelector(
|
||||||
|
(state) => state.canvas.boundingBoxCoordinates
|
||||||
|
);
|
||||||
|
const boundingBoxDimensions = useAppSelector(
|
||||||
|
(state) => state.canvas.boundingBoxDimensions
|
||||||
|
);
|
||||||
|
const isMaskEnabled = useAppSelector((state) => state.canvas.isMaskEnabled);
|
||||||
|
|
||||||
|
const shouldPreserveMaskedArea = useAppSelector(
|
||||||
|
(state) => state.canvas.shouldPreserveMaskedArea
|
||||||
|
);
|
||||||
|
const [generationMode, setGenerationMode] = useState<
|
||||||
|
GenerationMode | undefined
|
||||||
|
>();
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
setGenerationMode(undefined);
|
||||||
|
}, [
|
||||||
|
layerState,
|
||||||
|
boundingBoxCoordinates,
|
||||||
|
boundingBoxDimensions,
|
||||||
|
isMaskEnabled,
|
||||||
|
shouldPreserveMaskedArea,
|
||||||
|
]);
|
||||||
|
|
||||||
|
useDebounce(
|
||||||
|
async () => {
|
||||||
|
// Build canvas blobs
|
||||||
|
const canvasBlobsAndImageData = await getCanvasData(
|
||||||
|
layerState,
|
||||||
|
boundingBoxCoordinates,
|
||||||
|
boundingBoxDimensions,
|
||||||
|
isMaskEnabled,
|
||||||
|
shouldPreserveMaskedArea
|
||||||
|
);
|
||||||
|
|
||||||
|
if (!canvasBlobsAndImageData) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const { baseImageData, maskImageData } = canvasBlobsAndImageData;
|
||||||
|
|
||||||
|
// Determine the generation mode
|
||||||
|
const generationMode = getCanvasGenerationMode(
|
||||||
|
baseImageData,
|
||||||
|
maskImageData
|
||||||
|
);
|
||||||
|
|
||||||
|
setGenerationMode(generationMode);
|
||||||
|
},
|
||||||
|
1000,
|
||||||
|
[
|
||||||
|
layerState,
|
||||||
|
boundingBoxCoordinates,
|
||||||
|
boundingBoxDimensions,
|
||||||
|
isMaskEnabled,
|
||||||
|
shouldPreserveMaskedArea,
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
|
return generationMode;
|
||||||
|
};
|
@ -30,7 +30,6 @@ import {
|
|||||||
CanvasState,
|
CanvasState,
|
||||||
CanvasTool,
|
CanvasTool,
|
||||||
Dimensions,
|
Dimensions,
|
||||||
GenerationMode,
|
|
||||||
isCanvasAnyLine,
|
isCanvasAnyLine,
|
||||||
isCanvasBaseImage,
|
isCanvasBaseImage,
|
||||||
isCanvasMaskLine,
|
isCanvasMaskLine,
|
||||||
@ -859,9 +858,6 @@ export const canvasSlice = createSlice({
|
|||||||
state.isMovingBoundingBox = false;
|
state.isMovingBoundingBox = false;
|
||||||
state.isTransformingBoundingBox = false;
|
state.isTransformingBoundingBox = false;
|
||||||
},
|
},
|
||||||
generationModeChanged: (state, action: PayloadAction<GenerationMode>) => {
|
|
||||||
state.generationMode = action.payload;
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
extraReducers: (builder) => {
|
extraReducers: (builder) => {
|
||||||
builder.addCase(sessionCanceled.pending, (state) => {
|
builder.addCase(sessionCanceled.pending, (state) => {
|
||||||
@ -959,7 +955,6 @@ export const {
|
|||||||
stagingAreaInitialized,
|
stagingAreaInitialized,
|
||||||
canvasSessionIdChanged,
|
canvasSessionIdChanged,
|
||||||
setShouldAntialias,
|
setShouldAntialias,
|
||||||
generationModeChanged,
|
|
||||||
} = canvasSlice.actions;
|
} = canvasSlice.actions;
|
||||||
|
|
||||||
export default canvasSlice.reducer;
|
export default canvasSlice.reducer;
|
||||||
|
@ -1,5 +1,10 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import { CanvasState, isCanvasMaskLine } from '../store/canvasTypes';
|
import { Vector2d } from 'konva/lib/types';
|
||||||
|
import {
|
||||||
|
CanvasLayerState,
|
||||||
|
Dimensions,
|
||||||
|
isCanvasMaskLine,
|
||||||
|
} from '../store/canvasTypes';
|
||||||
import createMaskStage from './createMaskStage';
|
import createMaskStage from './createMaskStage';
|
||||||
import { getCanvasBaseLayer, getCanvasStage } from './konvaInstanceProvider';
|
import { getCanvasBaseLayer, getCanvasStage } from './konvaInstanceProvider';
|
||||||
import { konvaNodeToBlob } from './konvaNodeToBlob';
|
import { konvaNodeToBlob } from './konvaNodeToBlob';
|
||||||
@ -8,7 +13,13 @@ import { konvaNodeToImageData } from './konvaNodeToImageData';
|
|||||||
/**
|
/**
|
||||||
* Gets Blob and ImageData objects for the base and mask layers
|
* Gets Blob and ImageData objects for the base and mask layers
|
||||||
*/
|
*/
|
||||||
export const getCanvasData = async (canvasState: CanvasState) => {
|
export const getCanvasData = async (
|
||||||
|
layerState: CanvasLayerState,
|
||||||
|
boundingBoxCoordinates: Vector2d,
|
||||||
|
boundingBoxDimensions: Dimensions,
|
||||||
|
isMaskEnabled: boolean,
|
||||||
|
shouldPreserveMaskedArea: boolean
|
||||||
|
) => {
|
||||||
const log = logger('canvas');
|
const log = logger('canvas');
|
||||||
|
|
||||||
const canvasBaseLayer = getCanvasBaseLayer();
|
const canvasBaseLayer = getCanvasBaseLayer();
|
||||||
@ -19,14 +30,6 @@ export const getCanvasData = async (canvasState: CanvasState) => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const {
|
|
||||||
layerState: { objects },
|
|
||||||
boundingBoxCoordinates,
|
|
||||||
boundingBoxDimensions,
|
|
||||||
isMaskEnabled,
|
|
||||||
shouldPreserveMaskedArea,
|
|
||||||
} = canvasState;
|
|
||||||
|
|
||||||
const boundingBox = {
|
const boundingBox = {
|
||||||
...boundingBoxCoordinates,
|
...boundingBoxCoordinates,
|
||||||
...boundingBoxDimensions,
|
...boundingBoxDimensions,
|
||||||
@ -57,7 +60,7 @@ export const getCanvasData = async (canvasState: CanvasState) => {
|
|||||||
|
|
||||||
// For the mask layer, use the normal boundingBox
|
// For the mask layer, use the normal boundingBox
|
||||||
const maskStage = await createMaskStage(
|
const maskStage = await createMaskStage(
|
||||||
isMaskEnabled ? objects.filter(isCanvasMaskLine) : [], // only include mask lines, and only if mask is enabled
|
isMaskEnabled ? layerState.objects.filter(isCanvasMaskLine) : [], // only include mask lines, and only if mask is enabled
|
||||||
boundingBox,
|
boundingBox,
|
||||||
shouldPreserveMaskedArea
|
shouldPreserveMaskedArea
|
||||||
);
|
);
|
||||||
|
@ -1,9 +1,5 @@
|
|||||||
import { Box } from '@chakra-ui/react';
|
import { Box } from '@chakra-ui/react';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useCanvasGenerationMode } from 'features/canvas/hooks/useCanvasGenerationMode';
|
||||||
import { generationModeChanged } from 'features/canvas/store/canvasSlice';
|
|
||||||
import { getCanvasData } from 'features/canvas/util/getCanvasData';
|
|
||||||
import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGenerationMode';
|
|
||||||
import { useDebounce } from 'react-use';
|
|
||||||
|
|
||||||
const GENERATION_MODE_NAME_MAP = {
|
const GENERATION_MODE_NAME_MAP = {
|
||||||
txt2img: 'Text to Image',
|
txt2img: 'Text to Image',
|
||||||
@ -12,38 +8,8 @@ const GENERATION_MODE_NAME_MAP = {
|
|||||||
outpaint: 'Inpaint',
|
outpaint: 'Inpaint',
|
||||||
};
|
};
|
||||||
|
|
||||||
export const useGenerationMode = () => {
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
const canvasState = useAppSelector((state) => state.canvas);
|
|
||||||
|
|
||||||
useDebounce(
|
|
||||||
async () => {
|
|
||||||
// Build canvas blobs
|
|
||||||
const canvasBlobsAndImageData = await getCanvasData(canvasState);
|
|
||||||
|
|
||||||
if (!canvasBlobsAndImageData) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const { baseImageData, maskImageData } = canvasBlobsAndImageData;
|
|
||||||
|
|
||||||
// Determine the generation mode
|
|
||||||
const generationMode = getCanvasGenerationMode(
|
|
||||||
baseImageData,
|
|
||||||
maskImageData
|
|
||||||
);
|
|
||||||
|
|
||||||
dispatch(generationModeChanged(generationMode));
|
|
||||||
},
|
|
||||||
1000,
|
|
||||||
[dispatch, canvasState, generationModeChanged]
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
const GenerationModeStatusText = () => {
|
const GenerationModeStatusText = () => {
|
||||||
const generationMode = useAppSelector((state) => state.canvas.generationMode);
|
const generationMode = useCanvasGenerationMode();
|
||||||
|
|
||||||
useGenerationMode();
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Box>
|
<Box>
|
||||||
|
Loading…
Reference in New Issue
Block a user