diff --git a/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts b/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts index 1449a3298a..adc51341d7 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts @@ -1,117 +1,76 @@ -import type { PendingConnection, Templates } from 'features/nodes/store/types'; -import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes'; -import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation'; -import { differenceWith, map } from 'lodash-es'; +import type { Templates } from 'features/nodes/store/types'; +import { validateConnection } from 'features/nodes/store/util/validateConnection'; +import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field'; +import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation'; +import { map } from 'lodash-es'; import type { Connection, Edge } from 'reactflow'; -import { assert } from 'tsafe'; - -import { areTypesEqual } from './areTypesEqual'; -import { getCollectItemType } from './getCollectItemType'; -import { getHasCycles } from './getHasCycles'; /** - * Finds the first valid field for a pending connection between two nodes. - * @param templates The invocation templates + * + * @param source The source (node id) + * @param sourceHandle The source handle (field name), if any + * @param target The target (node id) + * @param targetHandle The target handle (field name), if any * @param nodes The current nodes * @param edges The current edges - * @param pendingConnection The pending connection - * @param candidateNode The candidate node to which the connection is being made - * @param candidateTemplate The candidate template for the candidate node - * @returns The first valid connection, or null if no valid connection is found + * @param templates The current templates + * @param edgePendingUpdate The edge pending update, if any + * @returns */ - export const getFirstValidConnection = ( - templates: Templates, + source: string, + sourceHandle: string | null, + target: string, + targetHandle: string | null, nodes: AnyNode[], edges: InvocationNodeEdge[], - pendingConnection: PendingConnection, - candidateNode: InvocationNode, - candidateTemplate: InvocationTemplate, + templates: Templates, edgePendingUpdate: Edge | null ): Connection | null => { - if (pendingConnection.node.id === candidateNode.id) { - // Cannot connect to self + if (source === target) { return null; } - const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source'; + if (sourceHandle && targetHandle) { + return { source, sourceHandle, target, targetHandle }; + } - if (pendingFieldKind === 'source') { - // Connecting from a source to a target - if (getHasCycles(pendingConnection.node.id, candidateNode.id, nodes, edges)) { - return null; - } - if (candidateNode.data.type === 'collect') { - // Special handling for collect node - the `item` field takes any number of connections - return { - source: pendingConnection.node.id, - sourceHandle: pendingConnection.fieldTemplate.name, - target: candidateNode.id, - targetHandle: 'item', - }; - } - // Only one connection per target field is allowed - look for an unconnected target field - const candidateFields = map(candidateTemplate.inputs); - const candidateConnectedFields = edges - .filter((edge) => edge.target === candidateNode.id || edge.id === edgePendingUpdate?.id) - .map((edge) => { - // Edges must always have a targetHandle, safe to assert here - assert(edge.targetHandle); - return edge.targetHandle; - }); - const candidateUnconnectedFields = differenceWith( - candidateFields, - candidateConnectedFields, - (field, connectedFieldName) => field.name === connectedFieldName + if (sourceHandle && !targetHandle) { + const candidates = getTargetCandidateFields( + source, + sourceHandle, + target, + nodes, + edges, + templates, + edgePendingUpdate ); - const candidateField = candidateUnconnectedFields.find((field) => - validateConnectionTypes(pendingConnection.fieldTemplate.type, field.type) - ); - if (candidateField) { - return { - source: pendingConnection.node.id, - sourceHandle: pendingConnection.fieldTemplate.name, - target: candidateNode.id, - targetHandle: candidateField.name, - }; - } - } else { - // Connecting from a target to a source - // Ensure we there is not already an edge to the target, except for collect nodes - const isCollect = pendingConnection.node.data.type === 'collect'; - const isTargetAlreadyConnected = edges.some( - (e) => e.target === pendingConnection.node.id && e.targetHandle === pendingConnection.fieldTemplate.name - ); - if (!isCollect && isTargetAlreadyConnected) { + + const firstCandidate = candidates[0]; + if (!firstCandidate) { return null; } - if (getHasCycles(candidateNode.id, pendingConnection.node.id, nodes, edges)) { + return { source, sourceHandle, target, targetHandle: firstCandidate.name }; + } + + if (!sourceHandle && targetHandle) { + const candidates = getSourceCandidateFields( + target, + targetHandle, + source, + nodes, + edges, + templates, + edgePendingUpdate + ); + + const firstCandidate = candidates[0]; + if (!firstCandidate) { return null; } - // Sources/outputs can have any number of edges, we can take the first matching output field - let candidateFields = map(candidateTemplate.outputs); - if (isCollect) { - // Narrow candidates to same field type as already is connected to the collect node - const collectItemType = getCollectItemType(templates, nodes, edges, pendingConnection.node.id); - if (collectItemType) { - candidateFields = candidateFields.filter((field) => areTypesEqual(field.type, collectItemType)); - } - } - const candidateField = candidateFields.find((field) => { - const isValid = validateConnectionTypes(field.type, pendingConnection.fieldTemplate.type); - const isAlreadyConnected = edges.some((e) => e.source === candidateNode.id && e.sourceHandle === field.name); - return isValid && !isAlreadyConnected; - }); - if (candidateField) { - return { - source: candidateNode.id, - sourceHandle: candidateField.name, - target: pendingConnection.node.id, - targetHandle: pendingConnection.fieldTemplate.name, - }; - } + return { source, sourceHandle: firstCandidate.name, target, targetHandle }; } return null;