tests(ui): add tests for consolidated connection validation

This commit is contained in:
psychedelicious 2024-05-19 00:11:15 +10:00
parent 6f7160b9fd
commit 3fcb2720d7
19 changed files with 1999 additions and 1223 deletions

View File

@ -775,6 +775,9 @@
"cannotConnectToSelf": "Cannot connect to self", "cannotConnectToSelf": "Cannot connect to self",
"cannotDuplicateConnection": "Cannot create duplicate connections", "cannotDuplicateConnection": "Cannot create duplicate connections",
"cannotMixAndMatchCollectionItemTypes": "Cannot mix and match collection item types", "cannotMixAndMatchCollectionItemTypes": "Cannot mix and match collection item types",
"missingNode": "Missing invocation node",
"missingInvocationTemplate": "Missing invocation template",
"missingFieldTemplate": "Missing field template",
"nodePack": "Node pack", "nodePack": "Node pack",
"collection": "Collection", "collection": "Collection",
"collectionFieldType": "{{name}} Collection", "collectionFieldType": "{{name}} Collection",

View File

@ -17,7 +17,8 @@ import {
nodeAdded, nodeAdded,
openAddNodePopover, openAddNodePopover,
} from 'features/nodes/store/nodesSlice'; } from 'features/nodes/store/nodesSlice';
import { getFirstValidConnection, validateSourceAndTargetTypes } from 'features/nodes/store/util/connectionValidation'; import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection';
import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes';
import type { AnyNode } from 'features/nodes/types/invocation'; import type { AnyNode } from 'features/nodes/types/invocation';
import { isInvocationNode } from 'features/nodes/types/invocation'; import { isInvocationNode } from 'features/nodes/types/invocation';
import { filter, map, memoize, some } from 'lodash-es'; import { filter, map, memoize, some } from 'lodash-es';
@ -77,7 +78,7 @@ const AddNodePopover = () => {
return some(fields, (field) => { return some(fields, (field) => {
const sourceType = pendingFieldKind === 'input' ? field.type : pendingConnection.fieldTemplate.type; const sourceType = pendingFieldKind === 'input' ? field.type : pendingConnection.fieldTemplate.type;
const targetType = pendingFieldKind === 'output' ? field.type : pendingConnection.fieldTemplate.type; const targetType = pendingFieldKind === 'output' ? field.type : pendingConnection.fieldTemplate.type;
return validateSourceAndTargetTypes(sourceType, targetType); return validateConnectionTypes(sourceType, targetType);
}); });
}); });
}, [templates, pendingConnection]); }, [templates, pendingConnection]);

View File

@ -8,7 +8,7 @@ import {
$templates, $templates,
connectionMade, connectionMade,
} from 'features/nodes/store/nodesSlice'; } from 'features/nodes/store/nodesSlice';
import { getFirstValidConnection } from 'features/nodes/store/util/connectionValidation'; import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection';
import { isInvocationNode } from 'features/nodes/types/invocation'; import { isInvocationNode } from 'features/nodes/types/invocation';
import { isString } from 'lodash-es'; import { isString } from 'lodash-es';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';

View File

@ -2,12 +2,10 @@
import { useStore } from '@nanostores/react'; import { useStore } from '@nanostores/react';
import { useAppSelector, useAppStore } from 'app/store/storeHooks'; import { useAppSelector, useAppStore } from 'app/store/storeHooks';
import { $templates } from 'features/nodes/store/nodesSlice'; import { $templates } from 'features/nodes/store/nodesSlice';
import { import { areTypesEqual } from 'features/nodes/store/util/areTypesEqual';
areTypesEqual, import { getCollectItemType } from 'features/nodes/store/util/getCollectItemType';
getCollectItemType, import { getHasCycles } from 'features/nodes/store/util/getHasCycles';
getHasCycles, import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes';
validateSourceAndTargetTypes,
} from 'features/nodes/store/util/connectionValidation';
import type { InvocationNodeData } from 'features/nodes/types/invocation'; import type { InvocationNodeData } from 'features/nodes/types/invocation';
import { useCallback } from 'react'; import { useCallback } from 'react';
import type { Connection, Node } from 'reactflow'; import type { Connection, Node } from 'reactflow';
@ -88,7 +86,7 @@ export const useIsValidConnection = () => {
} }
// Must use the originalType here if it exists // Must use the originalType here if it exists
if (!validateSourceAndTargetTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) { if (!validateConnectionTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) {
return false; return false;
} }

View File

@ -0,0 +1,101 @@
import { describe, expect, it } from 'vitest';
import { areTypesEqual } from './areTypesEqual';
describe(areTypesEqual.name, () => {
it('should handle equal source and target type', () => {
const sourceType = {
name: 'IntegerField',
isCollection: false,
isCollectionOrScalar: false,
originalType: {
name: 'Foo',
isCollection: false,
isCollectionOrScalar: false,
},
};
const targetType = {
name: 'IntegerField',
isCollection: false,
isCollectionOrScalar: false,
originalType: {
name: 'Bar',
isCollection: false,
isCollectionOrScalar: false,
},
};
expect(areTypesEqual(sourceType, targetType)).toBe(true);
});
it('should handle equal source type and original target type', () => {
const sourceType = {
name: 'IntegerField',
isCollection: false,
isCollectionOrScalar: false,
originalType: {
name: 'Foo',
isCollection: false,
isCollectionOrScalar: false,
},
};
const targetType = {
name: 'Bar',
isCollection: false,
isCollectionOrScalar: false,
originalType: {
name: 'IntegerField',
isCollection: false,
isCollectionOrScalar: false,
},
};
expect(areTypesEqual(sourceType, targetType)).toBe(true);
});
it('should handle equal original source type and target type', () => {
const sourceType = {
name: 'Foo',
isCollection: false,
isCollectionOrScalar: false,
originalType: {
name: 'IntegerField',
isCollection: false,
isCollectionOrScalar: false,
},
};
const targetType = {
name: 'IntegerField',
isCollection: false,
isCollectionOrScalar: false,
originalType: {
name: 'Bar',
isCollection: false,
isCollectionOrScalar: false,
},
};
expect(areTypesEqual(sourceType, targetType)).toBe(true);
});
it('should handle equal original source type and original target type', () => {
const sourceType = {
name: 'Foo',
isCollection: false,
isCollectionOrScalar: false,
originalType: {
name: 'IntegerField',
isCollection: false,
isCollectionOrScalar: false,
},
};
const targetType = {
name: 'Bar',
isCollection: false,
isCollectionOrScalar: false,
originalType: {
name: 'IntegerField',
isCollection: false,
isCollectionOrScalar: false,
},
};
expect(areTypesEqual(sourceType, targetType)).toBe(true);
});
});

View File

@ -0,0 +1,30 @@
import type { FieldType } from 'features/nodes/types/field';
import { isEqual, omit } from 'lodash-es';
/**
* Checks if two types are equal. If the field types have original types, those are also compared. Any match is
* considered equal. For example, if the source type and original target type match, the types are considered equal.
* @param sourceType The type of the source field.
* @param targetType The type of the target field.
* @returns True if the types are equal, false otherwise.
*/
export const areTypesEqual = (sourceType: FieldType, targetType: FieldType) => {
const _sourceType = 'originalType' in sourceType ? omit(sourceType, 'originalType') : sourceType;
const _targetType = 'originalType' in targetType ? omit(targetType, 'originalType') : targetType;
const _sourceTypeOriginal = 'originalType' in sourceType ? sourceType.originalType : null;
const _targetTypeOriginal = 'originalType' in targetType ? targetType.originalType : null;
if (isEqual(_sourceType, _targetType)) {
return true;
}
if (_targetTypeOriginal && isEqual(_sourceType, _targetTypeOriginal)) {
return true;
}
if (_sourceTypeOriginal && isEqual(_sourceTypeOriginal, _targetType)) {
return true;
}
if (_sourceTypeOriginal && _targetTypeOriginal && isEqual(_sourceTypeOriginal, _targetTypeOriginal)) {
return true;
}
return false;
};

View File

