feat(ui): tidy up atoms

This commit is contained in:
psychedelicious 2024-08-08 18:52:56 +10:00
parent d98d35a8a8
commit b88d14b3df
7 changed files with 137 additions and 137 deletions

View File

@ -19,7 +19,7 @@ export const TransformToolButton = memo(() => {
if (!canvasManager) {
return;
}
return canvasManager.$transformingEntity.listen((newValue) => {
return canvasManager.stateApi.$transformingEntity.listen((newValue) => {
setIsTransforming(Boolean(newValue));
});
}, [canvasManager]);

View File

@ -18,24 +18,16 @@ import {
} from 'features/controlLayers/konva/util';
import type { Extents, ExtentsResult, GetBboxTask, WorkerLogMessage } from 'features/controlLayers/konva/worker';
import type {
CanvasControlAdapterState,
CanvasEntityIdentifier,
CanvasInpaintMaskState,
CanvasLayerState,
CanvasRegionalGuidanceState,
CanvasV2State,
Coordinate,
Dimensions,
GenerationMode,
GetLoggingContext,
Rect,
RgbaColor,
} from 'features/controlLayers/store/types';
import { RGBA_RED } from 'features/controlLayers/store/types';
import { isValidLayer } from 'features/nodes/util/graph/generation/addLayers';
import type Konva from 'konva';
import { clamp } from 'lodash-es';
import type { WritableAtom } from 'nanostores';
import { atom } from 'nanostores';
import type { Logger } from 'roarr';
import { getImageDTO, uploadImage } from 'services/api/endpoints/images';
@ -51,32 +43,6 @@ import { CanvasStagingArea } from './CanvasStagingArea';
import { CanvasStateApi } from './CanvasStateApi';
import { setStageEventHandlers } from './events';
type EntityStateAndAdapter =
| {
id: string;
type: CanvasLayerState['type'];
state: CanvasLayerState;
adapter: CanvasLayerAdapter;
}
| {
id: string;
type: CanvasInpaintMaskState['type'];
state: CanvasInpaintMaskState;
adapter: CanvasMaskAdapter;
}
// | {
// id: string;
// type: CanvasControlAdapterState['type'];
// state: CanvasControlAdapterState;
// adapter: CanvasControlAdapter;
// }
| {
id: string;
type: CanvasRegionalGuidanceState['type'];
state: CanvasRegionalGuidanceState;
adapter: CanvasMaskAdapter;
};
export const $canvasManager = atom<CanvasManager | null>(null);
export class CanvasManager {
@ -101,12 +67,6 @@ export class CanvasManager {
_worker: Worker = new Worker(new URL('./worker.ts', import.meta.url), { type: 'module', name: 'worker' });
_tasks: Map<string, { task: GetBboxTask; onComplete: (extents: Extents | null) => void }> = new Map();
$transformingEntity: WritableAtom<CanvasEntityIdentifier | null> = atom();
$toolState: WritableAtom<CanvasV2State['tool']> = atom();
$currentFill: WritableAtom<RgbaColor> = atom();
$selectedEntity: WritableAtom<EntityStateAndAdapter | null> = atom();
$selectedEntityIdentifier: WritableAtom<CanvasEntityIdentifier | null> = atom();
constructor(stage: Konva.Stage, container: HTMLDivElement, store: Store<RootState>) {
this.stage = stage;
this.container = container;
@ -160,11 +120,11 @@ export class CanvasManager {
this.log.error('Worker message error');
};
this.$transformingEntity.set(null);
this.$toolState.set(this.stateApi.getToolState());
this.$selectedEntityIdentifier.set(this.stateApi.getState().selectedEntityIdentifier);
this.$currentFill.set(this.getCurrentFill());
this.$selectedEntity.set(this.getSelectedEntity());
this.stateApi.$transformingEntity.set(null);
this.stateApi.$toolState.set(this.stateApi.getToolState());
this.stateApi.$selectedEntityIdentifier.set(this.stateApi.getState().selectedEntityIdentifier);
this.stateApi.$currentFill.set(this.stateApi.getCurrentFill());
this.stateApi.$selectedEntity.set(this.stateApi.getSelectedEntity());
this.inpaintMask = new CanvasMaskAdapter(this.stateApi.getInpaintMaskState(), this);
this.stage.add(this.inpaintMask.konva.layer);
@ -270,80 +230,8 @@ export class CanvasManager {
});
}
getEntity(identifier: CanvasEntityIdentifier): EntityStateAndAdapter | null {
const state = this.stateApi.getState();
let entityState:
| CanvasLayerState
| CanvasControlAdapterState
| CanvasRegionalGuidanceState
| CanvasInpaintMaskState
| null = null;
let entityAdapter: CanvasLayerAdapter | CanvasControlAdapter | CanvasMaskAdapter | null = null;
if (identifier.type === 'layer') {
entityState = state.layers.entities.find((i) => i.id === identifier.id) ?? null;
entityAdapter = this.layers.get(identifier.id) ?? null;
} else if (identifier.type === 'control_adapter') {
entityState = state.controlAdapters.entities.find((i) => i.id === identifier.id) ?? null;
entityAdapter = this.controlAdapters.get(identifier.id) ?? null;
} else if (identifier.type === 'regional_guidance') {
entityState = state.regions.entities.find((i) => i.id === identifier.id) ?? null;
entityAdapter = this.regions.get(identifier.id) ?? null;
} else if (identifier.type === 'inpaint_mask') {
entityState = state.inpaintMask;
entityAdapter = this.inpaintMask;
}
if (entityState && entityAdapter && entityState.type === entityAdapter.type) {
return {
id: entityState.id,
type: entityState.type,
state: entityState,
adapter: entityAdapter,
} as EntityStateAndAdapter; // TODO(psyche): make TS happy w/o this cast
}
return null;
}
getSelectedEntity = () => {
const state = this.stateApi.getState();
if (state.selectedEntityIdentifier) {
return this.getEntity(state.selectedEntityIdentifier);
}
return null;
};
getCurrentFill = () => {
const state = this.stateApi.getState();
let currentFill: RgbaColor = state.tool.fill;
const selectedEntity = this.getSelectedEntity();
if (selectedEntity) {
// These two entity types use a compositing rect for opacity. Their fill is always white.
if (selectedEntity.state.type === 'regional_guidance' || selectedEntity.state.type === 'inpaint_mask') {
currentFill = RGBA_RED;
// currentFill = RGBA_WHITE;
}
}
return currentFill;
};
getBrushPreviewFill = () => {
const state = this.stateApi.getState();
let currentFill: RgbaColor = state.tool.fill;
const selectedEntity = this.getSelectedEntity();
if (selectedEntity) {
// The brush should use the mask opacity for these entity types
if (selectedEntity.state.type === 'regional_guidance' || selectedEntity.state.type === 'inpaint_mask') {
currentFill = { ...selectedEntity.state.fill, a: this.stateApi.getSettings().maskOpacity };
}
}
return currentFill;
};
getTransformingLayer() {
const transformingEntity = this.$transformingEntity.get();
const transformingEntity = this.stateApi.$transformingEntity.get();
if (!transformingEntity) {
return null;
}
@ -362,21 +250,21 @@ export class CanvasManager {
}
getIsTransforming() {
return Boolean(this.$transformingEntity.get());
return Boolean(this.stateApi.$transformingEntity.get());
}
startTransform() {
if (this.getIsTransforming()) {
return;
}
const entity = this.getSelectedEntity();
const entity = this.stateApi.getSelectedEntity();
if (!entity) {
this.log.warn('No entity selected to transform');
return;
}
// TODO(psyche): Support other entity types
entity.adapter.transformer.startTransform();
this.$transformingEntity.set({ id: entity.id, type: entity.type });
this.stateApi.$transformingEntity.set({ id: entity.id, type: entity.type });
}
async applyTransform() {
@ -384,7 +272,7 @@ export class CanvasManager {
if (layer) {
await layer.transformer.applyTransform();
}
this.$transformingEntity.set(null);
this.stateApi.$transformingEntity.set(null);
}
cancelTransform() {
@ -392,7 +280,7 @@ export class CanvasManager {
if (layer) {
layer.transformer.stopTransform();
}
this.$transformingEntity.set(null);
this.stateApi.$transformingEntity.set(null);
}
render = async () => {
@ -485,10 +373,10 @@ export class CanvasManager {
await this.renderControlAdapters();
}
this.$toolState.set(state.tool);
this.$selectedEntityIdentifier.set(state.selectedEntityIdentifier);
this.$selectedEntity.set(this.getSelectedEntity());
this.$currentFill.set(this.getCurrentFill());
this.stateApi.$toolState.set(state.tool);
this.stateApi.$selectedEntityIdentifier.set(state.selectedEntityIdentifier);
this.stateApi.$selectedEntity.set(this.stateApi.getSelectedEntity());
this.stateApi.$currentFill.set(this.stateApi.getCurrentFill());
if (
this._isFirstRender ||
@ -709,7 +597,7 @@ export class CanvasManager {
};
getRegionMaskImageDTO = async (id: string, rect?: Rect): Promise<ImageDTO> => {
const region = this.getEntity({ id, type: 'regional_guidance' });
const region = this.stateApi.getEntity({ id, type: 'regional_guidance' });
assert(region?.type === 'regional_guidance');
if (region.state.imageCache) {
const imageDTO = await getImageDTO(region.state.imageCache);

View File

@ -116,7 +116,7 @@ export class CanvasObjectRenderer {
}
this.subscriptions.add(
this.manager.$toolState.listen((newVal, oldVal) => {
this.manager.stateApi.$toolState.listen((newVal, oldVal) => {
if (newVal.selected !== oldVal.selected) {
this.commitBuffer();
}

View File

@ -2,7 +2,9 @@ import { $alt, $ctrl, $meta, $shift } from '@invoke-ai/ui-library';
import type { Store } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import type { CanvasLayerAdapter } from 'features/controlLayers/konva/CanvasLayerAdapter';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import type { CanvasMaskAdapter } from 'features/controlLayers/konva/CanvasMaskAdapter';
import {
$isDrawing,
$isMouseDown,
@ -30,6 +32,11 @@ import {
toolChanged,
} from 'features/controlLayers/store/canvasV2Slice';
import type {
CanvasEntityIdentifier,
CanvasInpaintMaskState,
CanvasLayerState,
CanvasRegionalGuidanceState,
CanvasV2State,
EntityBrushLineAddedPayload,
EntityEraserLineAddedPayload,
EntityIdentifierPayload,
@ -37,10 +44,40 @@ import type {
EntityRasterizedPayload,
EntityRectAddedPayload,
Rect,
RgbaColor,
Tool,
} from 'features/controlLayers/store/types';
import { RGBA_RED } from 'features/controlLayers/store/types';
import type { WritableAtom } from 'nanostores';
import { atom } from 'nanostores';
import type { ImageDTO } from 'services/api/types';
type EntityStateAndAdapter =
| {
id: string;
type: CanvasLayerState['type'];
state: CanvasLayerState;
adapter: CanvasLayerAdapter;
}
| {
id: string;
type: CanvasInpaintMaskState['type'];
state: CanvasInpaintMaskState;
adapter: CanvasMaskAdapter;
}
// | {
// id: string;
// type: CanvasControlAdapterState['type'];
// state: CanvasControlAdapterState;
// adapter: CanvasControlAdapter;
// }
| {
id: string;
type: CanvasRegionalGuidanceState['type'];
state: CanvasRegionalGuidanceState;
adapter: CanvasMaskAdapter;
};
const log = logger('canvas');
export class CanvasStateApi {
@ -152,6 +189,79 @@ export class CanvasStateApi {
return this._store.getState().system.consoleLogLevel;
};
getEntity(identifier: CanvasEntityIdentifier): EntityStateAndAdapter | null {
const state = this.getState();
let entityState: EntityStateAndAdapter['state'] | null = null;
let entityAdapter: EntityStateAndAdapter['adapter'] | null = null;
if (identifier.type === 'layer') {
entityState = state.layers.entities.find((i) => i.id === identifier.id) ?? null;
entityAdapter = this.manager.layers.get(identifier.id) ?? null;
} else if (identifier.type === 'control_adapter') {
entityState = state.controlAdapters.entities.find((i) => i.id === identifier.id) ?? null;
entityAdapter = this.manager.controlAdapters.get(identifier.id) ?? null;
} else if (identifier.type === 'regional_guidance') {
entityState = state.regions.entities.find((i) => i.id === identifier.id) ?? null;
entityAdapter = this.manager.regions.get(identifier.id) ?? null;
} else if (identifier.type === 'inpaint_mask') {
entityState = state.inpaintMask;
entityAdapter = this.manager.inpaintMask;
}
if (entityState && entityAdapter && entityState.type === entityAdapter.type) {
return {
id: entityState.id,
type: entityState.type,
state: entityState,
adapter: entityAdapter,
} as EntityStateAndAdapter; // TODO(psyche): make TS happy w/o this cast
}
return null;
}
getSelectedEntity = () => {
const state = this.getState();
if (state.selectedEntityIdentifier) {
return this.getEntity(state.selectedEntityIdentifier);
}
return null;
};
getCurrentFill = () => {
const state = this.getState();
let currentFill: RgbaColor = state.tool.fill;
const selectedEntity = this.getSelectedEntity();
if (selectedEntity) {
// These two entity types use a compositing rect for opacity. Their fill is always white.
if (selectedEntity.state.type === 'regional_guidance' || selectedEntity.state.type === 'inpaint_mask') {
currentFill = RGBA_RED;
// currentFill = RGBA_WHITE;
}
}
return currentFill;
};
getBrushPreviewFill = () => {
const state = this.getState();
let currentFill: RgbaColor = state.tool.fill;
const selectedEntity = this.getSelectedEntity();
if (selectedEntity) {
// The brush should use the mask opacity for these entity types
if (selectedEntity.state.type === 'regional_guidance' || selectedEntity.state.type === 'inpaint_mask') {
currentFill = { ...selectedEntity.state.fill, a: this.getSettings().maskOpacity };
}
}
return currentFill;
};
$transformingEntity: WritableAtom<CanvasEntityIdentifier | null> = atom();
$toolState: WritableAtom<CanvasV2State['tool']> = atom();
$currentFill: WritableAtom<RgbaColor> = atom();
$selectedEntity: WritableAtom<EntityStateAndAdapter | null> = atom();
$selectedEntityIdentifier: WritableAtom<CanvasEntityIdentifier | null> = atom();
// Read-write state, ephemeral interaction state
$isDrawing = $isDrawing;
$isMouseDown = $isMouseDown;

View File

@ -118,7 +118,7 @@ export class CanvasTool {
);
this.subscriptions.add(
this.manager.$toolState.listen(() => {
this.manager.stateApi.$toolState.listen(() => {
this.render();
})
);
@ -154,7 +154,7 @@ export class CanvasTool {
const stage = this.manager.stage;
const renderedEntityCount: number = 1; // TODO(psyche): this.manager should be renderable entity count
const toolState = this.manager.stateApi.getToolState();
const selectedEntity = this.manager.getSelectedEntity();
const selectedEntity = this.manager.stateApi.getSelectedEntity();
const cursorPos = this.manager.stateApi.$lastCursorPos.get();
const isDrawing = this.manager.stateApi.$isDrawing.get();
const isMouseDown = this.manager.stateApi.$isMouseDown.get();
@ -175,7 +175,7 @@ export class CanvasTool {
} else if (!isDrawableEntity) {
// Non-drawable layers don't have tools
stage.container().style.cursor = 'not-allowed';
} else if (tool === 'move' || Boolean(this.manager.$transformingEntity.get())) {
} else if (tool === 'move' || Boolean(this.manager.stateApi.$transformingEntity.get())) {
// Move tool gets a pointer
stage.container().style.cursor = 'default';
} else if (tool === 'rect') {
@ -198,7 +198,7 @@ export class CanvasTool {
// No need to render the brush preview if the cursor position or color is missing
if (cursorPos && tool === 'brush') {
const brushPreviewFill = this.manager.getBrushPreviewFill();
const brushPreviewFill = this.manager.stateApi.getBrushPreviewFill();
const alignedCursorPos = alignCoordForTool(cursorPos, toolState.brush.width);
const scale = stage.scaleX();
// Update the fill circle

View File

@ -384,7 +384,7 @@ export class CanvasTransformer {
// When the selected tool changes, we need to update the transformer's interaction state.
this.subscriptions.add(
this.manager.$toolState.listen((newVal, oldVal) => {
this.manager.stateApi.$toolState.listen((newVal, oldVal) => {
if (newVal.selected !== oldVal.selected) {
this.syncInteractionState();
}
@ -393,7 +393,7 @@ export class CanvasTransformer {
// When the selected entity changes, we need to update the transformer's interaction state.
this.subscriptions.add(
this.manager.$selectedEntityIdentifier.listen(() => {
this.manager.stateApi.$selectedEntityIdentifier.listen(() => {
this.syncInteractionState();
})
);

View File

@ -115,7 +115,7 @@ const getLastPointOfLastLineOfEntity = (
};
export const setStageEventHandlers = (manager: CanvasManager): (() => void) => {
const { stage, stateApi, getCurrentFill, getSelectedEntity } = manager;
const { stage, stateApi } = manager;
const {
getToolState,
setTool,
@ -130,6 +130,8 @@ export const setStageEventHandlers = (manager: CanvasManager): (() => void) => {
getSettings,
setBrushWidth,
setEraserWidth,
getCurrentFill,
getSelectedEntity,
} = stateApi;
function getIsPrimaryMouseDown(e: KonvaEventObject<MouseEvent>) {