feat(ui): connection validation for collection items types

This commit is contained in:
psychedelicious 2024-05-16 23:42:54 +10:00
parent 76825f4261
commit a8b042177d
7 changed files with 70 additions and 13 deletions

View File

@ -774,6 +774,7 @@
"cannotConnectOutputToOutput": "Cannot connect output to output",
"cannotConnectToSelf": "Cannot connect to self",
"cannotDuplicateConnection": "Cannot create duplicate connections",
"cannotMixAndMatchCollectionItemTypes": "Cannot mix and match collection item types",
"nodePack": "Node pack",
"collection": "Collection",
"collectionFieldType": "{{name}} Collection",

View File

@ -148,7 +148,7 @@ const AddNodePopover = () => {
const template = templates[node.data.type];
assert(template, 'Template not found');
const { nodes, edges } = store.getState().nodes.present;
const connection = getFirstValidConnection(nodes, edges, pendingConnection, node, template);
const connection = getFirstValidConnection(templates, nodes, edges, pendingConnection, node, template);
if (connection) {
dispatch(connectionMade(connection));
}

View File

@ -70,7 +70,7 @@ export const useConnection = () => {
}
const candidateTemplate = templates[candidateNode.data.type];
assert(candidateTemplate, `Template not found for node type: ${candidateNode.data.type}`);
const connection = getFirstValidConnection(nodes, edges, pendingConnection, candidateNode, candidateTemplate);
const connection = getFirstValidConnection(templates, nodes, edges, pendingConnection, candidateNode, candidateTemplate);
if (connection) {
dispatch(connectionMade(connection));
}

View File

@ -1,7 +1,7 @@
import { useStore } from '@nanostores/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { $pendingConnection, selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { $pendingConnection, $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { makeConnectionErrorSelector } from 'features/nodes/store/util/makeIsConnectionValidSelector';
import { useMemo } from 'react';
@ -15,6 +15,7 @@ type UseConnectionStateProps = {
export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionStateProps) => {
const pendingConnection = useStore($pendingConnection);
const templates = useStore($templates);
const fieldType = useFieldType(nodeId, fieldName, kind);
const selectIsConnected = useMemo(
@ -35,13 +36,14 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta
const selectConnectionError = useMemo(
() =>
makeConnectionErrorSelector(
templates,
pendingConnection,
nodeId,
fieldName,
kind === 'inputs' ? 'target' : 'source',
fieldType
),
[pendingConnection, nodeId, fieldName, kind, fieldType]
[templates, pendingConnection, nodeId, fieldName, kind, fieldType]
);
const isConnected = useAppSelector(selectIsConnected);

View File

@ -3,8 +3,10 @@ import { useStore } from '@nanostores/react';
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
import { $templates } from 'features/nodes/store/nodesSlice';
import { getIsGraphAcyclic } from 'features/nodes/store/util/getIsGraphAcyclic';
import { getCollectItemType } from 'features/nodes/store/util/makeIsConnectionValidSelector';
import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes';
import type { InvocationNodeData } from 'features/nodes/types/invocation';
import { isEqual } from 'lodash-es';
import { useCallback } from 'react';
import type { Connection, Node } from 'reactflow';
@ -60,6 +62,14 @@ export const useIsValidConnection = () => {
return false;
}
if (targetNode.data.type === 'collect' && targetFieldTemplate.name === 'item') {
// Collect nodes shouldn't mix and match field types
const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id);
if (collectItemType) {
return isEqual(sourceFieldTemplate.type, collectItemType);
}
}
// Connection is invalid if target already has a connection
if (
edges.find((edge) => {

View File

@ -1,8 +1,9 @@
import type { PendingConnection, Templates } from 'features/nodes/store/types';
import { getCollectItemType } from 'features/nodes/store/util/makeIsConnectionValidSelector';
import type { FieldInputTemplate, FieldOutputTemplate, FieldType } from 'features/nodes/types/field';
import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { differenceWith, map } from 'lodash-es';
import { differenceWith, isEqual, map } from 'lodash-es';
import type { Connection, Edge, HandleType, Node } from 'reactflow';
import { assert } from 'tsafe';
@ -115,6 +116,7 @@ export const findConnectionToValidHandle = (
};
export const getFirstValidConnection = (
templates: Templates,
nodes: AnyNode[],
edges: InvocationNodeEdge[],
pendingConnection: PendingConnection,
@ -170,12 +172,11 @@ export const getFirstValidConnection = (
} else {
// Connecting from a target to a source
// Ensure we there is not already an edge to the target, except for collect nodes
if (
pendingConnection.node.data.type !== 'collect' &&
edges.some(
(e) => e.target === pendingConnection.node.id && e.targetHandle === pendingConnection.fieldTemplate.name
)
) {
const isCollect = pendingConnection.node.data.type === 'collect';
const isTargetAlreadyConnected = edges.some(
(e) => e.target === pendingConnection.node.id && e.targetHandle === pendingConnection.fieldTemplate.name
);
if (!isCollect && isTargetAlreadyConnected) {
return null;
}
@ -184,7 +185,14 @@ export const getFirstValidConnection = (
}
// Sources/outputs can have any number of edges, we can take the first matching output field
const candidateFields = map(candidateTemplate.outputs);
let candidateFields = map(candidateTemplate.outputs);
if (isCollect) {
// Narrow candidates to same field type as already is connected to the collect node
const collectItemType = getCollectItemType(templates, nodes, edges, pendingConnection.node.id);
if (collectItemType) {
candidateFields = candidateFields.filter((field) => isEqual(field.type, collectItemType));
}
}
const candidateField = candidateFields.find((field) => {
const isValid = validateSourceAndTargetTypes(field.type, pendingConnection.fieldTemplate.type);
const isAlreadyConnected = edges.some((e) => e.source === candidateNode.id && e.sourceHandle === field.name);

View File

@ -1,19 +1,44 @@
import { createSelector } from '@reduxjs/toolkit';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import type { PendingConnection } from 'features/nodes/store/types';
import type { PendingConnection, Templates } from 'features/nodes/store/types';
import type { FieldType } from 'features/nodes/types/field';
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
import i18n from 'i18next';
import { isEqual } from 'lodash-es';
import type { HandleType } from 'reactflow';
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
export const getCollectItemType = (
templates: Templates,
nodes: AnyNode[],
edges: InvocationNodeEdge[],
nodeId: string
): FieldType | null => {
const firstEdgeToCollect = edges.find((edge) => edge.target === nodeId && edge.targetHandle === 'item');
if (!firstEdgeToCollect?.sourceHandle) {
return null;
}
const node = nodes.find((n) => n.id === firstEdgeToCollect.source);
if (!node) {
return null;
}
const template = templates[node.data.type];
if (!template) {
return null;
}
const fieldType = template.outputs[firstEdgeToCollect.sourceHandle]?.type ?? null;
return fieldType;
};
/**
* NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts`
* TODO: Figure out how to do this without duplicating all the logic
*/
export const makeConnectionErrorSelector = (
templates: Templates,
pendingConnection: PendingConnection | null,
nodeId: string,
fieldName: string,
@ -72,6 +97,17 @@ export const makeConnectionErrorSelector = (
return i18n.t('nodes.cannotDuplicateConnection');
}
const targetNode = nodes.find((node) => node.id === target);
if (targetNode?.data.type === 'collect' && targetHandle === 'item') {
// Collect nodes shouldn't mix and match field types
const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id);
if (collectItemType) {
if (!isEqual(sourceType, collectItemType)) {
return i18n.t('nodes.cannotMixAndMatchCollectionItemTypes');
}
}
}
if (
edges.find((edge) => {
return edge.target === target && edge.targetHandle === targetHandle;