tests(ui): finish test cases for validateConnection

This commit is contained in:
psychedelicious 2024-05-19 00:25:58 +10:00
parent 3fcb2720d7
commit 04a596179b
3 changed files with 388 additions and 6 deletions

View File

@ -298,7 +298,149 @@ export const main_model_loader: InvocationTemplate = {
useCache: true,
nodePack: 'invokeai',
classification: 'stable',
}
};
export const img_resize: InvocationTemplate = {
title: 'Resize Image',
type: 'img_resize',
version: '1.2.2',
tags: ['image', 'resize'],
description: 'Resizes an image to specific dimensions',
outputType: 'image_output',
inputs: {
board: {
name: 'board',
title: 'Board',
required: false,
description: 'The board to save the image to',
fieldKind: 'input',
input: 'direct',
ui_hidden: false,
type: {
name: 'BoardField',
isCollection: false,
isCollectionOrScalar: false,
},
},
metadata: {
name: 'metadata',
title: 'Metadata',
required: false,
description: 'Optional metadata to be saved with the image',
fieldKind: 'input',
input: 'connection',
ui_hidden: false,
type: {
name: 'MetadataField',
isCollection: false,
isCollectionOrScalar: false,
},
},
image: {
name: 'image',
title: 'Image',
required: true,
description: 'The image to resize',
fieldKind: 'input',
input: 'any',
ui_hidden: false,
type: {
name: 'ImageField',
isCollection: false,
isCollectionOrScalar: false,
},
},
width: {
name: 'width',
title: 'Width',
required: false,
description: 'The width to resize to (px)',
fieldKind: 'input',
input: 'any',
ui_hidden: false,
type: {
name: 'IntegerField',
isCollection: false,
isCollectionOrScalar: false,
},
default: 512,
exclusiveMinimum: 0,
},
height: {
name: 'height',
title: 'Height',
required: false,
description: 'The height to resize to (px)',
fieldKind: 'input',
input: 'any',
ui_hidden: false,
type: {
name: 'IntegerField',
isCollection: false,
isCollectionOrScalar: false,
},
default: 512,
exclusiveMinimum: 0,
},
resample_mode: {
name: 'resample_mode',
title: 'Resample Mode',
required: false,
description: 'The resampling mode',
fieldKind: 'input',
input: 'any',
ui_hidden: false,
type: {
name: 'EnumField',
isCollection: false,
isCollectionOrScalar: false,
},
options: ['nearest', 'box', 'bilinear', 'hamming', 'bicubic', 'lanczos'],
default: 'bicubic',
},
},
outputs: {
image: {
fieldKind: 'output',
name: 'image',
title: 'Image',
description: 'The output image',
type: {
name: 'ImageField',
isCollection: false,
isCollectionOrScalar: false,
},
ui_hidden: false,
},
width: {
fieldKind: 'output',
name: 'width',
title: 'Width',
description: 'The width of the image in pixels',
type: {
name: 'IntegerField',
isCollection: false,
isCollectionOrScalar: false,
},
ui_hidden: false,
},
height: {
fieldKind: 'output',
name: 'height',
title: 'Height',
description: 'The height of the image in pixels',
type: {
name: 'IntegerField',
isCollection: false,
isCollectionOrScalar: false,
},
ui_hidden: false,
},
},
useCache: true,
nodePack: 'invokeai',
classification: 'stable',
};
export const templates: Templates = {
add,
@ -306,6 +448,7 @@ export const templates: Templates = {
collect,
scheduler,
main_model_loader,
img_resize,
};
export const schema = {
@ -1068,6 +1211,205 @@ export const schema = {
},
class: 'invocation',
},
ImageResizeInvocation: {
properties: {
board: {
anyOf: [
{
$ref: '#/components/schemas/BoardField',
},
{
type: 'null',
},
],
description: 'The board to save the image to',
field_kind: 'internal',
input: 'direct',
orig_required: false,
ui_hidden: false,
},
metadata: {
anyOf: [
{
$ref: '#/components/schemas/MetadataField',
},
{
type: 'null',
},
],
description: 'Optional metadata to be saved with the image',
field_kind: 'internal',
input: 'connection',
orig_required: false,
ui_hidden: false,
},
id: {
type: 'string',
title: 'Id',
description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.',
field_kind: 'node_attribute',
},
is_intermediate: {
type: 'boolean',
title: 'Is Intermediate',
description: 'Whether or not this is an intermediate invocation.',
default: false,
field_kind: 'node_attribute',
ui_type: 'IsIntermediate',
},
use_cache: {
type: 'boolean',
title: 'Use Cache',
description: 'Whether or not to use the cache',
default: true,
field_kind: 'node_attribute',
},
image: {
allOf: [
{
$ref: '#/components/schemas/ImageField',
},
],
description: 'The image to resize',
field_kind: 'input',
input: 'any',
orig_required: true,
ui_hidden: false,
},
width: {
type: 'integer',
exclusiveMinimum: 0,
title: 'Width',
description: 'The width to resize to (px)',
default: 512,
field_kind: 'input',
input: 'any',
orig_default: 512,
orig_required: false,
ui_hidden: false,
},
height: {
type: 'integer',
exclusiveMinimum: 0,
title: 'Height',
description: 'The height to resize to (px)',
default: 512,
field_kind: 'input',
input: 'any',
orig_default: 512,
orig_required: false,
ui_hidden: false,
},
resample_mode: {
type: 'string',
enum: ['nearest', 'box', 'bilinear', 'hamming', 'bicubic', 'lanczos'],
title: 'Resample Mode',
description: 'The resampling mode',
default: 'bicubic',
field_kind: 'input',
input: 'any',
orig_default: 'bicubic',
orig_required: false,
ui_hidden: false,
},
type: {
type: 'string',
enum: ['img_resize'],
const: 'img_resize',
title: 'type',
default: 'img_resize',
field_kind: 'node_attribute',
},
},
type: 'object',
required: ['type', 'id'],
title: 'Resize Image',
description: 'Resizes an image to specific dimensions',
category: 'image',
classification: 'stable',
node_pack: 'invokeai',
tags: ['image', 'resize'],
version: '1.2.2',
output: {
$ref: '#/components/schemas/ImageOutput',
},
class: 'invocation',
},
ImageField: {
description: 'An image primitive field',
properties: {
image_name: {
description: 'The name of the image',
title: 'Image Name',
type: 'string',
},
},
required: ['image_name'],
title: 'ImageField',
type: 'object',
class: 'output',
},
ImageOutput: {
description: 'Base class for nodes that output a single image',
properties: {
image: {
allOf: [
{
$ref: '#/components/schemas/ImageField',
},
],
description: 'The output image',
field_kind: 'output',
ui_hidden: false,
},
width: {
description: 'The width of the image in pixels',
field_kind: 'output',
title: 'Width',
type: 'integer',
ui_hidden: false,
},
height: {
description: 'The height of the image in pixels',
field_kind: 'output',
title: 'Height',
type: 'integer',
ui_hidden: false,
},
type: {
const: 'image_output',
default: 'image_output',
enum: ['image_output'],
field_kind: 'node_attribute',
title: 'type',
type: 'string',
},
},
required: ['image', 'width', 'height', 'type', 'type'],
title: 'ImageOutput',
type: 'object',
class: 'output',
},
MetadataField: {
description:
'Pydantic model for metadata with custom root of type dict[str, Any].\nMetadata is stored without a strict schema.',
title: 'MetadataField',
type: 'object',
class: 'output',
},
BoardField: {
properties: {
board_id: {
type: 'string',
title: 'Board Id',
description: 'The id of the board',
},
},
type: 'object',
required: ['board_id'],
title: 'BoardField',
description: 'A board primitive field',
},
},
},
} as OpenAPIV3_1.Document;

