From 5c15458e15db4bb69d3ee04b443758b1948ccb59 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 5 Jul 2024 00:41:26 +1000 Subject: [PATCH] perf(ui): buffered drawing (wip) --- .../components/StageComponent.tsx | 1 + .../controlLayers/konva/CanvasInpaintMask.ts | 124 +++++--- .../controlLayers/konva/CanvasLayer.ts | 174 +++++++---- .../controlLayers/konva/CanvasManager.ts | 20 +- .../controlLayers/konva/CanvasRegion.ts | 123 +++++--- .../controlLayers/konva/CanvasStateApi.ts | 44 ++- .../features/controlLayers/konva/events.ts | 280 ++++++++++++------ .../controlLayers/store/canvasV2Slice.ts | 6 + .../store/inpaintMaskReducers.ts | 20 +- .../controlLayers/store/layersReducers.ts | 24 ++ .../controlLayers/store/regionsReducers.ts | 31 +- .../src/features/controlLayers/store/types.ts | 14 + 12 files changed, 605 insertions(+), 256 deletions(-) diff --git a/invokeai/frontend/web/src/features/controlLayers/components/StageComponent.tsx b/invokeai/frontend/web/src/features/controlLayers/components/StageComponent.tsx index 6efb54b9ca..1db97cce03 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/StageComponent.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/StageComponent.tsx @@ -27,6 +27,7 @@ const useStageRenderer = (stage: Konva.Stage, container: HTMLDivElement | null, const manager = new CanvasManager(stage, container, store); setCanvasManager(manager); + console.log(manager); const cleanup = manager.initialize(); return cleanup; }, [asPreview, container, stage, store]); diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasInpaintMask.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasInpaintMask.ts index 59ef9e7759..b22c66a6fd 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasInpaintMask.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasInpaintMask.ts @@ -6,7 +6,8 @@ import { CanvasRect } from 'features/controlLayers/konva/CanvasRect'; import { getNodeBboxFast } from 'features/controlLayers/konva/entityBbox'; import { getObjectGroupId, INPAINT_MASK_LAYER_ID } from 'features/controlLayers/konva/naming'; import { mapId } from 'features/controlLayers/konva/util'; -import { type InpaintMaskEntity, isDrawingTool } from 'features/controlLayers/store/types'; +import type { BrushLine, EraserLine, InpaintMaskEntity } from 'features/controlLayers/store/types'; +import { isDrawingTool, RGBA_RED } from 'features/controlLayers/store/types'; import Konva from 'konva'; import { assert } from 'tsafe'; import { v4 as uuidv4 } from 'uuid'; @@ -20,8 +21,10 @@ export class CanvasInpaintMask { compositingRect: Konva.Rect; transformer: Konva.Transformer; objects: Map; + private drawingBuffer: BrushLine | EraserLine | null; + private prevInpaintMaskState: InpaintMaskEntity; - constructor(manager: CanvasManager) { + constructor(entity: InpaintMaskEntity, manager: CanvasManager) { this.id = INPAINT_MASK_LAYER_ID; this.manager = manager; this.layer = new Konva.Layer({ id: INPAINT_MASK_LAYER_ID }); @@ -56,12 +59,42 @@ export class CanvasInpaintMask { this.compositingRect = new Konva.Rect({ listening: false }); this.group.add(this.compositingRect); this.objects = new Map(); + this.drawingBuffer = null; + this.prevInpaintMaskState = entity; } destroy(): void { this.layer.destroy(); } + getDrawingBuffer() { + return this.drawingBuffer; + } + + async setDrawingBuffer(obj: BrushLine | EraserLine | null) { + this.drawingBuffer = obj; + if (this.drawingBuffer) { + if (this.drawingBuffer.type === 'brush_line') { + this.drawingBuffer.color = RGBA_RED; + } + + await this.renderObject(this.drawingBuffer, true); + this.updateGroup(true, this.prevInpaintMaskState); + } + } + + finalizeDrawingBuffer() { + if (!this.drawingBuffer) { + return; + } + if (this.drawingBuffer.type === 'brush_line') { + this.manager.stateApi.onBrushLineAdded2({ id: this.id, brushLine: this.drawingBuffer }, 'inpaint_mask'); + } else if (this.drawingBuffer.type === 'eraser_line') { + this.manager.stateApi.onEraserLineAdded2({ id: this.id, eraserLine: this.drawingBuffer }, 'inpaint_mask'); + } + this.setDrawingBuffer(null); + } + async render(inpaintMaskState: InpaintMaskEntity) { // Update the layer's position and listening state this.group.setAttrs({ @@ -84,51 +117,62 @@ export class CanvasInpaintMask { } for (const obj of inpaintMaskState.objects) { - if (obj.type === 'brush_line') { - let brushLine = this.objects.get(obj.id); - assert(brushLine instanceof CanvasBrushLine || brushLine === undefined); + didDraw = await this.renderObject(obj); + } - if (!brushLine) { - brushLine = new CanvasBrushLine(obj); - this.objects.set(brushLine.id, brushLine); - this.objectsGroup.add(brushLine.konvaLineGroup); - didDraw = true; - } else { - if (brushLine.update(obj)) { - didDraw = true; - } + this.updateGroup(didDraw, inpaintMaskState); + this.prevInpaintMaskState = inpaintMaskState; + } + + private async renderObject(obj: InpaintMaskEntity['objects'][number], force = false): Promise { + if (obj.type === 'brush_line') { + let brushLine = this.objects.get(obj.id); + assert(brushLine instanceof CanvasBrushLine || brushLine === undefined); + + if (!brushLine) { + brushLine = new CanvasBrushLine(obj); + this.objects.set(brushLine.id, brushLine); + this.objectsGroup.add(brushLine.konvaLineGroup); + return true; + } else { + if (brushLine.update(obj, force)) { + return true; } - } else if (obj.type === 'eraser_line') { - let eraserLine = this.objects.get(obj.id); - assert(eraserLine instanceof CanvasEraserLine || eraserLine === undefined); + } + } else if (obj.type === 'eraser_line') { + let eraserLine = this.objects.get(obj.id); + assert(eraserLine instanceof CanvasEraserLine || eraserLine === undefined); - if (!eraserLine) { - eraserLine = new CanvasEraserLine(obj); - this.objects.set(eraserLine.id, eraserLine); - this.objectsGroup.add(eraserLine.konvaLineGroup); - didDraw = true; - } else { - if (eraserLine.update(obj)) { - didDraw = true; - } + if (!eraserLine) { + eraserLine = new CanvasEraserLine(obj); + this.objects.set(eraserLine.id, eraserLine); + this.objectsGroup.add(eraserLine.konvaLineGroup); + return true; + } else { + if (eraserLine.update(obj, force)) { + return true; } - } else if (obj.type === 'rect_shape') { - let rect = this.objects.get(obj.id); - assert(rect instanceof CanvasRect || rect === undefined); + } + } else if (obj.type === 'rect_shape') { + let rect = this.objects.get(obj.id); + assert(rect instanceof CanvasRect || rect === undefined); - if (!rect) { - rect = new CanvasRect(obj); - this.objects.set(rect.id, rect); - this.objectsGroup.add(rect.konvaRect); - didDraw = true; - } else { - if (rect.update(obj)) { - didDraw = true; - } + if (!rect) { + rect = new CanvasRect(obj); + this.objects.set(rect.id, rect); + this.objectsGroup.add(rect.konvaRect); + return true; + } else { + if (rect.update(obj, force)) { + return true; } } } + return false; + } + + updateGroup(didDraw: boolean, inpaintMaskState: InpaintMaskEntity) { // Only update layer visibility if it has changed. if (this.layer.visible() !== inpaintMaskState.isEnabled) { this.layer.visible(inpaintMaskState.isEnabled); @@ -155,10 +199,6 @@ export class CanvasInpaintMask { }); } - this.updateGroup(didDraw); - } - - updateGroup(didDraw: boolean) { const isSelected = this.manager.stateApi.getIsSelected(this.id); const selectedTool = this.manager.stateApi.getToolState().selected; diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasLayer.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasLayer.ts index 3288e710a4..f2a861a07d 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasLayer.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasLayer.ts @@ -5,7 +5,8 @@ import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; import { CanvasRect } from 'features/controlLayers/konva/CanvasRect'; import { getObjectGroupId } from 'features/controlLayers/konva/naming'; import { mapId } from 'features/controlLayers/konva/util'; -import { isDrawingTool, type LayerEntity } from 'features/controlLayers/store/types'; +import type { BrushLine, EraserLine, LayerEntity } from 'features/controlLayers/store/types'; +import { isDrawingTool } from 'features/controlLayers/store/types'; import Konva from 'konva'; import { assert } from 'tsafe'; import { v4 as uuidv4 } from 'uuid'; @@ -15,8 +16,11 @@ export class CanvasLayer { manager: CanvasManager; layer: Konva.Layer; group: Konva.Group; + objectsGroup: Konva.Group; transformer: Konva.Transformer; objects: Map; + private drawingBuffer: BrushLine | EraserLine | null; + private prevLayerState: LayerEntity; constructor(entity: LayerEntity, manager: CanvasManager) { this.id = entity.id; @@ -30,6 +34,8 @@ export class CanvasLayer { id: getObjectGroupId(this.layer.id(), uuidv4()), listening: false, }); + this.objectsGroup = new Konva.Group({}); + this.group.add(this.objectsGroup); this.layer.add(this.group); this.transformer = new Konva.Transformer({ @@ -52,12 +58,40 @@ export class CanvasLayer { this.layer.add(this.transformer); this.objects = new Map(); + this.drawingBuffer = null; + this.prevLayerState = entity; } destroy(): void { this.layer.destroy(); } + getDrawingBuffer() { + return this.drawingBuffer; + } + + async setDrawingBuffer(obj: BrushLine | EraserLine | null) { + if (obj) { + this.drawingBuffer = obj; + await this.renderObject(this.drawingBuffer, true); + this.updateGroup(true, this.prevLayerState); + } else { + this.drawingBuffer = null; + } + } + + finalizeDrawingBuffer() { + if (!this.drawingBuffer) { + return; + } + if (this.drawingBuffer.type === 'brush_line') { + this.manager.stateApi.onBrushLineAdded2({ id: this.id, brushLine: this.drawingBuffer }, 'layer'); + } else if (this.drawingBuffer.type === 'eraser_line') { + this.manager.stateApi.onEraserLineAdded2({ id: this.id, eraserLine: this.drawingBuffer }, 'layer'); + } + this.setDrawingBuffer(null); + } + async render(layerState: LayerEntity) { // Update the layer's position and listening state this.group.setAttrs({ @@ -72,7 +106,7 @@ export class CanvasLayer { const objectIds = layerState.objects.map(mapId); // Destroy any objects that are no longer in state for (const object of this.objects.values()) { - if (!objectIds.includes(object.id)) { + if (!objectIds.includes(object.id) && object.id !== this.drawingBuffer?.id) { this.objects.delete(object.id); object.destroy(); didDraw = true; @@ -80,67 +114,11 @@ export class CanvasLayer { } for (const obj of layerState.objects) { - if (obj.type === 'brush_line') { - let brushLine = this.objects.get(obj.id); - assert(brushLine instanceof CanvasBrushLine || brushLine === undefined); + didDraw = await this.renderObject(obj); + } - if (!brushLine) { - brushLine = new CanvasBrushLine(obj); - this.objects.set(brushLine.id, brushLine); - this.group.add(brushLine.konvaLineGroup); - didDraw = true; - } else { - if (brushLine.update(obj)) { - didDraw = true; - } - } - } else if (obj.type === 'eraser_line') { - let eraserLine = this.objects.get(obj.id); - assert(eraserLine instanceof CanvasEraserLine || eraserLine === undefined); - - if (!eraserLine) { - eraserLine = new CanvasEraserLine(obj); - this.objects.set(eraserLine.id, eraserLine); - this.group.add(eraserLine.konvaLineGroup); - didDraw = true; - } else { - if (eraserLine.update(obj)) { - didDraw = true; - } - } - } else if (obj.type === 'rect_shape') { - let rect = this.objects.get(obj.id); - assert(rect instanceof CanvasRect || rect === undefined); - - if (!rect) { - rect = new CanvasRect(obj); - this.objects.set(rect.id, rect); - this.group.add(rect.konvaRect); - didDraw = true; - } else { - if (rect.update(obj)) { - didDraw = true; - } - } - } else if (obj.type === 'image') { - let image = this.objects.get(obj.id); - assert(image instanceof CanvasImage || image === undefined); - - if (!image) { - image = await new CanvasImage(obj, { - onLoad: () => { - this.updateGroup(true); - }, - }); - this.objects.set(image.id, image); - this.group.add(image.konvaImageGroup); - await image.updateImageSource(obj.image.name); - } else { - if (await image.update(obj)) { - didDraw = true; - } - } - } + if (this.drawingBuffer) { + didDraw = await this.renderObject(this.drawingBuffer); } // Only update layer visibility if it has changed. @@ -151,10 +129,78 @@ export class CanvasLayer { this.group.opacity(layerState.opacity); // The layer only listens when using the move tool - otherwise the stage is handling mouse events - this.updateGroup(didDraw); + this.updateGroup(didDraw, this.prevLayerState); + + this.prevLayerState = layerState; } - updateGroup(didDraw: boolean) { + private async renderObject(obj: LayerEntity['objects'][number], force = false): Promise { + if (obj.type === 'brush_line') { + let brushLine = this.objects.get(obj.id); + assert(brushLine instanceof CanvasBrushLine || brushLine === undefined); + + if (!brushLine) { + brushLine = new CanvasBrushLine(obj); + this.objects.set(brushLine.id, brushLine); + this.objectsGroup.add(brushLine.konvaLineGroup); + return true; + } else { + if (brushLine.update(obj, force)) { + return true; + } + } + } else if (obj.type === 'eraser_line') { + let eraserLine = this.objects.get(obj.id); + assert(eraserLine instanceof CanvasEraserLine || eraserLine === undefined); + + if (!eraserLine) { + eraserLine = new CanvasEraserLine(obj); + this.objects.set(eraserLine.id, eraserLine); + this.objectsGroup.add(eraserLine.konvaLineGroup); + return true; + } else { + if (eraserLine.update(obj, force)) { + return true; + } + } + } else if (obj.type === 'rect_shape') { + let rect = this.objects.get(obj.id); + assert(rect instanceof CanvasRect || rect === undefined); + + if (!rect) { + rect = new CanvasRect(obj); + this.objects.set(rect.id, rect); + this.objectsGroup.add(rect.konvaRect); + return true; + } else { + if (rect.update(obj, force)) { + return true; + } + } + } else if (obj.type === 'image') { + let image = this.objects.get(obj.id); + assert(image instanceof CanvasImage || image === undefined); + + if (!image) { + image = await new CanvasImage(obj, { + onLoad: () => { + this.updateGroup(true, this.prevLayerState); + }, + }); + this.objects.set(image.id, image); + this.objectsGroup.add(image.konvaImageGroup); + await image.updateImageSource(obj.image.name); + } else { + if (await image.update(obj, force)) { + return true; + } + } + } + + return false; + } + + updateGroup(didDraw: boolean, _: LayerEntity) { const isSelected = this.manager.stateApi.getIsSelected(this.id); const selectedTool = this.manager.stateApi.getToolState().selected; diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasManager.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasManager.ts index 56af6548fe..72640e7d71 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasManager.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasManager.ts @@ -95,7 +95,7 @@ export class CanvasManager { this.background = new CanvasBackground(this); this.stage.add(this.background.layer); - this.inpaintMask = new CanvasInpaintMask(this); + this.inpaintMask = new CanvasInpaintMask(this.stateApi.getInpaintMaskState(), this); this.stage.add(this.inpaintMask.layer); this.layers = new Map(); @@ -346,6 +346,24 @@ export class CanvasManager { }; }; + getSelectedEntityAdapter = (): CanvasLayer | CanvasRegion | CanvasControlAdapter | CanvasInpaintMask | null => { + const state = this.stateApi.getState(); + const identifier = state.selectedEntityIdentifier; + if (!identifier) { + return null; + } else if (identifier.type === 'layer') { + return this.layers.get(identifier.id) ?? null; + } else if (identifier.type === 'control_adapter') { + return this.controlAdapters.get(identifier.id) ?? null; + } else if (identifier.type === 'regional_guidance') { + return this.regions.get(identifier.id) ?? null; + } else if (identifier.type === 'inpaint_mask') { + return this.inpaintMask; + } else { + return null; + } + }; + getGenerationMode() { return getGenerationMode({ manager: this }); } diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasRegion.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasRegion.ts index 57bda0a1ef..c3a05f36cc 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasRegion.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasRegion.ts @@ -6,7 +6,8 @@ import { CanvasRect } from 'features/controlLayers/konva/CanvasRect'; import { getNodeBboxFast } from 'features/controlLayers/konva/entityBbox'; import { getObjectGroupId } from 'features/controlLayers/konva/naming'; import { mapId } from 'features/controlLayers/konva/util'; -import { isDrawingTool, type RegionEntity } from 'features/controlLayers/store/types'; +import type { BrushLine, EraserLine, RegionEntity } from 'features/controlLayers/store/types'; +import { isDrawingTool, RGBA_RED } from 'features/controlLayers/store/types'; import Konva from 'konva'; import { assert } from 'tsafe'; import { v4 as uuidv4 } from 'uuid'; @@ -20,6 +21,8 @@ export class CanvasRegion { compositingRect: Konva.Rect; transformer: Konva.Transformer; objects: Map; + private drawingBuffer: BrushLine | EraserLine | null; + private prevRegionState: RegionEntity; constructor(entity: RegionEntity, manager: CanvasManager) { this.id = entity.id; @@ -56,12 +59,41 @@ export class CanvasRegion { this.compositingRect = new Konva.Rect({ listening: false }); this.group.add(this.compositingRect); this.objects = new Map(); + this.drawingBuffer = null; + this.prevRegionState = entity; } destroy(): void { this.layer.destroy(); } + getDrawingBuffer() { + return this.drawingBuffer; + } + + async setDrawingBuffer(obj: BrushLine | EraserLine | null) { + this.drawingBuffer = obj; + if (this.drawingBuffer) { + if (this.drawingBuffer.type === 'brush_line') { + this.drawingBuffer.color = RGBA_RED; + } + await this.renderObject(this.drawingBuffer, true); + this.updateGroup(true, this.prevRegionState); + } + } + + finalizeDrawingBuffer() { + if (!this.drawingBuffer) { + return; + } + if (this.drawingBuffer.type === 'brush_line') { + this.manager.stateApi.onBrushLineAdded2({ id: this.id, brushLine: this.drawingBuffer }, 'regional_guidance'); + } else if (this.drawingBuffer.type === 'eraser_line') { + this.manager.stateApi.onEraserLineAdded2({ id: this.id, eraserLine: this.drawingBuffer }, 'regional_guidance'); + } + this.setDrawingBuffer(null); + } + async render(regionState: RegionEntity) { // Update the layer's position and listening state this.group.setAttrs({ @@ -84,51 +116,62 @@ export class CanvasRegion { } for (const obj of regionState.objects) { - if (obj.type === 'brush_line') { - let brushLine = this.objects.get(obj.id); - assert(brushLine instanceof CanvasBrushLine || brushLine === undefined); + didDraw = await this.renderObject(obj); + } - if (!brushLine) { - brushLine = new CanvasBrushLine(obj); - this.objects.set(brushLine.id, brushLine); - this.objectsGroup.add(brushLine.konvaLineGroup); - didDraw = true; - } else { - if (brushLine.update(obj)) { - didDraw = true; - } + this.updateGroup(didDraw, regionState); + this.prevRegionState = regionState; + } + + private async renderObject(obj: RegionEntity['objects'][number], force = false): Promise { + if (obj.type === 'brush_line') { + let brushLine = this.objects.get(obj.id); + assert(brushLine instanceof CanvasBrushLine || brushLine === undefined); + + if (!brushLine) { + brushLine = new CanvasBrushLine(obj); + this.objects.set(brushLine.id, brushLine); + this.objectsGroup.add(brushLine.konvaLineGroup); + return true; + } else { + if (brushLine.update(obj, force)) { + return true; } - } else if (obj.type === 'eraser_line') { - let eraserLine = this.objects.get(obj.id); - assert(eraserLine instanceof CanvasEraserLine || eraserLine === undefined); + } + } else if (obj.type === 'eraser_line') { + let eraserLine = this.objects.get(obj.id); + assert(eraserLine instanceof CanvasEraserLine || eraserLine === undefined); - if (!eraserLine) { - eraserLine = new CanvasEraserLine(obj); - this.objects.set(eraserLine.id, eraserLine); - this.objectsGroup.add(eraserLine.konvaLineGroup); - didDraw = true; - } else { - if (eraserLine.update(obj)) { - didDraw = true; - } + if (!eraserLine) { + eraserLine = new CanvasEraserLine(obj); + this.objects.set(eraserLine.id, eraserLine); + this.objectsGroup.add(eraserLine.konvaLineGroup); + return true; + } else { + if (eraserLine.update(obj, force)) { + return true; } - } else if (obj.type === 'rect_shape') { - let rect = this.objects.get(obj.id); - assert(rect instanceof CanvasRect || rect === undefined); + } + } else if (obj.type === 'rect_shape') { + let rect = this.objects.get(obj.id); + assert(rect instanceof CanvasRect || rect === undefined); - if (!rect) { - rect = new CanvasRect(obj); - this.objects.set(rect.id, rect); - this.objectsGroup.add(rect.konvaRect); - didDraw = true; - } else { - if (rect.update(obj)) { - didDraw = true; - } + if (!rect) { + rect = new CanvasRect(obj); + this.objects.set(rect.id, rect); + this.objectsGroup.add(rect.konvaRect); + return true; + } else { + if (rect.update(obj, force)) { + return true; } } } + return false; + } + + updateGroup(didDraw: boolean, regionState: RegionEntity) { // Only update layer visibility if it has changed. if (this.layer.visible() !== regionState.isEnabled) { this.layer.visible(regionState.isEnabled); @@ -141,7 +184,6 @@ export class CanvasRegion { // Convert the color to a string, stripping the alpha - the object group will handle opacity. const rgbColor = rgbColorToString(regionState.fill); const maskOpacity = this.manager.stateApi.getMaskOpacity(); - this.compositingRect.setAttrs({ // The rect should be the size of the layer - use the fast method if we don't have a pixel-perfect bbox already ...getNodeBboxFast(this.objectsGroup), @@ -149,16 +191,11 @@ export class CanvasRegion { opacity: maskOpacity, // Draw this rect only where there are non-transparent pixels under it (e.g. the mask shapes) globalCompositeOperation: 'source-in', - visible: true, // This rect must always be on top of all other shapes zIndex: this.objects.size + 1, }); } - this.updateGroup(didDraw); - } - - updateGroup(didDraw: boolean) { const isSelected = this.manager.stateApi.getIsSelected(this.id); const selectedTool = this.manager.stateApi.getToolState().selected; diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApi.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApi.ts index 083a882c44..735fdb635e 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApi.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApi.ts @@ -19,7 +19,9 @@ import { eraserWidthChanged, imBboxChanged, imBrushLineAdded, + imBrushLineAdded2, imEraserLineAdded, + imEraserLineAdded2, imImageCacheChanged, imLinePointAdded, imRectAdded, @@ -27,7 +29,9 @@ import { imTranslated, layerBboxChanged, layerBrushLineAdded, + layerBrushLineAdded2, layerEraserLineAdded, + layerEraserLineAdded2, layerImageCacheChanged, layerLinePointAdded, layerRectAdded, @@ -35,7 +39,9 @@ import { layerTranslated, rgBboxChanged, rgBrushLineAdded, + rgBrushLineAdded2, rgEraserLineAdded, + rgEraserLineAdded2, rgImageCacheChanged, rgLinePointAdded, rgRectAdded, @@ -46,8 +52,10 @@ import { } from 'features/controlLayers/store/canvasV2Slice'; import type { BboxChangedArg, + BrushLine, BrushLineAddedArg, CanvasEntity, + EraserLine, EraserLineAddedArg, PointAddedToLineArg, PosChangedArg, @@ -127,6 +135,26 @@ export class CanvasStateApi { this.store.dispatch(imEraserLineAdded(arg)); } }; + onBrushLineAdded2 = (arg: { id: string; brushLine: BrushLine }, entityType: CanvasEntity['type']) => { + log.debug('Brush line added'); + if (entityType === 'layer') { + this.store.dispatch(layerBrushLineAdded2(arg)); + } else if (entityType === 'regional_guidance') { + this.store.dispatch(rgBrushLineAdded2(arg)); + } else if (entityType === 'inpaint_mask') { + this.store.dispatch(imBrushLineAdded2(arg)); + } + }; + onEraserLineAdded2 = (arg: { id: string; eraserLine: EraserLine }, entityType: CanvasEntity['type']) => { + log.debug('Eraser line added'); + if (entityType === 'layer') { + this.store.dispatch(layerEraserLineAdded2(arg)); + } else if (entityType === 'regional_guidance') { + this.store.dispatch(rgEraserLineAdded2(arg)); + } else if (entityType === 'inpaint_mask') { + this.store.dispatch(imEraserLineAdded2(arg)); + } + }; onPointAddedToLine = (arg: PointAddedToLineArg, entityType: CanvasEntity['type']) => { log.debug('Point added to line'); if (entityType === 'layer') { @@ -183,23 +211,21 @@ export class CanvasStateApi { getSelectedEntity = (): CanvasEntity | null => { const state = this.getState(); const identifier = state.selectedEntityIdentifier; - let selectedEntity: CanvasEntity | null = null; if (!identifier) { - selectedEntity = null; + return null; } else if (identifier.type === 'layer') { - selectedEntity = state.layers.entities.find((i) => i.id === identifier.id) ?? null; + return state.layers.entities.find((i) => i.id === identifier.id) ?? null; } else if (identifier.type === 'control_adapter') { - selectedEntity = state.controlAdapters.entities.find((i) => i.id === identifier.id) ?? null; + return state.controlAdapters.entities.find((i) => i.id === identifier.id) ?? null; } else if (identifier.type === 'ip_adapter') { - selectedEntity = state.ipAdapters.entities.find((i) => i.id === identifier.id) ?? null; + return state.ipAdapters.entities.find((i) => i.id === identifier.id) ?? null; } else if (identifier.type === 'regional_guidance') { - selectedEntity = state.regions.entities.find((i) => i.id === identifier.id) ?? null; + return state.regions.entities.find((i) => i.id === identifier.id) ?? null; } else if (identifier.type === 'inpaint_mask') { - selectedEntity = state.inpaintMask; + return state.inpaintMask; } else { - selectedEntity = null; + return null; } - return selectedEntity; }; getCurrentFill = () => { diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/events.ts b/invokeai/frontend/web/src/features/controlLayers/konva/events.ts index 4967982604..53fb797222 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/events.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/events.ts @@ -1,19 +1,14 @@ import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; import { getScaledCursorPosition } from 'features/controlLayers/konva/util'; -import type { CanvasEntity } from 'features/controlLayers/store/types'; +import type { CanvasEntity, CanvasV2State, Position } from 'features/controlLayers/store/types'; +import { isDrawableEntity, isDrawableEntityAdapter } from 'features/controlLayers/store/types'; import type Konva from 'konva'; import type { Vector2d } from 'konva/lib/types'; import { clamp } from 'lodash-es'; +import { v4 as uuidv4 } from 'uuid'; -import { - BRUSH_SPACING_TARGET_SCALE, - CANVAS_SCALE_BY, - MAX_BRUSH_SPACING_PX, - MAX_CANVAS_SCALE, - MIN_BRUSH_SPACING_PX, - MIN_CANVAS_SCALE, -} from './constants'; -import { PREVIEW_TOOL_GROUP_ID } from './naming'; +import { BRUSH_SPACING_TARGET_SCALE, CANVAS_SCALE_BY, MAX_CANVAS_SCALE, MIN_CANVAS_SCALE } from './constants'; +import { getBrushLineId, PREVIEW_TOOL_GROUP_ID } from './naming'; /** * Updates the last cursor position atom with the current cursor position, returning the new position or `null` if the @@ -21,10 +16,7 @@ import { PREVIEW_TOOL_GROUP_ID } from './naming'; * @param stage The konva stage * @param setLastCursorPos The callback to store the cursor pos */ -const updateLastCursorPos = ( - stage: Konva.Stage, - setLastCursorPos: CanvasManager['stateApi']['setLastCursorPos'] -) => { +const updateLastCursorPos = (stage: Konva.Stage, setLastCursorPos: CanvasManager['stateApi']['setLastCursorPos']) => { const pos = getScaledCursorPosition(stage); if (!pos) { return null; @@ -61,24 +53,18 @@ const maybeAddNextPoint = ( setLastAddedPoint: CanvasManager['stateApi']['setLastAddedPoint'], onPointAddedToLine: CanvasManager['stateApi']['onPointAddedToLine'] ) => { - const isDrawableEntity = - selectedEntity?.type === 'regional_guidance' || - selectedEntity?.type === 'layer' || - selectedEntity?.type === 'inpaint_mask'; - - if (!isDrawableEntity) { + if (!isDrawableEntity(selectedEntity)) { return; } + // Continue the last line const lastAddedPoint = getLastAddedPoint(); const toolState = getToolState(); - const minSpacingPx = clamp( + const minSpacingPx = toolState.selected === 'brush' ? toolState.brush.width * BRUSH_SPACING_TARGET_SCALE - : toolState.eraser.width * BRUSH_SPACING_TARGET_SCALE, - MIN_BRUSH_SPACING_PX, - MAX_BRUSH_SPACING_PX - ); + : toolState.eraser.width * BRUSH_SPACING_TARGET_SCALE; + if (lastAddedPoint) { // Dispatching redux events impacts perf substantially - using brush spacing keeps dispatches to a reasonable number if (Math.hypot(lastAddedPoint.x - currentPos.x, lastAddedPoint.y - currentPos.y) < minSpacingPx) { @@ -95,8 +81,29 @@ const maybeAddNextPoint = ( ); }; +const getNextPoint = ( + currentPos: Position, + toolState: CanvasV2State['tool'], + lastAddedPoint: Position | null +): Position | null => { + // Continue the last line + const minSpacingPx = + toolState.selected === 'brush' + ? toolState.brush.width * BRUSH_SPACING_TARGET_SCALE + : toolState.eraser.width * BRUSH_SPACING_TARGET_SCALE; + + if (lastAddedPoint) { + // Dispatching redux events impacts perf substantially - using brush spacing keeps dispatches to a reasonable number + if (Math.hypot(lastAddedPoint.x - currentPos.x, lastAddedPoint.y - currentPos.y) < minSpacingPx) { + return null; + } + } + + return currentPos; +}; + export const setStageEventHandlers = (manager: CanvasManager): (() => void) => { - const { stage, stateApi } = manager; + const { stage, stateApi, getSelectedEntityAdapter } = manager; const { getToolState, getCurrentFill, @@ -132,17 +139,21 @@ export const setStageEventHandlers = (manager: CanvasManager): (() => void) => { }); //#region mousedown - stage.on('mousedown', (e) => { + stage.on('mousedown', async (e) => { setIsMouseDown(true); const toolState = getToolState(); const pos = updateLastCursorPos(stage, setLastCursorPos); const selectedEntity = getSelectedEntity(); - const isDrawableEntity = - selectedEntity?.type === 'regional_guidance' || - selectedEntity?.type === 'layer' || - selectedEntity?.type === 'inpaint_mask'; + const selectedEntityAdapter = getSelectedEntityAdapter(); - if (pos && selectedEntity && isDrawableEntity && !getSpaceKey()) { + if ( + pos && + selectedEntity && + isDrawableEntity(selectedEntity) && + selectedEntityAdapter && + isDrawableEntityAdapter(selectedEntityAdapter) && + !getSpaceKey() + ) { setIsDrawing(true); setLastMouseDownPos(pos); @@ -180,21 +191,37 @@ export const setStageEventHandlers = (manager: CanvasManager): (() => void) => { ); } } else { - onBrushLineAdded( - { - id: selectedEntity.id, - points: [ - pos.x - selectedEntity.x, - pos.y - selectedEntity.y, - pos.x - selectedEntity.x, - pos.y - selectedEntity.y, - ], - color: getCurrentFill(), - width: toolState.brush.width, - clip, - }, - selectedEntity.type - ); + if (selectedEntityAdapter.getDrawingBuffer()) { + selectedEntityAdapter.finalizeDrawingBuffer(); + } + await selectedEntityAdapter.setDrawingBuffer({ + id: getBrushLineId(selectedEntityAdapter.id, uuidv4()), + type: 'brush_line', + points: [ + pos.x - selectedEntity.x, + pos.y - selectedEntity.y, + pos.x - selectedEntity.x, + pos.y - selectedEntity.y, + ], + strokeWidth: toolState.brush.width, + color: getCurrentFill(), + clip, + }); + // onBrushLineAdded( + // { + // id: selectedEntity.id, + // points: [ + // pos.x - selectedEntity.x, + // pos.y - selectedEntity.y, + // pos.x - selectedEntity.x, + // pos.y - selectedEntity.y, + // ], + // color: getCurrentFill(), + // width: toolState.brush.width, + // clip, + // }, + // selectedEntity.type + // ); } setLastAddedPoint(pos); } @@ -231,20 +258,36 @@ export const setStageEventHandlers = (manager: CanvasManager): (() => void) => { ); } } else { - onEraserLineAdded( - { - id: selectedEntity.id, - points: [ - pos.x - selectedEntity.x, - pos.y - selectedEntity.y, - pos.x - selectedEntity.x, - pos.y - selectedEntity.y, - ], - width: toolState.eraser.width, - clip, - }, - selectedEntity.type - ); + if (selectedEntityAdapter.getDrawingBuffer()) { + selectedEntityAdapter.finalizeDrawingBuffer(); + } + await selectedEntityAdapter.setDrawingBuffer({ + id: getBrushLineId(selectedEntityAdapter.id, uuidv4()), + type: 'eraser_line', + points: [ + pos.x - selectedEntity.x, + pos.y - selectedEntity.y, + pos.x - selectedEntity.x, + pos.y - selectedEntity.y, + ], + strokeWidth: toolState.eraser.width, + clip, + }); + + // onEraserLineAdded( + // { + // id: selectedEntity.id, + // points: [ + // pos.x - selectedEntity.x, + // pos.y - selectedEntity.y, + // pos.x - selectedEntity.x, + // pos.y - selectedEntity.y, + // ], + // width: toolState.eraser.width, + // clip, + // }, + // selectedEntity.type + // ); } setLastAddedPoint(pos); } @@ -253,18 +296,40 @@ export const setStageEventHandlers = (manager: CanvasManager): (() => void) => { }); //#region mouseup - stage.on('mouseup', () => { + stage.on('mouseup', async () => { setIsMouseDown(false); const pos = getLastCursorPos(); const selectedEntity = getSelectedEntity(); - const isDrawableEntity = - selectedEntity?.type === 'regional_guidance' || - selectedEntity?.type === 'layer' || - selectedEntity?.type === 'inpaint_mask'; + const selectedEntityAdapter = getSelectedEntityAdapter(); - if (pos && selectedEntity && isDrawableEntity && !getSpaceKey()) { + if ( + pos && + selectedEntity && + isDrawableEntity(selectedEntity) && + selectedEntityAdapter && + isDrawableEntityAdapter(selectedEntityAdapter) && + !getSpaceKey() + ) { const toolState = getToolState(); + if (toolState.selected === 'brush') { + const drawingBuffer = selectedEntityAdapter.getDrawingBuffer(); + if (drawingBuffer?.type === 'brush_line') { + selectedEntityAdapter.finalizeDrawingBuffer(); + } else { + await selectedEntityAdapter.setDrawingBuffer(null); + } + } + + if (toolState.selected === 'eraser') { + const drawingBuffer = selectedEntityAdapter.getDrawingBuffer(); + if (drawingBuffer?.type === 'eraser_line') { + selectedEntityAdapter.finalizeDrawingBuffer(); + } else { + await selectedEntityAdapter.setDrawingBuffer(null); + } + } + if (toolState.selected === 'rect') { const lastMouseDownPos = getLastMouseDownPos(); if (lastMouseDownPos) { @@ -292,32 +357,48 @@ export const setStageEventHandlers = (manager: CanvasManager): (() => void) => { }); //#region mousemove - stage.on('mousemove', () => { + stage.on('mousemove', async () => { const toolState = getToolState(); const pos = updateLastCursorPos(stage, setLastCursorPos); const selectedEntity = getSelectedEntity(); + const selectedEntityAdapter = getSelectedEntityAdapter(); stage .findOne(`#${PREVIEW_TOOL_GROUP_ID}`) ?.visible(toolState.selected === 'brush' || toolState.selected === 'eraser'); - const isDrawableEntity = - selectedEntity?.type === 'regional_guidance' || - selectedEntity?.type === 'layer' || - selectedEntity?.type === 'inpaint_mask'; - - if (pos && selectedEntity && isDrawableEntity && !getSpaceKey() && getIsMouseDown()) { + if ( + pos && + selectedEntity && + isDrawableEntity(selectedEntity) && + selectedEntityAdapter && + isDrawableEntityAdapter(selectedEntityAdapter) && + !getSpaceKey() && + getIsMouseDown() + ) { if (toolState.selected === 'brush') { if (getIsDrawing()) { + const drawingBuffer = selectedEntityAdapter.getDrawingBuffer(); + if (drawingBuffer?.type === 'brush_line') { + const lastAddedPoint = getLastAddedPoint(); + const nextPoint = getNextPoint(pos, toolState, lastAddedPoint); + if (nextPoint) { + drawingBuffer.points.push(nextPoint.x - selectedEntity.x, nextPoint.y - selectedEntity.y); + await selectedEntityAdapter.setDrawingBuffer(drawingBuffer); + setLastAddedPoint(nextPoint); + } + } else { + await selectedEntityAdapter.setDrawingBuffer(null); + } // Continue the last line - maybeAddNextPoint( - selectedEntity, - pos, - getToolState, - getLastAddedPoint, - setLastAddedPoint, - onPointAddedToLine - ); + // maybeAddNextPoint( + // selectedEntity, + // pos, + // getToolState, + // getLastAddedPoint, + // setLastAddedPoint, + // onPointAddedToLine + // ); } else { const bbox = getBbox(); const settings = getSettings(); @@ -353,15 +434,28 @@ export const setStageEventHandlers = (manager: CanvasManager): (() => void) => { if (toolState.selected === 'eraser') { if (getIsDrawing()) { + const drawingBuffer = selectedEntityAdapter.getDrawingBuffer(); + if (drawingBuffer?.type === 'eraser_line') { + const lastAddedPoint = getLastAddedPoint(); + const nextPoint = getNextPoint(pos, toolState, lastAddedPoint); + if (nextPoint) { + drawingBuffer.points.push(nextPoint.x - selectedEntity.x, nextPoint.y - selectedEntity.y); + await selectedEntityAdapter.setDrawingBuffer(drawingBuffer); + setLastAddedPoint(nextPoint); + } + } else { + await selectedEntityAdapter.setDrawingBuffer(null); + } + // Continue the last line - maybeAddNextPoint( - selectedEntity, - pos, - getToolState, - getLastAddedPoint, - setLastAddedPoint, - onPointAddedToLine - ); + // maybeAddNextPoint( + // selectedEntity, + // pos, + // getToolState, + // getLastAddedPoint, + // setLastAddedPoint, + // onPointAddedToLine + // ); } else { const bbox = getBbox(); const settings = getSettings(); @@ -407,12 +501,8 @@ export const setStageEventHandlers = (manager: CanvasManager): (() => void) => { const toolState = getToolState(); stage.findOne(`#${PREVIEW_TOOL_GROUP_ID}`)?.visible(false); - const isDrawableEntity = - selectedEntity?.type === 'regional_guidance' || - selectedEntity?.type === 'layer' || - selectedEntity?.type === 'inpaint_mask'; - if (pos && selectedEntity && isDrawableEntity && !getSpaceKey() && getIsMouseDown()) { + if (pos && selectedEntity && isDrawableEntity(selectedEntity) && !getSpaceKey() && getIsMouseDown()) { if (getIsMouseDown()) { if (toolState.selected === 'brush') { onPointAddedToLine({ id: selectedEntity.id, point: [pos.x, pos.y] }, selectedEntity.type); diff --git a/invokeai/frontend/web/src/features/controlLayers/store/canvasV2Slice.ts b/invokeai/frontend/web/src/features/controlLayers/store/canvasV2Slice.ts index c341e5e152..104a6d13b2 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/canvasV2Slice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/canvasV2Slice.ts @@ -347,6 +347,12 @@ export const { stagingAreaCanceledStaging, stagingAreaNextImageSelected, stagingAreaPreviousImageSelected, + layerBrushLineAdded2, + layerEraserLineAdded2, + rgBrushLineAdded2, + rgEraserLineAdded2, + imBrushLineAdded2, + imEraserLineAdded2, } = canvasV2Slice.actions; export const selectCanvasV2Slice = (state: RootState) => state.canvasV2; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/inpaintMaskReducers.ts b/invokeai/frontend/web/src/features/controlLayers/store/inpaintMaskReducers.ts index 58343c3ebf..9dc78b6105 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/inpaintMaskReducers.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/inpaintMaskReducers.ts @@ -1,6 +1,12 @@ import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit'; import { getBrushLineId, getEraserLineId, getRectShapeId } from 'features/controlLayers/konva/naming'; -import type { CanvasV2State, InpaintMaskEntity, ScaleChangedArg } from 'features/controlLayers/store/types'; +import type { + BrushLine, + CanvasV2State, + EraserLine, + InpaintMaskEntity, + ScaleChangedArg, +} from 'features/controlLayers/store/types'; import { imageDTOToImageWithDims, RGBA_RED } from 'features/controlLayers/store/types'; import type { IRect } from 'konva/lib/types'; import type { ImageDTO } from 'services/api/types'; @@ -81,6 +87,18 @@ export const inpaintMaskReducers = { payload: { ...payload, lineId: uuidv4() }, }), }, + imBrushLineAdded2: (state, action: PayloadAction<{ brushLine: BrushLine }>) => { + const { brushLine } = action.payload; + state.inpaintMask.objects.push(brushLine); + state.inpaintMask.bboxNeedsUpdate = true; + state.layers.imageCache = null; + }, + imEraserLineAdded2: (state, action: PayloadAction<{ eraserLine: EraserLine }>) => { + const { eraserLine } = action.payload; + state.inpaintMask.objects.push(eraserLine); + state.inpaintMask.bboxNeedsUpdate = true; + state.layers.imageCache = null; + }, imEraserLineAdded: { reducer: (state, action: PayloadAction & { lineId: string }>) => { const { points, lineId, width, clip } = action.payload; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/layersReducers.ts b/invokeai/frontend/web/src/features/controlLayers/store/layersReducers.ts index c39a70fe01..1199780818 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/layersReducers.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/layersReducers.ts @@ -7,8 +7,10 @@ import { assert } from 'tsafe'; import { v4 as uuidv4 } from 'uuid'; import type { + BrushLine, BrushLineAddedArg, CanvasV2State, + EraserLine, EraserLineAddedArg, ImageObjectAddedArg, LayerEntity, @@ -152,6 +154,28 @@ export const layersReducers = { moveToStart(state.layers.entities, layer); state.layers.imageCache = null; }, + layerBrushLineAdded2: (state, action: PayloadAction<{ id: string; brushLine: BrushLine }>) => { + const { id, brushLine } = action.payload; + const layer = selectLayer(state, id); + if (!layer) { + return; + } + + layer.objects.push(brushLine); + layer.bboxNeedsUpdate = true; + state.layers.imageCache = null; + }, + layerEraserLineAdded2: (state, action: PayloadAction<{ id: string; eraserLine: EraserLine }>) => { + const { id, eraserLine } = action.payload; + const layer = selectLayer(state, id); + if (!layer) { + return; + } + + layer.objects.push(eraserLine); + layer.bboxNeedsUpdate = true; + state.layers.imageCache = null; + }, layerBrushLineAdded: { reducer: (state, action: PayloadAction) => { const { id, points, lineId, color, width, clip } = action.payload; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/regionsReducers.ts b/invokeai/frontend/web/src/features/controlLayers/store/regionsReducers.ts index 9c7fb6e1e8..b033d6f813 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/regionsReducers.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/regionsReducers.ts @@ -1,7 +1,14 @@ import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit'; import { moveOneToEnd, moveOneToStart, moveToEnd, moveToStart } from 'common/util/arrayUtils'; import { getBrushLineId, getEraserLineId, getRectShapeId } from 'features/controlLayers/konva/naming'; -import type { CanvasV2State, CLIPVisionModelV2, IPMethodV2, ScaleChangedArg } from 'features/controlLayers/store/types'; +import type { + BrushLine, + CanvasV2State, + CLIPVisionModelV2, + EraserLine, + IPMethodV2, + ScaleChangedArg, +} from 'features/controlLayers/store/types'; import { imageDTOToImageObject, imageDTOToImageWithDims, RGBA_RED } from 'features/controlLayers/store/types'; import { zModelIdentifierField } from 'features/nodes/types/common'; import type { ParameterAutoNegative } from 'features/parameters/types/parameterSchemas'; @@ -354,6 +361,28 @@ export const regionsReducers = { payload: { ...payload, lineId: uuidv4() }, }), }, + rgBrushLineAdded2: (state, action: PayloadAction<{ id: string; brushLine: BrushLine }>) => { + const { id, brushLine } = action.payload; + const rg = selectRG(state, id); + if (!rg) { + return; + } + + rg.objects.push(brushLine); + rg.bboxNeedsUpdate = true; + state.layers.imageCache = null; + }, + rgEraserLineAdded2: (state, action: PayloadAction<{ id: string; eraserLine: EraserLine }>) => { + const { id, eraserLine } = action.payload; + const rg = selectRG(state, id); + if (!rg) { + return; + } + + rg.objects.push(eraserLine); + rg.bboxNeedsUpdate = true; + state.layers.imageCache = null; + }, rgEraserLineAdded: { reducer: (state, action: PayloadAction) => { const { id, points, lineId, width, clip } = action.payload; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/types.ts b/invokeai/frontend/web/src/features/controlLayers/store/types.ts index 4693654147..872afbe56b 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/types.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/types.ts @@ -1,3 +1,7 @@ +import type { CanvasControlAdapter } from 'features/controlLayers/konva/CanvasControlAdapter'; +import { CanvasInpaintMask } from 'features/controlLayers/konva/CanvasInpaintMask'; +import { CanvasLayer } from 'features/controlLayers/konva/CanvasLayer'; +import { CanvasRegion } from 'features/controlLayers/konva/CanvasRegion'; import { getImageObjectId } from 'features/controlLayers/konva/naming'; import { zModelIdentifierField } from 'features/nodes/types/common'; import type { AspectRatioState } from 'features/parameters/components/ImageSize/types'; @@ -924,3 +928,13 @@ export type RemoveIndexString = { }; export type GenerationMode = 'txt2img' | 'img2img' | 'inpaint' | 'outpaint'; + +export function isDrawableEntity(entity: CanvasEntity): entity is LayerEntity | RegionEntity | InpaintMaskEntity { + return entity.type === 'layer' || entity.type === 'regional_guidance' || entity.type === 'inpaint_mask'; +} + +export function isDrawableEntityAdapter( + adapter: CanvasLayer | CanvasRegion | CanvasControlAdapter | CanvasInpaintMask +): adapter is CanvasLayer | CanvasRegion | CanvasInpaintMask { + return adapter instanceof CanvasLayer || adapter instanceof CanvasRegion || adapter instanceof CanvasInpaintMask; +}