mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
tests(ui): add tests for consolidated connection validation
This commit is contained in:
parent
6f7160b9fd
commit
3fcb2720d7
@ -775,6 +775,9 @@
|
||||
"cannotConnectToSelf": "Cannot connect to self",
|
||||
"cannotDuplicateConnection": "Cannot create duplicate connections",
|
||||
"cannotMixAndMatchCollectionItemTypes": "Cannot mix and match collection item types",
|
||||
"missingNode": "Missing invocation node",
|
||||
"missingInvocationTemplate": "Missing invocation template",
|
||||
"missingFieldTemplate": "Missing field template",
|
||||
"nodePack": "Node pack",
|
||||
"collection": "Collection",
|
||||
"collectionFieldType": "{{name}} Collection",
|
||||
|
@ -17,7 +17,8 @@ import {
|
||||
nodeAdded,
|
||||
openAddNodePopover,
|
||||
} from 'features/nodes/store/nodesSlice';
|
||||
import { getFirstValidConnection, validateSourceAndTargetTypes } from 'features/nodes/store/util/connectionValidation';
|
||||
import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection';
|
||||
import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes';
|
||||
import type { AnyNode } from 'features/nodes/types/invocation';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { filter, map, memoize, some } from 'lodash-es';
|
||||
@ -77,7 +78,7 @@ const AddNodePopover = () => {
|
||||
return some(fields, (field) => {
|
||||
const sourceType = pendingFieldKind === 'input' ? field.type : pendingConnection.fieldTemplate.type;
|
||||
const targetType = pendingFieldKind === 'output' ? field.type : pendingConnection.fieldTemplate.type;
|
||||
return validateSourceAndTargetTypes(sourceType, targetType);
|
||||
return validateConnectionTypes(sourceType, targetType);
|
||||
});
|
||||
});
|
||||
}, [templates, pendingConnection]);
|
||||
|
@ -8,7 +8,7 @@ import {
|
||||
$templates,
|
||||
connectionMade,
|
||||
} from 'features/nodes/store/nodesSlice';
|
||||
import { getFirstValidConnection } from 'features/nodes/store/util/connectionValidation';
|
||||
import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { isString } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
|
@ -2,12 +2,10 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
areTypesEqual,
|
||||
getCollectItemType,
|
||||
getHasCycles,
|
||||
validateSourceAndTargetTypes,
|
||||
} from 'features/nodes/store/util/connectionValidation';
|
||||
import { areTypesEqual } from 'features/nodes/store/util/areTypesEqual';
|
||||
import { getCollectItemType } from 'features/nodes/store/util/getCollectItemType';
|
||||
import { getHasCycles } from 'features/nodes/store/util/getHasCycles';
|
||||
import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes';
|
||||
import type { InvocationNodeData } from 'features/nodes/types/invocation';
|
||||
import { useCallback } from 'react';
|
||||
import type { Connection, Node } from 'reactflow';
|
||||
@ -88,7 +86,7 @@ export const useIsValidConnection = () => {
|
||||
}
|
||||
|
||||
// Must use the originalType here if it exists
|
||||
if (!validateSourceAndTargetTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) {
|
||||
if (!validateConnectionTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,101 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import { areTypesEqual } from './areTypesEqual';
|
||||
|
||||
describe(areTypesEqual.name, () => {
|
||||
it('should handle equal source and target type', () => {
|
||||
const sourceType = {
|
||||
name: 'IntegerField',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
originalType: {
|
||||
name: 'Foo',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
},
|
||||
};
|
||||
const targetType = {
|
||||
name: 'IntegerField',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
originalType: {
|
||||
name: 'Bar',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
},
|
||||
};
|
||||
expect(areTypesEqual(sourceType, targetType)).toBe(true);
|
||||
});
|
||||
|
||||
it('should handle equal source type and original target type', () => {
|
||||
const sourceType = {
|
||||
name: 'IntegerField',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
originalType: {
|
||||
name: 'Foo',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
},
|
||||
};
|
||||
const targetType = {
|
||||
name: 'Bar',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
originalType: {
|
||||
name: 'IntegerField',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
},
|
||||
};
|
||||
expect(areTypesEqual(sourceType, targetType)).toBe(true);
|
||||
});
|
||||
|
||||
it('should handle equal original source type and target type', () => {
|
||||
const sourceType = {
|
||||
name: 'Foo',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
originalType: {
|
||||
name: 'IntegerField',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
},
|
||||
};
|
||||
const targetType = {
|
||||
name: 'IntegerField',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
originalType: {
|
||||
name: 'Bar',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
},
|
||||
};
|
||||
expect(areTypesEqual(sourceType, targetType)).toBe(true);
|
||||
});
|
||||
|
||||
it('should handle equal original source type and original target type', () => {
|
||||
const sourceType = {
|
||||
name: 'Foo',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
originalType: {
|
||||
name: 'IntegerField',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
},
|
||||
};
|
||||
const targetType = {
|
||||
name: 'Bar',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
originalType: {
|
||||
name: 'IntegerField',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
},
|
||||
};
|
||||
expect(areTypesEqual(sourceType, targetType)).toBe(true);
|
||||
});
|
||||
});
|
@ -0,0 +1,30 @@
|
||||
import type { FieldType } from 'features/nodes/types/field';
|
||||
import { isEqual, omit } from 'lodash-es';
|
||||
|
||||
/**
|
||||
* Checks if two types are equal. If the field types have original types, those are also compared. Any match is
|
||||
* considered equal. For example, if the source type and original target type match, the types are considered equal.
|
||||
* @param sourceType The type of the source field.
|
||||
* @param targetType The type of the target field.
|
||||
* @returns True if the types are equal, false otherwise.
|
||||
*/
|
||||
|
||||
export const areTypesEqual = (sourceType: FieldType, targetType: FieldType) => {
|
||||
const _sourceType = 'originalType' in sourceType ? omit(sourceType, 'originalType') : sourceType;
|
||||
const _targetType = 'originalType' in targetType ? omit(targetType, 'originalType') : targetType;
|
||||
const _sourceTypeOriginal = 'originalType' in sourceType ? sourceType.originalType : null;
|
||||
const _targetTypeOriginal = 'originalType' in targetType ? targetType.originalType : null;
|
||||
if (isEqual(_sourceType, _targetType)) {
|
||||
return true;
|
||||
}
|
||||
if (_targetTypeOriginal && isEqual(_sourceType, _targetTypeOriginal)) {
|
||||
return true;
|
||||
}
|
||||
if (_sourceTypeOriginal && isEqual(_sourceTypeOriginal, _targetType)) {
|
||||
return true;
|
||||
}
|
||||
if (_sourceTypeOriginal && _targetTypeOriginal && isEqual(_sourceTypeOriginal, _targetTypeOriginal)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
@ -1,179 +1,16 @@
|
||||
import graphlib from '@dagrejs/graphlib';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import type { NodesState, PendingConnection, Templates } from 'features/nodes/store/types';
|
||||
import { type FieldType, isStatefulFieldType } from 'features/nodes/types/field';
|
||||
import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation';
|
||||
import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes';
|
||||
import type { FieldType } from 'features/nodes/types/field';
|
||||
import i18n from 'i18next';
|
||||
import { differenceWith, isEqual, map, omit } from 'lodash-es';
|
||||
import type { Connection, Edge, HandleType, Node } from 'reactflow';
|
||||
import type { HandleType } from 'reactflow';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
/**
|
||||
* Finds the first valid field for a pending connection between two nodes.
|
||||
* @param templates The invocation templates
|
||||
* @param nodes The current nodes
|
||||
* @param edges The current edges
|
||||
* @param pendingConnection The pending connection
|
||||
* @param candidateNode The candidate node to which the connection is being made
|
||||
* @param candidateTemplate The candidate template for the candidate node
|
||||
* @returns The first valid connection, or null if no valid connection is found
|
||||
*/
|
||||
export const getFirstValidConnection = (
|
||||
templates: Templates,
|
||||
nodes: AnyNode[],
|
||||
edges: InvocationNodeEdge[],
|
||||
pendingConnection: PendingConnection,
|
||||
candidateNode: InvocationNode,
|
||||
candidateTemplate: InvocationTemplate
|
||||
): Connection | null => {
|
||||
if (pendingConnection.node.id === candidateNode.id) {
|
||||
// Cannot connect to self
|
||||
return null;
|
||||
}
|
||||
|
||||
const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source';
|
||||
|
||||
if (pendingFieldKind === 'source') {
|
||||
// Connecting from a source to a target
|
||||
if (getHasCycles(pendingConnection.node.id, candidateNode.id, nodes, edges)) {
|
||||
return null;
|
||||
}
|
||||
if (candidateNode.data.type === 'collect') {
|
||||
// Special handling for collect node - the `item` field takes any number of connections
|
||||
return {
|
||||
source: pendingConnection.node.id,
|
||||
sourceHandle: pendingConnection.fieldTemplate.name,
|
||||
target: candidateNode.id,
|
||||
targetHandle: 'item',
|
||||
};
|
||||
}
|
||||
// Only one connection per target field is allowed - look for an unconnected target field
|
||||
const candidateFields = map(candidateTemplate.inputs);
|
||||
const candidateConnectedFields = edges
|
||||
.filter((edge) => edge.target === candidateNode.id)
|
||||
.map((edge) => {
|
||||
// Edges must always have a targetHandle, safe to assert here
|
||||
assert(edge.targetHandle);
|
||||
return edge.targetHandle;
|
||||
});
|
||||
const candidateUnconnectedFields = differenceWith(
|
||||
candidateFields,
|
||||
candidateConnectedFields,
|
||||
(field, connectedFieldName) => field.name === connectedFieldName
|
||||
);
|
||||
const candidateField = candidateUnconnectedFields.find((field) =>
|
||||
validateSourceAndTargetTypes(pendingConnection.fieldTemplate.type, field.type)
|
||||
);
|
||||
if (candidateField) {
|
||||
return {
|
||||
source: pendingConnection.node.id,
|
||||
sourceHandle: pendingConnection.fieldTemplate.name,
|
||||
target: candidateNode.id,
|
||||
targetHandle: candidateField.name,
|
||||
};
|
||||
}
|
||||
} else {
|
||||
// Connecting from a target to a source
|
||||
// Ensure we there is not already an edge to the target, except for collect nodes
|
||||
const isCollect = pendingConnection.node.data.type === 'collect';
|
||||
const isTargetAlreadyConnected = edges.some(
|
||||
(e) => e.target === pendingConnection.node.id && e.targetHandle === pendingConnection.fieldTemplate.name
|
||||
);
|
||||
if (!isCollect && isTargetAlreadyConnected) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (getHasCycles(candidateNode.id, pendingConnection.node.id, nodes, edges)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Sources/outputs can have any number of edges, we can take the first matching output field
|
||||
let candidateFields = map(candidateTemplate.outputs);
|
||||
if (isCollect) {
|
||||
// Narrow candidates to same field type as already is connected to the collect node
|
||||
const collectItemType = getCollectItemType(templates, nodes, edges, pendingConnection.node.id);
|
||||
if (collectItemType) {
|
||||
candidateFields = candidateFields.filter((field) => areTypesEqual(field.type, collectItemType));
|
||||
}
|
||||
}
|
||||
const candidateField = candidateFields.find((field) => {
|
||||
const isValid = validateSourceAndTargetTypes(field.type, pendingConnection.fieldTemplate.type);
|
||||
const isAlreadyConnected = edges.some((e) => e.source === candidateNode.id && e.sourceHandle === field.name);
|
||||
return isValid && !isAlreadyConnected;
|
||||
});
|
||||
if (candidateField) {
|
||||
return {
|
||||
source: candidateNode.id,
|
||||
sourceHandle: candidateField.name,
|
||||
target: pendingConnection.node.id,
|
||||
targetHandle: pendingConnection.fieldTemplate.name,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
/**
|
||||
* Check if adding an edge between the source and target nodes would create a cycle in the graph.
|
||||
* @param source The source node id
|
||||
* @param target The target node id
|
||||
* @param nodes The graph's current nodes
|
||||
* @param edges The graph's current edges
|
||||
* @returns True if the graph would be acyclic after adding the edge, false otherwise
|
||||
*/
|
||||
export const getHasCycles = (source: string, target: string, nodes: Node[], edges: Edge[]) => {
|
||||
// construct graphlib graph from editor state
|
||||
const g = new graphlib.Graph();
|
||||
|
||||
nodes.forEach((n) => {
|
||||
g.setNode(n.id);
|
||||
});
|
||||
|
||||
edges.forEach((e) => {
|
||||
g.setEdge(e.source, e.target);
|
||||
});
|
||||
|
||||
// add the candidate edge
|
||||
g.setEdge(source, target);
|
||||
|
||||
// check if the graph is acyclic
|
||||
return !graphlib.alg.isAcyclic(g);
|
||||
};
|
||||
|
||||
/**
|
||||
* Given a collect node, return the type of the items it collects. The graph is traversed to find the first node and
|
||||
* field connected to the collector's `item` input. The field type of that field is returned, else null if there is no
|
||||
* input field.
|
||||
* @param templates The current invocation templates
|
||||
* @param nodes The current nodes
|
||||
* @param edges The current edges
|
||||
* @param nodeId The collect node's id
|
||||
* @returns The type of the items the collect node collects, or null if there is no input field
|
||||
*/
|
||||
export const getCollectItemType = (
|
||||
templates: Templates,
|
||||
nodes: AnyNode[],
|
||||
edges: InvocationNodeEdge[],
|
||||
nodeId: string
|
||||
): FieldType | null => {
|
||||
const firstEdgeToCollect = edges.find((edge) => edge.target === nodeId && edge.targetHandle === 'item');
|
||||
if (!firstEdgeToCollect?.sourceHandle) {
|
||||
return null;
|
||||
}
|
||||
const node = nodes.find((n) => n.id === firstEdgeToCollect.source);
|
||||
if (!node) {
|
||||
return null;
|
||||
}
|
||||
const template = templates[node.data.type];
|
||||
if (!template) {
|
||||
return null;
|
||||
}
|
||||
const fieldType = template.outputs[firstEdgeToCollect.sourceHandle]?.type ?? null;
|
||||
return fieldType;
|
||||
};
|
||||
import { areTypesEqual } from './areTypesEqual';
|
||||
import { getCollectItemType } from './getCollectItemType';
|
||||
import { getHasCycles } from './getHasCycles';
|
||||
|
||||
/**
|
||||
* Creates a selector that validates a pending connection.
|
||||
@ -276,7 +113,7 @@ export const makeConnectionErrorSelector = (
|
||||
return i18n.t('nodes.inputMayOnlyHaveOneConnection');
|
||||
}
|
||||
|
||||
if (!validateSourceAndTargetTypes(sourceType, targetType)) {
|
||||
if (!validateConnectionTypes(sourceType, targetType)) {
|
||||
return i18n.t('nodes.fieldTypesMustMatch');
|
||||
}
|
||||
|
||||
@ -295,97 +132,3 @@ export const makeConnectionErrorSelector = (
|
||||
}
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* Validates that the source and target types are compatible for a connection.
|
||||
* @param sourceType The type of the source field.
|
||||
* @param targetType The type of the target field.
|
||||
* @returns True if the connection is valid, false otherwise.
|
||||
*/
|
||||
export const validateSourceAndTargetTypes = (sourceType: FieldType, targetType: FieldType) => {
|
||||
// TODO: There's a bug with Collect -> Iterate nodes:
|
||||
// https://github.com/invoke-ai/InvokeAI/issues/3956
|
||||
// Once this is resolved, we can remove this check.
|
||||
if (sourceType.name === 'CollectionField' && targetType.name === 'CollectionField') {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (areTypesEqual(sourceType, targetType)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Connection types must be the same for a connection, with exceptions:
|
||||
* - CollectionItem can connect to any non-Collection
|
||||
* - Non-Collections can connect to CollectionItem
|
||||
* - Anything (non-Collections, Collections, CollectionOrScalar) can connect to CollectionOrScalar of the same base type
|
||||
* - Generic Collection can connect to any other Collection or CollectionOrScalar
|
||||
* - Any Collection can connect to a Generic Collection
|
||||
*/
|
||||
const isCollectionItemToNonCollection = sourceType.name === 'CollectionItemField' && !targetType.isCollection;
|
||||
|
||||
const isNonCollectionToCollectionItem =
|
||||
targetType.name === 'CollectionItemField' && !sourceType.isCollection && !sourceType.isCollectionOrScalar;
|
||||
|
||||
const isAnythingToCollectionOrScalarOfSameBaseType =
|
||||
targetType.isCollectionOrScalar && sourceType.name === targetType.name;
|
||||
|
||||
const isGenericCollectionToAnyCollectionOrCollectionOrScalar =
|
||||
sourceType.name === 'CollectionField' && (targetType.isCollection || targetType.isCollectionOrScalar);
|
||||
|
||||
const isCollectionToGenericCollection = targetType.name === 'CollectionField' && sourceType.isCollection;
|
||||
|
||||
const areBothTypesSingle =
|
||||
!sourceType.isCollection &&
|
||||
!sourceType.isCollectionOrScalar &&
|
||||
!targetType.isCollection &&
|
||||
!targetType.isCollectionOrScalar;
|
||||
|
||||
const isIntToFloat = areBothTypesSingle && sourceType.name === 'IntegerField' && targetType.name === 'FloatField';
|
||||
|
||||
const isIntOrFloatToString =
|
||||
areBothTypesSingle &&
|
||||
(sourceType.name === 'IntegerField' || sourceType.name === 'FloatField') &&
|
||||
targetType.name === 'StringField';
|
||||
|
||||
const isTargetAnyType = targetType.name === 'AnyField';
|
||||
|
||||
// One of these must be true for the connection to be valid
|
||||
return (
|
||||
isCollectionItemToNonCollection ||
|
||||
isNonCollectionToCollectionItem ||
|
||||
isAnythingToCollectionOrScalarOfSameBaseType ||
|
||||
isGenericCollectionToAnyCollectionOrCollectionOrScalar ||
|
||||
isCollectionToGenericCollection ||
|
||||
isIntToFloat ||
|
||||
isIntOrFloatToString ||
|
||||
isTargetAnyType
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* Checks if two types are equal. If the field types have original types, those are also compared. Any match is
|
||||
* considered equal. For example, if the source type and original target type match, the types are considered equal.
|
||||
* @param sourceType The type of the source field.
|
||||
* @param targetType The type of the target field.
|
||||
* @returns True if the types are equal, false otherwise.
|
||||
*/
|
||||
export const areTypesEqual = (sourceType: FieldType, targetType: FieldType) => {
|
||||
const _sourceType = isStatefulFieldType(sourceType) ? omit(sourceType, 'originalType') : sourceType;
|
||||
const _targetType = isStatefulFieldType(targetType) ? omit(targetType, 'originalType') : targetType;
|
||||
const _sourceTypeOriginal = isStatefulFieldType(sourceType) ? sourceType.originalType : sourceType;
|
||||
const _targetTypeOriginal = isStatefulFieldType(targetType) ? targetType.originalType : targetType;
|
||||
if (isEqual(_sourceType, _targetType)) {
|
||||
return true;
|
||||
}
|
||||
if (_targetTypeOriginal && isEqual(_sourceType, _targetTypeOriginal)) {
|
||||
return true;
|
||||
}
|
||||
if (_sourceTypeOriginal && isEqual(_sourceTypeOriginal, _targetType)) {
|
||||
return true;
|
||||
}
|
||||
if (_sourceTypeOriginal && _targetTypeOriginal && isEqual(_sourceTypeOriginal, _targetTypeOriginal)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
@ -0,0 +1,16 @@
|
||||
import { getCollectItemType } from 'features/nodes/store/util/getCollectItemType';
|
||||
import { add, buildEdge, collect, position, templates } from 'features/nodes/store/util/testUtils';
|
||||
import type { FieldType } from 'features/nodes/types/field';
|
||||
import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode';
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
describe(getCollectItemType.name, () => {
|
||||
it('should return the type of the items the collect node collects', () => {
|
||||
const n1 = buildInvocationNode(position, add);
|
||||
const n2 = buildInvocationNode(position, collect);
|
||||
const nodes = [n1, n2];
|
||||
const edges = [buildEdge(n1.id, 'value', n2.id, 'item')];
|
||||
const result = getCollectItemType(templates, nodes, edges, n2.id);
|
||||
expect(result).toEqual<FieldType>({ name: 'IntegerField', isCollection: false, isCollectionOrScalar: false });
|
||||
});
|
||||
});
|
@ -0,0 +1,35 @@
|
||||
import type { Templates } from 'features/nodes/store/types';
|
||||
import type { FieldType } from 'features/nodes/types/field';
|
||||
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
|
||||
|
||||
/**
|
||||
* Given a collect node, return the type of the items it collects. The graph is traversed to find the first node and
|
||||
* field connected to the collector's `item` input. The field type of that field is returned, else null if there is no
|
||||
* input field.
|
||||
* @param templates The current invocation templates
|
||||
* @param nodes The current nodes
|
||||
* @param edges The current edges
|
||||
* @param nodeId The collect node's id
|
||||
* @returns The type of the items the collect node collects, or null if there is no input field
|
||||
*/
|
||||
export const getCollectItemType = (
|
||||
templates: Templates,
|
||||
nodes: AnyNode[],
|
||||
edges: InvocationNodeEdge[],
|
||||
nodeId: string
|
||||
): FieldType | null => {
|
||||
const firstEdgeToCollect = edges.find((edge) => edge.target === nodeId && edge.targetHandle === 'item');
|
||||
if (!firstEdgeToCollect?.sourceHandle) {
|
||||
return null;
|
||||
}
|
||||
const node = nodes.find((n) => n.id === firstEdgeToCollect.source);
|
||||
if (!node) {
|
||||
return null;
|
||||
}
|
||||
const template = templates[node.data.type];
|
||||
if (!template) {
|
||||
return null;
|
||||
}
|
||||
const fieldType = template.outputs[firstEdgeToCollect.sourceHandle]?.type ?? null;
|
||||
return fieldType;
|
||||
};
|
@ -0,0 +1,116 @@
|
||||
import type { PendingConnection, Templates } from 'features/nodes/store/types';
|
||||
import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes';
|
||||
import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation';
|
||||
import { differenceWith, map } from 'lodash-es';
|
||||
import type { Connection } from 'reactflow';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import { areTypesEqual } from './areTypesEqual';
|
||||
import { getCollectItemType } from './getCollectItemType';
|
||||
import { getHasCycles } from './getHasCycles';
|
||||
|
||||
/**
|
||||
* Finds the first valid field for a pending connection between two nodes.
|
||||
* @param templates The invocation templates
|
||||
* @param nodes The current nodes
|
||||
* @param edges The current edges
|
||||
* @param pendingConnection The pending connection
|
||||
* @param candidateNode The candidate node to which the connection is being made
|
||||
* @param candidateTemplate The candidate template for the candidate node
|
||||
* @returns The first valid connection, or null if no valid connection is found
|
||||
*/
|
||||
|
||||
export const getFirstValidConnection = (
|
||||
templates: Templates,
|
||||
nodes: AnyNode[],
|
||||
edges: InvocationNodeEdge[],
|
||||
pendingConnection: PendingConnection,
|
||||
candidateNode: InvocationNode,
|
||||
candidateTemplate: InvocationTemplate
|
||||
): Connection | null => {
|
||||
if (pendingConnection.node.id === candidateNode.id) {
|
||||
// Cannot connect to self
|
||||
return null;
|
||||
}
|
||||
|
||||
const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source';
|
||||
|
||||
if (pendingFieldKind === 'source') {
|
||||
// Connecting from a source to a target
|
||||
if (getHasCycles(pendingConnection.node.id, candidateNode.id, nodes, edges)) {
|
||||
return null;
|
||||
}
|
||||
if (candidateNode.data.type === 'collect') {
|
||||
// Special handling for collect node - the `item` field takes any number of connections
|
||||
return {
|
||||
source: pendingConnection.node.id,
|
||||
sourceHandle: pendingConnection.fieldTemplate.name,
|
||||
target: candidateNode.id,
|
||||
targetHandle: 'item',
|
||||
};
|
||||
}
|
||||
// Only one connection per target field is allowed - look for an unconnected target field
|
||||
const candidateFields = map(candidateTemplate.inputs);
|
||||
const candidateConnectedFields = edges
|
||||
.filter((edge) => edge.target === candidateNode.id)
|
||||
.map((edge) => {
|
||||
// Edges must always have a targetHandle, safe to assert here
|
||||
assert(edge.targetHandle);
|
||||
return edge.targetHandle;
|
||||
});
|
||||
const candidateUnconnectedFields = differenceWith(
|
||||
candidateFields,
|
||||
candidateConnectedFields,
|
||||
(field, connectedFieldName) => field.name === connectedFieldName
|
||||
);
|
||||
const candidateField = candidateUnconnectedFields.find((field) => validateConnectionTypes(pendingConnection.fieldTemplate.type, field.type)
|
||||
);
|
||||
if (candidateField) {
|
||||
return {
|
||||
source: pendingConnection.node.id,
|
||||
sourceHandle: pendingConnection.fieldTemplate.name,
|
||||
target: candidateNode.id,
|
||||
targetHandle: candidateField.name,
|
||||
};
|
||||
}
|
||||
} else {
|
||||
// Connecting from a target to a source
|
||||
// Ensure we there is not already an edge to the target, except for collect nodes
|
||||
const isCollect = pendingConnection.node.data.type === 'collect';
|
||||
const isTargetAlreadyConnected = edges.some(
|
||||
(e) => e.target === pendingConnection.node.id && e.targetHandle === pendingConnection.fieldTemplate.name
|
||||
);
|
||||
if (!isCollect && isTargetAlreadyConnected) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (getHasCycles(candidateNode.id, pendingConnection.node.id, nodes, edges)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Sources/outputs can have any number of edges, we can take the first matching output field
|
||||
let candidateFields = map(candidateTemplate.outputs);
|
||||
if (isCollect) {
|
||||
// Narrow candidates to same field type as already is connected to the collect node
|
||||
const collectItemType = getCollectItemType(templates, nodes, edges, pendingConnection.node.id);
|
||||
if (collectItemType) {
|
||||
candidateFields = candidateFields.filter((field) => areTypesEqual(field.type, collectItemType));
|
||||
}
|
||||
}
|
||||
const candidateField = candidateFields.find((field) => {
|
||||
const isValid = validateConnectionTypes(field.type, pendingConnection.fieldTemplate.type);
|
||||
const isAlreadyConnected = edges.some((e) => e.source === candidateNode.id && e.sourceHandle === field.name);
|
||||
return isValid && !isAlreadyConnected;
|
||||
});
|
||||
if (candidateField) {
|
||||
return {
|
||||
source: candidateNode.id,
|
||||
sourceHandle: candidateField.name,
|
||||
target: pendingConnection.node.id,
|
||||
targetHandle: pendingConnection.fieldTemplate.name,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
};
|
@ -0,0 +1,23 @@
|
||||
import { getHasCycles } from 'features/nodes/store/util/getHasCycles';
|
||||
import { add, buildEdge, position } from 'features/nodes/store/util/testUtils';
|
||||
import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode';
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
describe(getHasCycles.name, () => {
|
||||
const n1 = buildInvocationNode(position, add);
|
||||
const n2 = buildInvocationNode(position, add);
|
||||
const n3 = buildInvocationNode(position, add);
|
||||
const nodes = [n1, n2, n3];
|
||||
|
||||
it('should return true if the graph WOULD have cycles after adding the edge', () => {
|
||||
const edges = [buildEdge(n1.id, 'value', n2.id, 'a'), buildEdge(n2.id, 'value', n3.id, 'a')];
|
||||
const result = getHasCycles(n3.id, n1.id, nodes, edges);
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
it('should return false if the graph WOULD NOT have cycles after adding the edge', () => {
|
||||
const edges = [buildEdge(n1.id, 'value', n2.id, 'a')];
|
||||
const result = getHasCycles(n2.id, n3.id, nodes, edges);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
});
|
@ -0,0 +1,30 @@
|
||||
import graphlib from '@dagrejs/graphlib';
|
||||
import type { Edge, Node } from 'reactflow';
|
||||
|
||||
/**
|
||||
* Check if adding an edge between the source and target nodes would create a cycle in the graph.
|
||||
* @param source The source node id
|
||||
* @param target The target node id
|
||||
* @param nodes The graph's current nodes
|
||||
* @param edges The graph's current edges
|
||||
* @returns True if the graph would be acyclic after adding the edge, false otherwise
|
||||
*/
|
||||
|
||||
export const getHasCycles = (source: string, target: string, nodes: Node[], edges: Edge[]) => {
|
||||
// construct graphlib graph from editor state
|
||||
const g = new graphlib.Graph();
|
||||
|
||||
nodes.forEach((n) => {
|
||||
g.setNode(n.id);
|
||||
});
|
||||
|
||||
edges.forEach((e) => {
|
||||
g.setEdge(e.source, e.target);
|
||||
});
|
||||
|
||||
// add the candidate edge
|
||||
g.setEdge(source, target);
|
||||
|
||||
// check if the graph is acyclic
|
||||
return !graphlib.alg.isAcyclic(g);
|
||||
};
|
1073
invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts
Normal file
1073
invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,149 @@
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode';
|
||||
import { set } from 'lodash-es';
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import { add, buildEdge, collect, main_model_loader, position, sub, templates } from './testUtils';
|
||||
import { buildAcceptResult, buildRejectResult, validateConnection } from './validateConnection';
|
||||
|
||||
describe(validateConnection.name, () => {
|
||||
it('should reject invalid connection to self', () => {
|
||||
const c = { source: 'add', sourceHandle: 'value', target: 'add', targetHandle: 'a' };
|
||||
const r = validateConnection(c, [], [], templates, null);
|
||||
expect(r).toEqual(buildRejectResult('nodes.cannotConnectToSelf'));
|
||||
});
|
||||
|
||||
describe('missing nodes', () => {
|
||||
const n1 = buildInvocationNode(position, add);
|
||||
const n2 = buildInvocationNode(position, sub);
|
||||
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
|
||||
|
||||
it('should reject missing source node', () => {
|
||||
const r = validateConnection(c, [n2], [], templates, null);
|
||||
expect(r).toEqual(buildRejectResult('nodes.missingNode'));
|
||||
});
|
||||
|
||||
it('should reject missing target node', () => {
|
||||
const r = validateConnection(c, [n1], [], templates, null);
|
||||
expect(r).toEqual(buildRejectResult('nodes.missingNode'));
|
||||
});
|
||||
});
|
||||
|
||||
describe('missing invocation templates', () => {
|
||||
const n1 = buildInvocationNode(position, add);
|
||||
const n2 = buildInvocationNode(position, sub);
|
||||
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
|
||||
const nodes = [n1, n2];
|
||||
|
||||
it('should reject missing source template', () => {
|
||||
const r = validateConnection(c, nodes, [], { sub }, null);
|
||||
expect(r).toEqual(buildRejectResult('nodes.missingInvocationTemplate'));
|
||||
});
|
||||
|
||||
it('should reject missing target template', () => {
|
||||
const r = validateConnection(c, nodes, [], { add }, null);
|
||||
expect(r).toEqual(buildRejectResult('nodes.missingInvocationTemplate'));
|
||||
});
|
||||
});
|
||||
|
||||
describe('missing field templates', () => {
|
||||
const n1 = buildInvocationNode(position, add);
|
||||
const n2 = buildInvocationNode(position, sub);
|
||||
const nodes = [n1, n2];
|
||||
|
||||
it('should reject missing source field template', () => {
|
||||
const c = { source: n1.id, sourceHandle: 'invalid', target: n2.id, targetHandle: 'a' };
|
||||
const r = validateConnection(c, nodes, [], templates, null);
|
||||
expect(r).toEqual(buildRejectResult('nodes.missingFieldTemplate'));
|
||||
});
|
||||
|
||||
it('should reject missing target field template', () => {
|
||||
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'invalid' };
|
||||
const r = validateConnection(c, nodes, [], templates, null);
|
||||
expect(r).toEqual(buildRejectResult('nodes.missingFieldTemplate'));
|
||||
});
|
||||
});
|
||||
|
||||
describe('duplicate connections', () => {
|
||||
const n1 = buildInvocationNode(position, add);
|
||||
const n2 = buildInvocationNode(position, sub);
|
||||
it('should accept non-duplicate connections', () => {
|
||||
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
|
||||
const r = validateConnection(c, [n1, n2], [], templates, null);
|
||||
expect(r).toEqual(buildAcceptResult());
|
||||
});
|
||||
it('should reject duplicate connections', () => {
|
||||
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
|
||||
const e = buildEdge(n1.id, 'value', n2.id, 'a');
|
||||
const r = validateConnection(c, [n1, n2], [e], templates, null);
|
||||
expect(r).toEqual(buildRejectResult('nodes.cannotDuplicateConnection'));
|
||||
});
|
||||
it('should accept duplicate connections if the duplicate is an ignored edge', () => {
|
||||
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
|
||||
const e = buildEdge(n1.id, 'value', n2.id, 'a');
|
||||
const r = validateConnection(c, [n1, n2], [e], templates, e);
|
||||
expect(r).toEqual(buildAcceptResult());
|
||||
});
|
||||
});
|
||||
|
||||
it('should reject connection to direct input', () => {
|
||||
// Create cloned add template w/ a direct input
|
||||
const addWithDirectAField = deepClone(add);
|
||||
set(addWithDirectAField, 'inputs.a.input', 'direct');
|
||||
set(addWithDirectAField, 'type', 'addWithDirectAField');
|
||||
|
||||
const n1 = buildInvocationNode(position, add);
|
||||
const n2 = buildInvocationNode(position, addWithDirectAField);
|
||||
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
|
||||
const r = validateConnection(c, [n1, n2], [], { add, addWithDirectAField }, null);
|
||||
expect(r).toEqual(buildRejectResult('nodes.cannotConnectToDirectInput'));
|
||||
});
|
||||
|
||||
it('should reject connection to a collect node with mismatched item types', () => {
|
||||
const n1 = buildInvocationNode(position, add);
|
||||
const n2 = buildInvocationNode(position, collect);
|
||||
const n3 = buildInvocationNode(position, main_model_loader);
|
||||
const nodes = [n1, n2, n3];
|
||||
const e1 = buildEdge(n1.id, 'value', n2.id, 'item');
|
||||
const edges = [e1];
|
||||
const c = { source: n3.id, sourceHandle: 'vae', target: n2.id, targetHandle: 'item' };
|
||||
const r = validateConnection(c, nodes, edges, templates, null);
|
||||
expect(r).toEqual(buildRejectResult('nodes.cannotMixAndMatchCollectionItemTypes'));
|
||||
});
|
||||
|
||||
it('should accept connection to a collect node with matching item types', () => {
|
||||
const n1 = buildInvocationNode(position, add);
|
||||
const n2 = buildInvocationNode(position, collect);
|
||||
const n3 = buildInvocationNode(position, sub);
|
||||
const nodes = [n1, n2, n3];
|
||||
const e1 = buildEdge(n1.id, 'value', n2.id, 'item');
|
||||
const edges = [e1];
|
||||
const c = { source: n3.id, sourceHandle: 'value', target: n2.id, targetHandle: 'item' };
|
||||
const r = validateConnection(c, nodes, edges, templates, null);
|
||||
expect(r).toEqual(buildAcceptResult());
|
||||
});
|
||||
|
||||
it('should reject connections to target field that is already connected', () => {
|
||||
const n1 = buildInvocationNode(position, add);
|
||||
const n2 = buildInvocationNode(position, add);
|
||||
const n3 = buildInvocationNode(position, add);
|
||||
const nodes = [n1, n2, n3];
|
||||
const e1 = buildEdge(n1.id, 'value', n2.id, 'a');
|
||||
const edges = [e1];
|
||||
const c = { source: n3.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
|
||||
const r = validateConnection(c, nodes, edges, templates, null);
|
||||
expect(r).toEqual(buildRejectResult('nodes.inputMayOnlyHaveOneConnection'));
|
||||
});
|
||||
|
||||
it('should accept connections to target field that is already connected (ignored edge)', () => {
|
||||
const n1 = buildInvocationNode(position, add);
|
||||
const n2 = buildInvocationNode(position, add);
|
||||
const n3 = buildInvocationNode(position, add);
|
||||
const nodes = [n1, n2, n3];
|
||||
const e1 = buildEdge(n1.id, 'value', n2.id, 'a');
|
||||
const edges = [e1];
|
||||
const c = { source: n3.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
|
||||
const r = validateConnection(c, nodes, edges, templates, e1);
|
||||
expect(r).toEqual(buildAcceptResult());
|
||||
});
|
||||
});
|
@ -0,0 +1,109 @@
|
||||
import type { Templates } from 'features/nodes/store/types';
|
||||
import { areTypesEqual } from 'features/nodes/store/util/areTypesEqual';
|
||||
import { getCollectItemType } from 'features/nodes/store/util/getCollectItemType';
|
||||
import type { AnyNode } from 'features/nodes/types/invocation';
|
||||
import type { Connection as NullableConnection, Edge } from 'reactflow';
|
||||
import type { O } from 'ts-toolbelt';
|
||||
|
||||
type Connection = O.NonNullable<NullableConnection>;
|
||||
|
||||
export type ValidateConnectionResult = {
|
||||
isValid: boolean;
|
||||
messageTKey?: string;
|
||||
};
|
||||
|
||||
export type ValidateConnectionFunc = (
|
||||
connection: Connection,
|
||||
nodes: AnyNode[],
|
||||
edges: Edge[],
|
||||
templates: Templates,
|
||||
ignoreEdge: Edge | null
|
||||
) => ValidateConnectionResult;
|
||||
|
||||
export const buildResult = (isValid: boolean, messageTKey?: string): ValidateConnectionResult => ({
|
||||
isValid,
|
||||
messageTKey,
|
||||
});
|
||||
|
||||
const getEqualityPredicate =
|
||||
(c: Connection) =>
|
||||
(e: Edge): boolean => {
|
||||
return (
|
||||
e.target === c.target &&
|
||||
e.targetHandle === c.targetHandle &&
|
||||
e.source === c.source &&
|
||||
e.sourceHandle === c.sourceHandle
|
||||
);
|
||||
};
|
||||
|
||||
export const buildAcceptResult = (): ValidateConnectionResult => ({ isValid: true });
|
||||
export const buildRejectResult = (messageTKey: string): ValidateConnectionResult => ({ isValid: false, messageTKey });
|
||||
|
||||
export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, templates, ignoreEdge) => {
|
||||
if (c.source === c.target) {
|
||||
return buildRejectResult('nodes.cannotConnectToSelf');
|
||||
}
|
||||
|
||||
const filteredEdges = edges.filter((e) => e.id !== ignoreEdge?.id);
|
||||
|
||||
if (filteredEdges.some(getEqualityPredicate(c))) {
|
||||
// We already have a connection from this source to this target
|
||||
return buildRejectResult('nodes.cannotDuplicateConnection');
|
||||
}
|
||||
|
||||
const sourceNode = nodes.find((n) => n.id === c.source);
|
||||
if (!sourceNode) {
|
||||
return buildRejectResult('nodes.missingNode');
|
||||
}
|
||||
|
||||
const targetNode = nodes.find((n) => n.id === c.target);
|
||||
if (!targetNode) {
|
||||
return buildRejectResult('nodes.missingNode');
|
||||
}
|
||||
|
||||
const sourceTemplate = templates[sourceNode.data.type];
|
||||
if (!sourceTemplate) {
|
||||
return buildRejectResult('nodes.missingInvocationTemplate');
|
||||
}
|
||||
|
||||
const targetTemplate = templates[targetNode.data.type];
|
||||
if (!targetTemplate) {
|
||||
return buildRejectResult('nodes.missingInvocationTemplate');
|
||||
}
|
||||
|
||||
const sourceFieldTemplate = sourceTemplate.outputs[c.sourceHandle];
|
||||
if (!sourceFieldTemplate) {
|
||||
return buildRejectResult('nodes.missingFieldTemplate');
|
||||
}
|
||||
|
||||
const targetFieldTemplate = targetTemplate.inputs[c.targetHandle];
|
||||
if (!targetFieldTemplate) {
|
||||
return buildRejectResult('nodes.missingFieldTemplate');
|
||||
}
|
||||
|
||||
if (targetFieldTemplate.input === 'direct') {
|
||||
return buildRejectResult('nodes.cannotConnectToDirectInput');
|
||||
}
|
||||
|
||||
if (targetNode.data.type === 'collect' && c.targetHandle === 'item') {
|
||||
// Collect nodes shouldn't mix and match field types
|
||||
const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id);
|
||||
if (collectItemType) {
|
||||
if (!areTypesEqual(sourceFieldTemplate.type, collectItemType)) {
|
||||
return buildRejectResult('nodes.cannotMixAndMatchCollectionItemTypes');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (
|
||||
edges.find((e) => {
|
||||
return e.target === c.target && e.targetHandle === c.targetHandle;
|
||||
}) &&
|
||||
// except CollectionItem inputs can have multiples
|
||||
targetFieldTemplate.type.name !== 'CollectionItemField'
|
||||
) {
|
||||
return buildRejectResult('nodes.inputMayOnlyHaveOneConnection');
|
||||
}
|
||||
|
||||
return buildAcceptResult();
|
||||
};
|
@ -0,0 +1,222 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import { validateConnectionTypes } from './validateConnectionTypes';
|
||||
|
||||
describe(validateConnectionTypes.name, () => {
|
||||
describe('generic cases', () => {
|
||||
it('should accept Scalar to Scalar of same type', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'FooField', isCollection: false, isCollectionOrScalar: false },
|
||||
{ name: 'FooField', isCollection: false, isCollectionOrScalar: false }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it('should accept Collection to Collection of same type', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'FooField', isCollection: true, isCollectionOrScalar: false },
|
||||
{ name: 'FooField', isCollection: true, isCollectionOrScalar: false }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it('should accept Scalar to CollectionOrScalar of same type', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'FooField', isCollection: false, isCollectionOrScalar: false },
|
||||
{ name: 'FooField', isCollection: false, isCollectionOrScalar: true }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it('should accept Collection to CollectionOrScalar of same type', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'FooField', isCollection: true, isCollectionOrScalar: false },
|
||||
{ name: 'FooField', isCollection: false, isCollectionOrScalar: true }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it('should reject Collection to Scalar of same type', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'FooField', isCollection: true, isCollectionOrScalar: false },
|
||||
{ name: 'FooField', isCollection: false, isCollectionOrScalar: false }
|
||||
);
|
||||
expect(r).toBe(false);
|
||||
});
|
||||
it('should reject CollectionOrScalar to Scalar of same type', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'FooField', isCollection: false, isCollectionOrScalar: true },
|
||||
{ name: 'FooField', isCollection: false, isCollectionOrScalar: false }
|
||||
);
|
||||
expect(r).toBe(false);
|
||||
});
|
||||
it('should reject mismatched types', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'FooField', isCollection: false, isCollectionOrScalar: false },
|
||||
{ name: 'BarField', isCollection: false, isCollectionOrScalar: false }
|
||||
);
|
||||
expect(r).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('special cases', () => {
|
||||
it('should reject a collection input to a collection input', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'CollectionField', isCollection: true, isCollectionOrScalar: false },
|
||||
{ name: 'CollectionField', isCollection: true, isCollectionOrScalar: false }
|
||||
);
|
||||
expect(r).toBe(false);
|
||||
});
|
||||
|
||||
it('should accept equal types', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: false },
|
||||
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: false }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
|
||||
describe('CollectionItemField', () => {
|
||||
it('should accept CollectionItemField to any Scalar target', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'CollectionItemField', isCollection: false, isCollectionOrScalar: false },
|
||||
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: false }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it('should accept CollectionItemField to any CollectionOrScalar target', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'CollectionItemField', isCollection: false, isCollectionOrScalar: false },
|
||||
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: true }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it('should accept any non-Collection to CollectionItemField', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: false },
|
||||
{ name: 'CollectionItemField', isCollection: false, isCollectionOrScalar: false }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it('should reject any Collection to CollectionItemField', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'IntegerField', isCollection: true, isCollectionOrScalar: false },
|
||||
{ name: 'CollectionItemField', isCollection: false, isCollectionOrScalar: false }
|
||||
);
|
||||
expect(r).toBe(false);
|
||||
});
|
||||
it('should reject any CollectionOrScalar to CollectionItemField', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: true },
|
||||
{ name: 'CollectionItemField', isCollection: false, isCollectionOrScalar: false }
|
||||
);
|
||||
expect(r).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('CollectionOrScalar', () => {
|
||||
it('should accept any Scalar of same type to CollectionOrScalar', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: false },
|
||||
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: true }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it('should accept any Collection of same type to CollectionOrScalar', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'IntegerField', isCollection: true, isCollectionOrScalar: false },
|
||||
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: true }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it('should accept any CollectionOrScalar of same type to CollectionOrScalar', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: true },
|
||||
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: true }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('CollectionField', () => {
|
||||
it('should accept any CollectionField to any Collection type', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'CollectionField', isCollection: false, isCollectionOrScalar: false },
|
||||
{ name: 'IntegerField', isCollection: true, isCollectionOrScalar: false }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it('should accept any CollectionField to any CollectionOrScalar type', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'CollectionField', isCollection: false, isCollectionOrScalar: false },
|
||||
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: true }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('subtype handling', () => {
|
||||
type TypePair = { t1: string; t2: string };
|
||||
const typePairs = [
|
||||
{ t1: 'IntegerField', t2: 'FloatField' },
|
||||
{ t1: 'IntegerField', t2: 'StringField' },
|
||||
{ t1: 'FloatField', t2: 'StringField' },
|
||||
];
|
||||
it.each(typePairs)('should accept Scalar $t1 to Scalar $t2', ({ t1, t2 }: TypePair) => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: t1, isCollection: false, isCollectionOrScalar: false },
|
||||
{ name: t2, isCollection: false, isCollectionOrScalar: false }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it.each(typePairs)('should accept Scalar $t1 to CollectionOrScalar $t2', ({ t1, t2 }: TypePair) => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: t1, isCollection: false, isCollectionOrScalar: false },
|
||||
{ name: t2, isCollection: false, isCollectionOrScalar: true }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it.each(typePairs)('should accept Collection $t1 to Collection $t2', ({ t1, t2 }: TypePair) => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: t1, isCollection: true, isCollectionOrScalar: false },
|
||||
{ name: t2, isCollection: false, isCollectionOrScalar: false }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it.each(typePairs)('should accept Collection $t1 to CollectionOrScalar $t2', ({ t1, t2 }: TypePair) => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: t1, isCollection: true, isCollectionOrScalar: false },
|
||||
{ name: t2, isCollection: false, isCollectionOrScalar: true }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it.each(typePairs)('should accept CollectionOrScalar $t1 to CollectionOrScalar $t2', ({ t1, t2 }: TypePair) => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: t1, isCollection: false, isCollectionOrScalar: true },
|
||||
{ name: t2, isCollection: false, isCollectionOrScalar: true }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('AnyField', () => {
|
||||
it('should accept any Scalar type to AnyField', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'FooField', isCollection: false, isCollectionOrScalar: false },
|
||||
{ name: 'AnyField', isCollection: false, isCollectionOrScalar: false }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it('should accept any Collection type to AnyField', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'FooField', isCollection: false, isCollectionOrScalar: false },
|
||||
{ name: 'AnyField', isCollection: true, isCollectionOrScalar: false }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it('should accept any CollectionOrScalar type to AnyField', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'FooField', isCollection: false, isCollectionOrScalar: false },
|
||||
{ name: 'AnyField', isCollection: false, isCollectionOrScalar: true }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
@ -0,0 +1,69 @@
|
||||
import { areTypesEqual } from 'features/nodes/store/util/areTypesEqual';
|
||||
import type { FieldType } from 'features/nodes/types/field';
|
||||
|
||||
/**
|
||||
* Validates that the source and target types are compatible for a connection.
|
||||
* @param sourceType The type of the source field.
|
||||
* @param targetType The type of the target field.
|
||||
* @returns True if the connection is valid, false otherwise.
|
||||
*/
|
||||
export const validateConnectionTypes = (sourceType: FieldType, targetType: FieldType) => {
|
||||
// TODO: There's a bug with Collect -> Iterate nodes:
|
||||
// https://github.com/invoke-ai/InvokeAI/issues/3956
|
||||
// Once this is resolved, we can remove this check.
|
||||
if (sourceType.name === 'CollectionField' && targetType.name === 'CollectionField') {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (areTypesEqual(sourceType, targetType)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Connection types must be the same for a connection, with exceptions:
|
||||
* - CollectionItem can connect to any non-Collection
|
||||
* - Non-Collections can connect to CollectionItem
|
||||
* - Anything (non-Collections, Collections, CollectionOrScalar) can connect to CollectionOrScalar of the same base type
|
||||
* - Generic Collection can connect to any other Collection or CollectionOrScalar
|
||||
* - Any Collection can connect to a Generic Collection
|
||||
*/
|
||||
const isCollectionItemToNonCollection = sourceType.name === 'CollectionItemField' && !targetType.isCollection;
|
||||
|
||||
const isNonCollectionToCollectionItem =
|
||||
targetType.name === 'CollectionItemField' && !sourceType.isCollection && !sourceType.isCollectionOrScalar;
|
||||
|
||||
const isAnythingToCollectionOrScalarOfSameBaseType =
|
||||
targetType.isCollectionOrScalar && sourceType.name === targetType.name;
|
||||
|
||||
const isGenericCollectionToAnyCollectionOrCollectionOrScalar =
|
||||
sourceType.name === 'CollectionField' && (targetType.isCollection || targetType.isCollectionOrScalar);
|
||||
|
||||
const isCollectionToGenericCollection = targetType.name === 'CollectionField' && sourceType.isCollection;
|
||||
|
||||
const areBothTypesSingle =
|
||||
!sourceType.isCollection &&
|
||||
!sourceType.isCollectionOrScalar &&
|
||||
!targetType.isCollection &&
|
||||
!targetType.isCollectionOrScalar;
|
||||
|
||||
const isIntToFloat = areBothTypesSingle && sourceType.name === 'IntegerField' && targetType.name === 'FloatField';
|
||||
|
||||
const isIntOrFloatToString =
|
||||
areBothTypesSingle &&
|
||||
(sourceType.name === 'IntegerField' || sourceType.name === 'FloatField') &&
|
||||
targetType.name === 'StringField';
|
||||
|
||||
const isTargetAnyType = targetType.name === 'AnyField';
|
||||
|
||||
// One of these must be true for the connection to be valid
|
||||
return (
|
||||
isCollectionItemToNonCollection ||
|
||||
isNonCollectionToCollectionItem ||
|
||||
isAnythingToCollectionOrScalarOfSameBaseType ||
|
||||
isGenericCollectionToAnyCollectionOrCollectionOrScalar ||
|
||||
isCollectionToGenericCollection ||
|
||||
isIntToFloat ||
|
||||
isIntOrFloatToString ||
|
||||
isTargetAnyType
|
||||
);
|
||||
};
|
@ -188,7 +188,6 @@ const zIntegerFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
});
|
||||
const zIntegerFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zIntegerFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
});
|
||||
export type IntegerFieldValue = z.infer<typeof zIntegerFieldValue>;
|
||||
export type IntegerFieldInputInstance = z.infer<typeof zIntegerFieldInputInstance>;
|
||||
@ -217,7 +216,6 @@ const zFloatFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
});
|
||||
const zFloatFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zFloatFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
});
|
||||
export type FloatFieldValue = z.infer<typeof zFloatFieldValue>;
|
||||
export type FloatFieldInputInstance = z.infer<typeof zFloatFieldInputInstance>;
|
||||
@ -243,7 +241,6 @@ const zStringFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
});
|
||||
const zStringFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zStringFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
});
|
||||
|
||||
export type StringFieldValue = z.infer<typeof zStringFieldValue>;
|
||||
@ -268,7 +265,6 @@ const zBooleanFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
});
|
||||
const zBooleanFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zBooleanFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
});
|
||||
export type BooleanFieldValue = z.infer<typeof zBooleanFieldValue>;
|
||||
export type BooleanFieldInputInstance = z.infer<typeof zBooleanFieldInputInstance>;
|
||||
@ -294,7 +290,6 @@ const zEnumFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
});
|
||||
const zEnumFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zEnumFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
});
|
||||
export type EnumFieldValue = z.infer<typeof zEnumFieldValue>;
|
||||
export type EnumFieldInputInstance = z.infer<typeof zEnumFieldInputInstance>;
|
||||
@ -318,7 +313,6 @@ const zImageFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
});
|
||||
const zImageFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zImageFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
});
|
||||
export type ImageFieldValue = z.infer<typeof zImageFieldValue>;
|
||||
export type ImageFieldInputInstance = z.infer<typeof zImageFieldInputInstance>;
|
||||
@ -342,7 +336,6 @@ const zBoardFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
});
|
||||
const zBoardFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zBoardFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
});
|
||||
export type BoardFieldValue = z.infer<typeof zBoardFieldValue>;
|
||||
export type BoardFieldInputInstance = z.infer<typeof zBoardFieldInputInstance>;
|
||||
@ -366,7 +359,6 @@ const zColorFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
});
|
||||
const zColorFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zColorFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
});
|
||||
export type ColorFieldValue = z.infer<typeof zColorFieldValue>;
|
||||
export type ColorFieldInputInstance = z.infer<typeof zColorFieldInputInstance>;
|
||||
@ -390,7 +382,6 @@ const zMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
});
|
||||
const zMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zMainModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
});
|
||||
export type MainModelFieldValue = z.infer<typeof zMainModelFieldValue>;
|
||||
export type MainModelFieldInputInstance = z.infer<typeof zMainModelFieldInputInstance>;
|
||||
@ -413,7 +404,6 @@ const zModelIdentifierFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
});
|
||||
const zModelIdentifierFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zModelIdentifierFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
});
|
||||
export type ModelIdentifierFieldValue = z.infer<typeof zModelIdentifierFieldValue>;
|
||||
export type ModelIdentifierFieldInputInstance = z.infer<typeof zModelIdentifierFieldInputInstance>;
|
||||
@ -437,7 +427,6 @@ const zSDXLMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
});
|
||||
const zSDXLMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zSDXLMainModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
});
|
||||
export type SDXLMainModelFieldInputInstance = z.infer<typeof zSDXLMainModelFieldInputInstance>;
|
||||
export type SDXLMainModelFieldInputTemplate = z.infer<typeof zSDXLMainModelFieldInputTemplate>;
|
||||
@ -461,7 +450,6 @@ const zSDXLRefinerModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
});
|
||||
const zSDXLRefinerModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zSDXLRefinerModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
});
|
||||
export type SDXLRefinerModelFieldValue = z.infer<typeof zSDXLRefinerModelFieldValue>;
|
||||
export type SDXLRefinerModelFieldInputInstance = z.infer<typeof zSDXLRefinerModelFieldInputInstance>;
|
||||
@ -485,7 +473,6 @@ const zVAEModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
});
|
||||
const zVAEModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zVAEModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
});
|
||||
export type VAEModelFieldValue = z.infer<typeof zVAEModelFieldValue>;
|
||||
export type VAEModelFieldInputInstance = z.infer<typeof zVAEModelFieldInputInstance>;
|
||||
@ -509,7 +496,6 @@ const zLoRAModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
});
|
||||
const zLoRAModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zLoRAModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
});
|
||||
export type LoRAModelFieldValue = z.infer<typeof zLoRAModelFieldValue>;
|
||||
export type LoRAModelFieldInputInstance = z.infer<typeof zLoRAModelFieldInputInstance>;
|
||||
@ -533,7 +519,6 @@ const zControlNetModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
});
|
||||
const zControlNetModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zControlNetModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
});
|
||||
export type ControlNetModelFieldValue = z.infer<typeof zControlNetModelFieldValue>;
|
||||
export type ControlNetModelFieldInputInstance = z.infer<typeof zControlNetModelFieldInputInstance>;
|
||||
@ -557,7 +542,6 @@ const zIPAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
});
|
||||
const zIPAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zIPAdapterModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
});
|
||||
export type IPAdapterModelFieldValue = z.infer<typeof zIPAdapterModelFieldValue>;
|
||||
export type IPAdapterModelFieldInputInstance = z.infer<typeof zIPAdapterModelFieldInputInstance>;
|
||||
@ -581,7 +565,6 @@ const zT2IAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
});
|
||||
const zT2IAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zT2IAdapterModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
});
|
||||
export type T2IAdapterModelFieldValue = z.infer<typeof zT2IAdapterModelFieldValue>;
|
||||
export type T2IAdapterModelFieldInputInstance = z.infer<typeof zT2IAdapterModelFieldInputInstance>;
|
||||
@ -605,7 +588,6 @@ const zSchedulerFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
});
|
||||
const zSchedulerFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zSchedulerFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
});
|
||||
export type SchedulerFieldValue = z.infer<typeof zSchedulerFieldValue>;
|
||||
export type SchedulerFieldInputInstance = z.infer<typeof zSchedulerFieldInputInstance>;
|
||||
@ -641,7 +623,6 @@ const zStatelessFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
});
|
||||
const zStatelessFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zStatelessFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
});
|
||||
|
||||
export type StatelessFieldInputTemplate = z.infer<typeof zStatelessFieldInputTemplate>;
|
||||
|
@ -1,942 +1,19 @@
|
||||
import { schema, templates } from 'features/nodes/store/util/testUtils';
|
||||
import { parseSchema } from 'features/nodes/util/schema/parseSchema';
|
||||
import { omit, pick } from 'lodash-es';
|
||||
import type { OpenAPIV3_1 } from 'openapi-types';
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
describe('parseSchema', () => {
|
||||
it('should parse the schema', () => {
|
||||
const templates = parseSchema(schema);
|
||||
expect(templates).toEqual(expected);
|
||||
const parsed = parseSchema(schema);
|
||||
expect(parsed).toEqual(templates);
|
||||
});
|
||||
it('should omit denied nodes', () => {
|
||||
const templates = parseSchema(schema, undefined, ['add']);
|
||||
expect(templates).toEqual(omit(expected, 'add'));
|
||||
const parsed = parseSchema(schema, undefined, ['add']);
|
||||
expect(parsed).toEqual(omit(templates, 'add'));
|
||||
});
|
||||
it('should include only allowed nodes', () => {
|
||||
const templates = parseSchema(schema, ['add']);
|
||||
expect(templates).toEqual(pick(expected, 'add'));
|
||||
const parsed = parseSchema(schema, ['add']);
|
||||
expect(parsed).toEqual(pick(templates, 'add'));
|
||||
});
|
||||
});
|
||||
|
||||
const expected = {
|
||||
add: {
|
||||
title: 'Add Integers',
|
||||
type: 'add',
|
||||
version: '1.0.1',
|
||||
tags: ['math', 'add'],
|
||||
description: 'Adds two numbers',
|
||||
outputType: 'integer_output',
|
||||
inputs: {
|
||||
a: {
|
||||
name: 'a',
|
||||
title: 'A',
|
||||
required: false,
|
||||
description: 'The first number',
|
||||
fieldKind: 'input',
|
||||
input: 'any',
|
||||
ui_hidden: false,
|
||||
type: {
|
||||
name: 'IntegerField',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
},
|
||||
default: 0,
|
||||
},
|
||||
b: {
|
||||
name: 'b',
|
||||
title: 'B',
|
||||
required: false,
|
||||
description: 'The second number',
|
||||
fieldKind: 'input',
|
||||
input: 'any',
|
||||
ui_hidden: false,
|
||||
type: {
|
||||
name: 'IntegerField',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
},
|
||||
default: 0,
|
||||
},
|
||||
},
|
||||
outputs: {
|
||||
value: {
|
||||
fieldKind: 'output',
|
||||
name: 'value',
|
||||
title: 'Value',
|
||||
description: 'The output integer',
|
||||
type: {
|
||||
name: 'IntegerField',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
},
|
||||
ui_hidden: false,
|
||||
},
|
||||
},
|
||||
useCache: true,
|
||||
nodePack: 'invokeai',
|
||||
classification: 'stable',
|
||||
},
|
||||
scheduler: {
|
||||
title: 'Scheduler',
|
||||
type: 'scheduler',
|
||||
version: '1.0.0',
|
||||
tags: ['scheduler'],
|
||||
description: 'Selects a scheduler.',
|
||||
outputType: 'scheduler_output',
|
||||
inputs: {
|
||||
scheduler: {
|
||||
name: 'scheduler',
|
||||
title: 'Scheduler',
|
||||
required: false,
|
||||
description: 'Scheduler to use during inference',
|
||||
fieldKind: 'input',
|
||||
input: 'any',
|
||||
ui_hidden: false,
|
||||
ui_type: 'SchedulerField',
|
||||
type: {
|
||||
name: 'SchedulerField',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
originalType: {
|
||||
name: 'EnumField',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
},
|
||||
},
|
||||
default: 'euler',
|
||||
},
|
||||
},
|
||||
outputs: {
|
||||
scheduler: {
|
||||
fieldKind: 'output',
|
||||
name: 'scheduler',
|
||||
title: 'Scheduler',
|
||||
description: 'Scheduler to use during inference',
|
||||
type: {
|
||||
name: 'SchedulerField',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
originalType: {
|
||||
name: 'EnumField',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
},
|
||||
},
|
||||
ui_hidden: false,
|
||||
ui_type: 'SchedulerField',
|
||||
},
|
||||
},
|
||||
useCache: true,
|
||||
nodePack: 'invokeai',
|
||||
classification: 'stable',
|
||||
},
|
||||
main_model_loader: {
|
||||
title: 'Main Model',
|
||||
type: 'main_model_loader',
|
||||
version: '1.0.2',
|
||||
tags: ['model'],
|
||||
description: 'Loads a main model, outputting its submodels.',
|
||||
outputType: 'model_loader_output',
|
||||
inputs: {
|
||||
model: {
|
||||
name: 'model',
|
||||
title: 'Model',
|
||||
required: true,
|
||||
description: 'Main model (UNet, VAE, CLIP) to load',
|
||||
fieldKind: 'input',
|
||||
input: 'direct',
|
||||
ui_hidden: false,
|
||||
ui_type: 'MainModelField',
|
||||
type: {
|
||||
name: 'MainModelField',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
originalType: {
|
||||
name: 'ModelIdentifierField',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
outputs: {
|
||||
vae: {
|
||||
fieldKind: 'output',
|
||||
name: 'vae',
|
||||
title: 'VAE',
|
||||
description: 'VAE',
|
||||
type: {
|
||||
name: 'VAEField',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
},
|
||||
ui_hidden: false,
|
||||
},
|
||||
clip: {
|
||||
fieldKind: 'output',
|
||||
name: 'clip',
|
||||
title: 'CLIP',
|
||||
description: 'CLIP (tokenizer, text encoder, LoRAs) and skipped layer count',
|
||||
type: {
|
||||
name: 'CLIPField',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
},
|
||||
ui_hidden: false,
|
||||
},
|
||||
unet: {
|
||||
fieldKind: 'output',
|
||||
name: 'unet',
|
||||
title: 'UNet',
|
||||
description: 'UNet (scheduler, LoRAs)',
|
||||
type: {
|
||||
name: 'UNetField',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
},
|
||||
ui_hidden: false,
|
||||
},
|
||||
},
|
||||
useCache: true,
|
||||
nodePack: 'invokeai',
|
||||
classification: 'stable',
|
||||
},
|
||||
collect: {
|
||||
title: 'Collect',
|
||||
type: 'collect',
|
||||
version: '1.0.0',
|
||||
tags: [],
|
||||
description: 'Collects values into a collection',
|
||||
outputType: 'collect_output',
|
||||
inputs: {
|
||||
item: {
|
||||
name: 'item',
|
||||
title: 'Collection Item',
|
||||
required: false,
|
||||
description: 'The item to collect (all inputs must be of the same type)',
|
||||
fieldKind: 'input',
|
||||
input: 'connection',
|
||||
ui_hidden: false,
|
||||
ui_type: 'CollectionItemField',
|
||||
type: {
|
||||
name: 'CollectionItemField',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
outputs: {
|
||||
collection: {
|
||||
fieldKind: 'output',
|
||||
name: 'collection',
|
||||
title: 'Collection',
|
||||
description: 'The collection of input items',
|
||||
type: {
|
||||
name: 'CollectionField',
|
||||
isCollection: true,
|
||||
isCollectionOrScalar: false,
|
||||
},
|
||||
ui_hidden: false,
|
||||
ui_type: 'CollectionField',
|
||||
},
|
||||
},
|
||||
useCache: true,
|
||||
classification: 'stable',
|
||||
},
|
||||
};
|
||||
|
||||
const schema = {
|
||||
openapi: '3.1.0',
|
||||
info: {
|
||||
title: 'Invoke - Community Edition',
|
||||
description: 'An API for invoking AI image operations',
|
||||
version: '1.0.0',
|
||||
},
|
||||
components: {
|
||||
schemas: {
|
||||
AddInvocation: {
|
||||
properties: {
|
||||
id: {
|
||||
type: 'string',
|
||||
title: 'Id',
|
||||
description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.',
|
||||
field_kind: 'node_attribute',
|
||||
},
|
||||
is_intermediate: {
|
||||
type: 'boolean',
|
||||
title: 'Is Intermediate',
|
||||
description: 'Whether or not this is an intermediate invocation.',
|
||||
default: false,
|
||||
field_kind: 'node_attribute',
|
||||
ui_type: 'IsIntermediate',
|
||||
},
|
||||
use_cache: {
|
||||
type: 'boolean',
|
||||
title: 'Use Cache',
|
||||
description: 'Whether or not to use the cache',
|
||||
default: true,
|
||||
field_kind: 'node_attribute',
|
||||
},
|
||||
a: {
|
||||
type: 'integer',
|
||||
title: 'A',
|
||||
description: 'The first number',
|
||||
default: 0,
|
||||
field_kind: 'input',
|
||||
input: 'any',
|
||||
orig_default: 0,
|
||||
orig_required: false,
|
||||
ui_hidden: false,
|
||||
},
|
||||
b: {
|
||||
type: 'integer',
|
||||
title: 'B',
|
||||
description: 'The second number',
|
||||
default: 0,
|
||||
field_kind: 'input',
|
||||
input: 'any',
|
||||
orig_default: 0,
|
||||
orig_required: false,
|
||||
ui_hidden: false,
|
||||
},
|
||||
type: {
|
||||
type: 'string',
|
||||
enum: ['add'],
|
||||
const: 'add',
|
||||
title: 'type',
|
||||
default: 'add',
|
||||
field_kind: 'node_attribute',
|
||||
},
|
||||
},
|
||||
type: 'object',
|
||||
required: ['type', 'id'],
|
||||
title: 'Add Integers',
|
||||
description: 'Adds two numbers',
|
||||
category: 'math',
|
||||
classification: 'stable',
|
||||
node_pack: 'invokeai',
|
||||
tags: ['math', 'add'],
|
||||
version: '1.0.1',
|
||||
output: {
|
||||
$ref: '#/components/schemas/IntegerOutput',
|
||||
},
|
||||
class: 'invocation',
|
||||
},
|
||||
IntegerOutput: {
|
||||
description: 'Base class for nodes that output a single integer',
|
||||
properties: {
|
||||
value: {
|
||||
description: 'The output integer',
|
||||
field_kind: 'output',
|
||||
title: 'Value',
|
||||
type: 'integer',
|
||||
ui_hidden: false,
|
||||
},
|
||||
type: {
|
||||
const: 'integer_output',
|
||||
default: 'integer_output',
|
||||
enum: ['integer_output'],
|
||||
field_kind: 'node_attribute',
|
||||
title: 'type',
|
||||
type: 'string',
|
||||
},
|
||||
},
|
||||
required: ['value', 'type', 'type'],
|
||||
title: 'IntegerOutput',
|
||||
type: 'object',
|
||||
class: 'output',
|
||||
},
|
||||
SchedulerInvocation: {
|
||||
properties: {
|
||||
id: {
|
||||
type: 'string',
|
||||
title: 'Id',
|
||||
description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.',
|
||||
field_kind: 'node_attribute',
|
||||
},
|
||||
is_intermediate: {
|
||||
type: 'boolean',
|
||||
title: 'Is Intermediate',
|
||||
description: 'Whether or not this is an intermediate invocation.',
|
||||
default: false,
|
||||
field_kind: 'node_attribute',
|
||||
ui_type: 'IsIntermediate',
|
||||
},
|
||||
use_cache: {
|
||||
type: 'boolean',
|
||||
title: 'Use Cache',
|
||||
description: 'Whether or not to use the cache',
|
||||
default: true,
|
||||
field_kind: 'node_attribute',
|
||||
},
|
||||
scheduler: {
|
||||
type: 'string',
|
||||
enum: [
|
||||
'ddim',
|
||||
'ddpm',
|
||||
'deis',
|
||||
'lms',
|
||||
'lms_k',
|
||||
'pndm',
|
||||
'heun',
|
||||
'heun_k',
|
||||
'euler',
|
||||
'euler_k',
|
||||
'euler_a',
|
||||
'kdpm_2',
|
||||
'kdpm_2_a',
|
||||
'dpmpp_2s',
|
||||
'dpmpp_2s_k',
|
||||
'dpmpp_2m',
|
||||
'dpmpp_2m_k',
|
||||
'dpmpp_2m_sde',
|
||||
'dpmpp_2m_sde_k',
|
||||
'dpmpp_sde',
|
||||
'dpmpp_sde_k',
|
||||
'unipc',
|
||||
'lcm',
|
||||
'tcd',
|
||||
],
|
||||
title: 'Scheduler',
|
||||
description: 'Scheduler to use during inference',
|
||||
default: 'euler',
|
||||
field_kind: 'input',
|
||||
input: 'any',
|
||||
orig_default: 'euler',
|
||||
orig_required: false,
|
||||
ui_hidden: false,
|
||||
ui_type: 'SchedulerField',
|
||||
},
|
||||
type: {
|
||||
type: 'string',
|
||||
enum: ['scheduler'],
|
||||
const: 'scheduler',
|
||||
title: 'type',
|
||||
default: 'scheduler',
|
||||
field_kind: 'node_attribute',
|
||||
},
|
||||
},
|
||||
type: 'object',
|
||||
required: ['type', 'id'],
|
||||
title: 'Scheduler',
|
||||
description: 'Selects a scheduler.',
|
||||
category: 'latents',
|
||||
classification: 'stable',
|
||||
node_pack: 'invokeai',
|
||||
tags: ['scheduler'],
|
||||
version: '1.0.0',
|
||||
output: {
|
||||
$ref: '#/components/schemas/SchedulerOutput',
|
||||
},
|
||||
class: 'invocation',
|
||||
},
|
||||
SchedulerOutput: {
|
||||
properties: {
|
||||
scheduler: {
|
||||
description: 'Scheduler to use during inference',
|
||||
enum: [
|
||||
'ddim',
|
||||
'ddpm',
|
||||
'deis',
|
||||
'lms',
|
||||
'lms_k',
|
||||
'pndm',
|
||||
'heun',
|
||||
'heun_k',
|
||||
'euler',
|
||||
'euler_k',
|
||||
'euler_a',
|
||||
'kdpm_2',
|
||||
'kdpm_2_a',
|
||||
'dpmpp_2s',
|
||||
'dpmpp_2s_k',
|
||||
'dpmpp_2m',
|
||||
'dpmpp_2m_k',
|
||||
'dpmpp_2m_sde',
|
||||
'dpmpp_2m_sde_k',
|
||||
'dpmpp_sde',
|
||||
'dpmpp_sde_k',
|
||||
'unipc',
|
||||
'lcm',
|
||||
'tcd',
|
||||
],
|
||||
field_kind: 'output',
|
||||
title: 'Scheduler',
|
||||
type: 'string',
|
||||
ui_hidden: false,
|
||||
ui_type: 'SchedulerField',
|
||||
},
|
||||
type: {
|
||||
const: 'scheduler_output',
|
||||
default: 'scheduler_output',
|
||||
enum: ['scheduler_output'],
|
||||
field_kind: 'node_attribute',
|
||||
title: 'type',
|
||||
type: 'string',
|
||||
},
|
||||
},
|
||||
required: ['scheduler', 'type', 'type'],
|
||||
title: 'SchedulerOutput',
|
||||
type: 'object',
|
||||
class: 'output',
|
||||
},
|
||||
MainModelLoaderInvocation: {
|
||||
properties: {
|
||||
id: {
|
||||
type: 'string',
|
||||
title: 'Id',
|
||||
description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.',
|
||||
field_kind: 'node_attribute',
|
||||
},
|
||||
is_intermediate: {
|
||||
type: 'boolean',
|
||||
title: 'Is Intermediate',
|
||||
description: 'Whether or not this is an intermediate invocation.',
|
||||
default: false,
|
||||
field_kind: 'node_attribute',
|
||||
ui_type: 'IsIntermediate',
|
||||
},
|
||||
use_cache: {
|
||||
type: 'boolean',
|
||||
title: 'Use Cache',
|
||||
description: 'Whether or not to use the cache',
|
||||
default: true,
|
||||
field_kind: 'node_attribute',
|
||||
},
|
||||
model: {
|
||||
allOf: [
|
||||
{
|
||||
$ref: '#/components/schemas/ModelIdentifierField',
|
||||
},
|
||||
],
|
||||
description: 'Main model (UNet, VAE, CLIP) to load',
|
||||
field_kind: 'input',
|
||||
input: 'direct',
|
||||
orig_required: true,
|
||||
ui_hidden: false,
|
||||
ui_type: 'MainModelField',
|
||||
},
|
||||
type: {
|
||||
type: 'string',
|
||||
enum: ['main_model_loader'],
|
||||
const: 'main_model_loader',
|
||||
title: 'type',
|
||||
default: 'main_model_loader',
|
||||
field_kind: 'node_attribute',
|
||||
},
|
||||
},
|
||||
type: 'object',
|
||||
required: ['model', 'type', 'id'],
|
||||
title: 'Main Model',
|
||||
description: 'Loads a main model, outputting its submodels.',
|
||||
category: 'model',
|
||||
classification: 'stable',
|
||||
node_pack: 'invokeai',
|
||||
tags: ['model'],
|
||||
version: '1.0.2',
|
||||
output: {
|
||||
$ref: '#/components/schemas/ModelLoaderOutput',
|
||||
},
|
||||
class: 'invocation',
|
||||
},
|
||||
ModelIdentifierField: {
|
||||
properties: {
|
||||
key: {
|
||||
description: "The model's unique key",
|
||||
title: 'Key',
|
||||
type: 'string',
|
||||
},
|
||||
hash: {
|
||||
description: "The model's BLAKE3 hash",
|
||||
title: 'Hash',
|
||||
type: 'string',
|
||||
},
|
||||
name: {
|
||||
description: "The model's name",
|
||||
title: 'Name',
|
||||
type: 'string',
|
||||
},
|
||||
base: {
|
||||
allOf: [
|
||||
{
|
||||
$ref: '#/components/schemas/BaseModelType',
|
||||
},
|
||||
],
|
||||
description: "The model's base model type",
|
||||
},
|
||||
type: {
|
||||
allOf: [
|
||||
{
|
||||
$ref: '#/components/schemas/ModelType',
|
||||
},
|
||||
],
|
||||
description: "The model's type",
|
||||
},
|
||||
submodel_type: {
|
||||
anyOf: [
|
||||
{
|
||||
$ref: '#/components/schemas/SubModelType',
|
||||
},
|
||||
{
|
||||
type: 'null',
|
||||
},
|
||||
],
|
||||
default: null,
|
||||
description: 'The submodel to load, if this is a main model',
|
||||
},
|
||||
},
|
||||
required: ['key', 'hash', 'name', 'base', 'type'],
|
||||
title: 'ModelIdentifierField',
|
||||
type: 'object',
|
||||
},
|
||||
BaseModelType: {
|
||||
description: 'Base model type.',
|
||||
enum: ['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'],
|
||||
title: 'BaseModelType',
|
||||
type: 'string',
|
||||
},
|
||||
ModelType: {
|
||||
description: 'Model type.',
|
||||
enum: ['onnx', 'main', 'vae', 'lora', 'controlnet', 'embedding', 'ip_adapter', 'clip_vision', 't2i_adapter'],
|
||||
title: 'ModelType',
|
||||
type: 'string',
|
||||
},
|
||||
SubModelType: {
|
||||
description: 'Submodel type.',
|
||||
enum: [
|
||||
'unet',
|
||||
'text_encoder',
|
||||
'text_encoder_2',
|
||||
'tokenizer',
|
||||
'tokenizer_2',
|
||||
'vae',
|
||||
'vae_decoder',
|
||||
'vae_encoder',
|
||||
'scheduler',
|
||||
'safety_checker',
|
||||
],
|
||||
title: 'SubModelType',
|
||||
type: 'string',
|
||||
},
|
||||
ModelLoaderOutput: {
|
||||
description: 'Model loader output',
|
||||
properties: {
|
||||
vae: {
|
||||
allOf: [
|
||||
{
|
||||
$ref: '#/components/schemas/VAEField',
|
||||
},
|
||||
],
|
||||
description: 'VAE',
|
||||
field_kind: 'output',
|
||||
title: 'VAE',
|
||||
ui_hidden: false,
|
||||
},
|
||||
type: {
|
||||
const: 'model_loader_output',
|
||||
default: 'model_loader_output',
|
||||
enum: ['model_loader_output'],
|
||||
field_kind: 'node_attribute',
|
||||
title: 'type',
|
||||
type: 'string',
|
||||
},
|
||||
clip: {
|
||||
allOf: [
|
||||
{
|
||||
$ref: '#/components/schemas/CLIPField',
|
||||
},
|
||||
],
|
||||
description: 'CLIP (tokenizer, text encoder, LoRAs) and skipped layer count',
|
||||
field_kind: 'output',
|
||||
title: 'CLIP',
|
||||
ui_hidden: false,
|
||||
},
|
||||
unet: {
|
||||
allOf: [
|
||||
{
|
||||
$ref: '#/components/schemas/UNetField',
|
||||
},
|
||||
],
|
||||
description: 'UNet (scheduler, LoRAs)',
|
||||
field_kind: 'output',
|
||||
title: 'UNet',
|
||||
ui_hidden: false,
|
||||
},
|
||||
},
|
||||
required: ['vae', 'type', 'clip', 'unet', 'type'],
|
||||
title: 'ModelLoaderOutput',
|
||||
type: 'object',
|
||||
class: 'output',
|
||||
},
|
||||
UNetField: {
|
||||
properties: {
|
||||
unet: {
|
||||
allOf: [
|
||||
{
|
||||
$ref: '#/components/schemas/ModelIdentifierField',
|
||||
},
|
||||
],
|
||||
description: 'Info to load unet submodel',
|
||||
},
|
||||
scheduler: {
|
||||
allOf: [
|
||||
{
|
||||
$ref: '#/components/schemas/ModelIdentifierField',
|
||||
},
|
||||
],
|
||||
description: 'Info to load scheduler submodel',
|
||||
},
|
||||
loras: {
|
||||
description: 'LoRAs to apply on model loading',
|
||||
items: {
|
||||
$ref: '#/components/schemas/LoRAField',
|
||||
},
|
||||
title: 'Loras',
|
||||
type: 'array',
|
||||
},
|
||||
seamless_axes: {
|
||||
description: 'Axes("x" and "y") to which apply seamless',
|
||||
items: {
|
||||
type: 'string',
|
||||
},
|
||||
title: 'Seamless Axes',
|
||||
type: 'array',
|
||||
},
|
||||
freeu_config: {
|
||||
anyOf: [
|
||||
{
|
||||
$ref: '#/components/schemas/FreeUConfig',
|
||||
},
|
||||
{
|
||||
type: 'null',
|
||||
},
|
||||
],
|
||||
default: null,
|
||||
description: 'FreeU configuration',
|
||||
},
|
||||
},
|
||||
required: ['unet', 'scheduler', 'loras'],
|
||||
title: 'UNetField',
|
||||
type: 'object',
|
||||
class: 'output',
|
||||
},
|
||||
LoRAField: {
|
||||
properties: {
|
||||
lora: {
|
||||
allOf: [
|
||||
{
|
||||
$ref: '#/components/schemas/ModelIdentifierField',
|
||||
},
|
||||
],
|
||||
description: 'Info to load lora model',
|
||||
},
|
||||
weight: {
|
||||
description: 'Weight to apply to lora model',
|
||||
title: 'Weight',
|
||||
type: 'number',
|
||||
},
|
||||
},
|
||||
required: ['lora', 'weight'],
|
||||
title: 'LoRAField',
|
||||
type: 'object',
|
||||
class: 'output',
|
||||
},
|
||||
FreeUConfig: {
|
||||
description:
|
||||
'Configuration for the FreeU hyperparameters.\n- https://huggingface.co/docs/diffusers/main/en/using-diffusers/freeu\n- https://github.com/ChenyangSi/FreeU',
|
||||
properties: {
|
||||
s1: {
|
||||
description:
|
||||
'Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.',
|
||||
maximum: 3.0,
|
||||
minimum: -1.0,
|
||||
title: 'S1',
|
||||
type: 'number',
|
||||
},
|
||||
s2: {
|
||||
description:
|
||||
'Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.',
|
||||
maximum: 3.0,
|
||||
minimum: -1.0,
|
||||
title: 'S2',
|
||||
type: 'number',
|
||||
},
|
||||
b1: {
|
||||
description: 'Scaling factor for stage 1 to amplify the contributions of backbone features.',
|
||||
maximum: 3.0,
|
||||
minimum: -1.0,
|
||||
title: 'B1',
|
||||
type: 'number',
|
||||
},
|
||||
b2: {
|
||||
description: 'Scaling factor for stage 2 to amplify the contributions of backbone features.',
|
||||
maximum: 3.0,
|
||||
minimum: -1.0,
|
||||
title: 'B2',
|
||||
type: 'number',
|
||||
},
|
||||
},
|
||||
required: ['s1', 's2', 'b1', 'b2'],
|
||||
title: 'FreeUConfig',
|
||||
type: 'object',
|
||||
class: 'output',
|
||||
},
|
||||
VAEField: {
|
||||
properties: {
|
||||
vae: {
|
||||
allOf: [
|
||||
{
|
||||
$ref: '#/components/schemas/ModelIdentifierField',
|
||||
},
|
||||
],
|
||||
description: 'Info to load vae submodel',
|
||||
},
|
||||
seamless_axes: {
|
||||
description: 'Axes("x" and "y") to which apply seamless',
|
||||
items: {
|
||||
type: 'string',
|
||||
},
|
||||
title: 'Seamless Axes',
|
||||
type: 'array',
|
||||
},
|
||||
},
|
||||
required: ['vae'],
|
||||
title: 'VAEField',
|
||||
type: 'object',
|
||||
class: 'output',
|
||||
},
|
||||
CLIPField: {
|
||||
properties: {
|
||||
tokenizer: {
|
||||
allOf: [
|
||||
{
|
||||
$ref: '#/components/schemas/ModelIdentifierField',
|
||||
},
|
||||
],
|
||||
description: 'Info to load tokenizer submodel',
|
||||
},
|
||||
text_encoder: {
|
||||
allOf: [
|
||||
{
|
||||
$ref: '#/components/schemas/ModelIdentifierField',
|
||||
},
|
||||
],
|
||||
description: 'Info to load text_encoder submodel',
|
||||
},
|
||||
skipped_layers: {
|
||||
description: 'Number of skipped layers in text_encoder',
|
||||
title: 'Skipped Layers',
|
||||
type: 'integer',
|
||||
},
|
||||
loras: {
|
||||
description: 'LoRAs to apply on model loading',
|
||||
items: {
|
||||
$ref: '#/components/schemas/LoRAField',
|
||||
},
|
||||
title: 'Loras',
|
||||
type: 'array',
|
||||
},
|
||||
},
|
||||
required: ['tokenizer', 'text_encoder', 'skipped_layers', 'loras'],
|
||||
title: 'CLIPField',
|
||||
type: 'object',
|
||||
class: 'output',
|
||||
},
|
||||
CollectInvocation: {
|
||||
properties: {
|
||||
id: {
|
||||
type: 'string',
|
||||
title: 'Id',
|
||||
description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.',
|
||||
field_kind: 'node_attribute',
|
||||
},
|
||||
is_intermediate: {
|
||||
type: 'boolean',
|
||||
title: 'Is Intermediate',
|
||||
description: 'Whether or not this is an intermediate invocation.',
|
||||
default: false,
|
||||
field_kind: 'node_attribute',
|
||||
ui_type: 'IsIntermediate',
|
||||
},
|
||||
use_cache: {
|
||||
type: 'boolean',
|
||||
title: 'Use Cache',
|
||||
description: 'Whether or not to use the cache',
|
||||
default: true,
|
||||
field_kind: 'node_attribute',
|
||||
},
|
||||
item: {
|
||||
anyOf: [
|
||||
{},
|
||||
{
|
||||
type: 'null',
|
||||
},
|
||||
],
|
||||
title: 'Collection Item',
|
||||
description: 'The item to collect (all inputs must be of the same type)',
|
||||
field_kind: 'input',
|
||||
input: 'connection',
|
||||
orig_required: false,
|
||||
ui_hidden: false,
|
||||
ui_type: 'CollectionItemField',
|
||||
},
|
||||
collection: {
|
||||
items: {},
|
||||
type: 'array',
|
||||
title: 'Collection',
|
||||
description: 'The collection, will be provided on execution',
|
||||
default: [],
|
||||
field_kind: 'input',
|
||||
input: 'any',
|
||||
orig_default: [],
|
||||
orig_required: false,
|
||||
ui_hidden: true,
|
||||
},
|
||||
type: {
|
||||
type: 'string',
|
||||
enum: ['collect'],
|
||||
const: 'collect',
|
||||
title: 'type',
|
||||
default: 'collect',
|
||||
field_kind: 'node_attribute',
|
||||
},
|
||||
},
|
||||
type: 'object',
|
||||
required: ['type', 'id'],
|
||||
title: 'CollectInvocation',
|
||||
description: 'Collects values into a collection',
|
||||
classification: 'stable',
|
||||
version: '1.0.0',
|
||||
output: {
|
||||
$ref: '#/components/schemas/CollectInvocationOutput',
|
||||
},
|
||||
class: 'invocation',
|
||||
},
|
||||
CollectInvocationOutput: {
|
||||
properties: {
|
||||
collection: {
|
||||
description: 'The collection of input items',
|
||||
field_kind: 'output',
|
||||
items: {},
|
||||
title: 'Collection',
|
||||
type: 'array',
|
||||
ui_hidden: false,
|
||||
ui_type: 'CollectionField',
|
||||
},
|
||||
type: {
|
||||
const: 'collect_output',
|
||||
default: 'collect_output',
|
||||
enum: ['collect_output'],
|
||||
field_kind: 'node_attribute',
|
||||
title: 'type',
|
||||
type: 'string',
|
||||
},
|
||||
},
|
||||
required: ['collection', 'type', 'type'],
|
||||
title: 'CollectInvocationOutput',
|
||||
type: 'object',
|
||||
class: 'output',
|
||||
},
|
||||
},
|
||||
},
|
||||
} as OpenAPIV3_1.Document;
|
||||
|
Loading…
Reference in New Issue
Block a user