feat(ui): add strict mode to validateConnection

This commit is contained in:
psychedelicious 2024-05-19 01:26:43 +10:00
parent 78f9f3ee95
commit 6ad01d824d
2 changed files with 91 additions and 62 deletions

View File

@ -166,4 +166,30 @@ describe(validateConnection.name, () => {
const r = validateConnection(c, nodes, edges, templates, null);
expect(r).toEqual(buildRejectResult('nodes.connectionWouldCreateCycle'));
});
describe('non-strict mode', () => {
it('should reject connections from self to self in non-strict mode', () => {
const c = { source: 'add', sourceHandle: 'value', target: 'add', targetHandle: 'a' };
const r = validateConnection(c, [], [], templates, null, false);
expect(r).toEqual(buildRejectResult('nodes.cannotConnectToSelf'));
});
it('should reject connections that create cycles in non-strict mode', () => {
const n1 = buildInvocationNode(position, add);
const n2 = buildInvocationNode(position, sub);
const nodes = [n1, n2];
const e1 = buildEdge(n1.id, 'value', n2.id, 'a');
const edges = [e1];
const c = { source: n2.id, sourceHandle: 'value', target: n1.id, targetHandle: 'a' };
const r = validateConnection(c, nodes, edges, templates, null, false);
expect(r).toEqual(buildRejectResult('nodes.connectionWouldCreateCycle'));
});
it('should otherwise allow invalid connections in non-strict mode', () => {
const n1 = buildInvocationNode(position, add);
const n2 = buildInvocationNode(position, img_resize);
const nodes = [n1, n2];
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'image' };
const r = validateConnection(c, nodes, [], templates, null, false);
expect(r).toEqual(buildAcceptResult());
});
});
});

View File

@ -25,7 +25,8 @@ export type ValidateConnectionFunc = (
nodes: AnyNode[],
edges: Edge[],
templates: Templates,
ignoreEdge: Edge | null
ignoreEdge: Edge | null,
strict?: boolean
) => ValidateConnectionResult;
export const buildResult = (isValid: boolean, messageTKey?: string): ValidateConnectionResult => {
@ -63,76 +64,78 @@ const getTargetEqualityPredicate =
export const buildAcceptResult = (): ValidateConnectionResult => ({ isValid: true });
export const buildRejectResult = (messageTKey: string): ValidateConnectionResult => ({ isValid: false, messageTKey });
export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, templates, ignoreEdge) => {
export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, templates, ignoreEdge, strict = true) => {
if (c.source === c.target) {
return buildRejectResult('nodes.cannotConnectToSelf');
}
/**
* We may need to ignore an edge when validating a connection.
*
* For example, while an edge is being updated, it still exists in the array of edges. As we validate the new connection,
* the user experience should be that the edge is temporarily removed from the graph, so we need to ignore it, else
* the validation will fail unexpectedly.
*/
const filteredEdges = edges.filter((e) => e.id !== ignoreEdge?.id);
if (strict) {
/**
* We may need to ignore an edge when validating a connection.
*
* For example, while an edge is being updated, it still exists in the array of edges. As we validate the new connection,
* the user experience should be that the edge is temporarily removed from the graph, so we need to ignore it, else
* the validation will fail unexpectedly.
*/
const filteredEdges = edges.filter((e) => e.id !== ignoreEdge?.id);
if (filteredEdges.some(getEqualityPredicate(c))) {
// We already have a connection from this source to this target
return buildRejectResult('nodes.cannotDuplicateConnection');
}
const sourceNode = nodes.find((n) => n.id === c.source);
if (!sourceNode) {
return buildRejectResult('nodes.missingNode');
}
const targetNode = nodes.find((n) => n.id === c.target);
if (!targetNode) {
return buildRejectResult('nodes.missingNode');
}
const sourceTemplate = templates[sourceNode.data.type];
if (!sourceTemplate) {
return buildRejectResult('nodes.missingInvocationTemplate');
}
const targetTemplate = templates[targetNode.data.type];
if (!targetTemplate) {
return buildRejectResult('nodes.missingInvocationTemplate');
}
const sourceFieldTemplate = sourceTemplate.outputs[c.sourceHandle];
if (!sourceFieldTemplate) {
return buildRejectResult('nodes.missingFieldTemplate');
}
const targetFieldTemplate = targetTemplate.inputs[c.targetHandle];
if (!targetFieldTemplate) {
return buildRejectResult('nodes.missingFieldTemplate');
}
if (targetFieldTemplate.input === 'direct') {
return buildRejectResult('nodes.cannotConnectToDirectInput');
}
if (targetNode.data.type === 'collect' && c.targetHandle === 'item') {
// Collect nodes shouldn't mix and match field types.
const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id);
if (collectItemType && !areTypesEqual(sourceFieldTemplate.type, collectItemType)) {
return buildRejectResult('nodes.cannotMixAndMatchCollectionItemTypes');
if (filteredEdges.some(getEqualityPredicate(c))) {
// We already have a connection from this source to this target
return buildRejectResult('nodes.cannotDuplicateConnection');
}
}
if (filteredEdges.find(getTargetEqualityPredicate(c))) {
// CollectionItemField inputs can have multiple input connections
if (targetFieldTemplate.type.name !== 'CollectionItemField') {
return buildRejectResult('nodes.inputMayOnlyHaveOneConnection');
const sourceNode = nodes.find((n) => n.id === c.source);
if (!sourceNode) {
return buildRejectResult('nodes.missingNode');
}
}
if (!validateConnectionTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) {
return buildRejectResult('nodes.fieldTypesMustMatch');
const targetNode = nodes.find((n) => n.id === c.target);
if (!targetNode) {
return buildRejectResult('nodes.missingNode');
}
const sourceTemplate = templates[sourceNode.data.type];
if (!sourceTemplate) {
return buildRejectResult('nodes.missingInvocationTemplate');
}
const targetTemplate = templates[targetNode.data.type];
if (!targetTemplate) {
return buildRejectResult('nodes.missingInvocationTemplate');
}
const sourceFieldTemplate = sourceTemplate.outputs[c.sourceHandle];
if (!sourceFieldTemplate) {
return buildRejectResult('nodes.missingFieldTemplate');
}
const targetFieldTemplate = targetTemplate.inputs[c.targetHandle];
if (!targetFieldTemplate) {
return buildRejectResult('nodes.missingFieldTemplate');
}
if (targetFieldTemplate.input === 'direct') {
return buildRejectResult('nodes.cannotConnectToDirectInput');
}
if (targetNode.data.type === 'collect' && c.targetHandle === 'item') {
// Collect nodes shouldn't mix and match field types.
const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id);
if (collectItemType && !areTypesEqual(sourceFieldTemplate.type, collectItemType)) {
return buildRejectResult('nodes.cannotMixAndMatchCollectionItemTypes');
}
}
if (filteredEdges.find(getTargetEqualityPredicate(c))) {
// CollectionItemField inputs can have multiple input connections
if (targetFieldTemplate.type.name !== 'CollectionItemField') {
return buildRejectResult('nodes.inputMayOnlyHaveOneConnection');
}
}
if (!validateConnectionTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) {
return buildRejectResult('nodes.fieldTypesMustMatch');
}
}
if (getHasCycles(c.source, c.target, nodes, edges)) {