diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index e9261692a2..6e783b0567 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -717,6 +717,7 @@ "cannotConnectInputToInput": "Cannot connect input to input", "cannotConnectOutputToOutput": "Cannot connect output to output", "cannotConnectToSelf": "Cannot connect to self", + "cannotDuplicateConnection": "Cannot create duplicate connections", "clipField": "Clip", "clipFieldDescription": "Tokenizer and text_encoder submodels.", "collection": "Collection", diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx index 57e5825fb9..e2ff7c5bb0 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx @@ -12,6 +12,7 @@ import { OnConnect, OnConnectEnd, OnConnectStart, + OnEdgeUpdateFunc, OnEdgesChange, OnEdgesDelete, OnInit, @@ -21,6 +22,7 @@ import { OnSelectionChangeFunc, ProOptions, ReactFlow, + ReactFlowProps, XYPosition, } from 'reactflow'; import { useIsValidConnection } from '../../hooks/useIsValidConnection'; @@ -28,6 +30,8 @@ import { connectionEnded, connectionMade, connectionStarted, + edgeAdded, + edgeDeleted, edgesChanged, edgesDeleted, nodesChanged, @@ -167,6 +171,63 @@ export const Flow = () => { } }, []); + // #region Updatable Edges + + /** + * Adapted from https://reactflow.dev/docs/examples/edges/updatable-edge/ + * and https://reactflow.dev/docs/examples/edges/delete-edge-on-drop/ + * + * - Edges can be dragged from one handle to another. + * - If the user drags the edge away from the node and drops it, delete the edge. + * - Do not delete the edge if the cursor didn't move (resolves annoying behaviour + * where the edge is deleted if you click it accidentally). + */ + + // We have a ref for cursor position, but it is the *projected* cursor position. + // Easiest to just keep track of the last mouse event for this particular feature + const edgeUpdateMouseEvent = useRef(); + + const onEdgeUpdateStart: NonNullable = + useCallback( + (e, edge, _handleType) => { + // update mouse event + edgeUpdateMouseEvent.current = e; + // always delete the edge when starting an updated + dispatch(edgeDeleted(edge.id)); + }, + [dispatch] + ); + + const onEdgeUpdate: OnEdgeUpdateFunc = useCallback( + (_oldEdge, newConnection) => { + // instead of updating the edge (we deleted it earlier), we instead create + // a new one. + dispatch(connectionMade(newConnection)); + }, + [dispatch] + ); + + const onEdgeUpdateEnd: NonNullable = + useCallback( + (e, edge, _handleType) => { + // Handle the case where user begins a drag but didn't move the cursor - + // bc we deleted the edge, we need to add it back + if ( + // ignore touch events + !('touches' in e) && + edgeUpdateMouseEvent.current?.clientX === e.clientX && + edgeUpdateMouseEvent.current?.clientY === e.clientY + ) { + dispatch(edgeAdded(edge)); + } + // reset mouse event + edgeUpdateMouseEvent.current = undefined; + }, + [dispatch] + ); + + // #endregion + useHotkeys(['Ctrl+c', 'Meta+c'], (e) => { e.preventDefault(); dispatch(selectionCopied()); @@ -196,6 +257,9 @@ export const Flow = () => { onNodesChange={onNodesChange} onEdgesChange={onEdgesChange} onEdgesDelete={onEdgesDelete} + onEdgeUpdate={onEdgeUpdate} + onEdgeUpdateStart={onEdgeUpdateStart} + onEdgeUpdateEnd={onEdgeUpdateEnd} onNodesDelete={onNodesDelete} onConnectStart={onConnectStart} onConnect={onConnect} diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts index 0439445c24..a57787556c 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts @@ -53,13 +53,12 @@ export const useIsValidConnection = () => { } if ( - edges - .filter((edge) => { - return edge.target === target && edge.targetHandle === targetHandle; - }) - .find((edge) => { - edge.source === source && edge.sourceHandle === sourceHandle; - }) + edges.find((edge) => { + edge.target === target && + edge.targetHandle === targetHandle && + edge.source === source && + edge.sourceHandle === sourceHandle; + }) ) { // We already have a connection from this source to this target return false; diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index 01de3de883..1b3a5ca929 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -15,6 +15,7 @@ import { NodeChange, OnConnectStartParams, SelectionMode, + updateEdge, Viewport, XYPosition, } from 'reactflow'; @@ -182,6 +183,16 @@ const nodesSlice = createSlice({ edgesChanged: (state, action: PayloadAction) => { state.edges = applyEdgeChanges(action.payload, state.edges); }, + edgeAdded: (state, action: PayloadAction) => { + state.edges = addEdge(action.payload, state.edges); + }, + edgeUpdated: ( + state, + action: PayloadAction<{ oldEdge: Edge; newConnection: Connection }> + ) => { + const { oldEdge, newConnection } = action.payload; + state.edges = updateEdge(oldEdge, newConnection, state.edges); + }, connectionStarted: (state, action: PayloadAction) => { state.connectionStartParams = action.payload; const { nodeId, handleId, handleType } = action.payload; @@ -366,6 +377,7 @@ const nodesSlice = createSlice({ target: edge.target, type: 'collapsed', data: { count: 1 }, + updatable: false, }); } } @@ -388,6 +400,7 @@ const nodesSlice = createSlice({ target: edge.target, type: 'collapsed', data: { count: 1 }, + updatable: false, }); } } @@ -400,6 +413,9 @@ const nodesSlice = createSlice({ } } }, + edgeDeleted: (state, action: PayloadAction) => { + state.edges = state.edges.filter((e) => e.id !== action.payload); + }, edgesDeleted: (state, action: PayloadAction) => { const edges = action.payload; const collapsedEdges = edges.filter((e) => e.type === 'collapsed'); @@ -890,69 +906,72 @@ const nodesSlice = createSlice({ }); export const { - nodesChanged, - edgesChanged, - nodeAdded, - nodesDeleted, + addNodePopoverClosed, + addNodePopoverOpened, + addNodePopoverToggled, + connectionEnded, connectionMade, connectionStarted, - connectionEnded, - shouldShowFieldTypeLegendChanged, - shouldShowMinimapPanelChanged, - nodeTemplatesBuilt, - nodeEditorReset, - imageCollectionFieldValueChanged, - fieldStringValueChanged, - fieldNumberValueChanged, + edgeDeleted, + edgesChanged, + edgesDeleted, + edgeUpdated, fieldBoardValueChanged, fieldBooleanValueChanged, - fieldImageValueChanged, fieldColorValueChanged, - fieldMainModelValueChanged, - fieldVaeModelValueChanged, - fieldLoRAModelValueChanged, - fieldEnumModelValueChanged, fieldControlNetModelValueChanged, + fieldEnumModelValueChanged, + fieldImageValueChanged, fieldIPAdapterModelValueChanged, + fieldLabelChanged, + fieldLoRAModelValueChanged, + fieldMainModelValueChanged, + fieldNumberValueChanged, fieldRefinerModelValueChanged, fieldSchedulerValueChanged, + fieldStringValueChanged, + fieldVaeModelValueChanged, + imageCollectionFieldValueChanged, + mouseOverFieldChanged, + mouseOverNodeChanged, + nodeAdded, + nodeEditorReset, + nodeEmbedWorkflowChanged, + nodeExclusivelySelected, + nodeIsIntermediateChanged, nodeIsOpenChanged, nodeLabelChanged, nodeNotesChanged, - edgesDeleted, - shouldValidateGraphChanged, - shouldAnimateEdgesChanged, nodeOpacityChanged, - shouldSnapToGridChanged, - shouldColorEdgesChanged, - selectedNodesChanged, - selectedEdgesChanged, - workflowNameChanged, - workflowDescriptionChanged, - workflowTagsChanged, - workflowAuthorChanged, - workflowNotesChanged, - workflowVersionChanged, - workflowContactChanged, - workflowLoaded, + nodesChanged, + nodesDeleted, + nodeTemplatesBuilt, + nodeUseCacheChanged, notesNodeValueChanged, + selectedAll, + selectedEdgesChanged, + selectedNodesChanged, + selectionCopied, + selectionModeChanged, + selectionPasted, + shouldAnimateEdgesChanged, + shouldColorEdgesChanged, + shouldShowFieldTypeLegendChanged, + shouldShowMinimapPanelChanged, + shouldSnapToGridChanged, + shouldValidateGraphChanged, + viewportChanged, + workflowAuthorChanged, + workflowContactChanged, + workflowDescriptionChanged, workflowExposedFieldAdded, workflowExposedFieldRemoved, - fieldLabelChanged, - viewportChanged, - mouseOverFieldChanged, - selectionCopied, - selectionPasted, - selectedAll, - addNodePopoverOpened, - addNodePopoverClosed, - addNodePopoverToggled, - selectionModeChanged, - nodeEmbedWorkflowChanged, - nodeIsIntermediateChanged, - mouseOverNodeChanged, - nodeExclusivelySelected, - nodeUseCacheChanged, + workflowLoaded, + workflowNameChanged, + workflowNotesChanged, + workflowTagsChanged, + workflowVersionChanged, + edgeAdded, } = nodesSlice.actions; export default nodesSlice.reducer; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts index 1be2d579d8..6343240a88 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts @@ -55,9 +55,29 @@ export const makeConnectionErrorSelector = ( return i18n.t('nodes.cannotConnectInputToInput'); } + // we have to figure out which is the target and which is the source + const target = handleType === 'target' ? nodeId : connectionNodeId; + const targetHandle = + handleType === 'target' ? fieldName : connectionFieldName; + const source = handleType === 'source' ? nodeId : connectionNodeId; + const sourceHandle = + handleType === 'source' ? fieldName : connectionFieldName; + if ( edges.find((edge) => { - return edge.target === nodeId && edge.targetHandle === fieldName; + edge.target === target && + edge.targetHandle === targetHandle && + edge.source === source && + edge.sourceHandle === sourceHandle; + }) + ) { + // We already have a connection from this source to this target + return i18n.t('nodes.cannotDuplicateConnection'); + } + + if ( + edges.find((edge) => { + return edge.target === target && edge.targetHandle === targetHandle; }) && // except CollectionItem inputs can have multiples targetType !== 'CollectionItem'