From 6ad01d824d689319232bd6654bf722962766cfa0 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 01:26:43 +1000 Subject: [PATCH] feat(ui): add strict mode to validateConnection --- .../store/util/validateConnection.test.ts | 26 ++++ .../nodes/store/util/validateConnection.ts | 127 +++++++++--------- 2 files changed, 91 insertions(+), 62 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts index cf05b4deb6..108839a499 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts @@ -166,4 +166,30 @@ describe(validateConnection.name, () => { const r = validateConnection(c, nodes, edges, templates, null); expect(r).toEqual(buildRejectResult('nodes.connectionWouldCreateCycle')); }); + + describe('non-strict mode', () => { + it('should reject connections from self to self in non-strict mode', () => { + const c = { source: 'add', sourceHandle: 'value', target: 'add', targetHandle: 'a' }; + const r = validateConnection(c, [], [], templates, null, false); + expect(r).toEqual(buildRejectResult('nodes.cannotConnectToSelf')); + }); + it('should reject connections that create cycles in non-strict mode', () => { + const n1 = buildInvocationNode(position, add); + const n2 = buildInvocationNode(position, sub); + const nodes = [n1, n2]; + const e1 = buildEdge(n1.id, 'value', n2.id, 'a'); + const edges = [e1]; + const c = { source: n2.id, sourceHandle: 'value', target: n1.id, targetHandle: 'a' }; + const r = validateConnection(c, nodes, edges, templates, null, false); + expect(r).toEqual(buildRejectResult('nodes.connectionWouldCreateCycle')); + }); + it('should otherwise allow invalid connections in non-strict mode', () => { + const n1 = buildInvocationNode(position, add); + const n2 = buildInvocationNode(position, img_resize); + const nodes = [n1, n2]; + const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'image' }; + const r = validateConnection(c, nodes, [], templates, null, false); + expect(r).toEqual(buildAcceptResult()); + }); + }); }); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts index b6b5a43d37..edb8ac5ecb 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts @@ -25,7 +25,8 @@ export type ValidateConnectionFunc = ( nodes: AnyNode[], edges: Edge[], templates: Templates, - ignoreEdge: Edge | null + ignoreEdge: Edge | null, + strict?: boolean ) => ValidateConnectionResult; export const buildResult = (isValid: boolean, messageTKey?: string): ValidateConnectionResult => { @@ -63,76 +64,78 @@ const getTargetEqualityPredicate = export const buildAcceptResult = (): ValidateConnectionResult => ({ isValid: true }); export const buildRejectResult = (messageTKey: string): ValidateConnectionResult => ({ isValid: false, messageTKey }); -export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, templates, ignoreEdge) => { +export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, templates, ignoreEdge, strict = true) => { if (c.source === c.target) { return buildRejectResult('nodes.cannotConnectToSelf'); } - /** - * We may need to ignore an edge when validating a connection. - * - * For example, while an edge is being updated, it still exists in the array of edges. As we validate the new connection, - * the user experience should be that the edge is temporarily removed from the graph, so we need to ignore it, else - * the validation will fail unexpectedly. - */ - const filteredEdges = edges.filter((e) => e.id !== ignoreEdge?.id); + if (strict) { + /** + * We may need to ignore an edge when validating a connection. + * + * For example, while an edge is being updated, it still exists in the array of edges. As we validate the new connection, + * the user experience should be that the edge is temporarily removed from the graph, so we need to ignore it, else + * the validation will fail unexpectedly. + */ + const filteredEdges = edges.filter((e) => e.id !== ignoreEdge?.id); - if (filteredEdges.some(getEqualityPredicate(c))) { - // We already have a connection from this source to this target - return buildRejectResult('nodes.cannotDuplicateConnection'); - } - - const sourceNode = nodes.find((n) => n.id === c.source); - if (!sourceNode) { - return buildRejectResult('nodes.missingNode'); - } - - const targetNode = nodes.find((n) => n.id === c.target); - if (!targetNode) { - return buildRejectResult('nodes.missingNode'); - } - - const sourceTemplate = templates[sourceNode.data.type]; - if (!sourceTemplate) { - return buildRejectResult('nodes.missingInvocationTemplate'); - } - - const targetTemplate = templates[targetNode.data.type]; - if (!targetTemplate) { - return buildRejectResult('nodes.missingInvocationTemplate'); - } - - const sourceFieldTemplate = sourceTemplate.outputs[c.sourceHandle]; - if (!sourceFieldTemplate) { - return buildRejectResult('nodes.missingFieldTemplate'); - } - - const targetFieldTemplate = targetTemplate.inputs[c.targetHandle]; - if (!targetFieldTemplate) { - return buildRejectResult('nodes.missingFieldTemplate'); - } - - if (targetFieldTemplate.input === 'direct') { - return buildRejectResult('nodes.cannotConnectToDirectInput'); - } - - if (targetNode.data.type === 'collect' && c.targetHandle === 'item') { - // Collect nodes shouldn't mix and match field types. - const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id); - if (collectItemType && !areTypesEqual(sourceFieldTemplate.type, collectItemType)) { - return buildRejectResult('nodes.cannotMixAndMatchCollectionItemTypes'); + if (filteredEdges.some(getEqualityPredicate(c))) { + // We already have a connection from this source to this target + return buildRejectResult('nodes.cannotDuplicateConnection'); } - } - if (filteredEdges.find(getTargetEqualityPredicate(c))) { - // CollectionItemField inputs can have multiple input connections - if (targetFieldTemplate.type.name !== 'CollectionItemField') { - return buildRejectResult('nodes.inputMayOnlyHaveOneConnection'); + const sourceNode = nodes.find((n) => n.id === c.source); + if (!sourceNode) { + return buildRejectResult('nodes.missingNode'); } - } - if (!validateConnectionTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) { - return buildRejectResult('nodes.fieldTypesMustMatch'); + const targetNode = nodes.find((n) => n.id === c.target); + if (!targetNode) { + return buildRejectResult('nodes.missingNode'); + } + + const sourceTemplate = templates[sourceNode.data.type]; + if (!sourceTemplate) { + return buildRejectResult('nodes.missingInvocationTemplate'); + } + + const targetTemplate = templates[targetNode.data.type]; + if (!targetTemplate) { + return buildRejectResult('nodes.missingInvocationTemplate'); + } + + const sourceFieldTemplate = sourceTemplate.outputs[c.sourceHandle]; + if (!sourceFieldTemplate) { + return buildRejectResult('nodes.missingFieldTemplate'); + } + + const targetFieldTemplate = targetTemplate.inputs[c.targetHandle]; + if (!targetFieldTemplate) { + return buildRejectResult('nodes.missingFieldTemplate'); + } + + if (targetFieldTemplate.input === 'direct') { + return buildRejectResult('nodes.cannotConnectToDirectInput'); + } + + if (targetNode.data.type === 'collect' && c.targetHandle === 'item') { + // Collect nodes shouldn't mix and match field types. + const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id); + if (collectItemType && !areTypesEqual(sourceFieldTemplate.type, collectItemType)) { + return buildRejectResult('nodes.cannotMixAndMatchCollectionItemTypes'); + } + } + + if (filteredEdges.find(getTargetEqualityPredicate(c))) { + // CollectionItemField inputs can have multiple input connections + if (targetFieldTemplate.type.name !== 'CollectionItemField') { + return buildRejectResult('nodes.inputMayOnlyHaveOneConnection'); + } + } + + if (!validateConnectionTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) { + return buildRejectResult('nodes.fieldTypesMustMatch'); + } } if (getHasCycles(c.source, c.target, nodes, edges)) {