diff --git a/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts b/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts index efde3336e2..b68ff8bef6 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts @@ -298,7 +298,149 @@ export const main_model_loader: InvocationTemplate = { useCache: true, nodePack: 'invokeai', classification: 'stable', -} +}; + +export const img_resize: InvocationTemplate = { + title: 'Resize Image', + type: 'img_resize', + version: '1.2.2', + tags: ['image', 'resize'], + description: 'Resizes an image to specific dimensions', + outputType: 'image_output', + inputs: { + board: { + name: 'board', + title: 'Board', + required: false, + description: 'The board to save the image to', + fieldKind: 'input', + input: 'direct', + ui_hidden: false, + type: { + name: 'BoardField', + isCollection: false, + isCollectionOrScalar: false, + }, + }, + metadata: { + name: 'metadata', + title: 'Metadata', + required: false, + description: 'Optional metadata to be saved with the image', + fieldKind: 'input', + input: 'connection', + ui_hidden: false, + type: { + name: 'MetadataField', + isCollection: false, + isCollectionOrScalar: false, + }, + }, + image: { + name: 'image', + title: 'Image', + required: true, + description: 'The image to resize', + fieldKind: 'input', + input: 'any', + ui_hidden: false, + type: { + name: 'ImageField', + isCollection: false, + isCollectionOrScalar: false, + }, + }, + width: { + name: 'width', + title: 'Width', + required: false, + description: 'The width to resize to (px)', + fieldKind: 'input', + input: 'any', + ui_hidden: false, + type: { + name: 'IntegerField', + isCollection: false, + isCollectionOrScalar: false, + }, + default: 512, + exclusiveMinimum: 0, + }, + height: { + name: 'height', + title: 'Height', + required: false, + description: 'The height to resize to (px)', + fieldKind: 'input', + input: 'any', + ui_hidden: false, + type: { + name: 'IntegerField', + isCollection: false, + isCollectionOrScalar: false, + }, + default: 512, + exclusiveMinimum: 0, + }, + resample_mode: { + name: 'resample_mode', + title: 'Resample Mode', + required: false, + description: 'The resampling mode', + fieldKind: 'input', + input: 'any', + ui_hidden: false, + type: { + name: 'EnumField', + isCollection: false, + isCollectionOrScalar: false, + }, + options: ['nearest', 'box', 'bilinear', 'hamming', 'bicubic', 'lanczos'], + default: 'bicubic', + }, + }, + outputs: { + image: { + fieldKind: 'output', + name: 'image', + title: 'Image', + description: 'The output image', + type: { + name: 'ImageField', + isCollection: false, + isCollectionOrScalar: false, + }, + ui_hidden: false, + }, + width: { + fieldKind: 'output', + name: 'width', + title: 'Width', + description: 'The width of the image in pixels', + type: { + name: 'IntegerField', + isCollection: false, + isCollectionOrScalar: false, + }, + ui_hidden: false, + }, + height: { + fieldKind: 'output', + name: 'height', + title: 'Height', + description: 'The height of the image in pixels', + type: { + name: 'IntegerField', + isCollection: false, + isCollectionOrScalar: false, + }, + ui_hidden: false, + }, + }, + useCache: true, + nodePack: 'invokeai', + classification: 'stable', +}; export const templates: Templates = { add, @@ -306,6 +448,7 @@ export const templates: Templates = { collect, scheduler, main_model_loader, + img_resize, }; export const schema = { @@ -1068,6 +1211,205 @@ export const schema = { }, class: 'invocation', }, + ImageResizeInvocation: { + properties: { + board: { + anyOf: [ + { + $ref: '#/components/schemas/BoardField', + }, + { + type: 'null', + }, + ], + description: 'The board to save the image to', + field_kind: 'internal', + input: 'direct', + orig_required: false, + ui_hidden: false, + }, + metadata: { + anyOf: [ + { + $ref: '#/components/schemas/MetadataField', + }, + { + type: 'null', + }, + ], + description: 'Optional metadata to be saved with the image', + field_kind: 'internal', + input: 'connection', + orig_required: false, + ui_hidden: false, + }, + id: { + type: 'string', + title: 'Id', + description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.', + field_kind: 'node_attribute', + }, + is_intermediate: { + type: 'boolean', + title: 'Is Intermediate', + description: 'Whether or not this is an intermediate invocation.', + default: false, + field_kind: 'node_attribute', + ui_type: 'IsIntermediate', + }, + use_cache: { + type: 'boolean', + title: 'Use Cache', + description: 'Whether or not to use the cache', + default: true, + field_kind: 'node_attribute', + }, + image: { + allOf: [ + { + $ref: '#/components/schemas/ImageField', + }, + ], + description: 'The image to resize', + field_kind: 'input', + input: 'any', + orig_required: true, + ui_hidden: false, + }, + width: { + type: 'integer', + exclusiveMinimum: 0, + title: 'Width', + description: 'The width to resize to (px)', + default: 512, + field_kind: 'input', + input: 'any', + orig_default: 512, + orig_required: false, + ui_hidden: false, + }, + height: { + type: 'integer', + exclusiveMinimum: 0, + title: 'Height', + description: 'The height to resize to (px)', + default: 512, + field_kind: 'input', + input: 'any', + orig_default: 512, + orig_required: false, + ui_hidden: false, + }, + resample_mode: { + type: 'string', + enum: ['nearest', 'box', 'bilinear', 'hamming', 'bicubic', 'lanczos'], + title: 'Resample Mode', + description: 'The resampling mode', + default: 'bicubic', + field_kind: 'input', + input: 'any', + orig_default: 'bicubic', + orig_required: false, + ui_hidden: false, + }, + type: { + type: 'string', + enum: ['img_resize'], + const: 'img_resize', + title: 'type', + default: 'img_resize', + field_kind: 'node_attribute', + }, + }, + type: 'object', + required: ['type', 'id'], + title: 'Resize Image', + description: 'Resizes an image to specific dimensions', + category: 'image', + classification: 'stable', + node_pack: 'invokeai', + tags: ['image', 'resize'], + version: '1.2.2', + output: { + $ref: '#/components/schemas/ImageOutput', + }, + class: 'invocation', + }, + ImageField: { + description: 'An image primitive field', + properties: { + image_name: { + description: 'The name of the image', + title: 'Image Name', + type: 'string', + }, + }, + required: ['image_name'], + title: 'ImageField', + type: 'object', + class: 'output', + }, + ImageOutput: { + description: 'Base class for nodes that output a single image', + properties: { + image: { + allOf: [ + { + $ref: '#/components/schemas/ImageField', + }, + ], + description: 'The output image', + field_kind: 'output', + ui_hidden: false, + }, + width: { + description: 'The width of the image in pixels', + field_kind: 'output', + title: 'Width', + type: 'integer', + ui_hidden: false, + }, + height: { + description: 'The height of the image in pixels', + field_kind: 'output', + title: 'Height', + type: 'integer', + ui_hidden: false, + }, + type: { + const: 'image_output', + default: 'image_output', + enum: ['image_output'], + field_kind: 'node_attribute', + title: 'type', + type: 'string', + }, + }, + required: ['image', 'width', 'height', 'type', 'type'], + title: 'ImageOutput', + type: 'object', + class: 'output', + }, + MetadataField: { + description: + 'Pydantic model for metadata with custom root of type dict[str, Any].\nMetadata is stored without a strict schema.', + title: 'MetadataField', + type: 'object', + class: 'output', + }, + BoardField: { + properties: { + board_id: { + type: 'string', + title: 'Board Id', + description: 'The id of the board', + }, + }, + type: 'object', + required: ['board_id'], + title: 'BoardField', + description: 'A board primitive field', + }, }, }, } as OpenAPIV3_1.Document; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts index 5d10ef368b..cf05b4deb6 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts @@ -3,7 +3,7 @@ import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNod import { set } from 'lodash-es'; import { describe, expect, it } from 'vitest'; -import { add, buildEdge, collect, main_model_loader, position, sub, templates } from './testUtils'; +import { add, buildEdge, collect, img_resize, main_model_loader, position, sub, templates } from './testUtils'; import { buildAcceptResult, buildRejectResult, validateConnection } from './validateConnection'; describe(validateConnection.name, () => { @@ -146,4 +146,24 @@ describe(validateConnection.name, () => { const r = validateConnection(c, nodes, edges, templates, e1); expect(r).toEqual(buildAcceptResult()); }); + + it('should reject connections between invalid types', () => { + const n1 = buildInvocationNode(position, add); + const n2 = buildInvocationNode(position, img_resize); + const nodes = [n1, n2]; + const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'image' }; + const r = validateConnection(c, nodes, [], templates, null); + expect(r).toEqual(buildRejectResult('nodes.fieldTypesMustMatch')); + }); + + it('should reject connections that would create cycles', () => { + const n1 = buildInvocationNode(position, add); + const n2 = buildInvocationNode(position, sub); + const nodes = [n1, n2]; + const e1 = buildEdge(n1.id, 'value', n2.id, 'a'); + const edges = [e1]; + const c = { source: n2.id, sourceHandle: 'value', target: n1.id, targetHandle: 'a' }; + const r = validateConnection(c, nodes, edges, templates, null); + expect(r).toEqual(buildRejectResult('nodes.connectionWouldCreateCycle')); + }); }); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts index d45a75ab9f..db8b7b737e 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts @@ -1,6 +1,8 @@ import type { Templates } from 'features/nodes/store/types'; import { areTypesEqual } from 'features/nodes/store/util/areTypesEqual'; import { getCollectItemType } from 'features/nodes/store/util/getCollectItemType'; +import { getHasCycles } from 'features/nodes/store/util/getHasCycles'; +import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes'; import type { AnyNode } from 'features/nodes/types/invocation'; import type { Connection as NullableConnection, Edge } from 'reactflow'; import type { O } from 'ts-toolbelt'; @@ -36,6 +38,12 @@ const getEqualityPredicate = ); }; +const getTargetEqualityPredicate = + (c: Connection) => + (e: Edge): boolean => { + return e.target === c.target && e.targetHandle === c.targetHandle; + }; + export const buildAcceptResult = (): ValidateConnectionResult => ({ isValid: true }); export const buildRejectResult = (messageTKey: string): ValidateConnectionResult => ({ isValid: false, messageTKey }); @@ -44,6 +52,12 @@ export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, temp return buildRejectResult('nodes.cannotConnectToSelf'); } + /** + * We may need to ignore an edge when validating a connection. + * + * For example, while an edge is being updated, it still exists in the array of edges. As we validate the new connection, + * the user experience should be that the edge is temporarily removed from the graph, so we need to ignore it. + */ const filteredEdges = edges.filter((e) => e.id !== ignoreEdge?.id); if (filteredEdges.some(getEqualityPredicate(c))) { @@ -96,14 +110,20 @@ export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, temp } if ( - edges.find((e) => { - return e.target === c.target && e.targetHandle === c.targetHandle; - }) && - // except CollectionItem inputs can have multiples + filteredEdges.find(getTargetEqualityPredicate(c)) && + // except CollectionItem inputs can have multiple input connections targetFieldTemplate.type.name !== 'CollectionItemField' ) { return buildRejectResult('nodes.inputMayOnlyHaveOneConnection'); } + if (!validateConnectionTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) { + return buildRejectResult('nodes.fieldTypesMustMatch'); + } + + if (getHasCycles(c.source, c.target, nodes, edges)) { + return buildRejectResult('nodes.connectionWouldCreateCycle'); + } + return buildAcceptResult(); };