View File

@ -3,7 +3,7 @@ import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNod
import { set } from 'lodash-es';
import { describe, expect, it } from 'vitest';
import { add, buildEdge, collect, main_model_loader, position, sub, templates } from './testUtils';
import { add, buildEdge, collect, img_resize, main_model_loader, position, sub, templates } from './testUtils';
import { buildAcceptResult, buildRejectResult, validateConnection } from './validateConnection';
describe(validateConnection.name, () => {
@ -146,4 +146,24 @@ describe(validateConnection.name, () => {
const r = validateConnection(c, nodes, edges, templates, e1);
expect(r).toEqual(buildAcceptResult());
});
it('should reject connections between invalid types', () => {
const n1 = buildInvocationNode(position, add);
const n2 = buildInvocationNode(position, img_resize);
const nodes = [n1, n2];
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'image' };
const r = validateConnection(c, nodes, [], templates, null);
expect(r).toEqual(buildRejectResult('nodes.fieldTypesMustMatch'));
});
it('should reject connections that would create cycles', () => {
const n1 = buildInvocationNode(position, add);
const n2 = buildInvocationNode(position, sub);
const nodes = [n1, n2];
const e1 = buildEdge(n1.id, 'value', n2.id, 'a');
const edges = [e1];
const c = { source: n2.id, sourceHandle: 'value', target: n1.id, targetHandle: 'a' };
const r = validateConnection(c, nodes, edges, templates, null);
expect(r).toEqual(buildRejectResult('nodes.connectionWouldCreateCycle'));
});
});

