mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): add strict mode to validateConnection
This commit is contained in:
parent
78f9f3ee95
commit
6ad01d824d
@ -166,4 +166,30 @@ describe(validateConnection.name, () => {
|
||||
const r = validateConnection(c, nodes, edges, templates, null);
|
||||
expect(r).toEqual(buildRejectResult('nodes.connectionWouldCreateCycle'));
|
||||
});
|
||||
|
||||
describe('non-strict mode', () => {
|
||||
it('should reject connections from self to self in non-strict mode', () => {
|
||||
const c = { source: 'add', sourceHandle: 'value', target: 'add', targetHandle: 'a' };
|
||||
const r = validateConnection(c, [], [], templates, null, false);
|
||||
expect(r).toEqual(buildRejectResult('nodes.cannotConnectToSelf'));
|
||||
});
|
||||
it('should reject connections that create cycles in non-strict mode', () => {
|
||||
const n1 = buildInvocationNode(position, add);
|
||||
const n2 = buildInvocationNode(position, sub);
|
||||
const nodes = [n1, n2];
|
||||
const e1 = buildEdge(n1.id, 'value', n2.id, 'a');
|
||||
const edges = [e1];
|
||||
const c = { source: n2.id, sourceHandle: 'value', target: n1.id, targetHandle: 'a' };
|
||||
const r = validateConnection(c, nodes, edges, templates, null, false);
|
||||
expect(r).toEqual(buildRejectResult('nodes.connectionWouldCreateCycle'));
|
||||
});
|
||||
it('should otherwise allow invalid connections in non-strict mode', () => {
|
||||
const n1 = buildInvocationNode(position, add);
|
||||
const n2 = buildInvocationNode(position, img_resize);
|
||||
const nodes = [n1, n2];
|
||||
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'image' };
|
||||
const r = validateConnection(c, nodes, [], templates, null, false);
|
||||
expect(r).toEqual(buildAcceptResult());
|
||||
});
|
||||
});
|
||||
});
|
||||
|
@ -25,7 +25,8 @@ export type ValidateConnectionFunc = (
|
||||
nodes: AnyNode[],
|
||||
edges: Edge[],
|
||||
templates: Templates,
|
||||
ignoreEdge: Edge | null
|
||||
ignoreEdge: Edge | null,
|
||||
strict?: boolean
|
||||
) => ValidateConnectionResult;
|
||||
|
||||
export const buildResult = (isValid: boolean, messageTKey?: string): ValidateConnectionResult => {
|
||||
@ -63,11 +64,12 @@ const getTargetEqualityPredicate =
|
||||
export const buildAcceptResult = (): ValidateConnectionResult => ({ isValid: true });
|
||||
export const buildRejectResult = (messageTKey: string): ValidateConnectionResult => ({ isValid: false, messageTKey });
|
||||
|
||||
export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, templates, ignoreEdge) => {
|
||||
export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, templates, ignoreEdge, strict = true) => {
|
||||
if (c.source === c.target) {
|
||||
return buildRejectResult('nodes.cannotConnectToSelf');
|
||||
}
|
||||
|
||||
if (strict) {
|
||||
/**
|
||||
* We may need to ignore an edge when validating a connection.
|
||||
*
|
||||
@ -134,6 +136,7 @@ export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, temp
|
||||
if (!validateConnectionTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) {
|
||||
return buildRejectResult('nodes.fieldTypesMustMatch');
|
||||
}
|
||||
}
|
||||
|
||||
if (getHasCycles(c.source, c.target, nodes, edges)) {
|
||||
return buildRejectResult('nodes.connectionWouldCreateCycle');
|
||||
|
Loading…
Reference in New Issue
Block a user