feat(ui): update staging handling to work w/ cropped mask

This commit is contained in:
psychedelicious 2024-07-16 21:31:57 +10:00
parent 0008617348
commit e2e02f31b6
12 changed files with 97 additions and 84 deletions

View File

@ -1,8 +1,7 @@
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { import {
layerAdded, layerAddedFromStagingArea,
layerImageAdded,
sessionStagingAreaImageAccepted, sessionStagingAreaImageAccepted,
sessionStagingAreaReset, sessionStagingAreaReset,
} from 'features/controlLayers/store/canvasV2Slice'; } from 'features/controlLayers/store/canvasV2Slice';
@ -49,33 +48,13 @@ export const addStagingListeners = (startAppListening: AppStartListening) => {
actionCreator: sessionStagingAreaImageAccepted, actionCreator: sessionStagingAreaImageAccepted,
effect: async (action, api) => { effect: async (action, api) => {
const { index } = action.payload; const { index } = action.payload;
const { layers, selectedEntityIdentifier } = api.getState().canvasV2; const state = api.getState();
let layer = layers.entities.find((layer) => layer.id === selectedEntityIdentifier?.id); const stagingAreaImage = state.canvasV2.session.stagedImages[index];
if (!layer) { assert(stagingAreaImage, 'No staged image found to accept');
layer = layers.entities[0]; const { x, y } = state.canvasV2.bbox.rect;
}
if (!layer) { api.dispatch(layerAddedFromStagingArea({ stagingAreaImage, pos: { x, y } }));
// We need to create a new layer to add the accepted image
api.dispatch(layerAdded());
layer = api.getState().canvasV2.layers.entities[0];
}
const stagedImage = api.getState().canvasV2.session.stagedImages[index];
assert(stagedImage, 'No staged image found to accept');
assert(layer, 'No layer found to stage image');
const { id } = layer;
api.dispatch(
layerImageAdded({
id,
imageDTO: stagedImage.imageDTO,
pos: { x: stagedImage.rect.x - layer.x, y: stagedImage.rect.y - layer.y },
})
);
api.dispatch(sessionStagingAreaReset()); api.dispatch(sessionStagingAreaReset());
}, },
}); });

View File

@ -59,12 +59,20 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
imageDTORequest.unsubscribe(); imageDTORequest.unsubscribe();
// handle tab-specific logic // handle tab-specific logic
if (data.origin === 'canvas' && data.result.type === 'canvas_v2_mask_and_crop_output') { if (data.origin === 'canvas' && data.invocation_source_id === 'canvas_output') {
const { x, y, width, height } = data.result; if (data.result.type === 'canvas_v2_mask_and_crop_output') {
if (canvasV2.session.isStaging) { const { offset_x, offset_y } = data.result;
dispatch(sessionImageStaged({ imageDTO, rect: { x, y, width, height } })); if (canvasV2.session.isStaging) {
} else if (!canvasV2.session.isActive) { dispatch(sessionImageStaged({ stagingAreaImage: { imageDTO, offsetX: offset_x, offsetY: offset_y } }));
$lastProgressEvent.set(null); } else if (!canvasV2.session.isActive) {
$lastProgressEvent.set(null);
}
} else if (data.result.type === 'image_output') {
if (canvasV2.session.isStaging) {
dispatch(sessionImageStaged({ stagingAreaImage: { imageDTO, offsetX: 0, offsetY: 0 } }));
} else if (!canvasV2.session.isActive) {
$lastProgressEvent.set(null);
}
} }
} }

View File

