feat(ui): get region and base layer canvas to blob logic working

This commit is contained in:
psychedelicious 2024-06-20 21:03:33 +10:00
parent 3b864921ac
commit 0c9cf73702
8 changed files with 204 additions and 76 deletions

View File

@ -216,7 +216,7 @@ export const updateBboxes = (
onBboxChanged({ id: entityState.id, bbox: getLayerBboxPixels(konvaLayer, filterLayerChildren) }, 'layer'); onBboxChanged({ id: entityState.id, bbox: getLayerBboxPixels(konvaLayer, filterLayerChildren) }, 'layer');
} }
} else if (entityState.type === 'control_adapter') { } else if (entityState.type === 'control_adapter') {
if (!entityState.image && !entityState.processedImage) { if (!entityState.imageObject && !entityState.processedImageObject) {
// No objects - no bbox to calculate // No objects - no bbox to calculate
onBboxChanged({ id: entityState.id, bbox: null }, 'control_adapter'); onBboxChanged({ id: entityState.id, bbox: null }, 'control_adapter');
} else { } else {

View File

@ -8,9 +8,9 @@ import {
} from 'features/controlLayers/konva/naming'; } from 'features/controlLayers/konva/naming';
import type { import type {
BrushLineObjectRecord, BrushLineObjectRecord,
KonvaEntityAdapter,
EraserLineObjectRecord, EraserLineObjectRecord,
ImageObjectRecord, ImageObjectRecord,
KonvaEntityAdapter,
RectShapeObjectRecord, RectShapeObjectRecord,
} from 'features/controlLayers/konva/nodeManager'; } from 'features/controlLayers/konva/nodeManager';
import type { import type {

View File

@ -53,8 +53,11 @@ import type {
import type Konva from 'konva'; import type Konva from 'konva';
import type { IRect, Vector2d } from 'konva/lib/types'; import type { IRect, Vector2d } from 'konva/lib/types';
import { debounce } from 'lodash-es'; import { debounce } from 'lodash-es';
import { atom } from 'nanostores';
import type { RgbaColor } from 'react-colorful'; import type { RgbaColor } from 'react-colorful';
export const $nodeManager = atom<KonvaNodeManager | null>(null);
/** /**
* Initializes the canvas renderer. It subscribes to the redux store and listens for changes directly, bypassing the * Initializes the canvas renderer. It subscribes to the redux store and listens for changes directly, bypassing the
* react rendering cycle entirely, improving canvas performance. * react rendering cycle entirely, improving canvas performance.
@ -249,6 +252,8 @@ export const initializeRenderer = (
}; };
const manager = new KonvaNodeManager(stage, getBbox, onBboxTransformed, $shift.get, $ctrl.get, $meta.get, $alt.get); const manager = new KonvaNodeManager(stage, getBbox, onBboxTransformed, $shift.get, $ctrl.get, $meta.get, $alt.get);
console.log(manager);
$nodeManager.set(manager);
const cleanupListeners = setStageEventHandlers({ const cleanupListeners = setStageEventHandlers({
manager, manager,
@ -344,7 +349,7 @@ export const initializeRenderer = (
canvasV2.controlAdapters !== prevCanvasV2.controlAdapters || canvasV2.controlAdapters !== prevCanvasV2.controlAdapters ||
canvasV2.regions !== prevCanvasV2.regions canvasV2.regions !== prevCanvasV2.regions
) { ) {
logIfDebugging('Updating entity bboxes'); // logIfDebugging('Updating entity bboxes');
// debouncedUpdateBboxes(stage, canvasV2.layers, canvasV2.controlAdapters, canvasV2.regions, onBboxChanged); // debouncedUpdateBboxes(stage, canvasV2.layers, canvasV2.controlAdapters, canvasV2.regions, onBboxChanged);
} }

View File

@ -16,9 +16,10 @@ import { toolReducers } from 'features/controlLayers/store/toolReducers';
import { initialAspectRatioState } from 'features/parameters/components/ImageSize/constants'; import { initialAspectRatioState } from 'features/parameters/components/ImageSize/constants';
import type { AspectRatioState } from 'features/parameters/components/ImageSize/types'; import type { AspectRatioState } from 'features/parameters/components/ImageSize/types';
import { atom } from 'nanostores'; import { atom } from 'nanostores';
import type { ImageDTO } from 'services/api/types';
import type { CanvasEntityIdentifier, CanvasV2State, StageAttrs } from './types'; import type { CanvasEntityIdentifier, CanvasV2State, StageAttrs } from './types';
import { DEFAULT_RGBA_COLOR } from './types'; import { DEFAULT_RGBA_COLOR, imageDTOToImageWithDims } from './types';
const initialState: CanvasV2State = { const initialState: CanvasV2State = {
_version: 3, _version: 3,
@ -119,6 +120,7 @@ const initialState: CanvasV2State = {
refinerNegativeAestheticScore: 2.5, refinerNegativeAestheticScore: 2.5,
refinerStart: 0.8, refinerStart: 0.8,
}, },
baseLayerImageCache: null,
}; };
export const canvasV2Slice = createSlice({ export const canvasV2Slice = createSlice({
@ -164,6 +166,10 @@ export const canvasV2Slice = createSlice({
state.layers = []; state.layers = [];
state.ipAdapters = []; state.ipAdapters = [];
state.controlAdapters = []; state.controlAdapters = [];
state.baseLayerImageCache = null;
},
baseLayerImageCacheChanged: (state, action: PayloadAction<ImageDTO | null>) => {
state.baseLayerImageCache = action.payload ? imageDTOToImageWithDims(action.payload) : null;
}, },
}, },
}); });
@ -185,6 +191,7 @@ export const {
scaledBboxChanged, scaledBboxChanged,
bboxScaleMethodChanged, bboxScaleMethodChanged,
clipToBboxChanged, clipToBboxChanged,
baseLayerImageCacheChanged,
// layers // layers
layerAdded, layerAdded,
layerRecalled, layerRecalled,

View File

@ -39,6 +39,7 @@ export const layersReducers = {
y: 0, y: 0,
}); });
state.selectedEntityIdentifier = { type: 'layer', id }; state.selectedEntityIdentifier = { type: 'layer', id };
state.baseLayerImageCache = null;
}, },
prepare: () => ({ payload: { id: uuidv4() } }), prepare: () => ({ payload: { id: uuidv4() } }),
}, },
@ -46,6 +47,7 @@ export const layersReducers = {
const { data } = action.payload; const { data } = action.payload;
state.layers.push(data); state.layers.push(data);
state.selectedEntityIdentifier = { type: 'layer', id: data.id }; state.selectedEntityIdentifier = { type: 'layer', id: data.id };
state.baseLayerImageCache = null;
}, },
layerIsEnabledToggled: (state, action: PayloadAction<{ id: string }>) => { layerIsEnabledToggled: (state, action: PayloadAction<{ id: string }>) => {
const { id } = action.payload; const { id } = action.payload;
@ -54,6 +56,7 @@ export const layersReducers = {
return; return;
} }
layer.isEnabled = !layer.isEnabled; layer.isEnabled = !layer.isEnabled;
state.baseLayerImageCache = null;
}, },
layerTranslated: (state, action: PayloadAction<{ id: string; x: number; y: number }>) => { layerTranslated: (state, action: PayloadAction<{ id: string; x: number; y: number }>) => {
const { id, x, y } = action.payload; const { id, x, y } = action.payload;
@ -63,6 +66,7 @@ export const layersReducers = {
} }
layer.x = x; layer.x = x;
layer.y = y; layer.y = y;
state.baseLayerImageCache = null;
}, },
layerBboxChanged: (state, action: PayloadAction<{ id: string; bbox: IRect | null }>) => { layerBboxChanged: (state, action: PayloadAction<{ id: string; bbox: IRect | null }>) => {
const { id, bbox } = action.payload; const { id, bbox } = action.payload;
@ -88,13 +92,16 @@ export const layersReducers = {
layer.objects = []; layer.objects = [];
layer.bbox = null; layer.bbox = null;
layer.bboxNeedsUpdate = false; layer.bboxNeedsUpdate = false;
state.baseLayerImageCache = null;
}, },
layerDeleted: (state, action: PayloadAction<{ id: string }>) => { layerDeleted: (state, action: PayloadAction<{ id: string }>) => {
const { id } = action.payload; const { id } = action.payload;
state.layers = state.layers.filter((l) => l.id !== id); state.layers = state.layers.filter((l) => l.id !== id);
state.baseLayerImageCache = null;
}, },
layerAllDeleted: (state) => { layerAllDeleted: (state) => {
state.layers = []; state.layers = [];
state.baseLayerImageCache = null;
}, },
layerOpacityChanged: (state, action: PayloadAction<{ id: string; opacity: number }>) => { layerOpacityChanged: (state, action: PayloadAction<{ id: string; opacity: number }>) => {
const { id, opacity } = action.payload; const { id, opacity } = action.payload;
@ -103,6 +110,7 @@ export const layersReducers = {
return; return;
} }
layer.opacity = opacity; layer.opacity = opacity;
state.baseLayerImageCache = null;
}, },
layerMovedForwardOne: (state, action: PayloadAction<{ id: string }>) => { layerMovedForwardOne: (state, action: PayloadAction<{ id: string }>) => {
const { id } = action.payload; const { id } = action.payload;
@ -111,6 +119,7 @@ export const layersReducers = {
return; return;
} }
moveOneToEnd(state.layers, layer); moveOneToEnd(state.layers, layer);
state.baseLayerImageCache = null;
}, },
layerMovedToFront: (state, action: PayloadAction<{ id: string }>) => { layerMovedToFront: (state, action: PayloadAction<{ id: string }>) => {
const { id } = action.payload; const { id } = action.payload;
@ -119,6 +128,7 @@ export const layersReducers = {
return; return;
} }
moveToEnd(state.layers, layer); moveToEnd(state.layers, layer);
state.baseLayerImageCache = null;
}, },
layerMovedBackwardOne: (state, action: PayloadAction<{ id: string }>) => { layerMovedBackwardOne: (state, action: PayloadAction<{ id: string }>) => {
const { id } = action.payload; const { id } = action.payload;
@ -127,6 +137,7 @@ export const layersReducers = {
return; return;
} }
moveOneToStart(state.layers, layer); moveOneToStart(state.layers, layer);
state.baseLayerImageCache = null;
}, },
layerMovedToBack: (state, action: PayloadAction<{ id: string }>) => { layerMovedToBack: (state, action: PayloadAction<{ id: string }>) => {
const { id } = action.payload; const { id } = action.payload;
@ -135,6 +146,7 @@ export const layersReducers = {
return; return;
} }
moveToStart(state.layers, layer); moveToStart(state.layers, layer);
state.baseLayerImageCache = null;
}, },
layerBrushLineAdded: { layerBrushLineAdded: {
reducer: (state, action: PayloadAction<BrushLineAddedArg & { lineId: string }>) => { reducer: (state, action: PayloadAction<BrushLineAddedArg & { lineId: string }>) => {
@ -153,6 +165,7 @@ export const layersReducers = {
clip, clip,
}); });
layer.bboxNeedsUpdate = true; layer.bboxNeedsUpdate = true;
state.baseLayerImageCache = null;
}, },
prepare: (payload: BrushLineAddedArg) => ({ prepare: (payload: BrushLineAddedArg) => ({
payload: { ...payload, lineId: uuidv4() }, payload: { ...payload, lineId: uuidv4() },
@ -174,6 +187,7 @@ export const layersReducers = {
clip, clip,
}); });
layer.bboxNeedsUpdate = true; layer.bboxNeedsUpdate = true;
state.baseLayerImageCache = null;
}, },
prepare: (payload: EraserLineAddedArg) => ({ prepare: (payload: EraserLineAddedArg) => ({
payload: { ...payload, lineId: uuidv4() }, payload: { ...payload, lineId: uuidv4() },
@ -191,6 +205,7 @@ export const layersReducers = {
} }
lastObject.points.push(...point); lastObject.points.push(...point);
layer.bboxNeedsUpdate = true; layer.bboxNeedsUpdate = true;
state.baseLayerImageCache = null;
}, },
layerRectAdded: { layerRectAdded: {
reducer: (state, action: PayloadAction<RectShapeAddedArg & { rectId: string }>) => { reducer: (state, action: PayloadAction<RectShapeAddedArg & { rectId: string }>) => {
@ -210,6 +225,7 @@ export const layersReducers = {
color, color,
}); });
layer.bboxNeedsUpdate = true; layer.bboxNeedsUpdate = true;
state.baseLayerImageCache = null;
}, },
prepare: (payload: RectShapeAddedArg) => ({ payload: { ...payload, rectId: uuidv4() } }), prepare: (payload: RectShapeAddedArg) => ({ payload: { ...payload, rectId: uuidv4() } }),
}, },
@ -222,6 +238,7 @@ export const layersReducers = {
} }
layer.objects.push(imageDTOToImageObject(id, objectId, imageDTO)); layer.objects.push(imageDTOToImageObject(id, objectId, imageDTO));
layer.bboxNeedsUpdate = true; layer.bboxNeedsUpdate = true;
state.baseLayerImageCache = null;
}, },
prepare: (payload: ImageObjectAddedArg) => ({ payload: { ...payload, objectId: uuidv4() } }), prepare: (payload: ImageObjectAddedArg) => ({ payload: { ...payload, objectId: uuidv4() } }),
}, },

View File

@ -872,6 +872,7 @@ export type CanvasV2State = {
refinerNegativeAestheticScore: number; refinerNegativeAestheticScore: number;
refinerStart: number; refinerStart: number;
}; };
baseLayerImageCache: ImageWithDims | null;
}; };
export type StageAttrs = { x: number; y: number; width: number; height: number; scale: number }; export type StageAttrs = { x: number; y: number; width: number; height: number; scale: number };

View File

@ -0,0 +1,96 @@
import { getStore } from 'app/store/nanostores/store';
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
import { $nodeManager } from 'features/controlLayers/konva/renderers/renderer';
import { blobToDataURL } from 'features/controlLayers/konva/util';
import { baseLayerImageCacheChanged } from 'features/controlLayers/store/canvasV2Slice';
import type { LayerEntity } from 'features/controlLayers/store/types';
import type Konva from 'konva';
import type { IRect } from 'konva/lib/types';
import { getImageDTO, imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import { assert } from 'tsafe';
const isValidLayer = (entity: LayerEntity) => {
return (
entity.isEnabled &&
// Boolean(entity.bbox) && TODO(psyche): Re-enable this check when we have a way to calculate bbox for all layers
entity.objects.length > 0
);
};
/**
* Get the blobs of all regional prompt layers. Only visible layers are returned.
* @param layerIds The IDs of the layers to get blobs for. If not provided, all regional prompt layers are used.
* @param preview Whether to open a new tab displaying each layer.
* @returns A map of layer IDs to blobs.
*/
const getBaseLayer = async (layers: LayerEntity[], bbox: IRect, preview: boolean = false): Promise<Blob> => {
const manager = $nodeManager.get();
assert(manager, 'Node manager is null');
const stage = manager.stage.clone();
stage.scaleX(1);
stage.scaleY(1);
stage.x(0);
stage.y(0);
const validLayers = layers.filter(isValidLayer);
// Konva bug (?) - when iterating over the array returned from `stage.getLayers()`, if you destroy a layer, the array
// is mutated in-place and the next iteration will skip the next layer. To avoid this, we first collect the layers
// to delete in a separate array and then destroy them.
// TODO(psyche): Maybe report this?
const toDelete: Konva.Layer[] = [];
for (const konvaLayer of stage.getLayers()) {
const layer = validLayers.find((l) => l.id === konvaLayer.id());
if (!layer) {
toDelete.push(konvaLayer);
}
}
for (const konvaLayer of toDelete) {
konvaLayer.destroy();
}
const blob = await new Promise<Blob>((resolve) => {
stage.toBlob({
callback: (blob) => {
assert(blob, 'Blob is null');
resolve(blob);
},
...bbox,
});
});
if (preview) {
const base64 = await blobToDataURL(blob);
openBase64ImageInTab([{ base64, caption: 'base layer' }]);
}
stage.destroy();
return blob;
};
export const getBaseLayerImage = async (): Promise<ImageDTO> => {
const { dispatch, getState } = getStore();
const state = getState();
if (state.canvasV2.baseLayerImageCache) {
const imageDTO = await getImageDTO(state.canvasV2.baseLayerImageCache.name);
if (imageDTO) {
return imageDTO;
}
}
const blob = await getBaseLayer(state.canvasV2.layers, state.canvasV2.bbox, true);
const file = new File([blob], 'image.png', { type: 'image/png' });
const req = dispatch(
imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'general', is_intermediate: true })
);
req.reset();
const imageDTO = await req.unwrap();
dispatch(baseLayerImageCacheChanged(imageDTO));
return imageDTO;
};

View File

@ -1,8 +1,8 @@
import { getStore } from 'app/store/nanostores/store'; import { getStore } from 'app/store/nanostores/store';
import { deepClone } from 'common/util/deepClone'; import { deepClone } from 'common/util/deepClone';
import openBase64ImageInTab from 'common/util/openBase64ImageInTab'; import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
import { KonvaNodeManager } from 'features/controlLayers/konva/nodeManager'; import type { KonvaEntityAdapter } from 'features/controlLayers/konva/nodeManager';
import { renderRegions } from 'features/controlLayers/konva/renderers/regions'; import { $nodeManager } from 'features/controlLayers/konva/renderers/renderer';
import { blobToDataURL } from 'features/controlLayers/konva/util'; import { blobToDataURL } from 'features/controlLayers/konva/util';
import { rgMaskImageUploaded } from 'features/controlLayers/store/canvasV2Slice'; import { rgMaskImageUploaded } from 'features/controlLayers/store/canvasV2Slice';
import type { Dimensions, IPAdapterEntity, RegionEntity } from 'features/controlLayers/store/types'; import type { Dimensions, IPAdapterEntity, RegionEntity } from 'features/controlLayers/store/types';
@ -15,9 +15,7 @@ import {
} from 'features/nodes/util/graph/constants'; } from 'features/nodes/util/graph/constants';
import { addIPAdapterCollectorSafe, isValidIPAdapter } from 'features/nodes/util/graph/generation/addIPAdapters'; import { addIPAdapterCollectorSafe, isValidIPAdapter } from 'features/nodes/util/graph/generation/addIPAdapters';
import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import Konva from 'konva';
import type { IRect } from 'konva/lib/types'; import type { IRect } from 'konva/lib/types';
import { size } from 'lodash-es';
import { getImageDTO, imagesApi } from 'services/api/endpoints/images'; import { getImageDTO, imagesApi } from 'services/api/endpoints/images';
import type { BaseModelType, ImageDTO, Invocation } from 'services/api/types'; import type { BaseModelType, ImageDTO, Invocation } from 'services/api/types';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
@ -50,38 +48,34 @@ export const addRegions = async (
const isSDXL = base === 'sdxl'; const isSDXL = base === 'sdxl';
const validRegions = regions.filter((rg) => isValidRegion(rg, base)); const validRegions = regions.filter((rg) => isValidRegion(rg, base));
const blobs = await getRGMaskBlobs(validRegions, documentSize, bbox);
assert(size(blobs) === size(validRegions), 'Mismatch between layer IDs and blobs');
for (const rg of validRegions) { for (const region of validRegions) {
const blob = blobs[rg.id];
assert(blob, `Blob for layer ${rg.id} not found`);
// Upload the mask image, or get the cached image if it exists // Upload the mask image, or get the cached image if it exists
const { image_name } = await getMaskImage(rg, blob); const { image_name } = await getRegionMaskImage(region, bbox, true);
// The main mask-to-tensor node // The main mask-to-tensor node
const maskToTensor = g.addNode({ const maskToTensor = g.addNode({
id: `${PROMPT_REGION_MASK_TO_TENSOR_PREFIX}_${rg.id}`, id: `${PROMPT_REGION_MASK_TO_TENSOR_PREFIX}_${region.id}`,
type: 'alpha_mask_to_tensor', type: 'alpha_mask_to_tensor',
image: { image: {
image_name, image_name,
}, },
}); });
if (rg.positivePrompt) { if (region.positivePrompt) {
// The main positive conditioning node // The main positive conditioning node
const regionalPosCond = g.addNode( const regionalPosCond = g.addNode(
isSDXL isSDXL
? { ? {
type: 'sdxl_compel_prompt', type: 'sdxl_compel_prompt',
id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${rg.id}`, id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${region.id}`,
prompt: rg.positivePrompt, prompt: region.positivePrompt,
style: rg.positivePrompt, // TODO: Should we put the positive prompt in both fields? style: region.positivePrompt, // TODO: Should we put the positive prompt in both fields?
} }
: { : {
type: 'compel', type: 'compel',
id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${rg.id}`, id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${region.id}`,
prompt: rg.positivePrompt, prompt: region.positivePrompt,
} }
); );
// Connect the mask to the conditioning // Connect the mask to the conditioning
@ -106,20 +100,20 @@ export const addRegions = async (
} }
} }
if (rg.negativePrompt) { if (region.negativePrompt) {
// The main negative conditioning node // The main negative conditioning node
const regionalNegCond = g.addNode( const regionalNegCond = g.addNode(
isSDXL isSDXL
? { ? {
type: 'sdxl_compel_prompt', type: 'sdxl_compel_prompt',
id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${rg.id}`, id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${region.id}`,
prompt: rg.negativePrompt, prompt: region.negativePrompt,
style: rg.negativePrompt, style: region.negativePrompt,
} }
: { : {
type: 'compel', type: 'compel',
id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${rg.id}`, id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${region.id}`,
prompt: rg.negativePrompt, prompt: region.negativePrompt,
} }
); );
// Connect the mask to the conditioning // Connect the mask to the conditioning
@ -143,10 +137,10 @@ export const addRegions = async (
} }
// If we are using the "invert" auto-negative setting, we need to add an additional negative conditioning node // If we are using the "invert" auto-negative setting, we need to add an additional negative conditioning node
if (rg.autoNegative === 'invert' && rg.positivePrompt) { if (region.autoNegative === 'invert' && region.positivePrompt) {
// We re-use the mask image, but invert it when converting to tensor // We re-use the mask image, but invert it when converting to tensor
const invertTensorMask = g.addNode({ const invertTensorMask = g.addNode({
id: `${PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX}_${rg.id}`, id: `${PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX}_${region.id}`,
type: 'invert_tensor_mask', type: 'invert_tensor_mask',
}); });
// Connect the OG mask image to the inverted mask-to-tensor node // Connect the OG mask image to the inverted mask-to-tensor node
@ -156,14 +150,14 @@ export const addRegions = async (
isSDXL isSDXL
? { ? {
type: 'sdxl_compel_prompt', type: 'sdxl_compel_prompt',
id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${rg.id}`, id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${region.id}`,
prompt: rg.positivePrompt, prompt: region.positivePrompt,
style: rg.positivePrompt, style: region.positivePrompt,
} }
: { : {
type: 'compel', type: 'compel',
id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${rg.id}`, id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${region.id}`,
prompt: rg.positivePrompt, prompt: region.positivePrompt,
} }
); );
// Connect the inverted mask to the conditioning // Connect the inverted mask to the conditioning
@ -186,7 +180,7 @@ export const addRegions = async (
} }
} }
const validRGIPAdapters: IPAdapterEntity[] = rg.ipAdapters.filter((ipa) => isValidIPAdapter(ipa, base)); const validRGIPAdapters: IPAdapterEntity[] = region.ipAdapters.filter((ipa) => isValidIPAdapter(ipa, base));
for (const ipa of validRGIPAdapters) { for (const ipa of validRGIPAdapters) {
const ipAdapterCollect = addIPAdapterCollectorSafe(g, denoise); const ipAdapterCollect = addIPAdapterCollectorSafe(g, denoise);
@ -245,6 +239,20 @@ export const getMaskImage = async (rg: RegionEntity, blob: Blob): Promise<ImageD
return imageDTO; return imageDTO;
}; };
export const uploadMaskImage = async ({ id }: RegionEntity, blob: Blob): Promise<ImageDTO> => {
const { dispatch } = getStore();
// No cached mask, or the cached image no longer exists - we need to upload the mask image
const file = new File([blob], `${id}_mask.png`, { type: 'image/png' });
const req = dispatch(
imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'mask', is_intermediate: true })
);
req.reset();
const imageDTO = await req.unwrap();
dispatch(rgMaskImageUploaded({ id, imageDTO }));
return imageDTO;
};
/** /**
* Get the blobs of all regional prompt layers. Only visible layers are returned. * Get the blobs of all regional prompt layers. Only visible layers are returned.
* @param layerIds The IDs of the layers to get blobs for. If not provided, all regional prompt layers are used. * @param layerIds The IDs of the layers to get blobs for. If not provided, all regional prompt layers are used.
@ -252,53 +260,47 @@ export const getMaskImage = async (rg: RegionEntity, blob: Blob): Promise<ImageD
* @returns A map of layer IDs to blobs. * @returns A map of layer IDs to blobs.
*/ */
export const getRGMaskBlobs = async ( export const getRegionMaskImage = async (
regions: RegionEntity[], region: RegionEntity,
documentSize: Dimensions,
bbox: IRect, bbox: IRect,
preview: boolean = false preview: boolean = false
): Promise<Record<string, Blob>> => { ): Promise<ImageDTO> => {
const container = document.createElement('div'); const manager = $nodeManager.get();
const stage = new Konva.Stage({ container, ...documentSize }); assert(manager, 'Node manager is null');
const manager = new KonvaNodeManager(stage);
renderRegions(manager, regions, 1, 'brush', null);
const adapters = manager.getAll();
const blobs: Record<string, Blob> = {};
// First remove all layers // TODO(psyche): Why do I need to annotate this? TS must have some kind of circular ref w/ this type but I can't figure it out...
for (const adapter of adapters) { const adapter: KonvaEntityAdapter | undefined = manager.get(region.id);
adapter.konvaLayer.remove(); assert(adapter, `Adapter for region ${region.id} not found`);
} if (region.imageCache) {
const imageDTO = await getImageDTO(region.imageCache.name);
// Next render each layer to a blob if (imageDTO) {
for (const adapter of adapters) { return imageDTO;
const region = regions.find((l) => l.id === adapter.id);
if (!region) {
continue;
} }
stage.add(adapter.konvaLayer); }
const blob = await new Promise<Blob>((resolve) => { const layer = adapter.konvaLayer.clone();
stage.toBlob({ const objectGroup = adapter.konvaObjectGroup.clone();
callback: (blob) => { layer.destroyChildren();
assert(blob, 'Blob is null'); layer.add(objectGroup);
resolve(blob); objectGroup.opacity(1);
}, objectGroup.cache();
...bbox,
}); const blob = await new Promise<Blob>((resolve) => {
layer.toBlob({
callback: (blob) => {
assert(blob, 'Blob is null');
resolve(blob);
},
...bbox,
}); });
});
if (preview) { if (preview) {
const base64 = await blobToDataURL(blob); const base64 = await blobToDataURL(blob);
openBase64ImageInTab([ const caption = `${region.id}: ${region.positivePrompt} / ${region.negativePrompt}`;
{ openBase64ImageInTab([{ base64, caption }]);
base64,
caption: `${region.id}: ${region.positivePrompt} / ${region.negativePrompt}`,
},
]);
}
adapter.konvaLayer.remove();
blobs[adapter.id] = blob;
} }
return blobs; layer.destroy();
return await uploadMaskImage(region, blob);
}; };