mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): rework getFirstValidConnection with new helpers
This commit is contained in:
parent
c98205d0d7
commit
83000a4190
@ -1,117 +1,76 @@
|
|||||||
import type { PendingConnection, Templates } from 'features/nodes/store/types';
|
import type { Templates } from 'features/nodes/store/types';
|
||||||
import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes';
|
import { validateConnection } from 'features/nodes/store/util/validateConnection';
|
||||||
import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation';
|
import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field';
|
||||||
import { differenceWith, map } from 'lodash-es';
|
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
|
||||||
|
import { map } from 'lodash-es';
|
||||||
import type { Connection, Edge } from 'reactflow';
|
import type { Connection, Edge } from 'reactflow';
|
||||||
import { assert } from 'tsafe';
|
|
||||||
|
|
||||||
import { areTypesEqual } from './areTypesEqual';
|
|
||||||
import { getCollectItemType } from './getCollectItemType';
|
|
||||||
import { getHasCycles } from './getHasCycles';
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Finds the first valid field for a pending connection between two nodes.
|
*
|
||||||
* @param templates The invocation templates
|
* @param source The source (node id)
|
||||||
|
* @param sourceHandle The source handle (field name), if any
|
||||||
|
* @param target The target (node id)
|
||||||
|
* @param targetHandle The target handle (field name), if any
|
||||||
* @param nodes The current nodes
|
* @param nodes The current nodes
|
||||||
* @param edges The current edges
|
* @param edges The current edges
|
||||||
* @param pendingConnection The pending connection
|
* @param templates The current templates
|
||||||
* @param candidateNode The candidate node to which the connection is being made
|
* @param edgePendingUpdate The edge pending update, if any
|
||||||
* @param candidateTemplate The candidate template for the candidate node
|
* @returns
|
||||||
* @returns The first valid connection, or null if no valid connection is found
|
|
||||||
*/
|
*/
|
||||||
|
|
||||||
export const getFirstValidConnection = (
|
export const getFirstValidConnection = (
|
||||||
templates: Templates,
|
source: string,
|
||||||
|
sourceHandle: string | null,
|
||||||
|
target: string,
|
||||||
|
targetHandle: string | null,
|
||||||
nodes: AnyNode[],
|
nodes: AnyNode[],
|
||||||
edges: InvocationNodeEdge[],
|
edges: InvocationNodeEdge[],
|
||||||
pendingConnection: PendingConnection,
|
templates: Templates,
|
||||||
candidateNode: InvocationNode,
|
|
||||||
candidateTemplate: InvocationTemplate,
|
|
||||||
edgePendingUpdate: Edge | null
|
edgePendingUpdate: Edge | null
|
||||||
): Connection | null => {
|
): Connection | null => {
|
||||||
if (pendingConnection.node.id === candidateNode.id) {
|
if (source === target) {
|
||||||
// Cannot connect to self
|
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source';
|
if (sourceHandle && targetHandle) {
|
||||||
|
return { source, sourceHandle, target, targetHandle };
|
||||||
|
}
|
||||||
|
|
||||||
if (pendingFieldKind === 'source') {
|
if (sourceHandle && !targetHandle) {
|
||||||
// Connecting from a source to a target
|
const candidates = getTargetCandidateFields(
|
||||||
if (getHasCycles(pendingConnection.node.id, candidateNode.id, nodes, edges)) {
|
source,
|
||||||
return null;
|
sourceHandle,
|
||||||
}
|
target,
|
||||||
if (candidateNode.data.type === 'collect') {
|
nodes,
|
||||||
// Special handling for collect node - the `item` field takes any number of connections
|
edges,
|
||||||
return {
|
templates,
|
||||||
source: pendingConnection.node.id,
|
edgePendingUpdate
|
||||||
sourceHandle: pendingConnection.fieldTemplate.name,
|
|
||||||
target: candidateNode.id,
|
|
||||||
targetHandle: 'item',
|
|
||||||
};
|
|
||||||
}
|
|
||||||
// Only one connection per target field is allowed - look for an unconnected target field
|
|
||||||
const candidateFields = map(candidateTemplate.inputs);
|
|
||||||
const candidateConnectedFields = edges
|
|
||||||
.filter((edge) => edge.target === candidateNode.id || edge.id === edgePendingUpdate?.id)
|
|
||||||
.map((edge) => {
|
|
||||||
// Edges must always have a targetHandle, safe to assert here
|
|
||||||
assert(edge.targetHandle);
|
|
||||||
return edge.targetHandle;
|
|
||||||
});
|
|
||||||
const candidateUnconnectedFields = differenceWith(
|
|
||||||
candidateFields,
|
|
||||||
candidateConnectedFields,
|
|
||||||
(field, connectedFieldName) => field.name === connectedFieldName
|
|
||||||
);
|
);
|
||||||
const candidateField = candidateUnconnectedFields.find((field) =>
|
|
||||||
validateConnectionTypes(pendingConnection.fieldTemplate.type, field.type)
|
const firstCandidate = candidates[0];
|
||||||
);
|
if (!firstCandidate) {
|
||||||
if (candidateField) {
|
|
||||||
return {
|
|
||||||
source: pendingConnection.node.id,
|
|
||||||
sourceHandle: pendingConnection.fieldTemplate.name,
|
|
||||||
target: candidateNode.id,
|
|
||||||
targetHandle: candidateField.name,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Connecting from a target to a source
|
|
||||||
// Ensure we there is not already an edge to the target, except for collect nodes
|
|
||||||
const isCollect = pendingConnection.node.data.type === 'collect';
|
|
||||||
const isTargetAlreadyConnected = edges.some(
|
|
||||||
(e) => e.target === pendingConnection.node.id && e.targetHandle === pendingConnection.fieldTemplate.name
|
|
||||||
);
|
|
||||||
if (!isCollect && isTargetAlreadyConnected) {
|
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (getHasCycles(candidateNode.id, pendingConnection.node.id, nodes, edges)) {
|
return { source, sourceHandle, target, targetHandle: firstCandidate.name };
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!sourceHandle && targetHandle) {
|
||||||
|
const candidates = getSourceCandidateFields(
|
||||||
|
target,
|
||||||
|
targetHandle,
|
||||||
|
source,
|
||||||
|
nodes,
|
||||||
|
edges,
|
||||||
|
templates,
|
||||||
|
edgePendingUpdate
|
||||||
|
);
|
||||||
|
|
||||||
|
const firstCandidate = candidates[0];
|
||||||
|
if (!firstCandidate) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sources/outputs can have any number of edges, we can take the first matching output field
|
return { source, sourceHandle: firstCandidate.name, target, targetHandle };
|
||||||
let candidateFields = map(candidateTemplate.outputs);
|
|
||||||
if (isCollect) {
|
|
||||||
// Narrow candidates to same field type as already is connected to the collect node
|
|
||||||
const collectItemType = getCollectItemType(templates, nodes, edges, pendingConnection.node.id);
|
|
||||||
if (collectItemType) {
|
|
||||||
candidateFields = candidateFields.filter((field) => areTypesEqual(field.type, collectItemType));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
const candidateField = candidateFields.find((field) => {
|
|
||||||
const isValid = validateConnectionTypes(field.type, pendingConnection.fieldTemplate.type);
|
|
||||||
const isAlreadyConnected = edges.some((e) => e.source === candidateNode.id && e.sourceHandle === field.name);
|
|
||||||
return isValid && !isAlreadyConnected;
|
|
||||||
});
|
|
||||||
if (candidateField) {
|
|
||||||
return {
|
|
||||||
source: candidateNode.id,
|
|
||||||
sourceHandle: candidateField.name,
|
|
||||||
target: pendingConnection.node.id,
|
|
||||||
targetHandle: pendingConnection.fieldTemplate.name,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return null;
|
return null;
|
||||||
|
Loading…
Reference in New Issue
Block a user