mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): add strict mode to validateConnection
This commit is contained in:
parent
78f9f3ee95
commit
6ad01d824d
@ -166,4 +166,30 @@ describe(validateConnection.name, () => {
|
||||
const r = validateConnection(c, nodes, edges, templates, null);
|
||||
expect(r).toEqual(buildRejectResult('nodes.connectionWouldCreateCycle'));
|
||||
});
|
||||
|
||||
describe('non-strict mode', () => {
|
||||
it('should reject connections from self to self in non-strict mode', () => {
|
||||
const c = { source: 'add', sourceHandle: 'value', target: 'add', targetHandle: 'a' };
|
||||
const r = validateConnection(c, [], [], templates, null, false);
|
||||
expect(r).toEqual(buildRejectResult('nodes.cannotConnectToSelf'));
|
||||
});
|
||||
it('should reject connections that create cycles in non-strict mode', () => {
|
||||
const n1 = buildInvocationNode(position, add);
|
||||
const n2 = buildInvocationNode(position, sub);
|
||||
const nodes = [n1, n2];
|
||||
const e1 = buildEdge(n1.id, 'value', n2.id, 'a');
|
||||
const edges = [e1];
|
||||
const c = { source: n2.id, sourceHandle: 'value', target: n1.id, targetHandle: 'a' };
|
||||
const r = validateConnection(c, nodes, edges, templates, null, false);
|
||||
expect(r).toEqual(buildRejectResult('nodes.connectionWouldCreateCycle'));
|
||||
});
|
||||
it('should otherwise allow invalid connections in non-strict mode', () => {
|
||||
const n1 = buildInvocationNode(position, add);
|
||||
const n2 = buildInvocationNode(position, img_resize);
|
||||
const nodes = [n1, n2];
|
||||
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'image' };
|
||||
const r = validateConnection(c, nodes, [], templates, null, false);
|
||||
expect(r).toEqual(buildAcceptResult());
|
||||
});
|
||||
});
|
||||
});
|
||||
|
@ -25,7 +25,8 @@ export type ValidateConnectionFunc = (
|
||||
nodes: AnyNode[],
|
||||
edges: Edge[],
|
||||
templates: Templates,
|
||||
ignoreEdge: Edge | null
|
||||
ignoreEdge: Edge | null,
|
||||
strict?: boolean
|
||||
) => ValidateConnectionResult;
|
||||
|
||||
export const buildResult = (isValid: boolean, messageTKey?: string): ValidateConnectionResult => {
|
||||
@ -63,76 +64,78 @@ const getTargetEqualityPredicate =
|
||||
export const buildAcceptResult = (): ValidateConnectionResult => ({ isValid: true });
|
||||
export const buildRejectResult = (messageTKey: string): ValidateConnectionResult => ({ isValid: false, messageTKey });
|
||||
|
||||
export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, templates, ignoreEdge) => {
|
||||
export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, templates, ignoreEdge, strict = true) => {
|
||||
if (c.source === c.target) {
|
||||
return buildRejectResult('nodes.cannotConnectToSelf');
|
||||
}
|
||||
|
||||
/**
|
||||
* We may need to ignore an edge when validating a connection.
|
||||
*
|
||||
* For example, while an edge is being updated, it still exists in the array of edges. As we validate the new connection,
|
||||
* the user experience should be that the edge is temporarily removed from the graph, so we need to ignore it, else
|
||||
* the validation will fail unexpectedly.
|
||||
*/
|
||||
const filteredEdges = edges.filter((e) => e.id !== ignoreEdge?.id);
|
||||
if (strict) {
|
||||
/**
|
||||
* We may need to ignore an edge when validating a connection.
|
||||
*
|
||||
* For example, while an edge is being updated, it still exists in the array of edges. As we validate the new connection,
|
||||
* the user experience should be that the edge is temporarily removed from the graph, so we need to ignore it, else
|
||||
* the validation will fail unexpectedly.
|
||||
*/
|
||||
const filteredEdges = edges.filter((e) => e.id !== ignoreEdge?.id);
|
||||
|
||||
if (filteredEdges.some(getEqualityPredicate(c))) {
|
||||
// We already have a connection from this source to this target
|
||||
return buildRejectResult('nodes.cannotDuplicateConnection');
|
||||
}
|
||||
|
||||
const sourceNode = nodes.find((n) => n.id === c.source);
|
||||
if (!sourceNode) {
|
||||
return buildRejectResult('nodes.missingNode');
|
||||
}
|
||||
|
||||
const targetNode = nodes.find((n) => n.id === c.target);
|
||||
if (!targetNode) {
|
||||
return buildRejectResult('nodes.missingNode');
|
||||
}
|
||||
|
||||
const sourceTemplate = templates[sourceNode.data.type];
|
||||
if (!sourceTemplate) {
|
||||
return buildRejectResult('nodes.missingInvocationTemplate');
|
||||
}
|
||||
|
||||
const targetTemplate = templates[targetNode.data.type];
|
||||
if (!targetTemplate) {
|
||||
return buildRejectResult('nodes.missingInvocationTemplate');
|
||||
}
|
||||
|
||||
const sourceFieldTemplate = sourceTemplate.outputs[c.sourceHandle];
|
||||
if (!sourceFieldTemplate) {
|
||||
return buildRejectResult('nodes.missingFieldTemplate');
|
||||
}
|
||||
|
||||
const targetFieldTemplate = targetTemplate.inputs[c.targetHandle];
|
||||
if (!targetFieldTemplate) {
|
||||
return buildRejectResult('nodes.missingFieldTemplate');
|
||||
}
|
||||
|
||||
if (targetFieldTemplate.input === 'direct') {
|
||||
return buildRejectResult('nodes.cannotConnectToDirectInput');
|
||||
}
|
||||
|
||||
if (targetNode.data.type === 'collect' && c.targetHandle === 'item') {
|
||||
// Collect nodes shouldn't mix and match field types.
|
||||
const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id);
|
||||
if (collectItemType && !areTypesEqual(sourceFieldTemplate.type, collectItemType)) {
|
||||
return buildRejectResult('nodes.cannotMixAndMatchCollectionItemTypes');
|
||||
if (filteredEdges.some(getEqualityPredicate(c))) {
|
||||
// We already have a connection from this source to this target
|
||||
return buildRejectResult('nodes.cannotDuplicateConnection');
|
||||
}
|
||||
}
|
||||
|
||||
if (filteredEdges.find(getTargetEqualityPredicate(c))) {
|
||||
// CollectionItemField inputs can have multiple input connections
|
||||
if (targetFieldTemplate.type.name !== 'CollectionItemField') {
|
||||
return buildRejectResult('nodes.inputMayOnlyHaveOneConnection');
|
||||
const sourceNode = nodes.find((n) => n.id === c.source);
|
||||
if (!sourceNode) {
|
||||
return buildRejectResult('nodes.missingNode');
|
||||
}
|
||||
}
|
||||
|
||||
if (!validateConnectionTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) {
|
||||
return buildRejectResult('nodes.fieldTypesMustMatch');
|
||||
const targetNode = nodes.find((n) => n.id === c.target);
|
||||
if (!targetNode) {
|
||||
return buildRejectResult('nodes.missingNode');
|
||||
}
|
||||
|
||||
const sourceTemplate = templates[sourceNode.data.type];
|
||||
if (!sourceTemplate) {
|
||||
return buildRejectResult('nodes.missingInvocationTemplate');
|
||||
}
|
||||
|
||||
const targetTemplate = templates[targetNode.data.type];
|
||||
if (!targetTemplate) {
|
||||
return buildRejectResult('nodes.missingInvocationTemplate');
|
||||
}
|
||||
|
||||
const sourceFieldTemplate = sourceTemplate.outputs[c.sourceHandle];
|
||||
if (!sourceFieldTemplate) {
|
||||
return buildRejectResult('nodes.missingFieldTemplate');
|
||||
}
|
||||
|
||||
const targetFieldTemplate = targetTemplate.inputs[c.targetHandle];
|
||||
if (!targetFieldTemplate) {
|
||||
return buildRejectResult('nodes.missingFieldTemplate');
|
||||
}
|
||||
|
||||
if (targetFieldTemplate.input === 'direct') {
|
||||
return buildRejectResult('nodes.cannotConnectToDirectInput');
|
||||
}
|
||||
|
||||
if (targetNode.data.type === 'collect' && c.targetHandle === 'item') {
|
||||
// Collect nodes shouldn't mix and match field types.
|
||||
const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id);
|
||||
if (collectItemType && !areTypesEqual(sourceFieldTemplate.type, collectItemType)) {
|
||||
return buildRejectResult('nodes.cannotMixAndMatchCollectionItemTypes');
|
||||
}
|
||||
}
|
||||
|
||||
if (filteredEdges.find(getTargetEqualityPredicate(c))) {
|
||||
// CollectionItemField inputs can have multiple input connections
|
||||
if (targetFieldTemplate.type.name !== 'CollectionItemField') {
|
||||
return buildRejectResult('nodes.inputMayOnlyHaveOneConnection');
|
||||
}
|
||||
}
|
||||
|
||||
if (!validateConnectionTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) {
|
||||
return buildRejectResult('nodes.fieldTypesMustMatch');
|
||||
}
|
||||
}
|
||||
|
||||
if (getHasCycles(c.source, c.target, nodes, edges)) {
|
||||
|
Loading…
Reference in New Issue
Block a user