tests(ui): add buildNode convenience wrapper for buildInvocationNode

This commit is contained in:
psychedelicious 2024-05-19 07:49:03 +10:00
parent ea97ae5ae8
commit fe3980a369
4 changed files with 42 additions and 42 deletions

View File

@ -1,20 +1,19 @@
import { getCollectItemType } from 'features/nodes/store/util/getCollectItemType';
import { add, buildEdge, collect, position, templates } from 'features/nodes/store/util/testUtils';
import { add, buildEdge, buildNode, collect, templates } from 'features/nodes/store/util/testUtils';
import type { FieldType } from 'features/nodes/types/field';
import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode';
import { describe, expect, it } from 'vitest';
describe(getCollectItemType.name, () => {
it('should return the type of the items the collect node collects', () => {
const n1 = buildInvocationNode(position, add);
const n2 = buildInvocationNode(position, collect);
const n1 = buildNode(add);
const n2 = buildNode(collect);
const nodes = [n1, n2];
const edges = [buildEdge(n1.id, 'value', n2.id, 'item')];
const result = getCollectItemType(templates, nodes, edges, n2.id);
expect(result).toEqual<FieldType>({ name: 'IntegerField', isCollection: false, isCollectionOrScalar: false });
});
it('should return null if the collect node does not have any connections', () => {
const n1 = buildInvocationNode(position, collect);
const n1 = buildNode(collect);
const nodes = [n1];
const result = getCollectItemType(templates, nodes, [], n1.id);
expect(result).toBeNull();

View File

@ -1,12 +1,11 @@
import { getHasCycles } from 'features/nodes/store/util/getHasCycles';
import { add, buildEdge, position } from 'features/nodes/store/util/testUtils';
import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode';
import { add, buildEdge, buildNode } from 'features/nodes/store/util/testUtils';
import { describe, expect, it } from 'vitest';
describe(getHasCycles.name, () => {
const n1 = buildInvocationNode(position, add);
const n2 = buildInvocationNode(position, add);
const n3 = buildInvocationNode(position, add);
const n1 = buildNode(add);
const n2 = buildNode(add);
const n3 = buildNode(add);
const nodes = [n1, n2, n3];
it('should return true if the graph WOULD have cycles after adding the edge', () => {

View File

@ -1,5 +1,6 @@
import type { Templates } from 'features/nodes/store/types';
import type { InvocationTemplate } from 'features/nodes/types/invocation';
import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode';
import type { OpenAPIV3_1 } from 'openapi-types';
import type { Edge, XYPosition } from 'reactflow';
@ -14,6 +15,8 @@ export const buildEdge = (source: string, sourceHandle: string, target: string,
export const position: XYPosition = { x: 0, y: 0 };
export const buildNode = (template: InvocationTemplate) => buildInvocationNode({ x: 0, y: 0 }, template);
export const add: InvocationTemplate = {
title: 'Add Integers',
type: 'add',

View File

@ -1,9 +1,8 @@
import { deepClone } from 'common/util/deepClone';
import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode';
import { set } from 'lodash-es';
import { describe, expect, it } from 'vitest';
import { add, buildEdge, collect, img_resize, main_model_loader, position, sub, templates } from './testUtils';
import { add, buildEdge, buildNode, collect, img_resize, main_model_loader, sub, templates } from './testUtils';
import { buildAcceptResult, buildRejectResult, validateConnection } from './validateConnection';
describe(validateConnection.name, () => {
@ -14,8 +13,8 @@ describe(validateConnection.name, () => {
});
describe('missing nodes', () => {
const n1 = buildInvocationNode(position, add);
const n2 = buildInvocationNode(position, sub);
const n1 = buildNode(add);
const n2 = buildNode(sub);
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
it('should reject missing source node', () => {
@ -30,8 +29,8 @@ describe(validateConnection.name, () => {
});
describe('missing invocation templates', () => {
const n1 = buildInvocationNode(position, add);
const n2 = buildInvocationNode(position, sub);
const n1 = buildNode(add);
const n2 = buildNode(sub);
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
const nodes = [n1, n2];
@ -47,8 +46,8 @@ describe(validateConnection.name, () => {
});
describe('missing field templates', () => {
const n1 = buildInvocationNode(position, add);
const n2 = buildInvocationNode(position, sub);
const n1 = buildNode(add);
const n2 = buildNode(sub);
const nodes = [n1, n2];
it('should reject missing source field template', () => {
@ -65,8 +64,8 @@ describe(validateConnection.name, () => {
});
describe('duplicate connections', () => {
const n1 = buildInvocationNode(position, add);
const n2 = buildInvocationNode(position, sub);
const n1 = buildNode(add);
const n2 = buildNode(sub);
it('should accept non-duplicate connections', () => {
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
const r = validateConnection(c, [n1, n2], [], templates, null);
@ -92,17 +91,17 @@ describe(validateConnection.name, () => {
set(addWithDirectAField, 'inputs.a.input', 'direct');
set(addWithDirectAField, 'type', 'addWithDirectAField');
const n1 = buildInvocationNode(position, add);
const n2 = buildInvocationNode(position, addWithDirectAField);
const n1 = buildNode(add);
const n2 = buildNode(addWithDirectAField);
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
const r = validateConnection(c, [n1, n2], [], { add, addWithDirectAField }, null);
expect(r).toEqual(buildRejectResult('nodes.cannotConnectToDirectInput'));
});
it('should reject connection to a collect node with mismatched item types', () => {
const n1 = buildInvocationNode(position, add);
const n2 = buildInvocationNode(position, collect);
const n3 = buildInvocationNode(position, main_model_loader);
const n1 = buildNode(add);
const n2 = buildNode(collect);
const n3 = buildNode(main_model_loader);
const nodes = [n1, n2, n3];
const e1 = buildEdge(n1.id, 'value', n2.id, 'item');
const edges = [e1];
@ -112,9 +111,9 @@ describe(validateConnection.name, () => {
});
it('should accept connection to a collect node with matching item types', () => {
const n1 = buildInvocationNode(position, add);
const n2 = buildInvocationNode(position, collect);
const n3 = buildInvocationNode(position, sub);
const n1 = buildNode(add);
const n2 = buildNode(collect);
const n3 = buildNode(sub);
const nodes = [n1, n2, n3];
const e1 = buildEdge(n1.id, 'value', n2.id, 'item');
const edges = [e1];
@ -124,9 +123,9 @@ describe(validateConnection.name, () => {
});
it('should reject connections to target field that is already connected', () => {
const n1 = buildInvocationNode(position, add);
const n2 = buildInvocationNode(position, add);
const n3 = buildInvocationNode(position, add);
const n1 = buildNode(add);
const n2 = buildNode(add);
const n3 = buildNode(add);
const nodes = [n1, n2, n3];
const e1 = buildEdge(n1.id, 'value', n2.id, 'a');
const edges = [e1];
@ -136,9 +135,9 @@ describe(validateConnection.name, () => {
});
it('should accept connections to target field that is already connected (ignored edge)', () => {
const n1 = buildInvocationNode(position, add);
const n2 = buildInvocationNode(position, add);
const n3 = buildInvocationNode(position, add);
const n1 = buildNode(add);
const n2 = buildNode(add);
const n3 = buildNode(add);
const nodes = [n1, n2, n3];
const e1 = buildEdge(n1.id, 'value', n2.id, 'a');
const edges = [e1];
@ -148,8 +147,8 @@ describe(validateConnection.name, () => {
});
it('should reject connections between invalid types', () => {
const n1 = buildInvocationNode(position, add);
const n2 = buildInvocationNode(position, img_resize);
const n1 = buildNode(add);
const n2 = buildNode(img_resize);
const nodes = [n1, n2];
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'image' };
const r = validateConnection(c, nodes, [], templates, null);
@ -157,8 +156,8 @@ describe(validateConnection.name, () => {
});
it('should reject connections that would create cycles', () => {
const n1 = buildInvocationNode(position, add);
const n2 = buildInvocationNode(position, sub);
const n1 = buildNode(add);
const n2 = buildNode(sub);
const nodes = [n1, n2];
const e1 = buildEdge(n1.id, 'value', n2.id, 'a');
const edges = [e1];
@ -174,8 +173,8 @@ describe(validateConnection.name, () => {
expect(r).toEqual(buildRejectResult('nodes.cannotConnectToSelf'));
});
it('should reject connections that create cycles in non-strict mode', () => {
const n1 = buildInvocationNode(position, add);
const n2 = buildInvocationNode(position, sub);
const n1 = buildNode(add);
const n2 = buildNode(sub);
const nodes = [n1, n2];
const e1 = buildEdge(n1.id, 'value', n2.id, 'a');
const edges = [e1];
@ -184,8 +183,8 @@ describe(validateConnection.name, () => {
expect(r).toEqual(buildRejectResult('nodes.connectionWouldCreateCycle'));
});
it('should otherwise allow invalid connections in non-strict mode', () => {
const n1 = buildInvocationNode(position, add);
const n2 = buildInvocationNode(position, img_resize);
const n1 = buildNode(add);
const n2 = buildNode(img_resize);
const nodes = [n1, n2];
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'image' };
const r = validateConnection(c, nodes, [], templates, null, false);