diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx index 14d69b4720..561890245e 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx @@ -9,6 +9,7 @@ import type { SelectInstance } from 'chakra-react-select'; import { useBuildNode } from 'features/nodes/hooks/useBuildNode'; import { $cursorPos, + $edgePendingUpdate, $isAddNodePopoverOpen, $pendingConnection, $templates, @@ -28,7 +29,6 @@ import { useHotkeys } from 'react-hotkeys-hook'; import type { HotkeyCallback } from 'react-hotkeys-hook/dist/types'; import { useTranslation } from 'react-i18next'; import type { FilterOptionOption } from 'react-select/dist/declarations/src/filters'; -import { assert } from 'tsafe'; const createRegex = memoize( (inputValue: string) => @@ -68,16 +68,18 @@ const AddNodePopover = () => { const filteredTemplates = useMemo(() => { // If we have a connection in progress, we need to filter the node choices + const templatesArray = map(templates); if (!pendingConnection) { - return map(templates); + return templatesArray; } return filter(templates, (template) => { - const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind; - const fields = pendingFieldKind === 'input' ? template.outputs : template.inputs; - return some(fields, (field) => { - const sourceType = pendingFieldKind === 'input' ? field.type : pendingConnection.fieldTemplate.type; - const targetType = pendingFieldKind === 'output' ? field.type : pendingConnection.fieldTemplate.type; + const candidateFields = pendingConnection.handleType === 'source' ? template.inputs : template.outputs; + return some(candidateFields, (field) => { + const sourceType = + pendingConnection.handleType === 'source' ? field.type : pendingConnection.fieldTemplate.type; + const targetType = + pendingConnection.handleType === 'target' ? field.type : pendingConnection.fieldTemplate.type; return validateConnectionTypes(sourceType, targetType); }); }); @@ -144,10 +146,25 @@ const AddNodePopover = () => { // Auto-connect an edge if we just added a node and have a pending connection if (pendingConnection && isInvocationNode(node)) { - const template = templates[node.data.type]; - assert(template, 'Template not found'); + const edgePendingUpdate = $edgePendingUpdate.get(); + const { handleType } = pendingConnection; + + const source = handleType === 'source' ? pendingConnection.nodeId : node.id; + const sourceHandle = handleType === 'source' ? pendingConnection.handleId : null; + const target = handleType === 'target' ? pendingConnection.nodeId : node.id; + const targetHandle = handleType === 'target' ? pendingConnection.handleId : null; + const { nodes, edges } = store.getState().nodes.present; - const connection = getFirstValidConnection(templates, nodes, edges, pendingConnection, node, template); + const connection = getFirstValidConnection( + source, + sourceHandle, + target, + targetHandle, + nodes, + edges, + templates, + edgePendingUpdate + ); if (connection) { dispatch(connectionMade(connection)); } diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts index d81a9e5807..f7bf1b8740 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts @@ -9,10 +9,10 @@ import { connectionMade, } from 'features/nodes/store/nodesSlice'; import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection'; -import { isInvocationNode } from 'features/nodes/types/invocation'; import { isString } from 'lodash-es'; import { useCallback, useMemo } from 'react'; -import { type OnConnect, type OnConnectEnd, type OnConnectStart, useUpdateNodeInternals } from 'reactflow'; +import type { OnConnect, OnConnectEnd, OnConnectStart } from 'reactflow'; +import { useUpdateNodeInternals } from 'reactflow'; import { assert } from 'tsafe'; export const useConnection = () => { @@ -21,21 +21,27 @@ export const useConnection = () => { const updateNodeInternals = useUpdateNodeInternals(); const onConnectStart = useCallback( - (event, params) => { + (event, { nodeId, handleId, handleType }) => { + assert(nodeId && handleId && handleType, 'Invalid connection start event'); const nodes = store.getState().nodes.present.nodes; - const { nodeId, handleId, handleType } = params; - assert(nodeId && handleId && handleType, `Invalid connection start params: ${JSON.stringify(params)}`); + const node = nodes.find((n) => n.id === nodeId); - assert(isInvocationNode(node), `Invalid node during connection: ${JSON.stringify(node)}`); + if (!node) { + return; + } + const template = templates[node.data.type]; - assert(template, `Template not found for node type: ${node.data.type}`); - const fieldTemplate = handleType === 'source' ? template.outputs[handleId] : template.inputs[handleId]; - assert(fieldTemplate, `Field template not found for field: ${node.data.type}.${handleId}`); - $pendingConnection.set({ - node, - template, - fieldTemplate, - }); + if (!template) { + return; + } + + const fieldTemplates = template[handleType === 'source' ? 'outputs' : 'inputs']; + const fieldTemplate = fieldTemplates[handleId]; + if (!fieldTemplate) { + return; + } + + $pendingConnection.set({ nodeId, handleId, handleType, fieldTemplate }); }, [store, templates] ); @@ -67,20 +73,20 @@ export const useConnection = () => { } const { nodes, edges } = store.getState().nodes.present; if (mouseOverNodeId) { - const candidateNode = nodes.filter(isInvocationNode).find((n) => n.id === mouseOverNodeId); - if (!candidateNode) { - // The mouse is over a non-invocation node - bail - return; - } - const candidateTemplate = templates[candidateNode.data.type]; - assert(candidateTemplate, `Template not found for node type: ${candidateNode.data.type}`); + const { handleType } = pendingConnection; + const source = handleType === 'source' ? pendingConnection.nodeId : mouseOverNodeId; + const sourceHandle = handleType === 'source' ? pendingConnection.handleId : null; + const target = handleType === 'target' ? pendingConnection.nodeId : mouseOverNodeId; + const targetHandle = handleType === 'target' ? pendingConnection.handleId : null; + const connection = getFirstValidConnection( - templates, + source, + sourceHandle, + target, + targetHandle, nodes, edges, - pendingConnection, - candidateNode, - candidateTemplate, + templates, edgePendingUpdate ); if (connection) { diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts index 7649209863..d218734fff 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts @@ -43,8 +43,8 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta return false; } return ( - pendingConnection.node.id === nodeId && - pendingConnection.fieldTemplate.name === fieldName && + pendingConnection.nodeId === nodeId && + pendingConnection.handleId === fieldName && pendingConnection.fieldTemplate.fieldKind === { inputs: 'input', outputs: 'output' }[kind] ); }, [fieldName, kind, nodeId, pendingConnection]); diff --git a/invokeai/frontend/web/src/features/nodes/store/types.ts b/invokeai/frontend/web/src/features/nodes/store/types.ts index 2f514bdb5b..6dcf70cfad 100644 --- a/invokeai/frontend/web/src/features/nodes/store/types.ts +++ b/invokeai/frontend/web/src/features/nodes/store/types.ts @@ -6,19 +6,20 @@ import type { } from 'features/nodes/types/field'; import type { AnyNode, - InvocationNode, InvocationNodeEdge, InvocationTemplate, NodeExecutionState, } from 'features/nodes/types/invocation'; import type { WorkflowV3 } from 'features/nodes/types/workflow'; +import type { HandleType } from 'reactflow'; export type Templates = Record; export type NodeExecutionStates = Record; export type PendingConnection = { - node: InvocationNode; - template: InvocationTemplate; + nodeId: string; + handleId: string; + handleType: HandleType; fieldTemplate: FieldInputTemplate | FieldOutputTemplate; }; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/makeConnectionErrorSelector.ts b/invokeai/frontend/web/src/features/nodes/store/util/makeConnectionErrorSelector.ts index e1a443a60e..c6d05d2c7c 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/makeConnectionErrorSelector.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/makeConnectionErrorSelector.ts @@ -36,9 +36,7 @@ export const makeConnectionErrorSelector = ( return i18n.t('nodes.noConnectionInProgress'); } - const connectionHandleType = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source'; - - if (handleType === connectionHandleType) { + if (handleType === pendingConnection.handleType) { if (handleType === 'source') { return i18n.t('nodes.cannotConnectOutputToOutput'); } @@ -46,10 +44,10 @@ export const makeConnectionErrorSelector = ( } // we have to figure out which is the target and which is the source - const source = handleType === 'source' ? nodeId : pendingConnection.node.id; - const sourceHandle = handleType === 'source' ? fieldName : pendingConnection.fieldTemplate.name; - const target = handleType === 'target' ? nodeId : pendingConnection.node.id; - const targetHandle = handleType === 'target' ? fieldName : pendingConnection.fieldTemplate.name; + const source = handleType === 'source' ? nodeId : pendingConnection.nodeId; + const sourceHandle = handleType === 'source' ? fieldName : pendingConnection.handleId; + const target = handleType === 'target' ? nodeId : pendingConnection.nodeId; + const targetHandle = handleType === 'target' ? fieldName : pendingConnection.handleId; const validationResult = validateConnection( { diff --git a/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts b/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts index f351083bc5..5155bb14ea 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts @@ -2,7 +2,7 @@ import type { Templates } from 'features/nodes/store/types'; import type { InvocationTemplate } from 'features/nodes/types/invocation'; import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode'; import type { OpenAPIV3_1 } from 'openapi-types'; -import type { Edge, XYPosition } from 'reactflow'; +import type { Edge } from 'reactflow'; export const buildEdge = (source: string, sourceHandle: string, target: string, targetHandle: string): Edge => ({ source, @@ -13,8 +13,6 @@ export const buildEdge = (source: string, sourceHandle: string, target: string, id: `reactflow__edge-${source}${sourceHandle}-${target}${targetHandle}`, }); -export const position: XYPosition = { x: 0, y: 0 }; - export const buildNode = (template: InvocationTemplate) => buildInvocationNode({ x: 0, y: 0 }, template); export const add: InvocationTemplate = { @@ -176,7 +174,7 @@ export const collect: InvocationTemplate = { classification: 'stable', }; -export const scheduler: InvocationTemplate = { +const scheduler: InvocationTemplate = { title: 'Scheduler', type: 'scheduler', version: '1.0.0', diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts index edb8ac5ecb..56e45c0d80 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts @@ -6,11 +6,10 @@ import { validateConnectionTypes } from 'features/nodes/store/util/validateConne import type { AnyNode } from 'features/nodes/types/invocation'; import type { Connection as NullableConnection, Edge } from 'reactflow'; import type { O } from 'ts-toolbelt'; -import { assert } from 'tsafe'; type Connection = O.NonNullable; -export type ValidateConnectionResult = +type ValidateConnectionResult = | { isValid: true; messageTKey?: string; @@ -20,7 +19,7 @@ export type ValidateConnectionResult = messageTKey: string; }; -export type ValidateConnectionFunc = ( +type ValidateConnectionFunc = ( connection: Connection, nodes: AnyNode[], edges: Edge[], @@ -29,21 +28,6 @@ export type ValidateConnectionFunc = ( strict?: boolean ) => ValidateConnectionResult; -export const buildResult = (isValid: boolean, messageTKey?: string): ValidateConnectionResult => { - if (isValid) { - return { - isValid, - messageTKey, - }; - } else { - assert(messageTKey !== undefined); - return { - isValid, - messageTKey, - }; - } -}; - const getEqualityPredicate = (c: Connection) => (e: Edge): boolean => {