@ -1,179 +1,16 @@
import graphlib from '@dagrejs/graphlib';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import type { RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import type { NodesState, PendingConnection, Templates } from 'features/nodes/store/types'; import type { NodesState, PendingConnection, Templates } from 'features/nodes/store/types';
import { type FieldType, isStatefulFieldType } from 'features/nodes/types/field'; import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes';
import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation'; import type { FieldType } from 'features/nodes/types/field';
import i18n from 'i18next'; import i18n from 'i18next';
import { differenceWith, isEqual, map, omit } from 'lodash-es'; import type { HandleType } from 'reactflow';
import type { Connection, Edge, HandleType, Node } from 'reactflow';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
/** import { areTypesEqual } from './areTypesEqual';
* Finds the first valid field for a pending connection between two nodes. import { getCollectItemType } from './getCollectItemType';
* @param templates The invocation templates import { getHasCycles } from './getHasCycles';
* @param nodes The current nodes
* @param edges The current edges
* @param pendingConnection The pending connection
* @param candidateNode The candidate node to which the connection is being made
* @param candidateTemplate The candidate template for the candidate node
* @returns The first valid connection, or null if no valid connection is found
*/
export const getFirstValidConnection = (
templates: Templates,
nodes: AnyNode[],
edges: InvocationNodeEdge[],
pendingConnection: PendingConnection,
candidateNode: InvocationNode,
candidateTemplate: InvocationTemplate
): Connection | null => {
if (pendingConnection.node.id === candidateNode.id) {
// Cannot connect to self
return null;
}
const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source';
if (pendingFieldKind === 'source') {
// Connecting from a source to a target
if (getHasCycles(pendingConnection.node.id, candidateNode.id, nodes, edges)) {
return null;
}
if (candidateNode.data.type === 'collect') {
// Special handling for collect node - the `item` field takes any number of connections
return {
source: pendingConnection.node.id,
sourceHandle: pendingConnection.fieldTemplate.name,
target: candidateNode.id,
targetHandle: 'item',
};
}
// Only one connection per target field is allowed - look for an unconnected target field
const candidateFields = map(candidateTemplate.inputs);
const candidateConnectedFields = edges
.filter((edge) => edge.target === candidateNode.id)
.map((edge) => {
// Edges must always have a targetHandle, safe to assert here
assert(edge.targetHandle);
return edge.targetHandle;
});
const candidateUnconnectedFields = differenceWith(
candidateFields,
candidateConnectedFields,
(field, connectedFieldName) => field.name === connectedFieldName
);
const candidateField = candidateUnconnectedFields.find((field) =>
validateSourceAndTargetTypes(pendingConnection.fieldTemplate.type, field.type)
);
if (candidateField) {
return {
source: pendingConnection.node.id,
sourceHandle: pendingConnection.fieldTemplate.name,
target: candidateNode.id,
targetHandle: candidateField.name,
};
}
} else {
// Connecting from a target to a source
// Ensure we there is not already an edge to the target, except for collect nodes
const isCollect = pendingConnection.node.data.type === 'collect';
const isTargetAlreadyConnected = edges.some(
(e) => e.target === pendingConnection.node.id && e.targetHandle === pendingConnection.fieldTemplate.name
);
if (!isCollect && isTargetAlreadyConnected) {
return null;
}
if (getHasCycles(candidateNode.id, pendingConnection.node.id, nodes, edges)) {
return null;
}
// Sources/outputs can have any number of edges, we can take the first matching output field
let candidateFields = map(candidateTemplate.outputs);
if (isCollect) {
// Narrow candidates to same field type as already is connected to the collect node
const collectItemType = getCollectItemType(templates, nodes, edges, pendingConnection.node.id);
if (collectItemType) {
candidateFields = candidateFields.filter((field) => areTypesEqual(field.type, collectItemType));
}
}
const candidateField = candidateFields.find((field) => {
const isValid = validateSourceAndTargetTypes(field.type, pendingConnection.fieldTemplate.type);
const isAlreadyConnected = edges.some((e) => e.source === candidateNode.id && e.sourceHandle === field.name);
return isValid && !isAlreadyConnected;
});
if (candidateField) {
return {
source: candidateNode.id,
sourceHandle: candidateField.name,
target: pendingConnection.node.id,
targetHandle: pendingConnection.fieldTemplate.name,
};
}
}
return null;
};
/**
* Check if adding an edge between the source and target nodes would create a cycle in the graph.
* @param source The source node id
* @param target The target node id
* @param nodes The graph's current nodes
* @param edges The graph's current edges
* @returns True if the graph would be acyclic after adding the edge, false otherwise
*/
export const getHasCycles = (source: string, target: string, nodes: Node[], edges: Edge[]) => {
// construct graphlib graph from editor state
const g = new graphlib.Graph();
nodes.forEach((n) => {
g.setNode(n.id);
});
edges.forEach((e) => {
g.setEdge(e.source, e.target);
});
// add the candidate edge
g.setEdge(source, target);
// check if the graph is acyclic
return !graphlib.alg.isAcyclic(g);
};
/**
* Given a collect node, return the type of the items it collects. The graph is traversed to find the first node and
* field connected to the collector's `item` input. The field type of that field is returned, else null if there is no
* input field.
* @param templates The current invocation templates
* @param nodes The current nodes
* @param edges The current edges
* @param nodeId The collect node's id
* @returns The type of the items the collect node collects, or null if there is no input field
*/
export const getCollectItemType = (
templates: Templates,
nodes: AnyNode[],
edges: InvocationNodeEdge[],
nodeId: string
): FieldType | null => {
const firstEdgeToCollect = edges.find((edge) => edge.target === nodeId && edge.targetHandle === 'item');
if (!firstEdgeToCollect?.sourceHandle) {
return null;
}
const node = nodes.find((n) => n.id === firstEdgeToCollect.source);
if (!node) {
return null;
}
const template = templates[node.data.type];
if (!template) {
return null;
}
const fieldType = template.outputs[firstEdgeToCollect.sourceHandle]?.type ?? null;
return fieldType;
};
/** /**
* Creates a selector that validates a pending connection. * Creates a selector that validates a pending connection.
@ -276,7 +113,7 @@ export const makeConnectionErrorSelector = (
return i18n.t('nodes.inputMayOnlyHaveOneConnection'); return i18n.t('nodes.inputMayOnlyHaveOneConnection');
} }
if (!validateSourceAndTargetTypes(sourceType, targetType)) { if (!validateConnectionTypes(sourceType, targetType)) {
return i18n.t('nodes.fieldTypesMustMatch'); return i18n.t('nodes.fieldTypesMustMatch');
} }
@ -295,97 +132,3 @@ export const makeConnectionErrorSelector = (
} }
); );
}; };
/**
* Validates that the source and target types are compatible for a connection.
* @param sourceType The type of the source field.
* @param targetType The type of the target field.
* @returns True if the connection is valid, false otherwise.
*/
export const validateSourceAndTargetTypes = (sourceType: FieldType, targetType: FieldType) => {
// TODO: There's a bug with Collect -> Iterate nodes:
// https://github.com/invoke-ai/InvokeAI/issues/3956
// Once this is resolved, we can remove this check.
if (sourceType.name === 'CollectionField' && targetType.name === 'CollectionField') {
return false;
}
if (areTypesEqual(sourceType, targetType)) {
return true;
}
/**
* Connection types must be the same for a connection, with exceptions:
* - CollectionItem can connect to any non-Collection
* - Non-Collections can connect to CollectionItem
* - Anything (non-Collections, Collections, CollectionOrScalar) can connect to CollectionOrScalar of the same base type
* - Generic Collection can connect to any other Collection or CollectionOrScalar
* - Any Collection can connect to a Generic Collection
*/
const isCollectionItemToNonCollection = sourceType.name === 'CollectionItemField' && !targetType.isCollection;
const isNonCollectionToCollectionItem =
targetType.name === 'CollectionItemField' && !sourceType.isCollection && !sourceType.isCollectionOrScalar;
const isAnythingToCollectionOrScalarOfSameBaseType =
targetType.isCollectionOrScalar && sourceType.name === targetType.name;
const isGenericCollectionToAnyCollectionOrCollectionOrScalar =
sourceType.name === 'CollectionField' && (targetType.isCollection || targetType.isCollectionOrScalar);
const isCollectionToGenericCollection = targetType.name === 'CollectionField' && sourceType.isCollection;
const areBothTypesSingle =
!sourceType.isCollection &&
!sourceType.isCollectionOrScalar &&
!targetType.isCollection &&
!targetType.isCollectionOrScalar;
const isIntToFloat = areBothTypesSingle && sourceType.name === 'IntegerField' && targetType.name === 'FloatField';
const isIntOrFloatToString =
areBothTypesSingle &&
(sourceType.name === 'IntegerField' || sourceType.name === 'FloatField') &&
targetType.name === 'StringField';
const isTargetAnyType = targetType.name === 'AnyField';
// One of these must be true for the connection to be valid
return (
isCollectionItemToNonCollection ||
isNonCollectionToCollectionItem ||
isAnythingToCollectionOrScalarOfSameBaseType ||
isGenericCollectionToAnyCollectionOrCollectionOrScalar ||
isCollectionToGenericCollection ||
isIntToFloat ||
isIntOrFloatToString ||
isTargetAnyType
);
};
/**
* Checks if two types are equal. If the field types have original types, those are also compared. Any match is
* considered equal. For example, if the source type and original target type match, the types are considered equal.
* @param sourceType The type of the source field.
* @param targetType The type of the target field.
* @returns True if the types are equal, false otherwise.
*/
export const areTypesEqual = (sourceType: FieldType, targetType: FieldType) => {
const _sourceType = isStatefulFieldType(sourceType) ? omit(sourceType, 'originalType') : sourceType;
const _targetType = isStatefulFieldType(targetType) ? omit(targetType, 'originalType') : targetType;
const _sourceTypeOriginal = isStatefulFieldType(sourceType) ? sourceType.originalType : sourceType;
const _targetTypeOriginal = isStatefulFieldType(targetType) ? targetType.originalType : targetType;
if (isEqual(_sourceType, _targetType)) {
return true;
}
if (_targetTypeOriginal && isEqual(_sourceType, _targetTypeOriginal)) {
return true;
}
if (_sourceTypeOriginal && isEqual(_sourceTypeOriginal, _targetType)) {
return true;
}
if (_sourceTypeOriginal && _targetTypeOriginal && isEqual(_sourceTypeOriginal, _targetTypeOriginal)) {
return true;
}
return false;
};

View File

@ -0,0 +1,16 @@
import { getCollectItemType } from 'features/nodes/store/util/getCollectItemType';
import { add, buildEdge, collect, position, templates } from 'features/nodes/store/util/testUtils';
import type { FieldType } from 'features/nodes/types/field';
import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode';
import { describe, expect, it } from 'vitest';
describe(getCollectItemType.name, () => {
it('should return the type of the items the collect node collects', () => {
const n1 = buildInvocationNode(position, add);
const n2 = buildInvocationNode(position, collect);
const nodes = [n1, n2];
const edges = [buildEdge(n1.id, 'value', n2.id, 'item')];
const result = getCollectItemType(templates, nodes, edges, n2.id);
expect(result).toEqual<FieldType>({ name: 'IntegerField', isCollection: false, isCollectionOrScalar: false });
});
});

View File

@ -0,0 +1,35 @@
import type { Templates } from 'features/nodes/store/types';
import type { FieldType } from 'features/nodes/types/field';
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
/**
* Given a collect node, return the type of the items it collects. The graph is traversed to find the first node and
* field connected to the collector's `item` input. The field type of that field is returned, else null if there is no
* input field.
* @param templates The current invocation templates
* @param nodes The current nodes
* @param edges The current edges
* @param nodeId The collect node's id
* @returns The type of the items the collect node collects, or null if there is no input field
*/
export const getCollectItemType = (
templates: Templates,
nodes: AnyNode[],
edges: InvocationNodeEdge[],
nodeId: string
): FieldType | null => {
const firstEdgeToCollect = edges.find((edge) => edge.target === nodeId && edge.targetHandle === 'item');
if (!firstEdgeToCollect?.sourceHandle) {
return null;
}
const node = nodes.find((n) => n.id === firstEdgeToCollect.source);
if (!node) {
return null;
}
const template = templates[node.data.type];
if (!template) {
return null;
}
const fieldType = template.outputs[firstEdgeToCollect.sourceHandle]?.type ?? null;
return fieldType;
};

View File

@ -0,0 +1,116 @@
import type { PendingConnection, Templates } from 'features/nodes/store/types';
import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes';
import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation';
import { differenceWith, map } from 'lodash-es';
import type { Connection } from 'reactflow';
import { assert } from 'tsafe';
import { areTypesEqual } from './areTypesEqual';
import { getCollectItemType } from './getCollectItemType';
import { getHasCycles } from './getHasCycles';
/**
* Finds the first valid field for a pending connection between two nodes.
* @param templates The invocation templates
* @param nodes The current nodes
* @param edges The current edges
* @param pendingConnection The pending connection
* @param candidateNode The candidate node to which the connection is being made
* @param candidateTemplate The candidate template for the candidate node
* @returns The first valid connection, or null if no valid connection is found
*/
export const getFirstValidConnection = (
templates: Templates,
nodes: AnyNode[],
edges: InvocationNodeEdge[],
pendingConnection: PendingConnection,
candidateNode: InvocationNode,
candidateTemplate: InvocationTemplate
): Connection | null => {
if (pendingConnection.node.id === candidateNode.id) {
// Cannot connect to self
return null;
}
const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source';
if (pendingFieldKind === 'source') {
// Connecting from a source to a target
if (getHasCycles(pendingConnection.node.id, candidateNode.id, nodes, edges)) {
return null;
}
if (candidateNode.data.type === 'collect') {
// Special handling for collect node - the `item` field takes any number of connections
return {
source: pendingConnection.node.id,
sourceHandle: pendingConnection.fieldTemplate.name,
target: candidateNode.id,
targetHandle: 'item',
};
}
// Only one connection per target field is allowed - look for an unconnected target field
const candidateFields = map(candidateTemplate.inputs);
const candidateConnectedFields = edges
.filter((edge) => edge.target === candidateNode.id)
.map((edge) => {
// Edges must always have a targetHandle, safe to assert here
assert(edge.targetHandle);
return edge.targetHandle;
});
const candidateUnconnectedFields = differenceWith(
candidateFields,
candidateConnectedFields,
(field, connectedFieldName) => field.name === connectedFieldName
);
const candidateField = candidateUnconnectedFields.find((field) => validateConnectionTypes(pendingConnection.fieldTemplate.type, field.type)
);
if (candidateField) {
return {
source: pendingConnection.node.id,
sourceHandle: pendingConnection.fieldTemplate.name,
target: candidateNode.id,
targetHandle: candidateField.name,
};
}
} else {
// Connecting from a target to a source
// Ensure we there is not already an edge to the target, except for collect nodes
const isCollect = pendingConnection.node.data.type === 'collect';
const isTargetAlreadyConnected = edges.some(
(e) => e.target === pendingConnection.node.id && e.targetHandle === pendingConnection.fieldTemplate.name
);
if (!isCollect && isTargetAlreadyConnected) {
return null;
}
if (getHasCycles(candidateNode.id, pendingConnection.node.id, nodes, edges)) {
return null;
}
// Sources/outputs can have any number of edges, we can take the first matching output field
let candidateFields = map(candidateTemplate.outputs);
if (isCollect) {
// Narrow candidates to same field type as already is connected to the collect node
const collectItemType = getCollectItemType(templates, nodes, edges, pendingConnection.node.id);
if (collectItemType) {
candidateFields = candidateFields.filter((field) => areTypesEqual(field.type, collectItemType));
}
}
const candidateField = candidateFields.find((field) => {
const isValid = validateConnectionTypes(field.type, pendingConnection.fieldTemplate.type);
const isAlreadyConnected = edges.some((e) => e.source === candidateNode.id && e.sourceHandle === field.name);
return isValid && !isAlreadyConnected;
});
if (candidateField) {
return {
source: candidateNode.id,
sourceHandle: candidateField.name,
target: pendingConnection.node.id,
targetHandle: pendingConnection.fieldTemplate.name,
};
}
}
return null;
};

View File

@ -0,0 +1,23 @@
import { getHasCycles } from 'features/nodes/store/util/getHasCycles';
import { add, buildEdge, position } from 'features/nodes/store/util/testUtils';
import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode';
import { describe, expect, it } from 'vitest';
describe(getHasCycles.name, () => {
const n1 = buildInvocationNode(position, add);
const n2 = buildInvocationNode(position, add);
const n3 = buildInvocationNode(position, add);
const nodes = [n1, n2, n3];
it('should return true if the graph WOULD have cycles after adding the edge', () => {
const edges = [buildEdge(n1.id, 'value', n2.id, 'a'), buildEdge(n2.id, 'value', n3.id, 'a')];
const result = getHasCycles(n3.id, n1.id, nodes, edges);
expect(result).toBe(true);
});
it('should return false if the graph WOULD NOT have cycles after adding the edge', () => {
const edges = [buildEdge(n1.id, 'value', n2.id, 'a')];
const result = getHasCycles(n2.id, n3.id, nodes, edges);
expect(result).toBe(false);
});
});

View File

@ -0,0 +1,30 @@
import graphlib from '@dagrejs/graphlib';
import type { Edge, Node } from 'reactflow';
/**
* Check if adding an edge between the source and target nodes would create a cycle in the graph.
* @param source The source node id
* @param target The target node id
* @param nodes The graph's current nodes
* @param edges The graph's current edges
* @returns True if the graph would be acyclic after adding the edge, false otherwise
*/
export const getHasCycles = (source: string, target: string, nodes: Node[], edges: Edge[]) => {
// construct graphlib graph from editor state
const g = new graphlib.Graph();
nodes.forEach((n) => {
g.setNode(n.id);
});
edges.forEach((e) => {
g.setEdge(e.source, e.target);
});
// add the candidate edge
g.setEdge(source, target);
// check if the graph is acyclic
return !graphlib.alg.isAcyclic(g);
};

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,149 @@
import { deepClone } from 'common/util/deepClone';
import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode';
import { set } from 'lodash-es';
import { describe, expect, it } from 'vitest';
import { add, buildEdge, collect, main_model_loader, position, sub, templates } from './testUtils';
import { buildAcceptResult, buildRejectResult, validateConnection } from './validateConnection';
describe(validateConnection.name, () => {
it('should reject invalid connection to self', () => {
const c = { source: 'add', sourceHandle: 'value', target: 'add', targetHandle: 'a' };
const r = validateConnection(c, [], [], templates, null);
expect(r).toEqual(buildRejectResult('nodes.cannotConnectToSelf'));
});
describe('missing nodes', () => {
const n1 = buildInvocationNode(position, add);
const n2 = buildInvocationNode(position, sub);
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
it('should reject missing source node', () => {
const r = validateConnection(c, [n2], [], templates, null);
expect(r).toEqual(buildRejectResult('nodes.missingNode'));
});
it('should reject missing target node', () => {
const r = validateConnection(c, [n1], [], templates, null);
expect(r).toEqual(buildRejectResult('nodes.missingNode'));
});
});
describe('missing invocation templates', () => {
const n1 = buildInvocationNode(position, add);
const n2 = buildInvocationNode(position, sub);
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
const nodes = [n1, n2];
it('should reject missing source template', () => {
const r = validateConnection(c, nodes, [], { sub }, null);
expect(r).toEqual(buildRejectResult('nodes.missingInvocationTemplate'));
});
it('should reject missing target template', () => {
const r = validateConnection(c, nodes, [], { add }, null);
expect(r).toEqual(buildRejectResult('nodes.missingInvocationTemplate'));
});
});
describe('missing field templates', () => {
const n1 = buildInvocationNode(position, add);
const n2 = buildInvocationNode(position, sub);
const nodes = [n1, n2];
it('should reject missing source field template', () => {
const c = { source: n1.id, sourceHandle: 'invalid', target: n2.id, targetHandle: 'a' };
const r = validateConnection(c, nodes, [], templates, null);
expect(r).toEqual(buildRejectResult('nodes.missingFieldTemplate'));
});
it('should reject missing target field template', () => {
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'invalid' };
const r = validateConnection(c, nodes, [], templates, null);
expect(r).toEqual(buildRejectResult('nodes.missingFieldTemplate'));
});
});
describe('duplicate connections', () => {
const n1 = buildInvocationNode(position, add);
const n2 = buildInvocationNode(position, sub);
it('should accept non-duplicate connections', () => {
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
const r = validateConnection(c, [n1, n2], [], templates, null);
expect(r).toEqual(buildAcceptResult());
});
it('should reject duplicate connections', () => {
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
const e = buildEdge(n1.id, 'value', n2.id, 'a');
const r = validateConnection(c, [n1, n2], [e], templates, null);
expect(r).toEqual(buildRejectResult('nodes.cannotDuplicateConnection'));
});
it('should accept duplicate connections if the duplicate is an ignored edge', () => {
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
const e = buildEdge(n1.id, 'value', n2.id, 'a');
const r = validateConnection(c, [n1, n2], [e], templates, e);
expect(r).toEqual(buildAcceptResult());
});
});
it('should reject connection to direct input', () => {
// Create cloned add template w/ a direct input
const addWithDirectAField = deepClone(add);
set(addWithDirectAField, 'inputs.a.input', 'direct');
set(addWithDirectAField, 'type', 'addWithDirectAField');
const n1 = buildInvocationNode(position, add);
const n2 = buildInvocationNode(position, addWithDirectAField);
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
const r = validateConnection(c, [n1, n2], [], { add, addWithDirectAField }, null);
expect(r).toEqual(buildRejectResult('nodes.cannotConnectToDirectInput'));
});
it('should reject connection to a collect node with mismatched item types', () => {
const n1 = buildInvocationNode(position, add);
const n2 = buildInvocationNode(position, collect);
const n3 = buildInvocationNode(position, main_model_loader);
const nodes = [n1, n2, n3];
const e1 = buildEdge(n1.id, 'value', n2.id, 'item');
const edges = [e1];
const c = { source: n3.id, sourceHandle: 'vae', target: n2.id, targetHandle: 'item' };
const r = validateConnection(c, nodes, edges, templates, null);
expect(r).toEqual(buildRejectResult('nodes.cannotMixAndMatchCollectionItemTypes'));
});
it('should accept connection to a collect node with matching item types', () => {
const n1 = buildInvocationNode(position, add);
const n2 = buildInvocationNode(position, collect);
const n3 = buildInvocationNode(position, sub);
const nodes = [n1, n2, n3];
const e1 = buildEdge(n1.id, 'value', n2.id, 'item');
const edges = [e1];
const c = { source: n3.id, sourceHandle: 'value', target: n2.id, targetHandle: 'item' };
const r = validateConnection(c, nodes, edges, templates, null);
expect(r).toEqual(buildAcceptResult());
});
it('should reject connections to target field that is already connected', () => {
const n1 = buildInvocationNode(position, add);
const n2 = buildInvocationNode(position, add);
const n3 = buildInvocationNode(position, add);
const nodes = [n1, n2, n3];
const e1 = buildEdge(n1.id, 'value', n2.id, 'a');
const edges = [e1];
const c = { source: n3.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
const r = validateConnection(c, nodes, edges, templates, null);
expect(r).toEqual(buildRejectResult('nodes.inputMayOnlyHaveOneConnection'));
});
it('should accept connections to target field that is already connected (ignored edge)', () => {
const n1 = buildInvocationNode(position, add);
const n2 = buildInvocationNode(position, add);
const n3 = buildInvocationNode(position, add);
const nodes = [n1, n2, n3];
const e1 = buildEdge(n1.id, 'value', n2.id, 'a');
const edges = [e1];
const c = { source: n3.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
const r = validateConnection(c, nodes, edges, templates, e1);
expect(r).toEqual(buildAcceptResult());
});
});

View File

@ -0,0 +1,109 @@
import type { Templates } from 'features/nodes/store/types';
import { areTypesEqual } from 'features/nodes/store/util/areTypesEqual';
import { getCollectItemType } from 'features/nodes/store/util/getCollectItemType';
import type { AnyNode } from 'features/nodes/types/invocation';
import type { Connection as NullableConnection, Edge } from 'reactflow';
import type { O } from 'ts-toolbelt';
type Connection = O.NonNullable<NullableConnection>;
export type ValidateConnectionResult = {
isValid: boolean;
messageTKey?: string;
};
export type ValidateConnectionFunc = (
connection: Connection,
nodes: AnyNode[],
edges: Edge[],
templates: Templates,
ignoreEdge: Edge | null
) => ValidateConnectionResult;
export const buildResult = (isValid: boolean, messageTKey?: string): ValidateConnectionResult => ({
isValid,
messageTKey,
});
const getEqualityPredicate =
(c: Connection) =>
(e: Edge): boolean => {
return (
e.target === c.target &&
e.targetHandle === c.targetHandle &&
e.source === c.source &&
e.sourceHandle === c.sourceHandle
);
};
export const buildAcceptResult = (): ValidateConnectionResult => ({ isValid: true });
export const buildRejectResult = (messageTKey: string): ValidateConnectionResult => ({ isValid: false, messageTKey });
export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, templates, ignoreEdge) => {
if (c.source === c.target) {
return buildRejectResult('nodes.cannotConnectToSelf');
}
const filteredEdges = edges.filter((e) => e.id !== ignoreEdge?.id);
if (filteredEdges.some(getEqualityPredicate(c))) {
// We already have a connection from this source to this target
return buildRejectResult('nodes.cannotDuplicateConnection');
}
const sourceNode = nodes.find((n) => n.id === c.source);
if (!sourceNode) {
return buildRejectResult('nodes.missingNode');
}
const targetNode = nodes.find((n) => n.id === c.target);
if (!targetNode) {
return buildRejectResult('nodes.missingNode');
}
const sourceTemplate = templates[sourceNode.data.type];
if (!sourceTemplate) {
return buildRejectResult('nodes.missingInvocationTemplate');
}
const targetTemplate = templates[targetNode.data.type];
if (!targetTemplate) {
return buildRejectResult('nodes.missingInvocationTemplate');
}
const sourceFieldTemplate = sourceTemplate.outputs[c.sourceHandle];
if (!sourceFieldTemplate) {
return buildRejectResult('nodes.missingFieldTemplate');
}
const targetFieldTemplate = targetTemplate.inputs[c.targetHandle];
if (!targetFieldTemplate) {
return buildRejectResult('nodes.missingFieldTemplate');
}
if (targetFieldTemplate.input === 'direct') {
return buildRejectResult('nodes.cannotConnectToDirectInput');
}
if (targetNode.data.type === 'collect' && c.targetHandle === 'item') {
// Collect nodes shouldn't mix and match field types
const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id);
if (collectItemType) {
if (!areTypesEqual(sourceFieldTemplate.type, collectItemType)) {
return buildRejectResult('nodes.cannotMixAndMatchCollectionItemTypes');
}
}
}
if (
edges.find((e) => {
return e.target === c.target && e.targetHandle === c.targetHandle;
}) &&
// except CollectionItem inputs can have multiples
targetFieldTemplate.type.name !== 'CollectionItemField'
) {
return buildRejectResult('nodes.inputMayOnlyHaveOneConnection');
}
return buildAcceptResult();
};

