mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
tests(ui): finish test cases for validateConnection
This commit is contained in:
parent
3fcb2720d7
commit
04a596179b
@ -298,7 +298,149 @@ export const main_model_loader: InvocationTemplate = {
|
||||
useCache: true,
|
||||
nodePack: 'invokeai',
|
||||
classification: 'stable',
|
||||
}
|
||||
};
|
||||
|
||||
export const img_resize: InvocationTemplate = {
|
||||
title: 'Resize Image',
|
||||
type: 'img_resize',
|
||||
version: '1.2.2',
|
||||
tags: ['image', 'resize'],
|
||||
description: 'Resizes an image to specific dimensions',
|
||||
outputType: 'image_output',
|
||||
inputs: {
|
||||
board: {
|
||||
name: 'board',
|
||||
title: 'Board',
|
||||
required: false,
|
||||
description: 'The board to save the image to',
|
||||
fieldKind: 'input',
|
||||
input: 'direct',
|
||||
ui_hidden: false,
|
||||
type: {
|
||||
name: 'BoardField',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
},
|
||||
},
|
||||
metadata: {
|
||||
name: 'metadata',
|
||||
title: 'Metadata',
|
||||
required: false,
|
||||
description: 'Optional metadata to be saved with the image',
|
||||
fieldKind: 'input',
|
||||
input: 'connection',
|
||||
ui_hidden: false,
|
||||
type: {
|
||||
name: 'MetadataField',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
},
|
||||
},
|
||||
image: {
|
||||
name: 'image',
|
||||
title: 'Image',
|
||||
required: true,
|
||||
description: 'The image to resize',
|
||||
fieldKind: 'input',
|
||||
input: 'any',
|
||||
ui_hidden: false,
|
||||
type: {
|
||||
name: 'ImageField',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
},
|
||||
},
|
||||
width: {
|
||||
name: 'width',
|
||||
title: 'Width',
|
||||
required: false,
|
||||
description: 'The width to resize to (px)',
|
||||
fieldKind: 'input',
|
||||
input: 'any',
|
||||
ui_hidden: false,
|
||||
type: {
|
||||
name: 'IntegerField',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
},
|
||||
default: 512,
|
||||
exclusiveMinimum: 0,
|
||||
},
|
||||
height: {
|
||||
name: 'height',
|
||||
title: 'Height',
|
||||
required: false,
|
||||
description: 'The height to resize to (px)',
|
||||
fieldKind: 'input',
|
||||
input: 'any',
|
||||
ui_hidden: false,
|
||||
type: {
|
||||
name: 'IntegerField',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
},
|
||||
default: 512,
|
||||
exclusiveMinimum: 0,
|
||||
},
|
||||
resample_mode: {
|
||||
name: 'resample_mode',
|
||||
title: 'Resample Mode',
|
||||
required: false,
|
||||
description: 'The resampling mode',
|
||||
fieldKind: 'input',
|
||||
input: 'any',
|
||||
ui_hidden: false,
|
||||
type: {
|
||||
name: 'EnumField',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
},
|
||||
options: ['nearest', 'box', 'bilinear', 'hamming', 'bicubic', 'lanczos'],
|
||||
default: 'bicubic',
|
||||
},
|
||||
},
|
||||
outputs: {
|
||||
image: {
|
||||
fieldKind: 'output',
|
||||
name: 'image',
|
||||
title: 'Image',
|
||||
description: 'The output image',
|
||||
type: {
|
||||
name: 'ImageField',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
},
|
||||
ui_hidden: false,
|
||||
},
|
||||
width: {
|
||||
fieldKind: 'output',
|
||||
name: 'width',
|
||||
title: 'Width',
|
||||
description: 'The width of the image in pixels',
|
||||
type: {
|
||||
name: 'IntegerField',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
},
|
||||
ui_hidden: false,
|
||||
},
|
||||
height: {
|
||||
fieldKind: 'output',
|
||||
name: 'height',
|
||||
title: 'Height',
|
||||
description: 'The height of the image in pixels',
|
||||
type: {
|
||||
name: 'IntegerField',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
},
|
||||
ui_hidden: false,
|
||||
},
|
||||
},
|
||||
useCache: true,
|
||||
nodePack: 'invokeai',
|
||||
classification: 'stable',
|
||||
};
|
||||
|
||||
export const templates: Templates = {
|
||||
add,
|
||||
@ -306,6 +448,7 @@ export const templates: Templates = {
|
||||
collect,
|
||||
scheduler,
|
||||
main_model_loader,
|
||||
img_resize,
|
||||
};
|
||||
|
||||
export const schema = {
|
||||
@ -1068,6 +1211,205 @@ export const schema = {
|
||||
},
|
||||
class: 'invocation',
|
||||
},
|
||||
ImageResizeInvocation: {
|
||||
properties: {
|
||||
board: {
|
||||
anyOf: [
|
||||
{
|
||||
$ref: '#/components/schemas/BoardField',
|
||||
},
|
||||
{
|
||||
type: 'null',
|
||||
},
|
||||
],
|
||||
description: 'The board to save the image to',
|
||||
field_kind: 'internal',
|
||||
input: 'direct',
|
||||
orig_required: false,
|
||||
ui_hidden: false,
|
||||
},
|
||||
metadata: {
|
||||
anyOf: [
|
||||
{
|
||||
$ref: '#/components/schemas/MetadataField',
|
||||
},
|
||||
{
|
||||
type: 'null',
|
||||
},
|
||||
],
|
||||
description: 'Optional metadata to be saved with the image',
|
||||
field_kind: 'internal',
|
||||
input: 'connection',
|
||||
orig_required: false,
|
||||
ui_hidden: false,
|
||||
},
|
||||
id: {
|
||||
type: 'string',
|
||||
title: 'Id',
|
||||
description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.',
|
||||
field_kind: 'node_attribute',
|
||||
},
|
||||
is_intermediate: {
|
||||
type: 'boolean',
|
||||
title: 'Is Intermediate',
|
||||
description: 'Whether or not this is an intermediate invocation.',
|
||||
default: false,
|
||||
field_kind: 'node_attribute',
|
||||
ui_type: 'IsIntermediate',
|
||||
},
|
||||
use_cache: {
|
||||
type: 'boolean',
|
||||
title: 'Use Cache',
|
||||
description: 'Whether or not to use the cache',
|
||||
default: true,
|
||||
field_kind: 'node_attribute',
|
||||
},
|
||||
image: {
|
||||
allOf: [
|
||||
{
|
||||
$ref: '#/components/schemas/ImageField',
|
||||
},
|
||||
],
|
||||
description: 'The image to resize',
|
||||
field_kind: 'input',
|
||||
input: 'any',
|
||||
orig_required: true,
|
||||
ui_hidden: false,
|
||||
},
|
||||
width: {
|
||||
type: 'integer',
|
||||
exclusiveMinimum: 0,
|
||||
title: 'Width',
|
||||
description: 'The width to resize to (px)',
|
||||
default: 512,
|
||||
field_kind: 'input',
|
||||
input: 'any',
|
||||
orig_default: 512,
|
||||
orig_required: false,
|
||||
ui_hidden: false,
|
||||
},
|
||||
height: {
|
||||
type: 'integer',
|
||||
exclusiveMinimum: 0,
|
||||
title: 'Height',
|
||||
description: 'The height to resize to (px)',
|
||||
default: 512,
|
||||
field_kind: 'input',
|
||||
input: 'any',
|
||||
orig_default: 512,
|
||||
orig_required: false,
|
||||
ui_hidden: false,
|
||||
},
|
||||
resample_mode: {
|
||||
type: 'string',
|
||||
enum: ['nearest', 'box', 'bilinear', 'hamming', 'bicubic', 'lanczos'],
|
||||
title: 'Resample Mode',
|
||||
description: 'The resampling mode',
|
||||
default: 'bicubic',
|
||||
field_kind: 'input',
|
||||
input: 'any',
|
||||
orig_default: 'bicubic',
|
||||
orig_required: false,
|
||||
ui_hidden: false,
|
||||
},
|
||||
type: {
|
||||
type: 'string',
|
||||
enum: ['img_resize'],
|
||||
const: 'img_resize',
|
||||
title: 'type',
|
||||
default: 'img_resize',
|
||||
field_kind: 'node_attribute',
|
||||
},
|
||||
},
|
||||
type: 'object',
|
||||
required: ['type', 'id'],
|
||||
title: 'Resize Image',
|
||||
description: 'Resizes an image to specific dimensions',
|
||||
category: 'image',
|
||||
classification: 'stable',
|
||||
node_pack: 'invokeai',
|
||||
tags: ['image', 'resize'],
|
||||
version: '1.2.2',
|
||||
output: {
|
||||
$ref: '#/components/schemas/ImageOutput',
|
||||
},
|
||||
class: 'invocation',
|
||||
},
|
||||
ImageField: {
|
||||
description: 'An image primitive field',
|
||||
properties: {
|
||||
image_name: {
|
||||
description: 'The name of the image',
|
||||
title: 'Image Name',
|
||||
type: 'string',
|
||||
},
|
||||
},
|
||||
required: ['image_name'],
|
||||
title: 'ImageField',
|
||||
type: 'object',
|
||||
class: 'output',
|
||||
},
|
||||
ImageOutput: {
|
||||
description: 'Base class for nodes that output a single image',
|
||||
properties: {
|
||||
image: {
|
||||
allOf: [
|
||||
{
|
||||
$ref: '#/components/schemas/ImageField',
|
||||
},
|
||||
],
|
||||
description: 'The output image',
|
||||
field_kind: 'output',
|
||||
ui_hidden: false,
|
||||
},
|
||||
width: {
|
||||
description: 'The width of the image in pixels',
|
||||
field_kind: 'output',
|
||||
title: 'Width',
|
||||
type: 'integer',
|
||||
ui_hidden: false,
|
||||
},
|
||||
height: {
|
||||
description: 'The height of the image in pixels',
|
||||
field_kind: 'output',
|
||||
title: 'Height',
|
||||
type: 'integer',
|
||||
ui_hidden: false,
|
||||
},
|
||||
type: {
|
||||
const: 'image_output',
|
||||
default: 'image_output',
|
||||
enum: ['image_output'],
|
||||
field_kind: 'node_attribute',
|
||||
title: 'type',
|
||||
type: 'string',
|
||||
},
|
||||
},
|
||||
required: ['image', 'width', 'height', 'type', 'type'],
|
||||
title: 'ImageOutput',
|
||||
type: 'object',
|
||||
class: 'output',
|
||||
},
|
||||
MetadataField: {
|
||||
description:
|
||||
'Pydantic model for metadata with custom root of type dict[str, Any].\nMetadata is stored without a strict schema.',
|
||||
title: 'MetadataField',
|
||||
type: 'object',
|
||||
class: 'output',
|
||||
},
|
||||
BoardField: {
|
||||
properties: {
|
||||
board_id: {
|
||||
type: 'string',
|
||||
title: 'Board Id',
|
||||
description: 'The id of the board',
|
||||
},
|
||||
},
|
||||
type: 'object',
|
||||
required: ['board_id'],
|
||||
title: 'BoardField',
|
||||
description: 'A board primitive field',
|
||||
},
|
||||
},
|
||||
},
|
||||
} as OpenAPIV3_1.Document;
|
||||
|
@ -3,7 +3,7 @@ import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNod
|
||||
import { set } from 'lodash-es';
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import { add, buildEdge, collect, main_model_loader, position, sub, templates } from './testUtils';
|
||||
import { add, buildEdge, collect, img_resize, main_model_loader, position, sub, templates } from './testUtils';
|
||||
import { buildAcceptResult, buildRejectResult, validateConnection } from './validateConnection';
|
||||
|
||||
describe(validateConnection.name, () => {
|
||||
@ -146,4 +146,24 @@ describe(validateConnection.name, () => {
|
||||
const r = validateConnection(c, nodes, edges, templates, e1);
|
||||
expect(r).toEqual(buildAcceptResult());
|
||||
});
|
||||
|
||||
it('should reject connections between invalid types', () => {
|
||||
const n1 = buildInvocationNode(position, add);
|
||||
const n2 = buildInvocationNode(position, img_resize);
|
||||
const nodes = [n1, n2];
|
||||
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'image' };
|
||||
const r = validateConnection(c, nodes, [], templates, null);
|
||||
expect(r).toEqual(buildRejectResult('nodes.fieldTypesMustMatch'));
|
||||
});
|
||||
|
||||
it('should reject connections that would create cycles', () => {
|
||||
const n1 = buildInvocationNode(position, add);
|
||||
const n2 = buildInvocationNode(position, sub);
|
||||
const nodes = [n1, n2];
|
||||
const e1 = buildEdge(n1.id, 'value', n2.id, 'a');
|
||||
const edges = [e1];
|
||||
const c = { source: n2.id, sourceHandle: 'value', target: n1.id, targetHandle: 'a' };
|
||||
const r = validateConnection(c, nodes, edges, templates, null);
|
||||
expect(r).toEqual(buildRejectResult('nodes.connectionWouldCreateCycle'));
|
||||
});
|
||||
});
|
||||
|
@ -1,6 +1,8 @@
|
||||
import type { Templates } from 'features/nodes/store/types';
|
||||
import { areTypesEqual } from 'features/nodes/store/util/areTypesEqual';
|
||||
import { getCollectItemType } from 'features/nodes/store/util/getCollectItemType';
|
||||
import { getHasCycles } from 'features/nodes/store/util/getHasCycles';
|
||||
import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes';
|
||||
import type { AnyNode } from 'features/nodes/types/invocation';
|
||||
import type { Connection as NullableConnection, Edge } from 'reactflow';
|
||||
import type { O } from 'ts-toolbelt';
|
||||
@ -36,6 +38,12 @@ const getEqualityPredicate =
|
||||
);
|
||||
};
|
||||
|
||||
const getTargetEqualityPredicate =
|
||||
(c: Connection) =>
|
||||
(e: Edge): boolean => {
|
||||
return e.target === c.target && e.targetHandle === c.targetHandle;
|
||||
};
|
||||
|
||||
export const buildAcceptResult = (): ValidateConnectionResult => ({ isValid: true });
|
||||
export const buildRejectResult = (messageTKey: string): ValidateConnectionResult => ({ isValid: false, messageTKey });
|
||||
|
||||
@ -44,6 +52,12 @@ export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, temp
|
||||
return buildRejectResult('nodes.cannotConnectToSelf');
|
||||
}
|
||||
|
||||
/**
|
||||
* We may need to ignore an edge when validating a connection.
|
||||
*
|
||||
* For example, while an edge is being updated, it still exists in the array of edges. As we validate the new connection,
|
||||
* the user experience should be that the edge is temporarily removed from the graph, so we need to ignore it.
|
||||
*/
|
||||
const filteredEdges = edges.filter((e) => e.id !== ignoreEdge?.id);
|
||||
|
||||
if (filteredEdges.some(getEqualityPredicate(c))) {
|
||||
@ -96,14 +110,20 @@ export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, temp
|
||||
}
|
||||
|
||||
if (
|
||||
edges.find((e) => {
|
||||
return e.target === c.target && e.targetHandle === c.targetHandle;
|
||||
}) &&
|
||||
// except CollectionItem inputs can have multiples
|
||||
filteredEdges.find(getTargetEqualityPredicate(c)) &&
|
||||
// except CollectionItem inputs can have multiple input connections
|
||||
targetFieldTemplate.type.name !== 'CollectionItemField'
|
||||
) {
|
||||
return buildRejectResult('nodes.inputMayOnlyHaveOneConnection');
|
||||
}
|
||||
|
||||
if (!validateConnectionTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) {
|
||||
return buildRejectResult('nodes.fieldTypesMustMatch');
|
||||
}
|
||||
|
||||
if (getHasCycles(c.source, c.target, nodes, edges)) {
|
||||
return buildRejectResult('nodes.connectionWouldCreateCycle');
|
||||
}
|
||||
|
||||
return buildAcceptResult();
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user