View File

@ -1,6 +1,8 @@
import type { Templates } from 'features/nodes/store/types';
import { areTypesEqual } from 'features/nodes/store/util/areTypesEqual';
import { getCollectItemType } from 'features/nodes/store/util/getCollectItemType';
import { getHasCycles } from 'features/nodes/store/util/getHasCycles';
import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes';
import type { AnyNode } from 'features/nodes/types/invocation';
import type { Connection as NullableConnection, Edge } from 'reactflow';
import type { O } from 'ts-toolbelt';
@ -36,6 +38,12 @@ const getEqualityPredicate =
);
};
const getTargetEqualityPredicate =
(c: Connection) =>
(e: Edge): boolean => {
return e.target === c.target && e.targetHandle === c.targetHandle;
};
export const buildAcceptResult = (): ValidateConnectionResult => ({ isValid: true });
export const buildRejectResult = (messageTKey: string): ValidateConnectionResult => ({ isValid: false, messageTKey });
@ -44,6 +52,12 @@ export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, temp
return buildRejectResult('nodes.cannotConnectToSelf');
}
/**
* We may need to ignore an edge when validating a connection.
*
* For example, while an edge is being updated, it still exists in the array of edges. As we validate the new connection,
* the user experience should be that the edge is temporarily removed from the graph, so we need to ignore it.
*/
const filteredEdges = edges.filter((e) => e.id !== ignoreEdge?.id);
if (filteredEdges.some(getEqualityPredicate(c))) {
@ -96,14 +110,20 @@ export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, temp
}
if (
edges.find((e) => {
return e.target === c.target && e.targetHandle === c.targetHandle;
}) &&
// except CollectionItem inputs can have multiples
filteredEdges.find(getTargetEqualityPredicate(c)) &&
// except CollectionItem inputs can have multiple input connections
targetFieldTemplate.type.name !== 'CollectionItemField'
) {
return buildRejectResult('nodes.inputMayOnlyHaveOneConnection');
}
if (!validateConnectionTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) {
return buildRejectResult('nodes.fieldTypesMustMatch');
}
if (getHasCycles(c.source, c.target, nodes, edges)) {
return buildRejectResult('nodes.connectionWouldCreateCycle');
}
return buildAcceptResult();
};