View File

@ -0,0 +1,222 @@
import { describe, expect, it } from 'vitest';
import { validateConnectionTypes } from './validateConnectionTypes';
describe(validateConnectionTypes.name, () => {
describe('generic cases', () => {
it('should accept Scalar to Scalar of same type', () => {
const r = validateConnectionTypes(
{ name: 'FooField', isCollection: false, isCollectionOrScalar: false },
{ name: 'FooField', isCollection: false, isCollectionOrScalar: false }
);
expect(r).toBe(true);
});
it('should accept Collection to Collection of same type', () => {
const r = validateConnectionTypes(
{ name: 'FooField', isCollection: true, isCollectionOrScalar: false },
{ name: 'FooField', isCollection: true, isCollectionOrScalar: false }
);
expect(r).toBe(true);
});
it('should accept Scalar to CollectionOrScalar of same type', () => {
const r = validateConnectionTypes(
{ name: 'FooField', isCollection: false, isCollectionOrScalar: false },
{ name: 'FooField', isCollection: false, isCollectionOrScalar: true }
);
expect(r).toBe(true);
});
it('should accept Collection to CollectionOrScalar of same type', () => {
const r = validateConnectionTypes(
{ name: 'FooField', isCollection: true, isCollectionOrScalar: false },
{ name: 'FooField', isCollection: false, isCollectionOrScalar: true }
);
expect(r).toBe(true);
});
it('should reject Collection to Scalar of same type', () => {
const r = validateConnectionTypes(
{ name: 'FooField', isCollection: true, isCollectionOrScalar: false },
{ name: 'FooField', isCollection: false, isCollectionOrScalar: false }
);
expect(r).toBe(false);
});
it('should reject CollectionOrScalar to Scalar of same type', () => {
const r = validateConnectionTypes(
{ name: 'FooField', isCollection: false, isCollectionOrScalar: true },
{ name: 'FooField', isCollection: false, isCollectionOrScalar: false }
);
expect(r).toBe(false);
});
it('should reject mismatched types', () => {
const r = validateConnectionTypes(
{ name: 'FooField', isCollection: false, isCollectionOrScalar: false },
{ name: 'BarField', isCollection: false, isCollectionOrScalar: false }
);
expect(r).toBe(false);
});
});
describe('special cases', () => {
it('should reject a collection input to a collection input', () => {
const r = validateConnectionTypes(
{ name: 'CollectionField', isCollection: true, isCollectionOrScalar: false },
{ name: 'CollectionField', isCollection: true, isCollectionOrScalar: false }
);
expect(r).toBe(false);
});
it('should accept equal types', () => {
const r = validateConnectionTypes(
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: false },
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: false }
);
expect(r).toBe(true);
});
describe('CollectionItemField', () => {
it('should accept CollectionItemField to any Scalar target', () => {
const r = validateConnectionTypes(
{ name: 'CollectionItemField', isCollection: false, isCollectionOrScalar: false },
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: false }
);
expect(r).toBe(true);
});
it('should accept CollectionItemField to any CollectionOrScalar target', () => {
const r = validateConnectionTypes(
{ name: 'CollectionItemField', isCollection: false, isCollectionOrScalar: false },
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: true }
);
expect(r).toBe(true);
});
it('should accept any non-Collection to CollectionItemField', () => {
const r = validateConnectionTypes(
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: false },
{ name: 'CollectionItemField', isCollection: false, isCollectionOrScalar: false }
);
expect(r).toBe(true);
});
it('should reject any Collection to CollectionItemField', () => {
const r = validateConnectionTypes(
{ name: 'IntegerField', isCollection: true, isCollectionOrScalar: false },
{ name: 'CollectionItemField', isCollection: false, isCollectionOrScalar: false }
);
expect(r).toBe(false);
});
it('should reject any CollectionOrScalar to CollectionItemField', () => {
const r = validateConnectionTypes(
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: true },
{ name: 'CollectionItemField', isCollection: false, isCollectionOrScalar: false }
);
expect(r).toBe(false);
});
});
describe('CollectionOrScalar', () => {
it('should accept any Scalar of same type to CollectionOrScalar', () => {
const r = validateConnectionTypes(
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: false },
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: true }
);
expect(r).toBe(true);
});
it('should accept any Collection of same type to CollectionOrScalar', () => {
const r = validateConnectionTypes(
{ name: 'IntegerField', isCollection: true, isCollectionOrScalar: false },
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: true }
);
expect(r).toBe(true);
});
it('should accept any CollectionOrScalar of same type to CollectionOrScalar', () => {
const r = validateConnectionTypes(
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: true },
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: true }
);
expect(r).toBe(true);
});
});
describe('CollectionField', () => {
it('should accept any CollectionField to any Collection type', () => {
const r = validateConnectionTypes(
{ name: 'CollectionField', isCollection: false, isCollectionOrScalar: false },
{ name: 'IntegerField', isCollection: true, isCollectionOrScalar: false }
);
expect(r).toBe(true);
});
it('should accept any CollectionField to any CollectionOrScalar type', () => {
const r = validateConnectionTypes(
{ name: 'CollectionField', isCollection: false, isCollectionOrScalar: false },
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: true }
);
expect(r).toBe(true);
});
});
describe('subtype handling', () => {
type TypePair = { t1: string; t2: string };
const typePairs = [
{ t1: 'IntegerField', t2: 'FloatField' },
{ t1: 'IntegerField', t2: 'StringField' },
{ t1: 'FloatField', t2: 'StringField' },
];
it.each(typePairs)('should accept Scalar $t1 to Scalar $t2', ({ t1, t2 }: TypePair) => {
const r = validateConnectionTypes(
{ name: t1, isCollection: false, isCollectionOrScalar: false },
{ name: t2, isCollection: false, isCollectionOrScalar: false }
);
expect(r).toBe(true);
});
it.each(typePairs)('should accept Scalar $t1 to CollectionOrScalar $t2', ({ t1, t2 }: TypePair) => {
const r = validateConnectionTypes(
{ name: t1, isCollection: false, isCollectionOrScalar: false },
{ name: t2, isCollection: false, isCollectionOrScalar: true }
);
expect(r).toBe(true);
});
it.each(typePairs)('should accept Collection $t1 to Collection $t2', ({ t1, t2 }: TypePair) => {
const r = validateConnectionTypes(
{ name: t1, isCollection: true, isCollectionOrScalar: false },
{ name: t2, isCollection: false, isCollectionOrScalar: false }
);
expect(r).toBe(true);
});
it.each(typePairs)('should accept Collection $t1 to CollectionOrScalar $t2', ({ t1, t2 }: TypePair) => {
const r = validateConnectionTypes(
{ name: t1, isCollection: true, isCollectionOrScalar: false },
{ name: t2, isCollection: false, isCollectionOrScalar: true }
);
expect(r).toBe(true);
});
it.each(typePairs)('should accept CollectionOrScalar $t1 to CollectionOrScalar $t2', ({ t1, t2 }: TypePair) => {
const r = validateConnectionTypes(
{ name: t1, isCollection: false, isCollectionOrScalar: true },
{ name: t2, isCollection: false, isCollectionOrScalar: true }
);
expect(r).toBe(true);
});
});
describe('AnyField', () => {
it('should accept any Scalar type to AnyField', () => {
const r = validateConnectionTypes(
{ name: 'FooField', isCollection: false, isCollectionOrScalar: false },
{ name: 'AnyField', isCollection: false, isCollectionOrScalar: false }
);
expect(r).toBe(true);
});
it('should accept any Collection type to AnyField', () => {
const r = validateConnectionTypes(
{ name: 'FooField', isCollection: false, isCollectionOrScalar: false },
{ name: 'AnyField', isCollection: true, isCollectionOrScalar: false }
);
expect(r).toBe(true);
});
it('should accept any CollectionOrScalar type to AnyField', () => {
const r = validateConnectionTypes(
{ name: 'FooField', isCollection: false, isCollectionOrScalar: false },
{ name: 'AnyField', isCollection: false, isCollectionOrScalar: true }
);
expect(r).toBe(true);
});
});
});
});

