feat(ui): rework getFirstValidConnection with new helpers

This commit is contained in:
psychedelicious 2024-05-19 09:59:29 +10:00
parent c98205d0d7
commit 83000a4190

View File

@ -1,117 +1,76 @@
import type { PendingConnection, Templates } from 'features/nodes/store/types'; import type { Templates } from 'features/nodes/store/types';
import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes'; import { validateConnection } from 'features/nodes/store/util/validateConnection';
import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation'; import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field';
import { differenceWith, map } from 'lodash-es'; import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
import { map } from 'lodash-es';
import type { Connection, Edge } from 'reactflow'; import type { Connection, Edge } from 'reactflow';
import { assert } from 'tsafe';
import { areTypesEqual } from './areTypesEqual';
import { getCollectItemType } from './getCollectItemType';
import { getHasCycles } from './getHasCycles';
/** /**
* Finds the first valid field for a pending connection between two nodes. *
* @param templates The invocation templates * @param source The source (node id)
* @param sourceHandle The source handle (field name), if any
* @param target The target (node id)
* @param targetHandle The target handle (field name), if any
* @param nodes The current nodes * @param nodes The current nodes
* @param edges The current edges * @param edges The current edges
* @param pendingConnection The pending connection * @param templates The current templates
* @param candidateNode The candidate node to which the connection is being made * @param edgePendingUpdate The edge pending update, if any
* @param candidateTemplate The candidate template for the candidate node * @returns
* @returns The first valid connection, or null if no valid connection is found
*/ */
export const getFirstValidConnection = ( export const getFirstValidConnection = (
templates: Templates, source: string,
sourceHandle: string | null,
target: string,
targetHandle: string | null,
nodes: AnyNode[], nodes: AnyNode[],
edges: InvocationNodeEdge[], edges: InvocationNodeEdge[],
pendingConnection: PendingConnection, templates: Templates,
candidateNode: InvocationNode,
candidateTemplate: InvocationTemplate,
edgePendingUpdate: Edge | null edgePendingUpdate: Edge | null
): Connection | null => { ): Connection | null => {
if (pendingConnection.node.id === candidateNode.id) { if (source === target) {
// Cannot connect to self
return null; return null;
} }
const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source'; if (sourceHandle && targetHandle) {
return { source, sourceHandle, target, targetHandle };
}
if (pendingFieldKind === 'source') { if (sourceHandle && !targetHandle) {
// Connecting from a source to a target const candidates = getTargetCandidateFields(
if (getHasCycles(pendingConnection.node.id, candidateNode.id, nodes, edges)) { source,
return null; sourceHandle,
} target,
if (candidateNode.data.type === 'collect') { nodes,
// Special handling for collect node - the `item` field takes any number of connections edges,
return { templates,
source: pendingConnection.node.id, edgePendingUpdate
sourceHandle: pendingConnection.fieldTemplate.name,
target: candidateNode.id,
targetHandle: 'item',
};
}
// Only one connection per target field is allowed - look for an unconnected target field
const candidateFields = map(candidateTemplate.inputs);
const candidateConnectedFields = edges
.filter((edge) => edge.target === candidateNode.id || edge.id === edgePendingUpdate?.id)
.map((edge) => {
// Edges must always have a targetHandle, safe to assert here
assert(edge.targetHandle);
return edge.targetHandle;
});
const candidateUnconnectedFields = differenceWith(
candidateFields,
candidateConnectedFields,
(field, connectedFieldName) => field.name === connectedFieldName
); );
const candidateField = candidateUnconnectedFields.find((field) =>
validateConnectionTypes(pendingConnection.fieldTemplate.type, field.type) const firstCandidate = candidates[0];
); if (!firstCandidate) {
if (candidateField) {
return {
source: pendingConnection.node.id,
sourceHandle: pendingConnection.fieldTemplate.name,
target: candidateNode.id,
targetHandle: candidateField.name,
};
}
} else {
// Connecting from a target to a source
// Ensure we there is not already an edge to the target, except for collect nodes
const isCollect = pendingConnection.node.data.type === 'collect';
const isTargetAlreadyConnected = edges.some(
(e) => e.target === pendingConnection.node.id && e.targetHandle === pendingConnection.fieldTemplate.name
);
if (!isCollect && isTargetAlreadyConnected) {
return null; return null;
} }
if (getHasCycles(candidateNode.id, pendingConnection.node.id, nodes, edges)) { return { source, sourceHandle, target, targetHandle: firstCandidate.name };
}
if (!sourceHandle && targetHandle) {
const candidates = getSourceCandidateFields(
target,
targetHandle,
source,
nodes,
edges,
templates,
edgePendingUpdate
);
const firstCandidate = candidates[0];
if (!firstCandidate) {
return null; return null;
} }
// Sources/outputs can have any number of edges, we can take the first matching output field return { source, sourceHandle: firstCandidate.name, target, targetHandle };
let candidateFields = map(candidateTemplate.outputs);
if (isCollect) {
// Narrow candidates to same field type as already is connected to the collect node
const collectItemType = getCollectItemType(templates, nodes, edges, pendingConnection.node.id);
if (collectItemType) {
candidateFields = candidateFields.filter((field) => areTypesEqual(field.type, collectItemType));
}
}
const candidateField = candidateFields.find((field) => {
const isValid = validateConnectionTypes(field.type, pendingConnection.fieldTemplate.type);
const isAlreadyConnected = edges.some((e) => e.source === candidateNode.id && e.sourceHandle === field.name);
return isValid && !isAlreadyConnected;
});
if (candidateField) {
return {
source: candidateNode.id,
sourceHandle: candidateField.name,
target: pendingConnection.node.id,
targetHandle: pendingConnection.fieldTemplate.name,
};
}
} }
return null; return null;