@ -18,55 +18,47 @@ export class CanvasStagingArea {
async render() { async render() {
const session = this.manager.stateApi.getSession(); const session = this.manager.stateApi.getSession();
const bboxRect = this.manager.stateApi.getBbox().rect;
const shouldShowStagedImage = this.manager.stateApi.getShouldShowStagedImage(); const shouldShowStagedImage = this.manager.stateApi.getShouldShowStagedImage();
this.selectedImage = session.stagedImages[session.selectedStagedImageIndex] ?? null; this.selectedImage = session.stagedImages[session.selectedStagedImageIndex] ?? null;
if (this.selectedImage) { if (this.selectedImage) {
const { imageDTO, offsetX, offsetY } = this.selectedImage;
if (this.image) { if (this.image) {
if ( if (!this.image.isLoading && !this.image.isError && this.image.imageName !== imageDTO.image_name) {
!this.image.isLoading && this.image.konvaImageGroup.visible(false);
!this.image.isError && this.image.konvaImage?.width(imageDTO.width);
this.image.imageName !== this.selectedImage.imageDTO.image_name this.image.konvaImage?.height(imageDTO.height);
) { this.image.konvaImageGroup.x(bboxRect.x + offsetX);
await this.image.updateImageSource(this.selectedImage.imageDTO.image_name); this.image.konvaImageGroup.y(bboxRect.y + offsetY);
await this.image.updateImageSource(imageDTO.image_name);
} }
this.image.konvaImageGroup.x(this.selectedImage.rect.x);
this.image.konvaImageGroup.y(this.selectedImage.rect.y);
this.image.konvaImageGroup.visible(shouldShowStagedImage);
} else { } else {
const { image_name } = this.selectedImage.imageDTO; const { image_name, width, height } = imageDTO;
const { x, y, width, height } = this.selectedImage.rect; this.image = new CanvasImage({
this.image = new CanvasImage( id: 'staging-area-image',
{ type: 'image',
id: 'staging-area-image', x: 0,
type: 'image', y: 0,
x, width,
y, height,
filters: [],
image: {
name: image_name,
width, width,
height, height,
filters: [],
image: {
name: image_name,
width,
height,
},
}, },
{ });
onLoad: (konvaImage) => {
if (this.selectedImage) {
konvaImage.width(this.selectedImage.rect.width);
konvaImage.height(this.selectedImage.rect.height);
}
this.manager.stateApi.resetLastProgressEvent();
this.image?.konvaImageGroup.visible(shouldShowStagedImage);
},
}
);
this.group.add(this.image.konvaImageGroup); this.group.add(this.image.konvaImageGroup);
await this.image.updateImageSource(this.selectedImage.imageDTO.image_name); this.image.konvaImage?.width(imageDTO.width);
this.image.konvaImageGroup.visible(shouldShowStagedImage); this.image.konvaImage?.height(imageDTO.height);
this.image.konvaImageGroup.x(bboxRect.x + offsetX);
this.image.konvaImageGroup.y(bboxRect.y + offsetY);
await this.image.updateImageSource(imageDTO.image_name);
} }
this.manager.stateApi.resetLastProgressEvent();
this.image.konvaImageGroup.visible(shouldShowStagedImage);
} else { } else {
this.image?.konvaImageGroup.visible(false); this.image?.konvaImageGroup.visible(false);
} }

View File

@ -15,7 +15,10 @@ import { regionsReducers } from 'features/controlLayers/store/regionsReducers';
import { sessionReducers } from 'features/controlLayers/store/sessionReducers'; import { sessionReducers } from 'features/controlLayers/store/sessionReducers';
import { settingsReducers } from 'features/controlLayers/store/settingsReducers'; import { settingsReducers } from 'features/controlLayers/store/settingsReducers';
import { toolReducers } from 'features/controlLayers/store/toolReducers'; import { toolReducers } from 'features/controlLayers/store/toolReducers';
import { getScaledBoundingBoxDimensions } from 'features/controlLayers/util/getScaledBoundingBoxDimensions';
import { initialAspectRatioState } from 'features/parameters/components/DocumentSize/constants'; import { initialAspectRatioState } from 'features/parameters/components/DocumentSize/constants';
import { getOptimalDimension } from 'features/parameters/util/optimalDimension';
import { pick } from 'lodash-es';
import { atom } from 'nanostores'; import { atom } from 'nanostores';
import type { InvocationDenoiseProgressEvent } from 'services/events/types'; import type { InvocationDenoiseProgressEvent } from 'services/events/types';
@ -158,6 +161,12 @@ export const canvasV2Slice = createSlice({
}, },
canvasReset: (state) => { canvasReset: (state) => {
state.bbox = deepClone(initialState.bbox); state.bbox = deepClone(initialState.bbox);
const optimalDimension = getOptimalDimension(state.params.model);
state.bbox.rect.width = optimalDimension;
state.bbox.rect.height = optimalDimension;
const size = pick(state.bbox.rect, 'width', 'height');
state.bbox.scaledSize = getScaledBoundingBoxDimensions(size, optimalDimension);
state.controlAdapters = deepClone(initialState.controlAdapters); state.controlAdapters = deepClone(initialState.controlAdapters);
state.ipAdapters = deepClone(initialState.ipAdapters); state.ipAdapters = deepClone(initialState.ipAdapters);
state.layers = deepClone(initialState.layers); state.layers = deepClone(initialState.layers);
@ -195,6 +204,7 @@ export const {
bboxSizeOptimized, bboxSizeOptimized,
// layers // layers
layerAdded, layerAdded,
layerAddedFromStagingArea,
layerRecalled, layerRecalled,
layerDeleted, layerDeleted,
layerReset, layerReset,

View File

@ -11,8 +11,10 @@ import type {
EraserLine, EraserLine,
ImageObjectAddedArg, ImageObjectAddedArg,
LayerEntity, LayerEntity,
Position,
RectShape, RectShape,
ScaleChangedArg, ScaleChangedArg,
StagingAreaImage,
} from './types'; } from './types';
import { imageDTOToImageObject, imageDTOToImageWithDims } from './types'; import { imageDTOToImageObject, imageDTOToImageWithDims } from './types';
@ -43,6 +45,32 @@ export const layersReducers = {
}, },
prepare: () => ({ payload: { id: uuidv4() } }), prepare: () => ({ payload: { id: uuidv4() } }),
}, },
layerAddedFromStagingArea: {
reducer: (
state,
action: PayloadAction<{ id: string; objectId: string; stagingAreaImage: StagingAreaImage; pos: Position }>
) => {
const { id, objectId, stagingAreaImage, pos } = action.payload;
const { imageDTO, offsetX, offsetY } = stagingAreaImage;
const imageObject = imageDTOToImageObject(id, objectId, imageDTO);
state.layers.entities.push({
id,
type: 'layer',
isEnabled: true,
bbox: null,
bboxNeedsUpdate: false,
objects: [imageObject],
opacity: 1,
x: pos.x + offsetX,
y: pos.y + offsetY,
});
state.selectedEntityIdentifier = { type: 'layer', id };
state.layers.imageCache = null;
},
prepare: (payload: { stagingAreaImage: StagingAreaImage; pos: Position }) => ({
payload: { ...payload, id: uuidv4(), objectId: uuidv4() },
}),
},
layerRecalled: (state, action: PayloadAction<{ data: LayerEntity }>) => { layerRecalled: (state, action: PayloadAction<{ data: LayerEntity }>) => {
const { data } = action.payload; const { data } = action.payload;
state.layers.entities.push(data); state.layers.entities.push(data);

View File

@ -60,7 +60,6 @@ export const paramsReducers = {
} }
// Update the bbox size to match the new model's optimal size // Update the bbox size to match the new model's optimal size
// TODO(psyche): Should we change the document size too?
const optimalDimension = getOptimalDimension(model); const optimalDimension = getOptimalDimension(model);
if (!getIsSizeOptimal(state.bbox.rect.width, state.bbox.rect.height, optimalDimension)) { if (!getIsSizeOptimal(state.bbox.rect.width, state.bbox.rect.height, optimalDimension)) {
const bboxDims = calculateNewSize(state.bbox.aspectRatio.value, optimalDimension * optimalDimension); const bboxDims = calculateNewSize(state.bbox.aspectRatio.value, optimalDimension * optimalDimension);

View File

@ -14,9 +14,9 @@ export const sessionReducers = {
state.tool.selectedBuffer = state.tool.selected; state.tool.selectedBuffer = state.tool.selected;
state.tool.selected = 'view'; state.tool.selected = 'view';
}, },
sessionImageStaged: (state, action: PayloadAction<StagingAreaImage>) => { sessionImageStaged: (state, action: PayloadAction<{ stagingAreaImage: StagingAreaImage }>) => {
const { imageDTO, rect } = action.payload; const { stagingAreaImage } = action.payload;
state.session.stagedImages.push({ imageDTO, rect }); state.session.stagedImages.push(stagingAreaImage);
state.session.selectedStagedImageIndex = state.session.stagedImages.length - 1; state.session.selectedStagedImageIndex = state.session.stagedImages.length - 1;
}, },
sessionNextStagedImageSelected: (state) => { sessionNextStagedImageSelected: (state) => {

View File

@ -828,7 +828,8 @@ export type LoRA = {
export type StagingAreaImage = { export type StagingAreaImage = {
imageDTO: ImageDTO; imageDTO: ImageDTO;
rect: Rect; offsetX: number;
offsetY: number;
}; };
export type CanvasV2State = { export type CanvasV2State = {

View File

@ -66,8 +66,7 @@ export const addInpaint = async (
const canvasPasteBack = g.addNode({ const canvasPasteBack = g.addNode({
id: 'canvas_v2_mask_and_crop', id: 'canvas_v2_mask_and_crop',
type: 'canvas_v2_mask_and_crop', type: 'canvas_v2_mask_and_crop',
invert: true, mask_blur: compositing.maskBlur,
crop_visible: true,
}); });
// Resize initial image and mask to scaled size, feed into to gradient mask // Resize initial image and mask to scaled size, feed into to gradient mask
@ -113,8 +112,7 @@ export const addInpaint = async (
const canvasPasteBack = g.addNode({ const canvasPasteBack = g.addNode({
id: 'canvas_v2_mask_and_crop', id: 'canvas_v2_mask_and_crop',
type: 'canvas_v2_mask_and_crop', type: 'canvas_v2_mask_and_crop',
invert: true, mask_blur: compositing.maskBlur,
crop_visible: true,
}); });
g.addEdge(alphaToMask, 'image', createGradientMask, 'mask'); g.addEdge(alphaToMask, 'image', createGradientMask, 'mask');
g.addEdge(i2l, 'latents', denoise, 'latents'); g.addEdge(i2l, 'latents', denoise, 'latents');

View File

@ -10,7 +10,7 @@ import type { Invocation } from 'services/api/types';
*/ */
export const addNSFWChecker = ( export const addNSFWChecker = (
g: Graph, g: Graph,
imageOutput: Invocation<'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_paste_back'> imageOutput: Invocation<'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_v2_mask_and_crop'>
): Invocation<'img_nsfw'> => { ): Invocation<'img_nsfw'> => {
const nsfw = g.addNode({ const nsfw = g.addNode({
id: NSFW_CHECKER, id: NSFW_CHECKER,

View File

@ -101,8 +101,7 @@ export const addOutpaint = async (
const canvasPasteBack = g.addNode({ const canvasPasteBack = g.addNode({
id: 'canvas_v2_mask_and_crop', id: 'canvas_v2_mask_and_crop',
type: 'canvas_v2_mask_and_crop', type: 'canvas_v2_mask_and_crop',
invert: true, mask_blur: compositing.maskBlur,
crop_visible: true,
}); });
// Resize initial image and mask to scaled size, feed into to gradient mask // Resize initial image and mask to scaled size, feed into to gradient mask
@ -147,8 +146,7 @@ export const addOutpaint = async (
const canvasPasteBack = g.addNode({ const canvasPasteBack = g.addNode({
id: 'canvas_v2_mask_and_crop', id: 'canvas_v2_mask_and_crop',
type: 'canvas_v2_mask_and_crop', type: 'canvas_v2_mask_and_crop',
invert: true, mask_blur: compositing.maskBlur,
crop_visible: true,
}); });
g.addEdge(maskAlphaToMask, 'image', maskCombine, 'mask1'); g.addEdge(maskAlphaToMask, 'image', maskCombine, 'mask1');
g.addEdge(initialImageAlphaToMask, 'image', maskCombine, 'mask2'); g.addEdge(initialImageAlphaToMask, 'image', maskCombine, 'mask2');

View File

@ -10,7 +10,7 @@ import type { Invocation } from 'services/api/types';
*/ */
export const addWatermarker = ( export const addWatermarker = (
g: Graph, g: Graph,
imageOutput: Invocation<'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_paste_back'> imageOutput: Invocation<'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_v2_mask_and_crop'>
): Invocation<'img_watermark'> => { ): Invocation<'img_watermark'> => {
const watermark = g.addNode({ const watermark = g.addNode({
id: WATERMARKER, id: WATERMARKER,