mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
tests(ui): add buildNode convenience wrapper for buildInvocationNode
This commit is contained in:
parent
ea97ae5ae8
commit
fe3980a369
@ -1,20 +1,19 @@
|
||||
import { getCollectItemType } from 'features/nodes/store/util/getCollectItemType';
|
||||
import { add, buildEdge, collect, position, templates } from 'features/nodes/store/util/testUtils';
|
||||
import { add, buildEdge, buildNode, collect, templates } from 'features/nodes/store/util/testUtils';
|
||||
import type { FieldType } from 'features/nodes/types/field';
|
||||
import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode';
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
describe(getCollectItemType.name, () => {
|
||||
it('should return the type of the items the collect node collects', () => {
|
||||
const n1 = buildInvocationNode(position, add);
|
||||
const n2 = buildInvocationNode(position, collect);
|
||||
const n1 = buildNode(add);
|
||||
const n2 = buildNode(collect);
|
||||
const nodes = [n1, n2];
|
||||
const edges = [buildEdge(n1.id, 'value', n2.id, 'item')];
|
||||
const result = getCollectItemType(templates, nodes, edges, n2.id);
|
||||
expect(result).toEqual<FieldType>({ name: 'IntegerField', isCollection: false, isCollectionOrScalar: false });
|
||||
});
|
||||
it('should return null if the collect node does not have any connections', () => {
|
||||
const n1 = buildInvocationNode(position, collect);
|
||||
const n1 = buildNode(collect);
|
||||
const nodes = [n1];
|
||||
const result = getCollectItemType(templates, nodes, [], n1.id);
|
||||
expect(result).toBeNull();
|
||||
|
@ -1,12 +1,11 @@
|
||||
import { getHasCycles } from 'features/nodes/store/util/getHasCycles';
|
||||
import { add, buildEdge, position } from 'features/nodes/store/util/testUtils';
|
||||
import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode';
|
||||
import { add, buildEdge, buildNode } from 'features/nodes/store/util/testUtils';
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
describe(getHasCycles.name, () => {
|
||||
const n1 = buildInvocationNode(position, add);
|
||||
const n2 = buildInvocationNode(position, add);
|
||||
const n3 = buildInvocationNode(position, add);
|
||||
const n1 = buildNode(add);
|
||||
const n2 = buildNode(add);
|
||||
const n3 = buildNode(add);
|
||||
const nodes = [n1, n2, n3];
|
||||
|
||||
it('should return true if the graph WOULD have cycles after adding the edge', () => {
|
||||
|
@ -1,5 +1,6 @@
|
||||
import type { Templates } from 'features/nodes/store/types';
|
||||
import type { InvocationTemplate } from 'features/nodes/types/invocation';
|
||||
import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode';
|
||||
import type { OpenAPIV3_1 } from 'openapi-types';
|
||||
import type { Edge, XYPosition } from 'reactflow';
|
||||
|
||||
@ -14,6 +15,8 @@ export const buildEdge = (source: string, sourceHandle: string, target: string,
|
||||
|
||||
export const position: XYPosition = { x: 0, y: 0 };
|
||||
|
||||
export const buildNode = (template: InvocationTemplate) => buildInvocationNode({ x: 0, y: 0 }, template);
|
||||
|
||||
export const add: InvocationTemplate = {
|
||||
title: 'Add Integers',
|
||||
type: 'add',
|
||||
|
@ -1,9 +1,8 @@
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode';
|
||||
import { set } from 'lodash-es';
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import { add, buildEdge, collect, img_resize, main_model_loader, position, sub, templates } from './testUtils';
|
||||
import { add, buildEdge, buildNode, collect, img_resize, main_model_loader, sub, templates } from './testUtils';
|
||||
import { buildAcceptResult, buildRejectResult, validateConnection } from './validateConnection';
|
||||
|
||||
describe(validateConnection.name, () => {
|
||||
@ -14,8 +13,8 @@ describe(validateConnection.name, () => {
|
||||
});
|
||||
|
||||
describe('missing nodes', () => {
|
||||
const n1 = buildInvocationNode(position, add);
|
||||
const n2 = buildInvocationNode(position, sub);
|
||||
const n1 = buildNode(add);
|
||||
const n2 = buildNode(sub);
|
||||
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
|
||||
|
||||
it('should reject missing source node', () => {
|
||||
@ -30,8 +29,8 @@ describe(validateConnection.name, () => {
|
||||
});
|
||||
|
||||
describe('missing invocation templates', () => {
|
||||
const n1 = buildInvocationNode(position, add);
|
||||
const n2 = buildInvocationNode(position, sub);
|
||||
const n1 = buildNode(add);
|
||||
const n2 = buildNode(sub);
|
||||
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
|
||||
const nodes = [n1, n2];
|
||||
|
||||
@ -47,8 +46,8 @@ describe(validateConnection.name, () => {
|
||||
});
|
||||
|
||||
describe('missing field templates', () => {
|
||||
const n1 = buildInvocationNode(position, add);
|
||||
const n2 = buildInvocationNode(position, sub);
|
||||
const n1 = buildNode(add);
|
||||
const n2 = buildNode(sub);
|
||||
const nodes = [n1, n2];
|
||||
|
||||
it('should reject missing source field template', () => {
|
||||
@ -65,8 +64,8 @@ describe(validateConnection.name, () => {
|
||||
});
|
||||
|
||||
describe('duplicate connections', () => {
|
||||
const n1 = buildInvocationNode(position, add);
|
||||
const n2 = buildInvocationNode(position, sub);
|
||||
const n1 = buildNode(add);
|
||||
const n2 = buildNode(sub);
|
||||
it('should accept non-duplicate connections', () => {
|
||||
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
|
||||
const r = validateConnection(c, [n1, n2], [], templates, null);
|
||||
@ -92,17 +91,17 @@ describe(validateConnection.name, () => {
|
||||
set(addWithDirectAField, 'inputs.a.input', 'direct');
|
||||
set(addWithDirectAField, 'type', 'addWithDirectAField');
|
||||
|
||||
const n1 = buildInvocationNode(position, add);
|
||||
const n2 = buildInvocationNode(position, addWithDirectAField);
|
||||
const n1 = buildNode(add);
|
||||
const n2 = buildNode(addWithDirectAField);
|
||||
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
|
||||
const r = validateConnection(c, [n1, n2], [], { add, addWithDirectAField }, null);
|
||||
expect(r).toEqual(buildRejectResult('nodes.cannotConnectToDirectInput'));
|
||||
});
|
||||
|
||||
it('should reject connection to a collect node with mismatched item types', () => {
|
||||
const n1 = buildInvocationNode(position, add);
|
||||
const n2 = buildInvocationNode(position, collect);
|
||||
const n3 = buildInvocationNode(position, main_model_loader);
|
||||
const n1 = buildNode(add);
|
||||
const n2 = buildNode(collect);
|
||||
const n3 = buildNode(main_model_loader);
|
||||
const nodes = [n1, n2, n3];
|
||||
const e1 = buildEdge(n1.id, 'value', n2.id, 'item');
|
||||
const edges = [e1];
|
||||
@ -112,9 +111,9 @@ describe(validateConnection.name, () => {
|
||||
});
|
||||
|
||||
it('should accept connection to a collect node with matching item types', () => {
|
||||
const n1 = buildInvocationNode(position, add);
|
||||
const n2 = buildInvocationNode(position, collect);
|
||||
const n3 = buildInvocationNode(position, sub);
|
||||
const n1 = buildNode(add);
|
||||
const n2 = buildNode(collect);
|
||||
const n3 = buildNode(sub);
|
||||
const nodes = [n1, n2, n3];
|
||||
const e1 = buildEdge(n1.id, 'value', n2.id, 'item');
|
||||
const edges = [e1];
|
||||
@ -124,9 +123,9 @@ describe(validateConnection.name, () => {
|
||||
});
|
||||
|
||||
it('should reject connections to target field that is already connected', () => {
|
||||
const n1 = buildInvocationNode(position, add);
|
||||
const n2 = buildInvocationNode(position, add);
|
||||
const n3 = buildInvocationNode(position, add);
|
||||
const n1 = buildNode(add);
|
||||
const n2 = buildNode(add);
|
||||
const n3 = buildNode(add);
|
||||
const nodes = [n1, n2, n3];
|
||||
const e1 = buildEdge(n1.id, 'value', n2.id, 'a');
|
||||
const edges = [e1];
|
||||
@ -136,9 +135,9 @@ describe(validateConnection.name, () => {
|
||||
});
|
||||
|
||||
it('should accept connections to target field that is already connected (ignored edge)', () => {
|
||||
const n1 = buildInvocationNode(position, add);
|
||||
const n2 = buildInvocationNode(position, add);
|
||||
const n3 = buildInvocationNode(position, add);
|
||||
const n1 = buildNode(add);
|
||||
const n2 = buildNode(add);
|
||||
const n3 = buildNode(add);
|
||||
const nodes = [n1, n2, n3];
|
||||
const e1 = buildEdge(n1.id, 'value', n2.id, 'a');
|
||||
const edges = [e1];
|
||||
@ -148,8 +147,8 @@ describe(validateConnection.name, () => {
|
||||
});
|
||||
|
||||
it('should reject connections between invalid types', () => {
|
||||
const n1 = buildInvocationNode(position, add);
|
||||
const n2 = buildInvocationNode(position, img_resize);
|
||||
const n1 = buildNode(add);
|
||||
const n2 = buildNode(img_resize);
|
||||
const nodes = [n1, n2];
|
||||
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'image' };
|
||||
const r = validateConnection(c, nodes, [], templates, null);
|
||||
@ -157,8 +156,8 @@ describe(validateConnection.name, () => {
|
||||
});
|
||||
|
||||
it('should reject connections that would create cycles', () => {
|
||||
const n1 = buildInvocationNode(position, add);
|
||||
const n2 = buildInvocationNode(position, sub);
|
||||
const n1 = buildNode(add);
|
||||
const n2 = buildNode(sub);
|
||||
const nodes = [n1, n2];
|
||||
const e1 = buildEdge(n1.id, 'value', n2.id, 'a');
|
||||
const edges = [e1];
|
||||
@ -174,8 +173,8 @@ describe(validateConnection.name, () => {
|
||||
expect(r).toEqual(buildRejectResult('nodes.cannotConnectToSelf'));
|
||||
});
|
||||
it('should reject connections that create cycles in non-strict mode', () => {
|
||||
const n1 = buildInvocationNode(position, add);
|
||||
const n2 = buildInvocationNode(position, sub);
|
||||
const n1 = buildNode(add);
|
||||
const n2 = buildNode(sub);
|
||||
const nodes = [n1, n2];
|
||||
const e1 = buildEdge(n1.id, 'value', n2.id, 'a');
|
||||
const edges = [e1];
|
||||
@ -184,8 +183,8 @@ describe(validateConnection.name, () => {
|
||||
expect(r).toEqual(buildRejectResult('nodes.connectionWouldCreateCycle'));
|
||||
});
|
||||
it('should otherwise allow invalid connections in non-strict mode', () => {
|
||||
const n1 = buildInvocationNode(position, add);
|
||||
const n2 = buildInvocationNode(position, img_resize);
|
||||
const n1 = buildNode(add);
|
||||
const n2 = buildNode(img_resize);
|
||||
const nodes = [n1, n2];
|
||||
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'image' };
|
||||
const r = validateConnection(c, nodes, [], templates, null, false);
|
||||
|
Loading…
Reference in New Issue
Block a user