refactor(ui): revise types for line and rect objects

- Create separate object types for brush and eraser lines, instead of a single type that has a `tool` field.
- Create new object type for rect shapes.
- Add logic to schemas to migrate old object types to new.
- Update renderers & reducers.
This commit is contained in:
psychedelicious 2024-06-05 16:39:22 +10:00
parent 87261bdbc9
commit 7c5dea6d12
3 changed files with 155 additions and 41 deletions

View File

@ -34,13 +34,14 @@ import {
isRenderableLayer,
} from 'features/controlLayers/store/controlLayersSlice';
import type {
BrushLine,
ControlAdapterLayer,
EraserLine,
InitialImageLayer,
Layer,
RectShape,
RegionalGuidanceLayer,
Tool,
VectorMaskLine,
VectorMaskRect,
} from 'features/controlLayers/store/types';
import { t } from 'i18next';
import Konva from 'konva';
@ -274,33 +275,64 @@ const createRGLayer = (
};
/**
* Creates a konva line from a vector mask line.
* @param vectorMaskLine The vector mask line state
* Creates a konva vector mask brush line from a vector mask line.
* @param brushLine The vector mask line state
* @param layerObjectGroup The konva layer's object group to add the line to
*/
const createVectorMaskLine = (vectorMaskLine: VectorMaskLine, layerObjectGroup: Konva.Group): Konva.Line => {
const createVectorMaskBrushLine = (brushLine: BrushLine, layerObjectGroup: Konva.Group): Konva.Line => {
const konvaLine = new Konva.Line({
id: vectorMaskLine.id,
key: vectorMaskLine.id,
id: brushLine.id,
key: brushLine.id,
name: RG_LAYER_LINE_NAME,
strokeWidth: vectorMaskLine.strokeWidth,
strokeWidth: brushLine.strokeWidth,
tension: 0,
lineCap: 'round',
lineJoin: 'round',
shadowForStrokeEnabled: false,
globalCompositeOperation: vectorMaskLine.tool === 'brush' ? 'source-over' : 'destination-out',
globalCompositeOperation: 'source-over',
listening: false,
});
layerObjectGroup.add(konvaLine);
return konvaLine;
};
/**
* Creates a konva vector mask eraser line from a vector mask line.
* @param eraserLine The vector mask line state
* @param layerObjectGroup The konva layer's object group to add the line to
*/
const createVectorMaskEraserLine = (eraserLine: EraserLine, layerObjectGroup: Konva.Group): Konva.Line => {
const konvaLine = new Konva.Line({
id: eraserLine.id,
key: eraserLine.id,
name: RG_LAYER_LINE_NAME,
strokeWidth: eraserLine.strokeWidth,
tension: 0,
lineCap: 'round',
lineJoin: 'round',
shadowForStrokeEnabled: false,
globalCompositeOperation: 'destination-out',
listening: false,
});
layerObjectGroup.add(konvaLine);
return konvaLine;
};
const createVectorMaskLine = (maskObject: BrushLine | EraserLine, layerObjectGroup: Konva.Group): Konva.Line => {
if (maskObject.type === 'brush_line') {
return createVectorMaskBrushLine(maskObject, layerObjectGroup);
} else {
// maskObject.type === 'eraser_line'
return createVectorMaskEraserLine(maskObject, layerObjectGroup);
}
};
/**
* Creates a konva rect from a vector mask rect.
* @param vectorMaskRect The vector mask rect state
* @param layerObjectGroup The konva layer's object group to add the line to
*/
const createVectorMaskRect = (vectorMaskRect: VectorMaskRect, layerObjectGroup: Konva.Group): Konva.Rect => {
const createVectorMaskRect = (vectorMaskRect: RectShape, layerObjectGroup: Konva.Group): Konva.Rect => {
const konvaRect = new Konva.Rect({
id: vectorMaskRect.id,
key: vectorMaskRect.id,
@ -369,7 +401,7 @@ const renderRGLayer = (
}
for (const maskObject of layerState.maskObjects) {
if (maskObject.type === 'vector_mask_line') {
if (maskObject.type === 'brush_line' || maskObject.type === 'eraser_line') {
const vectorMaskLine =
stage.findOne<Konva.Line>(`#${maskObject.id}`) ?? createVectorMaskLine(maskObject, konvaObjectGroup);
@ -384,7 +416,7 @@ const renderRGLayer = (
vectorMaskLine.stroke(rgbColor);
groupNeedsCache = true;
}
} else if (maskObject.type === 'vector_mask_rect') {
} else if (maskObject.type === 'rect_shape') {
const konvaObject =
stage.findOne<Konva.Rect>(`#${maskObject.id}`) ?? createVectorMaskRect(maskObject, konvaObjectGroup);

View File

@ -47,17 +47,19 @@ import type {
AddLineArg,
AddPointToLineArg,
AddRectArg,
BrushLine,
ControlAdapterLayer,
ControlLayersState,
DrawingTool,
EraserLine,
InitialImageLayer,
IPAdapterLayer,
Layer,
RectShape,
RegionalGuidanceLayer,
Tool,
VectorMaskLine,
VectorMaskRect,
} from './types';
import { DEFAULT_RGBA_COLOR } from './types';
export const initialControlLayersState: ControlLayersState = {
_version: 3,
@ -77,7 +79,8 @@ export const initialControlLayersState: ControlLayersState = {
},
};
const isLine = (obj: VectorMaskLine | VectorMaskRect): obj is VectorMaskLine => obj.type === 'vector_mask_line';
const isLine = (obj: BrushLine | EraserLine | RectShape): obj is BrushLine | EraserLine =>
obj.type === 'brush_line' || obj.type === 'eraser_line';
export const isRegionalGuidanceLayer = (layer?: Layer): layer is RegionalGuidanceLayer =>
layer?.type === 'regional_guidance_layer';
export const isControlAdapterLayer = (layer?: Layer): layer is ControlAdapterLayer =>
@ -491,15 +494,26 @@ export const controlLayersSlice = createSlice({
const { layerId, points, tool, lineUuid } = action.payload;
const layer = selectRGLayerOrThrow(state, layerId);
const lineId = getRGLayerLineId(layer.id, lineUuid);
layer.maskObjects.push({
type: 'vector_mask_line',
tool: tool,
id: lineId,
// Points must be offset by the layer's x and y coordinates
// TODO: Handle this in the event listener?
points: [points[0] - layer.x, points[1] - layer.y, points[2] - layer.x, points[3] - layer.y],
strokeWidth: state.brushSize,
});
if (tool === 'brush') {
layer.maskObjects.push({
id: lineId,
type: 'brush_line',
// Points must be offset by the layer's x and y coordinates
// TODO: Handle this in the event listener?
points: [points[0] - layer.x, points[1] - layer.y, points[2] - layer.x, points[3] - layer.y],
strokeWidth: state.brushSize,
color: DEFAULT_RGBA_COLOR,
});
} else {
layer.maskObjects.push({
id: lineId,
type: 'eraser_line',
// Points must be offset by the layer's x and y coordinates
// TODO: Handle this in the event listener?
points: [points[0] - layer.x, points[1] - layer.y, points[2] - layer.x, points[3] - layer.y],
strokeWidth: state.brushSize,
});
}
layer.bboxNeedsUpdate = true;
layer.uploadedMaskImage = null;
},
@ -530,12 +544,13 @@ export const controlLayersSlice = createSlice({
const layer = selectRGLayerOrThrow(state, layerId);
const id = getRGLayerRectId(layer.id, rectUuid);
layer.maskObjects.push({
type: 'vector_mask_rect',
type: 'rect_shape',
id,
x: rect.x - layer.x,
y: rect.y - layer.y,
width: rect.width,
height: rect.height,
color: DEFAULT_RGBA_COLOR,
});
layer.bboxNeedsUpdate = true;
layer.uploadedMaskImage = null;

View File

@ -5,13 +5,15 @@ import {
zT2IAdapterConfigV2,
} from 'features/controlLayers/util/controlAdapters';
import type { AspectRatioState } from 'features/parameters/components/ImageSize/types';
import type {
ParameterHeight,
ParameterNegativePrompt,
ParameterNegativeStylePromptSDXL,
ParameterPositivePrompt,
ParameterPositiveStylePromptSDXL,
ParameterWidth,
} from 'features/parameters/types/parameterSchemas';
import {
type ParameterHeight,
type ParameterNegativePrompt,
type ParameterNegativeStylePromptSDXL,
type ParameterPositivePrompt,
type ParameterPositiveStylePromptSDXL,
type ParameterWidth,
zAutoNegative,
zParameterNegativePrompt,
zParameterPositivePrompt,
@ -28,16 +30,15 @@ export type DrawingTool = z.infer<typeof zDrawingTool>;
const zPoints = z.array(z.number()).refine((points) => points.length % 2 === 0, {
message: 'Must have an even number of points',
});
const zVectorMaskLine = z.object({
const zOLD_VectorMaskLine = z.object({
id: z.string(),
type: z.literal('vector_mask_line'),
tool: zDrawingTool,
strokeWidth: z.number().min(1),
points: zPoints,
});
export type VectorMaskLine = z.infer<typeof zVectorMaskLine>;
const zVectorMaskRect = z.object({
const zOLD_VectorMaskRect = z.object({
id: z.string(),
type: z.literal('vector_mask_rect'),
x: z.number(),
@ -45,7 +46,45 @@ const zVectorMaskRect = z.object({
width: z.number().min(1),
height: z.number().min(1),
});
export type VectorMaskRect = z.infer<typeof zVectorMaskRect>;
const zRgbColor = z.object({
r: z.number().int().min(0).max(255),
g: z.number().int().min(0).max(255),
b: z.number().int().min(0).max(255),
});
const zRgbaColor = zRgbColor.extend({
a: z.number().min(0).max(1),
});
type RgbaColor = z.infer<typeof zRgbaColor>;
export const DEFAULT_RGBA_COLOR: RgbaColor = { r: 255, g: 255, b: 255, a: 1 };
const zBrushLine = z.object({
id: z.string(),
type: z.literal('brush_line'),
strokeWidth: z.number().min(1),
points: zPoints,
color: zRgbaColor,
});
export type BrushLine = z.infer<typeof zBrushLine>;
const zEraserline = z.object({
id: z.string(),
type: z.literal('eraser_line'),
strokeWidth: z.number().min(1),
points: zPoints,
});
export type EraserLine = z.infer<typeof zEraserline>;
const zRectShape = z.object({
id: z.string(),
type: z.literal('rect_shape'),
x: z.number(),
y: z.number(),
width: z.number().min(1),
height: z.number().min(1),
color: zRgbaColor,
});
export type RectShape = z.infer<typeof zRectShape>;
const zLayerBase = z.object({
id: z.string(),
@ -80,14 +119,42 @@ const zIPAdapterLayer = zLayerBase.extend({
});
export type IPAdapterLayer = z.infer<typeof zIPAdapterLayer>;
const zRgbColor = z.object({
r: z.number().int().min(0).max(255),
g: z.number().int().min(0).max(255),
b: z.number().int().min(0).max(255),
});
const zMaskObject = z
.discriminatedUnion('type', [zOLD_VectorMaskLine, zOLD_VectorMaskRect, zBrushLine, zEraserline, zRectShape])
.transform((val) => {
// Migrate old vector mask objects to new format
if (val.type === 'vector_mask_line') {
const { tool, ...rest } = val;
if (tool === 'brush') {
const asBrushline: BrushLine = {
...rest,
type: 'brush_line',
color: { r: 255, g: 255, b: 255, a: 1 },
};
return asBrushline;
} else if (tool === 'eraser') {
const asEraserLine: EraserLine = {
...rest,
type: 'eraser_line',
};
return asEraserLine;
}
} else if (val.type === 'vector_mask_rect') {
const asRectShape: RectShape = {
...val,
type: 'rect_shape',
color: { r: 255, g: 255, b: 255, a: 1 },
};
return asRectShape;
} else {
return val;
}
})
.pipe(z.discriminatedUnion('type', [zBrushLine, zEraserline, zRectShape]));
const zRegionalGuidanceLayer = zRenderableLayerBase.extend({
type: z.literal('regional_guidance_layer'),
maskObjects: z.array(z.discriminatedUnion('type', [zVectorMaskLine, zVectorMaskRect])),
maskObjects: z.array(zMaskObject),
positivePrompt: zParameterPositivePrompt.nullable(),
negativePrompt: zParameterNegativePrompt.nullable(),
ipAdapters: z.array(zIPAdapterConfigV2),