diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/events.ts b/invokeai/frontend/web/src/features/controlLayers/konva/events.ts index 95a5d8b3e8..4793faf7b7 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/events.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/events.ts @@ -6,6 +6,7 @@ import type { LayerEntity, Position, RegionEntity, + Tool, } from 'features/controlLayers/store/types'; import { isDrawableEntity, isDrawableEntityAdapter } from 'features/controlLayers/store/types'; import type Konva from 'konva'; @@ -64,6 +65,49 @@ const getNextPoint = ( return currentPos; }; +const getLastPointOfLine = (points: number[]): Position | null => { + if (points.length < 2) { + return null; + } + const x = points[points.length - 2]; + const y = points[points.length - 1]; + if (x === undefined || y === undefined) { + return null; + } + return { x, y }; +}; + +const getLastPointOfLastLineOfEntity = ( + entity: LayerEntity | RegionEntity | InpaintMaskEntity, + tool: Tool +): Position | null => { + const lastObject = entity.objects[entity.objects.length - 1]; + + if (!lastObject) { + return null; + } + + if ( + !( + (lastObject.type === 'brush_line' && tool === 'brush') || + (lastObject.type === 'eraser_line' && tool === 'eraser') + ) + ) { + // If the last object type and current tool do not match, we cannot continue the line + return null; + } + + if (lastObject.points.length < 2) { + return null; + } + const x = lastObject.points[lastObject.points.length - 2]; + const y = lastObject.points[lastObject.points.length - 1]; + if (x === undefined || y === undefined) { + return null; + } + return { x, y }; +}; + export const setStageEventHandlers = (manager: CanvasManager): (() => void) => { const { stage, stateApi, getSelectedEntityAdapter } = manager; const { @@ -75,7 +119,7 @@ export const setStageEventHandlers = (manager: CanvasManager): (() => void) => { setLastMouseDownPos, getLastCursorPos, setLastCursorPos, - getLastAddedPoint, + // getLastAddedPoint, setLastAddedPoint, setStageAttrs, getSelectedEntity, @@ -137,27 +181,26 @@ export const setStageEventHandlers = (manager: CanvasManager): (() => void) => { setLastMouseDownPos(pos); if (toolState.selected === 'brush') { - if (e.evt.shiftKey) { - const lastAddedPoint = getLastAddedPoint(); - // Create a straight line if holding shift - if (lastAddedPoint) { - if (selectedEntityAdapter.getDrawingBuffer()) { - selectedEntityAdapter.finalizeDrawingBuffer(); - } - await selectedEntityAdapter.setDrawingBuffer({ - id: getBrushLineId(selectedEntityAdapter.id, uuidv4()), - type: 'brush_line', - points: [ - lastAddedPoint.x - selectedEntity.x, - lastAddedPoint.y - selectedEntity.y, - pos.x - selectedEntity.x, - pos.y - selectedEntity.y, - ], - strokeWidth: toolState.brush.width, - color: getCurrentFill(), - clip: getClip(selectedEntity), - }); + const lastLinePoint = getLastPointOfLastLineOfEntity(selectedEntity, toolState.selected); + if (e.evt.shiftKey && lastLinePoint) { + // Create a straight line from the last line point + if (selectedEntityAdapter.getDrawingBuffer()) { + selectedEntityAdapter.finalizeDrawingBuffer(); } + await selectedEntityAdapter.setDrawingBuffer({ + id: getBrushLineId(selectedEntityAdapter.id, uuidv4()), + type: 'brush_line', + points: [ + // The last point of the last line is already normalized to the entity's coordinates + lastLinePoint.x, + lastLinePoint.y, + pos.x - selectedEntity.x, + pos.y - selectedEntity.y, + ], + strokeWidth: toolState.brush.width, + color: getCurrentFill(), + clip: getClip(selectedEntity), + }); } else { if (selectedEntityAdapter.getDrawingBuffer()) { selectedEntityAdapter.finalizeDrawingBuffer(); @@ -180,26 +223,25 @@ export const setStageEventHandlers = (manager: CanvasManager): (() => void) => { } if (toolState.selected === 'eraser') { - if (e.evt.shiftKey) { - // Create a straight line if holding shift - const lastAddedPoint = getLastAddedPoint(); - if (lastAddedPoint) { - if (selectedEntityAdapter.getDrawingBuffer()) { - selectedEntityAdapter.finalizeDrawingBuffer(); - } - await selectedEntityAdapter.setDrawingBuffer({ - id: getBrushLineId(selectedEntityAdapter.id, uuidv4()), - type: 'eraser_line', - points: [ - lastAddedPoint.x - selectedEntity.x, - lastAddedPoint.y - selectedEntity.y, - pos.x - selectedEntity.x, - pos.y - selectedEntity.y, - ], - strokeWidth: toolState.eraser.width, - clip: getClip(selectedEntity), - }); + const lastLinePoint = getLastPointOfLastLineOfEntity(selectedEntity, toolState.selected); + if (e.evt.shiftKey && lastLinePoint) { + // Create a straight line from the last line point + if (selectedEntityAdapter.getDrawingBuffer()) { + selectedEntityAdapter.finalizeDrawingBuffer(); } + await selectedEntityAdapter.setDrawingBuffer({ + id: getBrushLineId(selectedEntityAdapter.id, uuidv4()), + type: 'eraser_line', + points: [ + // The last point of the last line is already normalized to the entity's coordinates + lastLinePoint.x, + lastLinePoint.y, + pos.x - selectedEntity.x, + pos.y - selectedEntity.y, + ], + strokeWidth: toolState.eraser.width, + clip: getClip(selectedEntity), + }); } else { if (selectedEntityAdapter.getDrawingBuffer()) { selectedEntityAdapter.finalizeDrawingBuffer(); @@ -308,8 +350,7 @@ export const setStageEventHandlers = (manager: CanvasManager): (() => void) => { const drawingBuffer = selectedEntityAdapter.getDrawingBuffer(); if (drawingBuffer) { if (drawingBuffer?.type === 'brush_line') { - const lastAddedPoint = getLastAddedPoint(); - const nextPoint = getNextPoint(pos, toolState, lastAddedPoint); + const nextPoint = getNextPoint(pos, toolState, getLastPointOfLine(drawingBuffer.points)); if (nextPoint) { drawingBuffer.points.push(nextPoint.x - selectedEntity.x, nextPoint.y - selectedEntity.y); await selectedEntityAdapter.setDrawingBuffer(drawingBuffer); @@ -343,8 +384,7 @@ export const setStageEventHandlers = (manager: CanvasManager): (() => void) => { const drawingBuffer = selectedEntityAdapter.getDrawingBuffer(); if (drawingBuffer) { if (drawingBuffer.type === 'eraser_line') { - const lastAddedPoint = getLastAddedPoint(); - const nextPoint = getNextPoint(pos, toolState, lastAddedPoint); + const nextPoint = getNextPoint(pos, toolState, getLastPointOfLine(drawingBuffer.points)); if (nextPoint) { drawingBuffer.points.push(nextPoint.x - selectedEntity.x, nextPoint.y - selectedEntity.y); await selectedEntityAdapter.setDrawingBuffer(drawingBuffer);