feat(ui): layer bbox calc in worker

This commit is contained in:
psychedelicious 2024-07-18 19:21:40 +10:00
parent e70339ff3e
commit 778ee2c679
11 changed files with 324 additions and 14 deletions

View File

@ -29,7 +29,8 @@ export type LoggerNamespace =
| 'dnd'
| 'controlLayers'
| 'metadata'
| 'konva';
| 'konva'
| 'worker';
export const logger = (namespace: LoggerNamespace) => $logger.get().child({ namespace });

View File

@ -1,4 +1,5 @@
/* eslint-disable i18next/no-literal-string */
import { Button } from '@chakra-ui/react';
import { Flex } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { BrushWidth } from 'features/controlLayers/components/BrushWidth';
@ -9,12 +10,19 @@ import { NewSessionButton } from 'features/controlLayers/components/NewSessionBu
import { ResetCanvasButton } from 'features/controlLayers/components/ResetCanvasButton';
import { ToolChooser } from 'features/controlLayers/components/ToolChooser';
import { UndoRedoButtonGroup } from 'features/controlLayers/components/UndoRedoButtonGroup';
import { getCanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { ToggleProgressButton } from 'features/gallery/components/ImageViewer/ToggleProgressButton';
import { ViewerToggleMenu } from 'features/gallery/components/ImageViewer/ViewerToggleMenu';
import { memo } from 'react';
import { memo, useCallback } from 'react';
export const ControlLayersToolbar = memo(() => {
const tool = useAppSelector((s) => s.canvasV2.tool.selected);
const bbox = useCallback(() => {
const manager = getCanvasManager();
for (const l of manager.layers.values()) {
l.getBbox();
}
}, []);
return (
<Flex w="full" gap={2}>
<Flex flex={1} justifyContent="center">
@ -27,6 +35,7 @@ export const ControlLayersToolbar = memo(() => {
{tool === 'brush' && <BrushWidth />}
{tool === 'eraser' && <EraserWidth />}
</Flex>
<Button onClick={bbox}>bbox</Button>
<Flex flex={1} justifyContent="center">
<Flex gap={2} marginInlineStart="auto" alignItems="center">
<FillColorPicker />

View File

@ -242,7 +242,7 @@ export class CanvasInpaintMask {
// When the layer is selected and being moved, we should always cache it.
// We should update the cache if we drew to the layer.
if (!this.konva.group.isCached() || didDraw) {
this.konva.group.cache();
// this.konva.group.cache();
}
// Activate the transformer
this.konva.layer.listening(true);
@ -266,7 +266,7 @@ export class CanvasInpaintMask {
// We are using a non-drawing tool (move, view, bbox), so we should cache the layer.
// We should update the cache if we drew to the layer.
if (!this.konva.group.isCached() || didDraw) {
this.konva.group.cache();
// this.konva.group.cache();
}
}
return;
@ -279,7 +279,7 @@ export class CanvasInpaintMask {
this.konva.transformer.nodes([]);
// Update the layer's cache if it's not already cached or we drew to it.
if (!this.konva.group.isCached() || didDraw) {
this.konva.group.cache();
// this.konva.group.cache();
}
return;

View File

@ -4,9 +4,10 @@ import { CanvasImage } from 'features/controlLayers/konva/CanvasImage';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasRect } from 'features/controlLayers/konva/CanvasRect';
import { mapId } from 'features/controlLayers/konva/util';
import type { BrushLine, EraserLine, LayerEntity, RectShape } from 'features/controlLayers/store/types';
import type { BrushLine, EraserLine, LayerEntity, Rect, RectShape } from 'features/controlLayers/store/types';
import { isDrawingTool } from 'features/controlLayers/store/types';
import Konva from 'konva';
import { debounce } from 'lodash-es';
import { assert } from 'tsafe';
export class CanvasLayer {
@ -24,18 +25,26 @@ export class CanvasLayer {
konva: {
layer: Konva.Layer;
bbox: Konva.Rect;
group: Konva.Group;
objectGroup: Konva.Group;
transformer: Konva.Transformer;
};
objects: Map<string, CanvasBrushLine | CanvasEraserLine | CanvasRect | CanvasImage>;
bbox: Rect | null;
getBbox = debounce(this._getBbox, 300);
constructor(state: LayerEntity, manager: CanvasManager) {
this.id = state.id;
this.manager = manager;
this.konva = {
layer: new Konva.Layer({ name: CanvasLayer.LAYER_NAME, listening: false }),
group: new Konva.Group({ name: CanvasLayer.GROUP_NAME, listening: false }),
group: new Konva.Group({ name: CanvasLayer.GROUP_NAME, listening: true }),
bbox: new Konva.Rect({
listening: true,
stroke: 'hsl(200deg 76% 59%)', // invokeBlue.400
}),
objectGroup: new Konva.Group({ name: CanvasLayer.OBJECT_GROUP_NAME, listening: false }),
transformer: new Konva.Transformer({
name: CanvasLayer.TRANSFORMER_NAME,
@ -49,6 +58,7 @@ export class CanvasLayer {
};
this.konva.group.add(this.konva.objectGroup);
this.konva.group.add(this.konva.bbox);
this.konva.layer.add(this.konva.group);
this.konva.transformer.on('transformend', () => {
@ -72,6 +82,7 @@ export class CanvasLayer {
this.objects = new Map();
this.drawingBuffer = null;
this.state = state;
this.bbox = null;
}
destroy(): void {
@ -213,6 +224,10 @@ export class CanvasLayer {
return;
}
if (didDraw) {
this.getBbox();
}
this.konva.layer.visible(true);
this.konva.group.opacity(this.state.opacity);
const isSelected = this.manager.stateApi.getIsSelected(this.id);
@ -229,7 +244,7 @@ export class CanvasLayer {
// When the layer is selected and being moved, we should always cache it.
// We should update the cache if we drew to the layer.
if (!this.konva.group.isCached() || didDraw) {
this.konva.group.cache();
// this.konva.group.cache();
}
// Activate the transformer
this.konva.layer.listening(true);
@ -250,7 +265,7 @@ export class CanvasLayer {
// We are using a non-drawing tool (move, view, bbox), so we should cache the layer.
// We should update the cache if we drew to the layer.
if (!this.konva.group.isCached() || didDraw) {
this.konva.group.cache();
// this.konva.group.cache();
}
}
} else if (!isSelected) {
@ -260,8 +275,79 @@ export class CanvasLayer {
this.konva.transformer.nodes([]);
// Update the layer's cache if it's not already cached or we drew to it.
if (!this.konva.group.isCached() || didDraw) {
this.konva.group.cache();
// this.konva.group.cache();
}
}
}
renderBbox() {
if (!this.bbox) {
this.konva.bbox.visible(false);
return;
}
this.konva.bbox.visible(true);
this.konva.bbox.strokeWidth(1 / this.manager.stage.scaleX());
this.konva.bbox.setAttrs(this.bbox);
}
private _getBbox() {
let needsPixelBbox = false;
const rect = this.konva.objectGroup.getClientRect({ skipTransform: true });
// console.log('rect', rect);
// If there are no eraser strokes, we can use the client rect directly
for (const obj of this.objects.values()) {
if (obj instanceof CanvasEraserLine) {
needsPixelBbox = true;
break;
}
}
if (!needsPixelBbox) {
if (rect.width === 0 || rect.height === 0) {
this.bbox = null;
} else {
this.bbox = rect;
}
this.renderBbox();
return;
}
// We have eraser strokes - we must calculate the bbox using pixel data
// const a = window.performance.now();
const clone = this.konva.objectGroup.clone();
// const b = window.performance.now();
// console.log('cloned layer', b - a);
// const c = window.performance.now();
const canvas = clone.toCanvas();
// const d = window.performance.now();
// console.log('got canvas', d - c);
const ctx = canvas.getContext('2d');
if (!ctx) {
return;
}
const imageData = ctx.getImageData(0, 0, rect.width, rect.height);
// const e = window.performance.now();
// console.log('got image data', e - d);
this.manager.requestBbox(
{ buffer: imageData.data.buffer, width: imageData.width, height: imageData.height },
(extents) => {
// console.log('extents', extents);
if (extents) {
this.bbox = {
x: extents.minX + rect.x - Math.floor(this.konva.layer.x()),
y: extents.minY + rect.y - Math.floor(this.konva.layer.y()),
width: extents.maxX - extents.minX,
height: extents.maxY - extents.minY,
};
} else {
this.bbox = null;
}
this.renderBbox();
clone.destroy();
// console.log('bbox', this.bbox);
}
);
// console.log('transferred message', window.performance.now() - e);
}
}

View File

@ -11,6 +11,7 @@ import {
getInpaintMaskImage,
getRegionMaskImage,
} from 'features/controlLayers/konva/util';
import type { Extents, ExtentsResult, GetBboxTask, WorkerLogMessage } from 'features/controlLayers/konva/worker';
import { $lastProgressEvent, $shouldShowStagedImage } from 'features/controlLayers/store/canvasV2Slice';
import type { CanvasV2State, GenerationMode } from 'features/controlLayers/store/types';
import type Konva from 'konva';
@ -33,6 +34,24 @@ import { setStageEventHandlers } from './events';
const log = logger('canvas');
// type Extents = {
// minX: number;
// minY: number;
// maxX: number;
// maxY: number;
// };
// type GetBboxTask = {
// id: string;
// type: 'get_bbox';
// data: { imageData: ImageData };
// };
// type GetBboxResult = {
// id: string;
// type: 'get_bbox';
// data: { extents: Extents | null };
// };
type Util = {
getImageDTO: (imageName: string) => Promise<ImageDTO | null>;
uploadImage: (
@ -65,9 +84,12 @@ export class CanvasManager {
stateApi: CanvasStateApi;
preview: CanvasPreview;
background: CanvasBackground;
private store: Store<RootState>;
private isFirstRender: boolean;
private prevState: CanvasV2State;
private worker: Worker;
private tasks: Map<string, { task: GetBboxTask; onComplete: (extents: Extents | null) => void }>;
constructor(
stage: Konva.Stage,
@ -108,6 +130,41 @@ export class CanvasManager {
this.initialImage = new CanvasInitialImage(this.stateApi.getInitialImageState(), this);
this.stage.add(this.initialImage.konva.layer);
this.worker = new Worker(new URL('./worker.ts', import.meta.url), { type: 'module', name: 'worker' });
this.tasks = new Map();
this.worker.onmessage = (event: MessageEvent<ExtentsResult | WorkerLogMessage>) => {
const { type, data } = event.data;
if (type === 'log') {
if (data.ctx) {
log[data.level](data.ctx, data.message);
} else {
log[data.level](data.message);
}
} else if (type === 'extents') {
const task = this.tasks.get(data.id);
if (!task) {
return;
}
task.onComplete(data.extents);
}
};
this.worker.onerror = (event) => {
log.error({ message: event.message }, 'Worker error');
};
this.worker.onmessageerror = () => {
log.error('Worker message error');
};
}
requestBbox(data: Omit<GetBboxTask['data'], 'id'>, onComplete: (extents: Extents | null) => void) {
const id = crypto.randomUUID();
const task: GetBboxTask = {
type: 'get_bbox',
data: { ...data, id },
};
this.tasks.set(id, { task, onComplete });
this.worker.postMessage(task, [data.buffer]);
}
async renderInitialImage() {
@ -187,6 +244,12 @@ export class CanvasManager {
}
}
renderBboxes() {
for (const layer of this.layers.values()) {
layer.renderBbox();
}
}
arrangeEntities() {
const { getLayersState, getControlAdaptersState, getRegionsState } = this.stateApi;
const layers = getLayersState().entities;

View File

@ -47,7 +47,7 @@ const GET_CLIENT_RECT_CONFIG = { skipTransform: true };
* @param imageData The ImageData object to get the bounding box of.
* @returns The minimum and maximum x and y values of the image's bounding box, or null if the image has no pixels.
*/
const getImageDataBbox = (imageData: ImageData): Extents | null => {
export const getImageDataBbox = (imageData: ImageData): Extents | null => {
const { data, width, height } = imageData;
let minX = width;
let minY = height;
@ -77,7 +77,7 @@ const getImageDataBbox = (imageData: ImageData): Extents | null => {
}
}
return isEmpty ? null : { minX, minY, maxX, maxY };
return isEmpty ? null : { minX, minY, maxX: maxX + 1, maxY: maxY + 1 };
};
/**

View File

@ -496,6 +496,7 @@ export const setStageEventHandlers = (manager: CanvasManager): (() => void) => {
scale: newScale,
});
manager.background.render();
manager.renderBboxes();
}
}
manager.preview.tool.render();

View File

@ -142,6 +142,25 @@ export function imageDataToDataURL(imageData: ImageData): string {
return canvas.toDataURL();
}
export function imageDataToBlob(imageData: ImageData): Promise<Blob | null> {
const w = imageData.width;
const h = imageData.height;
const canvas = document.createElement('canvas');
canvas.width = w;
canvas.height = h;
const ctx = canvas.getContext('2d');
if (!ctx) {
return Promise.resolve(null);
}
ctx.putImageData(imageData, 0, 0);
return new Promise<Blob | null>((resolve) => {
canvas.toBlob(resolve);
});
}
/**
* Download a Blob as a file
*/

View File

@ -0,0 +1,131 @@
import type { LogLevel } from 'app/logging/logger';
import type { JsonObject } from 'roarr/dist/types';
export type Extents = {
minX: number;
minY: number;
maxX: number;
maxY: number;
};
/**
* Get the bounding box of an image.
* @param buffer The ArrayBuffer of the image to get the bounding box of.
* @param width The width of the image.
* @param height The height of the image.
* @returns The minimum and maximum x and y values of the image's bounding box, or null if the image has no pixels.
*/
const getImageDataBboxArrayBuffer = (buffer: ArrayBuffer, width: number, height: number): Extents | null => {
let minX = width;
let minY = height;
let maxX = -1;
let maxY = -1;
let alpha = 0;
let isEmpty = true;
const arr = new Uint8ClampedArray(buffer);
for (let y = 0; y < height; y++) {
for (let x = 0; x < width; x++) {
alpha = arr[(y * width + x) * 4 + 3] ?? 0;
if (alpha > 0) {
isEmpty = false;
if (x < minX) {
minX = x;
}
if (x > maxX) {
maxX = x;
}
if (y < minY) {
minY = y;
}
if (y > maxY) {
maxY = y;
}
}
}
}
return isEmpty ? null : { minX, minY, maxX: maxX + 1, maxY: maxY + 1 };
};
export type GetBboxTask = {
type: 'get_bbox';
data: { id: string; buffer: ArrayBuffer; width: number; height: number };
};
type TaskWithTimestamps<T extends Record<string, unknown>> = T & { started: number | null; finished: number | null };
export type ExtentsResult = {
type: 'extents';
data: { id: string; extents: Extents | null };
};
export type WorkerLogMessage = {
type: 'log';
data: { level: LogLevel; message: string; ctx?: JsonObject };
};
// A single worker is used to process tasks in a queue
const queue: TaskWithTimestamps<GetBboxTask>[] = [];
let currentTask: TaskWithTimestamps<GetBboxTask> | null = null;
function postLogMessage(level: LogLevel, message: string, ctx?: JsonObject) {
const data: WorkerLogMessage = {
type: 'log',
data: { level, message, ctx },
};
self.postMessage(data);
}
function processNextTask() {
// Grab the next task
const task = queue.shift();
if (!task) {
// Queue empty - we can clear the current task to allow the worker to resume the queue when another task is posted
currentTask = null;
return;
}
postLogMessage('debug', 'Processing task', { type: task.type, id: task.data.id });
task.started = performance.now();
// Set the current task so we don't process another one
currentTask = task;
// Process the task
if (task.type === 'get_bbox') {
const { buffer, width, height, id } = task.data;
const extents = getImageDataBboxArrayBuffer(buffer, width, height);
const result: ExtentsResult = {
type: 'extents',
data: { id, extents },
};
task.finished = performance.now();
postLogMessage('debug', 'Task complete', {
type: task.type,
id: task.data.id,
started: task.started,
finished: task.finished,
durationMs: task.finished - task.started,
});
self.postMessage(result);
} else {
postLogMessage('error', 'Unknown task type', { type: task.type });
}
// Repeat
processNextTask();
}
self.onmessage = (event: MessageEvent<Omit<GetBboxTask, 'started' | 'finished'>>) => {
const task = event.data;
postLogMessage('debug', 'Received task', { type: task.type, id: task.data.id });
// Add the task to the queue
queue.push({ ...event.data, started: null, finished: null });
// If we are not currently processing a task, process the next one
if (!currentTask) {
processNextTask();
}
};

View File

@ -58,7 +58,7 @@ export const layersReducers = {
type: 'layer',
isEnabled: true,
bbox: null,
bboxNeedsUpdate: false,
bboxNeedsUpdate: true,
objects: [imageObject],
opacity: 1,
position: { x: position.x + offsetX, y: position.y + offsetY },

View File

@ -5,7 +5,7 @@
"lib": ["DOM", "DOM.Iterable", "ESNext"],
"allowJs": false,
"skipLibCheck": true,
"esModuleInterop": false,
"esModuleInterop": true,
"allowSyntheticDefaultImports": true,
"strict": true,
"forceConsistentCasingInFileNames": true,