mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
tests(ui): finish test cases for validateConnection
This commit is contained in:
parent
3fcb2720d7
commit
04a596179b
@ -298,7 +298,149 @@ export const main_model_loader: InvocationTemplate = {
|
|||||||
useCache: true,
|
useCache: true,
|
||||||
nodePack: 'invokeai',
|
nodePack: 'invokeai',
|
||||||
classification: 'stable',
|
classification: 'stable',
|
||||||
}
|
};
|
||||||
|
|
||||||
|
export const img_resize: InvocationTemplate = {
|
||||||
|
title: 'Resize Image',
|
||||||
|
type: 'img_resize',
|
||||||
|
version: '1.2.2',
|
||||||
|
tags: ['image', 'resize'],
|
||||||
|
description: 'Resizes an image to specific dimensions',
|
||||||
|
outputType: 'image_output',
|
||||||
|
inputs: {
|
||||||
|
board: {
|
||||||
|
name: 'board',
|
||||||
|
title: 'Board',
|
||||||
|
required: false,
|
||||||
|
description: 'The board to save the image to',
|
||||||
|
fieldKind: 'input',
|
||||||
|
input: 'direct',
|
||||||
|
ui_hidden: false,
|
||||||
|
type: {
|
||||||
|
name: 'BoardField',
|
||||||
|
isCollection: false,
|
||||||
|
isCollectionOrScalar: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
metadata: {
|
||||||
|
name: 'metadata',
|
||||||
|
title: 'Metadata',
|
||||||
|
required: false,
|
||||||
|
description: 'Optional metadata to be saved with the image',
|
||||||
|
fieldKind: 'input',
|
||||||
|
input: 'connection',
|
||||||
|
ui_hidden: false,
|
||||||
|
type: {
|
||||||
|
name: 'MetadataField',
|
||||||
|
isCollection: false,
|
||||||
|
isCollectionOrScalar: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
image: {
|
||||||
|
name: 'image',
|
||||||
|
title: 'Image',
|
||||||
|
required: true,
|
||||||
|
description: 'The image to resize',
|
||||||
|
fieldKind: 'input',
|
||||||
|
input: 'any',
|
||||||
|
ui_hidden: false,
|
||||||
|
type: {
|
||||||
|
name: 'ImageField',
|
||||||
|
isCollection: false,
|
||||||
|
isCollectionOrScalar: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
width: {
|
||||||
|
name: 'width',
|
||||||
|
title: 'Width',
|
||||||
|
required: false,
|
||||||
|
description: 'The width to resize to (px)',
|
||||||
|
fieldKind: 'input',
|
||||||
|
input: 'any',
|
||||||
|
ui_hidden: false,
|
||||||
|
type: {
|
||||||
|
name: 'IntegerField',
|
||||||
|
isCollection: false,
|
||||||
|
isCollectionOrScalar: false,
|
||||||
|
},
|
||||||
|
default: 512,
|
||||||
|
exclusiveMinimum: 0,
|
||||||
|
},
|
||||||
|
height: {
|
||||||
|
name: 'height',
|
||||||
|
title: 'Height',
|
||||||
|
required: false,
|
||||||
|
description: 'The height to resize to (px)',
|
||||||
|
fieldKind: 'input',
|
||||||
|
input: 'any',
|
||||||
|
ui_hidden: false,
|
||||||
|
type: {
|
||||||
|
name: 'IntegerField',
|
||||||
|
isCollection: false,
|
||||||
|
isCollectionOrScalar: false,
|
||||||
|
},
|
||||||
|
default: 512,
|
||||||
|
exclusiveMinimum: 0,
|
||||||
|
},
|
||||||
|
resample_mode: {
|
||||||
|
name: 'resample_mode',
|
||||||
|
title: 'Resample Mode',
|
||||||
|
required: false,
|
||||||
|
description: 'The resampling mode',
|
||||||
|
fieldKind: 'input',
|
||||||
|
input: 'any',
|
||||||
|
ui_hidden: false,
|
||||||
|
type: {
|
||||||
|
name: 'EnumField',
|
||||||
|
isCollection: false,
|
||||||
|
isCollectionOrScalar: false,
|
||||||
|
},
|
||||||
|
options: ['nearest', 'box', 'bilinear', 'hamming', 'bicubic', 'lanczos'],
|
||||||
|
default: 'bicubic',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
outputs: {
|
||||||
|
image: {
|
||||||
|
fieldKind: 'output',
|
||||||
|
name: 'image',
|
||||||
|
title: 'Image',
|
||||||
|
description: 'The output image',
|
||||||
|
type: {
|
||||||
|
name: 'ImageField',
|
||||||
|
isCollection: false,
|
||||||
|
isCollectionOrScalar: false,
|
||||||
|
},
|
||||||
|
ui_hidden: false,
|
||||||
|
},
|
||||||
|
width: {
|
||||||
|
fieldKind: 'output',
|
||||||
|
name: 'width',
|
||||||
|
title: 'Width',
|
||||||
|
description: 'The width of the image in pixels',
|
||||||
|
type: {
|
||||||
|
name: 'IntegerField',
|
||||||
|
isCollection: false,
|
||||||
|
isCollectionOrScalar: false,
|
||||||
|
},
|
||||||
|
ui_hidden: false,
|
||||||
|
},
|
||||||
|
height: {
|
||||||
|
fieldKind: 'output',
|
||||||
|
name: 'height',
|
||||||
|
title: 'Height',
|
||||||
|
description: 'The height of the image in pixels',
|
||||||
|
type: {
|
||||||
|
name: 'IntegerField',
|
||||||
|
isCollection: false,
|
||||||
|
isCollectionOrScalar: false,
|
||||||
|
},
|
||||||
|
ui_hidden: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
useCache: true,
|
||||||
|
nodePack: 'invokeai',
|
||||||
|
classification: 'stable',
|
||||||
|
};
|
||||||
|
|
||||||
export const templates: Templates = {
|
export const templates: Templates = {
|
||||||
add,
|
add,
|
||||||
@ -306,6 +448,7 @@ export const templates: Templates = {
|
|||||||
collect,
|
collect,
|
||||||
scheduler,
|
scheduler,
|
||||||
main_model_loader,
|
main_model_loader,
|
||||||
|
img_resize,
|
||||||
};
|
};
|
||||||
|
|
||||||
export const schema = {
|
export const schema = {
|
||||||
@ -1068,6 +1211,205 @@ export const schema = {
|
|||||||
},
|
},
|
||||||
class: 'invocation',
|
class: 'invocation',
|
||||||
},
|
},
|
||||||
|
ImageResizeInvocation: {
|
||||||
|
properties: {
|
||||||
|
board: {
|
||||||
|
anyOf: [
|
||||||
|
{
|
||||||
|
$ref: '#/components/schemas/BoardField',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
type: 'null',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
description: 'The board to save the image to',
|
||||||
|
field_kind: 'internal',
|
||||||
|
input: 'direct',
|
||||||
|
orig_required: false,
|
||||||
|
ui_hidden: false,
|
||||||
|
},
|
||||||
|
metadata: {
|
||||||
|
anyOf: [
|
||||||
|
{
|
||||||
|
$ref: '#/components/schemas/MetadataField',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
type: 'null',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
description: 'Optional metadata to be saved with the image',
|
||||||
|
field_kind: 'internal',
|
||||||
|
input: 'connection',
|
||||||
|
orig_required: false,
|
||||||
|
ui_hidden: false,
|
||||||
|
},
|
||||||
|
id: {
|
||||||
|
type: 'string',
|
||||||
|
title: 'Id',
|
||||||
|
description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.',
|
||||||
|
field_kind: 'node_attribute',
|
||||||
|
},
|
||||||
|
is_intermediate: {
|
||||||
|
type: 'boolean',
|
||||||
|
title: 'Is Intermediate',
|
||||||
|
description: 'Whether or not this is an intermediate invocation.',
|
||||||
|
default: false,
|
||||||
|
field_kind: 'node_attribute',
|
||||||
|
ui_type: 'IsIntermediate',
|
||||||
|
},
|
||||||
|
use_cache: {
|
||||||
|
type: 'boolean',
|
||||||
|
title: 'Use Cache',
|
||||||
|
description: 'Whether or not to use the cache',
|
||||||
|
default: true,
|
||||||
|
field_kind: 'node_attribute',
|
||||||
|
},
|
||||||
|
image: {
|
||||||
|
allOf: [
|
||||||
|
{
|
||||||
|
$ref: '#/components/schemas/ImageField',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
description: 'The image to resize',
|
||||||
|
field_kind: 'input',
|
||||||
|
input: 'any',
|
||||||
|
orig_required: true,
|
||||||
|
ui_hidden: false,
|
||||||
|
},
|
||||||
|
width: {
|
||||||
|
type: 'integer',
|
||||||
|
exclusiveMinimum: 0,
|
||||||
|
title: 'Width',
|
||||||
|
description: 'The width to resize to (px)',
|
||||||
|
default: 512,
|
||||||
|
field_kind: 'input',
|
||||||
|
input: 'any',
|
||||||
|
orig_default: 512,
|
||||||
|
orig_required: false,
|
||||||
|
ui_hidden: false,
|
||||||
|
},
|
||||||
|
height: {
|
||||||
|
type: 'integer',
|
||||||
|
exclusiveMinimum: 0,
|
||||||
|
title: 'Height',
|
||||||
|
description: 'The height to resize to (px)',
|
||||||
|
default: 512,
|
||||||
|
field_kind: 'input',
|
||||||
|
input: 'any',
|
||||||
|
orig_default: 512,
|
||||||
|
orig_required: false,
|
||||||
|
ui_hidden: false,
|
||||||
|
},
|
||||||
|
resample_mode: {
|
||||||
|
type: 'string',
|
||||||
|
enum: ['nearest', 'box', 'bilinear', 'hamming', 'bicubic', 'lanczos'],
|
||||||
|
title: 'Resample Mode',
|
||||||
|
description: 'The resampling mode',
|
||||||
|
default: 'bicubic',
|
||||||
|
field_kind: 'input',
|
||||||
|
input: 'any',
|
||||||
|
orig_default: 'bicubic',
|
||||||
|
orig_required: false,
|
||||||
|
ui_hidden: false,
|
||||||
|
},
|
||||||
|
type: {
|
||||||
|
type: 'string',
|
||||||
|
enum: ['img_resize'],
|
||||||
|
const: 'img_resize',
|
||||||
|
title: 'type',
|
||||||
|
default: 'img_resize',
|
||||||
|
field_kind: 'node_attribute',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
type: 'object',
|
||||||
|
required: ['type', 'id'],
|
||||||
|
title: 'Resize Image',
|
||||||
|
description: 'Resizes an image to specific dimensions',
|
||||||
|
category: 'image',
|
||||||
|
classification: 'stable',
|
||||||
|
node_pack: 'invokeai',
|
||||||
|
tags: ['image', 'resize'],
|
||||||
|
version: '1.2.2',
|
||||||
|
output: {
|
||||||
|
$ref: '#/components/schemas/ImageOutput',
|
||||||
|
},
|
||||||
|
class: 'invocation',
|
||||||
|
},
|
||||||
|
ImageField: {
|
||||||
|
description: 'An image primitive field',
|
||||||
|
properties: {
|
||||||
|
image_name: {
|
||||||
|
description: 'The name of the image',
|
||||||
|
title: 'Image Name',
|
||||||
|
type: 'string',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
required: ['image_name'],
|
||||||
|
title: 'ImageField',
|
||||||
|
type: 'object',
|
||||||
|
class: 'output',
|
||||||
|
},
|
||||||
|
ImageOutput: {
|
||||||
|
description: 'Base class for nodes that output a single image',
|
||||||
|
properties: {
|
||||||
|
image: {
|
||||||
|
allOf: [
|
||||||
|
{
|
||||||
|
$ref: '#/components/schemas/ImageField',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
description: 'The output image',
|
||||||
|
field_kind: 'output',
|
||||||
|
ui_hidden: false,
|
||||||
|
},
|
||||||
|
width: {
|
||||||
|
description: 'The width of the image in pixels',
|
||||||
|
field_kind: 'output',
|
||||||
|
title: 'Width',
|
||||||
|
type: 'integer',
|
||||||
|
ui_hidden: false,
|
||||||
|
},
|
||||||
|
height: {
|
||||||
|
description: 'The height of the image in pixels',
|
||||||
|
field_kind: 'output',
|
||||||
|
title: 'Height',
|
||||||
|
type: 'integer',
|
||||||
|
ui_hidden: false,
|
||||||
|
},
|
||||||
|
type: {
|
||||||
|
const: 'image_output',
|
||||||
|
default: 'image_output',
|
||||||
|
enum: ['image_output'],
|
||||||
|
field_kind: 'node_attribute',
|
||||||
|
title: 'type',
|
||||||
|
type: 'string',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
required: ['image', 'width', 'height', 'type', 'type'],
|
||||||
|
title: 'ImageOutput',
|
||||||
|
type: 'object',
|
||||||
|
class: 'output',
|
||||||
|
},
|
||||||
|
MetadataField: {
|
||||||
|
description:
|
||||||
|
'Pydantic model for metadata with custom root of type dict[str, Any].\nMetadata is stored without a strict schema.',
|
||||||
|
title: 'MetadataField',
|
||||||
|
type: 'object',
|
||||||
|
class: 'output',
|
||||||
|
},
|
||||||
|
BoardField: {
|
||||||
|
properties: {
|
||||||
|
board_id: {
|
||||||
|
type: 'string',
|
||||||
|
title: 'Board Id',
|
||||||
|
description: 'The id of the board',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
type: 'object',
|
||||||
|
required: ['board_id'],
|
||||||
|
title: 'BoardField',
|
||||||
|
description: 'A board primitive field',
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
} as OpenAPIV3_1.Document;
|
} as OpenAPIV3_1.Document;
|
||||||
|
@ -3,7 +3,7 @@ import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNod
|
|||||||
import { set } from 'lodash-es';
|
import { set } from 'lodash-es';
|
||||||
import { describe, expect, it } from 'vitest';
|
import { describe, expect, it } from 'vitest';
|
||||||
|
|
||||||
import { add, buildEdge, collect, main_model_loader, position, sub, templates } from './testUtils';
|
import { add, buildEdge, collect, img_resize, main_model_loader, position, sub, templates } from './testUtils';
|
||||||
import { buildAcceptResult, buildRejectResult, validateConnection } from './validateConnection';
|
import { buildAcceptResult, buildRejectResult, validateConnection } from './validateConnection';
|
||||||
|
|
||||||
describe(validateConnection.name, () => {
|
describe(validateConnection.name, () => {
|
||||||
@ -146,4 +146,24 @@ describe(validateConnection.name, () => {
|
|||||||
const r = validateConnection(c, nodes, edges, templates, e1);
|
const r = validateConnection(c, nodes, edges, templates, e1);
|
||||||
expect(r).toEqual(buildAcceptResult());
|
expect(r).toEqual(buildAcceptResult());
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should reject connections between invalid types', () => {
|
||||||
|
const n1 = buildInvocationNode(position, add);
|
||||||
|
const n2 = buildInvocationNode(position, img_resize);
|
||||||
|
const nodes = [n1, n2];
|
||||||
|
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'image' };
|
||||||
|
const r = validateConnection(c, nodes, [], templates, null);
|
||||||
|
expect(r).toEqual(buildRejectResult('nodes.fieldTypesMustMatch'));
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should reject connections that would create cycles', () => {
|
||||||
|
const n1 = buildInvocationNode(position, add);
|
||||||
|
const n2 = buildInvocationNode(position, sub);
|
||||||
|
const nodes = [n1, n2];
|
||||||
|
const e1 = buildEdge(n1.id, 'value', n2.id, 'a');
|
||||||
|
const edges = [e1];
|
||||||
|
const c = { source: n2.id, sourceHandle: 'value', target: n1.id, targetHandle: 'a' };
|
||||||
|
const r = validateConnection(c, nodes, edges, templates, null);
|
||||||
|
expect(r).toEqual(buildRejectResult('nodes.connectionWouldCreateCycle'));
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
import type { Templates } from 'features/nodes/store/types';
|
import type { Templates } from 'features/nodes/store/types';
|
||||||
import { areTypesEqual } from 'features/nodes/store/util/areTypesEqual';
|
import { areTypesEqual } from 'features/nodes/store/util/areTypesEqual';
|
||||||
import { getCollectItemType } from 'features/nodes/store/util/getCollectItemType';
|
import { getCollectItemType } from 'features/nodes/store/util/getCollectItemType';
|
||||||
|
import { getHasCycles } from 'features/nodes/store/util/getHasCycles';
|
||||||
|
import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes';
|
||||||
import type { AnyNode } from 'features/nodes/types/invocation';
|
import type { AnyNode } from 'features/nodes/types/invocation';
|
||||||
import type { Connection as NullableConnection, Edge } from 'reactflow';
|
import type { Connection as NullableConnection, Edge } from 'reactflow';
|
||||||
import type { O } from 'ts-toolbelt';
|
import type { O } from 'ts-toolbelt';
|
||||||
@ -36,6 +38,12 @@ const getEqualityPredicate =
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const getTargetEqualityPredicate =
|
||||||
|
(c: Connection) =>
|
||||||
|
(e: Edge): boolean => {
|
||||||
|
return e.target === c.target && e.targetHandle === c.targetHandle;
|
||||||
|
};
|
||||||
|
|
||||||
export const buildAcceptResult = (): ValidateConnectionResult => ({ isValid: true });
|
export const buildAcceptResult = (): ValidateConnectionResult => ({ isValid: true });
|
||||||
export const buildRejectResult = (messageTKey: string): ValidateConnectionResult => ({ isValid: false, messageTKey });
|
export const buildRejectResult = (messageTKey: string): ValidateConnectionResult => ({ isValid: false, messageTKey });
|
||||||
|
|
||||||
@ -44,6 +52,12 @@ export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, temp
|
|||||||
return buildRejectResult('nodes.cannotConnectToSelf');
|
return buildRejectResult('nodes.cannotConnectToSelf');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* We may need to ignore an edge when validating a connection.
|
||||||
|
*
|
||||||
|
* For example, while an edge is being updated, it still exists in the array of edges. As we validate the new connection,
|
||||||
|
* the user experience should be that the edge is temporarily removed from the graph, so we need to ignore it.
|
||||||
|
*/
|
||||||
const filteredEdges = edges.filter((e) => e.id !== ignoreEdge?.id);
|
const filteredEdges = edges.filter((e) => e.id !== ignoreEdge?.id);
|
||||||
|
|
||||||
if (filteredEdges.some(getEqualityPredicate(c))) {
|
if (filteredEdges.some(getEqualityPredicate(c))) {
|
||||||
@ -96,14 +110,20 @@ export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, temp
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (
|
if (
|
||||||
edges.find((e) => {
|
filteredEdges.find(getTargetEqualityPredicate(c)) &&
|
||||||
return e.target === c.target && e.targetHandle === c.targetHandle;
|
// except CollectionItem inputs can have multiple input connections
|
||||||
}) &&
|
|
||||||
// except CollectionItem inputs can have multiples
|
|
||||||
targetFieldTemplate.type.name !== 'CollectionItemField'
|
targetFieldTemplate.type.name !== 'CollectionItemField'
|
||||||
) {
|
) {
|
||||||
return buildRejectResult('nodes.inputMayOnlyHaveOneConnection');
|
return buildRejectResult('nodes.inputMayOnlyHaveOneConnection');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!validateConnectionTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) {
|
||||||
|
return buildRejectResult('nodes.fieldTypesMustMatch');
|
||||||
|
}
|
||||||
|
|
||||||
|
if (getHasCycles(c.source, c.target, nodes, edges)) {
|
||||||
|
return buildRejectResult('nodes.connectionWouldCreateCycle');
|
||||||
|
}
|
||||||
|
|
||||||
return buildAcceptResult();
|
return buildAcceptResult();
|
||||||
};
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user