feat(ui): rework getFirstValidConnection with new helpers

This commit is contained in:
psychedelicious 2024-05-19 09:59:29 +10:00
parent c98205d0d7
commit 83000a4190

View File

@ -1,117 +1,76 @@
import type { PendingConnection, Templates } from 'features/nodes/store/types';
import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes';
import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation';
import { differenceWith, map } from 'lodash-es';
import type { Templates } from 'features/nodes/store/types';
import { validateConnection } from 'features/nodes/store/util/validateConnection';
import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field';
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
import { map } from 'lodash-es';
import type { Connection, Edge } from 'reactflow';
import { assert } from 'tsafe';
import { areTypesEqual } from './areTypesEqual';
import { getCollectItemType } from './getCollectItemType';
import { getHasCycles } from './getHasCycles';
/**
* Finds the first valid field for a pending connection between two nodes.
* @param templates The invocation templates
*
* @param source The source (node id)
* @param sourceHandle The source handle (field name), if any
* @param target The target (node id)
* @param targetHandle The target handle (field name), if any
* @param nodes The current nodes
* @param edges The current edges
* @param pendingConnection The pending connection
* @param candidateNode The candidate node to which the connection is being made
* @param candidateTemplate The candidate template for the candidate node
* @returns The first valid connection, or null if no valid connection is found
* @param templates The current templates
* @param edgePendingUpdate The edge pending update, if any
* @returns
*/
export const getFirstValidConnection = (
templates: Templates,
source: string,
sourceHandle: string | null,
target: string,
targetHandle: string | null,
nodes: AnyNode[],
edges: InvocationNodeEdge[],
pendingConnection: PendingConnection,
candidateNode: InvocationNode,
candidateTemplate: InvocationTemplate,
templates: Templates,
edgePendingUpdate: Edge | null
): Connection | null => {
if (pendingConnection.node.id === candidateNode.id) {
// Cannot connect to self
if (source === target) {
return null;
}
const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source';
if (sourceHandle && targetHandle) {
return { source, sourceHandle, target, targetHandle };
}
if (pendingFieldKind === 'source') {
// Connecting from a source to a target
if (getHasCycles(pendingConnection.node.id, candidateNode.id, nodes, edges)) {
return null;
}
if (candidateNode.data.type === 'collect') {
// Special handling for collect node - the `item` field takes any number of connections
return {
source: pendingConnection.node.id,
sourceHandle: pendingConnection.fieldTemplate.name,
target: candidateNode.id,
targetHandle: 'item',
};
}
// Only one connection per target field is allowed - look for an unconnected target field
const candidateFields = map(candidateTemplate.inputs);
const candidateConnectedFields = edges
.filter((edge) => edge.target === candidateNode.id || edge.id === edgePendingUpdate?.id)
.map((edge) => {
// Edges must always have a targetHandle, safe to assert here
assert(edge.targetHandle);
return edge.targetHandle;
});
const candidateUnconnectedFields = differenceWith(
candidateFields,
candidateConnectedFields,
(field, connectedFieldName) => field.name === connectedFieldName
if (sourceHandle && !targetHandle) {
const candidates = getTargetCandidateFields(
source,
sourceHandle,
target,
nodes,
edges,
templates,
edgePendingUpdate
);
const candidateField = candidateUnconnectedFields.find((field) =>
validateConnectionTypes(pendingConnection.fieldTemplate.type, field.type)
);
if (candidateField) {
return {
source: pendingConnection.node.id,
sourceHandle: pendingConnection.fieldTemplate.name,
target: candidateNode.id,
targetHandle: candidateField.name,
};
}
} else {
// Connecting from a target to a source
// Ensure we there is not already an edge to the target, except for collect nodes
const isCollect = pendingConnection.node.data.type === 'collect';
const isTargetAlreadyConnected = edges.some(
(e) => e.target === pendingConnection.node.id && e.targetHandle === pendingConnection.fieldTemplate.name
);
if (!isCollect && isTargetAlreadyConnected) {
const firstCandidate = candidates[0];
if (!firstCandidate) {
return null;
}
if (getHasCycles(candidateNode.id, pendingConnection.node.id, nodes, edges)) {
return { source, sourceHandle, target, targetHandle: firstCandidate.name };
}
if (!sourceHandle && targetHandle) {
const candidates = getSourceCandidateFields(
target,
targetHandle,
source,
nodes,
edges,
templates,
edgePendingUpdate
);
const firstCandidate = candidates[0];
if (!firstCandidate) {
return null;
}
// Sources/outputs can have any number of edges, we can take the first matching output field
let candidateFields = map(candidateTemplate.outputs);
if (isCollect) {
// Narrow candidates to same field type as already is connected to the collect node
const collectItemType = getCollectItemType(templates, nodes, edges, pendingConnection.node.id);
if (collectItemType) {
candidateFields = candidateFields.filter((field) => areTypesEqual(field.type, collectItemType));
}
}
const candidateField = candidateFields.find((field) => {
const isValid = validateConnectionTypes(field.type, pendingConnection.fieldTemplate.type);
const isAlreadyConnected = edges.some((e) => e.source === candidateNode.id && e.sourceHandle === field.name);
return isValid && !isAlreadyConnected;
});
if (candidateField) {
return {
source: candidateNode.id,
sourceHandle: candidateField.name,
target: pendingConnection.node.id,
targetHandle: pendingConnection.fieldTemplate.name,
};
}
return { source, sourceHandle: firstCandidate.name, target, targetHandle };
}
return null;