From 708facf70773e59bd98782adb18b44880458fcce Mon Sep 17 00:00:00 2001
From: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
Date: Thu, 22 Aug 2024 10:06:58 +1000
Subject: [PATCH] tidy(ui): abstract stage logic into module

---
 .../components/CanvasResetViewButton.tsx      |   4 +-
 .../controlLayers/components/CanvasScale.tsx  |   6 +-
 .../controlLayers/konva/CanvasManager.ts      | 212 ++---------------
 .../controlLayers/konva/CanvasStageModule.ts  | 225 ++++++++++++++++++
 .../controlLayers/konva/CanvasTool.ts         |  24 +-
 .../controlLayers/konva/CanvasTransformer.ts  |  16 +-
 .../features/controlLayers/konva/events.ts    |   8 +-
 7 files changed, 276 insertions(+), 219 deletions(-)
 create mode 100644 invokeai/frontend/web/src/features/controlLayers/konva/CanvasStageModule.ts

diff --git a/invokeai/frontend/web/src/features/controlLayers/components/CanvasResetViewButton.tsx b/invokeai/frontend/web/src/features/controlLayers/components/CanvasResetViewButton.tsx
index 4b380d0c16..756a31e837 100644
--- a/invokeai/frontend/web/src/features/controlLayers/components/CanvasResetViewButton.tsx
+++ b/invokeai/frontend/web/src/features/controlLayers/components/CanvasResetViewButton.tsx
@@ -16,14 +16,14 @@ export const CanvasResetViewButton = memo(() => {
     if (!canvasManager) {
       return;
     }
-    canvasManager.setStageScale(1);
+    canvasManager.stage.setScale(1);
   }, [canvasManager]);
 
   const resetView = useCallback(() => {
     if (!canvasManager) {
       return;
     }
-    canvasManager.resetView();
+    canvasManager.stage.resetView();
   }, [canvasManager]);
 
   const onReset = useCallback(() => {
diff --git a/invokeai/frontend/web/src/features/controlLayers/components/CanvasScale.tsx b/invokeai/frontend/web/src/features/controlLayers/components/CanvasScale.tsx
index 1d2d16d61c..5d8d2ee11b 100644
--- a/invokeai/frontend/web/src/features/controlLayers/components/CanvasScale.tsx
+++ b/invokeai/frontend/web/src/features/controlLayers/components/CanvasScale.tsx
@@ -91,7 +91,7 @@ export const CanvasScale = memo(() => {
         snappedScale = snapToNearest(scale, snapCandidates, 2);
       }
       const mappedScale = mapSliderValueToScale(snappedScale);
-      canvasManager.setStageScale(mappedScale / 100);
+      canvasManager.stage.setScale(mappedScale / 100);
     },
     [canvasManager]
   );
@@ -101,11 +101,11 @@ export const CanvasScale = memo(() => {
       return;
     }
     if (isNaN(Number(localScale))) {
-      canvasManager.setStageScale(1);
+      canvasManager.stage.setScale(1);
       setLocalScale(100);
       return;
     }
-    canvasManager.setStageScale(clamp(localScale / 100, MIN_CANVAS_SCALE, MAX_CANVAS_SCALE));
+    canvasManager.stage.setScale(clamp(localScale / 100, MIN_CANVAS_SCALE, MAX_CANVAS_SCALE));
   }, [canvasManager, localScale]);
 
   const onChangeNumberInput = useCallback((valueAsString: string, valueAsNumber: number) => {
diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasManager.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasManager.ts
index 4a2456ba95..73c4d03d13 100644
--- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasManager.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasManager.ts
@@ -1,22 +1,20 @@
 import type { AppSocket } from 'app/hooks/useSocketIO';
 import { logger } from 'app/logging/logger';
 import type { AppStore } from 'app/store/store';
-import type { JSONObject, SerializableObject } from 'common/types';
+import type { SerializableObject } from 'common/types';
 import { CanvasFilter } from 'features/controlLayers/konva/CanvasFilter';
-import { MAX_CANVAS_SCALE, MIN_CANVAS_SCALE } from 'features/controlLayers/konva/constants';
+import { CanvasStageModule } from 'features/controlLayers/konva/CanvasStageModule';
 import {
   canvasToBlob,
   canvasToImageData,
   getImageDataTransparency,
   getPrefixedId,
-  getRectUnion,
   nanoid,
   previewBlob,
 } from 'features/controlLayers/konva/util';
 import type { Extents, ExtentsResult, GetBboxTask, WorkerLogMessage } from 'features/controlLayers/konva/worker';
-import type { CanvasV2State, Coordinate, Dimensions, GenerationMode, Rect } from 'features/controlLayers/store/types';
+import type { CanvasV2State, GenerationMode, Rect } from 'features/controlLayers/store/types';
 import type Konva from 'konva';
-import { clamp } from 'lodash-es';
 import { LRUCache } from 'lru-cache';
 import { atom } from 'nanostores';
 import type { Logger } from 'roarr';
@@ -40,7 +38,6 @@ export class CanvasManager {
 
   id: string;
   path: string[];
-  stage: Konva.Stage;
   container: HTMLDivElement;
   rasterLayerAdapters: Map<string, CanvasLayerAdapter> = new Map();
   controlLayerAdapters: Map<string, CanvasLayerAdapter> = new Map();
@@ -50,6 +47,7 @@ export class CanvasManager {
   preview: CanvasPreview;
   background: CanvasBackground;
   filter: CanvasFilter;
+  stage: CanvasStageModule;
 
   log: Logger;
   socket: AppSocket;
@@ -69,7 +67,6 @@ export class CanvasManager {
   constructor(stage: Konva.Stage, container: HTMLDivElement, store: AppStore, socket: AppSocket) {
     this.id = getPrefixedId(this.type);
     this.path = [this.id];
-    this.stage = stage;
     this.container = container;
     this._store = store;
     this.socket = socket;
@@ -87,11 +84,13 @@ export class CanvasManager {
       };
     });
 
+    this.stage = new CanvasStageModule(stage, container, this);
+
     this.preview = new CanvasPreview(this);
-    this.stage.add(this.preview.getLayer());
+    this.stage.addLayer(this.preview.getLayer());
 
     this.background = new CanvasBackground(this);
-    this.stage.add(this.background.konva.layer);
+    this.stage.addLayer(this.background.konva.layer);
 
     this.filter = new CanvasFilter(this);
 
@@ -169,83 +168,6 @@ export class CanvasManager {
     this.preview.getLayer().zIndex(++zIndex);
   }
 
-  fitStageToContainer() {
-    this.stage.width(this.container.offsetWidth);
-    this.stage.height(this.container.offsetHeight);
-    this.stateApi.$stageAttrs.set({
-      x: this.stage.x(),
-      y: this.stage.y(),
-      width: this.stage.width(),
-      height: this.stage.height(),
-      scale: this.stage.scaleX(),
-    });
-  }
-
-  getVisibleRect = (): Rect => {
-    const rects = [];
-
-    for (const adapter of this.inpaintMaskAdapters.values()) {
-      if (adapter.state.isEnabled) {
-        rects.push(adapter.transformer.getRelativeRect());
-      }
-    }
-
-    for (const adapter of this.rasterLayerAdapters.values()) {
-      if (adapter.state.isEnabled) {
-        rects.push(adapter.transformer.getRelativeRect());
-      }
-    }
-
-    for (const adapter of this.controlLayerAdapters.values()) {
-      if (adapter.state.isEnabled) {
-        rects.push(adapter.transformer.getRelativeRect());
-      }
-    }
-
-    for (const adapter of this.regionalGuidanceAdapters.values()) {
-      if (adapter.state.isEnabled) {
-        rects.push(adapter.transformer.getRelativeRect());
-      }
-    }
-
-    const rectUnion = getRectUnion(...rects);
-
-    if (rectUnion.width === 0 || rectUnion.height === 0) {
-      // fall back to the bbox if there is no content
-      return this.stateApi.getBbox().rect;
-    } else {
-      return rectUnion;
-    }
-  };
-
-  resetView() {
-    const { width, height } = this.getStageSize();
-    const rect = this.getVisibleRect();
-
-    const padding = 20; // Padding in absolute pixels
-
-    const availableWidth = width - padding * 2;
-    const availableHeight = height - padding * 2;
-
-    const scale = Math.min(Math.min(availableWidth / rect.width, availableHeight / rect.height), 1);
-    const x = -rect.x * scale + padding + (availableWidth - rect.width * scale) / 2;
-    const y = -rect.y * scale + padding + (availableHeight - rect.height * scale) / 2;
-
-    this.stage.setAttrs({
-      x,
-      y,
-      scaleX: scale,
-      scaleY: scale,
-    });
-
-    this.stateApi.$stageAttrs.set({
-      ...this.stateApi.$stageAttrs.get(),
-      x,
-      y,
-      scale,
-    });
-  }
-
   getTransformingLayer = (): CanvasLayerAdapter | CanvasMaskAdapter | null => {
     const transformingEntity = this.stateApi.$transformingEntity.get();
     if (!transformingEntity) {
@@ -307,6 +229,10 @@ export class CanvasManager {
     const isFirstRender = this.isFirstRender;
     this.isFirstRender = false;
 
+    if (isFirstRender) {
+      this.log.trace('First render');
+    }
+
     const prevState = this.prevState;
     this.prevState = state;
 
@@ -340,7 +266,7 @@ export class CanvasManager {
         if (!adapter) {
           adapter = new CanvasLayerAdapter(entityState, this);
           this.rasterLayerAdapters.set(adapter.id, adapter);
-          this.stage.add(adapter.konva.layer);
+          this.stage.addLayer(adapter.konva.layer);
         }
         await adapter.update({
           state: entityState,
@@ -371,7 +297,7 @@ export class CanvasManager {
         if (!adapter) {
           adapter = new CanvasLayerAdapter(entityState, this);
           this.controlLayerAdapters.set(adapter.id, adapter);
-          this.stage.add(adapter.konva.layer);
+          this.stage.addLayer(adapter.konva.layer);
         }
         await adapter.update({
           state: entityState,
@@ -408,7 +334,7 @@ export class CanvasManager {
         if (!adapter) {
           adapter = new CanvasMaskAdapter(entityState, this);
           this.regionalGuidanceAdapters.set(adapter.id, adapter);
-          this.stage.add(adapter.konva.layer);
+          this.stage.addLayer(adapter.konva.layer);
         }
         await adapter.update({
           state: entityState,
@@ -445,7 +371,7 @@ export class CanvasManager {
         if (!adapter) {
           adapter = new CanvasMaskAdapter(entityState, this);
           this.inpaintMaskAdapters.set(adapter.id, adapter);
-          this.stage.add(adapter.konva.layer);
+          this.stage.addLayer(adapter.konva.layer);
         }
         await adapter.update({
           state: entityState,
@@ -488,23 +414,15 @@ export class CanvasManager {
   };
 
   initialize = () => {
-    this.log.debug('Initializing renderer');
-    this.stage.container(this.container);
+    this.log.debug('Initializing canvas manager');
 
     const unsubscribeListeners = setStageEventHandlers(this);
 
-    // We can use a resize observer to ensure the stage always fits the container. We also need to re-render the bg and
-    // document bounds overlay when the stage is resized.
-    const resizeObserver = new ResizeObserver(this.fitStageToContainer.bind(this));
-    resizeObserver.observe(this.container);
-    this.fitStageToContainer();
-
+    const cleanupStage = this.stage.initialize();
     const unsubscribeRenderer = this._store.subscribe(this.render);
 
-    this.log.debug('First render of konva stage');
-
     return () => {
-      this.log.debug('Cleaning up konva renderer');
+      this.log.debug('Cleaning up canvas manager');
       const allAdapters = [
         ...this.rasterLayerAdapters.values(),
         ...this.controlLayerAdapters.values(),
@@ -518,96 +436,10 @@ export class CanvasManager {
       this.preview.destroy();
       unsubscribeRenderer();
       unsubscribeListeners();
-      resizeObserver.disconnect();
+      cleanupStage();
     };
   };
 
-  /**
-   * Gets the center of the stage in either absolute or relative coordinates
-   * @param absolute Whether to return the center in absolute coordinates
-   */
-  getStageCenter(absolute = false): Coordinate {
-    const scale = this.getStageScale();
-    const { x, y } = this.getStagePosition();
-    const { width, height } = this.getStageSize();
-
-    const center = {
-      x: (width / 2 - x) / scale,
-      y: (height / 2 - y) / scale,
-    };
-
-    if (!absolute) {
-      return center;
-    }
-
-    return this.stage.getAbsoluteTransform().point(center);
-  }
-
-  /**
-   * Sets the scale of the stage. If center is provided, the stage will zoom in/out on that point.
-   * @param scale The new scale to set
-   * @param center The center of the stage to zoom in/out on
-   */
-  setStageScale(scale: number, center: Coordinate = this.getStageCenter(true)) {
-    const newScale = clamp(Math.round(scale * 100) / 100, MIN_CANVAS_SCALE, MAX_CANVAS_SCALE);
-
-    const { x, y } = this.getStagePosition();
-    const oldScale = this.getStageScale();
-
-    const deltaX = (center.x - x) / oldScale;
-    const deltaY = (center.y - y) / oldScale;
-
-    const newX = center.x - deltaX * newScale;
-    const newY = center.y - deltaY * newScale;
-
-    this.stage.setAttrs({
-      x: newX,
-      y: newY,
-      scaleX: newScale,
-      scaleY: newScale,
-    });
-
-    this.stateApi.$stageAttrs.set({
-      x: Math.floor(this.stage.x()),
-      y: Math.floor(this.stage.y()),
-      width: this.stage.width(),
-      height: this.stage.height(),
-      scale: this.stage.scaleX(),
-    });
-  }
-
-  /**
-   * Gets the scale of the stage. The stage is always scaled uniformly in x and y.
-   */
-  getStageScale(): number {
-    // The stage is never scaled differently in x and y
-    return this.stage.scaleX();
-  }
-
-  /**
-   * Gets the position of the stage.
-   */
-  getStagePosition(): Coordinate {
-    return this.stage.position();
-  }
-
-  /**
-   * Gets the size of the stage.
-   */
-  getStageSize(): Dimensions {
-    return this.stage.size();
-  }
-
-  /**
-   * Scales a number of pixels by the current stage scale. For example, if the stage is scaled by 5, then 10 pixels
-   * would be scaled to 10px / 5 = 2 pixels.
-   * @param pixels The number of pixels to scale
-   * @returns The number of pixels scaled by the current stage scale
-   */
-  getScaledPixels(pixels: number): number {
-    return pixels / this.getStageScale();
-  }
-
   clearCaches = () => {
     this.canvasCache.clear();
     this.imageNameCache.clear();
@@ -820,13 +652,13 @@ export class CanvasManager {
     return generationMode;
   }
 
-  getLoggingContext = (): JSONObject => {
+  getLoggingContext = (): SerializableObject => {
     return {
       path: this.path.join('.'),
     };
   };
 
-  buildLogger = (getContext: () => JSONObject): Logger => {
+  buildLogger = (getContext: () => SerializableObject): Logger => {
     return this.log.child((message) => {
       return {
         ...message,
diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStageModule.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStageModule.ts
new file mode 100644
index 0000000000..85504b527b
--- /dev/null
+++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStageModule.ts
@@ -0,0 +1,225 @@
+import type { SerializableObject } from 'common/types';
+import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
+import { getPrefixedId, getRectUnion } from 'features/controlLayers/konva/util';
+import type { Coordinate, Dimensions, Rect } from 'features/controlLayers/store/types';
+import type Konva from 'konva';
+import { clamp } from 'lodash-es';
+import type { Logger } from 'roarr';
+
+export class CanvasStageModule {
+  static MIN_CANVAS_SCALE = 0.1;
+  static MAX_CANVAS_SCALE = 20;
+
+  id: string;
+  path: string[];
+  konva: { stage: Konva.Stage };
+  manager: CanvasManager;
+  container: HTMLDivElement;
+  log: Logger;
+
+  constructor(stage: Konva.Stage, container: HTMLDivElement, manager: CanvasManager) {
+    this.id = getPrefixedId('stage');
+    this.manager = manager;
+    this.path = this.manager.path.concat(this.id);
+    this.log = this.manager.buildLogger(this.getLoggingContext);
+    this.log.debug('Creating stage module');
+    this.container = container;
+    this.konva = { stage };
+  }
+
+  initialize = () => {
+    this.log.debug('Initializing stage');
+    this.konva.stage.container(this.container);
+    const resizeObserver = new ResizeObserver(this.fitStageToContainer);
+    resizeObserver.observe(this.container);
+    this.fitStageToContainer();
+
+    return () => {
+      this.log.debug('Destroying stage');
+      resizeObserver.disconnect();
+      this.konva.stage.destroy();
+    };
+  };
+
+  fitStageToContainer = () => {
+    this.log.trace('Fitting stage to container');
+    this.konva.stage.width(this.konva.stage.container().offsetWidth);
+    this.konva.stage.height(this.konva.stage.container().offsetHeight);
+    this.manager.stateApi.$stageAttrs.set({
+      x: this.konva.stage.x(),
+      y: this.konva.stage.y(),
+      width: this.konva.stage.width(),
+      height: this.konva.stage.height(),
+      scale: this.konva.stage.scaleX(),
+    });
+  };
+
+  getVisibleRect = (): Rect => {
+    const rects = [];
+
+    for (const adapter of this.manager.inpaintMaskAdapters.values()) {
+      if (adapter.state.isEnabled) {
+        rects.push(adapter.transformer.getRelativeRect());
+      }
+    }
+
+    for (const adapter of this.manager.rasterLayerAdapters.values()) {
+      if (adapter.state.isEnabled) {
+        rects.push(adapter.transformer.getRelativeRect());
+      }
+    }
+
+    for (const adapter of this.manager.controlLayerAdapters.values()) {
+      if (adapter.state.isEnabled) {
+        rects.push(adapter.transformer.getRelativeRect());
+      }
+    }
+
+    for (const adapter of this.manager.regionalGuidanceAdapters.values()) {
+      if (adapter.state.isEnabled) {
+        rects.push(adapter.transformer.getRelativeRect());
+      }
+    }
+
+    const rectUnion = getRectUnion(...rects);
+
+    if (rectUnion.width === 0 || rectUnion.height === 0) {
+      // fall back to the bbox if there is no content
+      return this.manager.stateApi.getBbox().rect;
+    } else {
+      return rectUnion;
+    }
+  };
+
+  resetView() {
+    this.log.trace('Resetting view');
+    const { width, height } = this.getSize();
+    const rect = this.getVisibleRect();
+
+    const padding = 20; // Padding in absolute pixels
+
+    const availableWidth = width - padding * 2;
+    const availableHeight = height - padding * 2;
+
+    const scale = Math.min(Math.min(availableWidth / rect.width, availableHeight / rect.height), 1);
+    const x = -rect.x * scale + padding + (availableWidth - rect.width * scale) / 2;
+    const y = -rect.y * scale + padding + (availableHeight - rect.height * scale) / 2;
+
+    this.konva.stage.setAttrs({
+      x,
+      y,
+      scaleX: scale,
+      scaleY: scale,
+    });
+
+    this.manager.stateApi.$stageAttrs.set({
+      ...this.manager.stateApi.$stageAttrs.get(),
+      x,
+      y,
+      scale,
+    });
+  }
+
+  /**
+   * Gets the center of the stage in either absolute or relative coordinates
+   * @param absolute Whether to return the center in absolute coordinates
+   */
+  getCenter = (absolute = false): Coordinate => {
+    const scale = this.getScale();
+    const { x, y } = this.getPosition();
+    const { width, height } = this.getSize();
+
+    const center = {
+      x: (width / 2 - x) / scale,
+      y: (height / 2 - y) / scale,
+    };
+
+    if (!absolute) {
+      return center;
+    }
+
+    return this.konva.stage.getAbsoluteTransform().point(center);
+  };
+
+  /**
+   * Sets the scale of the stage. If center is provided, the stage will zoom in/out on that point.
+   * @param scale The new scale to set
+   * @param center The center of the stage to zoom in/out on
+   */
+  setScale = (scale: number, center: Coordinate = this.getCenter(true)) => {
+    this.log.trace('Setting scale');
+    const newScale = clamp(
+      Math.round(scale * 100) / 100,
+      CanvasStageModule.MIN_CANVAS_SCALE,
+      CanvasStageModule.MAX_CANVAS_SCALE
+    );
+
+    const { x, y } = this.getPosition();
+    const oldScale = this.getScale();
+
+    const deltaX = (center.x - x) / oldScale;
+    const deltaY = (center.y - y) / oldScale;
+
+    const newX = center.x - deltaX * newScale;
+    const newY = center.y - deltaY * newScale;
+
+    this.konva.stage.setAttrs({
+      x: newX,
+      y: newY,
+      scaleX: newScale,
+      scaleY: newScale,
+    });
+
+    this.manager.stateApi.$stageAttrs.set({
+      x: Math.floor(this.konva.stage.x()),
+      y: Math.floor(this.konva.stage.y()),
+      width: this.konva.stage.width(),
+      height: this.konva.stage.height(),
+      scale: this.konva.stage.scaleX(),
+    });
+  };
+
+  /**
+   * Gets the scale of the stage. The stage is always scaled uniformly in x and y.
+   */
+  getScale = (): number => {
+    // The stage is never scaled differently in x and y
+    return this.konva.stage.scaleX();
+  };
+
+  /**
+   * Gets the position of the stage.
+   */
+  getPosition = (): Coordinate => {
+    return this.konva.stage.position();
+  };
+
+  /**
+   * Gets the size of the stage.
+   */
+  getSize(): Dimensions {
+    return this.konva.stage.size();
+  }
+
+  /**
+   * Scales a number of pixels by the current stage scale. For example, if the stage is scaled by 5, then 10 pixels
+   * would be scaled to 10px / 5 = 2 pixels.
+   * @param pixels The number of pixels to scale
+   * @returns The number of pixels scaled by the current stage scale
+   */
+  getScaledPixels = (pixels: number): number => {
+    return pixels / this.getScale();
+  };
+
+  setIsDraggable = (isDraggable: boolean) => {
+    this.konva.stage.draggable(isDraggable);
+  };
+
+  addLayer = (layer: Konva.Layer) => {
+    this.konva.stage.add(layer);
+  };
+
+  getLoggingContext = (): SerializableObject => {
+    return { ...this.manager.getLoggingContext(), path: this.path.join('.') };
+  };
+}
diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool.ts
index 28dd794470..94a2071a7d 100644
--- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool.ts
@@ -159,7 +159,7 @@ export class CanvasTool {
 
   scaleTool = () => {
     const toolState = this.manager.stateApi.getToolState();
-    const scale = this.manager.stage.scaleX();
+    const scale = this.manager.stage.getScale();
 
     const brushRadius = toolState.brush.width / 2;
     this.konva.brush.innerBorderCircle.strokeWidth(BRUSH_ERASER_BORDER_WIDTH / scale);
@@ -199,34 +199,34 @@ export class CanvasTool {
     // Update the stage's pointer style
     if (Boolean(this.manager.stateApi.$transformingEntity.get()) || renderedEntityCount === 0) {
       // We are transforming and/or have no layers, so we should not render any tool
-      stage.container().style.cursor = 'default';
+      stage.container.style.cursor = 'default';
     } else if (tool === 'view') {
       // view tool gets a hand
-      stage.container().style.cursor = isMouseDown ? 'grabbing' : 'grab';
+      stage.container.style.cursor = isMouseDown ? 'grabbing' : 'grab';
       // Bbox tool gets default
     } else if (tool === 'bbox') {
-      stage.container().style.cursor = 'default';
+      stage.container.style.cursor = 'default';
     } else if (tool === 'eyeDropper') {
       // Eyedropper gets none
-      stage.container().style.cursor = 'none';
+      stage.container.style.cursor = 'none';
     } else if (isDrawable) {
       if (tool === 'move') {
         // Move gets default arrow
-        stage.container().style.cursor = 'default';
+        stage.container.style.cursor = 'default';
       } else if (tool === 'rect') {
         // Rect gets a crosshair
-        stage.container().style.cursor = 'crosshair';
+        stage.container.style.cursor = 'crosshair';
       } else if (tool === 'brush' || tool === 'eraser') {
         // Hide the native cursor and use the konva-rendered brush preview
-        stage.container().style.cursor = 'none';
+        stage.container.style.cursor = 'none';
       }
     } else {
       // isDrawable === 'false'
       // Non-drawable layers don't have tools
-      stage.container().style.cursor = 'not-allowed';
+      stage.container.style.cursor = 'not-allowed';
     }
 
-    stage.draggable(tool === 'view');
+    stage.setIsDraggable(tool === 'view');
 
     if (!cursorPos || renderedEntityCount === 0 || !isDrawable) {
       // We can bail early if the mouse isn't over the stage or there are no layers
@@ -238,7 +238,7 @@ export class CanvasTool {
       if (cursorPos && tool === 'brush') {
         const brushPreviewFill = this.manager.stateApi.getBrushPreviewFill();
         const alignedCursorPos = alignCoordForTool(cursorPos, toolState.brush.width);
-        const scale = stage.scaleX();
+        const scale = stage.getScale();
         // Update the fill circle
         const radius = toolState.brush.width / 2;
 
@@ -261,7 +261,7 @@ export class CanvasTool {
       } else if (cursorPos && tool === 'eraser') {
         const alignedCursorPos = alignCoordForTool(cursorPos, toolState.eraser.width);
 
-        const scale = stage.scaleX();
+        const scale = stage.getScale();
         // Update the fill circle
         const radius = toolState.eraser.width / 2;
         this.konva.eraser.fillCircle.setAttrs({
diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTransformer.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTransformer.ts
index 8d9940af1c..9ea6b5ee24 100644
--- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTransformer.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTransformer.ts
@@ -174,8 +174,8 @@ export class CanvasTransformer {
           // We need to snap the anchor to the nearest pixel, but the positions provided to this callback are absolute,
           // scaled coordinates. They need to be converted to stage coordinates, snapped, then converted back to absolute
           // before returning them.
-          const stageScale = this.manager.getStageScale();
-          const stagePos = this.manager.getStagePosition();
+          const stageScale = this.manager.stage.getScale();
+          const stagePos = this.manager.stage.getPosition();
 
           // Unscale and round the target position to the nearest pixel.
           const targetX = Math.round(newPos.x / stageScale);
@@ -335,8 +335,8 @@ export class CanvasTransformer {
       // The bbox should be updated to reflect the new position of the interaction rect, taking into account its padding
       // and border
       this.konva.outlineRect.setAttrs({
-        x: this.konva.proxyRect.x() - this.manager.getScaledPixels(CanvasTransformer.OUTLINE_PADDING),
-        y: this.konva.proxyRect.y() - this.manager.getScaledPixels(CanvasTransformer.OUTLINE_PADDING),
+        x: this.konva.proxyRect.x() - this.manager.stage.getScaledPixels(CanvasTransformer.OUTLINE_PADDING),
+        y: this.konva.proxyRect.y() - this.manager.stage.getScaledPixels(CanvasTransformer.OUTLINE_PADDING),
       });
 
       // The object group is translated by the difference between the interaction rect's new and old positions (which is
@@ -407,8 +407,8 @@ export class CanvasTransformer {
    * @param bbox The bounding box of the parent entity
    */
   update = (position: Coordinate, bbox: Rect) => {
-    const onePixel = this.manager.getScaledPixels(1);
-    const bboxPadding = this.manager.getScaledPixels(CanvasTransformer.OUTLINE_PADDING);
+    const onePixel = this.manager.stage.getScaledPixels(1);
+    const bboxPadding = this.manager.stage.getScaledPixels(CanvasTransformer.OUTLINE_PADDING);
 
     this.konva.outlineRect.setAttrs({
       x: position.x + bbox.x - bboxPadding,
@@ -474,8 +474,8 @@ export class CanvasTransformer {
    * Updates the transformer's scale. This is called when the stage is scaled.
    */
   syncScale = () => {
-    const onePixel = this.manager.getScaledPixels(1);
-    const bboxPadding = this.manager.getScaledPixels(CanvasTransformer.OUTLINE_PADDING);
+    const onePixel = this.manager.stage.getScaledPixels(1);
+    const bboxPadding = this.manager.stage.getScaledPixels(CanvasTransformer.OUTLINE_PADDING);
 
     this.konva.outlineRect.setAttrs({
       x: this.konva.proxyRect.x() - bboxPadding,
diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/events.ts b/invokeai/frontend/web/src/features/controlLayers/konva/events.ts
index 0e2f79a51b..a514d65e98 100644
--- a/invokeai/frontend/web/src/features/controlLayers/konva/events.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/konva/events.ts
@@ -135,7 +135,7 @@ const getColorUnderCursor = (stage: Konva.Stage): RgbaColor | null => {
 };
 
 export const setStageEventHandlers = (manager: CanvasManager): (() => void) => {
-  const { stage, stateApi } = manager;
+  const stage = manager.stage.konva.stage;
   const {
     getToolState,
     setTool,
@@ -152,7 +152,7 @@ export const setStageEventHandlers = (manager: CanvasManager): (() => void) => {
     setEraserWidth,
     getCurrentFill,
     getSelectedEntity,
-  } = stateApi;
+  } = manager.stateApi;
 
   function getIsPrimaryMouseDown(e: KonvaEventObject<MouseEvent>) {
     return e.evt.buttons === 1;
@@ -496,8 +496,8 @@ export const setStageEventHandlers = (manager: CanvasManager): (() => void) => {
       if (cursorPos) {
         // When wheeling on trackpad, e.evt.ctrlKey is true - in that case, let's reverse the direction
         const delta = e.evt.ctrlKey ? -e.evt.deltaY : e.evt.deltaY;
-        const scale = manager.getStageScale() * CANVAS_SCALE_BY ** delta;
-        manager.setStageScale(scale, cursorPos);
+        const scale = manager.stage.getScale() * CANVAS_SCALE_BY ** delta;
+        manager.stage.setScale(scale, cursorPos);
       }
     }
     manager.preview.tool.render();