feat(ui): prevent connections to direct-only inputs

This commit is contained in:
psychedelicious
2024-05-17 20:08:37 +10:00
parent ad8778df6c
commit 575ecb4028
3 changed files with 25 additions and 12 deletions

View File

@ -45,6 +45,10 @@ export const useIsValidConnection = () => {
return false; return false;
} }
if (targetFieldTemplate.input === 'direct') {
return false;
}
if (!shouldValidateGraph) { if (!shouldValidateGraph) {
// manual override! // manual override!
return true; return true;

View File

@ -38,7 +38,7 @@ export const getFirstValidConnection = (
}; };
} }
// Only one connection per target field is allowed - look for an unconnected target field // Only one connection per target field is allowed - look for an unconnected target field
const candidateFields = map(candidateTemplate.inputs); const candidateFields = map(candidateTemplate.inputs).filter((i) => i.input !== 'direct');
const candidateConnectedFields = edges const candidateConnectedFields = edges
.filter((edge) => edge.target === candidateNode.id) .filter((edge) => edge.target === candidateNode.id)
.map((edge) => { .map((edge) => {

View File

@ -6,6 +6,7 @@ import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocatio
import i18n from 'i18next'; import i18n from 'i18next';
import { isEqual } from 'lodash-es'; import { isEqual } from 'lodash-es';
import type { HandleType } from 'reactflow'; import type { HandleType } from 'reactflow';
import { assert } from 'tsafe';
import { getIsGraphAcyclic } from './getIsGraphAcyclic'; import { getIsGraphAcyclic } from './getIsGraphAcyclic';
import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes'; import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
@ -80,25 +81,33 @@ export const makeConnectionErrorSelector = (
} }
// we have to figure out which is the target and which is the source // we have to figure out which is the target and which is the source
const target = handleType === 'target' ? nodeId : connectionNodeId; const targetNodeId = handleType === 'target' ? nodeId : connectionNodeId;
const targetHandle = handleType === 'target' ? fieldName : connectionFieldName; const targetFieldName = handleType === 'target' ? fieldName : connectionFieldName;
const source = handleType === 'source' ? nodeId : connectionNodeId; const sourceNodeId = handleType === 'source' ? nodeId : connectionNodeId;
const sourceHandle = handleType === 'source' ? fieldName : connectionFieldName; const sourceFieldName = handleType === 'source' ? fieldName : connectionFieldName;
if ( if (
edges.find((edge) => { edges.find((edge) => {
edge.target === target && edge.target === targetNodeId &&
edge.targetHandle === targetHandle && edge.targetHandle === targetFieldName &&
edge.source === source && edge.source === sourceNodeId &&
edge.sourceHandle === sourceHandle; edge.sourceHandle === sourceFieldName;
}) })
) { ) {
// We already have a connection from this source to this target // We already have a connection from this source to this target
return i18n.t('nodes.cannotDuplicateConnection'); return i18n.t('nodes.cannotDuplicateConnection');
} }
const targetNode = nodes.find((node) => node.id === target); const targetNode = nodes.find((node) => node.id === targetNodeId);
if (targetNode?.data.type === 'collect' && targetHandle === 'item') { assert(targetNode, `Target node not found: ${targetNodeId}`);
const targetTemplate = templates[targetNode.data.type];
assert(targetTemplate, `Target template not found: ${targetNode.data.type}`);
if (targetTemplate.inputs[targetFieldName]?.input === 'direct') {
return i18n.t('nodes.cannotConnectToDirectInput');
}
if (targetNode.data.type === 'collect' && targetFieldName === 'item') {
// Collect nodes shouldn't mix and match field types // Collect nodes shouldn't mix and match field types
const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id); const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id);
if (collectItemType) { if (collectItemType) {
@ -110,7 +119,7 @@ export const makeConnectionErrorSelector = (
if ( if (
edges.find((edge) => { edges.find((edge) => {
return edge.target === target && edge.targetHandle === targetHandle; return edge.target === targetNodeId && edge.targetHandle === targetFieldName;
}) && }) &&
// except CollectionItem inputs can have multiples // except CollectionItem inputs can have multiples
targetType.name !== 'CollectionItemField' targetType.name !== 'CollectionItemField'