diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx index 501513919a..18bbac0b44 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx @@ -9,8 +9,8 @@ import { useWorkflowWatcher } from 'features/nodes/hooks/useWorkflowWatcher'; import { $cursorPos, $didUpdateEdge, + $edgePendingUpdate, $isAddNodePopoverOpen, - $isUpdatingEdge, $lastEdgeUpdateMouseEvent, $pendingConnection, $viewport, @@ -160,8 +160,8 @@ export const Flow = memo(() => { * where the edge is deleted if you click it accidentally). */ - const onEdgeUpdateStart: NonNullable = useCallback((e, _edge, _handleType) => { - $isUpdatingEdge.set(true); + const onEdgeUpdateStart: NonNullable = useCallback((e, edge, _handleType) => { + $edgePendingUpdate.set(edge); $didUpdateEdge.set(false); $lastEdgeUpdateMouseEvent.set(e); }, []); @@ -196,7 +196,7 @@ export const Flow = memo(() => { dispatch(edgeDeleted(edge.id)); } - $isUpdatingEdge.set(false); + $edgePendingUpdate.set(null); $didUpdateEdge.set(false); $pendingConnection.set(null); $lastEdgeUpdateMouseEvent.set(null); @@ -259,7 +259,7 @@ export const Flow = memo(() => { useHotkeys(['meta+shift+z', 'ctrl+shift+z'], onRedoHotkey); const onEscapeHotkey = useCallback(() => { - if (!$isUpdatingEdge.get()) { + if (!$edgePendingUpdate.get()) { $pendingConnection.set(null); $isAddNodePopoverOpen.set(false); cancelConnection(); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts index 0190a0b29e..d81a9e5807 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts @@ -2,8 +2,8 @@ import { useStore } from '@nanostores/react'; import { useAppStore } from 'app/store/storeHooks'; import { $mouseOverNode } from 'features/nodes/hooks/useMouseOverNode'; import { + $edgePendingUpdate, $isAddNodePopoverOpen, - $isUpdatingEdge, $pendingConnection, $templates, connectionMade, @@ -52,12 +52,12 @@ export const useConnection = () => { const onConnectEnd = useCallback(() => { const { dispatch } = store; const pendingConnection = $pendingConnection.get(); - const isUpdatingEdge = $isUpdatingEdge.get(); + const edgePendingUpdate = $edgePendingUpdate.get(); const mouseOverNodeId = $mouseOverNode.get(); // If we are in the middle of an edge update, and the mouse isn't over a node, we should just bail so the edge // update logic can finish up - if (isUpdatingEdge && !mouseOverNodeId) { + if (edgePendingUpdate && !mouseOverNodeId) { $pendingConnection.set(null); return; } @@ -80,7 +80,8 @@ export const useConnection = () => { edges, pendingConnection, candidateNode, - candidateTemplate + candidateTemplate, + edgePendingUpdate ); if (connection) { dispatch(connectionMade(connection)); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts index 5dcb7a28b5..7649209863 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts @@ -1,7 +1,7 @@ import { useStore } from '@nanostores/react'; import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; -import { $pendingConnection, $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice'; +import { $edgePendingUpdate, $pendingConnection, $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice'; import { makeConnectionErrorSelector } from 'features/nodes/store/util/makeConnectionErrorSelector'; import { useMemo } from 'react'; @@ -14,6 +14,7 @@ type UseConnectionStateProps = { export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionStateProps) => { const pendingConnection = useStore($pendingConnection); const templates = useStore($templates); + const edgePendingUpdate = useStore($edgePendingUpdate); const selectIsConnected = useMemo( () => @@ -47,7 +48,7 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta pendingConnection.fieldTemplate.fieldKind === { inputs: 'input', outputs: 'output' }[kind] ); }, [fieldName, kind, nodeId, pendingConnection]); - const connectionError = useAppSelector((s) => selectConnectionError(s, pendingConnection)); + const connectionError = useAppSelector((s) => selectConnectionError(s, pendingConnection, edgePendingUpdate)); const shouldDim = useMemo( () => Boolean(isConnectionInProgress && connectionError && !isConnectionStartField), diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts index 0f8609d2ff..9a978b09a8 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts @@ -1,7 +1,7 @@ // TODO: enable this at some point import { useStore } from '@nanostores/react'; import { useAppSelector, useAppStore } from 'app/store/storeHooks'; -import { $templates } from 'features/nodes/store/nodesSlice'; +import { $edgePendingUpdate, $templates } from 'features/nodes/store/nodesSlice'; import { validateConnection } from 'features/nodes/store/util/validateConnection'; import { useCallback } from 'react'; import type { Connection } from 'reactflow'; @@ -21,7 +21,7 @@ export const useIsValidConnection = () => { if (!(source && sourceHandle && target && targetHandle)) { return false; } - + const edgePendingUpdate = $edgePendingUpdate.get(); const { nodes, edges } = store.getState().nodes.present; const validationResult = validateConnection( @@ -29,7 +29,7 @@ export const useIsValidConnection = () => { nodes, edges, templates, - null, + edgePendingUpdate, shouldValidateGraph ); diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index 83632c16e1..7915d3608c 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -503,7 +503,7 @@ export const $copiedNodes = atom([]); export const $copiedEdges = atom([]); export const $edgesToCopiedNodes = atom([]); export const $pendingConnection = atom(null); -export const $isUpdatingEdge = atom(false); +export const $edgePendingUpdate = atom(null); export const $didUpdateEdge = atom(false); export const $lastEdgeUpdateMouseEvent = atom(null); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts b/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts index 98155f0c20..00899c065d 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts @@ -2,7 +2,7 @@ import type { PendingConnection, Templates } from 'features/nodes/store/types'; import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes'; import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation'; import { differenceWith, map } from 'lodash-es'; -import type { Connection } from 'reactflow'; +import type { Connection, Edge } from 'reactflow'; import { assert } from 'tsafe'; import { areTypesEqual } from './areTypesEqual'; @@ -26,7 +26,8 @@ export const getFirstValidConnection = ( edges: InvocationNodeEdge[], pendingConnection: PendingConnection, candidateNode: InvocationNode, - candidateTemplate: InvocationTemplate + candidateTemplate: InvocationTemplate, + edgePendingUpdate: Edge | null ): Connection | null => { if (pendingConnection.node.id === candidateNode.id) { // Cannot connect to self @@ -52,7 +53,7 @@ export const getFirstValidConnection = ( // Only one connection per target field is allowed - look for an unconnected target field const candidateFields = map(candidateTemplate.inputs); const candidateConnectedFields = edges - .filter((edge) => edge.target === candidateNode.id) + .filter((edge) => edge.target === candidateNode.id || edge.id === edgePendingUpdate?.id) .map((edge) => { // Edges must always have a targetHandle, safe to assert here assert(edge.targetHandle); @@ -63,7 +64,8 @@ export const getFirstValidConnection = ( candidateConnectedFields, (field, connectedFieldName) => field.name === connectedFieldName ); - const candidateField = candidateUnconnectedFields.find((field) => validateConnectionTypes(pendingConnection.fieldTemplate.type, field.type) + const candidateField = candidateUnconnectedFields.find((field) => + validateConnectionTypes(pendingConnection.fieldTemplate.type, field.type) ); if (candidateField) { return { diff --git a/invokeai/frontend/web/src/features/nodes/store/util/makeConnectionErrorSelector.ts b/invokeai/frontend/web/src/features/nodes/store/util/makeConnectionErrorSelector.ts index 3cefb6815f..fb7ed49d41 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/makeConnectionErrorSelector.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/makeConnectionErrorSelector.ts @@ -4,7 +4,7 @@ import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; import type { NodesState, PendingConnection, Templates } from 'features/nodes/store/types'; import { validateConnection } from 'features/nodes/store/util/validateConnection'; import i18n from 'i18next'; -import type { HandleType } from 'reactflow'; +import type { Edge, HandleType } from 'reactflow'; /** * Creates a selector that validates a pending connection. @@ -27,7 +27,9 @@ export const makeConnectionErrorSelector = ( return createMemoizedSelector( selectNodesSlice, (state: RootState, pendingConnection: PendingConnection | null) => pendingConnection, - (nodesSlice: NodesState, pendingConnection: PendingConnection | null) => { + (state: RootState, pendingConnection: PendingConnection | null, edgePendingUpdate: Edge | null) => + edgePendingUpdate, + (nodesSlice: NodesState, pendingConnection: PendingConnection | null, edgePendingUpdate: Edge | null) => { const { nodes, edges } = nodesSlice; if (!pendingConnection) { @@ -61,7 +63,7 @@ export const makeConnectionErrorSelector = ( nodes, edges, templates, - null + edgePendingUpdate ); if (!validationResult.isValid) {