feat(ui): tidy up atoms

This commit is contained in:
psychedelicious 2024-08-08 18:52:56 +10:00
parent c90d3f3bb9
commit 11059ee2d4
7 changed files with 137 additions and 137 deletions

View File

@ -19,7 +19,7 @@ export const TransformToolButton = memo(() => {
if (!canvasManager) { if (!canvasManager) {
return; return;
} }
return canvasManager.$transformingEntity.listen((newValue) => { return canvasManager.stateApi.$transformingEntity.listen((newValue) => {
setIsTransforming(Boolean(newValue)); setIsTransforming(Boolean(newValue));
}); });
}, [canvasManager]); }, [canvasManager]);

View File

@ -18,24 +18,16 @@ import {
} from 'features/controlLayers/konva/util'; } from 'features/controlLayers/konva/util';
import type { Extents, ExtentsResult, GetBboxTask, WorkerLogMessage } from 'features/controlLayers/konva/worker'; import type { Extents, ExtentsResult, GetBboxTask, WorkerLogMessage } from 'features/controlLayers/konva/worker';
import type { import type {
CanvasControlAdapterState,
CanvasEntityIdentifier,
CanvasInpaintMaskState,
CanvasLayerState,
CanvasRegionalGuidanceState,
CanvasV2State, CanvasV2State,
Coordinate, Coordinate,
Dimensions, Dimensions,
GenerationMode, GenerationMode,
GetLoggingContext, GetLoggingContext,
Rect, Rect,
RgbaColor,
} from 'features/controlLayers/store/types'; } from 'features/controlLayers/store/types';
import { RGBA_RED } from 'features/controlLayers/store/types';
import { isValidLayer } from 'features/nodes/util/graph/generation/addLayers'; import { isValidLayer } from 'features/nodes/util/graph/generation/addLayers';
import type Konva from 'konva'; import type Konva from 'konva';
import { clamp } from 'lodash-es'; import { clamp } from 'lodash-es';
import type { WritableAtom } from 'nanostores';
import { atom } from 'nanostores'; import { atom } from 'nanostores';
import type { Logger } from 'roarr'; import type { Logger } from 'roarr';
import { getImageDTO, uploadImage } from 'services/api/endpoints/images'; import { getImageDTO, uploadImage } from 'services/api/endpoints/images';
@ -51,32 +43,6 @@ import { CanvasStagingArea } from './CanvasStagingArea';
import { CanvasStateApi } from './CanvasStateApi'; import { CanvasStateApi } from './CanvasStateApi';
import { setStageEventHandlers } from './events'; import { setStageEventHandlers } from './events';
type EntityStateAndAdapter =
| {
id: string;
type: CanvasLayerState['type'];
state: CanvasLayerState;
adapter: CanvasLayerAdapter;
}
| {
id: string;
type: CanvasInpaintMaskState['type'];
state: CanvasInpaintMaskState;
adapter: CanvasMaskAdapter;
}
// | {
// id: string;
// type: CanvasControlAdapterState['type'];
// state: CanvasControlAdapterState;
// adapter: CanvasControlAdapter;
// }
| {
id: string;
type: CanvasRegionalGuidanceState['type'];
state: CanvasRegionalGuidanceState;
adapter: CanvasMaskAdapter;
};
export const $canvasManager = atom<CanvasManager | null>(null); export const $canvasManager = atom<CanvasManager | null>(null);
export class CanvasManager { export class CanvasManager {
@ -101,12 +67,6 @@ export class CanvasManager {
_worker: Worker = new Worker(new URL('./worker.ts', import.meta.url), { type: 'module', name: 'worker' }); _worker: Worker = new Worker(new URL('./worker.ts', import.meta.url), { type: 'module', name: 'worker' });
_tasks: Map<string, { task: GetBboxTask; onComplete: (extents: Extents | null) => void }> = new Map(); _tasks: Map<string, { task: GetBboxTask; onComplete: (extents: Extents | null) => void }> = new Map();
$transformingEntity: WritableAtom<CanvasEntityIdentifier | null> = atom();
$toolState: WritableAtom<CanvasV2State['tool']> = atom();
$currentFill: WritableAtom<RgbaColor> = atom();
$selectedEntity: WritableAtom<EntityStateAndAdapter | null> = atom();
$selectedEntityIdentifier: WritableAtom<CanvasEntityIdentifier | null> = atom();
constructor(stage: Konva.Stage, container: HTMLDivElement, store: Store<RootState>) { constructor(stage: Konva.Stage, container: HTMLDivElement, store: Store<RootState>) {
this.stage = stage; this.stage = stage;
this.container = container; this.container = container;
@ -160,11 +120,11 @@ export class CanvasManager {
this.log.error('Worker message error'); this.log.error('Worker message error');
}; };
this.$transformingEntity.set(null); this.stateApi.$transformingEntity.set(null);
this.$toolState.set(this.stateApi.getToolState()); this.stateApi.$toolState.set(this.stateApi.getToolState());
this.$selectedEntityIdentifier.set(this.stateApi.getState().selectedEntityIdentifier); this.stateApi.$selectedEntityIdentifier.set(this.stateApi.getState().selectedEntityIdentifier);
this.$currentFill.set(this.getCurrentFill()); this.stateApi.$currentFill.set(this.stateApi.getCurrentFill());
this.$selectedEntity.set(this.getSelectedEntity()); this.stateApi.$selectedEntity.set(this.stateApi.getSelectedEntity());
this.inpaintMask = new CanvasMaskAdapter(this.stateApi.getInpaintMaskState(), this); this.inpaintMask = new CanvasMaskAdapter(this.stateApi.getInpaintMaskState(), this);
this.stage.add(this.inpaintMask.konva.layer); this.stage.add(this.inpaintMask.konva.layer);
@ -270,80 +230,8 @@ export class CanvasManager {
}); });
} }
getEntity(identifier: CanvasEntityIdentifier): EntityStateAndAdapter | null {
const state = this.stateApi.getState();
let entityState:
| CanvasLayerState
| CanvasControlAdapterState
| CanvasRegionalGuidanceState
| CanvasInpaintMaskState
| null = null;
let entityAdapter: CanvasLayerAdapter | CanvasControlAdapter | CanvasMaskAdapter | null = null;
if (identifier.type === 'layer') {
entityState = state.layers.entities.find((i) => i.id === identifier.id) ?? null;
entityAdapter = this.layers.get(identifier.id) ?? null;
} else if (identifier.type === 'control_adapter') {
entityState = state.controlAdapters.entities.find((i) => i.id === identifier.id) ?? null;
entityAdapter = this.controlAdapters.get(identifier.id) ?? null;
} else if (identifier.type === 'regional_guidance') {
entityState = state.regions.entities.find((i) => i.id === identifier.id) ?? null;
entityAdapter = this.regions.get(identifier.id) ?? null;
} else if (identifier.type === 'inpaint_mask') {
entityState = state.inpaintMask;
entityAdapter = this.inpaintMask;
}
if (entityState && entityAdapter && entityState.type === entityAdapter.type) {
return {
id: entityState.id,
type: entityState.type,
state: entityState,
adapter: entityAdapter,
} as EntityStateAndAdapter; // TODO(psyche): make TS happy w/o this cast
}
return null;
}
getSelectedEntity = () => {
const state = this.stateApi.getState();
if (state.selectedEntityIdentifier) {
return this.getEntity(state.selectedEntityIdentifier);
}
return null;
};
getCurrentFill = () => {
const state = this.stateApi.getState();
let currentFill: RgbaColor = state.tool.fill;
const selectedEntity = this.getSelectedEntity();
if (selectedEntity) {
// These two entity types use a compositing rect for opacity. Their fill is always white.
if (selectedEntity.state.type === 'regional_guidance' || selectedEntity.state.type === 'inpaint_mask') {
currentFill = RGBA_RED;
// currentFill = RGBA_WHITE;
}
}
return currentFill;
};
getBrushPreviewFill = () => {
const state = this.stateApi.getState();
let currentFill: RgbaColor = state.tool.fill;
const selectedEntity = this.getSelectedEntity();
if (selectedEntity) {
// The brush should use the mask opacity for these entity types
if (selectedEntity.state.type === 'regional_guidance' || selectedEntity.state.type === 'inpaint_mask') {
currentFill = { ...selectedEntity.state.fill, a: this.stateApi.getSettings().maskOpacity };
}
}
return currentFill;
};
getTransformingLayer() { getTransformingLayer() {
const transformingEntity = this.$transformingEntity.get(); const transformingEntity = this.stateApi.$transformingEntity.get();
if (!transformingEntity) { if (!transformingEntity) {
return null; return null;
} }
@ -362,21 +250,21 @@ export class CanvasManager {
} }
getIsTransforming() { getIsTransforming() {
return Boolean(this.$transformingEntity.get()); return Boolean(this.stateApi.$transformingEntity.get());
} }
startTransform() { startTransform() {
if (this.getIsTransforming()) { if (this.getIsTransforming()) {
return; return;
} }
const entity = this.getSelectedEntity(); const entity = this.stateApi.getSelectedEntity();
if (!entity) { if (!entity) {
this.log.warn('No entity selected to transform'); this.log.warn('No entity selected to transform');
return; return;
} }
// TODO(psyche): Support other entity types // TODO(psyche): Support other entity types
entity.adapter.transformer.startTransform(); entity.adapter.transformer.startTransform();
this.$transformingEntity.set({ id: entity.id, type: entity.type }); this.stateApi.$transformingEntity.set({ id: entity.id, type: entity.type });
} }
async applyTransform() { async applyTransform() {
@ -384,7 +272,7 @@ export class CanvasManager {
if (layer) { if (layer) {
await layer.transformer.applyTransform(); await layer.transformer.applyTransform();
} }
this.$transformingEntity.set(null); this.stateApi.$transformingEntity.set(null);
} }
cancelTransform() { cancelTransform() {
@ -392,7 +280,7 @@ export class CanvasManager {
if (layer) { if (layer) {
layer.transformer.stopTransform(); layer.transformer.stopTransform();
} }
this.$transformingEntity.set(null); this.stateApi.$transformingEntity.set(null);
} }
render = async () => { render = async () => {
@ -485,10 +373,10 @@ export class CanvasManager {
await this.renderControlAdapters(); await this.renderControlAdapters();
} }
this.$toolState.set(state.tool); this.stateApi.$toolState.set(state.tool);
this.$selectedEntityIdentifier.set(state.selectedEntityIdentifier); this.stateApi.$selectedEntityIdentifier.set(state.selectedEntityIdentifier);
this.$selectedEntity.set(this.getSelectedEntity()); this.stateApi.$selectedEntity.set(this.stateApi.getSelectedEntity());
this.$currentFill.set(this.getCurrentFill()); this.stateApi.$currentFill.set(this.stateApi.getCurrentFill());
if ( if (
this._isFirstRender || this._isFirstRender ||
@ -709,7 +597,7 @@ export class CanvasManager {
}; };
getRegionMaskImageDTO = async (id: string, rect?: Rect): Promise<ImageDTO> => { getRegionMaskImageDTO = async (id: string, rect?: Rect): Promise<ImageDTO> => {
const region = this.getEntity({ id, type: 'regional_guidance' }); const region = this.stateApi.getEntity({ id, type: 'regional_guidance' });
assert(region?.type === 'regional_guidance'); assert(region?.type === 'regional_guidance');
if (region.state.imageCache) { if (region.state.imageCache) {
const imageDTO = await getImageDTO(region.state.imageCache); const imageDTO = await getImageDTO(region.state.imageCache);

View File

@ -116,7 +116,7 @@ export class CanvasObjectRenderer {
} }
this.subscriptions.add( this.subscriptions.add(
this.manager.$toolState.listen((newVal, oldVal) => { this.manager.stateApi.$toolState.listen((newVal, oldVal) => {
if (newVal.selected !== oldVal.selected) { if (newVal.selected !== oldVal.selected) {
this.commitBuffer(); this.commitBuffer();
} }

View File

@ -2,7 +2,9 @@ import { $alt, $ctrl, $meta, $shift } from '@invoke-ai/ui-library';
import type { Store } from '@reduxjs/toolkit'; import type { Store } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import type { CanvasLayerAdapter } from 'features/controlLayers/konva/CanvasLayerAdapter';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import type { CanvasMaskAdapter } from 'features/controlLayers/konva/CanvasMaskAdapter';
import { import {
$isDrawing, $isDrawing,
$isMouseDown, $isMouseDown,
@ -30,6 +32,11 @@ import {
toolChanged, toolChanged,
} from 'features/controlLayers/store/canvasV2Slice'; } from 'features/controlLayers/store/canvasV2Slice';
import type { import type {
CanvasEntityIdentifier,
CanvasInpaintMaskState,
CanvasLayerState,
CanvasRegionalGuidanceState,
CanvasV2State,
EntityBrushLineAddedPayload, EntityBrushLineAddedPayload,
EntityEraserLineAddedPayload, EntityEraserLineAddedPayload,
EntityIdentifierPayload, EntityIdentifierPayload,
@ -37,10 +44,40 @@ import type {
EntityRasterizedPayload, EntityRasterizedPayload,
EntityRectAddedPayload, EntityRectAddedPayload,
Rect, Rect,
RgbaColor,
Tool, Tool,
} from 'features/controlLayers/store/types'; } from 'features/controlLayers/store/types';
import { RGBA_RED } from 'features/controlLayers/store/types';
import type { WritableAtom } from 'nanostores';
import { atom } from 'nanostores';
import type { ImageDTO } from 'services/api/types'; import type { ImageDTO } from 'services/api/types';
type EntityStateAndAdapter =
| {
id: string;
type: CanvasLayerState['type'];
state: CanvasLayerState;
adapter: CanvasLayerAdapter;
}
| {
id: string;
type: CanvasInpaintMaskState['type'];
state: CanvasInpaintMaskState;
adapter: CanvasMaskAdapter;
}
// | {
// id: string;
// type: CanvasControlAdapterState['type'];
// state: CanvasControlAdapterState;
// adapter: CanvasControlAdapter;
// }
| {
id: string;
type: CanvasRegionalGuidanceState['type'];
state: CanvasRegionalGuidanceState;
adapter: CanvasMaskAdapter;
};
const log = logger('canvas'); const log = logger('canvas');
export class CanvasStateApi { export class CanvasStateApi {
@ -152,6 +189,79 @@ export class CanvasStateApi {
return this._store.getState().system.consoleLogLevel; return this._store.getState().system.consoleLogLevel;
}; };
getEntity(identifier: CanvasEntityIdentifier): EntityStateAndAdapter | null {
const state = this.getState();
let entityState: EntityStateAndAdapter['state'] | null = null;
let entityAdapter: EntityStateAndAdapter['adapter'] | null = null;
if (identifier.type === 'layer') {
entityState = state.layers.entities.find((i) => i.id === identifier.id) ?? null;
entityAdapter = this.manager.layers.get(identifier.id) ?? null;
} else if (identifier.type === 'control_adapter') {
entityState = state.controlAdapters.entities.find((i) => i.id === identifier.id) ?? null;
entityAdapter = this.manager.controlAdapters.get(identifier.id) ?? null;
} else if (identifier.type === 'regional_guidance') {
entityState = state.regions.entities.find((i) => i.id === identifier.id) ?? null;
entityAdapter = this.manager.regions.get(identifier.id) ?? null;
} else if (identifier.type === 'inpaint_mask') {
entityState = state.inpaintMask;
entityAdapter = this.manager.inpaintMask;
}
if (entityState && entityAdapter && entityState.type === entityAdapter.type) {
return {
id: entityState.id,
type: entityState.type,
state: entityState,
adapter: entityAdapter,
} as EntityStateAndAdapter; // TODO(psyche): make TS happy w/o this cast
}
return null;
}
getSelectedEntity = () => {
const state = this.getState();
if (state.selectedEntityIdentifier) {
return this.getEntity(state.selectedEntityIdentifier);
}
return null;
};
getCurrentFill = () => {
const state = this.getState();
let currentFill: RgbaColor = state.tool.fill;
const selectedEntity = this.getSelectedEntity();
if (selectedEntity) {
// These two entity types use a compositing rect for opacity. Their fill is always white.
if (selectedEntity.state.type === 'regional_guidance' || selectedEntity.state.type === 'inpaint_mask') {
currentFill = RGBA_RED;
// currentFill = RGBA_WHITE;
}
}
return currentFill;
};
getBrushPreviewFill = () => {
const state = this.getState();
let currentFill: RgbaColor = state.tool.fill;
const selectedEntity = this.getSelectedEntity();
if (selectedEntity) {
// The brush should use the mask opacity for these entity types
if (selectedEntity.state.type === 'regional_guidance' || selectedEntity.state.type === 'inpaint_mask') {
currentFill = { ...selectedEntity.state.fill, a: this.getSettings().maskOpacity };
}
}
return currentFill;
};
$transformingEntity: WritableAtom<CanvasEntityIdentifier | null> = atom();
$toolState: WritableAtom<CanvasV2State['tool']> = atom();
$currentFill: WritableAtom<RgbaColor> = atom();
$selectedEntity: WritableAtom<EntityStateAndAdapter | null> = atom();
$selectedEntityIdentifier: WritableAtom<CanvasEntityIdentifier | null> = atom();
// Read-write state, ephemeral interaction state // Read-write state, ephemeral interaction state
$isDrawing = $isDrawing; $isDrawing = $isDrawing;
$isMouseDown = $isMouseDown; $isMouseDown = $isMouseDown;

View File

@ -118,7 +118,7 @@ export class CanvasTool {
); );
this.subscriptions.add( this.subscriptions.add(
this.manager.$toolState.listen(() => { this.manager.stateApi.$toolState.listen(() => {
this.render(); this.render();
}) })
); );
@ -154,7 +154,7 @@ export class CanvasTool {
const stage = this.manager.stage; const stage = this.manager.stage;
const renderedEntityCount: number = 1; // TODO(psyche): this.manager should be renderable entity count const renderedEntityCount: number = 1; // TODO(psyche): this.manager should be renderable entity count
const toolState = this.manager.stateApi.getToolState(); const toolState = this.manager.stateApi.getToolState();
const selectedEntity = this.manager.getSelectedEntity(); const selectedEntity = this.manager.stateApi.getSelectedEntity();
const cursorPos = this.manager.stateApi.$lastCursorPos.get(); const cursorPos = this.manager.stateApi.$lastCursorPos.get();
const isDrawing = this.manager.stateApi.$isDrawing.get(); const isDrawing = this.manager.stateApi.$isDrawing.get();
const isMouseDown = this.manager.stateApi.$isMouseDown.get(); const isMouseDown = this.manager.stateApi.$isMouseDown.get();
@ -175,7 +175,7 @@ export class CanvasTool {
} else if (!isDrawableEntity) { } else if (!isDrawableEntity) {
// Non-drawable layers don't have tools // Non-drawable layers don't have tools
stage.container().style.cursor = 'not-allowed'; stage.container().style.cursor = 'not-allowed';
} else if (tool === 'move' || Boolean(this.manager.$transformingEntity.get())) { } else if (tool === 'move' || Boolean(this.manager.stateApi.$transformingEntity.get())) {
// Move tool gets a pointer // Move tool gets a pointer
stage.container().style.cursor = 'default'; stage.container().style.cursor = 'default';
} else if (tool === 'rect') { } else if (tool === 'rect') {
@ -198,7 +198,7 @@ export class CanvasTool {
// No need to render the brush preview if the cursor position or color is missing // No need to render the brush preview if the cursor position or color is missing
if (cursorPos && tool === 'brush') { if (cursorPos && tool === 'brush') {
const brushPreviewFill = this.manager.getBrushPreviewFill(); const brushPreviewFill = this.manager.stateApi.getBrushPreviewFill();
const alignedCursorPos = alignCoordForTool(cursorPos, toolState.brush.width); const alignedCursorPos = alignCoordForTool(cursorPos, toolState.brush.width);
const scale = stage.scaleX(); const scale = stage.scaleX();
// Update the fill circle // Update the fill circle

View File

@ -384,7 +384,7 @@ export class CanvasTransformer {
// When the selected tool changes, we need to update the transformer's interaction state. // When the selected tool changes, we need to update the transformer's interaction state.
this.subscriptions.add( this.subscriptions.add(
this.manager.$toolState.listen((newVal, oldVal) => { this.manager.stateApi.$toolState.listen((newVal, oldVal) => {
if (newVal.selected !== oldVal.selected) { if (newVal.selected !== oldVal.selected) {
this.syncInteractionState(); this.syncInteractionState();
} }
@ -393,7 +393,7 @@ export class CanvasTransformer {
// When the selected entity changes, we need to update the transformer's interaction state. // When the selected entity changes, we need to update the transformer's interaction state.
this.subscriptions.add( this.subscriptions.add(
this.manager.$selectedEntityIdentifier.listen(() => { this.manager.stateApi.$selectedEntityIdentifier.listen(() => {
this.syncInteractionState(); this.syncInteractionState();
}) })
); );

View File

@ -115,7 +115,7 @@ const getLastPointOfLastLineOfEntity = (
}; };
export const setStageEventHandlers = (manager: CanvasManager): (() => void) => { export const setStageEventHandlers = (manager: CanvasManager): (() => void) => {
const { stage, stateApi, getCurrentFill, getSelectedEntity } = manager; const { stage, stateApi } = manager;
const { const {
getToolState, getToolState,
setTool, setTool,
@ -130,6 +130,8 @@ export const setStageEventHandlers = (manager: CanvasManager): (() => void) => {
getSettings, getSettings,
setBrushWidth, setBrushWidth,
setEraserWidth, setEraserWidth,
getCurrentFill,
getSelectedEntity,
} = stateApi; } = stateApi;
function getIsPrimaryMouseDown(e: KonvaEventObject<MouseEvent>) { function getIsPrimaryMouseDown(e: KonvaEventObject<MouseEvent>) {