mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): connection validation for collection items types
This commit is contained in:
parent
76825f4261
commit
a8b042177d
@ -774,6 +774,7 @@
|
||||
"cannotConnectOutputToOutput": "Cannot connect output to output",
|
||||
"cannotConnectToSelf": "Cannot connect to self",
|
||||
"cannotDuplicateConnection": "Cannot create duplicate connections",
|
||||
"cannotMixAndMatchCollectionItemTypes": "Cannot mix and match collection item types",
|
||||
"nodePack": "Node pack",
|
||||
"collection": "Collection",
|
||||
"collectionFieldType": "{{name}} Collection",
|
||||
|
@ -148,7 +148,7 @@ const AddNodePopover = () => {
|
||||
const template = templates[node.data.type];
|
||||
assert(template, 'Template not found');
|
||||
const { nodes, edges } = store.getState().nodes.present;
|
||||
const connection = getFirstValidConnection(nodes, edges, pendingConnection, node, template);
|
||||
const connection = getFirstValidConnection(templates, nodes, edges, pendingConnection, node, template);
|
||||
if (connection) {
|
||||
dispatch(connectionMade(connection));
|
||||
}
|
||||
|
@ -70,7 +70,7 @@ export const useConnection = () => {
|
||||
}
|
||||
const candidateTemplate = templates[candidateNode.data.type];
|
||||
assert(candidateTemplate, `Template not found for node type: ${candidateNode.data.type}`);
|
||||
const connection = getFirstValidConnection(nodes, edges, pendingConnection, candidateNode, candidateTemplate);
|
||||
const connection = getFirstValidConnection(templates, nodes, edges, pendingConnection, candidateNode, candidateTemplate);
|
||||
if (connection) {
|
||||
dispatch(connectionMade(connection));
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { $pendingConnection, selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { $pendingConnection, $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { makeConnectionErrorSelector } from 'features/nodes/store/util/makeIsConnectionValidSelector';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
@ -15,6 +15,7 @@ type UseConnectionStateProps = {
|
||||
|
||||
export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionStateProps) => {
|
||||
const pendingConnection = useStore($pendingConnection);
|
||||
const templates = useStore($templates);
|
||||
const fieldType = useFieldType(nodeId, fieldName, kind);
|
||||
|
||||
const selectIsConnected = useMemo(
|
||||
@ -35,13 +36,14 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta
|
||||
const selectConnectionError = useMemo(
|
||||
() =>
|
||||
makeConnectionErrorSelector(
|
||||
templates,
|
||||
pendingConnection,
|
||||
nodeId,
|
||||
fieldName,
|
||||
kind === 'inputs' ? 'target' : 'source',
|
||||
fieldType
|
||||
),
|
||||
[pendingConnection, nodeId, fieldName, kind, fieldType]
|
||||
[templates, pendingConnection, nodeId, fieldName, kind, fieldType]
|
||||
);
|
||||
|
||||
const isConnected = useAppSelector(selectIsConnected);
|
||||
|
@ -3,8 +3,10 @@ import { useStore } from '@nanostores/react';
|
||||
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import { getIsGraphAcyclic } from 'features/nodes/store/util/getIsGraphAcyclic';
|
||||
import { getCollectItemType } from 'features/nodes/store/util/makeIsConnectionValidSelector';
|
||||
import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes';
|
||||
import type { InvocationNodeData } from 'features/nodes/types/invocation';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import { useCallback } from 'react';
|
||||
import type { Connection, Node } from 'reactflow';
|
||||
|
||||
@ -60,6 +62,14 @@ export const useIsValidConnection = () => {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (targetNode.data.type === 'collect' && targetFieldTemplate.name === 'item') {
|
||||
// Collect nodes shouldn't mix and match field types
|
||||
const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id);
|
||||
if (collectItemType) {
|
||||
return isEqual(sourceFieldTemplate.type, collectItemType);
|
||||
}
|
||||
}
|
||||
|
||||
// Connection is invalid if target already has a connection
|
||||
if (
|
||||
edges.find((edge) => {
|
||||
|
@ -1,8 +1,9 @@
|
||||
import type { PendingConnection, Templates } from 'features/nodes/store/types';
|
||||
import { getCollectItemType } from 'features/nodes/store/util/makeIsConnectionValidSelector';
|
||||
import type { FieldInputTemplate, FieldOutputTemplate, FieldType } from 'features/nodes/types/field';
|
||||
import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { differenceWith, map } from 'lodash-es';
|
||||
import { differenceWith, isEqual, map } from 'lodash-es';
|
||||
import type { Connection, Edge, HandleType, Node } from 'reactflow';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
@ -115,6 +116,7 @@ export const findConnectionToValidHandle = (
|
||||
};
|
||||
|
||||
export const getFirstValidConnection = (
|
||||
templates: Templates,
|
||||
nodes: AnyNode[],
|
||||
edges: InvocationNodeEdge[],
|
||||
pendingConnection: PendingConnection,
|
||||
@ -170,12 +172,11 @@ export const getFirstValidConnection = (
|
||||
} else {
|
||||
// Connecting from a target to a source
|
||||
// Ensure we there is not already an edge to the target, except for collect nodes
|
||||
if (
|
||||
pendingConnection.node.data.type !== 'collect' &&
|
||||
edges.some(
|
||||
(e) => e.target === pendingConnection.node.id && e.targetHandle === pendingConnection.fieldTemplate.name
|
||||
)
|
||||
) {
|
||||
const isCollect = pendingConnection.node.data.type === 'collect';
|
||||
const isTargetAlreadyConnected = edges.some(
|
||||
(e) => e.target === pendingConnection.node.id && e.targetHandle === pendingConnection.fieldTemplate.name
|
||||
);
|
||||
if (!isCollect && isTargetAlreadyConnected) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@ -184,7 +185,14 @@ export const getFirstValidConnection = (
|
||||
}
|
||||
|
||||
// Sources/outputs can have any number of edges, we can take the first matching output field
|
||||
const candidateFields = map(candidateTemplate.outputs);
|
||||
let candidateFields = map(candidateTemplate.outputs);
|
||||
if (isCollect) {
|
||||
// Narrow candidates to same field type as already is connected to the collect node
|
||||
const collectItemType = getCollectItemType(templates, nodes, edges, pendingConnection.node.id);
|
||||
if (collectItemType) {
|
||||
candidateFields = candidateFields.filter((field) => isEqual(field.type, collectItemType));
|
||||
}
|
||||
}
|
||||
const candidateField = candidateFields.find((field) => {
|
||||
const isValid = validateSourceAndTargetTypes(field.type, pendingConnection.fieldTemplate.type);
|
||||
const isAlreadyConnected = edges.some((e) => e.source === candidateNode.id && e.sourceHandle === field.name);
|
||||
|
@ -1,19 +1,44 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import type { PendingConnection } from 'features/nodes/store/types';
|
||||
import type { PendingConnection, Templates } from 'features/nodes/store/types';
|
||||
import type { FieldType } from 'features/nodes/types/field';
|
||||
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
|
||||
import i18n from 'i18next';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import type { HandleType } from 'reactflow';
|
||||
|
||||
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
|
||||
import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
|
||||
|
||||
export const getCollectItemType = (
|
||||
templates: Templates,
|
||||
nodes: AnyNode[],
|
||||
edges: InvocationNodeEdge[],
|
||||
nodeId: string
|
||||
): FieldType | null => {
|
||||
const firstEdgeToCollect = edges.find((edge) => edge.target === nodeId && edge.targetHandle === 'item');
|
||||
if (!firstEdgeToCollect?.sourceHandle) {
|
||||
return null;
|
||||
}
|
||||
const node = nodes.find((n) => n.id === firstEdgeToCollect.source);
|
||||
if (!node) {
|
||||
return null;
|
||||
}
|
||||
const template = templates[node.data.type];
|
||||
if (!template) {
|
||||
return null;
|
||||
}
|
||||
const fieldType = template.outputs[firstEdgeToCollect.sourceHandle]?.type ?? null;
|
||||
return fieldType;
|
||||
};
|
||||
|
||||
/**
|
||||
* NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts`
|
||||
* TODO: Figure out how to do this without duplicating all the logic
|
||||
*/
|
||||
|
||||
export const makeConnectionErrorSelector = (
|
||||
templates: Templates,
|
||||
pendingConnection: PendingConnection | null,
|
||||
nodeId: string,
|
||||
fieldName: string,
|
||||
@ -72,6 +97,17 @@ export const makeConnectionErrorSelector = (
|
||||
return i18n.t('nodes.cannotDuplicateConnection');
|
||||
}
|
||||
|
||||
const targetNode = nodes.find((node) => node.id === target);
|
||||
if (targetNode?.data.type === 'collect' && targetHandle === 'item') {
|
||||
// Collect nodes shouldn't mix and match field types
|
||||
const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id);
|
||||
if (collectItemType) {
|
||||
if (!isEqual(sourceType, collectItemType)) {
|
||||
return i18n.t('nodes.cannotMixAndMatchCollectionItemTypes');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (
|
||||
edges.find((edge) => {
|
||||
return edge.target === target && edge.targetHandle === targetHandle;
|
||||
|
Loading…
Reference in New Issue
Block a user