From 80e7e1293a648cf59ed7d91143081b52f61fbd90 Mon Sep 17 00:00:00 2001
From: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
Date: Wed, 3 Jul 2024 19:57:07 +1000
Subject: [PATCH] fix(ui): region rendering

---
 .../controlLayers/konva/nodeManager.ts        |   7 +-
 .../konva/renderers/inpaintMask.ts            |   6 +-
 .../controlLayers/konva/renderers/regions.ts  | 262 +++++++++---------
 .../controlLayers/konva/renderers/renderer.ts |   3 +
 .../controlLayers/store/canvasV2Slice.ts      |   1 +
 .../controlLayers/store/regionsReducers.ts    |  29 +-
 6 files changed, 171 insertions(+), 137 deletions(-)

diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/nodeManager.ts b/invokeai/frontend/web/src/features/controlLayers/konva/nodeManager.ts
index 7dc7415b3e..de5c372498 100644
--- a/invokeai/frontend/web/src/features/controlLayers/konva/nodeManager.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/konva/nodeManager.ts
@@ -181,9 +181,6 @@ export class KonvaNodeManager {
 
   renderRegions() {
     const { entities } = this.stateApi.getRegionsState();
-    const maskOpacity = this.stateApi.getMaskOpacity();
-    const toolState = this.stateApi.getToolState();
-    const selectedEntity = this.stateApi.getSelectedEntity();
 
     // Destroy the konva nodes for nonexistent entities
     for (const canvasRegion of this.regions.values()) {
@@ -196,11 +193,11 @@ export class KonvaNodeManager {
     for (const entity of entities) {
       let adapter = this.regions.get(entity.id);
       if (!adapter) {
-        adapter = new CanvasRegion(entity, this.stateApi.onPosChanged);
+        adapter = new CanvasRegion(entity, this);
         this.regions.set(adapter.id, adapter);
         this.stage.add(adapter.layer);
       }
-      adapter.render(entity, toolState.selected, selectedEntity, maskOpacity);
+      adapter.render(entity);
     }
   }
 
diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/renderers/inpaintMask.ts b/invokeai/frontend/web/src/features/controlLayers/konva/renderers/inpaintMask.ts
index 4f9c022240..06d95993a3 100644
--- a/invokeai/frontend/web/src/features/controlLayers/konva/renderers/inpaintMask.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/konva/renderers/inpaintMask.ts
@@ -22,11 +22,7 @@ export class CanvasInpaintMask {
   constructor(entity: InpaintMaskEntity, manager: KonvaNodeManager) {
     this.id = entity.id;
     this.manager = manager;
-    this.layer = new Konva.Layer({
-      id: entity.id,
-      draggable: true,
-      dragDistance: 0,
-    });
+    this.layer = new Konva.Layer({ id: entity.id });
 
     this.group = new Konva.Group({
       id: getObjectGroupId(this.layer.id(), uuidv4()),
diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/renderers/regions.ts b/invokeai/frontend/web/src/features/controlLayers/konva/renderers/regions.ts
index 2e8f154a98..0f292dfaca 100644
--- a/invokeai/frontend/web/src/features/controlLayers/konva/renderers/regions.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/konva/renderers/regions.ts
@@ -1,42 +1,61 @@
 import { rgbColorToString } from 'common/util/colorCodeTransformers';
 import { getObjectGroupId } from 'features/controlLayers/konva/naming';
-import type { StateApi } from 'features/controlLayers/konva/nodeManager';
+import type { KonvaNodeManager } from 'features/controlLayers/konva/nodeManager';
 import { getNodeBboxFast } from 'features/controlLayers/konva/renderers/entityBbox';
 import { KonvaBrushLine, KonvaEraserLine, KonvaRect } from 'features/controlLayers/konva/renderers/objects';
 import { mapId } from 'features/controlLayers/konva/util';
-import type { CanvasEntityIdentifier, RegionEntity, Tool } from 'features/controlLayers/store/types';
+import {
+  isDrawingTool,
+  type RegionEntity,
+} from 'features/controlLayers/store/types';
 import Konva from 'konva';
 import { assert } from 'tsafe';
 import { v4 as uuidv4 } from 'uuid';
 
 export class CanvasRegion {
   id: string;
+  manager: KonvaNodeManager;
   layer: Konva.Layer;
   group: Konva.Group;
+  objectsGroup: Konva.Group;
   compositingRect: Konva.Rect;
+  transformer: Konva.Transformer;
   objects: Map<string, KonvaBrushLine | KonvaEraserLine | KonvaRect>;
 
-  constructor(entity: RegionEntity, onPosChanged: StateApi['onPosChanged']) {
+  constructor(entity: RegionEntity, manager: KonvaNodeManager) {
     this.id = entity.id;
+    this.manager = manager;
+    this.layer = new Konva.Layer({ id: entity.id });
 
-    this.layer = new Konva.Layer({
-      id: entity.id,
-      draggable: true,
-      dragDistance: 0,
-    });
-
-    // When a drag on the layer finishes, update the layer's position in state. During the drag, konva handles changing
-    // the position - we do not need to call this on the `dragmove` event.
-    this.layer.on('dragend', function (e) {
-      onPosChanged({ id: entity.id, x: Math.floor(e.target.x()), y: Math.floor(e.target.y()) }, 'regional_guidance');
-    });
     this.group = new Konva.Group({
       id: getObjectGroupId(this.layer.id(), uuidv4()),
       listening: false,
     });
+    this.objectsGroup = new Konva.Group({});
+    this.group.add(this.objectsGroup);
     this.layer.add(this.group);
+
+    this.transformer = new Konva.Transformer({
+      shouldOverdrawWholeArea: true,
+      draggable: true,
+      dragDistance: 0,
+      enabledAnchors: ['top-left', 'top-right', 'bottom-left', 'bottom-right'],
+      rotateEnabled: false,
+      flipEnabled: false,
+    });
+    this.transformer.on('transformend', () => {
+      this.manager.stateApi.onScaleChanged(
+        { id: this.id, scale: this.group.scaleX(), x: this.group.x(), y: this.group.y() },
+        'regional_guidance'
+      );
+    });
+    this.transformer.on('dragend', () => {
+      this.manager.stateApi.onPosChanged({ id: this.id, x: this.group.x(), y: this.group.y() }, 'regional_guidance');
+    });
+    this.layer.add(this.transformer);
+
     this.compositingRect = new Konva.Rect({ listening: false });
-    this.layer.add(this.compositingRect);
+    this.group.add(this.compositingRect);
     this.objects = new Map();
   }
 
@@ -44,24 +63,16 @@ export class CanvasRegion {
     this.layer.destroy();
   }
 
-  async render(
-    regionState: RegionEntity,
-    selectedTool: Tool,
-    selectedEntityIdentifier: CanvasEntityIdentifier | null,
-    maskOpacity: number
-  ) {
+  async render(regionState: RegionEntity) {
     // Update the layer's position and listening state
-    this.layer.setAttrs({
-      listening: selectedTool === 'move', // The layer only listens when using the move tool - otherwise the stage is handling mouse events
-      x: Math.floor(regionState.x),
-      y: Math.floor(regionState.y),
+    this.group.setAttrs({
+      x: regionState.x,
+      y: regionState.y,
+      scaleX: 1,
+      scaleY: 1,
     });
 
-    // Convert the color to a string, stripping the alpha - the object group will handle opacity.
-    const rgbColor = rgbColorToString(regionState.fill);
-
-    // We use caching to handle "global" layer opacity, but caching is expensive and we should only do it when required.
-    let groupNeedsCache = false;
+    let didDraw = false;
 
     const objectIds = regionState.objects.map(mapId);
     // Destroy any objects that are no longer in state
@@ -69,7 +80,7 @@ export class CanvasRegion {
       if (!objectIds.includes(object.id)) {
         this.objects.delete(object.id);
         object.destroy();
-        groupNeedsCache = true;
+        didDraw = true;
       }
     }
 
@@ -81,13 +92,12 @@ export class CanvasRegion {
         if (!brushLine) {
           brushLine = new KonvaBrushLine(obj);
           this.objects.set(brushLine.id, brushLine);
-          this.group.add(brushLine.konvaLineGroup);
-          groupNeedsCache = true;
-        }
-
-        if (obj.points.length !== brushLine.konvaLine.points().length) {
-          brushLine.konvaLine.points(obj.points);
-          groupNeedsCache = true;
+          this.objectsGroup.add(brushLine.konvaLineGroup);
+          didDraw = true;
+        } else {
+          if (brushLine.update(obj)) {
+            didDraw = true;
+          }
         }
       } else if (obj.type === 'eraser_line') {
         let eraserLine = this.objects.get(obj.id);
@@ -96,13 +106,12 @@ export class CanvasRegion {
         if (!eraserLine) {
           eraserLine = new KonvaEraserLine(obj);
           this.objects.set(eraserLine.id, eraserLine);
-          this.group.add(eraserLine.konvaLineGroup);
-          groupNeedsCache = true;
-        }
-
-        if (obj.points.length !== eraserLine.konvaLine.points().length) {
-          eraserLine.konvaLine.points(obj.points);
-          groupNeedsCache = true;
+          this.objectsGroup.add(eraserLine.konvaLineGroup);
+          didDraw = true;
+        } else {
+          if (eraserLine.update(obj)) {
+            didDraw = true;
+          }
         }
       } else if (obj.type === 'rect_shape') {
         let rect = this.objects.get(obj.id);
@@ -111,8 +120,12 @@ export class CanvasRegion {
         if (!rect) {
           rect = new KonvaRect(obj);
           this.objects.set(rect.id, rect);
-          this.group.add(rect.konvaRect);
-          groupNeedsCache = true;
+          this.objectsGroup.add(rect.konvaRect);
+          didDraw = true;
+        } else {
+          if (rect.update(obj)) {
+            didDraw = true;
+          }
         }
       }
     }
@@ -120,92 +133,91 @@ export class CanvasRegion {
     // Only update layer visibility if it has changed.
     if (this.layer.visible() !== regionState.isEnabled) {
       this.layer.visible(regionState.isEnabled);
-      groupNeedsCache = true;
     }
 
-    if (this.objects.size === 0) {
-      // No objects - clear the cache to reset the previous pixel data
-      this.group.clearCache();
-      return;
-    }
-
-    // We must clear the cache first so Konva will re-draw the group with the new compositing rect
-    if (this.group.isCached()) {
-      this.group.clearCache();
-    }
     // The user is allowed to reduce mask opacity to 0, but we need the opacity for the compositing rect to work
     this.group.opacity(1);
 
-    this.compositingRect.setAttrs({
-      // The rect should be the size of the layer - use the fast method if we don't have a pixel-perfect bbox already
-      ...(!regionState.bboxNeedsUpdate && regionState.bbox ? regionState.bbox : getNodeBboxFast(this.layer)),
-      fill: rgbColor,
-      opacity: maskOpacity,
-      // Draw this rect only where there are non-transparent pixels under it (e.g. the mask shapes)
-      globalCompositeOperation: 'source-in',
-      visible: true,
-      // This rect must always be on top of all other shapes
-      zIndex: this.objects.size + 1,
-    });
+    if (didDraw) {
+      // Convert the color to a string, stripping the alpha - the object group will handle opacity.
+      const rgbColor = rgbColorToString(regionState.fill);
+      const maskOpacity = this.manager.stateApi.getMaskOpacity();
 
-    // const isSelected = selectedEntityIdentifier?.id === regionState.id;
+      this.compositingRect.setAttrs({
+        // The rect should be the size of the layer - use the fast method if we don't have a pixel-perfect bbox already
+        ...getNodeBboxFast(this.objectsGroup),
+        fill: rgbColor,
+        opacity: maskOpacity,
+        // Draw this rect only where there are non-transparent pixels under it (e.g. the mask shapes)
+        globalCompositeOperation: 'source-in',
+        visible: true,
+        // This rect must always be on top of all other shapes
+        zIndex: this.objects.size + 1,
+      });
+    }
 
-    // /**
-    //  * When the group is selected, we use a rect of the selected preview color, composited over the shapes. This allows
-    //  * shapes to render as a "raster" layer with all pixels drawn at the same color and opacity.
-    //  *
-    //  * Without this special handling, each shape is drawn individually with the given opacity, atop the other shapes. The
-    //  * effect is like if you have a Photoshop Group consisting of many shapes, each of which has the given opacity.
-    //  * Overlapping shapes will have their colors blended together, and the final color is the result of all the shapes.
-    //  *
-    //  * Instead, with the special handling, the effect is as if you drew all the shapes at 100% opacity, flattened them to
-    //  * a single raster image, and _then_ applied the 50% opacity.
-    //  */
-    // if (isSelected && selectedTool !== 'move') {
-    //   // We must clear the cache first so Konva will re-draw the group with the new compositing rect
-    //   if (this.konvaObjectGroup.isCached()) {
-    //     this.konvaObjectGroup.clearCache();
-    //   }
-    //   // The user is allowed to reduce mask opacity to 0, but we need the opacity for the compositing rect to work
-    //   this.konvaObjectGroup.opacity(1);
+    this.updateGroup(didDraw);
+  }
 
-    //   this.compositingRect.setAttrs({
-    //     // The rect should be the size of the layer - use the fast method if we don't have a pixel-perfect bbox already
-    //     ...(!regionState.bboxNeedsUpdate && regionState.bbox ? regionState.bbox : getLayerBboxFast(this.konvaLayer)),
-    //     fill: rgbColor,
-    //     opacity: maskOpacity,
-    //     // Draw this rect only where there are non-transparent pixels under it (e.g. the mask shapes)
-    //     globalCompositeOperation: 'source-in',
-    //     visible: true,
-    //     // This rect must always be on top of all other shapes
-    //     zIndex: this.objects.size + 1,
-    //   });
-    // } else {
-    //   // The compositing rect should only be shown when the layer is selected.
-    //   this.compositingRect.visible(false);
-    //   // Cache only if needed - or if we are on this code path and _don't_ have a cache
-    //   if (groupNeedsCache || !this.konvaObjectGroup.isCached()) {
-    //     this.konvaObjectGroup.cache();
-    //   }
-    //   // Updating group opacity does not require re-caching
-    //   this.konvaObjectGroup.opacity(maskOpacity);
-    // }
+  updateGroup(didDraw: boolean) {
+    const isSelected = this.manager.stateApi.getIsSelected(this.id);
+    const selectedTool = this.manager.stateApi.getToolState().selected;
 
-    // const bboxRect =
-    //   regionMap.konvaLayer.findOne<Konva.Rect>(`.${LAYER_BBOX_NAME}`) ?? createBboxRect(rg, regionMap.konvaLayer);
-    // if (rg.bbox) {
-    //   const active = !rg.bboxNeedsUpdate && isSelected && tool === 'move';
-    //   bboxRect.setAttrs({
-    //     visible: active,
-    //     listening: active,
-    //     x: rg.bbox.x,
-    //     y: rg.bbox.y,
-    //     width: rg.bbox.width,
-    //     height: rg.bbox.height,
-    //     stroke: isSelected ? BBOX_SELECTED_STROKE : '',
-    //   });
-    // } else {
-    //   bboxRect.visible(false);
-    // }
+    if (this.objects.size === 0) {
+      // If the layer is totally empty, reset the cache and bail out.
+      this.layer.listening(false);
+      this.transformer.nodes([]);
+      if (this.group.isCached()) {
+        this.group.clearCache();
+      }
+      return;
+    }
+
+    if (isSelected && selectedTool === 'move') {
+      // When the layer is selected and being moved, we should always cache it.
+      // We should update the cache if we drew to the layer.
+      if (!this.group.isCached() || didDraw) {
+        this.group.cache();
+      }
+      // Activate the transformer
+      this.layer.listening(true);
+      this.transformer.nodes([this.group]);
+      this.transformer.forceUpdate();
+      return;
+    }
+
+    if (isSelected && selectedTool !== 'move') {
+      // If the layer is selected but not using the move tool, we don't want the layer to be listening.
+      this.layer.listening(false);
+      // The transformer also does not need to be active.
+      this.transformer.nodes([]);
+      if (isDrawingTool(selectedTool)) {
+        // We are using a drawing tool (brush, eraser, rect). These tools change the layer's rendered appearance, so we
+        // should never be cached.
+        if (this.group.isCached()) {
+          this.group.clearCache();
+        }
+      } else {
+        // We are using a non-drawing tool (move, view, bbox), so we should cache the layer.
+        // We should update the cache if we drew to the layer.
+        if (!this.group.isCached() || didDraw) {
+          this.group.cache();
+        }
+      }
+      return;
+    }
+
+    if (!isSelected) {
+      // Unselected layers should not be listening
+      this.layer.listening(false);
+      // The transformer also does not need to be active.
+      this.transformer.nodes([]);
+      // Update the layer's cache if it's not already cached or we drew to it.
+      if (!this.group.isCached() || didDraw) {
+        this.group.cache();
+      }
+
+      return;
+    }
   }
 }
diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/renderers/renderer.ts b/invokeai/frontend/web/src/features/controlLayers/konva/renderers/renderer.ts
index 1ecd2621b8..997d564c88 100644
--- a/invokeai/frontend/web/src/features/controlLayers/konva/renderers/renderer.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/konva/renderers/renderer.ts
@@ -37,6 +37,7 @@ import {
   rgImageCacheChanged,
   rgLinePointAdded,
   rgRectAdded,
+  rgScaled,
   rgTranslated,
   toolBufferChanged,
   toolChanged,
@@ -110,6 +111,8 @@ export const initializeRenderer = (
       dispatch(layerScaled(arg));
     } else if (entityType === 'inpaint_mask') {
       dispatch(imScaled(arg));
+    } else if (entityType === 'regional_guidance') {
+      dispatch(rgScaled(arg));
     }
   };
   const onBboxChanged = (arg: BboxChangedArg, entityType: CanvasEntity['type']) => {
diff --git a/invokeai/frontend/web/src/features/controlLayers/store/canvasV2Slice.ts b/invokeai/frontend/web/src/features/controlLayers/store/canvasV2Slice.ts
index e9dabfafee..0b63056b3d 100644
--- a/invokeai/frontend/web/src/features/controlLayers/store/canvasV2Slice.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/store/canvasV2Slice.ts
@@ -280,6 +280,7 @@ export const {
   rgEraserLineAdded,
   rgLinePointAdded,
   rgRectAdded,
+  rgScaled,
   // Compositing
   setInfillMethod,
   setInfillTileSize,
diff --git a/invokeai/frontend/web/src/features/controlLayers/store/regionsReducers.ts b/invokeai/frontend/web/src/features/controlLayers/store/regionsReducers.ts
index 59ad0d5314..9c7fb6e1e8 100644
--- a/invokeai/frontend/web/src/features/controlLayers/store/regionsReducers.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/store/regionsReducers.ts
@@ -1,8 +1,8 @@
 import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit';
 import { moveOneToEnd, moveOneToStart, moveToEnd, moveToStart } from 'common/util/arrayUtils';
 import { getBrushLineId, getEraserLineId, getRectShapeId } from 'features/controlLayers/konva/naming';
-import type { CanvasV2State, CLIPVisionModelV2, IPMethodV2 } from 'features/controlLayers/store/types';
-import { imageDTOToImageObject, imageDTOToImageWithDims,RGBA_RED } from 'features/controlLayers/store/types';
+import type { CanvasV2State, CLIPVisionModelV2, IPMethodV2, ScaleChangedArg } from 'features/controlLayers/store/types';
+import { imageDTOToImageObject, imageDTOToImageWithDims, RGBA_RED } from 'features/controlLayers/store/types';
 import { zModelIdentifierField } from 'features/nodes/types/common';
 import type { ParameterAutoNegative } from 'features/parameters/types/parameterSchemas';
 import type { IRect } from 'konva/lib/types';
@@ -107,6 +107,31 @@ export const regionsReducers = {
       rg.y = y;
     }
   },
+  rgScaled: (state, action: PayloadAction<ScaleChangedArg>) => {
+    const { id, scale, x, y } = action.payload;
+    const rg = selectRG(state, id);
+    if (!rg) {
+      return;
+    }
+    for (const obj of rg.objects) {
+      if (obj.type === 'brush_line') {
+        obj.points = obj.points.map((point) => point * scale);
+        obj.strokeWidth *= scale;
+      } else if (obj.type === 'eraser_line') {
+        obj.points = obj.points.map((point) => point * scale);
+        obj.strokeWidth *= scale;
+      } else if (obj.type === 'rect_shape') {
+        obj.x *= scale;
+        obj.y *= scale;
+        obj.height *= scale;
+        obj.width *= scale;
+      }
+    }
+    rg.x = x;
+    rg.y = y;
+    rg.bboxNeedsUpdate = true;
+    state.layers.imageCache = null;
+  },
   rgBboxChanged: (state, action: PayloadAction<{ id: string; bbox: IRect | null }>) => {
     const { id, bbox } = action.payload;
     const rg = selectRG(state, id);