View File

@ -0,0 +1,69 @@
import { areTypesEqual } from 'features/nodes/store/util/areTypesEqual';
import type { FieldType } from 'features/nodes/types/field';
/**
* Validates that the source and target types are compatible for a connection.
* @param sourceType The type of the source field.
* @param targetType The type of the target field.
* @returns True if the connection is valid, false otherwise.
*/
export const validateConnectionTypes = (sourceType: FieldType, targetType: FieldType) => {
// TODO: There's a bug with Collect -> Iterate nodes:
// https://github.com/invoke-ai/InvokeAI/issues/3956
// Once this is resolved, we can remove this check.
if (sourceType.name === 'CollectionField' && targetType.name === 'CollectionField') {
return false;
}
if (areTypesEqual(sourceType, targetType)) {
return true;
}
/**
* Connection types must be the same for a connection, with exceptions:
* - CollectionItem can connect to any non-Collection
* - Non-Collections can connect to CollectionItem
* - Anything (non-Collections, Collections, CollectionOrScalar) can connect to CollectionOrScalar of the same base type
* - Generic Collection can connect to any other Collection or CollectionOrScalar
* - Any Collection can connect to a Generic Collection
*/
const isCollectionItemToNonCollection = sourceType.name === 'CollectionItemField' && !targetType.isCollection;
const isNonCollectionToCollectionItem =
targetType.name === 'CollectionItemField' && !sourceType.isCollection && !sourceType.isCollectionOrScalar;
const isAnythingToCollectionOrScalarOfSameBaseType =
targetType.isCollectionOrScalar && sourceType.name === targetType.name;
const isGenericCollectionToAnyCollectionOrCollectionOrScalar =
sourceType.name === 'CollectionField' && (targetType.isCollection || targetType.isCollectionOrScalar);
const isCollectionToGenericCollection = targetType.name === 'CollectionField' && sourceType.isCollection;
const areBothTypesSingle =
!sourceType.isCollection &&
!sourceType.isCollectionOrScalar &&
!targetType.isCollection &&
!targetType.isCollectionOrScalar;
const isIntToFloat = areBothTypesSingle && sourceType.name === 'IntegerField' && targetType.name === 'FloatField';
const isIntOrFloatToString =
areBothTypesSingle &&
(sourceType.name === 'IntegerField' || sourceType.name === 'FloatField') &&
targetType.name === 'StringField';
const isTargetAnyType = targetType.name === 'AnyField';
// One of these must be true for the connection to be valid
return (
isCollectionItemToNonCollection ||
isNonCollectionToCollectionItem ||
isAnythingToCollectionOrScalarOfSameBaseType ||
isGenericCollectionToAnyCollectionOrCollectionOrScalar ||
isCollectionToGenericCollection ||
isIntToFloat ||
isIntOrFloatToString ||
isTargetAnyType
);
};

