feat(ui): wip regional prompting UI

This commit is contained in:
psychedelicious 2024-04-08 17:06:08 +10:00 committed by Kent Keirsey
parent f87eee810b
commit 83d359b681
18 changed files with 712 additions and 92 deletions

View File

@ -0,0 +1,19 @@
import type { FlexProps } from '@invoke-ai/ui-library';
import { Flex, forwardRef } from '@invoke-ai/ui-library';
import { useMemo } from 'react';
import type { RgbaColor, RgbColor } from 'react-colorful';
type Props = FlexProps & {
previewColor: RgbColor | RgbaColor;
};
export const ColorPreview = forwardRef((props: Props, ref) => {
const { previewColor, ...rest } = props;
const colorString = useMemo(() => {
if ('a' in previewColor) {
return `rgba(${previewColor.r}, ${previewColor.g}, ${previewColor.b}, ${previewColor.a ?? 1})`;
}
return `rgba(${previewColor.r}, ${previewColor.g}, ${previewColor.b}, 1)`;
}, [previewColor]);
return <Flex ref={ref} w="full" h="full" borderRadius="base" backgroundColor={colorString} {...rest} />;
});

View File

@ -0,0 +1,84 @@
import type { ChakraProps } from '@invoke-ai/ui-library';
import { CompositeNumberInput, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { CSSProperties } from 'react';
import { memo, useCallback } from 'react';
import { RgbColorPicker as ColorfulRgbColorPicker } from 'react-colorful';
import type { ColorPickerBaseProps, RgbColor } from 'react-colorful/dist/types';
import { useTranslation } from 'react-i18next';
type RgbColorPickerProps = ColorPickerBaseProps<RgbColor> & {
withNumberInput?: boolean;
};
const colorPickerPointerStyles: NonNullable<ChakraProps['sx']> = {
width: 6,
height: 6,
borderColor: 'base.100',
};
const sx: ChakraProps['sx'] = {
'.react-colorful__hue-pointer': colorPickerPointerStyles,
'.react-colorful__saturation-pointer': colorPickerPointerStyles,
'.react-colorful__alpha-pointer': colorPickerPointerStyles,
gap: 5,
flexDir: 'column',
};
const colorPickerStyles: CSSProperties = { width: '100%' };
const numberInputWidth: ChakraProps['w'] = '4.2rem';
const RgbColorPicker = (props: RgbColorPickerProps) => {
const { color, onChange, withNumberInput, ...rest } = props;
const { t } = useTranslation();
const handleChangeR = useCallback((r: number) => onChange({ ...color, r }), [color, onChange]);
const handleChangeG = useCallback((g: number) => onChange({ ...color, g }), [color, onChange]);
const handleChangeB = useCallback((b: number) => onChange({ ...color, b }), [color, onChange]);
return (
<Flex sx={sx}>
<ColorfulRgbColorPicker color={color} onChange={onChange} style={colorPickerStyles} {...rest} />
{withNumberInput && (
<Flex gap={5}>
<FormControl gap={0}>
<FormLabel>{t('common.red')}</FormLabel>
<CompositeNumberInput
value={color.r}
onChange={handleChangeR}
min={0}
max={255}
step={1}
w={numberInputWidth}
defaultValue={90}
/>
</FormControl>
<FormControl gap={0}>
<FormLabel>{t('common.green')}</FormLabel>
<CompositeNumberInput
value={color.g}
onChange={handleChangeG}
min={0}
max={255}
step={1}
w={numberInputWidth}
defaultValue={90}
/>
</FormControl>
<FormControl gap={0}>
<FormLabel>{t('common.blue')}</FormLabel>
<CompositeNumberInput
value={color.b}
onChange={handleChangeB}
min={0}
max={255}
step={1}
w={numberInputWidth}
defaultValue={255}
/>
</FormControl>
</Flex>
)}
</Flex>
);
};
export default memo(RgbColorPicker);

View File

@ -1,6 +1,11 @@
import type { RgbaColor } from 'react-colorful';
import type { RgbaColor, RgbColor } from 'react-colorful';
export const rgbaColorToString = (color: RgbaColor): string => {
const { r, g, b, a } = color;
return `rgba(${r}, ${g}, ${b}, ${a})`;
};
export const rgbColorToString = (color: RgbColor): string => {
const { r, g, b } = color;
return `rgba(${r}, ${g}, ${b})`;
};

View File

@ -0,0 +1,13 @@
import { Button } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { layerAdded } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { useCallback } from 'react';
export const AddLayerButton = () => {
const dispatch = useAppDispatch();
const onClick = useCallback(() => {
dispatch(layerAdded('promptRegionLayer'));
}, [dispatch]);
return <Button onClick={onClick}>Add Layer</Button>;
};

View File

@ -0,0 +1,23 @@
import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks';
import { rgbColorToString } from 'features/canvas/util/colorToString';
import { $cursorPosition } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { Circle } from 'react-konva';
export const BrushPreview = () => {
const brushSize = useAppSelector((s) => s.regionalPrompts.brushSize);
const color = useAppSelector((s) => {
const _color = s.regionalPrompts.layers.find((l) => l.id === s.regionalPrompts.selectedLayer)?.color;
if (!_color) {
return null;
}
return rgbColorToString(_color);
});
const pos = useStore($cursorPosition);
if (!brushSize || !color || !pos) {
return null;
}
return <Circle x={pos.x} y={pos.y} radius={brushSize / 2} fill={color} />;
};

View File

@ -0,0 +1,22 @@
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { brushSizeChanged } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { useCallback } from 'react';
export const BrushSize = () => {
const dispatch = useAppDispatch();
const brushSize = useAppSelector((s) => s.regionalPrompts.brushSize);
const onChange = useCallback(
(v: number) => {
dispatch(brushSizeChanged(v));
},
[dispatch]
);
return (
<FormControl orientation="vertical">
<FormLabel>Brush Size</FormLabel>
<CompositeSlider min={1} max={100} value={brushSize} onChange={onChange} />
<CompositeNumberInput min={1} max={500} value={brushSize} onChange={onChange} />
</FormControl>
);
};

View File

@ -0,0 +1,17 @@
import { Button } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { layerDeleted } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { useCallback } from 'react';
type Props = {
id: string;
};
export const DeleteLayerButton = ({ id }: Props) => {
const dispatch = useAppDispatch();
const onClick = useCallback(() => {
dispatch(layerDeleted(id));
}, [dispatch, id]);
return <Button onClick={onClick} flexShrink={0}>Delete</Button>;
};

View File

@ -0,0 +1,35 @@
import { Popover, PopoverBody, PopoverContent, PopoverTrigger } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { ColorPreview } from 'common/components/ColorPreview';
import RgbColorPicker from 'common/components/RgbColorPicker';
import { useLayer } from 'features/regionalPrompts/hooks/useLayer';
import { promptRegionLayerColorChanged } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { useCallback } from 'react';
import type { RgbColor } from 'react-colorful';
type Props = {
id: string;
};
export const LayerColorPicker = ({ id }: Props) => {
const layer = useLayer(id);
const dispatch = useAppDispatch();
const onColorChange = useCallback(
(color: RgbColor) => {
dispatch(promptRegionLayerColorChanged({ layerId: id, color }));
},
[dispatch, id]
);
return (
<Popover isLazy>
<PopoverTrigger>
<ColorPreview previewColor={layer.color} />
</PopoverTrigger>
<PopoverContent>
<PopoverBody w={64} h={64}>
<RgbColorPicker color={layer.color} onChange={onColorChange} withNumberInput />
</PopoverBody>
</PopoverContent>
</Popover>
);
};

View File

@ -0,0 +1,32 @@
import { Flex } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { DeleteLayerButton } from 'features/regionalPrompts/components/DeleteLayerButton';
import { LayerColorPicker } from 'features/regionalPrompts/components/LayerColorPicker';
import { RegionalPromptsPrompt } from 'features/regionalPrompts/components/RegionalPromptsPrompt';
import { ResetLayerButton } from 'features/regionalPrompts/components/ResetLayerButton';
import { layerSelected } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { useCallback, useMemo } from 'react';
type Props = {
id: string;
};
export const LayerListItem = ({ id }: Props) => {
const dispatch = useAppDispatch();
const selectedLayer = useAppSelector((s) => s.regionalPrompts.selectedLayer);
const border = useMemo(() => (selectedLayer === id ? '1px solid red' : 'none'), [selectedLayer, id]);
const onClickCapture = useCallback(() => {
// Must be capture so that the layer is selected before deleting/resetting/etc
dispatch(layerSelected(id));
}, [dispatch, id]);
return (
<Flex flexDir="column" onClickCapture={onClickCapture} border={border}>
<Flex gap={2}>
<ResetLayerButton id={id} />
<DeleteLayerButton id={id} />
<LayerColorPicker id={id} />
</Flex>
<RegionalPromptsPrompt layerId={id} />
</Flex>
);
};

View File

@ -1,13 +1,15 @@
import { rgbaColorToString } from 'features/canvas/util/colorToString';
import { rgbColorToString } from 'features/canvas/util/colorToString';
import { useTransform } from 'features/regionalPrompts/hooks/useTransform';
import type { LineObject } from 'features/regionalPrompts/store/regionalPromptsSlice';
import type { RgbColor } from 'react-colorful';
import { Line } from 'react-konva';
type Props = {
line: LineObject;
color: RgbColor;
};
export const LineComponent = ({ line }: Props) => {
export const LineComponent = ({ line, color }: Props) => {
const { shapeRef } = useTransform(line);
return (
@ -15,9 +17,13 @@ export const LineComponent = ({ line }: Props) => {
ref={shapeRef}
key={line.id}
points={line.points}
stroke={rgbaColorToString(line.color)}
strokeWidth={line.strokeWidth}
draggable
stroke={rgbColorToString(color)}
tension={0}
lineCap="round"
lineJoin="round"
shadowForStrokeEnabled={false}
listening={false}
/>
);
};

View File

@ -1,13 +1,15 @@
import { rgbaColorToString } from 'features/canvas/util/colorToString';
import { rgbColorToString } from 'features/canvas/util/colorToString';
import { useTransform } from 'features/regionalPrompts/hooks/useTransform';
import type { FillRectObject } from 'features/regionalPrompts/store/regionalPromptsSlice';
import type { RgbColor } from 'react-colorful';
import { Rect } from 'react-konva';
type Props = {
rect: FillRectObject;
color: RgbColor;
};
export const RectComponent = ({ rect }: Props) => {
export const RectComponent = ({ rect, color }: Props) => {
const { shapeRef } = useTransform(rect);
return (
@ -18,8 +20,8 @@ export const RectComponent = ({ rect }: Props) => {
y={rect.y}
width={rect.width}
height={rect.height}
fill={rgbaColorToString(rect.color)}
draggable
fill={rgbColorToString(color)}
listening={false}
/>
);
};

View File

@ -1,23 +1,28 @@
import { Flex } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { AddLayerButton } from 'features/regionalPrompts/components/AddLayerButton';
import { BrushSize } from 'features/regionalPrompts/components/BrushSize';
import { LayerListItem } from 'features/regionalPrompts/components/LayerListItem';
import { RegionalPromptsStage } from 'features/regionalPrompts/components/RegionalPromptsStage';
import { layersSelectors, selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice';
const selectLayers = createSelector(selectRegionalPromptsSlice, (regionalPrompts) =>
layersSelectors.selectAll(regionalPrompts)
const selectLayerIds = createSelector(selectRegionalPromptsSlice, (regionalPrompts) =>
regionalPrompts.layers.map((l) => l.id)
);
export const RegionalPromptsEditor = () => {
const layers = useAppSelector(selectLayers);
const layerIds = useAppSelector(selectLayerIds);
return (
<Flex>
<Flex flexBasis={1}>
{layers.map((layer) => (
<Flex key={layer.id}>{layer.prompt}</Flex>
<Flex gap={4}>
<Flex flexDir="column" w={200}>
<AddLayerButton />
<BrushSize />
{layerIds.map((id) => (
<LayerListItem key={id} id={id} />
))}
</Flex>
<Flex flexBasis={1}>
<Flex>
<RegionalPromptsStage />
</Flex>
</Flex>

View File

@ -0,0 +1,73 @@
import { Box, Textarea } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { ShowDynamicPromptsPreviewButton } from 'features/dynamicPrompts/components/ShowDynamicPromptsPreviewButton';
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
import { PromptPopover } from 'features/prompt/PromptPopover';
import { usePrompt } from 'features/prompt/usePrompt';
import { useLayer } from 'features/regionalPrompts/hooks/useLayer';
import { promptChanged } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { SDXLConcatButton } from 'features/sdxl/components/SDXLPrompts/SDXLConcatButton';
import { memo, useCallback, useRef } from 'react';
import type { HotkeyCallback } from 'react-hotkeys-hook';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
type Props = {
layerId: string;
};
export const RegionalPromptsPrompt = memo((props: Props) => {
const layer = useLayer(props.layerId);
const dispatch = useAppDispatch();
const baseModel = useAppSelector((s) => s.generation.model)?.base;
const textareaRef = useRef<HTMLTextAreaElement>(null);
const { t } = useTranslation();
const handleChange = useCallback(
(v: string) => {
dispatch(promptChanged({ layerId: props.layerId, prompt: v }));
},
[dispatch, props.layerId]
);
const { onChange, isOpen, onClose, onOpen, onSelect, onKeyDown, onFocus } = usePrompt({
prompt: layer.prompt,
textareaRef: textareaRef,
onChange: handleChange,
});
const focus: HotkeyCallback = useCallback(
(e) => {
onFocus();
e.preventDefault();
},
[onFocus]
);
useHotkeys('alt+a', focus, []);
return (
<PromptPopover isOpen={isOpen} onClose={onClose} onSelect={onSelect} width={textareaRef.current?.clientWidth}>
<Box pos="relative">
<Textarea
id="prompt"
name="prompt"
ref={textareaRef}
value={layer.prompt}
placeholder={t('parameters.positivePromptPlaceholder')}
onChange={onChange}
minH={28}
minW={64}
onKeyDown={onKeyDown}
variant="darkFilled"
paddingRight={30}
/>
<PromptOverlayButtonWrapper>
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
{baseModel === 'sdxl' && <SDXLConcatButton />}
<ShowDynamicPromptsPreviewButton />
</PromptOverlayButtonWrapper>
</Box>
</PromptPopover>
);
});
RegionalPromptsPrompt.displayName = 'RegionalPromptsPrompt';

View File

@ -1,38 +1,73 @@
import { chakra } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { BrushPreview } from 'features/regionalPrompts/components/BrushPreview';
import { LineComponent } from 'features/regionalPrompts/components/LineComponent';
import { RectComponent } from 'features/regionalPrompts/components/RectComponent';
import {
layerObjectsSelectors,
layersSelectors,
selectRegionalPromptsSlice,
} from 'features/regionalPrompts/store/regionalPromptsSlice';
import { memo } from 'react';
useMouseDown,
useMouseEnter,
useMouseLeave,
useMouseMove,
useMouseUp,
} from 'features/regionalPrompts/hooks/useMouseDown';
import { $stage, selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice';
import type Konva from 'konva';
import { memo, useCallback, useRef } from 'react';
import { Group, Layer, Stage } from 'react-konva';
const selectLayers = createSelector(selectRegionalPromptsSlice, (regionalPrompts) =>
layersSelectors.selectAll(regionalPrompts)
);
const selectLayers = createSelector(selectRegionalPromptsSlice, (regionalPrompts) => regionalPrompts.layers);
const ChakraStage = chakra(Stage, {
shouldForwardProp: (prop) => !['sx'].includes(prop),
});
const stageSx = {
border: '1px solid green',
};
export const RegionalPromptsStage: React.FC = memo(() => {
const layers = useAppSelector(selectLayers);
const stageRef = useRef<Konva.Stage | null>(null);
const onMouseDown = useMouseDown(stageRef);
const onMouseUp = useMouseUp(stageRef);
const onMouseMove = useMouseMove(stageRef);
const onMouseEnter = useMouseEnter(stageRef);
const onMouseLeave = useMouseLeave(stageRef);
const stageRefCallback = useCallback((el: Konva.Stage) => {
$stage.set(el);
stageRef.current = el;
}, []);
return (
<Stage width={window.innerWidth} height={window.innerHeight}>
<ChakraStage
ref={stageRefCallback}
width={512}
height={512}
onMouseDown={onMouseDown}
onMouseUp={onMouseUp}
onMouseMove={onMouseMove}
onMouseEnter={onMouseEnter}
onMouseLeave={onMouseLeave}
tabIndex={-1}
sx={stageSx}
>
<Layer>
{layers.map((layer) => (
<Group key={layer.id}>
{layerObjectsSelectors.selectAll(layer.objects).map((obj) => {
{layer.objects.map((obj) => {
if (obj.kind === 'line') {
return <LineComponent key={obj.id} line={obj} />;
return <LineComponent key={obj.id} line={obj} color={layer.color} />;
}
if (obj.kind === 'fillRect') {
return <RectComponent key={obj.id} rect={obj} />;
return <RectComponent key={obj.id} rect={obj} color={layer.color} />;
}
})}
</Group>
))}
<BrushPreview />
</Layer>
</Stage>
</ChakraStage>
);
});

View File

@ -0,0 +1,17 @@
import { Button } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { layerReset } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { useCallback } from 'react';
type Props = {
id: string;
};
export const ResetLayerButton = ({ id }: Props) => {
const dispatch = useAppDispatch();
const onClick = useCallback(() => {
dispatch(layerReset(id));
}, [dispatch, id]);
return <Button onClick={onClick} flexShrink={0}>Reset</Button>;
};

View File

@ -0,0 +1,18 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { useMemo } from 'react';
import { assert } from 'tsafe';
export const useLayer = (layerId: string) => {
const selectLayer = useMemo(
() =>
createSelector(selectRegionalPromptsSlice, (regionalPrompts) =>
regionalPrompts.layers.find((l) => l.id === layerId)
),
[layerId]
);
const layer = useAppSelector(selectLayer);
assert(layer, `Layer ${layerId} doesn't exist!`);
return layer;
};

View File

@ -0,0 +1,139 @@
import { useAppDispatch } from 'app/store/storeHooks';
import getScaledCursorPosition from 'features/canvas/util/getScaledCursorPosition';
import {
$cursorPosition,
$isMouseDown,
$isMouseOver,
$tool,
lineAdded,
pointsAdded,
} from 'features/regionalPrompts/store/regionalPromptsSlice';
import type Konva from 'konva';
import type { KonvaEventObject } from 'konva/lib/Node';
import type { MutableRefObject } from 'react';
import { useCallback } from 'react';
const getIsFocused = (stage: Konva.Stage) => {
return stage.container().contains(document.activeElement);
};
const syncCursorPos = (stage: Konva.Stage) => {
const pos = getScaledCursorPosition(stage);
if (!pos) {
return null;
}
$cursorPosition.set(pos);
return pos;
};
export const useMouseDown = (stageRef: MutableRefObject<Konva.Stage | null>) => {
const dispatch = useAppDispatch();
const onMouseDown = useCallback(
(_e: KonvaEventObject<MouseEvent | TouchEvent>) => {
if (!stageRef.current) {
return;
}
console.log('Mouse down');
const pos = syncCursorPos(stageRef.current);
if (!pos) {
return;
}
$isMouseDown.set(true);
if ($tool.get() === 'brush') {
dispatch(lineAdded([pos.x, pos.y]));
}
},
[dispatch, stageRef]
);
return onMouseDown;
};
export const useMouseUp = (stageRef: MutableRefObject<Konva.Stage | null>) => {
const dispatch = useAppDispatch();
const onMouseUp = useCallback(
(_e: KonvaEventObject<MouseEvent | TouchEvent>) => {
if (!stageRef.current) {
return;
}
console.log('Mouse up');
if ($tool.get() === 'brush' && $isMouseDown.get()) {
// Add another point to the last line.
$isMouseDown.set(false);
const pos = syncCursorPos(stageRef.current);
if (!pos) {
return;
}
dispatch(pointsAdded([pos.x, pos.y]));
}
},
[dispatch, stageRef]
);
return onMouseUp;
};
export const useMouseMove = (stageRef: MutableRefObject<Konva.Stage | null>) => {
const dispatch = useAppDispatch();
const onMouseMove = useCallback(
(_e: KonvaEventObject<MouseEvent | TouchEvent>) => {
if (!stageRef.current) {
return;
}
console.log('Mouse move');
const pos = syncCursorPos(stageRef.current);
if (!pos) {
return;
}
if (getIsFocused(stageRef.current) && $isMouseOver.get() && $isMouseDown.get() && $tool.get() === 'brush') {
dispatch(pointsAdded([pos.x, pos.y]));
}
},
[dispatch, stageRef]
);
return onMouseMove;
};
export const useMouseLeave = (stageRef: MutableRefObject<Konva.Stage | null>) => {
const onMouseLeave = useCallback(
(_e: KonvaEventObject<MouseEvent | TouchEvent>) => {
if (!stageRef.current) {
return;
}
console.log('Mouse leave');
$isMouseOver.set(false);
$isMouseDown.set(false);
$cursorPosition.set(null);
},
[stageRef]
);
return onMouseLeave;
};
export const useMouseEnter = (stageRef: MutableRefObject<Konva.Stage | null>) => {
const dispatch = useAppDispatch();
const onMouseEnter = useCallback(
(e: KonvaEventObject<MouseEvent>) => {
if (!stageRef.current) {
return;
}
console.log('Mouse enter');
$isMouseOver.set(true);
const pos = syncCursorPos(stageRef.current);
if (!pos) {
return;
}
if (!getIsFocused(stageRef.current)) {
return;
}
if (e.evt.buttons !== 1) {
$isMouseDown.set(false);
} else {
$isMouseDown.set(true);
if ($tool.get() === 'brush') {
dispatch(lineAdded([pos.x, pos.y]));
}
}
},
[dispatch, stageRef]
);
return onMouseEnter;
};

View File

@ -1,8 +1,11 @@
import type { EntityState } from '@reduxjs/toolkit';
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import type { RgbaColor } from 'react-colorful';
import type Konva from 'konva';
import type { Vector2d } from 'konva/lib/types';
import { atom } from 'nanostores';
import type { RgbColor } from 'react-colorful';
import { assert } from 'tsafe';
import { v4 as uuidv4 } from 'uuid';
type LayerObjectBase = {
@ -23,7 +26,6 @@ export type LineObject = LayerObjectBase & {
kind: 'line';
strokeWidth: number;
points: number[];
color: RgbaColor;
};
export type FillRectObject = LayerObjectBase & {
@ -32,84 +34,150 @@ export type FillRectObject = LayerObjectBase & {
y: number;
width: number;
height: number;
color: RgbaColor;
};
export type LayerObject = ImageObject | LineObject | FillRectObject;
export type PromptRegionLayer = {
id: string;
objects: EntityState<LayerObject, string>;
kind: 'promptRegionLayer';
objects: LayerObject[];
prompt: string;
color: RgbColor;
};
export const layersAdapter = createEntityAdapter<PromptRegionLayer, string>({
selectId: (layer) => layer.id,
});
export const layersSelectors = layersAdapter.getSelectors(undefined, getSelectorsOptions);
export type Layer = PromptRegionLayer;
export const layerObjectsAdapter = createEntityAdapter<LayerObject, string>({
selectId: (obj) => obj.id,
});
export const layerObjectsSelectors = layerObjectsAdapter.getSelectors(undefined, getSelectorsOptions);
export type Tool = 'brush';
const getMockState = () => {
// Mock data
const layer1ID = uuidv4();
const obj1ID = uuidv4();
const obj2ID = uuidv4();
const objectEntities: Record<string, LayerObject> = {
[obj1ID]: {
id: obj1ID,
kind: 'line',
isSelected: false,
color: { r: 255, g: 0, b: 0, a: 1 },
strokeWidth: 5,
points: [20, 20, 100, 100],
},
[obj2ID]: {
id: obj2ID,
kind: 'fillRect',
isSelected: false,
color: { r: 0, g: 255, b: 0, a: 1 },
x: 150,
y: 150,
width: 100,
height: 100,
},
};
const objectsInitialState = layerObjectsAdapter.getInitialState(undefined, objectEntities);
const entities: Record<string, PromptRegionLayer> = {
[layer1ID]: {
id: layer1ID,
prompt: 'strawberries',
objects: objectsInitialState,
},
};
return entities;
export type RegionalPromptsState = {
_version: 1;
selectedLayer: string | null;
layers: PromptRegionLayer[];
brushSize: number;
};
export const initialRegionalPromptsState = layersAdapter.getInitialState(
{ _version: 1, selectedID: null },
getMockState()
);
export const initialRegionalPromptsState: RegionalPromptsState = {
_version: 1,
selectedLayer: null,
brushSize: 40,
layers: [],
};
export type RegionalPromptsState = typeof initialRegionalPromptsState;
const isLine = (obj: LayerObject): obj is LineObject => obj.kind === 'line';
export const regionalPromptsSlice = createSlice({
name: 'regionalPrompts',
initialState: initialRegionalPromptsState,
reducers: {
layerAdded: layersAdapter.addOne,
layerRemoved: layersAdapter.removeOne,
layerUpdated: layersAdapter.updateOne,
layersReset: layersAdapter.removeAll,
layerAdded: {
reducer: (state, action: PayloadAction<Layer['kind'], string, { id: string }>) => {
const newLayer = buildLayer(action.meta.id, action.payload, state.layers.length);
state.layers.push(newLayer);
state.selectedLayer = newLayer.id;
},
prepare: (payload: Layer['kind']) => ({ payload, meta: { id: uuidv4() } }),
},
layerSelected: (state, action: PayloadAction<string>) => {
state.selectedLayer = action.payload;
},
layerReset: (state, action: PayloadAction<string>) => {
const layer = state.layers.find((l) => l.id === action.payload);
if (!layer) {
return;
}
layer.objects = [];
},
layerDeleted: (state, action: PayloadAction<string>) => {
state.layers = state.layers.filter((l) => l.id !== action.payload);
state.selectedLayer = state.layers[0]?.id ?? null;
},
promptChanged: (state, action: PayloadAction<{ layerId: string; prompt: string }>) => {
const { layerId, prompt } = action.payload;
const layer = state.layers.find((l) => l.id === layerId);
if (!layer) {
return;
}
layer.prompt = prompt;
},
promptRegionLayerColorChanged: (state, action: PayloadAction<{ layerId: string; color: RgbColor }>) => {
const { layerId, color } = action.payload;
const layer = state.layers.find((l) => l.id === layerId);
if (!layer || layer.kind !== 'promptRegionLayer') {
return;
}
layer.color = color;
},
lineAdded: {
reducer: (state, action: PayloadAction<number[], string, { id: string }>) => {
const selectedLayer = state.layers.find((l) => l.id === state.selectedLayer);
if (!selectedLayer || selectedLayer.kind !== 'promptRegionLayer') {
return;
}
selectedLayer.objects.push(buildLine(action.meta.id, action.payload, state.brushSize));
},
prepare: (payload: number[]) => ({ payload, meta: { id: uuidv4() } }),
},
pointsAdded: (state, action: PayloadAction<number[]>) => {
const selectedLayer = state.layers.find((l) => l.id === state.selectedLayer);
if (!selectedLayer || selectedLayer.kind !== 'promptRegionLayer') {
return;
}
const lastLine = selectedLayer.objects.findLast(isLine);
if (!lastLine) {
return;
}
lastLine.points.push(...action.payload);
},
brushSizeChanged: (state, action: PayloadAction<number>) => {
state.brushSize = action.payload;
},
},
});
export const { layerAdded, layerRemoved, layerUpdated, layersReset } = regionalPromptsSlice.actions;
const DEFAULT_COLORS = [
{ r: 200, g: 0, b: 0 },
{ r: 0, g: 200, b: 0 },
{ r: 0, g: 0, b: 200 },
{ r: 200, g: 200, b: 0 },
{ r: 0, g: 200, b: 200 },
{ r: 200, g: 0, b: 200 },
];
const buildLayer = (id: string, kind: Layer['kind'], layerCount: number) => {
if (kind === 'promptRegionLayer') {
const color = DEFAULT_COLORS[layerCount % DEFAULT_COLORS.length];
assert(color, 'Color not found');
return {
id,
kind,
prompt: '',
objects: [],
color,
};
}
assert(false, `Unknown layer kind: ${kind}`);
};
const buildLine = (id: string, points: number[], brushSize: number): LineObject => ({
isSelected: false,
kind: 'line',
id,
points,
strokeWidth: brushSize,
});
export const {
layerAdded,
layerSelected,
layerReset,
layerDeleted,
promptChanged,
lineAdded,
pointsAdded,
promptRegionLayerColorChanged,
brushSizeChanged,
} = regionalPromptsSlice.actions;
export const selectRegionalPromptsSlice = (state: RootState) => state.regionalPrompts;
@ -124,3 +192,10 @@ export const regionalPromptsPersistConfig: PersistConfig<RegionalPromptsState> =
migrate: migrateRegionalPromptsState,
persistDenylist: [],
};
export const $isMouseDown = atom(false);
export const $isMouseOver = atom(false);
export const $isFocused = atom(false);
export const $cursorPosition = atom<Vector2d | null>(null);
export const $tool = atom<Tool>('brush');
export const $stage = atom<Konva.Stage | null>(null);