mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
tests(ui): add tests for consolidated connection validation
This commit is contained in:
parent
6f7160b9fd
commit
3fcb2720d7
@ -775,6 +775,9 @@
|
|||||||
"cannotConnectToSelf": "Cannot connect to self",
|
"cannotConnectToSelf": "Cannot connect to self",
|
||||||
"cannotDuplicateConnection": "Cannot create duplicate connections",
|
"cannotDuplicateConnection": "Cannot create duplicate connections",
|
||||||
"cannotMixAndMatchCollectionItemTypes": "Cannot mix and match collection item types",
|
"cannotMixAndMatchCollectionItemTypes": "Cannot mix and match collection item types",
|
||||||
|
"missingNode": "Missing invocation node",
|
||||||
|
"missingInvocationTemplate": "Missing invocation template",
|
||||||
|
"missingFieldTemplate": "Missing field template",
|
||||||
"nodePack": "Node pack",
|
"nodePack": "Node pack",
|
||||||
"collection": "Collection",
|
"collection": "Collection",
|
||||||
"collectionFieldType": "{{name}} Collection",
|
"collectionFieldType": "{{name}} Collection",
|
||||||
|
@ -17,7 +17,8 @@ import {
|
|||||||
nodeAdded,
|
nodeAdded,
|
||||||
openAddNodePopover,
|
openAddNodePopover,
|
||||||
} from 'features/nodes/store/nodesSlice';
|
} from 'features/nodes/store/nodesSlice';
|
||||||
import { getFirstValidConnection, validateSourceAndTargetTypes } from 'features/nodes/store/util/connectionValidation';
|
import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection';
|
||||||
|
import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes';
|
||||||
import type { AnyNode } from 'features/nodes/types/invocation';
|
import type { AnyNode } from 'features/nodes/types/invocation';
|
||||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||||
import { filter, map, memoize, some } from 'lodash-es';
|
import { filter, map, memoize, some } from 'lodash-es';
|
||||||
@ -77,7 +78,7 @@ const AddNodePopover = () => {
|
|||||||
return some(fields, (field) => {
|
return some(fields, (field) => {
|
||||||
const sourceType = pendingFieldKind === 'input' ? field.type : pendingConnection.fieldTemplate.type;
|
const sourceType = pendingFieldKind === 'input' ? field.type : pendingConnection.fieldTemplate.type;
|
||||||
const targetType = pendingFieldKind === 'output' ? field.type : pendingConnection.fieldTemplate.type;
|
const targetType = pendingFieldKind === 'output' ? field.type : pendingConnection.fieldTemplate.type;
|
||||||
return validateSourceAndTargetTypes(sourceType, targetType);
|
return validateConnectionTypes(sourceType, targetType);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}, [templates, pendingConnection]);
|
}, [templates, pendingConnection]);
|
||||||
|
@ -8,7 +8,7 @@ import {
|
|||||||
$templates,
|
$templates,
|
||||||
connectionMade,
|
connectionMade,
|
||||||
} from 'features/nodes/store/nodesSlice';
|
} from 'features/nodes/store/nodesSlice';
|
||||||
import { getFirstValidConnection } from 'features/nodes/store/util/connectionValidation';
|
import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection';
|
||||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||||
import { isString } from 'lodash-es';
|
import { isString } from 'lodash-es';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
|
@ -2,12 +2,10 @@
|
|||||||
import { useStore } from '@nanostores/react';
|
import { useStore } from '@nanostores/react';
|
||||||
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
|
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
|
||||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||||
import {
|
import { areTypesEqual } from 'features/nodes/store/util/areTypesEqual';
|
||||||
areTypesEqual,
|
import { getCollectItemType } from 'features/nodes/store/util/getCollectItemType';
|
||||||
getCollectItemType,
|
import { getHasCycles } from 'features/nodes/store/util/getHasCycles';
|
||||||
getHasCycles,
|
import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes';
|
||||||
validateSourceAndTargetTypes,
|
|
||||||
} from 'features/nodes/store/util/connectionValidation';
|
|
||||||
import type { InvocationNodeData } from 'features/nodes/types/invocation';
|
import type { InvocationNodeData } from 'features/nodes/types/invocation';
|
||||||
import { useCallback } from 'react';
|
import { useCallback } from 'react';
|
||||||
import type { Connection, Node } from 'reactflow';
|
import type { Connection, Node } from 'reactflow';
|
||||||
@ -88,7 +86,7 @@ export const useIsValidConnection = () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Must use the originalType here if it exists
|
// Must use the originalType here if it exists
|
||||||
if (!validateSourceAndTargetTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) {
|
if (!validateConnectionTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -0,0 +1,101 @@
|
|||||||
|
import { describe, expect, it } from 'vitest';
|
||||||
|
|
||||||
|
import { areTypesEqual } from './areTypesEqual';
|
||||||
|
|
||||||
|
describe(areTypesEqual.name, () => {
|
||||||
|
it('should handle equal source and target type', () => {
|
||||||
|
const sourceType = {
|
||||||
|
name: 'IntegerField',
|
||||||
|
isCollection: false,
|
||||||
|
isCollectionOrScalar: false,
|
||||||
|
originalType: {
|
||||||
|
name: 'Foo',
|
||||||
|
isCollection: false,
|
||||||
|
isCollectionOrScalar: false,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
const targetType = {
|
||||||
|
name: 'IntegerField',
|
||||||
|
isCollection: false,
|
||||||
|
isCollectionOrScalar: false,
|
||||||
|
originalType: {
|
||||||
|
name: 'Bar',
|
||||||
|
isCollection: false,
|
||||||
|
isCollectionOrScalar: false,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
expect(areTypesEqual(sourceType, targetType)).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle equal source type and original target type', () => {
|
||||||
|
const sourceType = {
|
||||||
|
name: 'IntegerField',
|
||||||
|
isCollection: false,
|
||||||
|
isCollectionOrScalar: false,
|
||||||
|
originalType: {
|
||||||
|
name: 'Foo',
|
||||||
|
isCollection: false,
|
||||||
|
isCollectionOrScalar: false,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
const targetType = {
|
||||||
|
name: 'Bar',
|
||||||
|
isCollection: false,
|
||||||
|
isCollectionOrScalar: false,
|
||||||
|
originalType: {
|
||||||
|
name: 'IntegerField',
|
||||||
|
isCollection: false,
|
||||||
|
isCollectionOrScalar: false,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
expect(areTypesEqual(sourceType, targetType)).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle equal original source type and target type', () => {
|
||||||
|
const sourceType = {
|
||||||
|
name: 'Foo',
|
||||||
|
isCollection: false,
|
||||||
|
isCollectionOrScalar: false,
|
||||||
|
originalType: {
|
||||||
|
name: 'IntegerField',
|
||||||
|
isCollection: false,
|
||||||
|
isCollectionOrScalar: false,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
const targetType = {
|
||||||
|
name: 'IntegerField',
|
||||||
|
isCollection: false,
|
||||||
|
isCollectionOrScalar: false,
|
||||||
|
originalType: {
|
||||||
|
name: 'Bar',
|
||||||
|
isCollection: false,
|
||||||
|
isCollectionOrScalar: false,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
expect(areTypesEqual(sourceType, targetType)).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle equal original source type and original target type', () => {
|
||||||
|
const sourceType = {
|
||||||
|
name: 'Foo',
|
||||||
|
isCollection: false,
|
||||||
|
isCollectionOrScalar: false,
|
||||||
|
originalType: {
|
||||||
|
name: 'IntegerField',
|
||||||
|
isCollection: false,
|
||||||
|
isCollectionOrScalar: false,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
const targetType = {
|
||||||
|
name: 'Bar',
|
||||||
|
isCollection: false,
|
||||||
|
isCollectionOrScalar: false,
|
||||||
|
originalType: {
|
||||||
|
name: 'IntegerField',
|
||||||
|
isCollection: false,
|
||||||
|
isCollectionOrScalar: false,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
expect(areTypesEqual(sourceType, targetType)).toBe(true);
|
||||||
|
});
|
||||||
|
});
|
@ -0,0 +1,30 @@
|
|||||||
|
import type { FieldType } from 'features/nodes/types/field';
|
||||||
|
import { isEqual, omit } from 'lodash-es';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Checks if two types are equal. If the field types have original types, those are also compared. Any match is
|
||||||
|
* considered equal. For example, if the source type and original target type match, the types are considered equal.
|
||||||
|
* @param sourceType The type of the source field.
|
||||||
|
* @param targetType The type of the target field.
|
||||||
|
* @returns True if the types are equal, false otherwise.
|
||||||
|
*/
|
||||||
|
|
||||||
|
export const areTypesEqual = (sourceType: FieldType, targetType: FieldType) => {
|
||||||
|
const _sourceType = 'originalType' in sourceType ? omit(sourceType, 'originalType') : sourceType;
|
||||||
|
const _targetType = 'originalType' in targetType ? omit(targetType, 'originalType') : targetType;
|
||||||
|
const _sourceTypeOriginal = 'originalType' in sourceType ? sourceType.originalType : null;
|
||||||
|
const _targetTypeOriginal = 'originalType' in targetType ? targetType.originalType : null;
|
||||||
|
if (isEqual(_sourceType, _targetType)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (_targetTypeOriginal && isEqual(_sourceType, _targetTypeOriginal)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (_sourceTypeOriginal && isEqual(_sourceTypeOriginal, _targetType)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (_sourceTypeOriginal && _targetTypeOriginal && isEqual(_sourceTypeOriginal, _targetTypeOriginal)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
};
|
@ -1,179 +1,16 @@
|
|||||||
import graphlib from '@dagrejs/graphlib';
|
|
||||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||||
import type { NodesState, PendingConnection, Templates } from 'features/nodes/store/types';
|
import type { NodesState, PendingConnection, Templates } from 'features/nodes/store/types';
|
||||||
import { type FieldType, isStatefulFieldType } from 'features/nodes/types/field';
|
import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes';
|
||||||
import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation';
|
import type { FieldType } from 'features/nodes/types/field';
|
||||||
import i18n from 'i18next';
|
import i18n from 'i18next';
|
||||||
import { differenceWith, isEqual, map, omit } from 'lodash-es';
|
import type { HandleType } from 'reactflow';
|
||||||
import type { Connection, Edge, HandleType, Node } from 'reactflow';
|
|
||||||
import { assert } from 'tsafe';
|
import { assert } from 'tsafe';
|
||||||
|
|
||||||
/**
|
import { areTypesEqual } from './areTypesEqual';
|
||||||
* Finds the first valid field for a pending connection between two nodes.
|
import { getCollectItemType } from './getCollectItemType';
|
||||||
* @param templates The invocation templates
|
import { getHasCycles } from './getHasCycles';
|
||||||
* @param nodes The current nodes
|
|
||||||
* @param edges The current edges
|
|
||||||
* @param pendingConnection The pending connection
|
|
||||||
* @param candidateNode The candidate node to which the connection is being made
|
|
||||||
* @param candidateTemplate The candidate template for the candidate node
|
|
||||||
* @returns The first valid connection, or null if no valid connection is found
|
|
||||||
*/
|
|
||||||
export const getFirstValidConnection = (
|
|
||||||
templates: Templates,
|
|
||||||
nodes: AnyNode[],
|
|
||||||
edges: InvocationNodeEdge[],
|
|
||||||
pendingConnection: PendingConnection,
|
|
||||||
candidateNode: InvocationNode,
|
|
||||||
candidateTemplate: InvocationTemplate
|
|
||||||
): Connection | null => {
|
|
||||||
if (pendingConnection.node.id === candidateNode.id) {
|
|
||||||
// Cannot connect to self
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source';
|
|
||||||
|
|
||||||
if (pendingFieldKind === 'source') {
|
|
||||||
// Connecting from a source to a target
|
|
||||||
if (getHasCycles(pendingConnection.node.id, candidateNode.id, nodes, edges)) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
if (candidateNode.data.type === 'collect') {
|
|
||||||
// Special handling for collect node - the `item` field takes any number of connections
|
|
||||||
return {
|
|
||||||
source: pendingConnection.node.id,
|
|
||||||
sourceHandle: pendingConnection.fieldTemplate.name,
|
|
||||||
target: candidateNode.id,
|
|
||||||
targetHandle: 'item',
|
|
||||||
};
|
|
||||||
}
|
|
||||||
// Only one connection per target field is allowed - look for an unconnected target field
|
|
||||||
const candidateFields = map(candidateTemplate.inputs);
|
|
||||||
const candidateConnectedFields = edges
|
|
||||||
.filter((edge) => edge.target === candidateNode.id)
|
|
||||||
.map((edge) => {
|
|
||||||
// Edges must always have a targetHandle, safe to assert here
|
|
||||||
assert(edge.targetHandle);
|
|
||||||
return edge.targetHandle;
|
|
||||||
});
|
|
||||||
const candidateUnconnectedFields = differenceWith(
|
|
||||||
candidateFields,
|
|
||||||
candidateConnectedFields,
|
|
||||||
(field, connectedFieldName) => field.name === connectedFieldName
|
|
||||||
);
|
|
||||||
const candidateField = candidateUnconnectedFields.find((field) =>
|
|
||||||
validateSourceAndTargetTypes(pendingConnection.fieldTemplate.type, field.type)
|
|
||||||
);
|
|
||||||
if (candidateField) {
|
|
||||||
return {
|
|
||||||
source: pendingConnection.node.id,
|
|
||||||
sourceHandle: pendingConnection.fieldTemplate.name,
|
|
||||||
target: candidateNode.id,
|
|
||||||
targetHandle: candidateField.name,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Connecting from a target to a source
|
|
||||||
// Ensure we there is not already an edge to the target, except for collect nodes
|
|
||||||
const isCollect = pendingConnection.node.data.type === 'collect';
|
|
||||||
const isTargetAlreadyConnected = edges.some(
|
|
||||||
(e) => e.target === pendingConnection.node.id && e.targetHandle === pendingConnection.fieldTemplate.name
|
|
||||||
);
|
|
||||||
if (!isCollect && isTargetAlreadyConnected) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (getHasCycles(candidateNode.id, pendingConnection.node.id, nodes, edges)) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sources/outputs can have any number of edges, we can take the first matching output field
|
|
||||||
let candidateFields = map(candidateTemplate.outputs);
|
|
||||||
if (isCollect) {
|
|
||||||
// Narrow candidates to same field type as already is connected to the collect node
|
|
||||||
const collectItemType = getCollectItemType(templates, nodes, edges, pendingConnection.node.id);
|
|
||||||
if (collectItemType) {
|
|
||||||
candidateFields = candidateFields.filter((field) => areTypesEqual(field.type, collectItemType));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
const candidateField = candidateFields.find((field) => {
|
|
||||||
const isValid = validateSourceAndTargetTypes(field.type, pendingConnection.fieldTemplate.type);
|
|
||||||
const isAlreadyConnected = edges.some((e) => e.source === candidateNode.id && e.sourceHandle === field.name);
|
|
||||||
return isValid && !isAlreadyConnected;
|
|
||||||
});
|
|
||||||
if (candidateField) {
|
|
||||||
return {
|
|
||||||
source: candidateNode.id,
|
|
||||||
sourceHandle: candidateField.name,
|
|
||||||
target: pendingConnection.node.id,
|
|
||||||
targetHandle: pendingConnection.fieldTemplate.name,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return null;
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if adding an edge between the source and target nodes would create a cycle in the graph.
|
|
||||||
* @param source The source node id
|
|
||||||
* @param target The target node id
|
|
||||||
* @param nodes The graph's current nodes
|
|
||||||
* @param edges The graph's current edges
|
|
||||||
* @returns True if the graph would be acyclic after adding the edge, false otherwise
|
|
||||||
*/
|
|
||||||
export const getHasCycles = (source: string, target: string, nodes: Node[], edges: Edge[]) => {
|
|
||||||
// construct graphlib graph from editor state
|
|
||||||
const g = new graphlib.Graph();
|
|
||||||
|
|
||||||
nodes.forEach((n) => {
|
|
||||||
g.setNode(n.id);
|
|
||||||
});
|
|
||||||
|
|
||||||
edges.forEach((e) => {
|
|
||||||
g.setEdge(e.source, e.target);
|
|
||||||
});
|
|
||||||
|
|
||||||
// add the candidate edge
|
|
||||||
g.setEdge(source, target);
|
|
||||||
|
|
||||||
// check if the graph is acyclic
|
|
||||||
return !graphlib.alg.isAcyclic(g);
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Given a collect node, return the type of the items it collects. The graph is traversed to find the first node and
|
|
||||||
* field connected to the collector's `item` input. The field type of that field is returned, else null if there is no
|
|
||||||
* input field.
|
|
||||||
* @param templates The current invocation templates
|
|
||||||
* @param nodes The current nodes
|
|
||||||
* @param edges The current edges
|
|
||||||
* @param nodeId The collect node's id
|
|
||||||
* @returns The type of the items the collect node collects, or null if there is no input field
|
|
||||||
*/
|
|
||||||
export const getCollectItemType = (
|
|
||||||
templates: Templates,
|
|
||||||
nodes: AnyNode[],
|
|
||||||
edges: InvocationNodeEdge[],
|
|
||||||
nodeId: string
|
|
||||||
): FieldType | null => {
|
|
||||||
const firstEdgeToCollect = edges.find((edge) => edge.target === nodeId && edge.targetHandle === 'item');
|
|
||||||
if (!firstEdgeToCollect?.sourceHandle) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
const node = nodes.find((n) => n.id === firstEdgeToCollect.source);
|
|
||||||
if (!node) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
const template = templates[node.data.type];
|
|
||||||
if (!template) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
const fieldType = template.outputs[firstEdgeToCollect.sourceHandle]?.type ?? null;
|
|
||||||
return fieldType;
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a selector that validates a pending connection.
|
* Creates a selector that validates a pending connection.
|
||||||
@ -276,7 +113,7 @@ export const makeConnectionErrorSelector = (
|
|||||||
return i18n.t('nodes.inputMayOnlyHaveOneConnection');
|
return i18n.t('nodes.inputMayOnlyHaveOneConnection');
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!validateSourceAndTargetTypes(sourceType, targetType)) {
|
if (!validateConnectionTypes(sourceType, targetType)) {
|
||||||
return i18n.t('nodes.fieldTypesMustMatch');
|
return i18n.t('nodes.fieldTypesMustMatch');
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -295,97 +132,3 @@ export const makeConnectionErrorSelector = (
|
|||||||
}
|
}
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
|
||||||
* Validates that the source and target types are compatible for a connection.
|
|
||||||
* @param sourceType The type of the source field.
|
|
||||||
* @param targetType The type of the target field.
|
|
||||||
* @returns True if the connection is valid, false otherwise.
|
|
||||||
*/
|
|
||||||
export const validateSourceAndTargetTypes = (sourceType: FieldType, targetType: FieldType) => {
|
|
||||||
// TODO: There's a bug with Collect -> Iterate nodes:
|
|
||||||
// https://github.com/invoke-ai/InvokeAI/issues/3956
|
|
||||||
// Once this is resolved, we can remove this check.
|
|
||||||
if (sourceType.name === 'CollectionField' && targetType.name === 'CollectionField') {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (areTypesEqual(sourceType, targetType)) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Connection types must be the same for a connection, with exceptions:
|
|
||||||
* - CollectionItem can connect to any non-Collection
|
|
||||||
* - Non-Collections can connect to CollectionItem
|
|
||||||
* - Anything (non-Collections, Collections, CollectionOrScalar) can connect to CollectionOrScalar of the same base type
|
|
||||||
* - Generic Collection can connect to any other Collection or CollectionOrScalar
|
|
||||||
* - Any Collection can connect to a Generic Collection
|
|
||||||
*/
|
|
||||||
const isCollectionItemToNonCollection = sourceType.name === 'CollectionItemField' && !targetType.isCollection;
|
|
||||||
|
|
||||||
const isNonCollectionToCollectionItem =
|
|
||||||
targetType.name === 'CollectionItemField' && !sourceType.isCollection && !sourceType.isCollectionOrScalar;
|
|
||||||
|
|
||||||
const isAnythingToCollectionOrScalarOfSameBaseType =
|
|
||||||
targetType.isCollectionOrScalar && sourceType.name === targetType.name;
|
|
||||||
|
|
||||||
const isGenericCollectionToAnyCollectionOrCollectionOrScalar =
|
|
||||||
sourceType.name === 'CollectionField' && (targetType.isCollection || targetType.isCollectionOrScalar);
|
|
||||||
|
|
||||||
const isCollectionToGenericCollection = targetType.name === 'CollectionField' && sourceType.isCollection;
|
|
||||||
|
|
||||||
const areBothTypesSingle =
|
|
||||||
!sourceType.isCollection &&
|
|
||||||
!sourceType.isCollectionOrScalar &&
|
|
||||||
!targetType.isCollection &&
|
|
||||||
!targetType.isCollectionOrScalar;
|
|
||||||
|
|
||||||
const isIntToFloat = areBothTypesSingle && sourceType.name === 'IntegerField' && targetType.name === 'FloatField';
|
|
||||||
|
|
||||||
const isIntOrFloatToString =
|
|
||||||
areBothTypesSingle &&
|
|
||||||
(sourceType.name === 'IntegerField' || sourceType.name === 'FloatField') &&
|
|
||||||
targetType.name === 'StringField';
|
|
||||||
|
|
||||||
const isTargetAnyType = targetType.name === 'AnyField';
|
|
||||||
|
|
||||||
// One of these must be true for the connection to be valid
|
|
||||||
return (
|
|
||||||
isCollectionItemToNonCollection ||
|
|
||||||
isNonCollectionToCollectionItem ||
|
|
||||||
isAnythingToCollectionOrScalarOfSameBaseType ||
|
|
||||||
isGenericCollectionToAnyCollectionOrCollectionOrScalar ||
|
|
||||||
isCollectionToGenericCollection ||
|
|
||||||
isIntToFloat ||
|
|
||||||
isIntOrFloatToString ||
|
|
||||||
isTargetAnyType
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Checks if two types are equal. If the field types have original types, those are also compared. Any match is
|
|
||||||
* considered equal. For example, if the source type and original target type match, the types are considered equal.
|
|
||||||
* @param sourceType The type of the source field.
|
|
||||||
* @param targetType The type of the target field.
|
|
||||||
* @returns True if the types are equal, false otherwise.
|
|
||||||
*/
|
|
||||||
export const areTypesEqual = (sourceType: FieldType, targetType: FieldType) => {
|
|
||||||
const _sourceType = isStatefulFieldType(sourceType) ? omit(sourceType, 'originalType') : sourceType;
|
|
||||||
const _targetType = isStatefulFieldType(targetType) ? omit(targetType, 'originalType') : targetType;
|
|
||||||
const _sourceTypeOriginal = isStatefulFieldType(sourceType) ? sourceType.originalType : sourceType;
|
|
||||||
const _targetTypeOriginal = isStatefulFieldType(targetType) ? targetType.originalType : targetType;
|
|
||||||
if (isEqual(_sourceType, _targetType)) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
if (_targetTypeOriginal && isEqual(_sourceType, _targetTypeOriginal)) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
if (_sourceTypeOriginal && isEqual(_sourceTypeOriginal, _targetType)) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
if (_sourceTypeOriginal && _targetTypeOriginal && isEqual(_sourceTypeOriginal, _targetTypeOriginal)) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
};
|
|
||||||
|
@ -0,0 +1,16 @@
|
|||||||
|
import { getCollectItemType } from 'features/nodes/store/util/getCollectItemType';
|
||||||
|
import { add, buildEdge, collect, position, templates } from 'features/nodes/store/util/testUtils';
|
||||||
|
import type { FieldType } from 'features/nodes/types/field';
|
||||||
|
import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode';
|
||||||
|
import { describe, expect, it } from 'vitest';
|
||||||
|
|
||||||
|
describe(getCollectItemType.name, () => {
|
||||||
|
it('should return the type of the items the collect node collects', () => {
|
||||||
|
const n1 = buildInvocationNode(position, add);
|
||||||
|
const n2 = buildInvocationNode(position, collect);
|
||||||
|
const nodes = [n1, n2];
|
||||||
|
const edges = [buildEdge(n1.id, 'value', n2.id, 'item')];
|
||||||
|
const result = getCollectItemType(templates, nodes, edges, n2.id);
|
||||||
|
expect(result).toEqual<FieldType>({ name: 'IntegerField', isCollection: false, isCollectionOrScalar: false });
|
||||||
|
});
|
||||||
|
});
|
@ -0,0 +1,35 @@
|
|||||||
|
import type { Templates } from 'features/nodes/store/types';
|
||||||
|
import type { FieldType } from 'features/nodes/types/field';
|
||||||
|
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Given a collect node, return the type of the items it collects. The graph is traversed to find the first node and
|
||||||
|
* field connected to the collector's `item` input. The field type of that field is returned, else null if there is no
|
||||||
|
* input field.
|
||||||
|
* @param templates The current invocation templates
|
||||||
|
* @param nodes The current nodes
|
||||||
|
* @param edges The current edges
|
||||||
|
* @param nodeId The collect node's id
|
||||||
|
* @returns The type of the items the collect node collects, or null if there is no input field
|
||||||
|
*/
|
||||||
|
export const getCollectItemType = (
|
||||||
|
templates: Templates,
|
||||||
|
nodes: AnyNode[],
|
||||||
|
edges: InvocationNodeEdge[],
|
||||||
|
nodeId: string
|
||||||
|
): FieldType | null => {
|
||||||
|
const firstEdgeToCollect = edges.find((edge) => edge.target === nodeId && edge.targetHandle === 'item');
|
||||||
|
if (!firstEdgeToCollect?.sourceHandle) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
const node = nodes.find((n) => n.id === firstEdgeToCollect.source);
|
||||||
|
if (!node) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
const template = templates[node.data.type];
|
||||||
|
if (!template) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
const fieldType = template.outputs[firstEdgeToCollect.sourceHandle]?.type ?? null;
|
||||||
|
return fieldType;
|
||||||
|
};
|
@ -0,0 +1,116 @@
|
|||||||
|
import type { PendingConnection, Templates } from 'features/nodes/store/types';
|
||||||
|
import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes';
|
||||||
|
import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation';
|
||||||
|
import { differenceWith, map } from 'lodash-es';
|
||||||
|
import type { Connection } from 'reactflow';
|
||||||
|
import { assert } from 'tsafe';
|
||||||
|
|
||||||
|
import { areTypesEqual } from './areTypesEqual';
|
||||||
|
import { getCollectItemType } from './getCollectItemType';
|
||||||
|
import { getHasCycles } from './getHasCycles';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Finds the first valid field for a pending connection between two nodes.
|
||||||
|
* @param templates The invocation templates
|
||||||
|
* @param nodes The current nodes
|
||||||
|
* @param edges The current edges
|
||||||
|
* @param pendingConnection The pending connection
|
||||||
|
* @param candidateNode The candidate node to which the connection is being made
|
||||||
|
* @param candidateTemplate The candidate template for the candidate node
|
||||||
|
* @returns The first valid connection, or null if no valid connection is found
|
||||||
|
*/
|
||||||
|
|
||||||
|
export const getFirstValidConnection = (
|
||||||
|
templates: Templates,
|
||||||
|
nodes: AnyNode[],
|
||||||
|
edges: InvocationNodeEdge[],
|
||||||
|
pendingConnection: PendingConnection,
|
||||||
|
candidateNode: InvocationNode,
|
||||||
|
candidateTemplate: InvocationTemplate
|
||||||
|
): Connection | null => {
|
||||||
|
if (pendingConnection.node.id === candidateNode.id) {
|
||||||
|
// Cannot connect to self
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source';
|
||||||
|
|
||||||
|
if (pendingFieldKind === 'source') {
|
||||||
|
// Connecting from a source to a target
|
||||||
|
if (getHasCycles(pendingConnection.node.id, candidateNode.id, nodes, edges)) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
if (candidateNode.data.type === 'collect') {
|
||||||
|
// Special handling for collect node - the `item` field takes any number of connections
|
||||||
|
return {
|
||||||
|
source: pendingConnection.node.id,
|
||||||
|
sourceHandle: pendingConnection.fieldTemplate.name,
|
||||||
|
target: candidateNode.id,
|
||||||
|
targetHandle: 'item',
|
||||||
|
};
|
||||||
|
}
|
||||||
|
// Only one connection per target field is allowed - look for an unconnected target field
|
||||||
|
const candidateFields = map(candidateTemplate.inputs);
|
||||||
|
const candidateConnectedFields = edges
|
||||||
|
.filter((edge) => edge.target === candidateNode.id)
|
||||||
|
.map((edge) => {
|
||||||
|
// Edges must always have a targetHandle, safe to assert here
|
||||||
|
assert(edge.targetHandle);
|
||||||
|
return edge.targetHandle;
|
||||||
|
});
|
||||||
|
const candidateUnconnectedFields = differenceWith(
|
||||||
|
candidateFields,
|
||||||
|
candidateConnectedFields,
|
||||||
|
(field, connectedFieldName) => field.name === connectedFieldName
|
||||||
|
);
|
||||||
|
const candidateField = candidateUnconnectedFields.find((field) => validateConnectionTypes(pendingConnection.fieldTemplate.type, field.type)
|
||||||
|
);
|
||||||
|
if (candidateField) {
|
||||||
|
return {
|
||||||
|
source: pendingConnection.node.id,
|
||||||
|
sourceHandle: pendingConnection.fieldTemplate.name,
|
||||||
|
target: candidateNode.id,
|
||||||
|
targetHandle: candidateField.name,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Connecting from a target to a source
|
||||||
|
// Ensure we there is not already an edge to the target, except for collect nodes
|
||||||
|
const isCollect = pendingConnection.node.data.type === 'collect';
|
||||||
|
const isTargetAlreadyConnected = edges.some(
|
||||||
|
(e) => e.target === pendingConnection.node.id && e.targetHandle === pendingConnection.fieldTemplate.name
|
||||||
|
);
|
||||||
|
if (!isCollect && isTargetAlreadyConnected) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (getHasCycles(candidateNode.id, pendingConnection.node.id, nodes, edges)) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sources/outputs can have any number of edges, we can take the first matching output field
|
||||||
|
let candidateFields = map(candidateTemplate.outputs);
|
||||||
|
if (isCollect) {
|
||||||
|
// Narrow candidates to same field type as already is connected to the collect node
|
||||||
|
const collectItemType = getCollectItemType(templates, nodes, edges, pendingConnection.node.id);
|
||||||
|
if (collectItemType) {
|
||||||
|
candidateFields = candidateFields.filter((field) => areTypesEqual(field.type, collectItemType));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const candidateField = candidateFields.find((field) => {
|
||||||
|
const isValid = validateConnectionTypes(field.type, pendingConnection.fieldTemplate.type);
|
||||||
|
const isAlreadyConnected = edges.some((e) => e.source === candidateNode.id && e.sourceHandle === field.name);
|
||||||
|
return isValid && !isAlreadyConnected;
|
||||||
|
});
|
||||||
|
if (candidateField) {
|
||||||
|
return {
|
||||||
|
source: candidateNode.id,
|
||||||
|
sourceHandle: candidateField.name,
|
||||||
|
target: pendingConnection.node.id,
|
||||||
|
targetHandle: pendingConnection.fieldTemplate.name,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return null;
|
||||||
|
};
|
@ -0,0 +1,23 @@
|
|||||||
|
import { getHasCycles } from 'features/nodes/store/util/getHasCycles';
|
||||||
|
import { add, buildEdge, position } from 'features/nodes/store/util/testUtils';
|
||||||
|
import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode';
|
||||||
|
import { describe, expect, it } from 'vitest';
|
||||||
|
|
||||||
|
describe(getHasCycles.name, () => {
|
||||||
|
const n1 = buildInvocationNode(position, add);
|
||||||
|
const n2 = buildInvocationNode(position, add);
|
||||||
|
const n3 = buildInvocationNode(position, add);
|
||||||
|
const nodes = [n1, n2, n3];
|
||||||
|
|
||||||
|
it('should return true if the graph WOULD have cycles after adding the edge', () => {
|
||||||
|
const edges = [buildEdge(n1.id, 'value', n2.id, 'a'), buildEdge(n2.id, 'value', n3.id, 'a')];
|
||||||
|
const result = getHasCycles(n3.id, n1.id, nodes, edges);
|
||||||
|
expect(result).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return false if the graph WOULD NOT have cycles after adding the edge', () => {
|
||||||
|
const edges = [buildEdge(n1.id, 'value', n2.id, 'a')];
|
||||||
|
const result = getHasCycles(n2.id, n3.id, nodes, edges);
|
||||||
|
expect(result).toBe(false);
|
||||||
|
});
|
||||||
|
});
|
@ -0,0 +1,30 @@
|
|||||||
|
import graphlib from '@dagrejs/graphlib';
|
||||||
|
import type { Edge, Node } from 'reactflow';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if adding an edge between the source and target nodes would create a cycle in the graph.
|
||||||
|
* @param source The source node id
|
||||||
|
* @param target The target node id
|
||||||
|
* @param nodes The graph's current nodes
|
||||||
|
* @param edges The graph's current edges
|
||||||
|
* @returns True if the graph would be acyclic after adding the edge, false otherwise
|
||||||
|
*/
|
||||||
|
|
||||||
|
export const getHasCycles = (source: string, target: string, nodes: Node[], edges: Edge[]) => {
|
||||||
|
// construct graphlib graph from editor state
|
||||||
|
const g = new graphlib.Graph();
|
||||||
|
|
||||||
|
nodes.forEach((n) => {
|
||||||
|
g.setNode(n.id);
|
||||||
|
});
|
||||||
|
|
||||||
|
edges.forEach((e) => {
|
||||||
|
g.setEdge(e.source, e.target);
|
||||||
|
});
|
||||||
|
|
||||||
|
// add the candidate edge
|
||||||
|
g.setEdge(source, target);
|
||||||
|
|
||||||
|
// check if the graph is acyclic
|
||||||
|
return !graphlib.alg.isAcyclic(g);
|
||||||
|
};
|
1073
invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts
Normal file
1073
invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,149 @@
|
|||||||
|
import { deepClone } from 'common/util/deepClone';
|
||||||
|
import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode';
|
||||||
|
import { set } from 'lodash-es';
|
||||||
|
import { describe, expect, it } from 'vitest';
|
||||||
|
|
||||||
|
import { add, buildEdge, collect, main_model_loader, position, sub, templates } from './testUtils';
|
||||||
|
import { buildAcceptResult, buildRejectResult, validateConnection } from './validateConnection';
|
||||||
|
|
||||||
|
describe(validateConnection.name, () => {
|
||||||
|
it('should reject invalid connection to self', () => {
|
||||||
|
const c = { source: 'add', sourceHandle: 'value', target: 'add', targetHandle: 'a' };
|
||||||
|
const r = validateConnection(c, [], [], templates, null);
|
||||||
|
expect(r).toEqual(buildRejectResult('nodes.cannotConnectToSelf'));
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('missing nodes', () => {
|
||||||
|
const n1 = buildInvocationNode(position, add);
|
||||||
|
const n2 = buildInvocationNode(position, sub);
|
||||||
|
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
|
||||||
|
|
||||||
|
it('should reject missing source node', () => {
|
||||||
|
const r = validateConnection(c, [n2], [], templates, null);
|
||||||
|
expect(r).toEqual(buildRejectResult('nodes.missingNode'));
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should reject missing target node', () => {
|
||||||
|
const r = validateConnection(c, [n1], [], templates, null);
|
||||||
|
expect(r).toEqual(buildRejectResult('nodes.missingNode'));
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('missing invocation templates', () => {
|
||||||
|
const n1 = buildInvocationNode(position, add);
|
||||||
|
const n2 = buildInvocationNode(position, sub);
|
||||||
|
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
|
||||||
|
const nodes = [n1, n2];
|
||||||
|
|
||||||
|
it('should reject missing source template', () => {
|
||||||
|
const r = validateConnection(c, nodes, [], { sub }, null);
|
||||||
|
expect(r).toEqual(buildRejectResult('nodes.missingInvocationTemplate'));
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should reject missing target template', () => {
|
||||||
|
const r = validateConnection(c, nodes, [], { add }, null);
|
||||||
|
expect(r).toEqual(buildRejectResult('nodes.missingInvocationTemplate'));
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('missing field templates', () => {
|
||||||
|
const n1 = buildInvocationNode(position, add);
|
||||||
|
const n2 = buildInvocationNode(position, sub);
|
||||||
|
const nodes = [n1, n2];
|
||||||
|
|
||||||
|
it('should reject missing source field template', () => {
|
||||||
|
const c = { source: n1.id, sourceHandle: 'invalid', target: n2.id, targetHandle: 'a' };
|
||||||
|
const r = validateConnection(c, nodes, [], templates, null);
|
||||||
|
expect(r).toEqual(buildRejectResult('nodes.missingFieldTemplate'));
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should reject missing target field template', () => {
|
||||||
|
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'invalid' };
|
||||||
|
const r = validateConnection(c, nodes, [], templates, null);
|
||||||
|
expect(r).toEqual(buildRejectResult('nodes.missingFieldTemplate'));
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('duplicate connections', () => {
|
||||||
|
const n1 = buildInvocationNode(position, add);
|
||||||
|
const n2 = buildInvocationNode(position, sub);
|
||||||
|
it('should accept non-duplicate connections', () => {
|
||||||
|
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
|
||||||
|
const r = validateConnection(c, [n1, n2], [], templates, null);
|
||||||
|
expect(r).toEqual(buildAcceptResult());
|
||||||
|
});
|
||||||
|
it('should reject duplicate connections', () => {
|
||||||
|
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
|
||||||
|
const e = buildEdge(n1.id, 'value', n2.id, 'a');
|
||||||
|
const r = validateConnection(c, [n1, n2], [e], templates, null);
|
||||||
|
expect(r).toEqual(buildRejectResult('nodes.cannotDuplicateConnection'));
|
||||||
|
});
|
||||||
|
it('should accept duplicate connections if the duplicate is an ignored edge', () => {
|
||||||
|
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
|
||||||
|
const e = buildEdge(n1.id, 'value', n2.id, 'a');
|
||||||
|
const r = validateConnection(c, [n1, n2], [e], templates, e);
|
||||||
|
expect(r).toEqual(buildAcceptResult());
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should reject connection to direct input', () => {
|
||||||
|
// Create cloned add template w/ a direct input
|
||||||
|
const addWithDirectAField = deepClone(add);
|
||||||
|
set(addWithDirectAField, 'inputs.a.input', 'direct');
|
||||||
|
set(addWithDirectAField, 'type', 'addWithDirectAField');
|
||||||
|
|
||||||
|
const n1 = buildInvocationNode(position, add);
|
||||||
|
const n2 = buildInvocationNode(position, addWithDirectAField);
|
||||||
|
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
|
||||||
|
const r = validateConnection(c, [n1, n2], [], { add, addWithDirectAField }, null);
|
||||||
|
expect(r).toEqual(buildRejectResult('nodes.cannotConnectToDirectInput'));
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should reject connection to a collect node with mismatched item types', () => {
|
||||||
|
const n1 = buildInvocationNode(position, add);
|
||||||
|
const n2 = buildInvocationNode(position, collect);
|
||||||
|
const n3 = buildInvocationNode(position, main_model_loader);
|
||||||
|
const nodes = [n1, n2, n3];
|
||||||
|
const e1 = buildEdge(n1.id, 'value', n2.id, 'item');
|
||||||
|
const edges = [e1];
|
||||||
|
const c = { source: n3.id, sourceHandle: 'vae', target: n2.id, targetHandle: 'item' };
|
||||||
|
const r = validateConnection(c, nodes, edges, templates, null);
|
||||||
|
expect(r).toEqual(buildRejectResult('nodes.cannotMixAndMatchCollectionItemTypes'));
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should accept connection to a collect node with matching item types', () => {
|
||||||
|
const n1 = buildInvocationNode(position, add);
|
||||||
|
const n2 = buildInvocationNode(position, collect);
|
||||||
|
const n3 = buildInvocationNode(position, sub);
|
||||||
|
const nodes = [n1, n2, n3];
|
||||||
|
const e1 = buildEdge(n1.id, 'value', n2.id, 'item');
|
||||||
|
const edges = [e1];
|
||||||
|
const c = { source: n3.id, sourceHandle: 'value', target: n2.id, targetHandle: 'item' };
|
||||||
|
const r = validateConnection(c, nodes, edges, templates, null);
|
||||||
|
expect(r).toEqual(buildAcceptResult());
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should reject connections to target field that is already connected', () => {
|
||||||
|
const n1 = buildInvocationNode(position, add);
|
||||||
|
const n2 = buildInvocationNode(position, add);
|
||||||
|
const n3 = buildInvocationNode(position, add);
|
||||||
|
const nodes = [n1, n2, n3];
|
||||||
|
const e1 = buildEdge(n1.id, 'value', n2.id, 'a');
|
||||||
|
const edges = [e1];
|
||||||
|
const c = { source: n3.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
|
||||||
|
const r = validateConnection(c, nodes, edges, templates, null);
|
||||||
|
expect(r).toEqual(buildRejectResult('nodes.inputMayOnlyHaveOneConnection'));
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should accept connections to target field that is already connected (ignored edge)', () => {
|
||||||
|
const n1 = buildInvocationNode(position, add);
|
||||||
|
const n2 = buildInvocationNode(position, add);
|
||||||
|
const n3 = buildInvocationNode(position, add);
|
||||||
|
const nodes = [n1, n2, n3];
|
||||||
|
const e1 = buildEdge(n1.id, 'value', n2.id, 'a');
|
||||||
|
const edges = [e1];
|
||||||
|
const c = { source: n3.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
|
||||||
|
const r = validateConnection(c, nodes, edges, templates, e1);
|
||||||
|
expect(r).toEqual(buildAcceptResult());
|
||||||
|
});
|
||||||
|
});
|
@ -0,0 +1,109 @@
|
|||||||
|
import type { Templates } from 'features/nodes/store/types';
|
||||||
|
import { areTypesEqual } from 'features/nodes/store/util/areTypesEqual';
|
||||||
|
import { getCollectItemType } from 'features/nodes/store/util/getCollectItemType';
|
||||||
|
import type { AnyNode } from 'features/nodes/types/invocation';
|
||||||
|
import type { Connection as NullableConnection, Edge } from 'reactflow';
|
||||||
|
import type { O } from 'ts-toolbelt';
|
||||||
|
|
||||||
|
type Connection = O.NonNullable<NullableConnection>;
|
||||||
|
|
||||||
|
export type ValidateConnectionResult = {
|
||||||
|
isValid: boolean;
|
||||||
|
messageTKey?: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type ValidateConnectionFunc = (
|
||||||
|
connection: Connection,
|
||||||
|
nodes: AnyNode[],
|
||||||
|
edges: Edge[],
|
||||||
|
templates: Templates,
|
||||||
|
ignoreEdge: Edge | null
|
||||||
|
) => ValidateConnectionResult;
|
||||||
|
|
||||||
|
export const buildResult = (isValid: boolean, messageTKey?: string): ValidateConnectionResult => ({
|
||||||
|
isValid,
|
||||||
|
messageTKey,
|
||||||
|
});
|
||||||
|
|
||||||
|
const getEqualityPredicate =
|
||||||
|
(c: Connection) =>
|
||||||
|
(e: Edge): boolean => {
|
||||||
|
return (
|
||||||
|
e.target === c.target &&
|
||||||
|
e.targetHandle === c.targetHandle &&
|
||||||
|
e.source === c.source &&
|
||||||
|
e.sourceHandle === c.sourceHandle
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export const buildAcceptResult = (): ValidateConnectionResult => ({ isValid: true });
|
||||||
|
export const buildRejectResult = (messageTKey: string): ValidateConnectionResult => ({ isValid: false, messageTKey });
|
||||||
|
|
||||||
|
export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, templates, ignoreEdge) => {
|
||||||
|
if (c.source === c.target) {
|
||||||
|
return buildRejectResult('nodes.cannotConnectToSelf');
|
||||||
|
}
|
||||||
|
|
||||||
|
const filteredEdges = edges.filter((e) => e.id !== ignoreEdge?.id);
|
||||||
|
|
||||||
|
if (filteredEdges.some(getEqualityPredicate(c))) {
|
||||||
|
// We already have a connection from this source to this target
|
||||||
|
return buildRejectResult('nodes.cannotDuplicateConnection');
|
||||||
|
}
|
||||||
|
|
||||||
|
const sourceNode = nodes.find((n) => n.id === c.source);
|
||||||
|
if (!sourceNode) {
|
||||||
|
return buildRejectResult('nodes.missingNode');
|
||||||
|
}
|
||||||
|
|
||||||
|
const targetNode = nodes.find((n) => n.id === c.target);
|
||||||
|
if (!targetNode) {
|
||||||
|
return buildRejectResult('nodes.missingNode');
|
||||||
|
}
|
||||||
|
|
||||||
|
const sourceTemplate = templates[sourceNode.data.type];
|
||||||
|
if (!sourceTemplate) {
|
||||||
|
return buildRejectResult('nodes.missingInvocationTemplate');
|
||||||
|
}
|
||||||
|
|
||||||
|
const targetTemplate = templates[targetNode.data.type];
|
||||||
|
if (!targetTemplate) {
|
||||||
|
return buildRejectResult('nodes.missingInvocationTemplate');
|
||||||
|
}
|
||||||
|
|
||||||
|
const sourceFieldTemplate = sourceTemplate.outputs[c.sourceHandle];
|
||||||
|
if (!sourceFieldTemplate) {
|
||||||
|
return buildRejectResult('nodes.missingFieldTemplate');
|
||||||
|
}
|
||||||
|
|
||||||
|
const targetFieldTemplate = targetTemplate.inputs[c.targetHandle];
|
||||||
|
if (!targetFieldTemplate) {
|
||||||
|
return buildRejectResult('nodes.missingFieldTemplate');
|
||||||
|
}
|
||||||
|
|
||||||
|
if (targetFieldTemplate.input === 'direct') {
|
||||||
|
return buildRejectResult('nodes.cannotConnectToDirectInput');
|
||||||
|
}
|
||||||
|
|
||||||
|
if (targetNode.data.type === 'collect' && c.targetHandle === 'item') {
|
||||||
|
// Collect nodes shouldn't mix and match field types
|
||||||
|
const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id);
|
||||||
|
if (collectItemType) {
|
||||||
|
if (!areTypesEqual(sourceFieldTemplate.type, collectItemType)) {
|
||||||
|
return buildRejectResult('nodes.cannotMixAndMatchCollectionItemTypes');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
edges.find((e) => {
|
||||||
|
return e.target === c.target && e.targetHandle === c.targetHandle;
|
||||||
|
}) &&
|
||||||
|
// except CollectionItem inputs can have multiples
|
||||||
|
targetFieldTemplate.type.name !== 'CollectionItemField'
|
||||||
|
) {
|
||||||
|
return buildRejectResult('nodes.inputMayOnlyHaveOneConnection');
|
||||||
|
}
|
||||||
|
|
||||||
|
return buildAcceptResult();
|
||||||
|
};
|
@ -0,0 +1,222 @@
|
|||||||
|
import { describe, expect, it } from 'vitest';
|
||||||
|
|
||||||
|
import { validateConnectionTypes } from './validateConnectionTypes';
|
||||||
|
|
||||||
|
describe(validateConnectionTypes.name, () => {
|
||||||
|
describe('generic cases', () => {
|
||||||
|
it('should accept Scalar to Scalar of same type', () => {
|
||||||
|
const r = validateConnectionTypes(
|
||||||
|
{ name: 'FooField', isCollection: false, isCollectionOrScalar: false },
|
||||||
|
{ name: 'FooField', isCollection: false, isCollectionOrScalar: false }
|
||||||
|
);
|
||||||
|
expect(r).toBe(true);
|
||||||
|
});
|
||||||
|
it('should accept Collection to Collection of same type', () => {
|
||||||
|
const r = validateConnectionTypes(
|
||||||
|
{ name: 'FooField', isCollection: true, isCollectionOrScalar: false },
|
||||||
|
{ name: 'FooField', isCollection: true, isCollectionOrScalar: false }
|
||||||
|
);
|
||||||
|
expect(r).toBe(true);
|
||||||
|
});
|
||||||
|
it('should accept Scalar to CollectionOrScalar of same type', () => {
|
||||||
|
const r = validateConnectionTypes(
|
||||||
|
{ name: 'FooField', isCollection: false, isCollectionOrScalar: false },
|
||||||
|
{ name: 'FooField', isCollection: false, isCollectionOrScalar: true }
|
||||||
|
);
|
||||||
|
expect(r).toBe(true);
|
||||||
|
});
|
||||||
|
it('should accept Collection to CollectionOrScalar of same type', () => {
|
||||||
|
const r = validateConnectionTypes(
|
||||||
|
{ name: 'FooField', isCollection: true, isCollectionOrScalar: false },
|
||||||
|
{ name: 'FooField', isCollection: false, isCollectionOrScalar: true }
|
||||||
|
);
|
||||||
|
expect(r).toBe(true);
|
||||||
|
});
|
||||||
|
it('should reject Collection to Scalar of same type', () => {
|
||||||
|
const r = validateConnectionTypes(
|
||||||
|
{ name: 'FooField', isCollection: true, isCollectionOrScalar: false },
|
||||||
|
{ name: 'FooField', isCollection: false, isCollectionOrScalar: false }
|
||||||
|
);
|
||||||
|
expect(r).toBe(false);
|
||||||
|
});
|
||||||
|
it('should reject CollectionOrScalar to Scalar of same type', () => {
|
||||||
|
const r = validateConnectionTypes(
|
||||||
|
{ name: 'FooField', isCollection: false, isCollectionOrScalar: true },
|
||||||
|
{ name: 'FooField', isCollection: false, isCollectionOrScalar: false }
|
||||||
|
);
|
||||||
|
expect(r).toBe(false);
|
||||||
|
});
|
||||||
|
it('should reject mismatched types', () => {
|
||||||
|
const r = validateConnectionTypes(
|
||||||
|
{ name: 'FooField', isCollection: false, isCollectionOrScalar: false },
|
||||||
|
{ name: 'BarField', isCollection: false, isCollectionOrScalar: false }
|
||||||
|
);
|
||||||
|
expect(r).toBe(false);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('special cases', () => {
|
||||||
|
it('should reject a collection input to a collection input', () => {
|
||||||
|
const r = validateConnectionTypes(
|
||||||
|
{ name: 'CollectionField', isCollection: true, isCollectionOrScalar: false },
|
||||||
|
{ name: 'CollectionField', isCollection: true, isCollectionOrScalar: false }
|
||||||
|
);
|
||||||
|
expect(r).toBe(false);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should accept equal types', () => {
|
||||||
|
const r = validateConnectionTypes(
|
||||||
|
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: false },
|
||||||
|
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: false }
|
||||||
|
);
|
||||||
|
expect(r).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('CollectionItemField', () => {
|
||||||
|
it('should accept CollectionItemField to any Scalar target', () => {
|
||||||
|
const r = validateConnectionTypes(
|
||||||
|
{ name: 'CollectionItemField', isCollection: false, isCollectionOrScalar: false },
|
||||||
|
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: false }
|
||||||
|
);
|
||||||
|
expect(r).toBe(true);
|
||||||
|
});
|
||||||
|
it('should accept CollectionItemField to any CollectionOrScalar target', () => {
|
||||||
|
const r = validateConnectionTypes(
|
||||||
|
{ name: 'CollectionItemField', isCollection: false, isCollectionOrScalar: false },
|
||||||
|
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: true }
|
||||||
|
);
|
||||||
|
expect(r).toBe(true);
|
||||||
|
});
|
||||||
|
it('should accept any non-Collection to CollectionItemField', () => {
|
||||||
|
const r = validateConnectionTypes(
|
||||||
|
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: false },
|
||||||
|
{ name: 'CollectionItemField', isCollection: false, isCollectionOrScalar: false }
|
||||||
|
);
|
||||||
|
expect(r).toBe(true);
|
||||||
|
});
|
||||||
|
it('should reject any Collection to CollectionItemField', () => {
|
||||||
|
const r = validateConnectionTypes(
|
||||||
|
{ name: 'IntegerField', isCollection: true, isCollectionOrScalar: false },
|
||||||
|
{ name: 'CollectionItemField', isCollection: false, isCollectionOrScalar: false }
|
||||||
|
);
|
||||||
|
expect(r).toBe(false);
|
||||||
|
});
|
||||||
|
it('should reject any CollectionOrScalar to CollectionItemField', () => {
|
||||||
|
const r = validateConnectionTypes(
|
||||||
|
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: true },
|
||||||
|
{ name: 'CollectionItemField', isCollection: false, isCollectionOrScalar: false }
|
||||||
|
);
|
||||||
|
expect(r).toBe(false);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('CollectionOrScalar', () => {
|
||||||
|
it('should accept any Scalar of same type to CollectionOrScalar', () => {
|
||||||
|
const r = validateConnectionTypes(
|
||||||
|
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: false },
|
||||||
|
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: true }
|
||||||
|
);
|
||||||
|
expect(r).toBe(true);
|
||||||
|
});
|
||||||
|
it('should accept any Collection of same type to CollectionOrScalar', () => {
|
||||||
|
const r = validateConnectionTypes(
|
||||||
|
{ name: 'IntegerField', isCollection: true, isCollectionOrScalar: false },
|
||||||
|
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: true }
|
||||||
|
);
|
||||||
|
expect(r).toBe(true);
|
||||||
|
});
|
||||||
|
it('should accept any CollectionOrScalar of same type to CollectionOrScalar', () => {
|
||||||
|
const r = validateConnectionTypes(
|
||||||
|
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: true },
|
||||||
|
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: true }
|
||||||
|
);
|
||||||
|
expect(r).toBe(true);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('CollectionField', () => {
|
||||||
|
it('should accept any CollectionField to any Collection type', () => {
|
||||||
|
const r = validateConnectionTypes(
|
||||||
|
{ name: 'CollectionField', isCollection: false, isCollectionOrScalar: false },
|
||||||
|
{ name: 'IntegerField', isCollection: true, isCollectionOrScalar: false }
|
||||||
|
);
|
||||||
|
expect(r).toBe(true);
|
||||||
|
});
|
||||||
|
it('should accept any CollectionField to any CollectionOrScalar type', () => {
|
||||||
|
const r = validateConnectionTypes(
|
||||||
|
{ name: 'CollectionField', isCollection: false, isCollectionOrScalar: false },
|
||||||
|
{ name: 'IntegerField', isCollection: false, isCollectionOrScalar: true }
|
||||||
|
);
|
||||||
|
expect(r).toBe(true);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('subtype handling', () => {
|
||||||
|
type TypePair = { t1: string; t2: string };
|
||||||
|
const typePairs = [
|
||||||
|
{ t1: 'IntegerField', t2: 'FloatField' },
|
||||||
|
{ t1: 'IntegerField', t2: 'StringField' },
|
||||||
|
{ t1: 'FloatField', t2: 'StringField' },
|
||||||
|
];
|
||||||
|
it.each(typePairs)('should accept Scalar $t1 to Scalar $t2', ({ t1, t2 }: TypePair) => {
|
||||||
|
const r = validateConnectionTypes(
|
||||||
|
{ name: t1, isCollection: false, isCollectionOrScalar: false },
|
||||||
|
{ name: t2, isCollection: false, isCollectionOrScalar: false }
|
||||||
|
);
|
||||||
|
expect(r).toBe(true);
|
||||||
|
});
|
||||||
|
it.each(typePairs)('should accept Scalar $t1 to CollectionOrScalar $t2', ({ t1, t2 }: TypePair) => {
|
||||||
|
const r = validateConnectionTypes(
|
||||||
|
{ name: t1, isCollection: false, isCollectionOrScalar: false },
|
||||||
|
{ name: t2, isCollection: false, isCollectionOrScalar: true }
|
||||||
|
);
|
||||||
|
expect(r).toBe(true);
|
||||||
|
});
|
||||||
|
it.each(typePairs)('should accept Collection $t1 to Collection $t2', ({ t1, t2 }: TypePair) => {
|
||||||
|
const r = validateConnectionTypes(
|
||||||
|
{ name: t1, isCollection: true, isCollectionOrScalar: false },
|
||||||
|
{ name: t2, isCollection: false, isCollectionOrScalar: false }
|
||||||
|
);
|
||||||
|
expect(r).toBe(true);
|
||||||
|
});
|
||||||
|
it.each(typePairs)('should accept Collection $t1 to CollectionOrScalar $t2', ({ t1, t2 }: TypePair) => {
|
||||||
|
const r = validateConnectionTypes(
|
||||||
|
{ name: t1, isCollection: true, isCollectionOrScalar: false },
|
||||||
|
{ name: t2, isCollection: false, isCollectionOrScalar: true }
|
||||||
|
);
|
||||||
|
expect(r).toBe(true);
|
||||||
|
});
|
||||||
|
it.each(typePairs)('should accept CollectionOrScalar $t1 to CollectionOrScalar $t2', ({ t1, t2 }: TypePair) => {
|
||||||
|
const r = validateConnectionTypes(
|
||||||
|
{ name: t1, isCollection: false, isCollectionOrScalar: true },
|
||||||
|
{ name: t2, isCollection: false, isCollectionOrScalar: true }
|
||||||
|
);
|
||||||
|
expect(r).toBe(true);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('AnyField', () => {
|
||||||
|
it('should accept any Scalar type to AnyField', () => {
|
||||||
|
const r = validateConnectionTypes(
|
||||||
|
{ name: 'FooField', isCollection: false, isCollectionOrScalar: false },
|
||||||
|
{ name: 'AnyField', isCollection: false, isCollectionOrScalar: false }
|
||||||
|
);
|
||||||
|
expect(r).toBe(true);
|
||||||
|
});
|
||||||
|
it('should accept any Collection type to AnyField', () => {
|
||||||
|
const r = validateConnectionTypes(
|
||||||
|
{ name: 'FooField', isCollection: false, isCollectionOrScalar: false },
|
||||||
|
{ name: 'AnyField', isCollection: true, isCollectionOrScalar: false }
|
||||||
|
);
|
||||||
|
expect(r).toBe(true);
|
||||||
|
});
|
||||||
|
it('should accept any CollectionOrScalar type to AnyField', () => {
|
||||||
|
const r = validateConnectionTypes(
|
||||||
|
{ name: 'FooField', isCollection: false, isCollectionOrScalar: false },
|
||||||
|
{ name: 'AnyField', isCollection: false, isCollectionOrScalar: true }
|
||||||
|
);
|
||||||
|
expect(r).toBe(true);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
@ -0,0 +1,69 @@
|
|||||||
|
import { areTypesEqual } from 'features/nodes/store/util/areTypesEqual';
|
||||||
|
import type { FieldType } from 'features/nodes/types/field';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validates that the source and target types are compatible for a connection.
|
||||||
|
* @param sourceType The type of the source field.
|
||||||
|
* @param targetType The type of the target field.
|
||||||
|
* @returns True if the connection is valid, false otherwise.
|
||||||
|
*/
|
||||||
|
export const validateConnectionTypes = (sourceType: FieldType, targetType: FieldType) => {
|
||||||
|
// TODO: There's a bug with Collect -> Iterate nodes:
|
||||||
|
// https://github.com/invoke-ai/InvokeAI/issues/3956
|
||||||
|
// Once this is resolved, we can remove this check.
|
||||||
|
if (sourceType.name === 'CollectionField' && targetType.name === 'CollectionField') {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (areTypesEqual(sourceType, targetType)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Connection types must be the same for a connection, with exceptions:
|
||||||
|
* - CollectionItem can connect to any non-Collection
|
||||||
|
* - Non-Collections can connect to CollectionItem
|
||||||
|
* - Anything (non-Collections, Collections, CollectionOrScalar) can connect to CollectionOrScalar of the same base type
|
||||||
|
* - Generic Collection can connect to any other Collection or CollectionOrScalar
|
||||||
|
* - Any Collection can connect to a Generic Collection
|
||||||
|
*/
|
||||||
|
const isCollectionItemToNonCollection = sourceType.name === 'CollectionItemField' && !targetType.isCollection;
|
||||||
|
|
||||||
|
const isNonCollectionToCollectionItem =
|
||||||
|
targetType.name === 'CollectionItemField' && !sourceType.isCollection && !sourceType.isCollectionOrScalar;
|
||||||
|
|
||||||
|
const isAnythingToCollectionOrScalarOfSameBaseType =
|
||||||
|
targetType.isCollectionOrScalar && sourceType.name === targetType.name;
|
||||||
|
|
||||||
|
const isGenericCollectionToAnyCollectionOrCollectionOrScalar =
|
||||||
|
sourceType.name === 'CollectionField' && (targetType.isCollection || targetType.isCollectionOrScalar);
|
||||||
|
|
||||||
|
const isCollectionToGenericCollection = targetType.name === 'CollectionField' && sourceType.isCollection;
|
||||||
|
|
||||||
|
const areBothTypesSingle =
|
||||||
|
!sourceType.isCollection &&
|
||||||
|
!sourceType.isCollectionOrScalar &&
|
||||||
|
!targetType.isCollection &&
|
||||||
|
!targetType.isCollectionOrScalar;
|
||||||
|
|
||||||
|
const isIntToFloat = areBothTypesSingle && sourceType.name === 'IntegerField' && targetType.name === 'FloatField';
|
||||||
|
|
||||||
|
const isIntOrFloatToString =
|
||||||
|
areBothTypesSingle &&
|
||||||
|
(sourceType.name === 'IntegerField' || sourceType.name === 'FloatField') &&
|
||||||
|
targetType.name === 'StringField';
|
||||||
|
|
||||||
|
const isTargetAnyType = targetType.name === 'AnyField';
|
||||||
|
|
||||||
|
// One of these must be true for the connection to be valid
|
||||||
|
return (
|
||||||
|
isCollectionItemToNonCollection ||
|
||||||
|
isNonCollectionToCollectionItem ||
|
||||||
|
isAnythingToCollectionOrScalarOfSameBaseType ||
|
||||||
|
isGenericCollectionToAnyCollectionOrCollectionOrScalar ||
|
||||||
|
isCollectionToGenericCollection ||
|
||||||
|
isIntToFloat ||
|
||||||
|
isIntOrFloatToString ||
|
||||||
|
isTargetAnyType
|
||||||
|
);
|
||||||
|
};
|
@ -188,7 +188,6 @@ const zIntegerFieldInputTemplate = zFieldInputTemplateBase.extend({
|
|||||||
});
|
});
|
||||||
const zIntegerFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
const zIntegerFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||||
type: zIntegerFieldType,
|
type: zIntegerFieldType,
|
||||||
originalType: zFieldType.optional(),
|
|
||||||
});
|
});
|
||||||
export type IntegerFieldValue = z.infer<typeof zIntegerFieldValue>;
|
export type IntegerFieldValue = z.infer<typeof zIntegerFieldValue>;
|
||||||
export type IntegerFieldInputInstance = z.infer<typeof zIntegerFieldInputInstance>;
|
export type IntegerFieldInputInstance = z.infer<typeof zIntegerFieldInputInstance>;
|
||||||
@ -217,7 +216,6 @@ const zFloatFieldInputTemplate = zFieldInputTemplateBase.extend({
|
|||||||
});
|
});
|
||||||
const zFloatFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
const zFloatFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||||
type: zFloatFieldType,
|
type: zFloatFieldType,
|
||||||
originalType: zFieldType.optional(),
|
|
||||||
});
|
});
|
||||||
export type FloatFieldValue = z.infer<typeof zFloatFieldValue>;
|
export type FloatFieldValue = z.infer<typeof zFloatFieldValue>;
|
||||||
export type FloatFieldInputInstance = z.infer<typeof zFloatFieldInputInstance>;
|
export type FloatFieldInputInstance = z.infer<typeof zFloatFieldInputInstance>;
|
||||||
@ -243,7 +241,6 @@ const zStringFieldInputTemplate = zFieldInputTemplateBase.extend({
|
|||||||
});
|
});
|
||||||
const zStringFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
const zStringFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||||
type: zStringFieldType,
|
type: zStringFieldType,
|
||||||
originalType: zFieldType.optional(),
|
|
||||||
});
|
});
|
||||||
|
|
||||||
export type StringFieldValue = z.infer<typeof zStringFieldValue>;
|
export type StringFieldValue = z.infer<typeof zStringFieldValue>;
|
||||||
@ -268,7 +265,6 @@ const zBooleanFieldInputTemplate = zFieldInputTemplateBase.extend({
|
|||||||
});
|
});
|
||||||
const zBooleanFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
const zBooleanFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||||
type: zBooleanFieldType,
|
type: zBooleanFieldType,
|
||||||
originalType: zFieldType.optional(),
|
|
||||||
});
|
});
|
||||||
export type BooleanFieldValue = z.infer<typeof zBooleanFieldValue>;
|
export type BooleanFieldValue = z.infer<typeof zBooleanFieldValue>;
|
||||||
export type BooleanFieldInputInstance = z.infer<typeof zBooleanFieldInputInstance>;
|
export type BooleanFieldInputInstance = z.infer<typeof zBooleanFieldInputInstance>;
|
||||||
@ -294,7 +290,6 @@ const zEnumFieldInputTemplate = zFieldInputTemplateBase.extend({
|
|||||||
});
|
});
|
||||||
const zEnumFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
const zEnumFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||||
type: zEnumFieldType,
|
type: zEnumFieldType,
|
||||||
originalType: zFieldType.optional(),
|
|
||||||
});
|
});
|
||||||
export type EnumFieldValue = z.infer<typeof zEnumFieldValue>;
|
export type EnumFieldValue = z.infer<typeof zEnumFieldValue>;
|
||||||
export type EnumFieldInputInstance = z.infer<typeof zEnumFieldInputInstance>;
|
export type EnumFieldInputInstance = z.infer<typeof zEnumFieldInputInstance>;
|
||||||
@ -318,7 +313,6 @@ const zImageFieldInputTemplate = zFieldInputTemplateBase.extend({
|
|||||||
});
|
});
|
||||||
const zImageFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
const zImageFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||||
type: zImageFieldType,
|
type: zImageFieldType,
|
||||||
originalType: zFieldType.optional(),
|
|
||||||
});
|
});
|
||||||
export type ImageFieldValue = z.infer<typeof zImageFieldValue>;
|
export type ImageFieldValue = z.infer<typeof zImageFieldValue>;
|
||||||
export type ImageFieldInputInstance = z.infer<typeof zImageFieldInputInstance>;
|
export type ImageFieldInputInstance = z.infer<typeof zImageFieldInputInstance>;
|
||||||
@ -342,7 +336,6 @@ const zBoardFieldInputTemplate = zFieldInputTemplateBase.extend({
|
|||||||
});
|
});
|
||||||
const zBoardFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
const zBoardFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||||
type: zBoardFieldType,
|
type: zBoardFieldType,
|
||||||
originalType: zFieldType.optional(),
|
|
||||||
});
|
});
|
||||||
export type BoardFieldValue = z.infer<typeof zBoardFieldValue>;
|
export type BoardFieldValue = z.infer<typeof zBoardFieldValue>;
|
||||||
export type BoardFieldInputInstance = z.infer<typeof zBoardFieldInputInstance>;
|
export type BoardFieldInputInstance = z.infer<typeof zBoardFieldInputInstance>;
|
||||||
@ -366,7 +359,6 @@ const zColorFieldInputTemplate = zFieldInputTemplateBase.extend({
|
|||||||
});
|
});
|
||||||
const zColorFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
const zColorFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||||
type: zColorFieldType,
|
type: zColorFieldType,
|
||||||
originalType: zFieldType.optional(),
|
|
||||||
});
|
});
|
||||||
export type ColorFieldValue = z.infer<typeof zColorFieldValue>;
|
export type ColorFieldValue = z.infer<typeof zColorFieldValue>;
|
||||||
export type ColorFieldInputInstance = z.infer<typeof zColorFieldInputInstance>;
|
export type ColorFieldInputInstance = z.infer<typeof zColorFieldInputInstance>;
|
||||||
@ -390,7 +382,6 @@ const zMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
|||||||
});
|
});
|
||||||
const zMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
const zMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||||
type: zMainModelFieldType,
|
type: zMainModelFieldType,
|
||||||
originalType: zFieldType.optional(),
|
|
||||||
});
|
});
|
||||||
export type MainModelFieldValue = z.infer<typeof zMainModelFieldValue>;
|
export type MainModelFieldValue = z.infer<typeof zMainModelFieldValue>;
|
||||||
export type MainModelFieldInputInstance = z.infer<typeof zMainModelFieldInputInstance>;
|
export type MainModelFieldInputInstance = z.infer<typeof zMainModelFieldInputInstance>;
|
||||||
@ -413,7 +404,6 @@ const zModelIdentifierFieldInputTemplate = zFieldInputTemplateBase.extend({
|
|||||||
});
|
});
|
||||||
const zModelIdentifierFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
const zModelIdentifierFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||||
type: zModelIdentifierFieldType,
|
type: zModelIdentifierFieldType,
|
||||||
originalType: zFieldType.optional(),
|
|
||||||
});
|
});
|
||||||
export type ModelIdentifierFieldValue = z.infer<typeof zModelIdentifierFieldValue>;
|
export type ModelIdentifierFieldValue = z.infer<typeof zModelIdentifierFieldValue>;
|
||||||
export type ModelIdentifierFieldInputInstance = z.infer<typeof zModelIdentifierFieldInputInstance>;
|
export type ModelIdentifierFieldInputInstance = z.infer<typeof zModelIdentifierFieldInputInstance>;
|
||||||
@ -437,7 +427,6 @@ const zSDXLMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
|||||||
});
|
});
|
||||||
const zSDXLMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
const zSDXLMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||||
type: zSDXLMainModelFieldType,
|
type: zSDXLMainModelFieldType,
|
||||||
originalType: zFieldType.optional(),
|
|
||||||
});
|
});
|
||||||
export type SDXLMainModelFieldInputInstance = z.infer<typeof zSDXLMainModelFieldInputInstance>;
|
export type SDXLMainModelFieldInputInstance = z.infer<typeof zSDXLMainModelFieldInputInstance>;
|
||||||
export type SDXLMainModelFieldInputTemplate = z.infer<typeof zSDXLMainModelFieldInputTemplate>;
|
export type SDXLMainModelFieldInputTemplate = z.infer<typeof zSDXLMainModelFieldInputTemplate>;
|
||||||
@ -461,7 +450,6 @@ const zSDXLRefinerModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
|||||||
});
|
});
|
||||||
const zSDXLRefinerModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
const zSDXLRefinerModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||||
type: zSDXLRefinerModelFieldType,
|
type: zSDXLRefinerModelFieldType,
|
||||||
originalType: zFieldType.optional(),
|
|
||||||
});
|
});
|
||||||
export type SDXLRefinerModelFieldValue = z.infer<typeof zSDXLRefinerModelFieldValue>;
|
export type SDXLRefinerModelFieldValue = z.infer<typeof zSDXLRefinerModelFieldValue>;
|
||||||
export type SDXLRefinerModelFieldInputInstance = z.infer<typeof zSDXLRefinerModelFieldInputInstance>;
|
export type SDXLRefinerModelFieldInputInstance = z.infer<typeof zSDXLRefinerModelFieldInputInstance>;
|
||||||
@ -485,7 +473,6 @@ const zVAEModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
|||||||
});
|
});
|
||||||
const zVAEModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
const zVAEModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||||
type: zVAEModelFieldType,
|
type: zVAEModelFieldType,
|
||||||
originalType: zFieldType.optional(),
|
|
||||||
});
|
});
|
||||||
export type VAEModelFieldValue = z.infer<typeof zVAEModelFieldValue>;
|
export type VAEModelFieldValue = z.infer<typeof zVAEModelFieldValue>;
|
||||||
export type VAEModelFieldInputInstance = z.infer<typeof zVAEModelFieldInputInstance>;
|
export type VAEModelFieldInputInstance = z.infer<typeof zVAEModelFieldInputInstance>;
|
||||||
@ -509,7 +496,6 @@ const zLoRAModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
|||||||
});
|
});
|
||||||
const zLoRAModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
const zLoRAModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||||
type: zLoRAModelFieldType,
|
type: zLoRAModelFieldType,
|
||||||
originalType: zFieldType.optional(),
|
|
||||||
});
|
});
|
||||||
export type LoRAModelFieldValue = z.infer<typeof zLoRAModelFieldValue>;
|
export type LoRAModelFieldValue = z.infer<typeof zLoRAModelFieldValue>;
|
||||||
export type LoRAModelFieldInputInstance = z.infer<typeof zLoRAModelFieldInputInstance>;
|
export type LoRAModelFieldInputInstance = z.infer<typeof zLoRAModelFieldInputInstance>;
|
||||||
@ -533,7 +519,6 @@ const zControlNetModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
|||||||
});
|
});
|
||||||
const zControlNetModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
const zControlNetModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||||
type: zControlNetModelFieldType,
|
type: zControlNetModelFieldType,
|
||||||
originalType: zFieldType.optional(),
|
|
||||||
});
|
});
|
||||||
export type ControlNetModelFieldValue = z.infer<typeof zControlNetModelFieldValue>;
|
export type ControlNetModelFieldValue = z.infer<typeof zControlNetModelFieldValue>;
|
||||||
export type ControlNetModelFieldInputInstance = z.infer<typeof zControlNetModelFieldInputInstance>;
|
export type ControlNetModelFieldInputInstance = z.infer<typeof zControlNetModelFieldInputInstance>;
|
||||||
@ -557,7 +542,6 @@ const zIPAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
|||||||
});
|
});
|
||||||
const zIPAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
const zIPAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||||
type: zIPAdapterModelFieldType,
|
type: zIPAdapterModelFieldType,
|
||||||
originalType: zFieldType.optional(),
|
|
||||||
});
|
});
|
||||||
export type IPAdapterModelFieldValue = z.infer<typeof zIPAdapterModelFieldValue>;
|
export type IPAdapterModelFieldValue = z.infer<typeof zIPAdapterModelFieldValue>;
|
||||||
export type IPAdapterModelFieldInputInstance = z.infer<typeof zIPAdapterModelFieldInputInstance>;
|
export type IPAdapterModelFieldInputInstance = z.infer<typeof zIPAdapterModelFieldInputInstance>;
|
||||||
@ -581,7 +565,6 @@ const zT2IAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
|||||||
});
|
});
|
||||||
const zT2IAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
const zT2IAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||||
type: zT2IAdapterModelFieldType,
|
type: zT2IAdapterModelFieldType,
|
||||||
originalType: zFieldType.optional(),
|
|
||||||
});
|
});
|
||||||
export type T2IAdapterModelFieldValue = z.infer<typeof zT2IAdapterModelFieldValue>;
|
export type T2IAdapterModelFieldValue = z.infer<typeof zT2IAdapterModelFieldValue>;
|
||||||
export type T2IAdapterModelFieldInputInstance = z.infer<typeof zT2IAdapterModelFieldInputInstance>;
|
export type T2IAdapterModelFieldInputInstance = z.infer<typeof zT2IAdapterModelFieldInputInstance>;
|
||||||
@ -605,7 +588,6 @@ const zSchedulerFieldInputTemplate = zFieldInputTemplateBase.extend({
|
|||||||
});
|
});
|
||||||
const zSchedulerFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
const zSchedulerFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||||
type: zSchedulerFieldType,
|
type: zSchedulerFieldType,
|
||||||
originalType: zFieldType.optional(),
|
|
||||||
});
|
});
|
||||||
export type SchedulerFieldValue = z.infer<typeof zSchedulerFieldValue>;
|
export type SchedulerFieldValue = z.infer<typeof zSchedulerFieldValue>;
|
||||||
export type SchedulerFieldInputInstance = z.infer<typeof zSchedulerFieldInputInstance>;
|
export type SchedulerFieldInputInstance = z.infer<typeof zSchedulerFieldInputInstance>;
|
||||||
@ -641,7 +623,6 @@ const zStatelessFieldInputTemplate = zFieldInputTemplateBase.extend({
|
|||||||
});
|
});
|
||||||
const zStatelessFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
const zStatelessFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||||
type: zStatelessFieldType,
|
type: zStatelessFieldType,
|
||||||
originalType: zFieldType.optional(),
|
|
||||||
});
|
});
|
||||||
|
|
||||||
export type StatelessFieldInputTemplate = z.infer<typeof zStatelessFieldInputTemplate>;
|
export type StatelessFieldInputTemplate = z.infer<typeof zStatelessFieldInputTemplate>;
|
||||||
|
@ -1,942 +1,19 @@
|
|||||||
|
import { schema, templates } from 'features/nodes/store/util/testUtils';
|
||||||
import { parseSchema } from 'features/nodes/util/schema/parseSchema';
|
import { parseSchema } from 'features/nodes/util/schema/parseSchema';
|
||||||
import { omit, pick } from 'lodash-es';
|
import { omit, pick } from 'lodash-es';
|
||||||
import type { OpenAPIV3_1 } from 'openapi-types';
|
|
||||||
import { describe, expect, it } from 'vitest';
|
import { describe, expect, it } from 'vitest';
|
||||||
|
|
||||||
describe('parseSchema', () => {
|
describe('parseSchema', () => {
|
||||||
it('should parse the schema', () => {
|
it('should parse the schema', () => {
|
||||||
const templates = parseSchema(schema);
|
const parsed = parseSchema(schema);
|
||||||
expect(templates).toEqual(expected);
|
expect(parsed).toEqual(templates);
|
||||||
});
|
});
|
||||||
it('should omit denied nodes', () => {
|
it('should omit denied nodes', () => {
|
||||||
const templates = parseSchema(schema, undefined, ['add']);
|
const parsed = parseSchema(schema, undefined, ['add']);
|
||||||
expect(templates).toEqual(omit(expected, 'add'));
|
expect(parsed).toEqual(omit(templates, 'add'));
|
||||||
});
|
});
|
||||||
it('should include only allowed nodes', () => {
|
it('should include only allowed nodes', () => {
|
||||||
const templates = parseSchema(schema, ['add']);
|
const parsed = parseSchema(schema, ['add']);
|
||||||
expect(templates).toEqual(pick(expected, 'add'));
|
expect(parsed).toEqual(pick(templates, 'add'));
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
const expected = {
|
|
||||||
add: {
|
|
||||||
title: 'Add Integers',
|
|
||||||
type: 'add',
|
|
||||||
version: '1.0.1',
|
|
||||||
tags: ['math', 'add'],
|
|
||||||
description: 'Adds two numbers',
|
|
||||||
outputType: 'integer_output',
|
|
||||||
inputs: {
|
|
||||||
a: {
|
|
||||||
name: 'a',
|
|
||||||
title: 'A',
|
|
||||||
required: false,
|
|
||||||
description: 'The first number',
|
|
||||||
fieldKind: 'input',
|
|
||||||
input: 'any',
|
|
||||||
ui_hidden: false,
|
|
||||||
type: {
|
|
||||||
name: 'IntegerField',
|
|
||||||
isCollection: false,
|
|
||||||
isCollectionOrScalar: false,
|
|
||||||
},
|
|
||||||
default: 0,
|
|
||||||
},
|
|
||||||
b: {
|
|
||||||
name: 'b',
|
|
||||||
title: 'B',
|
|
||||||
required: false,
|
|
||||||
description: 'The second number',
|
|
||||||
fieldKind: 'input',
|
|
||||||
input: 'any',
|
|
||||||
ui_hidden: false,
|
|
||||||
type: {
|
|
||||||
name: 'IntegerField',
|
|
||||||
isCollection: false,
|
|
||||||
isCollectionOrScalar: false,
|
|
||||||
},
|
|
||||||
default: 0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
outputs: {
|
|
||||||
value: {
|
|
||||||
fieldKind: 'output',
|
|
||||||
name: 'value',
|
|
||||||
title: 'Value',
|
|
||||||
description: 'The output integer',
|
|
||||||
type: {
|
|
||||||
name: 'IntegerField',
|
|
||||||
isCollection: false,
|
|
||||||
isCollectionOrScalar: false,
|
|
||||||
},
|
|
||||||
ui_hidden: false,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
useCache: true,
|
|
||||||
nodePack: 'invokeai',
|
|
||||||
classification: 'stable',
|
|
||||||
},
|
|
||||||
scheduler: {
|
|
||||||
title: 'Scheduler',
|
|
||||||
type: 'scheduler',
|
|
||||||
version: '1.0.0',
|
|
||||||
tags: ['scheduler'],
|
|
||||||
description: 'Selects a scheduler.',
|
|
||||||
outputType: 'scheduler_output',
|
|
||||||
inputs: {
|
|
||||||
scheduler: {
|
|
||||||
name: 'scheduler',
|
|
||||||
title: 'Scheduler',
|
|
||||||
required: false,
|
|
||||||
description: 'Scheduler to use during inference',
|
|
||||||
fieldKind: 'input',
|
|
||||||
input: 'any',
|
|
||||||
ui_hidden: false,
|
|
||||||
ui_type: 'SchedulerField',
|
|
||||||
type: {
|
|
||||||
name: 'SchedulerField',
|
|
||||||
isCollection: false,
|
|
||||||
isCollectionOrScalar: false,
|
|
||||||
originalType: {
|
|
||||||
name: 'EnumField',
|
|
||||||
isCollection: false,
|
|
||||||
isCollectionOrScalar: false,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
default: 'euler',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
outputs: {
|
|
||||||
scheduler: {
|
|
||||||
fieldKind: 'output',
|
|
||||||
name: 'scheduler',
|
|
||||||
title: 'Scheduler',
|
|
||||||
description: 'Scheduler to use during inference',
|
|
||||||
type: {
|
|
||||||
name: 'SchedulerField',
|
|
||||||
isCollection: false,
|
|
||||||
isCollectionOrScalar: false,
|
|
||||||
originalType: {
|
|
||||||
name: 'EnumField',
|
|
||||||
isCollection: false,
|
|
||||||
isCollectionOrScalar: false,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
ui_hidden: false,
|
|
||||||
ui_type: 'SchedulerField',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
useCache: true,
|
|
||||||
nodePack: 'invokeai',
|
|
||||||
classification: 'stable',
|
|
||||||
},
|
|
||||||
main_model_loader: {
|
|
||||||
title: 'Main Model',
|
|
||||||
type: 'main_model_loader',
|
|
||||||
version: '1.0.2',
|
|
||||||
tags: ['model'],
|
|
||||||
description: 'Loads a main model, outputting its submodels.',
|
|
||||||
outputType: 'model_loader_output',
|
|
||||||
inputs: {
|
|
||||||
model: {
|
|
||||||
name: 'model',
|
|
||||||
title: 'Model',
|
|
||||||
required: true,
|
|
||||||
description: 'Main model (UNet, VAE, CLIP) to load',
|
|
||||||
fieldKind: 'input',
|
|
||||||
input: 'direct',
|
|
||||||
ui_hidden: false,
|
|
||||||
ui_type: 'MainModelField',
|
|
||||||
type: {
|
|
||||||
name: 'MainModelField',
|
|
||||||
isCollection: false,
|
|
||||||
isCollectionOrScalar: false,
|
|
||||||
originalType: {
|
|
||||||
name: 'ModelIdentifierField',
|
|
||||||
isCollection: false,
|
|
||||||
isCollectionOrScalar: false,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
outputs: {
|
|
||||||
vae: {
|
|
||||||
fieldKind: 'output',
|
|
||||||
name: 'vae',
|
|
||||||
title: 'VAE',
|
|
||||||
description: 'VAE',
|
|
||||||
type: {
|
|
||||||
name: 'VAEField',
|
|
||||||
isCollection: false,
|
|
||||||
isCollectionOrScalar: false,
|
|
||||||
},
|
|
||||||
ui_hidden: false,
|
|
||||||
},
|
|
||||||
clip: {
|
|
||||||
fieldKind: 'output',
|
|
||||||
name: 'clip',
|
|
||||||
title: 'CLIP',
|
|
||||||
description: 'CLIP (tokenizer, text encoder, LoRAs) and skipped layer count',
|
|
||||||
type: {
|
|
||||||
name: 'CLIPField',
|
|
||||||
isCollection: false,
|
|
||||||
isCollectionOrScalar: false,
|
|
||||||
},
|
|
||||||
ui_hidden: false,
|
|
||||||
},
|
|
||||||
unet: {
|
|
||||||
fieldKind: 'output',
|
|
||||||
name: 'unet',
|
|
||||||
title: 'UNet',
|
|
||||||
description: 'UNet (scheduler, LoRAs)',
|
|
||||||
type: {
|
|
||||||
name: 'UNetField',
|
|
||||||
isCollection: false,
|
|
||||||
isCollectionOrScalar: false,
|
|
||||||
},
|
|
||||||
ui_hidden: false,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
useCache: true,
|
|
||||||
nodePack: 'invokeai',
|
|
||||||
classification: 'stable',
|
|
||||||
},
|
|
||||||
collect: {
|
|
||||||
title: 'Collect',
|
|
||||||
type: 'collect',
|
|
||||||
version: '1.0.0',
|
|
||||||
tags: [],
|
|
||||||
description: 'Collects values into a collection',
|
|
||||||
outputType: 'collect_output',
|
|
||||||
inputs: {
|
|
||||||
item: {
|
|
||||||
name: 'item',
|
|
||||||
title: 'Collection Item',
|
|
||||||
required: false,
|
|
||||||
description: 'The item to collect (all inputs must be of the same type)',
|
|
||||||
fieldKind: 'input',
|
|
||||||
input: 'connection',
|
|
||||||
ui_hidden: false,
|
|
||||||
ui_type: 'CollectionItemField',
|
|
||||||
type: {
|
|
||||||
name: 'CollectionItemField',
|
|
||||||
isCollection: false,
|
|
||||||
isCollectionOrScalar: false,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
outputs: {
|
|
||||||
collection: {
|
|
||||||
fieldKind: 'output',
|
|
||||||
name: 'collection',
|
|
||||||
title: 'Collection',
|
|
||||||
description: 'The collection of input items',
|
|
||||||
type: {
|
|
||||||
name: 'CollectionField',
|
|
||||||
isCollection: true,
|
|
||||||
isCollectionOrScalar: false,
|
|
||||||
},
|
|
||||||
ui_hidden: false,
|
|
||||||
ui_type: 'CollectionField',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
useCache: true,
|
|
||||||
classification: 'stable',
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
const schema = {
|
|
||||||
openapi: '3.1.0',
|
|
||||||
info: {
|
|
||||||
title: 'Invoke - Community Edition',
|
|
||||||
description: 'An API for invoking AI image operations',
|
|
||||||
version: '1.0.0',
|
|
||||||
},
|
|
||||||
components: {
|
|
||||||
schemas: {
|
|
||||||
AddInvocation: {
|
|
||||||
properties: {
|
|
||||||
id: {
|
|
||||||
type: 'string',
|
|
||||||
title: 'Id',
|
|
||||||
description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.',
|
|
||||||
field_kind: 'node_attribute',
|
|
||||||
},
|
|
||||||
is_intermediate: {
|
|
||||||
type: 'boolean',
|
|
||||||
title: 'Is Intermediate',
|
|
||||||
description: 'Whether or not this is an intermediate invocation.',
|
|
||||||
default: false,
|
|
||||||
field_kind: 'node_attribute',
|
|
||||||
ui_type: 'IsIntermediate',
|
|
||||||
},
|
|
||||||
use_cache: {
|
|
||||||
type: 'boolean',
|
|
||||||
title: 'Use Cache',
|
|
||||||
description: 'Whether or not to use the cache',
|
|
||||||
default: true,
|
|
||||||
field_kind: 'node_attribute',
|
|
||||||
},
|
|
||||||
a: {
|
|
||||||
type: 'integer',
|
|
||||||
title: 'A',
|
|
||||||
description: 'The first number',
|
|
||||||
default: 0,
|
|
||||||
field_kind: 'input',
|
|
||||||
input: 'any',
|
|
||||||
orig_default: 0,
|
|
||||||
orig_required: false,
|
|
||||||
ui_hidden: false,
|
|
||||||
},
|
|
||||||
b: {
|
|
||||||
type: 'integer',
|
|
||||||
title: 'B',
|
|
||||||
description: 'The second number',
|
|
||||||
default: 0,
|
|
||||||
field_kind: 'input',
|
|
||||||
input: 'any',
|
|
||||||
orig_default: 0,
|
|
||||||
orig_required: false,
|
|
||||||
ui_hidden: false,
|
|
||||||
},
|
|
||||||
type: {
|
|
||||||
type: 'string',
|
|
||||||
enum: ['add'],
|
|
||||||
const: 'add',
|
|
||||||
title: 'type',
|
|
||||||
default: 'add',
|
|
||||||
field_kind: 'node_attribute',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
type: 'object',
|
|
||||||
required: ['type', 'id'],
|
|
||||||
title: 'Add Integers',
|
|
||||||
description: 'Adds two numbers',
|
|
||||||
category: 'math',
|
|
||||||
classification: 'stable',
|
|
||||||
node_pack: 'invokeai',
|
|
||||||
tags: ['math', 'add'],
|
|
||||||
version: '1.0.1',
|
|
||||||
output: {
|
|
||||||
$ref: '#/components/schemas/IntegerOutput',
|
|
||||||
},
|
|
||||||
class: 'invocation',
|
|
||||||
},
|
|
||||||
IntegerOutput: {
|
|
||||||
description: 'Base class for nodes that output a single integer',
|
|
||||||
properties: {
|
|
||||||
value: {
|
|
||||||
description: 'The output integer',
|
|
||||||
field_kind: 'output',
|
|
||||||
title: 'Value',
|
|
||||||
type: 'integer',
|
|
||||||
ui_hidden: false,
|
|
||||||
},
|
|
||||||
type: {
|
|
||||||
const: 'integer_output',
|
|
||||||
default: 'integer_output',
|
|
||||||
enum: ['integer_output'],
|
|
||||||
field_kind: 'node_attribute',
|
|
||||||
title: 'type',
|
|
||||||
type: 'string',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
required: ['value', 'type', 'type'],
|
|
||||||
title: 'IntegerOutput',
|
|
||||||
type: 'object',
|
|
||||||
class: 'output',
|
|
||||||
},
|
|
||||||
SchedulerInvocation: {
|
|
||||||
properties: {
|
|
||||||
id: {
|
|
||||||
type: 'string',
|
|
||||||
title: 'Id',
|
|
||||||
description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.',
|
|
||||||
field_kind: 'node_attribute',
|
|
||||||
},
|
|
||||||
is_intermediate: {
|
|
||||||
type: 'boolean',
|
|
||||||
title: 'Is Intermediate',
|
|
||||||
description: 'Whether or not this is an intermediate invocation.',
|
|
||||||
default: false,
|
|
||||||
field_kind: 'node_attribute',
|
|
||||||
ui_type: 'IsIntermediate',
|
|
||||||
},
|
|
||||||
use_cache: {
|
|
||||||
type: 'boolean',
|
|
||||||
title: 'Use Cache',
|
|
||||||
description: 'Whether or not to use the cache',
|
|
||||||
default: true,
|
|
||||||
field_kind: 'node_attribute',
|
|
||||||
},
|
|
||||||
scheduler: {
|
|
||||||
type: 'string',
|
|
||||||
enum: [
|
|
||||||
'ddim',
|
|
||||||
'ddpm',
|
|
||||||
'deis',
|
|
||||||
'lms',
|
|
||||||
'lms_k',
|
|
||||||
'pndm',
|
|
||||||
'heun',
|
|
||||||
'heun_k',
|
|
||||||
'euler',
|
|
||||||
'euler_k',
|
|
||||||
'euler_a',
|
|
||||||
'kdpm_2',
|
|
||||||
'kdpm_2_a',
|
|
||||||
'dpmpp_2s',
|
|
||||||
'dpmpp_2s_k',
|
|
||||||
'dpmpp_2m',
|
|
||||||
'dpmpp_2m_k',
|
|
||||||
'dpmpp_2m_sde',
|
|
||||||
'dpmpp_2m_sde_k',
|
|
||||||
'dpmpp_sde',
|
|
||||||
'dpmpp_sde_k',
|
|
||||||
'unipc',
|
|
||||||
'lcm',
|
|
||||||
'tcd',
|
|
||||||
],
|
|
||||||
title: 'Scheduler',
|
|
||||||
description: 'Scheduler to use during inference',
|
|
||||||
default: 'euler',
|
|
||||||
field_kind: 'input',
|
|
||||||
input: 'any',
|
|
||||||
orig_default: 'euler',
|
|
||||||
orig_required: false,
|
|
||||||
ui_hidden: false,
|
|
||||||
ui_type: 'SchedulerField',
|
|
||||||
},
|
|
||||||
type: {
|
|
||||||
type: 'string',
|
|
||||||
enum: ['scheduler'],
|
|
||||||
const: 'scheduler',
|
|
||||||
title: 'type',
|
|
||||||
default: 'scheduler',
|
|
||||||
field_kind: 'node_attribute',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
type: 'object',
|
|
||||||
required: ['type', 'id'],
|
|
||||||
title: 'Scheduler',
|
|
||||||
description: 'Selects a scheduler.',
|
|
||||||
category: 'latents',
|
|
||||||
classification: 'stable',
|
|
||||||
node_pack: 'invokeai',
|
|
||||||
tags: ['scheduler'],
|
|
||||||
version: '1.0.0',
|
|
||||||
output: {
|
|
||||||
$ref: '#/components/schemas/SchedulerOutput',
|
|
||||||
},
|
|
||||||
class: 'invocation',
|
|
||||||
},
|
|
||||||
SchedulerOutput: {
|
|
||||||
properties: {
|
|
||||||
scheduler: {
|
|
||||||
description: 'Scheduler to use during inference',
|
|
||||||
enum: [
|
|
||||||
'ddim',
|
|
||||||
'ddpm',
|
|
||||||
'deis',
|
|
||||||
'lms',
|
|
||||||
'lms_k',
|
|
||||||
'pndm',
|
|
||||||
'heun',
|
|
||||||
'heun_k',
|
|
||||||
'euler',
|
|
||||||
'euler_k',
|
|
||||||
'euler_a',
|
|
||||||
'kdpm_2',
|
|
||||||
'kdpm_2_a',
|
|
||||||
'dpmpp_2s',
|
|
||||||
'dpmpp_2s_k',
|
|
||||||
'dpmpp_2m',
|
|
||||||
'dpmpp_2m_k',
|
|
||||||
'dpmpp_2m_sde',
|
|
||||||
'dpmpp_2m_sde_k',
|
|
||||||
'dpmpp_sde',
|
|
||||||
'dpmpp_sde_k',
|
|
||||||
'unipc',
|
|
||||||
'lcm',
|
|
||||||
'tcd',
|
|
||||||
],
|
|
||||||
field_kind: 'output',
|
|
||||||
title: 'Scheduler',
|
|
||||||
type: 'string',
|
|
||||||
ui_hidden: false,
|
|
||||||
ui_type: 'SchedulerField',
|
|
||||||
},
|
|
||||||
type: {
|
|
||||||
const: 'scheduler_output',
|
|
||||||
default: 'scheduler_output',
|
|
||||||
enum: ['scheduler_output'],
|
|
||||||
field_kind: 'node_attribute',
|
|
||||||
title: 'type',
|
|
||||||
type: 'string',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
required: ['scheduler', 'type', 'type'],
|
|
||||||
title: 'SchedulerOutput',
|
|
||||||
type: 'object',
|
|
||||||
class: 'output',
|
|
||||||
},
|
|
||||||
MainModelLoaderInvocation: {
|
|
||||||
properties: {
|
|
||||||
id: {
|
|
||||||
type: 'string',
|
|
||||||
title: 'Id',
|
|
||||||
description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.',
|
|
||||||
field_kind: 'node_attribute',
|
|
||||||
},
|
|
||||||
is_intermediate: {
|
|
||||||
type: 'boolean',
|
|
||||||
title: 'Is Intermediate',
|
|
||||||
description: 'Whether or not this is an intermediate invocation.',
|
|
||||||
default: false,
|
|
||||||
field_kind: 'node_attribute',
|
|
||||||
ui_type: 'IsIntermediate',
|
|
||||||
},
|
|
||||||
use_cache: {
|
|
||||||
type: 'boolean',
|
|
||||||
title: 'Use Cache',
|
|
||||||
description: 'Whether or not to use the cache',
|
|
||||||
default: true,
|
|
||||||
field_kind: 'node_attribute',
|
|
||||||
},
|
|
||||||
model: {
|
|
||||||
allOf: [
|
|
||||||
{
|
|
||||||
$ref: '#/components/schemas/ModelIdentifierField',
|
|
||||||
},
|
|
||||||
],
|
|
||||||
description: 'Main model (UNet, VAE, CLIP) to load',
|
|
||||||
field_kind: 'input',
|
|
||||||
input: 'direct',
|
|
||||||
orig_required: true,
|
|
||||||
ui_hidden: false,
|
|
||||||
ui_type: 'MainModelField',
|
|
||||||
},
|
|
||||||
type: {
|
|
||||||
type: 'string',
|
|
||||||
enum: ['main_model_loader'],
|
|
||||||
const: 'main_model_loader',
|
|
||||||
title: 'type',
|
|
||||||
default: 'main_model_loader',
|
|
||||||
field_kind: 'node_attribute',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
type: 'object',
|
|
||||||
required: ['model', 'type', 'id'],
|
|
||||||
title: 'Main Model',
|
|
||||||
description: 'Loads a main model, outputting its submodels.',
|
|
||||||
category: 'model',
|
|
||||||
classification: 'stable',
|
|
||||||
node_pack: 'invokeai',
|
|
||||||
tags: ['model'],
|
|
||||||
version: '1.0.2',
|
|
||||||
output: {
|
|
||||||
$ref: '#/components/schemas/ModelLoaderOutput',
|
|
||||||
},
|
|
||||||
class: 'invocation',
|
|
||||||
},
|
|
||||||
ModelIdentifierField: {
|
|
||||||
properties: {
|
|
||||||
key: {
|
|
||||||
description: "The model's unique key",
|
|
||||||
title: 'Key',
|
|
||||||
type: 'string',
|
|
||||||
},
|
|
||||||
hash: {
|
|
||||||
description: "The model's BLAKE3 hash",
|
|
||||||
title: 'Hash',
|
|
||||||
type: 'string',
|
|
||||||
},
|
|
||||||
name: {
|
|
||||||
description: "The model's name",
|
|
||||||
title: 'Name',
|
|
||||||
type: 'string',
|
|
||||||
},
|
|
||||||
base: {
|
|
||||||
allOf: [
|
|
||||||
{
|
|
||||||
$ref: '#/components/schemas/BaseModelType',
|
|
||||||
},
|
|
||||||
],
|
|
||||||
description: "The model's base model type",
|
|
||||||
},
|
|
||||||
type: {
|
|
||||||
allOf: [
|
|
||||||
{
|
|
||||||
$ref: '#/components/schemas/ModelType',
|
|
||||||
},
|
|
||||||
],
|
|
||||||
description: "The model's type",
|
|
||||||
},
|
|
||||||
submodel_type: {
|
|
||||||
anyOf: [
|
|
||||||
{
|
|
||||||
$ref: '#/components/schemas/SubModelType',
|
|
||||||
},
|
|
||||||
{
|
|
||||||
type: 'null',
|
|
||||||
},
|
|
||||||
],
|
|
||||||
default: null,
|
|
||||||
description: 'The submodel to load, if this is a main model',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
required: ['key', 'hash', 'name', 'base', 'type'],
|
|
||||||
title: 'ModelIdentifierField',
|
|
||||||
type: 'object',
|
|
||||||
},
|
|
||||||
BaseModelType: {
|
|
||||||
description: 'Base model type.',
|
|
||||||
enum: ['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'],
|
|
||||||
title: 'BaseModelType',
|
|
||||||
type: 'string',
|
|
||||||
},
|
|
||||||
ModelType: {
|
|
||||||
description: 'Model type.',
|
|
||||||
enum: ['onnx', 'main', 'vae', 'lora', 'controlnet', 'embedding', 'ip_adapter', 'clip_vision', 't2i_adapter'],
|
|
||||||
title: 'ModelType',
|
|
||||||
type: 'string',
|
|
||||||
},
|
|
||||||
SubModelType: {
|
|
||||||
description: 'Submodel type.',
|
|
||||||
enum: [
|
|
||||||
'unet',
|
|
||||||
'text_encoder',
|
|
||||||
'text_encoder_2',
|
|
||||||
'tokenizer',
|
|
||||||
'tokenizer_2',
|
|
||||||
'vae',
|
|
||||||
'vae_decoder',
|
|
||||||
'vae_encoder',
|
|
||||||
'scheduler',
|
|
||||||
'safety_checker',
|
|
||||||
],
|
|
||||||
title: 'SubModelType',
|
|
||||||
type: 'string',
|
|
||||||
},
|
|
||||||
ModelLoaderOutput: {
|
|
||||||
description: 'Model loader output',
|
|
||||||
properties: {
|
|
||||||
vae: {
|
|
||||||
allOf: [
|
|
||||||
{
|
|
||||||
$ref: '#/components/schemas/VAEField',
|
|
||||||
},
|
|
||||||
],
|
|
||||||
description: 'VAE',
|
|
||||||
field_kind: 'output',
|
|
||||||
title: 'VAE',
|
|
||||||
ui_hidden: false,
|
|
||||||
},
|
|
||||||
type: {
|
|
||||||
const: 'model_loader_output',
|
|
||||||
default: 'model_loader_output',
|
|
||||||
enum: ['model_loader_output'],
|
|
||||||
field_kind: 'node_attribute',
|
|
||||||
title: 'type',
|
|
||||||
type: 'string',
|
|
||||||
},
|
|
||||||
clip: {
|
|
||||||
allOf: [
|
|
||||||
{
|
|
||||||
$ref: '#/components/schemas/CLIPField',
|
|
||||||
},
|
|
||||||
],
|
|
||||||
description: 'CLIP (tokenizer, text encoder, LoRAs) and skipped layer count',
|
|
||||||
field_kind: 'output',
|
|
||||||
title: 'CLIP',
|
|
||||||
ui_hidden: false,
|
|
||||||
},
|
|
||||||
unet: {
|
|
||||||
allOf: [
|
|
||||||
{
|
|
||||||
$ref: '#/components/schemas/UNetField',
|
|
||||||
},
|
|
||||||
],
|
|
||||||
description: 'UNet (scheduler, LoRAs)',
|
|
||||||
field_kind: 'output',
|
|
||||||
title: 'UNet',
|
|
||||||
ui_hidden: false,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
required: ['vae', 'type', 'clip', 'unet', 'type'],
|
|
||||||
title: 'ModelLoaderOutput',
|
|
||||||
type: 'object',
|
|
||||||
class: 'output',
|
|
||||||
},
|
|
||||||
UNetField: {
|
|
||||||
properties: {
|
|
||||||
unet: {
|
|
||||||
allOf: [
|
|
||||||
{
|
|
||||||
$ref: '#/components/schemas/ModelIdentifierField',
|
|
||||||
},
|
|
||||||
],
|
|
||||||
description: 'Info to load unet submodel',
|
|
||||||
},
|
|
||||||
scheduler: {
|
|
||||||
allOf: [
|
|
||||||
{
|
|
||||||
$ref: '#/components/schemas/ModelIdentifierField',
|
|
||||||
},
|
|
||||||
],
|
|
||||||
description: 'Info to load scheduler submodel',
|
|
||||||
},
|
|
||||||
loras: {
|
|
||||||
description: 'LoRAs to apply on model loading',
|
|
||||||
items: {
|
|
||||||
$ref: '#/components/schemas/LoRAField',
|
|
||||||
},
|
|
||||||
title: 'Loras',
|
|
||||||
type: 'array',
|
|
||||||
},
|
|
||||||
seamless_axes: {
|
|
||||||
description: 'Axes("x" and "y") to which apply seamless',
|
|
||||||
items: {
|
|
||||||
type: 'string',
|
|
||||||
},
|
|
||||||
title: 'Seamless Axes',
|
|
||||||
type: 'array',
|
|
||||||
},
|
|
||||||
freeu_config: {
|
|
||||||
anyOf: [
|
|
||||||
{
|
|
||||||
$ref: '#/components/schemas/FreeUConfig',
|
|
||||||
},
|
|
||||||
{
|
|
||||||
type: 'null',
|
|
||||||
},
|
|
||||||
],
|
|
||||||
default: null,
|
|
||||||
description: 'FreeU configuration',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
required: ['unet', 'scheduler', 'loras'],
|
|
||||||
title: 'UNetField',
|
|
||||||
type: 'object',
|
|
||||||
class: 'output',
|
|
||||||
},
|
|
||||||
LoRAField: {
|
|
||||||
properties: {
|
|
||||||
lora: {
|
|
||||||
allOf: [
|
|
||||||
{
|
|
||||||
$ref: '#/components/schemas/ModelIdentifierField',
|
|
||||||
},
|
|
||||||
],
|
|
||||||
description: 'Info to load lora model',
|
|
||||||
},
|
|
||||||
weight: {
|
|
||||||
description: 'Weight to apply to lora model',
|
|
||||||
title: 'Weight',
|
|
||||||
type: 'number',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
required: ['lora', 'weight'],
|
|
||||||
title: 'LoRAField',
|
|
||||||
type: 'object',
|
|
||||||
class: 'output',
|
|
||||||
},
|
|
||||||
FreeUConfig: {
|
|
||||||
description:
|
|
||||||
'Configuration for the FreeU hyperparameters.\n- https://huggingface.co/docs/diffusers/main/en/using-diffusers/freeu\n- https://github.com/ChenyangSi/FreeU',
|
|
||||||
properties: {
|
|
||||||
s1: {
|
|
||||||
description:
|
|
||||||
'Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.',
|
|
||||||
maximum: 3.0,
|
|
||||||
minimum: -1.0,
|
|
||||||
title: 'S1',
|
|
||||||
type: 'number',
|
|
||||||
},
|
|
||||||
s2: {
|
|
||||||
description:
|
|
||||||
'Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.',
|
|
||||||
maximum: 3.0,
|
|
||||||
minimum: -1.0,
|
|
||||||
title: 'S2',
|
|
||||||
type: 'number',
|
|
||||||
},
|
|
||||||
b1: {
|
|
||||||
description: 'Scaling factor for stage 1 to amplify the contributions of backbone features.',
|
|
||||||
maximum: 3.0,
|
|
||||||
minimum: -1.0,
|
|
||||||
title: 'B1',
|
|
||||||
type: 'number',
|
|
||||||
},
|
|
||||||
b2: {
|
|
||||||
description: 'Scaling factor for stage 2 to amplify the contributions of backbone features.',
|
|
||||||
maximum: 3.0,
|
|
||||||
minimum: -1.0,
|
|
||||||
title: 'B2',
|
|
||||||
type: 'number',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
required: ['s1', 's2', 'b1', 'b2'],
|
|
||||||
title: 'FreeUConfig',
|
|
||||||
type: 'object',
|
|
||||||
class: 'output',
|
|
||||||
},
|
|
||||||
VAEField: {
|
|
||||||
properties: {
|
|
||||||
vae: {
|
|
||||||
allOf: [
|
|
||||||
{
|
|
||||||
$ref: '#/components/schemas/ModelIdentifierField',
|
|
||||||
},
|
|
||||||
],
|
|
||||||
description: 'Info to load vae submodel',
|
|
||||||
},
|
|
||||||
seamless_axes: {
|
|
||||||
description: 'Axes("x" and "y") to which apply seamless',
|
|
||||||
items: {
|
|
||||||
type: 'string',
|
|
||||||
},
|
|
||||||
title: 'Seamless Axes',
|
|
||||||
type: 'array',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
required: ['vae'],
|
|
||||||
title: 'VAEField',
|
|
||||||
type: 'object',
|
|
||||||
class: 'output',
|
|
||||||
},
|
|
||||||
CLIPField: {
|
|
||||||
properties: {
|
|
||||||
tokenizer: {
|
|
||||||
allOf: [
|
|
||||||
{
|
|
||||||
$ref: '#/components/schemas/ModelIdentifierField',
|
|
||||||
},
|
|
||||||
],
|
|
||||||
description: 'Info to load tokenizer submodel',
|
|
||||||
},
|
|
||||||
text_encoder: {
|
|
||||||
allOf: [
|
|
||||||
{
|
|
||||||
$ref: '#/components/schemas/ModelIdentifierField',
|
|
||||||
},
|
|
||||||
],
|
|
||||||
description: 'Info to load text_encoder submodel',
|
|
||||||
},
|
|
||||||
skipped_layers: {
|
|
||||||
description: 'Number of skipped layers in text_encoder',
|
|
||||||
title: 'Skipped Layers',
|
|
||||||
type: 'integer',
|
|
||||||
},
|
|
||||||
loras: {
|
|
||||||
description: 'LoRAs to apply on model loading',
|
|
||||||
items: {
|
|
||||||
$ref: '#/components/schemas/LoRAField',
|
|
||||||
},
|
|
||||||
title: 'Loras',
|
|
||||||
type: 'array',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
required: ['tokenizer', 'text_encoder', 'skipped_layers', 'loras'],
|
|
||||||
title: 'CLIPField',
|
|
||||||
type: 'object',
|
|
||||||
class: 'output',
|
|
||||||
},
|
|
||||||
CollectInvocation: {
|
|
||||||
properties: {
|
|
||||||
id: {
|
|
||||||
type: 'string',
|
|
||||||
title: 'Id',
|
|
||||||
description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.',
|
|
||||||
field_kind: 'node_attribute',
|
|
||||||
},
|
|
||||||
is_intermediate: {
|
|
||||||
type: 'boolean',
|
|
||||||
title: 'Is Intermediate',
|
|
||||||
description: 'Whether or not this is an intermediate invocation.',
|
|
||||||
default: false,
|
|
||||||
field_kind: 'node_attribute',
|
|
||||||
ui_type: 'IsIntermediate',
|
|
||||||
},
|
|
||||||
use_cache: {
|
|
||||||
type: 'boolean',
|
|
||||||
title: 'Use Cache',
|
|
||||||
description: 'Whether or not to use the cache',
|
|
||||||
default: true,
|
|
||||||
field_kind: 'node_attribute',
|
|
||||||
},
|
|
||||||
item: {
|
|
||||||
anyOf: [
|
|
||||||
{},
|
|
||||||
{
|
|
||||||
type: 'null',
|
|
||||||
},
|
|
||||||
],
|
|
||||||
title: 'Collection Item',
|
|
||||||
description: 'The item to collect (all inputs must be of the same type)',
|
|
||||||
field_kind: 'input',
|
|
||||||
input: 'connection',
|
|
||||||
orig_required: false,
|
|
||||||
ui_hidden: false,
|
|
||||||
ui_type: 'CollectionItemField',
|
|
||||||
},
|
|
||||||
collection: {
|
|
||||||
items: {},
|
|
||||||
type: 'array',
|
|
||||||
title: 'Collection',
|
|
||||||
description: 'The collection, will be provided on execution',
|
|
||||||
default: [],
|
|
||||||
field_kind: 'input',
|
|
||||||
input: 'any',
|
|
||||||
orig_default: [],
|
|
||||||
orig_required: false,
|
|
||||||
ui_hidden: true,
|
|
||||||
},
|
|
||||||
type: {
|
|
||||||
type: 'string',
|
|
||||||
enum: ['collect'],
|
|
||||||
const: 'collect',
|
|
||||||
title: 'type',
|
|
||||||
default: 'collect',
|
|
||||||
field_kind: 'node_attribute',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
type: 'object',
|
|
||||||
required: ['type', 'id'],
|
|
||||||
title: 'CollectInvocation',
|
|
||||||
description: 'Collects values into a collection',
|
|
||||||
classification: 'stable',
|
|
||||||
version: '1.0.0',
|
|
||||||
output: {
|
|
||||||
$ref: '#/components/schemas/CollectInvocationOutput',
|
|
||||||
},
|
|
||||||
class: 'invocation',
|
|
||||||
},
|
|
||||||
CollectInvocationOutput: {
|
|
||||||
properties: {
|
|
||||||
collection: {
|
|
||||||
description: 'The collection of input items',
|
|
||||||
field_kind: 'output',
|
|
||||||
items: {},
|
|
||||||
title: 'Collection',
|
|
||||||
type: 'array',
|
|
||||||
ui_hidden: false,
|
|
||||||
ui_type: 'CollectionField',
|
|
||||||
},
|
|
||||||
type: {
|
|
||||||
const: 'collect_output',
|
|
||||||
default: 'collect_output',
|
|
||||||
enum: ['collect_output'],
|
|
||||||
field_kind: 'node_attribute',
|
|
||||||
title: 'type',
|
|
||||||
type: 'string',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
required: ['collection', 'type', 'type'],
|
|
||||||
title: 'CollectInvocationOutput',
|
|
||||||
type: 'object',
|
|
||||||
class: 'output',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
} as OpenAPIV3_1.Document;
|
|
||||||
|
Loading…
Reference in New Issue
Block a user