feat(ui): get region and base layer canvas to blob logic working

This commit is contained in:
psychedelicious 2024-06-20 21:03:33 +10:00
parent 3b864921ac
commit 0c9cf73702
8 changed files with 204 additions and 76 deletions

View File

@ -216,7 +216,7 @@ export const updateBboxes = (
onBboxChanged({ id: entityState.id, bbox: getLayerBboxPixels(konvaLayer, filterLayerChildren) }, 'layer');
}
} else if (entityState.type === 'control_adapter') {
if (!entityState.image && !entityState.processedImage) {
if (!entityState.imageObject && !entityState.processedImageObject) {
// No objects - no bbox to calculate
onBboxChanged({ id: entityState.id, bbox: null }, 'control_adapter');
} else {

View File

@ -8,9 +8,9 @@ import {
} from 'features/controlLayers/konva/naming';
import type {
BrushLineObjectRecord,
KonvaEntityAdapter,
EraserLineObjectRecord,
ImageObjectRecord,
KonvaEntityAdapter,
RectShapeObjectRecord,
} from 'features/controlLayers/konva/nodeManager';
import type {

View File

@ -53,8 +53,11 @@ import type {
import type Konva from 'konva';
import type { IRect, Vector2d } from 'konva/lib/types';
import { debounce } from 'lodash-es';
import { atom } from 'nanostores';
import type { RgbaColor } from 'react-colorful';
export const $nodeManager = atom<KonvaNodeManager | null>(null);
/**
* Initializes the canvas renderer. It subscribes to the redux store and listens for changes directly, bypassing the
* react rendering cycle entirely, improving canvas performance.
@ -249,6 +252,8 @@ export const initializeRenderer = (
};
const manager = new KonvaNodeManager(stage, getBbox, onBboxTransformed, $shift.get, $ctrl.get, $meta.get, $alt.get);
console.log(manager);
$nodeManager.set(manager);
const cleanupListeners = setStageEventHandlers({
manager,
@ -344,7 +349,7 @@ export const initializeRenderer = (
canvasV2.controlAdapters !== prevCanvasV2.controlAdapters ||
canvasV2.regions !== prevCanvasV2.regions
) {
logIfDebugging('Updating entity bboxes');
// logIfDebugging('Updating entity bboxes');
// debouncedUpdateBboxes(stage, canvasV2.layers, canvasV2.controlAdapters, canvasV2.regions, onBboxChanged);
}

View File

@ -16,9 +16,10 @@ import { toolReducers } from 'features/controlLayers/store/toolReducers';
import { initialAspectRatioState } from 'features/parameters/components/ImageSize/constants';
import type { AspectRatioState } from 'features/parameters/components/ImageSize/types';
import { atom } from 'nanostores';
import type { ImageDTO } from 'services/api/types';
import type { CanvasEntityIdentifier, CanvasV2State, StageAttrs } from './types';
import { DEFAULT_RGBA_COLOR } from './types';
import { DEFAULT_RGBA_COLOR, imageDTOToImageWithDims } from './types';
const initialState: CanvasV2State = {
_version: 3,
@ -119,6 +120,7 @@ const initialState: CanvasV2State = {
refinerNegativeAestheticScore: 2.5,
refinerStart: 0.8,
},
baseLayerImageCache: null,
};
export const canvasV2Slice = createSlice({
@ -164,6 +166,10 @@ export const canvasV2Slice = createSlice({
state.layers = [];
state.ipAdapters = [];
state.controlAdapters = [];
state.baseLayerImageCache = null;
},
baseLayerImageCacheChanged: (state, action: PayloadAction<ImageDTO | null>) => {
state.baseLayerImageCache = action.payload ? imageDTOToImageWithDims(action.payload) : null;
},
},
});
@ -185,6 +191,7 @@ export const {
scaledBboxChanged,
bboxScaleMethodChanged,
clipToBboxChanged,
baseLayerImageCacheChanged,
// layers
layerAdded,
layerRecalled,

View File

@ -39,6 +39,7 @@ export const layersReducers = {
y: 0,
});
state.selectedEntityIdentifier = { type: 'layer', id };
state.baseLayerImageCache = null;
},
prepare: () => ({ payload: { id: uuidv4() } }),
},
@ -46,6 +47,7 @@ export const layersReducers = {
const { data } = action.payload;
state.layers.push(data);
state.selectedEntityIdentifier = { type: 'layer', id: data.id };
state.baseLayerImageCache = null;
},
layerIsEnabledToggled: (state, action: PayloadAction<{ id: string }>) => {
const { id } = action.payload;
@ -54,6 +56,7 @@ export const layersReducers = {
return;
}
layer.isEnabled = !layer.isEnabled;
state.baseLayerImageCache = null;
},
layerTranslated: (state, action: PayloadAction<{ id: string; x: number; y: number }>) => {
const { id, x, y } = action.payload;
@ -63,6 +66,7 @@ export const layersReducers = {
}
layer.x = x;
layer.y = y;
state.baseLayerImageCache = null;
},
layerBboxChanged: (state, action: PayloadAction<{ id: string; bbox: IRect | null }>) => {
const { id, bbox } = action.payload;
@ -88,13 +92,16 @@ export const layersReducers = {
layer.objects = [];
layer.bbox = null;
layer.bboxNeedsUpdate = false;
state.baseLayerImageCache = null;
},
layerDeleted: (state, action: PayloadAction<{ id: string }>) => {
const { id } = action.payload;
state.layers = state.layers.filter((l) => l.id !== id);
state.baseLayerImageCache = null;
},
layerAllDeleted: (state) => {
state.layers = [];
state.baseLayerImageCache = null;
},
layerOpacityChanged: (state, action: PayloadAction<{ id: string; opacity: number }>) => {
const { id, opacity } = action.payload;
@ -103,6 +110,7 @@ export const layersReducers = {
return;
}
layer.opacity = opacity;
state.baseLayerImageCache = null;
},
layerMovedForwardOne: (state, action: PayloadAction<{ id: string }>) => {
const { id } = action.payload;
@ -111,6 +119,7 @@ export const layersReducers = {
return;
}
moveOneToEnd(state.layers, layer);
state.baseLayerImageCache = null;
},
layerMovedToFront: (state, action: PayloadAction<{ id: string }>) => {
const { id } = action.payload;
@ -119,6 +128,7 @@ export const layersReducers = {
return;
}
moveToEnd(state.layers, layer);
state.baseLayerImageCache = null;
},
layerMovedBackwardOne: (state, action: PayloadAction<{ id: string }>) => {
const { id } = action.payload;
@ -127,6 +137,7 @@ export const layersReducers = {
return;
}
moveOneToStart(state.layers, layer);
state.baseLayerImageCache = null;
},
layerMovedToBack: (state, action: PayloadAction<{ id: string }>) => {
const { id } = action.payload;
@ -135,6 +146,7 @@ export const layersReducers = {
return;
}
moveToStart(state.layers, layer);
state.baseLayerImageCache = null;
},
layerBrushLineAdded: {
reducer: (state, action: PayloadAction<BrushLineAddedArg & { lineId: string }>) => {
@ -153,6 +165,7 @@ export const layersReducers = {
clip,
});
layer.bboxNeedsUpdate = true;
state.baseLayerImageCache = null;
},
prepare: (payload: BrushLineAddedArg) => ({
payload: { ...payload, lineId: uuidv4() },
@ -174,6 +187,7 @@ export const layersReducers = {
clip,
});
layer.bboxNeedsUpdate = true;
state.baseLayerImageCache = null;
},
prepare: (payload: EraserLineAddedArg) => ({
payload: { ...payload, lineId: uuidv4() },
@ -191,6 +205,7 @@ export const layersReducers = {
}
lastObject.points.push(...point);
layer.bboxNeedsUpdate = true;
state.baseLayerImageCache = null;
},
layerRectAdded: {
reducer: (state, action: PayloadAction<RectShapeAddedArg & { rectId: string }>) => {
@ -210,6 +225,7 @@ export const layersReducers = {
color,
});
layer.bboxNeedsUpdate = true;
state.baseLayerImageCache = null;
},
prepare: (payload: RectShapeAddedArg) => ({ payload: { ...payload, rectId: uuidv4() } }),
},
@ -222,6 +238,7 @@ export const layersReducers = {
}
layer.objects.push(imageDTOToImageObject(id, objectId, imageDTO));
layer.bboxNeedsUpdate = true;
state.baseLayerImageCache = null;
},
prepare: (payload: ImageObjectAddedArg) => ({ payload: { ...payload, objectId: uuidv4() } }),
},

View File

@ -872,6 +872,7 @@ export type CanvasV2State = {
refinerNegativeAestheticScore: number;
refinerStart: number;
};
baseLayerImageCache: ImageWithDims | null;
};
export type StageAttrs = { x: number; y: number; width: number; height: number; scale: number };

View File

@ -0,0 +1,96 @@
import { getStore } from 'app/store/nanostores/store';
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
import { $nodeManager } from 'features/controlLayers/konva/renderers/renderer';
import { blobToDataURL } from 'features/controlLayers/konva/util';
import { baseLayerImageCacheChanged } from 'features/controlLayers/store/canvasV2Slice';
import type { LayerEntity } from 'features/controlLayers/store/types';
import type Konva from 'konva';
import type { IRect } from 'konva/lib/types';
import { getImageDTO, imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import { assert } from 'tsafe';
const isValidLayer = (entity: LayerEntity) => {
return (
entity.isEnabled &&
// Boolean(entity.bbox) && TODO(psyche): Re-enable this check when we have a way to calculate bbox for all layers
entity.objects.length > 0
);
};
/**
* Get the blobs of all regional prompt layers. Only visible layers are returned.
* @param layerIds The IDs of the layers to get blobs for. If not provided, all regional prompt layers are used.
* @param preview Whether to open a new tab displaying each layer.
* @returns A map of layer IDs to blobs.
*/
const getBaseLayer = async (layers: LayerEntity[], bbox: IRect, preview: boolean = false): Promise<Blob> => {
const manager = $nodeManager.get();
assert(manager, 'Node manager is null');
const stage = manager.stage.clone();
stage.scaleX(1);
stage.scaleY(1);
stage.x(0);
stage.y(0);
const validLayers = layers.filter(isValidLayer);
// Konva bug (?) - when iterating over the array returned from `stage.getLayers()`, if you destroy a layer, the array
// is mutated in-place and the next iteration will skip the next layer. To avoid this, we first collect the layers
// to delete in a separate array and then destroy them.
// TODO(psyche): Maybe report this?
const toDelete: Konva.Layer[] = [];
for (const konvaLayer of stage.getLayers()) {
const layer = validLayers.find((l) => l.id === konvaLayer.id());
if (!layer) {
toDelete.push(konvaLayer);
}
}
for (const konvaLayer of toDelete) {
konvaLayer.destroy();
}
const blob = await new Promise<Blob>((resolve) => {
stage.toBlob({
callback: (blob) => {
assert(blob, 'Blob is null');
resolve(blob);
},
...bbox,
});
});
if (preview) {
const base64 = await blobToDataURL(blob);
openBase64ImageInTab([{ base64, caption: 'base layer' }]);
}
stage.destroy();
return blob;
};
export const getBaseLayerImage = async (): Promise<ImageDTO> => {
const { dispatch, getState } = getStore();
const state = getState();
if (state.canvasV2.baseLayerImageCache) {
const imageDTO = await getImageDTO(state.canvasV2.baseLayerImageCache.name);
if (imageDTO) {
return imageDTO;
}
}
const blob = await getBaseLayer(state.canvasV2.layers, state.canvasV2.bbox, true);
const file = new File([blob], 'image.png', { type: 'image/png' });
const req = dispatch(
imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'general', is_intermediate: true })
);
req.reset();
const imageDTO = await req.unwrap();
dispatch(baseLayerImageCacheChanged(imageDTO));
return imageDTO;
};

View File

@ -1,8 +1,8 @@
import { getStore } from 'app/store/nanostores/store';
import { deepClone } from 'common/util/deepClone';
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
import { KonvaNodeManager } from 'features/controlLayers/konva/nodeManager';
import { renderRegions } from 'features/controlLayers/konva/renderers/regions';
import type { KonvaEntityAdapter } from 'features/controlLayers/konva/nodeManager';
import { $nodeManager } from 'features/controlLayers/konva/renderers/renderer';
import { blobToDataURL } from 'features/controlLayers/konva/util';
import { rgMaskImageUploaded } from 'features/controlLayers/store/canvasV2Slice';
import type { Dimensions, IPAdapterEntity, RegionEntity } from 'features/controlLayers/store/types';
@ -15,9 +15,7 @@ import {
} from 'features/nodes/util/graph/constants';
import { addIPAdapterCollectorSafe, isValidIPAdapter } from 'features/nodes/util/graph/generation/addIPAdapters';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import Konva from 'konva';
import type { IRect } from 'konva/lib/types';
import { size } from 'lodash-es';
import { getImageDTO, imagesApi } from 'services/api/endpoints/images';
import type { BaseModelType, ImageDTO, Invocation } from 'services/api/types';
import { assert } from 'tsafe';
@ -50,38 +48,34 @@ export const addRegions = async (
const isSDXL = base === 'sdxl';
const validRegions = regions.filter((rg) => isValidRegion(rg, base));
const blobs = await getRGMaskBlobs(validRegions, documentSize, bbox);
assert(size(blobs) === size(validRegions), 'Mismatch between layer IDs and blobs');
for (const rg of validRegions) {
const blob = blobs[rg.id];
assert(blob, `Blob for layer ${rg.id} not found`);
for (const region of validRegions) {
// Upload the mask image, or get the cached image if it exists
const { image_name } = await getMaskImage(rg, blob);
const { image_name } = await getRegionMaskImage(region, bbox, true);
// The main mask-to-tensor node
const maskToTensor = g.addNode({
id: `${PROMPT_REGION_MASK_TO_TENSOR_PREFIX}_${rg.id}`,
id: `${PROMPT_REGION_MASK_TO_TENSOR_PREFIX}_${region.id}`,
type: 'alpha_mask_to_tensor',
image: {
image_name,
},
});
if (rg.positivePrompt) {
if (region.positivePrompt) {
// The main positive conditioning node
const regionalPosCond = g.addNode(
isSDXL
? {
type: 'sdxl_compel_prompt',
id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${rg.id}`,
prompt: rg.positivePrompt,
style: rg.positivePrompt, // TODO: Should we put the positive prompt in both fields?
id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${region.id}`,
prompt: region.positivePrompt,
style: region.positivePrompt, // TODO: Should we put the positive prompt in both fields?
}
: {
type: 'compel',
id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${rg.id}`,
prompt: rg.positivePrompt,
id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${region.id}`,
prompt: region.positivePrompt,
}
);
// Connect the mask to the conditioning
@ -106,20 +100,20 @@ export const addRegions = async (
}
}
if (rg.negativePrompt) {
if (region.negativePrompt) {
// The main negative conditioning node
const regionalNegCond = g.addNode(
isSDXL
? {
type: 'sdxl_compel_prompt',
id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${rg.id}`,
prompt: rg.negativePrompt,
style: rg.negativePrompt,
id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${region.id}`,
prompt: region.negativePrompt,
style: region.negativePrompt,
}
: {
type: 'compel',
id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${rg.id}`,
prompt: rg.negativePrompt,
id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${region.id}`,
prompt: region.negativePrompt,
}
);
// Connect the mask to the conditioning
@ -143,10 +137,10 @@ export const addRegions = async (
}
// If we are using the "invert" auto-negative setting, we need to add an additional negative conditioning node
if (rg.autoNegative === 'invert' && rg.positivePrompt) {
if (region.autoNegative === 'invert' && region.positivePrompt) {
// We re-use the mask image, but invert it when converting to tensor
const invertTensorMask = g.addNode({
id: `${PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX}_${rg.id}`,
id: `${PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX}_${region.id}`,
type: 'invert_tensor_mask',
});
// Connect the OG mask image to the inverted mask-to-tensor node
@ -156,14 +150,14 @@ export const addRegions = async (
isSDXL
? {
type: 'sdxl_compel_prompt',
id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${rg.id}`,
prompt: rg.positivePrompt,
style: rg.positivePrompt,
id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${region.id}`,
prompt: region.positivePrompt,
style: region.positivePrompt,
}
: {
type: 'compel',
id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${rg.id}`,
prompt: rg.positivePrompt,
id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${region.id}`,
prompt: region.positivePrompt,
}
);
// Connect the inverted mask to the conditioning
@ -186,7 +180,7 @@ export const addRegions = async (
}
}
const validRGIPAdapters: IPAdapterEntity[] = rg.ipAdapters.filter((ipa) => isValidIPAdapter(ipa, base));
const validRGIPAdapters: IPAdapterEntity[] = region.ipAdapters.filter((ipa) => isValidIPAdapter(ipa, base));
for (const ipa of validRGIPAdapters) {
const ipAdapterCollect = addIPAdapterCollectorSafe(g, denoise);
@ -245,6 +239,20 @@ export const getMaskImage = async (rg: RegionEntity, blob: Blob): Promise<ImageD
return imageDTO;
};
export const uploadMaskImage = async ({ id }: RegionEntity, blob: Blob): Promise<ImageDTO> => {
const { dispatch } = getStore();
// No cached mask, or the cached image no longer exists - we need to upload the mask image
const file = new File([blob], `${id}_mask.png`, { type: 'image/png' });
const req = dispatch(
imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'mask', is_intermediate: true })
);
req.reset();
const imageDTO = await req.unwrap();
dispatch(rgMaskImageUploaded({ id, imageDTO }));
return imageDTO;
};
/**
* Get the blobs of all regional prompt layers. Only visible layers are returned.
* @param layerIds The IDs of the layers to get blobs for. If not provided, all regional prompt layers are used.
@ -252,33 +260,32 @@ export const getMaskImage = async (rg: RegionEntity, blob: Blob): Promise<ImageD
* @returns A map of layer IDs to blobs.
*/
export const getRGMaskBlobs = async (
regions: RegionEntity[],
documentSize: Dimensions,
export const getRegionMaskImage = async (
region: RegionEntity,
bbox: IRect,
preview: boolean = false
): Promise<Record<string, Blob>> => {
const container = document.createElement('div');
const stage = new Konva.Stage({ container, ...documentSize });
const manager = new KonvaNodeManager(stage);
renderRegions(manager, regions, 1, 'brush', null);
const adapters = manager.getAll();
const blobs: Record<string, Blob> = {};
): Promise<ImageDTO> => {
const manager = $nodeManager.get();
assert(manager, 'Node manager is null');
// First remove all layers
for (const adapter of adapters) {
adapter.konvaLayer.remove();
// TODO(psyche): Why do I need to annotate this? TS must have some kind of circular ref w/ this type but I can't figure it out...
const adapter: KonvaEntityAdapter | undefined = manager.get(region.id);
assert(adapter, `Adapter for region ${region.id} not found`);
if (region.imageCache) {
const imageDTO = await getImageDTO(region.imageCache.name);
if (imageDTO) {
return imageDTO;
}
}
const layer = adapter.konvaLayer.clone();
const objectGroup = adapter.konvaObjectGroup.clone();
layer.destroyChildren();
layer.add(objectGroup);
objectGroup.opacity(1);
objectGroup.cache();
// Next render each layer to a blob
for (const adapter of adapters) {
const region = regions.find((l) => l.id === adapter.id);
if (!region) {
continue;
}
stage.add(adapter.konvaLayer);
const blob = await new Promise<Blob>((resolve) => {
stage.toBlob({
layer.toBlob({
callback: (blob) => {
assert(blob, 'Blob is null');
resolve(blob);
@ -289,16 +296,11 @@ export const getRGMaskBlobs = async (
if (preview) {
const base64 = await blobToDataURL(blob);
openBase64ImageInTab([
{
base64,
caption: `${region.id}: ${region.positivePrompt} / ${region.negativePrompt}`,
},
]);
}
adapter.konvaLayer.remove();
blobs[adapter.id] = blob;
const caption = `${region.id}: ${region.positivePrompt} / ${region.negativePrompt}`;
openBase64ImageInTab([{ base64, caption }]);
}
return blobs;
layer.destroy();
return await uploadMaskImage(region, blob);
};