refactor(ui): revise types for line and rect objects

- Create separate object types for brush and eraser lines, instead of a single type that has a `tool` field.
- Create new object type for rect shapes.
- Add logic to schemas to migrate old object types to new.
- Update renderers & reducers.
This commit is contained in:
psychedelicious 2024-06-05 16:39:22 +10:00
parent 87261bdbc9
commit 7c5dea6d12
3 changed files with 155 additions and 41 deletions

View File

@ -34,13 +34,14 @@ import {
isRenderableLayer, isRenderableLayer,
} from 'features/controlLayers/store/controlLayersSlice'; } from 'features/controlLayers/store/controlLayersSlice';
import type { import type {
BrushLine,
ControlAdapterLayer, ControlAdapterLayer,
EraserLine,
InitialImageLayer, InitialImageLayer,
Layer, Layer,
RectShape,
RegionalGuidanceLayer, RegionalGuidanceLayer,
Tool, Tool,
VectorMaskLine,
VectorMaskRect,
} from 'features/controlLayers/store/types'; } from 'features/controlLayers/store/types';
import { t } from 'i18next'; import { t } from 'i18next';
import Konva from 'konva'; import Konva from 'konva';
@ -274,33 +275,64 @@ const createRGLayer = (
}; };
/** /**
* Creates a konva line from a vector mask line. * Creates a konva vector mask brush line from a vector mask line.
* @param vectorMaskLine The vector mask line state * @param brushLine The vector mask line state
* @param layerObjectGroup The konva layer's object group to add the line to * @param layerObjectGroup The konva layer's object group to add the line to
*/ */
const createVectorMaskLine = (vectorMaskLine: VectorMaskLine, layerObjectGroup: Konva.Group): Konva.Line => { const createVectorMaskBrushLine = (brushLine: BrushLine, layerObjectGroup: Konva.Group): Konva.Line => {
const konvaLine = new Konva.Line({ const konvaLine = new Konva.Line({
id: vectorMaskLine.id, id: brushLine.id,
key: vectorMaskLine.id, key: brushLine.id,
name: RG_LAYER_LINE_NAME, name: RG_LAYER_LINE_NAME,
strokeWidth: vectorMaskLine.strokeWidth, strokeWidth: brushLine.strokeWidth,
tension: 0, tension: 0,
lineCap: 'round', lineCap: 'round',
lineJoin: 'round', lineJoin: 'round',
shadowForStrokeEnabled: false, shadowForStrokeEnabled: false,
globalCompositeOperation: vectorMaskLine.tool === 'brush' ? 'source-over' : 'destination-out', globalCompositeOperation: 'source-over',
listening: false, listening: false,
}); });
layerObjectGroup.add(konvaLine); layerObjectGroup.add(konvaLine);
return konvaLine; return konvaLine;
}; };
/**
* Creates a konva vector mask eraser line from a vector mask line.
* @param eraserLine The vector mask line state
* @param layerObjectGroup The konva layer's object group to add the line to
*/
const createVectorMaskEraserLine = (eraserLine: EraserLine, layerObjectGroup: Konva.Group): Konva.Line => {
const konvaLine = new Konva.Line({
id: eraserLine.id,
key: eraserLine.id,
name: RG_LAYER_LINE_NAME,
strokeWidth: eraserLine.strokeWidth,
tension: 0,
lineCap: 'round',
lineJoin: 'round',
shadowForStrokeEnabled: false,
globalCompositeOperation: 'destination-out',
listening: false,
});
layerObjectGroup.add(konvaLine);
return konvaLine;
};
const createVectorMaskLine = (maskObject: BrushLine | EraserLine, layerObjectGroup: Konva.Group): Konva.Line => {
if (maskObject.type === 'brush_line') {
return createVectorMaskBrushLine(maskObject, layerObjectGroup);
} else {
// maskObject.type === 'eraser_line'
return createVectorMaskEraserLine(maskObject, layerObjectGroup);
}
};
/** /**
* Creates a konva rect from a vector mask rect. * Creates a konva rect from a vector mask rect.
* @param vectorMaskRect The vector mask rect state * @param vectorMaskRect The vector mask rect state
* @param layerObjectGroup The konva layer's object group to add the line to * @param layerObjectGroup The konva layer's object group to add the line to
*/ */
const createVectorMaskRect = (vectorMaskRect: VectorMaskRect, layerObjectGroup: Konva.Group): Konva.Rect => { const createVectorMaskRect = (vectorMaskRect: RectShape, layerObjectGroup: Konva.Group): Konva.Rect => {
const konvaRect = new Konva.Rect({ const konvaRect = new Konva.Rect({
id: vectorMaskRect.id, id: vectorMaskRect.id,
key: vectorMaskRect.id, key: vectorMaskRect.id,
@ -369,7 +401,7 @@ const renderRGLayer = (
} }
for (const maskObject of layerState.maskObjects) { for (const maskObject of layerState.maskObjects) {
if (maskObject.type === 'vector_mask_line') { if (maskObject.type === 'brush_line' || maskObject.type === 'eraser_line') {
const vectorMaskLine = const vectorMaskLine =
stage.findOne<Konva.Line>(`#${maskObject.id}`) ?? createVectorMaskLine(maskObject, konvaObjectGroup); stage.findOne<Konva.Line>(`#${maskObject.id}`) ?? createVectorMaskLine(maskObject, konvaObjectGroup);
@ -384,7 +416,7 @@ const renderRGLayer = (
vectorMaskLine.stroke(rgbColor); vectorMaskLine.stroke(rgbColor);
groupNeedsCache = true; groupNeedsCache = true;
} }
} else if (maskObject.type === 'vector_mask_rect') { } else if (maskObject.type === 'rect_shape') {
const konvaObject = const konvaObject =
stage.findOne<Konva.Rect>(`#${maskObject.id}`) ?? createVectorMaskRect(maskObject, konvaObjectGroup); stage.findOne<Konva.Rect>(`#${maskObject.id}`) ?? createVectorMaskRect(maskObject, konvaObjectGroup);

View File

@ -47,17 +47,19 @@ import type {
AddLineArg, AddLineArg,
AddPointToLineArg, AddPointToLineArg,
AddRectArg, AddRectArg,
BrushLine,
ControlAdapterLayer, ControlAdapterLayer,
ControlLayersState, ControlLayersState,
DrawingTool, DrawingTool,
EraserLine,
InitialImageLayer, InitialImageLayer,
IPAdapterLayer, IPAdapterLayer,
Layer, Layer,
RectShape,
RegionalGuidanceLayer, RegionalGuidanceLayer,
Tool, Tool,
VectorMaskLine,
VectorMaskRect,
} from './types'; } from './types';
import { DEFAULT_RGBA_COLOR } from './types';
export const initialControlLayersState: ControlLayersState = { export const initialControlLayersState: ControlLayersState = {
_version: 3, _version: 3,
@ -77,7 +79,8 @@ export const initialControlLayersState: ControlLayersState = {
}, },
}; };
const isLine = (obj: VectorMaskLine | VectorMaskRect): obj is VectorMaskLine => obj.type === 'vector_mask_line'; const isLine = (obj: BrushLine | EraserLine | RectShape): obj is BrushLine | EraserLine =>
obj.type === 'brush_line' || obj.type === 'eraser_line';
export const isRegionalGuidanceLayer = (layer?: Layer): layer is RegionalGuidanceLayer => export const isRegionalGuidanceLayer = (layer?: Layer): layer is RegionalGuidanceLayer =>
layer?.type === 'regional_guidance_layer'; layer?.type === 'regional_guidance_layer';
export const isControlAdapterLayer = (layer?: Layer): layer is ControlAdapterLayer => export const isControlAdapterLayer = (layer?: Layer): layer is ControlAdapterLayer =>
@ -491,15 +494,26 @@ export const controlLayersSlice = createSlice({
const { layerId, points, tool, lineUuid } = action.payload; const { layerId, points, tool, lineUuid } = action.payload;
const layer = selectRGLayerOrThrow(state, layerId); const layer = selectRGLayerOrThrow(state, layerId);
const lineId = getRGLayerLineId(layer.id, lineUuid); const lineId = getRGLayerLineId(layer.id, lineUuid);
layer.maskObjects.push({ if (tool === 'brush') {
type: 'vector_mask_line', layer.maskObjects.push({
tool: tool, id: lineId,
id: lineId, type: 'brush_line',
// Points must be offset by the layer's x and y coordinates // Points must be offset by the layer's x and y coordinates
// TODO: Handle this in the event listener? // TODO: Handle this in the event listener?
points: [points[0] - layer.x, points[1] - layer.y, points[2] - layer.x, points[3] - layer.y], points: [points[0] - layer.x, points[1] - layer.y, points[2] - layer.x, points[3] - layer.y],
strokeWidth: state.brushSize, strokeWidth: state.brushSize,
}); color: DEFAULT_RGBA_COLOR,
});
} else {
layer.maskObjects.push({
id: lineId,
type: 'eraser_line',
// Points must be offset by the layer's x and y coordinates
// TODO: Handle this in the event listener?
points: [points[0] - layer.x, points[1] - layer.y, points[2] - layer.x, points[3] - layer.y],
strokeWidth: state.brushSize,
});
}
layer.bboxNeedsUpdate = true; layer.bboxNeedsUpdate = true;
layer.uploadedMaskImage = null; layer.uploadedMaskImage = null;
}, },
@ -530,12 +544,13 @@ export const controlLayersSlice = createSlice({
const layer = selectRGLayerOrThrow(state, layerId); const layer = selectRGLayerOrThrow(state, layerId);
const id = getRGLayerRectId(layer.id, rectUuid); const id = getRGLayerRectId(layer.id, rectUuid);
layer.maskObjects.push({ layer.maskObjects.push({
type: 'vector_mask_rect', type: 'rect_shape',
id, id,
x: rect.x - layer.x, x: rect.x - layer.x,
y: rect.y - layer.y, y: rect.y - layer.y,
width: rect.width, width: rect.width,
height: rect.height, height: rect.height,
color: DEFAULT_RGBA_COLOR,
}); });
layer.bboxNeedsUpdate = true; layer.bboxNeedsUpdate = true;
layer.uploadedMaskImage = null; layer.uploadedMaskImage = null;

View File

@ -5,13 +5,15 @@ import {
zT2IAdapterConfigV2, zT2IAdapterConfigV2,
} from 'features/controlLayers/util/controlAdapters'; } from 'features/controlLayers/util/controlAdapters';
import type { AspectRatioState } from 'features/parameters/components/ImageSize/types'; import type { AspectRatioState } from 'features/parameters/components/ImageSize/types';
import type {
ParameterHeight,
ParameterNegativePrompt,
ParameterNegativeStylePromptSDXL,
ParameterPositivePrompt,
ParameterPositiveStylePromptSDXL,
ParameterWidth,
} from 'features/parameters/types/parameterSchemas';
import { import {
type ParameterHeight,
type ParameterNegativePrompt,
type ParameterNegativeStylePromptSDXL,
type ParameterPositivePrompt,
type ParameterPositiveStylePromptSDXL,
type ParameterWidth,
zAutoNegative, zAutoNegative,
zParameterNegativePrompt, zParameterNegativePrompt,
zParameterPositivePrompt, zParameterPositivePrompt,
@ -28,16 +30,15 @@ export type DrawingTool = z.infer<typeof zDrawingTool>;
const zPoints = z.array(z.number()).refine((points) => points.length % 2 === 0, { const zPoints = z.array(z.number()).refine((points) => points.length % 2 === 0, {
message: 'Must have an even number of points', message: 'Must have an even number of points',
}); });
const zVectorMaskLine = z.object({ const zOLD_VectorMaskLine = z.object({
id: z.string(), id: z.string(),
type: z.literal('vector_mask_line'), type: z.literal('vector_mask_line'),
tool: zDrawingTool, tool: zDrawingTool,
strokeWidth: z.number().min(1), strokeWidth: z.number().min(1),
points: zPoints, points: zPoints,
}); });
export type VectorMaskLine = z.infer<typeof zVectorMaskLine>;
const zVectorMaskRect = z.object({ const zOLD_VectorMaskRect = z.object({
id: z.string(), id: z.string(),
type: z.literal('vector_mask_rect'), type: z.literal('vector_mask_rect'),
x: z.number(), x: z.number(),
@ -45,7 +46,45 @@ const zVectorMaskRect = z.object({
width: z.number().min(1), width: z.number().min(1),
height: z.number().min(1), height: z.number().min(1),
}); });
export type VectorMaskRect = z.infer<typeof zVectorMaskRect>;
const zRgbColor = z.object({
r: z.number().int().min(0).max(255),
g: z.number().int().min(0).max(255),
b: z.number().int().min(0).max(255),
});
const zRgbaColor = zRgbColor.extend({
a: z.number().min(0).max(1),
});
type RgbaColor = z.infer<typeof zRgbaColor>;
export const DEFAULT_RGBA_COLOR: RgbaColor = { r: 255, g: 255, b: 255, a: 1 };
const zBrushLine = z.object({
id: z.string(),
type: z.literal('brush_line'),
strokeWidth: z.number().min(1),
points: zPoints,
color: zRgbaColor,
});
export type BrushLine = z.infer<typeof zBrushLine>;
const zEraserline = z.object({
id: z.string(),
type: z.literal('eraser_line'),
strokeWidth: z.number().min(1),
points: zPoints,
});
export type EraserLine = z.infer<typeof zEraserline>;
const zRectShape = z.object({
id: z.string(),
type: z.literal('rect_shape'),
x: z.number(),
y: z.number(),
width: z.number().min(1),
height: z.number().min(1),
color: zRgbaColor,
});
export type RectShape = z.infer<typeof zRectShape>;
const zLayerBase = z.object({ const zLayerBase = z.object({
id: z.string(), id: z.string(),
@ -80,14 +119,42 @@ const zIPAdapterLayer = zLayerBase.extend({
}); });
export type IPAdapterLayer = z.infer<typeof zIPAdapterLayer>; export type IPAdapterLayer = z.infer<typeof zIPAdapterLayer>;
const zRgbColor = z.object({ const zMaskObject = z
r: z.number().int().min(0).max(255), .discriminatedUnion('type', [zOLD_VectorMaskLine, zOLD_VectorMaskRect, zBrushLine, zEraserline, zRectShape])
g: z.number().int().min(0).max(255), .transform((val) => {
b: z.number().int().min(0).max(255), // Migrate old vector mask objects to new format
}); if (val.type === 'vector_mask_line') {
const { tool, ...rest } = val;
if (tool === 'brush') {
const asBrushline: BrushLine = {
...rest,
type: 'brush_line',
color: { r: 255, g: 255, b: 255, a: 1 },
};
return asBrushline;
} else if (tool === 'eraser') {
const asEraserLine: EraserLine = {
...rest,
type: 'eraser_line',
};
return asEraserLine;
}
} else if (val.type === 'vector_mask_rect') {
const asRectShape: RectShape = {
...val,
type: 'rect_shape',
color: { r: 255, g: 255, b: 255, a: 1 },
};
return asRectShape;
} else {
return val;
}
})
.pipe(z.discriminatedUnion('type', [zBrushLine, zEraserline, zRectShape]));
const zRegionalGuidanceLayer = zRenderableLayerBase.extend({ const zRegionalGuidanceLayer = zRenderableLayerBase.extend({
type: z.literal('regional_guidance_layer'), type: z.literal('regional_guidance_layer'),
maskObjects: z.array(z.discriminatedUnion('type', [zVectorMaskLine, zVectorMaskRect])), maskObjects: z.array(zMaskObject),
positivePrompt: zParameterPositivePrompt.nullable(), positivePrompt: zParameterPositivePrompt.nullable(),
negativePrompt: zParameterNegativePrompt.nullable(), negativePrompt: zParameterNegativePrompt.nullable(),
ipAdapters: z.array(zIPAdapterConfigV2), ipAdapters: z.array(zIPAdapterConfigV2),