View File

@ -188,7 +188,6 @@ const zIntegerFieldInputTemplate = zFieldInputTemplateBase.extend({
}); });
const zIntegerFieldOutputTemplate = zFieldOutputTemplateBase.extend({ const zIntegerFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zIntegerFieldType, type: zIntegerFieldType,
originalType: zFieldType.optional(),
}); });
export type IntegerFieldValue = z.infer<typeof zIntegerFieldValue>; export type IntegerFieldValue = z.infer<typeof zIntegerFieldValue>;
export type IntegerFieldInputInstance = z.infer<typeof zIntegerFieldInputInstance>; export type IntegerFieldInputInstance = z.infer<typeof zIntegerFieldInputInstance>;
@ -217,7 +216,6 @@ const zFloatFieldInputTemplate = zFieldInputTemplateBase.extend({
}); });
const zFloatFieldOutputTemplate = zFieldOutputTemplateBase.extend({ const zFloatFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zFloatFieldType, type: zFloatFieldType,
originalType: zFieldType.optional(),
}); });
export type FloatFieldValue = z.infer<typeof zFloatFieldValue>; export type FloatFieldValue = z.infer<typeof zFloatFieldValue>;
export type FloatFieldInputInstance = z.infer<typeof zFloatFieldInputInstance>; export type FloatFieldInputInstance = z.infer<typeof zFloatFieldInputInstance>;
@ -243,7 +241,6 @@ const zStringFieldInputTemplate = zFieldInputTemplateBase.extend({
}); });
const zStringFieldOutputTemplate = zFieldOutputTemplateBase.extend({ const zStringFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zStringFieldType, type: zStringFieldType,
originalType: zFieldType.optional(),
}); });
export type StringFieldValue = z.infer<typeof zStringFieldValue>; export type StringFieldValue = z.infer<typeof zStringFieldValue>;
@ -268,7 +265,6 @@ const zBooleanFieldInputTemplate = zFieldInputTemplateBase.extend({
}); });
const zBooleanFieldOutputTemplate = zFieldOutputTemplateBase.extend({ const zBooleanFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zBooleanFieldType, type: zBooleanFieldType,
originalType: zFieldType.optional(),
}); });
export type BooleanFieldValue = z.infer<typeof zBooleanFieldValue>; export type BooleanFieldValue = z.infer<typeof zBooleanFieldValue>;
export type BooleanFieldInputInstance = z.infer<typeof zBooleanFieldInputInstance>; export type BooleanFieldInputInstance = z.infer<typeof zBooleanFieldInputInstance>;
@ -294,7 +290,6 @@ const zEnumFieldInputTemplate = zFieldInputTemplateBase.extend({
}); });
const zEnumFieldOutputTemplate = zFieldOutputTemplateBase.extend({ const zEnumFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zEnumFieldType, type: zEnumFieldType,
originalType: zFieldType.optional(),
}); });
export type EnumFieldValue = z.infer<typeof zEnumFieldValue>; export type EnumFieldValue = z.infer<typeof zEnumFieldValue>;
export type EnumFieldInputInstance = z.infer<typeof zEnumFieldInputInstance>; export type EnumFieldInputInstance = z.infer<typeof zEnumFieldInputInstance>;
@ -318,7 +313,6 @@ const zImageFieldInputTemplate = zFieldInputTemplateBase.extend({
}); });
const zImageFieldOutputTemplate = zFieldOutputTemplateBase.extend({ const zImageFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zImageFieldType, type: zImageFieldType,
originalType: zFieldType.optional(),
}); });
export type ImageFieldValue = z.infer<typeof zImageFieldValue>; export type ImageFieldValue = z.infer<typeof zImageFieldValue>;
export type ImageFieldInputInstance = z.infer<typeof zImageFieldInputInstance>; export type ImageFieldInputInstance = z.infer<typeof zImageFieldInputInstance>;
@ -342,7 +336,6 @@ const zBoardFieldInputTemplate = zFieldInputTemplateBase.extend({
}); });
const zBoardFieldOutputTemplate = zFieldOutputTemplateBase.extend({ const zBoardFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zBoardFieldType, type: zBoardFieldType,
originalType: zFieldType.optional(),
}); });
export type BoardFieldValue = z.infer<typeof zBoardFieldValue>; export type BoardFieldValue = z.infer<typeof zBoardFieldValue>;
export type BoardFieldInputInstance = z.infer<typeof zBoardFieldInputInstance>; export type BoardFieldInputInstance = z.infer<typeof zBoardFieldInputInstance>;
@ -366,7 +359,6 @@ const zColorFieldInputTemplate = zFieldInputTemplateBase.extend({
}); });
const zColorFieldOutputTemplate = zFieldOutputTemplateBase.extend({ const zColorFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zColorFieldType, type: zColorFieldType,
originalType: zFieldType.optional(),
}); });
export type ColorFieldValue = z.infer<typeof zColorFieldValue>; export type ColorFieldValue = z.infer<typeof zColorFieldValue>;
export type ColorFieldInputInstance = z.infer<typeof zColorFieldInputInstance>; export type ColorFieldInputInstance = z.infer<typeof zColorFieldInputInstance>;
@ -390,7 +382,6 @@ const zMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
}); });
const zMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ const zMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zMainModelFieldType, type: zMainModelFieldType,
originalType: zFieldType.optional(),
}); });
export type MainModelFieldValue = z.infer<typeof zMainModelFieldValue>; export type MainModelFieldValue = z.infer<typeof zMainModelFieldValue>;
export type MainModelFieldInputInstance = z.infer<typeof zMainModelFieldInputInstance>; export type MainModelFieldInputInstance = z.infer<typeof zMainModelFieldInputInstance>;
@ -413,7 +404,6 @@ const zModelIdentifierFieldInputTemplate = zFieldInputTemplateBase.extend({
}); });
const zModelIdentifierFieldOutputTemplate = zFieldOutputTemplateBase.extend({ const zModelIdentifierFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zModelIdentifierFieldType, type: zModelIdentifierFieldType,
originalType: zFieldType.optional(),
}); });
export type ModelIdentifierFieldValue = z.infer<typeof zModelIdentifierFieldValue>; export type ModelIdentifierFieldValue = z.infer<typeof zModelIdentifierFieldValue>;
export type ModelIdentifierFieldInputInstance = z.infer<typeof zModelIdentifierFieldInputInstance>; export type ModelIdentifierFieldInputInstance = z.infer<typeof zModelIdentifierFieldInputInstance>;
@ -437,7 +427,6 @@ const zSDXLMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
}); });
const zSDXLMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ const zSDXLMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zSDXLMainModelFieldType, type: zSDXLMainModelFieldType,
originalType: zFieldType.optional(),
}); });
export type SDXLMainModelFieldInputInstance = z.infer<typeof zSDXLMainModelFieldInputInstance>; export type SDXLMainModelFieldInputInstance = z.infer<typeof zSDXLMainModelFieldInputInstance>;
export type SDXLMainModelFieldInputTemplate = z.infer<typeof zSDXLMainModelFieldInputTemplate>; export type SDXLMainModelFieldInputTemplate = z.infer<typeof zSDXLMainModelFieldInputTemplate>;
@ -461,7 +450,6 @@ const zSDXLRefinerModelFieldInputTemplate = zFieldInputTemplateBase.extend({
}); });
const zSDXLRefinerModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ const zSDXLRefinerModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zSDXLRefinerModelFieldType, type: zSDXLRefinerModelFieldType,
originalType: zFieldType.optional(),
}); });
export type SDXLRefinerModelFieldValue = z.infer<typeof zSDXLRefinerModelFieldValue>; export type SDXLRefinerModelFieldValue = z.infer<typeof zSDXLRefinerModelFieldValue>;
export type SDXLRefinerModelFieldInputInstance = z.infer<typeof zSDXLRefinerModelFieldInputInstance>; export type SDXLRefinerModelFieldInputInstance = z.infer<typeof zSDXLRefinerModelFieldInputInstance>;
@ -485,7 +473,6 @@ const zVAEModelFieldInputTemplate = zFieldInputTemplateBase.extend({
}); });
const zVAEModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ const zVAEModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zVAEModelFieldType, type: zVAEModelFieldType,
originalType: zFieldType.optional(),
}); });
export type VAEModelFieldValue = z.infer<typeof zVAEModelFieldValue>; export type VAEModelFieldValue = z.infer<typeof zVAEModelFieldValue>;
export type VAEModelFieldInputInstance = z.infer<typeof zVAEModelFieldInputInstance>; export type VAEModelFieldInputInstance = z.infer<typeof zVAEModelFieldInputInstance>;
@ -509,7 +496,6 @@ const zLoRAModelFieldInputTemplate = zFieldInputTemplateBase.extend({
}); });
const zLoRAModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ const zLoRAModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zLoRAModelFieldType, type: zLoRAModelFieldType,
originalType: zFieldType.optional(),
}); });
export type LoRAModelFieldValue = z.infer<typeof zLoRAModelFieldValue>; export type LoRAModelFieldValue = z.infer<typeof zLoRAModelFieldValue>;
export type LoRAModelFieldInputInstance = z.infer<typeof zLoRAModelFieldInputInstance>; export type LoRAModelFieldInputInstance = z.infer<typeof zLoRAModelFieldInputInstance>;
@ -533,7 +519,6 @@ const zControlNetModelFieldInputTemplate = zFieldInputTemplateBase.extend({
}); });
const zControlNetModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ const zControlNetModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zControlNetModelFieldType, type: zControlNetModelFieldType,
originalType: zFieldType.optional(),
}); });
export type ControlNetModelFieldValue = z.infer<typeof zControlNetModelFieldValue>; export type ControlNetModelFieldValue = z.infer<typeof zControlNetModelFieldValue>;
export type ControlNetModelFieldInputInstance = z.infer<typeof zControlNetModelFieldInputInstance>; export type ControlNetModelFieldInputInstance = z.infer<typeof zControlNetModelFieldInputInstance>;
@ -557,7 +542,6 @@ const zIPAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({
}); });
const zIPAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ const zIPAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zIPAdapterModelFieldType, type: zIPAdapterModelFieldType,
originalType: zFieldType.optional(),
}); });
export type IPAdapterModelFieldValue = z.infer<typeof zIPAdapterModelFieldValue>; export type IPAdapterModelFieldValue = z.infer<typeof zIPAdapterModelFieldValue>;
export type IPAdapterModelFieldInputInstance = z.infer<typeof zIPAdapterModelFieldInputInstance>; export type IPAdapterModelFieldInputInstance = z.infer<typeof zIPAdapterModelFieldInputInstance>;
@ -581,7 +565,6 @@ const zT2IAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({
}); });
const zT2IAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ const zT2IAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zT2IAdapterModelFieldType, type: zT2IAdapterModelFieldType,
originalType: zFieldType.optional(),
}); });
export type T2IAdapterModelFieldValue = z.infer<typeof zT2IAdapterModelFieldValue>; export type T2IAdapterModelFieldValue = z.infer<typeof zT2IAdapterModelFieldValue>;
export type T2IAdapterModelFieldInputInstance = z.infer<typeof zT2IAdapterModelFieldInputInstance>; export type T2IAdapterModelFieldInputInstance = z.infer<typeof zT2IAdapterModelFieldInputInstance>;
@ -605,7 +588,6 @@ const zSchedulerFieldInputTemplate = zFieldInputTemplateBase.extend({
}); });
const zSchedulerFieldOutputTemplate = zFieldOutputTemplateBase.extend({ const zSchedulerFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zSchedulerFieldType, type: zSchedulerFieldType,
originalType: zFieldType.optional(),
}); });
export type SchedulerFieldValue = z.infer<typeof zSchedulerFieldValue>; export type SchedulerFieldValue = z.infer<typeof zSchedulerFieldValue>;
export type SchedulerFieldInputInstance = z.infer<typeof zSchedulerFieldInputInstance>; export type SchedulerFieldInputInstance = z.infer<typeof zSchedulerFieldInputInstance>;
@ -641,7 +623,6 @@ const zStatelessFieldInputTemplate = zFieldInputTemplateBase.extend({
}); });
const zStatelessFieldOutputTemplate = zFieldOutputTemplateBase.extend({ const zStatelessFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zStatelessFieldType, type: zStatelessFieldType,
originalType: zFieldType.optional(),
}); });
export type StatelessFieldInputTemplate = z.infer<typeof zStatelessFieldInputTemplate>; export type StatelessFieldInputTemplate = z.infer<typeof zStatelessFieldInputTemplate>;

