diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx index 656de737c7..501513919a 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx @@ -8,12 +8,13 @@ import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection' import { useWorkflowWatcher } from 'features/nodes/hooks/useWorkflowWatcher'; import { $cursorPos, + $didUpdateEdge, $isAddNodePopoverOpen, $isUpdatingEdge, + $lastEdgeUpdateMouseEvent, $pendingConnection, $viewport, connectionMade, - edgeAdded, edgeDeleted, edgesChanged, edgesDeleted, @@ -24,6 +25,7 @@ import { undo, } from 'features/nodes/store/nodesSlice'; import { $flow } from 'features/nodes/store/reactFlowInstance'; +import { isString } from 'lodash-es'; import type { CSSProperties, MouseEvent } from 'react'; import { memo, useCallback, useMemo, useRef } from 'react'; import { useHotkeys } from 'react-hotkeys-hook'; @@ -39,7 +41,7 @@ import type { ReactFlowProps, ReactFlowState, } from 'reactflow'; -import { Background, ReactFlow, useStore as useReactFlowStore } from 'reactflow'; +import { Background, ReactFlow, useStore as useReactFlowStore, useUpdateNodeInternals } from 'reactflow'; import CustomConnectionLine from './connectionLines/CustomConnectionLine'; import InvocationCollapsedEdge from './edges/InvocationCollapsedEdge'; @@ -81,6 +83,7 @@ export const Flow = memo(() => { const flowWrapper = useRef(null); const isValidConnection = useIsValidConnection(); const cancelConnection = useReactFlowStore(selectCancelConnection); + const updateNodeInternals = useUpdateNodeInternals(); useWorkflowWatcher(); useSyncExecutionState(); const [borderRadius] = useToken('radii', ['base']); @@ -157,45 +160,46 @@ export const Flow = memo(() => { * where the edge is deleted if you click it accidentally). */ - // We have a ref for cursor position, but it is the *projected* cursor position. - // Easiest to just keep track of the last mouse event for this particular feature - const edgeUpdateMouseEvent = useRef(); - - const onEdgeUpdateStart: NonNullable = useCallback( - (e, edge, _handleType) => { - $isUpdatingEdge.set(true); - // update mouse event - edgeUpdateMouseEvent.current = e; - // always delete the edge when starting an updated - dispatch(edgeDeleted(edge.id)); - }, - [dispatch] - ); + const onEdgeUpdateStart: NonNullable = useCallback((e, _edge, _handleType) => { + $isUpdatingEdge.set(true); + $didUpdateEdge.set(false); + $lastEdgeUpdateMouseEvent.set(e); + }, []); const onEdgeUpdate: OnEdgeUpdateFunc = useCallback( - (_oldEdge, newConnection) => { - // Because we deleted the edge when the update started, we must create a new edge from the connection + (edge, newConnection) => { + // This event is fired when an edge update is successful + $didUpdateEdge.set(true); + // When an edge update is successful, we need to delete the old edge and create a new one + dispatch(edgeDeleted(edge.id)); dispatch(connectionMade(newConnection)); + // Because we shift the position of handles depending on whether a field is connected or not, we must use + // updateNodeInternals to tell reactflow to recalculate the positions of the handles + const nodesToUpdate = [edge.source, edge.target, newConnection.source, newConnection.target].filter(isString); + updateNodeInternals(nodesToUpdate); }, - [dispatch] + [dispatch, updateNodeInternals] ); const onEdgeUpdateEnd: NonNullable = useCallback( (e, edge, _handleType) => { - $isUpdatingEdge.set(false); - $pendingConnection.set(null); - // Handle the case where user begins a drag but didn't move the cursor - we deleted the edge when starting - // the edge update - we need to add it back - if ( - // ignore touch events - !('touches' in e) && - edgeUpdateMouseEvent.current?.clientX === e.clientX && - edgeUpdateMouseEvent.current?.clientY === e.clientY - ) { - dispatch(edgeAdded(edge)); + const didUpdateEdge = $didUpdateEdge.get(); + // Fall back to a reasonable default event + const lastEvent = $lastEdgeUpdateMouseEvent.get() ?? { clientX: 0, clientY: 0 }; + // We have to narrow this event down to MouseEvents - could be TouchEvent + const didMouseMove = + !('touches' in e) && Math.hypot(e.clientX - lastEvent.clientX, e.clientY - lastEvent.clientY) > 5; + + // If we got this far and did not successfully update an edge, and the mouse moved away from the handle, + // the user probably intended to delete the edge + if (!didUpdateEdge && didMouseMove) { + dispatch(edgeDeleted(edge.id)); } - // reset mouse event - edgeUpdateMouseEvent.current = undefined; + + $isUpdatingEdge.set(false); + $didUpdateEdge.set(false); + $pendingConnection.set(null); + $lastEdgeUpdateMouseEvent.set(null); }, [dispatch] ); @@ -255,9 +259,11 @@ export const Flow = memo(() => { useHotkeys(['meta+shift+z', 'ctrl+shift+z'], onRedoHotkey); const onEscapeHotkey = useCallback(() => { - $pendingConnection.set(null); - $isAddNodePopoverOpen.set(false); - cancelConnection(); + if (!$isUpdatingEdge.get()) { + $pendingConnection.set(null); + $isAddNodePopoverOpen.set(false); + cancelConnection(); + } }, [cancelConnection]); useHotkeys('esc', onEscapeHotkey); diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index cec13e8df4..83632c16e1 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -47,6 +47,7 @@ import { import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation'; import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation'; import { atom } from 'nanostores'; +import type { MouseEvent } from 'react'; import type { Connection, Edge, EdgeChange, EdgeRemoveChange, Node, NodeChange, Viewport, XYPosition } from 'reactflow'; import { addEdge, applyEdgeChanges, applyNodeChanges, getConnectedEdges, getIncomers, getOutgoers } from 'reactflow'; import type { UndoableOptions } from 'redux-undo'; @@ -125,9 +126,6 @@ export const nodesSlice = createSlice({ edgesChanged: (state, action: PayloadAction) => { state.edges = applyEdgeChanges(action.payload, state.edges); }, - edgeAdded: (state, action: PayloadAction) => { - state.edges = addEdge(action.payload, state.edges); - }, connectionMade: (state, action: PayloadAction) => { state.edges = addEdge({ ...action.payload, type: 'default' }, state.edges); }, @@ -495,7 +493,6 @@ export const { notesNodeValueChanged, selectedAll, selectionPasted, - edgeAdded, undo, redo, } = nodesSlice.actions; @@ -507,6 +504,9 @@ export const $copiedEdges = atom([]); export const $edgesToCopiedNodes = atom([]); export const $pendingConnection = atom(null); export const $isUpdatingEdge = atom(false); +export const $didUpdateEdge = atom(false); +export const $lastEdgeUpdateMouseEvent = atom(null); + export const $viewport = atom({ x: 0, y: 0, zoom: 1 }); export const $isAddNodePopoverOpen = atom(false); export const closeAddNodePopover = () => { @@ -609,6 +609,5 @@ export const isAnyNodeOrEdgeMutation = isAnyOf( nodesDeleted, nodeUseCacheChanged, notesNodeValueChanged, - selectionPasted, - edgeAdded + selectionPasted );