diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 7de7a8e01c..1f44e641fc 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -775,6 +775,9 @@ "cannotConnectToSelf": "Cannot connect to self", "cannotDuplicateConnection": "Cannot create duplicate connections", "cannotMixAndMatchCollectionItemTypes": "Cannot mix and match collection item types", + "missingNode": "Missing invocation node", + "missingInvocationTemplate": "Missing invocation template", + "missingFieldTemplate": "Missing field template", "nodePack": "Node pack", "collection": "Collection", "collectionFieldType": "{{name}} Collection", diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx index 214fc069f9..14d69b4720 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx @@ -17,7 +17,8 @@ import { nodeAdded, openAddNodePopover, } from 'features/nodes/store/nodesSlice'; -import { getFirstValidConnection, validateSourceAndTargetTypes } from 'features/nodes/store/util/connectionValidation'; +import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection'; +import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes'; import type { AnyNode } from 'features/nodes/types/invocation'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { filter, map, memoize, some } from 'lodash-es'; @@ -77,7 +78,7 @@ const AddNodePopover = () => { return some(fields, (field) => { const sourceType = pendingFieldKind === 'input' ? field.type : pendingConnection.fieldTemplate.type; const targetType = pendingFieldKind === 'output' ? field.type : pendingConnection.fieldTemplate.type; - return validateSourceAndTargetTypes(sourceType, targetType); + return validateConnectionTypes(sourceType, targetType); }); }); }, [templates, pendingConnection]); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts index f0dba67bf5..0190a0b29e 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts @@ -8,7 +8,7 @@ import { $templates, connectionMade, } from 'features/nodes/store/nodesSlice'; -import { getFirstValidConnection } from 'features/nodes/store/util/connectionValidation'; +import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { isString } from 'lodash-es'; import { useCallback, useMemo } from 'react'; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts index b92114bab2..77c4e3c75b 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts @@ -2,12 +2,10 @@ import { useStore } from '@nanostores/react'; import { useAppSelector, useAppStore } from 'app/store/storeHooks'; import { $templates } from 'features/nodes/store/nodesSlice'; -import { - areTypesEqual, - getCollectItemType, - getHasCycles, - validateSourceAndTargetTypes, -} from 'features/nodes/store/util/connectionValidation'; +import { areTypesEqual } from 'features/nodes/store/util/areTypesEqual'; +import { getCollectItemType } from 'features/nodes/store/util/getCollectItemType'; +import { getHasCycles } from 'features/nodes/store/util/getHasCycles'; +import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes'; import type { InvocationNodeData } from 'features/nodes/types/invocation'; import { useCallback } from 'react'; import type { Connection, Node } from 'reactflow'; @@ -88,7 +86,7 @@ export const useIsValidConnection = () => { } // Must use the originalType here if it exists - if (!validateSourceAndTargetTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) { + if (!validateConnectionTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) { return false; } diff --git a/invokeai/frontend/web/src/features/nodes/store/util/areTypesEqual.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/areTypesEqual.test.ts new file mode 100644 index 0000000000..7be307d07e --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/areTypesEqual.test.ts @@ -0,0 +1,101 @@ +import { describe, expect, it } from 'vitest'; + +import { areTypesEqual } from './areTypesEqual'; + +describe(areTypesEqual.name, () => { + it('should handle equal source and target type', () => { + const sourceType = { + name: 'IntegerField', + isCollection: false, + isCollectionOrScalar: false, + originalType: { + name: 'Foo', + isCollection: false, + isCollectionOrScalar: false, + }, + }; + const targetType = { + name: 'IntegerField', + isCollection: false, + isCollectionOrScalar: false, + originalType: { + name: 'Bar', + isCollection: false, + isCollectionOrScalar: false, + }, + }; + expect(areTypesEqual(sourceType, targetType)).toBe(true); + }); + + it('should handle equal source type and original target type', () => { + const sourceType = { + name: 'IntegerField', + isCollection: false, + isCollectionOrScalar: false, + originalType: { + name: 'Foo', + isCollection: false, + isCollectionOrScalar: false, + }, + }; + const targetType = { + name: 'Bar', + isCollection: false, + isCollectionOrScalar: false, + originalType: { + name: 'IntegerField', + isCollection: false, + isCollectionOrScalar: false, + }, + }; + expect(areTypesEqual(sourceType, targetType)).toBe(true); + }); + + it('should handle equal original source type and target type', () => { + const sourceType = { + name: 'Foo', + isCollection: false, + isCollectionOrScalar: false, + originalType: { + name: 'IntegerField', + isCollection: false, + isCollectionOrScalar: false, + }, + }; + const targetType = { + name: 'IntegerField', + isCollection: false, + isCollectionOrScalar: false, + originalType: { + name: 'Bar', + isCollection: false, + isCollectionOrScalar: false, + }, + }; + expect(areTypesEqual(sourceType, targetType)).toBe(true); + }); + + it('should handle equal original source type and original target type', () => { + const sourceType = { + name: 'Foo', + isCollection: false, + isCollectionOrScalar: false, + originalType: { + name: 'IntegerField', + isCollection: false, + isCollectionOrScalar: false, + }, + }; + const targetType = { + name: 'Bar', + isCollection: false, + isCollectionOrScalar: false, + originalType: { + name: 'IntegerField', + isCollection: false, + isCollectionOrScalar: false, + }, + }; + expect(areTypesEqual(sourceType, targetType)).toBe(true); + }); +}); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/areTypesEqual.ts b/invokeai/frontend/web/src/features/nodes/store/util/areTypesEqual.ts new file mode 100644 index 0000000000..e01b48b972 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/areTypesEqual.ts @@ -0,0 +1,30 @@ +import type { FieldType } from 'features/nodes/types/field'; +import { isEqual, omit } from 'lodash-es'; + +/** + * Checks if two types are equal. If the field types have original types, those are also compared. Any match is + * considered equal. For example, if the source type and original target type match, the types are considered equal. + * @param sourceType The type of the source field. + * @param targetType The type of the target field. + * @returns True if the types are equal, false otherwise. + */ + +export const areTypesEqual = (sourceType: FieldType, targetType: FieldType) => { + const _sourceType = 'originalType' in sourceType ? omit(sourceType, 'originalType') : sourceType; + const _targetType = 'originalType' in targetType ? omit(targetType, 'originalType') : targetType; + const _sourceTypeOriginal = 'originalType' in sourceType ? sourceType.originalType : null; + const _targetTypeOriginal = 'originalType' in targetType ? targetType.originalType : null; + if (isEqual(_sourceType, _targetType)) { + return true; + } + if (_targetTypeOriginal && isEqual(_sourceType, _targetTypeOriginal)) { + return true; + } + if (_sourceTypeOriginal && isEqual(_sourceTypeOriginal, _targetType)) { + return true; + } + if (_sourceTypeOriginal && _targetTypeOriginal && isEqual(_sourceTypeOriginal, _targetTypeOriginal)) { + return true; + } + return false; +}; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/connectionValidation.ts b/invokeai/frontend/web/src/features/nodes/store/util/connectionValidation.ts index a2f723fcfe..7819221f8a 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/connectionValidation.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/connectionValidation.ts @@ -1,179 +1,16 @@ -import graphlib from '@dagrejs/graphlib'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import type { RootState } from 'app/store/store'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; import type { NodesState, PendingConnection, Templates } from 'features/nodes/store/types'; -import { type FieldType, isStatefulFieldType } from 'features/nodes/types/field'; -import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation'; +import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes'; +import type { FieldType } from 'features/nodes/types/field'; import i18n from 'i18next'; -import { differenceWith, isEqual, map, omit } from 'lodash-es'; -import type { Connection, Edge, HandleType, Node } from 'reactflow'; +import type { HandleType } from 'reactflow'; import { assert } from 'tsafe'; -/** - * Finds the first valid field for a pending connection between two nodes. - * @param templates The invocation templates - * @param nodes The current nodes - * @param edges The current edges - * @param pendingConnection The pending connection - * @param candidateNode The candidate node to which the connection is being made - * @param candidateTemplate The candidate template for the candidate node - * @returns The first valid connection, or null if no valid connection is found - */ -export const getFirstValidConnection = ( - templates: Templates, - nodes: AnyNode[], - edges: InvocationNodeEdge[], - pendingConnection: PendingConnection, - candidateNode: InvocationNode, - candidateTemplate: InvocationTemplate -): Connection | null => { - if (pendingConnection.node.id === candidateNode.id) { - // Cannot connect to self - return null; - } - - const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source'; - - if (pendingFieldKind === 'source') { - // Connecting from a source to a target - if (getHasCycles(pendingConnection.node.id, candidateNode.id, nodes, edges)) { - return null; - } - if (candidateNode.data.type === 'collect') { - // Special handling for collect node - the `item` field takes any number of connections - return { - source: pendingConnection.node.id, - sourceHandle: pendingConnection.fieldTemplate.name, - target: candidateNode.id, - targetHandle: 'item', - }; - } - // Only one connection per target field is allowed - look for an unconnected target field - const candidateFields = map(candidateTemplate.inputs); - const candidateConnectedFields = edges - .filter((edge) => edge.target === candidateNode.id) - .map((edge) => { - // Edges must always have a targetHandle, safe to assert here - assert(edge.targetHandle); - return edge.targetHandle; - }); - const candidateUnconnectedFields = differenceWith( - candidateFields, - candidateConnectedFields, - (field, connectedFieldName) => field.name === connectedFieldName - ); - const candidateField = candidateUnconnectedFields.find((field) => - validateSourceAndTargetTypes(pendingConnection.fieldTemplate.type, field.type) - ); - if (candidateField) { - return { - source: pendingConnection.node.id, - sourceHandle: pendingConnection.fieldTemplate.name, - target: candidateNode.id, - targetHandle: candidateField.name, - }; - } - } else { - // Connecting from a target to a source - // Ensure we there is not already an edge to the target, except for collect nodes - const isCollect = pendingConnection.node.data.type === 'collect'; - const isTargetAlreadyConnected = edges.some( - (e) => e.target === pendingConnection.node.id && e.targetHandle === pendingConnection.fieldTemplate.name - ); - if (!isCollect && isTargetAlreadyConnected) { - return null; - } - - if (getHasCycles(candidateNode.id, pendingConnection.node.id, nodes, edges)) { - return null; - } - - // Sources/outputs can have any number of edges, we can take the first matching output field - let candidateFields = map(candidateTemplate.outputs); - if (isCollect) { - // Narrow candidates to same field type as already is connected to the collect node - const collectItemType = getCollectItemType(templates, nodes, edges, pendingConnection.node.id); - if (collectItemType) { - candidateFields = candidateFields.filter((field) => areTypesEqual(field.type, collectItemType)); - } - } - const candidateField = candidateFields.find((field) => { - const isValid = validateSourceAndTargetTypes(field.type, pendingConnection.fieldTemplate.type); - const isAlreadyConnected = edges.some((e) => e.source === candidateNode.id && e.sourceHandle === field.name); - return isValid && !isAlreadyConnected; - }); - if (candidateField) { - return { - source: candidateNode.id, - sourceHandle: candidateField.name, - target: pendingConnection.node.id, - targetHandle: pendingConnection.fieldTemplate.name, - }; - } - } - - return null; -}; - -/** - * Check if adding an edge between the source and target nodes would create a cycle in the graph. - * @param source The source node id - * @param target The target node id - * @param nodes The graph's current nodes - * @param edges The graph's current edges - * @returns True if the graph would be acyclic after adding the edge, false otherwise - */ -export const getHasCycles = (source: string, target: string, nodes: Node[], edges: Edge[]) => { - // construct graphlib graph from editor state - const g = new graphlib.Graph(); - - nodes.forEach((n) => { - g.setNode(n.id); - }); - - edges.forEach((e) => { - g.setEdge(e.source, e.target); - }); - - // add the candidate edge - g.setEdge(source, target); - - // check if the graph is acyclic - return !graphlib.alg.isAcyclic(g); -}; - -/** - * Given a collect node, return the type of the items it collects. The graph is traversed to find the first node and - * field connected to the collector's `item` input. The field type of that field is returned, else null if there is no - * input field. - * @param templates The current invocation templates - * @param nodes The current nodes - * @param edges The current edges - * @param nodeId The collect node's id - * @returns The type of the items the collect node collects, or null if there is no input field - */ -export const getCollectItemType = ( - templates: Templates, - nodes: AnyNode[], - edges: InvocationNodeEdge[], - nodeId: string -): FieldType | null => { - const firstEdgeToCollect = edges.find((edge) => edge.target === nodeId && edge.targetHandle === 'item'); - if (!firstEdgeToCollect?.sourceHandle) { - return null; - } - const node = nodes.find((n) => n.id === firstEdgeToCollect.source); - if (!node) { - return null; - } - const template = templates[node.data.type]; - if (!template) { - return null; - } - const fieldType = template.outputs[firstEdgeToCollect.sourceHandle]?.type ?? null; - return fieldType; -}; +import { areTypesEqual } from './areTypesEqual'; +import { getCollectItemType } from './getCollectItemType'; +import { getHasCycles } from './getHasCycles'; /** * Creates a selector that validates a pending connection. @@ -276,7 +113,7 @@ export const makeConnectionErrorSelector = ( return i18n.t('nodes.inputMayOnlyHaveOneConnection'); } - if (!validateSourceAndTargetTypes(sourceType, targetType)) { + if (!validateConnectionTypes(sourceType, targetType)) { return i18n.t('nodes.fieldTypesMustMatch'); } @@ -295,97 +132,3 @@ export const makeConnectionErrorSelector = ( } ); }; - -/** - * Validates that the source and target types are compatible for a connection. - * @param sourceType The type of the source field. - * @param targetType The type of the target field. - * @returns True if the connection is valid, false otherwise. - */ -export const validateSourceAndTargetTypes = (sourceType: FieldType, targetType: FieldType) => { - // TODO: There's a bug with Collect -> Iterate nodes: - // https://github.com/invoke-ai/InvokeAI/issues/3956 - // Once this is resolved, we can remove this check. - if (sourceType.name === 'CollectionField' && targetType.name === 'CollectionField') { - return false; - } - - if (areTypesEqual(sourceType, targetType)) { - return true; - } - - /** - * Connection types must be the same for a connection, with exceptions: - * - CollectionItem can connect to any non-Collection - * - Non-Collections can connect to CollectionItem - * - Anything (non-Collections, Collections, CollectionOrScalar) can connect to CollectionOrScalar of the same base type - * - Generic Collection can connect to any other Collection or CollectionOrScalar - * - Any Collection can connect to a Generic Collection - */ - const isCollectionItemToNonCollection = sourceType.name === 'CollectionItemField' && !targetType.isCollection; - - const isNonCollectionToCollectionItem = - targetType.name === 'CollectionItemField' && !sourceType.isCollection && !sourceType.isCollectionOrScalar; - - const isAnythingToCollectionOrScalarOfSameBaseType = - targetType.isCollectionOrScalar && sourceType.name === targetType.name; - - const isGenericCollectionToAnyCollectionOrCollectionOrScalar = - sourceType.name === 'CollectionField' && (targetType.isCollection || targetType.isCollectionOrScalar); - - const isCollectionToGenericCollection = targetType.name === 'CollectionField' && sourceType.isCollection; - - const areBothTypesSingle = - !sourceType.isCollection && - !sourceType.isCollectionOrScalar && - !targetType.isCollection && - !targetType.isCollectionOrScalar; - - const isIntToFloat = areBothTypesSingle && sourceType.name === 'IntegerField' && targetType.name === 'FloatField'; - - const isIntOrFloatToString = - areBothTypesSingle && - (sourceType.name === 'IntegerField' || sourceType.name === 'FloatField') && - targetType.name === 'StringField'; - - const isTargetAnyType = targetType.name === 'AnyField'; - - // One of these must be true for the connection to be valid - return ( - isCollectionItemToNonCollection || - isNonCollectionToCollectionItem || - isAnythingToCollectionOrScalarOfSameBaseType || - isGenericCollectionToAnyCollectionOrCollectionOrScalar || - isCollectionToGenericCollection || - isIntToFloat || - isIntOrFloatToString || - isTargetAnyType - ); -}; - -/** - * Checks if two types are equal. If the field types have original types, those are also compared. Any match is - * considered equal. For example, if the source type and original target type match, the types are considered equal. - * @param sourceType The type of the source field. - * @param targetType The type of the target field. - * @returns True if the types are equal, false otherwise. - */ -export const areTypesEqual = (sourceType: FieldType, targetType: FieldType) => { - const _sourceType = isStatefulFieldType(sourceType) ? omit(sourceType, 'originalType') : sourceType; - const _targetType = isStatefulFieldType(targetType) ? omit(targetType, 'originalType') : targetType; - const _sourceTypeOriginal = isStatefulFieldType(sourceType) ? sourceType.originalType : sourceType; - const _targetTypeOriginal = isStatefulFieldType(targetType) ? targetType.originalType : targetType; - if (isEqual(_sourceType, _targetType)) { - return true; - } - if (_targetTypeOriginal && isEqual(_sourceType, _targetTypeOriginal)) { - return true; - } - if (_sourceTypeOriginal && isEqual(_sourceTypeOriginal, _targetType)) { - return true; - } - if (_sourceTypeOriginal && _targetTypeOriginal && isEqual(_sourceTypeOriginal, _targetTypeOriginal)) { - return true; - } - return false; -}; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.test.ts new file mode 100644 index 0000000000..93c63b6f41 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.test.ts @@ -0,0 +1,16 @@ +import { getCollectItemType } from 'features/nodes/store/util/getCollectItemType'; +import { add, buildEdge, collect, position, templates } from 'features/nodes/store/util/testUtils'; +import type { FieldType } from 'features/nodes/types/field'; +import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode'; +import { describe, expect, it } from 'vitest'; + +describe(getCollectItemType.name, () => { + it('should return the type of the items the collect node collects', () => { + const n1 = buildInvocationNode(position, add); + const n2 = buildInvocationNode(position, collect); + const nodes = [n1, n2]; + const edges = [buildEdge(n1.id, 'value', n2.id, 'item')]; + const result = getCollectItemType(templates, nodes, edges, n2.id); + expect(result).toEqual({ name: 'IntegerField', isCollection: false, isCollectionOrScalar: false }); + }); +}); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.ts b/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.ts new file mode 100644 index 0000000000..9e0ce0fbee --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.ts @@ -0,0 +1,35 @@ +import type { Templates } from 'features/nodes/store/types'; +import type { FieldType } from 'features/nodes/types/field'; +import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation'; + +/** + * Given a collect node, return the type of the items it collects. The graph is traversed to find the first node and + * field connected to the collector's `item` input. The field type of that field is returned, else null if there is no + * input field. + * @param templates The current invocation templates + * @param nodes The current nodes + * @param edges The current edges + * @param nodeId The collect node's id + * @returns The type of the items the collect node collects, or null if there is no input field + */ +export const getCollectItemType = ( + templates: Templates, + nodes: AnyNode[], + edges: InvocationNodeEdge[], + nodeId: string +): FieldType | null => { + const firstEdgeToCollect = edges.find((edge) => edge.target === nodeId && edge.targetHandle === 'item'); + if (!firstEdgeToCollect?.sourceHandle) { + return null; + } + const node = nodes.find((n) => n.id === firstEdgeToCollect.source); + if (!node) { + return null; + } + const template = templates[node.data.type]; + if (!template) { + return null; + } + const fieldType = template.outputs[firstEdgeToCollect.sourceHandle]?.type ?? null; + return fieldType; +}; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts b/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts new file mode 100644 index 0000000000..98155f0c20 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts @@ -0,0 +1,116 @@ +import type { PendingConnection, Templates } from 'features/nodes/store/types'; +import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes'; +import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation'; +import { differenceWith, map } from 'lodash-es'; +import type { Connection } from 'reactflow'; +import { assert } from 'tsafe'; + +import { areTypesEqual } from './areTypesEqual'; +import { getCollectItemType } from './getCollectItemType'; +import { getHasCycles } from './getHasCycles'; + +/** + * Finds the first valid field for a pending connection between two nodes. + * @param templates The invocation templates + * @param nodes The current nodes + * @param edges The current edges + * @param pendingConnection The pending connection + * @param candidateNode The candidate node to which the connection is being made + * @param candidateTemplate The candidate template for the candidate node + * @returns The first valid connection, or null if no valid connection is found + */ + +export const getFirstValidConnection = ( + templates: Templates, + nodes: AnyNode[], + edges: InvocationNodeEdge[], + pendingConnection: PendingConnection, + candidateNode: InvocationNode, + candidateTemplate: InvocationTemplate +): Connection | null => { + if (pendingConnection.node.id === candidateNode.id) { + // Cannot connect to self + return null; + } + + const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source'; + + if (pendingFieldKind === 'source') { + // Connecting from a source to a target + if (getHasCycles(pendingConnection.node.id, candidateNode.id, nodes, edges)) { + return null; + } + if (candidateNode.data.type === 'collect') { + // Special handling for collect node - the `item` field takes any number of connections + return { + source: pendingConnection.node.id, + sourceHandle: pendingConnection.fieldTemplate.name, + target: candidateNode.id, + targetHandle: 'item', + }; + } + // Only one connection per target field is allowed - look for an unconnected target field + const candidateFields = map(candidateTemplate.inputs); + const candidateConnectedFields = edges + .filter((edge) => edge.target === candidateNode.id) + .map((edge) => { + // Edges must always have a targetHandle, safe to assert here + assert(edge.targetHandle); + return edge.targetHandle; + }); + const candidateUnconnectedFields = differenceWith( + candidateFields, + candidateConnectedFields, + (field, connectedFieldName) => field.name === connectedFieldName + ); + const candidateField = candidateUnconnectedFields.find((field) => validateConnectionTypes(pendingConnection.fieldTemplate.type, field.type) + ); + if (candidateField) { + return { + source: pendingConnection.node.id, + sourceHandle: pendingConnection.fieldTemplate.name, + target: candidateNode.id, + targetHandle: candidateField.name, + }; + } + } else { + // Connecting from a target to a source + // Ensure we there is not already an edge to the target, except for collect nodes + const isCollect = pendingConnection.node.data.type === 'collect'; + const isTargetAlreadyConnected = edges.some( + (e) => e.target === pendingConnection.node.id && e.targetHandle === pendingConnection.fieldTemplate.name + ); + if (!isCollect && isTargetAlreadyConnected) { + return null; + } + + if (getHasCycles(candidateNode.id, pendingConnection.node.id, nodes, edges)) { + return null; + } + + // Sources/outputs can have any number of edges, we can take the first matching output field + let candidateFields = map(candidateTemplate.outputs); + if (isCollect) { + // Narrow candidates to same field type as already is connected to the collect node + const collectItemType = getCollectItemType(templates, nodes, edges, pendingConnection.node.id); + if (collectItemType) { + candidateFields = candidateFields.filter((field) => areTypesEqual(field.type, collectItemType)); + } + } + const candidateField = candidateFields.find((field) => { + const isValid = validateConnectionTypes(field.type, pendingConnection.fieldTemplate.type); + const isAlreadyConnected = edges.some((e) => e.source === candidateNode.id && e.sourceHandle === field.name); + return isValid && !isAlreadyConnected; + }); + if (candidateField) { + return { + source: candidateNode.id, + sourceHandle: candidateField.name, + target: pendingConnection.node.id, + targetHandle: pendingConnection.fieldTemplate.name, + }; + } + } + + return null; +}; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/getHasCycles.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/getHasCycles.test.ts new file mode 100644 index 0000000000..872da36998 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/getHasCycles.test.ts @@ -0,0 +1,23 @@ +import { getHasCycles } from 'features/nodes/store/util/getHasCycles'; +import { add, buildEdge, position } from 'features/nodes/store/util/testUtils'; +import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode'; +import { describe, expect, it } from 'vitest'; + +describe(getHasCycles.name, () => { + const n1 = buildInvocationNode(position, add); + const n2 = buildInvocationNode(position, add); + const n3 = buildInvocationNode(position, add); + const nodes = [n1, n2, n3]; + + it('should return true if the graph WOULD have cycles after adding the edge', () => { + const edges = [buildEdge(n1.id, 'value', n2.id, 'a'), buildEdge(n2.id, 'value', n3.id, 'a')]; + const result = getHasCycles(n3.id, n1.id, nodes, edges); + expect(result).toBe(true); + }); + + it('should return false if the graph WOULD NOT have cycles after adding the edge', () => { + const edges = [buildEdge(n1.id, 'value', n2.id, 'a')]; + const result = getHasCycles(n2.id, n3.id, nodes, edges); + expect(result).toBe(false); + }); +}); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/getHasCycles.ts b/invokeai/frontend/web/src/features/nodes/store/util/getHasCycles.ts new file mode 100644 index 0000000000..c1a4e51f0c --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/getHasCycles.ts @@ -0,0 +1,30 @@ +import graphlib from '@dagrejs/graphlib'; +import type { Edge, Node } from 'reactflow'; + +/** + * Check if adding an edge between the source and target nodes would create a cycle in the graph. + * @param source The source node id + * @param target The target node id + * @param nodes The graph's current nodes + * @param edges The graph's current edges + * @returns True if the graph would be acyclic after adding the edge, false otherwise + */ + +export const getHasCycles = (source: string, target: string, nodes: Node[], edges: Edge[]) => { + // construct graphlib graph from editor state + const g = new graphlib.Graph(); + + nodes.forEach((n) => { + g.setNode(n.id); + }); + + edges.forEach((e) => { + g.setEdge(e.source, e.target); + }); + + // add the candidate edge + g.setEdge(source, target); + + // check if the graph is acyclic + return !graphlib.alg.isAcyclic(g); +}; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts b/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts new file mode 100644 index 0000000000..efde3336e2 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts @@ -0,0 +1,1073 @@ +import type { Templates } from 'features/nodes/store/types'; +import type { InvocationTemplate } from 'features/nodes/types/invocation'; +import type { OpenAPIV3_1 } from 'openapi-types'; +import type { Edge, XYPosition } from 'reactflow'; + +export const buildEdge = (source: string, sourceHandle: string, target: string, targetHandle: string): Edge => ({ + source, + sourceHandle, + target, + targetHandle, + type: 'default', + id: `reactflow__edge-${source}${sourceHandle}-${target}${targetHandle}`, +}); + +export const position: XYPosition = { x: 0, y: 0 }; + +export const add: InvocationTemplate = { + title: 'Add Integers', + type: 'add', + version: '1.0.1', + tags: ['math', 'add'], + description: 'Adds two numbers', + outputType: 'integer_output', + inputs: { + a: { + name: 'a', + title: 'A', + required: false, + description: 'The first number', + fieldKind: 'input', + input: 'any', + ui_hidden: false, + type: { + name: 'IntegerField', + isCollection: false, + isCollectionOrScalar: false, + }, + default: 0, + }, + b: { + name: 'b', + title: 'B', + required: false, + description: 'The second number', + fieldKind: 'input', + input: 'any', + ui_hidden: false, + type: { + name: 'IntegerField', + isCollection: false, + isCollectionOrScalar: false, + }, + default: 0, + }, + }, + outputs: { + value: { + fieldKind: 'output', + name: 'value', + title: 'Value', + description: 'The output integer', + type: { + name: 'IntegerField', + isCollection: false, + isCollectionOrScalar: false, + }, + ui_hidden: false, + }, + }, + useCache: true, + nodePack: 'invokeai', + classification: 'stable', +}; + +export const sub: InvocationTemplate = { + title: 'Subtract Integers', + type: 'sub', + version: '1.0.1', + tags: ['math', 'subtract'], + description: 'Subtracts two numbers', + outputType: 'integer_output', + inputs: { + a: { + name: 'a', + title: 'A', + required: false, + description: 'The first number', + fieldKind: 'input', + input: 'any', + ui_hidden: false, + type: { + name: 'IntegerField', + isCollection: false, + isCollectionOrScalar: false, + }, + default: 0, + }, + b: { + name: 'b', + title: 'B', + required: false, + description: 'The second number', + fieldKind: 'input', + input: 'any', + ui_hidden: false, + type: { + name: 'IntegerField', + isCollection: false, + isCollectionOrScalar: false, + }, + default: 0, + }, + }, + outputs: { + value: { + fieldKind: 'output', + name: 'value', + title: 'Value', + description: 'The output integer', + type: { + name: 'IntegerField', + isCollection: false, + isCollectionOrScalar: false, + }, + ui_hidden: false, + }, + }, + useCache: true, + nodePack: 'invokeai', + classification: 'stable', +}; + +export const collect: InvocationTemplate = { + title: 'Collect', + type: 'collect', + version: '1.0.0', + tags: [], + description: 'Collects values into a collection', + outputType: 'collect_output', + inputs: { + item: { + name: 'item', + title: 'Collection Item', + required: false, + description: 'The item to collect (all inputs must be of the same type)', + fieldKind: 'input', + input: 'connection', + ui_hidden: false, + ui_type: 'CollectionItemField', + type: { + name: 'CollectionItemField', + isCollection: false, + isCollectionOrScalar: false, + }, + }, + }, + outputs: { + collection: { + fieldKind: 'output', + name: 'collection', + title: 'Collection', + description: 'The collection of input items', + type: { + name: 'CollectionField', + isCollection: true, + isCollectionOrScalar: false, + }, + ui_hidden: false, + ui_type: 'CollectionField', + }, + }, + useCache: true, + classification: 'stable', +}; + +export const scheduler: InvocationTemplate = { + title: 'Scheduler', + type: 'scheduler', + version: '1.0.0', + tags: ['scheduler'], + description: 'Selects a scheduler.', + outputType: 'scheduler_output', + inputs: { + scheduler: { + name: 'scheduler', + title: 'Scheduler', + required: false, + description: 'Scheduler to use during inference', + fieldKind: 'input', + input: 'any', + ui_hidden: false, + ui_type: 'SchedulerField', + type: { + name: 'SchedulerField', + isCollection: false, + isCollectionOrScalar: false, + originalType: { + name: 'EnumField', + isCollection: false, + isCollectionOrScalar: false, + }, + }, + default: 'euler', + }, + }, + outputs: { + scheduler: { + fieldKind: 'output', + name: 'scheduler', + title: 'Scheduler', + description: 'Scheduler to use during inference', + type: { + name: 'SchedulerField', + isCollection: false, + isCollectionOrScalar: false, + originalType: { + name: 'EnumField', + isCollection: false, + isCollectionOrScalar: false, + }, + }, + ui_hidden: false, + ui_type: 'SchedulerField', + }, + }, + useCache: true, + nodePack: 'invokeai', + classification: 'stable', +}; + +export const main_model_loader: InvocationTemplate = { + title: 'Main Model', + type: 'main_model_loader', + version: '1.0.2', + tags: ['model'], + description: 'Loads a main model, outputting its submodels.', + outputType: 'model_loader_output', + inputs: { + model: { + name: 'model', + title: 'Model', + required: true, + description: 'Main model (UNet, VAE, CLIP) to load', + fieldKind: 'input', + input: 'direct', + ui_hidden: false, + ui_type: 'MainModelField', + type: { + name: 'MainModelField', + isCollection: false, + isCollectionOrScalar: false, + originalType: { + name: 'ModelIdentifierField', + isCollection: false, + isCollectionOrScalar: false, + }, + }, + }, + }, + outputs: { + vae: { + fieldKind: 'output', + name: 'vae', + title: 'VAE', + description: 'VAE', + type: { + name: 'VAEField', + isCollection: false, + isCollectionOrScalar: false, + }, + ui_hidden: false, + }, + clip: { + fieldKind: 'output', + name: 'clip', + title: 'CLIP', + description: 'CLIP (tokenizer, text encoder, LoRAs) and skipped layer count', + type: { + name: 'CLIPField', + isCollection: false, + isCollectionOrScalar: false, + }, + ui_hidden: false, + }, + unet: { + fieldKind: 'output', + name: 'unet', + title: 'UNet', + description: 'UNet (scheduler, LoRAs)', + type: { + name: 'UNetField', + isCollection: false, + isCollectionOrScalar: false, + }, + ui_hidden: false, + }, + }, + useCache: true, + nodePack: 'invokeai', + classification: 'stable', +} + +export const templates: Templates = { + add, + sub, + collect, + scheduler, + main_model_loader, +}; + +export const schema = { + openapi: '3.1.0', + info: { + title: 'Invoke - Community Edition', + description: 'An API for invoking AI image operations', + version: '1.0.0', + }, + components: { + schemas: { + AddInvocation: { + properties: { + id: { + type: 'string', + title: 'Id', + description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.', + field_kind: 'node_attribute', + }, + is_intermediate: { + type: 'boolean', + title: 'Is Intermediate', + description: 'Whether or not this is an intermediate invocation.', + default: false, + field_kind: 'node_attribute', + ui_type: 'IsIntermediate', + }, + use_cache: { + type: 'boolean', + title: 'Use Cache', + description: 'Whether or not to use the cache', + default: true, + field_kind: 'node_attribute', + }, + a: { + type: 'integer', + title: 'A', + description: 'The first number', + default: 0, + field_kind: 'input', + input: 'any', + orig_default: 0, + orig_required: false, + ui_hidden: false, + }, + b: { + type: 'integer', + title: 'B', + description: 'The second number', + default: 0, + field_kind: 'input', + input: 'any', + orig_default: 0, + orig_required: false, + ui_hidden: false, + }, + type: { + type: 'string', + enum: ['add'], + const: 'add', + title: 'type', + default: 'add', + field_kind: 'node_attribute', + }, + }, + type: 'object', + required: ['type', 'id'], + title: 'Add Integers', + description: 'Adds two numbers', + category: 'math', + classification: 'stable', + node_pack: 'invokeai', + tags: ['math', 'add'], + version: '1.0.1', + output: { + $ref: '#/components/schemas/IntegerOutput', + }, + class: 'invocation', + }, + IntegerOutput: { + description: 'Base class for nodes that output a single integer', + properties: { + value: { + description: 'The output integer', + field_kind: 'output', + title: 'Value', + type: 'integer', + ui_hidden: false, + }, + type: { + const: 'integer_output', + default: 'integer_output', + enum: ['integer_output'], + field_kind: 'node_attribute', + title: 'type', + type: 'string', + }, + }, + required: ['value', 'type', 'type'], + title: 'IntegerOutput', + type: 'object', + class: 'output', + }, + SchedulerInvocation: { + properties: { + id: { + type: 'string', + title: 'Id', + description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.', + field_kind: 'node_attribute', + }, + is_intermediate: { + type: 'boolean', + title: 'Is Intermediate', + description: 'Whether or not this is an intermediate invocation.', + default: false, + field_kind: 'node_attribute', + ui_type: 'IsIntermediate', + }, + use_cache: { + type: 'boolean', + title: 'Use Cache', + description: 'Whether or not to use the cache', + default: true, + field_kind: 'node_attribute', + }, + scheduler: { + type: 'string', + enum: [ + 'ddim', + 'ddpm', + 'deis', + 'lms', + 'lms_k', + 'pndm', + 'heun', + 'heun_k', + 'euler', + 'euler_k', + 'euler_a', + 'kdpm_2', + 'kdpm_2_a', + 'dpmpp_2s', + 'dpmpp_2s_k', + 'dpmpp_2m', + 'dpmpp_2m_k', + 'dpmpp_2m_sde', + 'dpmpp_2m_sde_k', + 'dpmpp_sde', + 'dpmpp_sde_k', + 'unipc', + 'lcm', + 'tcd', + ], + title: 'Scheduler', + description: 'Scheduler to use during inference', + default: 'euler', + field_kind: 'input', + input: 'any', + orig_default: 'euler', + orig_required: false, + ui_hidden: false, + ui_type: 'SchedulerField', + }, + type: { + type: 'string', + enum: ['scheduler'], + const: 'scheduler', + title: 'type', + default: 'scheduler', + field_kind: 'node_attribute', + }, + }, + type: 'object', + required: ['type', 'id'], + title: 'Scheduler', + description: 'Selects a scheduler.', + category: 'latents', + classification: 'stable', + node_pack: 'invokeai', + tags: ['scheduler'], + version: '1.0.0', + output: { + $ref: '#/components/schemas/SchedulerOutput', + }, + class: 'invocation', + }, + SchedulerOutput: { + properties: { + scheduler: { + description: 'Scheduler to use during inference', + enum: [ + 'ddim', + 'ddpm', + 'deis', + 'lms', + 'lms_k', + 'pndm', + 'heun', + 'heun_k', + 'euler', + 'euler_k', + 'euler_a', + 'kdpm_2', + 'kdpm_2_a', + 'dpmpp_2s', + 'dpmpp_2s_k', + 'dpmpp_2m', + 'dpmpp_2m_k', + 'dpmpp_2m_sde', + 'dpmpp_2m_sde_k', + 'dpmpp_sde', + 'dpmpp_sde_k', + 'unipc', + 'lcm', + 'tcd', + ], + field_kind: 'output', + title: 'Scheduler', + type: 'string', + ui_hidden: false, + ui_type: 'SchedulerField', + }, + type: { + const: 'scheduler_output', + default: 'scheduler_output', + enum: ['scheduler_output'], + field_kind: 'node_attribute', + title: 'type', + type: 'string', + }, + }, + required: ['scheduler', 'type', 'type'], + title: 'SchedulerOutput', + type: 'object', + class: 'output', + }, + MainModelLoaderInvocation: { + properties: { + id: { + type: 'string', + title: 'Id', + description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.', + field_kind: 'node_attribute', + }, + is_intermediate: { + type: 'boolean', + title: 'Is Intermediate', + description: 'Whether or not this is an intermediate invocation.', + default: false, + field_kind: 'node_attribute', + ui_type: 'IsIntermediate', + }, + use_cache: { + type: 'boolean', + title: 'Use Cache', + description: 'Whether or not to use the cache', + default: true, + field_kind: 'node_attribute', + }, + model: { + allOf: [ + { + $ref: '#/components/schemas/ModelIdentifierField', + }, + ], + description: 'Main model (UNet, VAE, CLIP) to load', + field_kind: 'input', + input: 'direct', + orig_required: true, + ui_hidden: false, + ui_type: 'MainModelField', + }, + type: { + type: 'string', + enum: ['main_model_loader'], + const: 'main_model_loader', + title: 'type', + default: 'main_model_loader', + field_kind: 'node_attribute', + }, + }, + type: 'object', + required: ['model', 'type', 'id'], + title: 'Main Model', + description: 'Loads a main model, outputting its submodels.', + category: 'model', + classification: 'stable', + node_pack: 'invokeai', + tags: ['model'], + version: '1.0.2', + output: { + $ref: '#/components/schemas/ModelLoaderOutput', + }, + class: 'invocation', + }, + ModelIdentifierField: { + properties: { + key: { + description: "The model's unique key", + title: 'Key', + type: 'string', + }, + hash: { + description: "The model's BLAKE3 hash", + title: 'Hash', + type: 'string', + }, + name: { + description: "The model's name", + title: 'Name', + type: 'string', + }, + base: { + allOf: [ + { + $ref: '#/components/schemas/BaseModelType', + }, + ], + description: "The model's base model type", + }, + type: { + allOf: [ + { + $ref: '#/components/schemas/ModelType', + }, + ], + description: "The model's type", + }, + submodel_type: { + anyOf: [ + { + $ref: '#/components/schemas/SubModelType', + }, + { + type: 'null', + }, + ], + default: null, + description: 'The submodel to load, if this is a main model', + }, + }, + required: ['key', 'hash', 'name', 'base', 'type'], + title: 'ModelIdentifierField', + type: 'object', + }, + BaseModelType: { + description: 'Base model type.', + enum: ['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'], + title: 'BaseModelType', + type: 'string', + }, + ModelType: { + description: 'Model type.', + enum: ['onnx', 'main', 'vae', 'lora', 'controlnet', 'embedding', 'ip_adapter', 'clip_vision', 't2i_adapter'], + title: 'ModelType', + type: 'string', + }, + SubModelType: { + description: 'Submodel type.', + enum: [ + 'unet', + 'text_encoder', + 'text_encoder_2', + 'tokenizer', + 'tokenizer_2', + 'vae', + 'vae_decoder', + 'vae_encoder', + 'scheduler', + 'safety_checker', + ], + title: 'SubModelType', + type: 'string', + }, + ModelLoaderOutput: { + description: 'Model loader output', + properties: { + vae: { + allOf: [ + { + $ref: '#/components/schemas/VAEField', + }, + ], + description: 'VAE', + field_kind: 'output', + title: 'VAE', + ui_hidden: false, + }, + type: { + const: 'model_loader_output', + default: 'model_loader_output', + enum: ['model_loader_output'], + field_kind: 'node_attribute', + title: 'type', + type: 'string', + }, + clip: { + allOf: [ + { + $ref: '#/components/schemas/CLIPField', + }, + ], + description: 'CLIP (tokenizer, text encoder, LoRAs) and skipped layer count', + field_kind: 'output', + title: 'CLIP', + ui_hidden: false, + }, + unet: { + allOf: [ + { + $ref: '#/components/schemas/UNetField', + }, + ], + description: 'UNet (scheduler, LoRAs)', + field_kind: 'output', + title: 'UNet', + ui_hidden: false, + }, + }, + required: ['vae', 'type', 'clip', 'unet', 'type'], + title: 'ModelLoaderOutput', + type: 'object', + class: 'output', + }, + UNetField: { + properties: { + unet: { + allOf: [ + { + $ref: '#/components/schemas/ModelIdentifierField', + }, + ], + description: 'Info to load unet submodel', + }, + scheduler: { + allOf: [ + { + $ref: '#/components/schemas/ModelIdentifierField', + }, + ], + description: 'Info to load scheduler submodel', + }, + loras: { + description: 'LoRAs to apply on model loading', + items: { + $ref: '#/components/schemas/LoRAField', + }, + title: 'Loras', + type: 'array', + }, + seamless_axes: { + description: 'Axes("x" and "y") to which apply seamless', + items: { + type: 'string', + }, + title: 'Seamless Axes', + type: 'array', + }, + freeu_config: { + anyOf: [ + { + $ref: '#/components/schemas/FreeUConfig', + }, + { + type: 'null', + }, + ], + default: null, + description: 'FreeU configuration', + }, + }, + required: ['unet', 'scheduler', 'loras'], + title: 'UNetField', + type: 'object', + class: 'output', + }, + LoRAField: { + properties: { + lora: { + allOf: [ + { + $ref: '#/components/schemas/ModelIdentifierField', + }, + ], + description: 'Info to load lora model', + }, + weight: { + description: 'Weight to apply to lora model', + title: 'Weight', + type: 'number', + }, + }, + required: ['lora', 'weight'], + title: 'LoRAField', + type: 'object', + class: 'output', + }, + FreeUConfig: { + description: + 'Configuration for the FreeU hyperparameters.\n- https://huggingface.co/docs/diffusers/main/en/using-diffusers/freeu\n- https://github.com/ChenyangSi/FreeU', + properties: { + s1: { + description: + 'Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.', + maximum: 3, + minimum: -1, + title: 'S1', + type: 'number', + }, + s2: { + description: + 'Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.', + maximum: 3, + minimum: -1, + title: 'S2', + type: 'number', + }, + b1: { + description: 'Scaling factor for stage 1 to amplify the contributions of backbone features.', + maximum: 3, + minimum: -1, + title: 'B1', + type: 'number', + }, + b2: { + description: 'Scaling factor for stage 2 to amplify the contributions of backbone features.', + maximum: 3, + minimum: -1, + title: 'B2', + type: 'number', + }, + }, + required: ['s1', 's2', 'b1', 'b2'], + title: 'FreeUConfig', + type: 'object', + class: 'output', + }, + VAEField: { + properties: { + vae: { + allOf: [ + { + $ref: '#/components/schemas/ModelIdentifierField', + }, + ], + description: 'Info to load vae submodel', + }, + seamless_axes: { + description: 'Axes("x" and "y") to which apply seamless', + items: { + type: 'string', + }, + title: 'Seamless Axes', + type: 'array', + }, + }, + required: ['vae'], + title: 'VAEField', + type: 'object', + class: 'output', + }, + CLIPField: { + properties: { + tokenizer: { + allOf: [ + { + $ref: '#/components/schemas/ModelIdentifierField', + }, + ], + description: 'Info to load tokenizer submodel', + }, + text_encoder: { + allOf: [ + { + $ref: '#/components/schemas/ModelIdentifierField', + }, + ], + description: 'Info to load text_encoder submodel', + }, + skipped_layers: { + description: 'Number of skipped layers in text_encoder', + title: 'Skipped Layers', + type: 'integer', + }, + loras: { + description: 'LoRAs to apply on model loading', + items: { + $ref: '#/components/schemas/LoRAField', + }, + title: 'Loras', + type: 'array', + }, + }, + required: ['tokenizer', 'text_encoder', 'skipped_layers', 'loras'], + title: 'CLIPField', + type: 'object', + class: 'output', + }, + CollectInvocation: { + properties: { + id: { + type: 'string', + title: 'Id', + description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.', + field_kind: 'node_attribute', + }, + is_intermediate: { + type: 'boolean', + title: 'Is Intermediate', + description: 'Whether or not this is an intermediate invocation.', + default: false, + field_kind: 'node_attribute', + ui_type: 'IsIntermediate', + }, + use_cache: { + type: 'boolean', + title: 'Use Cache', + description: 'Whether or not to use the cache', + default: true, + field_kind: 'node_attribute', + }, + item: { + anyOf: [ + {}, + { + type: 'null', + }, + ], + title: 'Collection Item', + description: 'The item to collect (all inputs must be of the same type)', + field_kind: 'input', + input: 'connection', + orig_required: false, + ui_hidden: false, + ui_type: 'CollectionItemField', + }, + collection: { + items: {}, + type: 'array', + title: 'Collection', + description: 'The collection, will be provided on execution', + default: [], + field_kind: 'input', + input: 'any', + orig_default: [], + orig_required: false, + ui_hidden: true, + }, + type: { + type: 'string', + enum: ['collect'], + const: 'collect', + title: 'type', + default: 'collect', + field_kind: 'node_attribute', + }, + }, + type: 'object', + required: ['type', 'id'], + title: 'CollectInvocation', + description: 'Collects values into a collection', + classification: 'stable', + version: '1.0.0', + output: { + $ref: '#/components/schemas/CollectInvocationOutput', + }, + class: 'invocation', + }, + CollectInvocationOutput: { + properties: { + collection: { + description: 'The collection of input items', + field_kind: 'output', + items: {}, + title: 'Collection', + type: 'array', + ui_hidden: false, + ui_type: 'CollectionField', + }, + type: { + const: 'collect_output', + default: 'collect_output', + enum: ['collect_output'], + field_kind: 'node_attribute', + title: 'type', + type: 'string', + }, + }, + required: ['collection', 'type', 'type'], + title: 'CollectInvocationOutput', + type: 'object', + class: 'output', + }, + SubtractInvocation: { + properties: { + id: { + type: 'string', + title: 'Id', + description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.', + field_kind: 'node_attribute', + }, + is_intermediate: { + type: 'boolean', + title: 'Is Intermediate', + description: 'Whether or not this is an intermediate invocation.', + default: false, + field_kind: 'node_attribute', + ui_type: 'IsIntermediate', + }, + use_cache: { + type: 'boolean', + title: 'Use Cache', + description: 'Whether or not to use the cache', + default: true, + field_kind: 'node_attribute', + }, + a: { + type: 'integer', + title: 'A', + description: 'The first number', + default: 0, + field_kind: 'input', + input: 'any', + orig_default: 0, + orig_required: false, + ui_hidden: false, + }, + b: { + type: 'integer', + title: 'B', + description: 'The second number', + default: 0, + field_kind: 'input', + input: 'any', + orig_default: 0, + orig_required: false, + ui_hidden: false, + }, + type: { + type: 'string', + enum: ['sub'], + const: 'sub', + title: 'type', + default: 'sub', + field_kind: 'node_attribute', + }, + }, + type: 'object', + required: ['type', 'id'], + title: 'Subtract Integers', + description: 'Subtracts two numbers', + category: 'math', + classification: 'stable', + node_pack: 'invokeai', + tags: ['math', 'subtract'], + version: '1.0.1', + output: { + $ref: '#/components/schemas/IntegerOutput', + }, + class: 'invocation', + }, + }, + }, +} as OpenAPIV3_1.Document; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts new file mode 100644 index 0000000000..5d10ef368b --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts @@ -0,0 +1,149 @@ +import { deepClone } from 'common/util/deepClone'; +import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode'; +import { set } from 'lodash-es'; +import { describe, expect, it } from 'vitest'; + +import { add, buildEdge, collect, main_model_loader, position, sub, templates } from './testUtils'; +import { buildAcceptResult, buildRejectResult, validateConnection } from './validateConnection'; + +describe(validateConnection.name, () => { + it('should reject invalid connection to self', () => { + const c = { source: 'add', sourceHandle: 'value', target: 'add', targetHandle: 'a' }; + const r = validateConnection(c, [], [], templates, null); + expect(r).toEqual(buildRejectResult('nodes.cannotConnectToSelf')); + }); + + describe('missing nodes', () => { + const n1 = buildInvocationNode(position, add); + const n2 = buildInvocationNode(position, sub); + const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' }; + + it('should reject missing source node', () => { + const r = validateConnection(c, [n2], [], templates, null); + expect(r).toEqual(buildRejectResult('nodes.missingNode')); + }); + + it('should reject missing target node', () => { + const r = validateConnection(c, [n1], [], templates, null); + expect(r).toEqual(buildRejectResult('nodes.missingNode')); + }); + }); + + describe('missing invocation templates', () => { + const n1 = buildInvocationNode(position, add); + const n2 = buildInvocationNode(position, sub); + const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' }; + const nodes = [n1, n2]; + + it('should reject missing source template', () => { + const r = validateConnection(c, nodes, [], { sub }, null); + expect(r).toEqual(buildRejectResult('nodes.missingInvocationTemplate')); + }); + + it('should reject missing target template', () => { + const r = validateConnection(c, nodes, [], { add }, null); + expect(r).toEqual(buildRejectResult('nodes.missingInvocationTemplate')); + }); + }); + + describe('missing field templates', () => { + const n1 = buildInvocationNode(position, add); + const n2 = buildInvocationNode(position, sub); + const nodes = [n1, n2]; + + it('should reject missing source field template', () => { + const c = { source: n1.id, sourceHandle: 'invalid', target: n2.id, targetHandle: 'a' }; + const r = validateConnection(c, nodes, [], templates, null); + expect(r).toEqual(buildRejectResult('nodes.missingFieldTemplate')); + }); + + it('should reject missing target field template', () => { + const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'invalid' }; + const r = validateConnection(c, nodes, [], templates, null); + expect(r).toEqual(buildRejectResult('nodes.missingFieldTemplate')); + }); + }); + + describe('duplicate connections', () => { + const n1 = buildInvocationNode(position, add); + const n2 = buildInvocationNode(position, sub); + it('should accept non-duplicate connections', () => { + const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' }; + const r = validateConnection(c, [n1, n2], [], templates, null); + expect(r).toEqual(buildAcceptResult()); + }); + it('should reject duplicate connections', () => { + const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' }; + const e = buildEdge(n1.id, 'value', n2.id, 'a'); + const r = validateConnection(c, [n1, n2], [e], templates, null); + expect(r).toEqual(buildRejectResult('nodes.cannotDuplicateConnection')); + }); + it('should accept duplicate connections if the duplicate is an ignored edge', () => { + const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' }; + const e = buildEdge(n1.id, 'value', n2.id, 'a'); + const r = validateConnection(c, [n1, n2], [e], templates, e); + expect(r).toEqual(buildAcceptResult()); + }); + }); + + it('should reject connection to direct input', () => { + // Create cloned add template w/ a direct input + const addWithDirectAField = deepClone(add); + set(addWithDirectAField, 'inputs.a.input', 'direct'); + set(addWithDirectAField, 'type', 'addWithDirectAField'); + + const n1 = buildInvocationNode(position, add); + const n2 = buildInvocationNode(position, addWithDirectAField); + const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' }; + const r = validateConnection(c, [n1, n2], [], { add, addWithDirectAField }, null); + expect(r).toEqual(buildRejectResult('nodes.cannotConnectToDirectInput')); + }); + + it('should reject connection to a collect node with mismatched item types', () => { + const n1 = buildInvocationNode(position, add); + const n2 = buildInvocationNode(position, collect); + const n3 = buildInvocationNode(position, main_model_loader); + const nodes = [n1, n2, n3]; + const e1 = buildEdge(n1.id, 'value', n2.id, 'item'); + const edges = [e1]; + const c = { source: n3.id, sourceHandle: 'vae', target: n2.id, targetHandle: 'item' }; + const r = validateConnection(c, nodes, edges, templates, null); + expect(r).toEqual(buildRejectResult('nodes.cannotMixAndMatchCollectionItemTypes')); + }); + + it('should accept connection to a collect node with matching item types', () => { + const n1 = buildInvocationNode(position, add); + const n2 = buildInvocationNode(position, collect); + const n3 = buildInvocationNode(position, sub); + const nodes = [n1, n2, n3]; + const e1 = buildEdge(n1.id, 'value', n2.id, 'item'); + const edges = [e1]; + const c = { source: n3.id, sourceHandle: 'value', target: n2.id, targetHandle: 'item' }; + const r = validateConnection(c, nodes, edges, templates, null); + expect(r).toEqual(buildAcceptResult()); + }); + + it('should reject connections to target field that is already connected', () => { + const n1 = buildInvocationNode(position, add); + const n2 = buildInvocationNode(position, add); + const n3 = buildInvocationNode(position, add); + const nodes = [n1, n2, n3]; + const e1 = buildEdge(n1.id, 'value', n2.id, 'a'); + const edges = [e1]; + const c = { source: n3.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' }; + const r = validateConnection(c, nodes, edges, templates, null); + expect(r).toEqual(buildRejectResult('nodes.inputMayOnlyHaveOneConnection')); + }); + + it('should accept connections to target field that is already connected (ignored edge)', () => { + const n1 = buildInvocationNode(position, add); + const n2 = buildInvocationNode(position, add); + const n3 = buildInvocationNode(position, add); + const nodes = [n1, n2, n3]; + const e1 = buildEdge(n1.id, 'value', n2.id, 'a'); + const edges = [e1]; + const c = { source: n3.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' }; + const r = validateConnection(c, nodes, edges, templates, e1); + expect(r).toEqual(buildAcceptResult()); + }); +}); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts new file mode 100644 index 0000000000..d45a75ab9f --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts @@ -0,0 +1,109 @@ +import type { Templates } from 'features/nodes/store/types'; +import { areTypesEqual } from 'features/nodes/store/util/areTypesEqual'; +import { getCollectItemType } from 'features/nodes/store/util/getCollectItemType'; +import type { AnyNode } from 'features/nodes/types/invocation'; +import type { Connection as NullableConnection, Edge } from 'reactflow'; +import type { O } from 'ts-toolbelt'; + +type Connection = O.NonNullable; + +export type ValidateConnectionResult = { + isValid: boolean; + messageTKey?: string; +}; + +export type ValidateConnectionFunc = ( + connection: Connection, + nodes: AnyNode[], + edges: Edge[], + templates: Templates, + ignoreEdge: Edge | null +) => ValidateConnectionResult; + +export const buildResult = (isValid: boolean, messageTKey?: string): ValidateConnectionResult => ({ + isValid, + messageTKey, +}); + +const getEqualityPredicate = + (c: Connection) => + (e: Edge): boolean => { + return ( + e.target === c.target && + e.targetHandle === c.targetHandle && + e.source === c.source && + e.sourceHandle === c.sourceHandle + ); + }; + +export const buildAcceptResult = (): ValidateConnectionResult => ({ isValid: true }); +export const buildRejectResult = (messageTKey: string): ValidateConnectionResult => ({ isValid: false, messageTKey }); + +export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, templates, ignoreEdge) => { + if (c.source === c.target) { + return buildRejectResult('nodes.cannotConnectToSelf'); + } + + const filteredEdges = edges.filter((e) => e.id !== ignoreEdge?.id); + + if (filteredEdges.some(getEqualityPredicate(c))) { + // We already have a connection from this source to this target + return buildRejectResult('nodes.cannotDuplicateConnection'); + } + + const sourceNode = nodes.find((n) => n.id === c.source); + if (!sourceNode) { + return buildRejectResult('nodes.missingNode'); + } + + const targetNode = nodes.find((n) => n.id === c.target); + if (!targetNode) { + return buildRejectResult('nodes.missingNode'); + } + + const sourceTemplate = templates[sourceNode.data.type]; + if (!sourceTemplate) { + return buildRejectResult('nodes.missingInvocationTemplate'); + } + + const targetTemplate = templates[targetNode.data.type]; + if (!targetTemplate) { + return buildRejectResult('nodes.missingInvocationTemplate'); + } + + const sourceFieldTemplate = sourceTemplate.outputs[c.sourceHandle]; + if (!sourceFieldTemplate) { + return buildRejectResult('nodes.missingFieldTemplate'); + } + + const targetFieldTemplate = targetTemplate.inputs[c.targetHandle]; + if (!targetFieldTemplate) { + return buildRejectResult('nodes.missingFieldTemplate'); + } + + if (targetFieldTemplate.input === 'direct') { + return buildRejectResult('nodes.cannotConnectToDirectInput'); + } + + if (targetNode.data.type === 'collect' && c.targetHandle === 'item') { + // Collect nodes shouldn't mix and match field types + const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id); + if (collectItemType) { + if (!areTypesEqual(sourceFieldTemplate.type, collectItemType)) { + return buildRejectResult('nodes.cannotMixAndMatchCollectionItemTypes'); + } + } + } + + if ( + edges.find((e) => { + return e.target === c.target && e.targetHandle === c.targetHandle; + }) && + // except CollectionItem inputs can have multiples + targetFieldTemplate.type.name !== 'CollectionItemField' + ) { + return buildRejectResult('nodes.inputMayOnlyHaveOneConnection'); + } + + return buildAcceptResult(); +}; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.test.ts new file mode 100644 index 0000000000..d953fd973f --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.test.ts @@ -0,0 +1,222 @@ +import { describe, expect, it } from 'vitest'; + +import { validateConnectionTypes } from './validateConnectionTypes'; + +describe(validateConnectionTypes.name, () => { + describe('generic cases', () => { + it('should accept Scalar to Scalar of same type', () => { + const r = validateConnectionTypes( + { name: 'FooField', isCollection: false, isCollectionOrScalar: false }, + { name: 'FooField', isCollection: false, isCollectionOrScalar: false } + ); + expect(r).toBe(true); + }); + it('should accept Collection to Collection of same type', () => { + const r = validateConnectionTypes( + { name: 'FooField', isCollection: true, isCollectionOrScalar: false }, + { name: 'FooField', isCollection: true, isCollectionOrScalar: false } + ); + expect(r).toBe(true); + }); + it('should accept Scalar to CollectionOrScalar of same type', () => { + const r = validateConnectionTypes( + { name: 'FooField', isCollection: false, isCollectionOrScalar: false }, + { name: 'FooField', isCollection: false, isCollectionOrScalar: true } + ); + expect(r).toBe(true); + }); + it('should accept Collection to CollectionOrScalar of same type', () => { + const r = validateConnectionTypes( + { name: 'FooField', isCollection: true, isCollectionOrScalar: false }, + { name: 'FooField', isCollection: false, isCollectionOrScalar: true } + ); + expect(r).toBe(true); + }); + it('should reject Collection to Scalar of same type', () => { + const r = validateConnectionTypes( + { name: 'FooField', isCollection: true, isCollectionOrScalar: false }, + { name: 'FooField', isCollection: false, isCollectionOrScalar: false } + ); + expect(r).toBe(false); + }); + it('should reject CollectionOrScalar to Scalar of same type', () => { + const r = validateConnectionTypes( + { name: 'FooField', isCollection: false, isCollectionOrScalar: true }, + { name: 'FooField', isCollection: false, isCollectionOrScalar: false } + ); + expect(r).toBe(false); + }); + it('should reject mismatched types', () => { + const r = validateConnectionTypes( + { name: 'FooField', isCollection: false, isCollectionOrScalar: false }, + { name: 'BarField', isCollection: false, isCollectionOrScalar: false } + ); + expect(r).toBe(false); + }); + }); + + describe('special cases', () => { + it('should reject a collection input to a collection input', () => { + const r = validateConnectionTypes( + { name: 'CollectionField', isCollection: true, isCollectionOrScalar: false }, + { name: 'CollectionField', isCollection: true, isCollectionOrScalar: false } + ); + expect(r).toBe(false); + }); + + it('should accept equal types', () => { + const r = validateConnectionTypes( + { name: 'IntegerField', isCollection: false, isCollectionOrScalar: false }, + { name: 'IntegerField', isCollection: false, isCollectionOrScalar: false } + ); + expect(r).toBe(true); + }); + + describe('CollectionItemField', () => { + it('should accept CollectionItemField to any Scalar target', () => { + const r = validateConnectionTypes( + { name: 'CollectionItemField', isCollection: false, isCollectionOrScalar: false }, + { name: 'IntegerField', isCollection: false, isCollectionOrScalar: false } + ); + expect(r).toBe(true); + }); + it('should accept CollectionItemField to any CollectionOrScalar target', () => { + const r = validateConnectionTypes( + { name: 'CollectionItemField', isCollection: false, isCollectionOrScalar: false }, + { name: 'IntegerField', isCollection: false, isCollectionOrScalar: true } + ); + expect(r).toBe(true); + }); + it('should accept any non-Collection to CollectionItemField', () => { + const r = validateConnectionTypes( + { name: 'IntegerField', isCollection: false, isCollectionOrScalar: false }, + { name: 'CollectionItemField', isCollection: false, isCollectionOrScalar: false } + ); + expect(r).toBe(true); + }); + it('should reject any Collection to CollectionItemField', () => { + const r = validateConnectionTypes( + { name: 'IntegerField', isCollection: true, isCollectionOrScalar: false }, + { name: 'CollectionItemField', isCollection: false, isCollectionOrScalar: false } + ); + expect(r).toBe(false); + }); + it('should reject any CollectionOrScalar to CollectionItemField', () => { + const r = validateConnectionTypes( + { name: 'IntegerField', isCollection: false, isCollectionOrScalar: true }, + { name: 'CollectionItemField', isCollection: false, isCollectionOrScalar: false } + ); + expect(r).toBe(false); + }); + }); + + describe('CollectionOrScalar', () => { + it('should accept any Scalar of same type to CollectionOrScalar', () => { + const r = validateConnectionTypes( + { name: 'IntegerField', isCollection: false, isCollectionOrScalar: false }, + { name: 'IntegerField', isCollection: false, isCollectionOrScalar: true } + ); + expect(r).toBe(true); + }); + it('should accept any Collection of same type to CollectionOrScalar', () => { + const r = validateConnectionTypes( + { name: 'IntegerField', isCollection: true, isCollectionOrScalar: false }, + { name: 'IntegerField', isCollection: false, isCollectionOrScalar: true } + ); + expect(r).toBe(true); + }); + it('should accept any CollectionOrScalar of same type to CollectionOrScalar', () => { + const r = validateConnectionTypes( + { name: 'IntegerField', isCollection: false, isCollectionOrScalar: true }, + { name: 'IntegerField', isCollection: false, isCollectionOrScalar: true } + ); + expect(r).toBe(true); + }); + }); + + describe('CollectionField', () => { + it('should accept any CollectionField to any Collection type', () => { + const r = validateConnectionTypes( + { name: 'CollectionField', isCollection: false, isCollectionOrScalar: false }, + { name: 'IntegerField', isCollection: true, isCollectionOrScalar: false } + ); + expect(r).toBe(true); + }); + it('should accept any CollectionField to any CollectionOrScalar type', () => { + const r = validateConnectionTypes( + { name: 'CollectionField', isCollection: false, isCollectionOrScalar: false }, + { name: 'IntegerField', isCollection: false, isCollectionOrScalar: true } + ); + expect(r).toBe(true); + }); + }); + + describe('subtype handling', () => { + type TypePair = { t1: string; t2: string }; + const typePairs = [ + { t1: 'IntegerField', t2: 'FloatField' }, + { t1: 'IntegerField', t2: 'StringField' }, + { t1: 'FloatField', t2: 'StringField' }, + ]; + it.each(typePairs)('should accept Scalar $t1 to Scalar $t2', ({ t1, t2 }: TypePair) => { + const r = validateConnectionTypes( + { name: t1, isCollection: false, isCollectionOrScalar: false }, + { name: t2, isCollection: false, isCollectionOrScalar: false } + ); + expect(r).toBe(true); + }); + it.each(typePairs)('should accept Scalar $t1 to CollectionOrScalar $t2', ({ t1, t2 }: TypePair) => { + const r = validateConnectionTypes( + { name: t1, isCollection: false, isCollectionOrScalar: false }, + { name: t2, isCollection: false, isCollectionOrScalar: true } + ); + expect(r).toBe(true); + }); + it.each(typePairs)('should accept Collection $t1 to Collection $t2', ({ t1, t2 }: TypePair) => { + const r = validateConnectionTypes( + { name: t1, isCollection: true, isCollectionOrScalar: false }, + { name: t2, isCollection: false, isCollectionOrScalar: false } + ); + expect(r).toBe(true); + }); + it.each(typePairs)('should accept Collection $t1 to CollectionOrScalar $t2', ({ t1, t2 }: TypePair) => { + const r = validateConnectionTypes( + { name: t1, isCollection: true, isCollectionOrScalar: false }, + { name: t2, isCollection: false, isCollectionOrScalar: true } + ); + expect(r).toBe(true); + }); + it.each(typePairs)('should accept CollectionOrScalar $t1 to CollectionOrScalar $t2', ({ t1, t2 }: TypePair) => { + const r = validateConnectionTypes( + { name: t1, isCollection: false, isCollectionOrScalar: true }, + { name: t2, isCollection: false, isCollectionOrScalar: true } + ); + expect(r).toBe(true); + }); + }); + + describe('AnyField', () => { + it('should accept any Scalar type to AnyField', () => { + const r = validateConnectionTypes( + { name: 'FooField', isCollection: false, isCollectionOrScalar: false }, + { name: 'AnyField', isCollection: false, isCollectionOrScalar: false } + ); + expect(r).toBe(true); + }); + it('should accept any Collection type to AnyField', () => { + const r = validateConnectionTypes( + { name: 'FooField', isCollection: false, isCollectionOrScalar: false }, + { name: 'AnyField', isCollection: true, isCollectionOrScalar: false } + ); + expect(r).toBe(true); + }); + it('should accept any CollectionOrScalar type to AnyField', () => { + const r = validateConnectionTypes( + { name: 'FooField', isCollection: false, isCollectionOrScalar: false }, + { name: 'AnyField', isCollection: false, isCollectionOrScalar: true } + ); + expect(r).toBe(true); + }); + }); + }); +}); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.ts new file mode 100644 index 0000000000..092279e315 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.ts @@ -0,0 +1,69 @@ +import { areTypesEqual } from 'features/nodes/store/util/areTypesEqual'; +import type { FieldType } from 'features/nodes/types/field'; + +/** + * Validates that the source and target types are compatible for a connection. + * @param sourceType The type of the source field. + * @param targetType The type of the target field. + * @returns True if the connection is valid, false otherwise. + */ +export const validateConnectionTypes = (sourceType: FieldType, targetType: FieldType) => { + // TODO: There's a bug with Collect -> Iterate nodes: + // https://github.com/invoke-ai/InvokeAI/issues/3956 + // Once this is resolved, we can remove this check. + if (sourceType.name === 'CollectionField' && targetType.name === 'CollectionField') { + return false; + } + + if (areTypesEqual(sourceType, targetType)) { + return true; + } + + /** + * Connection types must be the same for a connection, with exceptions: + * - CollectionItem can connect to any non-Collection + * - Non-Collections can connect to CollectionItem + * - Anything (non-Collections, Collections, CollectionOrScalar) can connect to CollectionOrScalar of the same base type + * - Generic Collection can connect to any other Collection or CollectionOrScalar + * - Any Collection can connect to a Generic Collection + */ + const isCollectionItemToNonCollection = sourceType.name === 'CollectionItemField' && !targetType.isCollection; + + const isNonCollectionToCollectionItem = + targetType.name === 'CollectionItemField' && !sourceType.isCollection && !sourceType.isCollectionOrScalar; + + const isAnythingToCollectionOrScalarOfSameBaseType = + targetType.isCollectionOrScalar && sourceType.name === targetType.name; + + const isGenericCollectionToAnyCollectionOrCollectionOrScalar = + sourceType.name === 'CollectionField' && (targetType.isCollection || targetType.isCollectionOrScalar); + + const isCollectionToGenericCollection = targetType.name === 'CollectionField' && sourceType.isCollection; + + const areBothTypesSingle = + !sourceType.isCollection && + !sourceType.isCollectionOrScalar && + !targetType.isCollection && + !targetType.isCollectionOrScalar; + + const isIntToFloat = areBothTypesSingle && sourceType.name === 'IntegerField' && targetType.name === 'FloatField'; + + const isIntOrFloatToString = + areBothTypesSingle && + (sourceType.name === 'IntegerField' || sourceType.name === 'FloatField') && + targetType.name === 'StringField'; + + const isTargetAnyType = targetType.name === 'AnyField'; + + // One of these must be true for the connection to be valid + return ( + isCollectionItemToNonCollection || + isNonCollectionToCollectionItem || + isAnythingToCollectionOrScalarOfSameBaseType || + isGenericCollectionToAnyCollectionOrCollectionOrScalar || + isCollectionToGenericCollection || + isIntToFloat || + isIntOrFloatToString || + isTargetAnyType + ); +}; diff --git a/invokeai/frontend/web/src/features/nodes/types/field.ts b/invokeai/frontend/web/src/features/nodes/types/field.ts index a98f773c7e..8a1a0b5039 100644 --- a/invokeai/frontend/web/src/features/nodes/types/field.ts +++ b/invokeai/frontend/web/src/features/nodes/types/field.ts @@ -188,7 +188,6 @@ const zIntegerFieldInputTemplate = zFieldInputTemplateBase.extend({ }); const zIntegerFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zIntegerFieldType, - originalType: zFieldType.optional(), }); export type IntegerFieldValue = z.infer; export type IntegerFieldInputInstance = z.infer; @@ -217,7 +216,6 @@ const zFloatFieldInputTemplate = zFieldInputTemplateBase.extend({ }); const zFloatFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zFloatFieldType, - originalType: zFieldType.optional(), }); export type FloatFieldValue = z.infer; export type FloatFieldInputInstance = z.infer; @@ -243,7 +241,6 @@ const zStringFieldInputTemplate = zFieldInputTemplateBase.extend({ }); const zStringFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zStringFieldType, - originalType: zFieldType.optional(), }); export type StringFieldValue = z.infer; @@ -268,7 +265,6 @@ const zBooleanFieldInputTemplate = zFieldInputTemplateBase.extend({ }); const zBooleanFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zBooleanFieldType, - originalType: zFieldType.optional(), }); export type BooleanFieldValue = z.infer; export type BooleanFieldInputInstance = z.infer; @@ -294,7 +290,6 @@ const zEnumFieldInputTemplate = zFieldInputTemplateBase.extend({ }); const zEnumFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zEnumFieldType, - originalType: zFieldType.optional(), }); export type EnumFieldValue = z.infer; export type EnumFieldInputInstance = z.infer; @@ -318,7 +313,6 @@ const zImageFieldInputTemplate = zFieldInputTemplateBase.extend({ }); const zImageFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zImageFieldType, - originalType: zFieldType.optional(), }); export type ImageFieldValue = z.infer; export type ImageFieldInputInstance = z.infer; @@ -342,7 +336,6 @@ const zBoardFieldInputTemplate = zFieldInputTemplateBase.extend({ }); const zBoardFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zBoardFieldType, - originalType: zFieldType.optional(), }); export type BoardFieldValue = z.infer; export type BoardFieldInputInstance = z.infer; @@ -366,7 +359,6 @@ const zColorFieldInputTemplate = zFieldInputTemplateBase.extend({ }); const zColorFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zColorFieldType, - originalType: zFieldType.optional(), }); export type ColorFieldValue = z.infer; export type ColorFieldInputInstance = z.infer; @@ -390,7 +382,6 @@ const zMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({ }); const zMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zMainModelFieldType, - originalType: zFieldType.optional(), }); export type MainModelFieldValue = z.infer; export type MainModelFieldInputInstance = z.infer; @@ -413,7 +404,6 @@ const zModelIdentifierFieldInputTemplate = zFieldInputTemplateBase.extend({ }); const zModelIdentifierFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zModelIdentifierFieldType, - originalType: zFieldType.optional(), }); export type ModelIdentifierFieldValue = z.infer; export type ModelIdentifierFieldInputInstance = z.infer; @@ -437,7 +427,6 @@ const zSDXLMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({ }); const zSDXLMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zSDXLMainModelFieldType, - originalType: zFieldType.optional(), }); export type SDXLMainModelFieldInputInstance = z.infer; export type SDXLMainModelFieldInputTemplate = z.infer; @@ -461,7 +450,6 @@ const zSDXLRefinerModelFieldInputTemplate = zFieldInputTemplateBase.extend({ }); const zSDXLRefinerModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zSDXLRefinerModelFieldType, - originalType: zFieldType.optional(), }); export type SDXLRefinerModelFieldValue = z.infer; export type SDXLRefinerModelFieldInputInstance = z.infer; @@ -485,7 +473,6 @@ const zVAEModelFieldInputTemplate = zFieldInputTemplateBase.extend({ }); const zVAEModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zVAEModelFieldType, - originalType: zFieldType.optional(), }); export type VAEModelFieldValue = z.infer; export type VAEModelFieldInputInstance = z.infer; @@ -509,7 +496,6 @@ const zLoRAModelFieldInputTemplate = zFieldInputTemplateBase.extend({ }); const zLoRAModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zLoRAModelFieldType, - originalType: zFieldType.optional(), }); export type LoRAModelFieldValue = z.infer; export type LoRAModelFieldInputInstance = z.infer; @@ -533,7 +519,6 @@ const zControlNetModelFieldInputTemplate = zFieldInputTemplateBase.extend({ }); const zControlNetModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zControlNetModelFieldType, - originalType: zFieldType.optional(), }); export type ControlNetModelFieldValue = z.infer; export type ControlNetModelFieldInputInstance = z.infer; @@ -557,7 +542,6 @@ const zIPAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({ }); const zIPAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zIPAdapterModelFieldType, - originalType: zFieldType.optional(), }); export type IPAdapterModelFieldValue = z.infer; export type IPAdapterModelFieldInputInstance = z.infer; @@ -581,7 +565,6 @@ const zT2IAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({ }); const zT2IAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zT2IAdapterModelFieldType, - originalType: zFieldType.optional(), }); export type T2IAdapterModelFieldValue = z.infer; export type T2IAdapterModelFieldInputInstance = z.infer; @@ -605,7 +588,6 @@ const zSchedulerFieldInputTemplate = zFieldInputTemplateBase.extend({ }); const zSchedulerFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zSchedulerFieldType, - originalType: zFieldType.optional(), }); export type SchedulerFieldValue = z.infer; export type SchedulerFieldInputInstance = z.infer; @@ -641,7 +623,6 @@ const zStatelessFieldInputTemplate = zFieldInputTemplateBase.extend({ }); const zStatelessFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zStatelessFieldType, - originalType: zFieldType.optional(), }); export type StatelessFieldInputTemplate = z.infer; diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.test.ts b/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.test.ts index 480387a8a4..656bdc9d64 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.test.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.test.ts @@ -1,942 +1,19 @@ +import { schema, templates } from 'features/nodes/store/util/testUtils'; import { parseSchema } from 'features/nodes/util/schema/parseSchema'; import { omit, pick } from 'lodash-es'; -import type { OpenAPIV3_1 } from 'openapi-types'; import { describe, expect, it } from 'vitest'; describe('parseSchema', () => { it('should parse the schema', () => { - const templates = parseSchema(schema); - expect(templates).toEqual(expected); + const parsed = parseSchema(schema); + expect(parsed).toEqual(templates); }); it('should omit denied nodes', () => { - const templates = parseSchema(schema, undefined, ['add']); - expect(templates).toEqual(omit(expected, 'add')); + const parsed = parseSchema(schema, undefined, ['add']); + expect(parsed).toEqual(omit(templates, 'add')); }); it('should include only allowed nodes', () => { - const templates = parseSchema(schema, ['add']); - expect(templates).toEqual(pick(expected, 'add')); + const parsed = parseSchema(schema, ['add']); + expect(parsed).toEqual(pick(templates, 'add')); }); }); - -const expected = { - add: { - title: 'Add Integers', - type: 'add', - version: '1.0.1', - tags: ['math', 'add'], - description: 'Adds two numbers', - outputType: 'integer_output', - inputs: { - a: { - name: 'a', - title: 'A', - required: false, - description: 'The first number', - fieldKind: 'input', - input: 'any', - ui_hidden: false, - type: { - name: 'IntegerField', - isCollection: false, - isCollectionOrScalar: false, - }, - default: 0, - }, - b: { - name: 'b', - title: 'B', - required: false, - description: 'The second number', - fieldKind: 'input', - input: 'any', - ui_hidden: false, - type: { - name: 'IntegerField', - isCollection: false, - isCollectionOrScalar: false, - }, - default: 0, - }, - }, - outputs: { - value: { - fieldKind: 'output', - name: 'value', - title: 'Value', - description: 'The output integer', - type: { - name: 'IntegerField', - isCollection: false, - isCollectionOrScalar: false, - }, - ui_hidden: false, - }, - }, - useCache: true, - nodePack: 'invokeai', - classification: 'stable', - }, - scheduler: { - title: 'Scheduler', - type: 'scheduler', - version: '1.0.0', - tags: ['scheduler'], - description: 'Selects a scheduler.', - outputType: 'scheduler_output', - inputs: { - scheduler: { - name: 'scheduler', - title: 'Scheduler', - required: false, - description: 'Scheduler to use during inference', - fieldKind: 'input', - input: 'any', - ui_hidden: false, - ui_type: 'SchedulerField', - type: { - name: 'SchedulerField', - isCollection: false, - isCollectionOrScalar: false, - originalType: { - name: 'EnumField', - isCollection: false, - isCollectionOrScalar: false, - }, - }, - default: 'euler', - }, - }, - outputs: { - scheduler: { - fieldKind: 'output', - name: 'scheduler', - title: 'Scheduler', - description: 'Scheduler to use during inference', - type: { - name: 'SchedulerField', - isCollection: false, - isCollectionOrScalar: false, - originalType: { - name: 'EnumField', - isCollection: false, - isCollectionOrScalar: false, - }, - }, - ui_hidden: false, - ui_type: 'SchedulerField', - }, - }, - useCache: true, - nodePack: 'invokeai', - classification: 'stable', - }, - main_model_loader: { - title: 'Main Model', - type: 'main_model_loader', - version: '1.0.2', - tags: ['model'], - description: 'Loads a main model, outputting its submodels.', - outputType: 'model_loader_output', - inputs: { - model: { - name: 'model', - title: 'Model', - required: true, - description: 'Main model (UNet, VAE, CLIP) to load', - fieldKind: 'input', - input: 'direct', - ui_hidden: false, - ui_type: 'MainModelField', - type: { - name: 'MainModelField', - isCollection: false, - isCollectionOrScalar: false, - originalType: { - name: 'ModelIdentifierField', - isCollection: false, - isCollectionOrScalar: false, - }, - }, - }, - }, - outputs: { - vae: { - fieldKind: 'output', - name: 'vae', - title: 'VAE', - description: 'VAE', - type: { - name: 'VAEField', - isCollection: false, - isCollectionOrScalar: false, - }, - ui_hidden: false, - }, - clip: { - fieldKind: 'output', - name: 'clip', - title: 'CLIP', - description: 'CLIP (tokenizer, text encoder, LoRAs) and skipped layer count', - type: { - name: 'CLIPField', - isCollection: false, - isCollectionOrScalar: false, - }, - ui_hidden: false, - }, - unet: { - fieldKind: 'output', - name: 'unet', - title: 'UNet', - description: 'UNet (scheduler, LoRAs)', - type: { - name: 'UNetField', - isCollection: false, - isCollectionOrScalar: false, - }, - ui_hidden: false, - }, - }, - useCache: true, - nodePack: 'invokeai', - classification: 'stable', - }, - collect: { - title: 'Collect', - type: 'collect', - version: '1.0.0', - tags: [], - description: 'Collects values into a collection', - outputType: 'collect_output', - inputs: { - item: { - name: 'item', - title: 'Collection Item', - required: false, - description: 'The item to collect (all inputs must be of the same type)', - fieldKind: 'input', - input: 'connection', - ui_hidden: false, - ui_type: 'CollectionItemField', - type: { - name: 'CollectionItemField', - isCollection: false, - isCollectionOrScalar: false, - }, - }, - }, - outputs: { - collection: { - fieldKind: 'output', - name: 'collection', - title: 'Collection', - description: 'The collection of input items', - type: { - name: 'CollectionField', - isCollection: true, - isCollectionOrScalar: false, - }, - ui_hidden: false, - ui_type: 'CollectionField', - }, - }, - useCache: true, - classification: 'stable', - }, -}; - -const schema = { - openapi: '3.1.0', - info: { - title: 'Invoke - Community Edition', - description: 'An API for invoking AI image operations', - version: '1.0.0', - }, - components: { - schemas: { - AddInvocation: { - properties: { - id: { - type: 'string', - title: 'Id', - description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.', - field_kind: 'node_attribute', - }, - is_intermediate: { - type: 'boolean', - title: 'Is Intermediate', - description: 'Whether or not this is an intermediate invocation.', - default: false, - field_kind: 'node_attribute', - ui_type: 'IsIntermediate', - }, - use_cache: { - type: 'boolean', - title: 'Use Cache', - description: 'Whether or not to use the cache', - default: true, - field_kind: 'node_attribute', - }, - a: { - type: 'integer', - title: 'A', - description: 'The first number', - default: 0, - field_kind: 'input', - input: 'any', - orig_default: 0, - orig_required: false, - ui_hidden: false, - }, - b: { - type: 'integer', - title: 'B', - description: 'The second number', - default: 0, - field_kind: 'input', - input: 'any', - orig_default: 0, - orig_required: false, - ui_hidden: false, - }, - type: { - type: 'string', - enum: ['add'], - const: 'add', - title: 'type', - default: 'add', - field_kind: 'node_attribute', - }, - }, - type: 'object', - required: ['type', 'id'], - title: 'Add Integers', - description: 'Adds two numbers', - category: 'math', - classification: 'stable', - node_pack: 'invokeai', - tags: ['math', 'add'], - version: '1.0.1', - output: { - $ref: '#/components/schemas/IntegerOutput', - }, - class: 'invocation', - }, - IntegerOutput: { - description: 'Base class for nodes that output a single integer', - properties: { - value: { - description: 'The output integer', - field_kind: 'output', - title: 'Value', - type: 'integer', - ui_hidden: false, - }, - type: { - const: 'integer_output', - default: 'integer_output', - enum: ['integer_output'], - field_kind: 'node_attribute', - title: 'type', - type: 'string', - }, - }, - required: ['value', 'type', 'type'], - title: 'IntegerOutput', - type: 'object', - class: 'output', - }, - SchedulerInvocation: { - properties: { - id: { - type: 'string', - title: 'Id', - description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.', - field_kind: 'node_attribute', - }, - is_intermediate: { - type: 'boolean', - title: 'Is Intermediate', - description: 'Whether or not this is an intermediate invocation.', - default: false, - field_kind: 'node_attribute', - ui_type: 'IsIntermediate', - }, - use_cache: { - type: 'boolean', - title: 'Use Cache', - description: 'Whether or not to use the cache', - default: true, - field_kind: 'node_attribute', - }, - scheduler: { - type: 'string', - enum: [ - 'ddim', - 'ddpm', - 'deis', - 'lms', - 'lms_k', - 'pndm', - 'heun', - 'heun_k', - 'euler', - 'euler_k', - 'euler_a', - 'kdpm_2', - 'kdpm_2_a', - 'dpmpp_2s', - 'dpmpp_2s_k', - 'dpmpp_2m', - 'dpmpp_2m_k', - 'dpmpp_2m_sde', - 'dpmpp_2m_sde_k', - 'dpmpp_sde', - 'dpmpp_sde_k', - 'unipc', - 'lcm', - 'tcd', - ], - title: 'Scheduler', - description: 'Scheduler to use during inference', - default: 'euler', - field_kind: 'input', - input: 'any', - orig_default: 'euler', - orig_required: false, - ui_hidden: false, - ui_type: 'SchedulerField', - }, - type: { - type: 'string', - enum: ['scheduler'], - const: 'scheduler', - title: 'type', - default: 'scheduler', - field_kind: 'node_attribute', - }, - }, - type: 'object', - required: ['type', 'id'], - title: 'Scheduler', - description: 'Selects a scheduler.', - category: 'latents', - classification: 'stable', - node_pack: 'invokeai', - tags: ['scheduler'], - version: '1.0.0', - output: { - $ref: '#/components/schemas/SchedulerOutput', - }, - class: 'invocation', - }, - SchedulerOutput: { - properties: { - scheduler: { - description: 'Scheduler to use during inference', - enum: [ - 'ddim', - 'ddpm', - 'deis', - 'lms', - 'lms_k', - 'pndm', - 'heun', - 'heun_k', - 'euler', - 'euler_k', - 'euler_a', - 'kdpm_2', - 'kdpm_2_a', - 'dpmpp_2s', - 'dpmpp_2s_k', - 'dpmpp_2m', - 'dpmpp_2m_k', - 'dpmpp_2m_sde', - 'dpmpp_2m_sde_k', - 'dpmpp_sde', - 'dpmpp_sde_k', - 'unipc', - 'lcm', - 'tcd', - ], - field_kind: 'output', - title: 'Scheduler', - type: 'string', - ui_hidden: false, - ui_type: 'SchedulerField', - }, - type: { - const: 'scheduler_output', - default: 'scheduler_output', - enum: ['scheduler_output'], - field_kind: 'node_attribute', - title: 'type', - type: 'string', - }, - }, - required: ['scheduler', 'type', 'type'], - title: 'SchedulerOutput', - type: 'object', - class: 'output', - }, - MainModelLoaderInvocation: { - properties: { - id: { - type: 'string', - title: 'Id', - description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.', - field_kind: 'node_attribute', - }, - is_intermediate: { - type: 'boolean', - title: 'Is Intermediate', - description: 'Whether or not this is an intermediate invocation.', - default: false, - field_kind: 'node_attribute', - ui_type: 'IsIntermediate', - }, - use_cache: { - type: 'boolean', - title: 'Use Cache', - description: 'Whether or not to use the cache', - default: true, - field_kind: 'node_attribute', - }, - model: { - allOf: [ - { - $ref: '#/components/schemas/ModelIdentifierField', - }, - ], - description: 'Main model (UNet, VAE, CLIP) to load', - field_kind: 'input', - input: 'direct', - orig_required: true, - ui_hidden: false, - ui_type: 'MainModelField', - }, - type: { - type: 'string', - enum: ['main_model_loader'], - const: 'main_model_loader', - title: 'type', - default: 'main_model_loader', - field_kind: 'node_attribute', - }, - }, - type: 'object', - required: ['model', 'type', 'id'], - title: 'Main Model', - description: 'Loads a main model, outputting its submodels.', - category: 'model', - classification: 'stable', - node_pack: 'invokeai', - tags: ['model'], - version: '1.0.2', - output: { - $ref: '#/components/schemas/ModelLoaderOutput', - }, - class: 'invocation', - }, - ModelIdentifierField: { - properties: { - key: { - description: "The model's unique key", - title: 'Key', - type: 'string', - }, - hash: { - description: "The model's BLAKE3 hash", - title: 'Hash', - type: 'string', - }, - name: { - description: "The model's name", - title: 'Name', - type: 'string', - }, - base: { - allOf: [ - { - $ref: '#/components/schemas/BaseModelType', - }, - ], - description: "The model's base model type", - }, - type: { - allOf: [ - { - $ref: '#/components/schemas/ModelType', - }, - ], - description: "The model's type", - }, - submodel_type: { - anyOf: [ - { - $ref: '#/components/schemas/SubModelType', - }, - { - type: 'null', - }, - ], - default: null, - description: 'The submodel to load, if this is a main model', - }, - }, - required: ['key', 'hash', 'name', 'base', 'type'], - title: 'ModelIdentifierField', - type: 'object', - }, - BaseModelType: { - description: 'Base model type.', - enum: ['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'], - title: 'BaseModelType', - type: 'string', - }, - ModelType: { - description: 'Model type.', - enum: ['onnx', 'main', 'vae', 'lora', 'controlnet', 'embedding', 'ip_adapter', 'clip_vision', 't2i_adapter'], - title: 'ModelType', - type: 'string', - }, - SubModelType: { - description: 'Submodel type.', - enum: [ - 'unet', - 'text_encoder', - 'text_encoder_2', - 'tokenizer', - 'tokenizer_2', - 'vae', - 'vae_decoder', - 'vae_encoder', - 'scheduler', - 'safety_checker', - ], - title: 'SubModelType', - type: 'string', - }, - ModelLoaderOutput: { - description: 'Model loader output', - properties: { - vae: { - allOf: [ - { - $ref: '#/components/schemas/VAEField', - }, - ], - description: 'VAE', - field_kind: 'output', - title: 'VAE', - ui_hidden: false, - }, - type: { - const: 'model_loader_output', - default: 'model_loader_output', - enum: ['model_loader_output'], - field_kind: 'node_attribute', - title: 'type', - type: 'string', - }, - clip: { - allOf: [ - { - $ref: '#/components/schemas/CLIPField', - }, - ], - description: 'CLIP (tokenizer, text encoder, LoRAs) and skipped layer count', - field_kind: 'output', - title: 'CLIP', - ui_hidden: false, - }, - unet: { - allOf: [ - { - $ref: '#/components/schemas/UNetField', - }, - ], - description: 'UNet (scheduler, LoRAs)', - field_kind: 'output', - title: 'UNet', - ui_hidden: false, - }, - }, - required: ['vae', 'type', 'clip', 'unet', 'type'], - title: 'ModelLoaderOutput', - type: 'object', - class: 'output', - }, - UNetField: { - properties: { - unet: { - allOf: [ - { - $ref: '#/components/schemas/ModelIdentifierField', - }, - ], - description: 'Info to load unet submodel', - }, - scheduler: { - allOf: [ - { - $ref: '#/components/schemas/ModelIdentifierField', - }, - ], - description: 'Info to load scheduler submodel', - }, - loras: { - description: 'LoRAs to apply on model loading', - items: { - $ref: '#/components/schemas/LoRAField', - }, - title: 'Loras', - type: 'array', - }, - seamless_axes: { - description: 'Axes("x" and "y") to which apply seamless', - items: { - type: 'string', - }, - title: 'Seamless Axes', - type: 'array', - }, - freeu_config: { - anyOf: [ - { - $ref: '#/components/schemas/FreeUConfig', - }, - { - type: 'null', - }, - ], - default: null, - description: 'FreeU configuration', - }, - }, - required: ['unet', 'scheduler', 'loras'], - title: 'UNetField', - type: 'object', - class: 'output', - }, - LoRAField: { - properties: { - lora: { - allOf: [ - { - $ref: '#/components/schemas/ModelIdentifierField', - }, - ], - description: 'Info to load lora model', - }, - weight: { - description: 'Weight to apply to lora model', - title: 'Weight', - type: 'number', - }, - }, - required: ['lora', 'weight'], - title: 'LoRAField', - type: 'object', - class: 'output', - }, - FreeUConfig: { - description: - 'Configuration for the FreeU hyperparameters.\n- https://huggingface.co/docs/diffusers/main/en/using-diffusers/freeu\n- https://github.com/ChenyangSi/FreeU', - properties: { - s1: { - description: - 'Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.', - maximum: 3.0, - minimum: -1.0, - title: 'S1', - type: 'number', - }, - s2: { - description: - 'Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.', - maximum: 3.0, - minimum: -1.0, - title: 'S2', - type: 'number', - }, - b1: { - description: 'Scaling factor for stage 1 to amplify the contributions of backbone features.', - maximum: 3.0, - minimum: -1.0, - title: 'B1', - type: 'number', - }, - b2: { - description: 'Scaling factor for stage 2 to amplify the contributions of backbone features.', - maximum: 3.0, - minimum: -1.0, - title: 'B2', - type: 'number', - }, - }, - required: ['s1', 's2', 'b1', 'b2'], - title: 'FreeUConfig', - type: 'object', - class: 'output', - }, - VAEField: { - properties: { - vae: { - allOf: [ - { - $ref: '#/components/schemas/ModelIdentifierField', - }, - ], - description: 'Info to load vae submodel', - }, - seamless_axes: { - description: 'Axes("x" and "y") to which apply seamless', - items: { - type: 'string', - }, - title: 'Seamless Axes', - type: 'array', - }, - }, - required: ['vae'], - title: 'VAEField', - type: 'object', - class: 'output', - }, - CLIPField: { - properties: { - tokenizer: { - allOf: [ - { - $ref: '#/components/schemas/ModelIdentifierField', - }, - ], - description: 'Info to load tokenizer submodel', - }, - text_encoder: { - allOf: [ - { - $ref: '#/components/schemas/ModelIdentifierField', - }, - ], - description: 'Info to load text_encoder submodel', - }, - skipped_layers: { - description: 'Number of skipped layers in text_encoder', - title: 'Skipped Layers', - type: 'integer', - }, - loras: { - description: 'LoRAs to apply on model loading', - items: { - $ref: '#/components/schemas/LoRAField', - }, - title: 'Loras', - type: 'array', - }, - }, - required: ['tokenizer', 'text_encoder', 'skipped_layers', 'loras'], - title: 'CLIPField', - type: 'object', - class: 'output', - }, - CollectInvocation: { - properties: { - id: { - type: 'string', - title: 'Id', - description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.', - field_kind: 'node_attribute', - }, - is_intermediate: { - type: 'boolean', - title: 'Is Intermediate', - description: 'Whether or not this is an intermediate invocation.', - default: false, - field_kind: 'node_attribute', - ui_type: 'IsIntermediate', - }, - use_cache: { - type: 'boolean', - title: 'Use Cache', - description: 'Whether or not to use the cache', - default: true, - field_kind: 'node_attribute', - }, - item: { - anyOf: [ - {}, - { - type: 'null', - }, - ], - title: 'Collection Item', - description: 'The item to collect (all inputs must be of the same type)', - field_kind: 'input', - input: 'connection', - orig_required: false, - ui_hidden: false, - ui_type: 'CollectionItemField', - }, - collection: { - items: {}, - type: 'array', - title: 'Collection', - description: 'The collection, will be provided on execution', - default: [], - field_kind: 'input', - input: 'any', - orig_default: [], - orig_required: false, - ui_hidden: true, - }, - type: { - type: 'string', - enum: ['collect'], - const: 'collect', - title: 'type', - default: 'collect', - field_kind: 'node_attribute', - }, - }, - type: 'object', - required: ['type', 'id'], - title: 'CollectInvocation', - description: 'Collects values into a collection', - classification: 'stable', - version: '1.0.0', - output: { - $ref: '#/components/schemas/CollectInvocationOutput', - }, - class: 'invocation', - }, - CollectInvocationOutput: { - properties: { - collection: { - description: 'The collection of input items', - field_kind: 'output', - items: {}, - title: 'Collection', - type: 'array', - ui_hidden: false, - ui_type: 'CollectionField', - }, - type: { - const: 'collect_output', - default: 'collect_output', - enum: ['collect_output'], - field_kind: 'node_attribute', - title: 'type', - type: 'string', - }, - }, - required: ['collection', 'type', 'type'], - title: 'CollectInvocationOutput', - type: 'object', - class: 'output', - }, - }, - }, -} as OpenAPIV3_1.Document;