View File

@ -1,942 +1,19 @@
import { schema, templates } from 'features/nodes/store/util/testUtils';
import { parseSchema } from 'features/nodes/util/schema/parseSchema'; import { parseSchema } from 'features/nodes/util/schema/parseSchema';
import { omit, pick } from 'lodash-es'; import { omit, pick } from 'lodash-es';
import type { OpenAPIV3_1 } from 'openapi-types';
import { describe, expect, it } from 'vitest'; import { describe, expect, it } from 'vitest';
describe('parseSchema', () => { describe('parseSchema', () => {
it('should parse the schema', () => { it('should parse the schema', () => {
const templates = parseSchema(schema); const parsed = parseSchema(schema);
expect(templates).toEqual(expected); expect(parsed).toEqual(templates);
}); });
it('should omit denied nodes', () => { it('should omit denied nodes', () => {
const templates = parseSchema(schema, undefined, ['add']); const parsed = parseSchema(schema, undefined, ['add']);
expect(templates).toEqual(omit(expected, 'add')); expect(parsed).toEqual(omit(templates, 'add'));
}); });
it('should include only allowed nodes', () => { it('should include only allowed nodes', () => {
const templates = parseSchema(schema, ['add']); const parsed = parseSchema(schema, ['add']);
expect(templates).toEqual(pick(expected, 'add')); expect(parsed).toEqual(pick(templates, 'add'));
}); });
}); });
const expected = {
add: {
title: 'Add Integers',
type: 'add',
version: '1.0.1',
tags: ['math', 'add'],
description: 'Adds two numbers',
outputType: 'integer_output',
inputs: {
a: {
name: 'a',
title: 'A',
required: false,
description: 'The first number',
fieldKind: 'input',
input: 'any',
ui_hidden: false,
type: {
name: 'IntegerField',
isCollection: false,
isCollectionOrScalar: false,
},
default: 0,
},
b: {
name: 'b',
title: 'B',
required: false,
description: 'The second number',
fieldKind: 'input',
input: 'any',
ui_hidden: false,
type: {
name: 'IntegerField',
isCollection: false,
isCollectionOrScalar: false,
},
default: 0,
},
},
outputs: {
value: {
fieldKind: 'output',
name: 'value',
title: 'Value',
description: 'The output integer',
type: {
name: 'IntegerField',
isCollection: false,
isCollectionOrScalar: false,
},
ui_hidden: false,
},
},
useCache: true,
nodePack: 'invokeai',
classification: 'stable',
},
scheduler: {
title: 'Scheduler',
type: 'scheduler',
version: '1.0.0',
tags: ['scheduler'],
description: 'Selects a scheduler.',
outputType: 'scheduler_output',
inputs: {
scheduler: {
name: 'scheduler',
title: 'Scheduler',
required: false,
description: 'Scheduler to use during inference',
fieldKind: 'input',
input: 'any',
ui_hidden: false,
ui_type: 'SchedulerField',
type: {
name: 'SchedulerField',
isCollection: false,
isCollectionOrScalar: false,
originalType: {
name: 'EnumField',
isCollection: false,
isCollectionOrScalar: false,
},
},
default: 'euler',
},
},
outputs: {
scheduler: {
fieldKind: 'output',
name: 'scheduler',
title: 'Scheduler',
description: 'Scheduler to use during inference',
type: {
name: 'SchedulerField',
isCollection: false,
isCollectionOrScalar: false,
originalType: {
name: 'EnumField',
isCollection: false,
isCollectionOrScalar: false,
},
},
ui_hidden: false,
ui_type: 'SchedulerField',
},
},
useCache: true,
nodePack: 'invokeai',
classification: 'stable',
},
main_model_loader: {
title: 'Main Model',
type: 'main_model_loader',
version: '1.0.2',
tags: ['model'],
description: 'Loads a main model, outputting its submodels.',
outputType: 'model_loader_output',
inputs: {
model: {
name: 'model',
title: 'Model',
required: true,
description: 'Main model (UNet, VAE, CLIP) to load',
fieldKind: 'input',
input: 'direct',
ui_hidden: false,
ui_type: 'MainModelField',
type: {
name: 'MainModelField',
isCollection: false,
isCollectionOrScalar: false,
originalType: {
name: 'ModelIdentifierField',
isCollection: false,
isCollectionOrScalar: false,
},
},
},
},
outputs: {
vae: {
fieldKind: 'output',
name: 'vae',
title: 'VAE',
description: 'VAE',
type: {
name: 'VAEField',
isCollection: false,
isCollectionOrScalar: false,
},
ui_hidden: false,
},
clip: {
fieldKind: 'output',
name: 'clip',
title: 'CLIP',
description: 'CLIP (tokenizer, text encoder, LoRAs) and skipped layer count',
type: {
name: 'CLIPField',
isCollection: false,
isCollectionOrScalar: false,
},
ui_hidden: false,
},
unet: {
fieldKind: 'output',
name: 'unet',
title: 'UNet',
description: 'UNet (scheduler, LoRAs)',
type: {
name: 'UNetField',
isCollection: false,
isCollectionOrScalar: false,
},
ui_hidden: false,
},
},
useCache: true,
nodePack: 'invokeai',
classification: 'stable',
},
collect: {
title: 'Collect',
type: 'collect',
version: '1.0.0',
tags: [],
description: 'Collects values into a collection',
outputType: 'collect_output',
inputs: {
item: {
name: 'item',
title: 'Collection Item',
required: false,
description: 'The item to collect (all inputs must be of the same type)',
fieldKind: 'input',
input: 'connection',
ui_hidden: false,
ui_type: 'CollectionItemField',
type: {
name: 'CollectionItemField',
isCollection: false,
isCollectionOrScalar: false,
},
},
},
outputs: {
collection: {
fieldKind: 'output',
name: 'collection',
title: 'Collection',
description: 'The collection of input items',
type: {
name: 'CollectionField',
isCollection: true,
isCollectionOrScalar: false,
},
ui_hidden: false,
ui_type: 'CollectionField',
},
},
useCache: true,
classification: 'stable',
},
};
const schema = {
openapi: '3.1.0',
info: {
title: 'Invoke - Community Edition',
description: 'An API for invoking AI image operations',
version: '1.0.0',
},
components: {
schemas: {
AddInvocation: {
properties: {
id: {
type: 'string',
title: 'Id',
description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.',
field_kind: 'node_attribute',
},
is_intermediate: {
type: 'boolean',
title: 'Is Intermediate',
description: 'Whether or not this is an intermediate invocation.',
default: false,
field_kind: 'node_attribute',
ui_type: 'IsIntermediate',
},
use_cache: {
type: 'boolean',
title: 'Use Cache',
description: 'Whether or not to use the cache',
default: true,
field_kind: 'node_attribute',
},
a: {
type: 'integer',
title: 'A',
description: 'The first number',
default: 0,
field_kind: 'input',
input: 'any',
orig_default: 0,
orig_required: false,
ui_hidden: false,
},
b: {
type: 'integer',
title: 'B',
description: 'The second number',
default: 0,
field_kind: 'input',
input: 'any',
orig_default: 0,
orig_required: false,
ui_hidden: false,
},
type: {
type: 'string',
enum: ['add'],
const: 'add',
title: 'type',
default: 'add',
field_kind: 'node_attribute',
},
},
type: 'object',
required: ['type', 'id'],
title: 'Add Integers',
description: 'Adds two numbers',
category: 'math',
classification: 'stable',
node_pack: 'invokeai',
tags: ['math', 'add'],
version: '1.0.1',
output: {
$ref: '#/components/schemas/IntegerOutput',
},
class: 'invocation',
},
IntegerOutput: {
description: 'Base class for nodes that output a single integer',
properties: {
value: {
description: 'The output integer',
field_kind: 'output',
title: 'Value',
type: 'integer',
ui_hidden: false,
},
type: {
const: 'integer_output',
default: 'integer_output',
enum: ['integer_output'],
field_kind: 'node_attribute',
title: 'type',
type: 'string',
},
},
required: ['value', 'type', 'type'],
title: 'IntegerOutput',
type: 'object',
class: 'output',
},
SchedulerInvocation: {
properties: {
id: {
type: 'string',
title: 'Id',
description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.',
field_kind: 'node_attribute',
},
is_intermediate: {
type: 'boolean',
title: 'Is Intermediate',
description: 'Whether or not this is an intermediate invocation.',
default: false,
field_kind: 'node_attribute',
ui_type: 'IsIntermediate',
},
use_cache: {
type: 'boolean',
title: 'Use Cache',
description: 'Whether or not to use the cache',
default: true,
field_kind: 'node_attribute',
},
scheduler: {
type: 'string',
enum: [
'ddim',
'ddpm',
'deis',
'lms',
'lms_k',
'pndm',
'heun',
'heun_k',
'euler',
'euler_k',
'euler_a',
'kdpm_2',
'kdpm_2_a',
'dpmpp_2s',
'dpmpp_2s_k',
'dpmpp_2m',
'dpmpp_2m_k',
'dpmpp_2m_sde',
'dpmpp_2m_sde_k',
'dpmpp_sde',
'dpmpp_sde_k',
'unipc',
'lcm',
'tcd',
],
title: 'Scheduler',
description: 'Scheduler to use during inference',
default: 'euler',
field_kind: 'input',
input: 'any',
orig_default: 'euler',
orig_required: false,
ui_hidden: false,
ui_type: 'SchedulerField',
},
type: {
type: 'string',
enum: ['scheduler'],
const: 'scheduler',
title: 'type',
default: 'scheduler',
field_kind: 'node_attribute',
},
},
type: 'object',
required: ['type', 'id'],
title: 'Scheduler',
description: 'Selects a scheduler.',
category: 'latents',
classification: 'stable',
node_pack: 'invokeai',
tags: ['scheduler'],
version: '1.0.0',
output: {
$ref: '#/components/schemas/SchedulerOutput',
},
class: 'invocation',
},
SchedulerOutput: {
properties: {
scheduler: {
description: 'Scheduler to use during inference',
enum: [
'ddim',
'ddpm',
'deis',
'lms',
'lms_k',
'pndm',
'heun',
'heun_k',
'euler',
'euler_k',
'euler_a',
'kdpm_2',
'kdpm_2_a',
'dpmpp_2s',
'dpmpp_2s_k',
'dpmpp_2m',
'dpmpp_2m_k',
'dpmpp_2m_sde',
'dpmpp_2m_sde_k',
'dpmpp_sde',
'dpmpp_sde_k',
'unipc',
'lcm',
'tcd',
],
field_kind: 'output',
title: 'Scheduler',
type: 'string',
ui_hidden: false,
ui_type: 'SchedulerField',
},
type: {
const: 'scheduler_output',
default: 'scheduler_output',
enum: ['scheduler_output'],
field_kind: 'node_attribute',
title: 'type',
type: 'string',
},
},
required: ['scheduler', 'type', 'type'],
title: 'SchedulerOutput',
type: 'object',
class: 'output',
},
MainModelLoaderInvocation: {
properties: {
id: {
type: 'string',
title: 'Id',
description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.',
field_kind: 'node_attribute',
},
is_intermediate: {
type: 'boolean',
title: 'Is Intermediate',
description: 'Whether or not this is an intermediate invocation.',
default: false,
field_kind: 'node_attribute',
ui_type: 'IsIntermediate',
},
use_cache: {
type: 'boolean',
title: 'Use Cache',
description: 'Whether or not to use the cache',
default: true,
field_kind: 'node_attribute',
},
model: {
allOf: [
{
$ref: '#/components/schemas/ModelIdentifierField',
},
],
description: 'Main model (UNet, VAE, CLIP) to load',
field_kind: 'input',
input: 'direct',
orig_required: true,
ui_hidden: false,
ui_type: 'MainModelField',
},
type: {
type: 'string',
enum: ['main_model_loader'],
const: 'main_model_loader',
title: 'type',
default: 'main_model_loader',
field_kind: 'node_attribute',
},
},
type: 'object',
required: ['model', 'type', 'id'],
title: 'Main Model',
description: 'Loads a main model, outputting its submodels.',
category: 'model',
classification: 'stable',
node_pack: 'invokeai',
tags: ['model'],
version: '1.0.2',
output: {
$ref: '#/components/schemas/ModelLoaderOutput',
},
class: 'invocation',
},
ModelIdentifierField: {
properties: {
key: {
description: "The model's unique key",
title: 'Key',
type: 'string',
},
hash: {
description: "The model's BLAKE3 hash",
title: 'Hash',
type: 'string',
},
name: {
description: "The model's name",
title: 'Name',
type: 'string',
},
base: {
allOf: [
{
$ref: '#/components/schemas/BaseModelType',
},
],
description: "The model's base model type",
},
type: {
allOf: [
{
$ref: '#/components/schemas/ModelType',
},
],
description: "The model's type",
},
submodel_type: {
anyOf: [
{
$ref: '#/components/schemas/SubModelType',
},
{
type: 'null',
},
],
default: null,
description: 'The submodel to load, if this is a main model',
},
},
required: ['key', 'hash', 'name', 'base', 'type'],
title: 'ModelIdentifierField',
type: 'object',
},
BaseModelType: {
description: 'Base model type.',
enum: ['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'],
title: 'BaseModelType',
type: 'string',
},
ModelType: {
description: 'Model type.',
enum: ['onnx', 'main', 'vae', 'lora', 'controlnet', 'embedding', 'ip_adapter', 'clip_vision', 't2i_adapter'],
title: 'ModelType',
type: 'string',
},
SubModelType: {
description: 'Submodel type.',
enum: [
'unet',
'text_encoder',
'text_encoder_2',
'tokenizer',
'tokenizer_2',
'vae',
'vae_decoder',
'vae_encoder',
'scheduler',
'safety_checker',
],
title: 'SubModelType',
type: 'string',
},
ModelLoaderOutput: {
description: 'Model loader output',
properties: {
vae: {
allOf: [
{
$ref: '#/components/schemas/VAEField',
},
],
description: 'VAE',
field_kind: 'output',
title: 'VAE',
ui_hidden: false,
},
type: {
const: 'model_loader_output',
default: 'model_loader_output',
enum: ['model_loader_output'],
field_kind: 'node_attribute',
title: 'type',
type: 'string',
},
clip: {
allOf: [
{
$ref: '#/components/schemas/CLIPField',
},
],
description: 'CLIP (tokenizer, text encoder, LoRAs) and skipped layer count',
field_kind: 'output',
title: 'CLIP',
ui_hidden: false,
},
unet: {
allOf: [
{
$ref: '#/components/schemas/UNetField',
},
],
description: 'UNet (scheduler, LoRAs)',
field_kind: 'output',
title: 'UNet',
ui_hidden: false,
},
},
required: ['vae', 'type', 'clip', 'unet', 'type'],
title: 'ModelLoaderOutput',
type: 'object',
class: 'output',
},
UNetField: {
properties: {
unet: {
allOf: [
{
$ref: '#/components/schemas/ModelIdentifierField',
},
],
description: 'Info to load unet submodel',
},
scheduler: {
allOf: [
{
$ref: '#/components/schemas/ModelIdentifierField',
},
],
description: 'Info to load scheduler submodel',
},
loras: {
description: 'LoRAs to apply on model loading',
items: {
$ref: '#/components/schemas/LoRAField',
},
title: 'Loras',
type: 'array',
},
seamless_axes: {
description: 'Axes("x" and "y") to which apply seamless',
items: {
type: 'string',
},
title: 'Seamless Axes',
type: 'array',
},
freeu_config: {
anyOf: [
{
$ref: '#/components/schemas/FreeUConfig',
},
{
type: 'null',
},
],
default: null,
description: 'FreeU configuration',
},
},
required: ['unet', 'scheduler', 'loras'],
title: 'UNetField',
type: 'object',
class: 'output',
},
LoRAField: {
properties: {
lora: {
allOf: [
{
$ref: '#/components/schemas/ModelIdentifierField',
},
],
description: 'Info to load lora model',
},
weight: {
description: 'Weight to apply to lora model',
title: 'Weight',
type: 'number',
},
},
required: ['lora', 'weight'],
title: 'LoRAField',
type: 'object',
class: 'output',
},
FreeUConfig: {
description:
'Configuration for the FreeU hyperparameters.\n- https://huggingface.co/docs/diffusers/main/en/using-diffusers/freeu\n- https://github.com/ChenyangSi/FreeU',
properties: {
s1: {
description:
'Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.',
maximum: 3.0,
minimum: -1.0,
title: 'S1',
type: 'number',
},
s2: {
description:
'Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.',
maximum: 3.0,
minimum: -1.0,
title: 'S2',
type: 'number',
},
b1: {
description: 'Scaling factor for stage 1 to amplify the contributions of backbone features.',
maximum: 3.0,
minimum: -1.0,
title: 'B1',
type: 'number',
},
b2: {
description: 'Scaling factor for stage 2 to amplify the contributions of backbone features.',
maximum: 3.0,
minimum: -1.0,
title: 'B2',
type: 'number',
},
},
required: ['s1', 's2', 'b1', 'b2'],
title: 'FreeUConfig',
type: 'object',
class: 'output',
},
VAEField: {
properties: {
vae: {
allOf: [
{
$ref: '#/components/schemas/ModelIdentifierField',
},
],
description: 'Info to load vae submodel',
},
seamless_axes: {
description: 'Axes("x" and "y") to which apply seamless',
items: {
type: 'string',
},
title: 'Seamless Axes',
type: 'array',
},
},
required: ['vae'],
title: 'VAEField',
type: 'object',
class: 'output',
},
CLIPField: {
properties: {
tokenizer: {
allOf: [
{
$ref: '#/components/schemas/ModelIdentifierField',
},
],
description: 'Info to load tokenizer submodel',
},
text_encoder: {
allOf: [
{
$ref: '#/components/schemas/ModelIdentifierField',
},
],
description: 'Info to load text_encoder submodel',
},
skipped_layers: {
description: 'Number of skipped layers in text_encoder',
title: 'Skipped Layers',
type: 'integer',
},
loras: {
description: 'LoRAs to apply on model loading',
items: {
$ref: '#/components/schemas/LoRAField',
},
title: 'Loras',
type: 'array',
},
},
required: ['tokenizer', 'text_encoder', 'skipped_layers', 'loras'],
title: 'CLIPField',
type: 'object',
class: 'output',
},
CollectInvocation: {
properties: {
id: {
type: 'string',
title: 'Id',
description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.',
field_kind: 'node_attribute',
},
is_intermediate: {
type: 'boolean',
title: 'Is Intermediate',
description: 'Whether or not this is an intermediate invocation.',
default: false,
field_kind: 'node_attribute',
ui_type: 'IsIntermediate',
},
use_cache: {
type: 'boolean',
title: 'Use Cache',
description: 'Whether or not to use the cache',
default: true,
field_kind: 'node_attribute',
},
item: {
anyOf: [
{},
{
type: 'null',
},
],
title: 'Collection Item',
description: 'The item to collect (all inputs must be of the same type)',
field_kind: 'input',
input: 'connection',
orig_required: false,
ui_hidden: false,
ui_type: 'CollectionItemField',
},
collection: {
items: {},
type: 'array',
title: 'Collection',
description: 'The collection, will be provided on execution',
default: [],
field_kind: 'input',
input: 'any',
orig_default: [],
orig_required: false,
ui_hidden: true,
},
type: {
type: 'string',
enum: ['collect'],
const: 'collect',
title: 'type',
default: 'collect',
field_kind: 'node_attribute',
},
},
type: 'object',
required: ['type', 'id'],
title: 'CollectInvocation',
description: 'Collects values into a collection',
classification: 'stable',
version: '1.0.0',
output: {
$ref: '#/components/schemas/CollectInvocationOutput',
},
class: 'invocation',
},
CollectInvocationOutput: {
properties: {
collection: {
description: 'The collection of input items',
field_kind: 'output',
items: {},
title: 'Collection',
type: 'array',
ui_hidden: false,
ui_type: 'CollectionField',
},
type: {
const: 'collect_output',
default: 'collect_output',
enum: ['collect_output'],
field_kind: 'node_attribute',
title: 'type',
type: 'string',
},
},
required: ['collection', 'type', 'type'],
title: 'CollectInvocationOutput',
type: 'object',
class: 'output',
},
},
},
} as OpenAPIV3_1.Document;