mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): connection validation for collection items types
This commit is contained in:
@ -774,6 +774,7 @@
|
|||||||
"cannotConnectOutputToOutput": "Cannot connect output to output",
|
"cannotConnectOutputToOutput": "Cannot connect output to output",
|
||||||
"cannotConnectToSelf": "Cannot connect to self",
|
"cannotConnectToSelf": "Cannot connect to self",
|
||||||
"cannotDuplicateConnection": "Cannot create duplicate connections",
|
"cannotDuplicateConnection": "Cannot create duplicate connections",
|
||||||
|
"cannotMixAndMatchCollectionItemTypes": "Cannot mix and match collection item types",
|
||||||
"nodePack": "Node pack",
|
"nodePack": "Node pack",
|
||||||
"collection": "Collection",
|
"collection": "Collection",
|
||||||
"collectionFieldType": "{{name}} Collection",
|
"collectionFieldType": "{{name}} Collection",
|
||||||
|
@ -148,7 +148,7 @@ const AddNodePopover = () => {
|
|||||||
const template = templates[node.data.type];
|
const template = templates[node.data.type];
|
||||||
assert(template, 'Template not found');
|
assert(template, 'Template not found');
|
||||||
const { nodes, edges } = store.getState().nodes.present;
|
const { nodes, edges } = store.getState().nodes.present;
|
||||||
const connection = getFirstValidConnection(nodes, edges, pendingConnection, node, template);
|
const connection = getFirstValidConnection(templates, nodes, edges, pendingConnection, node, template);
|
||||||
if (connection) {
|
if (connection) {
|
||||||
dispatch(connectionMade(connection));
|
dispatch(connectionMade(connection));
|
||||||
}
|
}
|
||||||
|
@ -70,7 +70,7 @@ export const useConnection = () => {
|
|||||||
}
|
}
|
||||||
const candidateTemplate = templates[candidateNode.data.type];
|
const candidateTemplate = templates[candidateNode.data.type];
|
||||||
assert(candidateTemplate, `Template not found for node type: ${candidateNode.data.type}`);
|
assert(candidateTemplate, `Template not found for node type: ${candidateNode.data.type}`);
|
||||||
const connection = getFirstValidConnection(nodes, edges, pendingConnection, candidateNode, candidateTemplate);
|
const connection = getFirstValidConnection(templates, nodes, edges, pendingConnection, candidateNode, candidateTemplate);
|
||||||
if (connection) {
|
if (connection) {
|
||||||
dispatch(connectionMade(connection));
|
dispatch(connectionMade(connection));
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import { useStore } from '@nanostores/react';
|
import { useStore } from '@nanostores/react';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { $pendingConnection, selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
import { $pendingConnection, $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||||
import { makeConnectionErrorSelector } from 'features/nodes/store/util/makeIsConnectionValidSelector';
|
import { makeConnectionErrorSelector } from 'features/nodes/store/util/makeIsConnectionValidSelector';
|
||||||
import { useMemo } from 'react';
|
import { useMemo } from 'react';
|
||||||
|
|
||||||
@ -15,6 +15,7 @@ type UseConnectionStateProps = {
|
|||||||
|
|
||||||
export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionStateProps) => {
|
export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionStateProps) => {
|
||||||
const pendingConnection = useStore($pendingConnection);
|
const pendingConnection = useStore($pendingConnection);
|
||||||
|
const templates = useStore($templates);
|
||||||
const fieldType = useFieldType(nodeId, fieldName, kind);
|
const fieldType = useFieldType(nodeId, fieldName, kind);
|
||||||
|
|
||||||
const selectIsConnected = useMemo(
|
const selectIsConnected = useMemo(
|
||||||
@ -35,13 +36,14 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta
|
|||||||
const selectConnectionError = useMemo(
|
const selectConnectionError = useMemo(
|
||||||
() =>
|
() =>
|
||||||
makeConnectionErrorSelector(
|
makeConnectionErrorSelector(
|
||||||
|
templates,
|
||||||
pendingConnection,
|
pendingConnection,
|
||||||
nodeId,
|
nodeId,
|
||||||
fieldName,
|
fieldName,
|
||||||
kind === 'inputs' ? 'target' : 'source',
|
kind === 'inputs' ? 'target' : 'source',
|
||||||
fieldType
|
fieldType
|
||||||
),
|
),
|
||||||
[pendingConnection, nodeId, fieldName, kind, fieldType]
|
[templates, pendingConnection, nodeId, fieldName, kind, fieldType]
|
||||||
);
|
);
|
||||||
|
|
||||||
const isConnected = useAppSelector(selectIsConnected);
|
const isConnected = useAppSelector(selectIsConnected);
|
||||||
|
@ -3,8 +3,10 @@ import { useStore } from '@nanostores/react';
|
|||||||
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
|
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
|
||||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||||
import { getIsGraphAcyclic } from 'features/nodes/store/util/getIsGraphAcyclic';
|
import { getIsGraphAcyclic } from 'features/nodes/store/util/getIsGraphAcyclic';
|
||||||
|
import { getCollectItemType } from 'features/nodes/store/util/makeIsConnectionValidSelector';
|
||||||
import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes';
|
import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes';
|
||||||
import type { InvocationNodeData } from 'features/nodes/types/invocation';
|
import type { InvocationNodeData } from 'features/nodes/types/invocation';
|
||||||
|
import { isEqual } from 'lodash-es';
|
||||||
import { useCallback } from 'react';
|
import { useCallback } from 'react';
|
||||||
import type { Connection, Node } from 'reactflow';
|
import type { Connection, Node } from 'reactflow';
|
||||||
|
|
||||||
@ -60,6 +62,14 @@ export const useIsValidConnection = () => {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (targetNode.data.type === 'collect' && targetFieldTemplate.name === 'item') {
|
||||||
|
// Collect nodes shouldn't mix and match field types
|
||||||
|
const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id);
|
||||||
|
if (collectItemType) {
|
||||||
|
return isEqual(sourceFieldTemplate.type, collectItemType);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Connection is invalid if target already has a connection
|
// Connection is invalid if target already has a connection
|
||||||
if (
|
if (
|
||||||
edges.find((edge) => {
|
edges.find((edge) => {
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
import type { PendingConnection, Templates } from 'features/nodes/store/types';
|
import type { PendingConnection, Templates } from 'features/nodes/store/types';
|
||||||
|
import { getCollectItemType } from 'features/nodes/store/util/makeIsConnectionValidSelector';
|
||||||
import type { FieldInputTemplate, FieldOutputTemplate, FieldType } from 'features/nodes/types/field';
|
import type { FieldInputTemplate, FieldOutputTemplate, FieldType } from 'features/nodes/types/field';
|
||||||
import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation';
|
import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation';
|
||||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||||
import { differenceWith, map } from 'lodash-es';
|
import { differenceWith, isEqual, map } from 'lodash-es';
|
||||||
import type { Connection, Edge, HandleType, Node } from 'reactflow';
|
import type { Connection, Edge, HandleType, Node } from 'reactflow';
|
||||||
import { assert } from 'tsafe';
|
import { assert } from 'tsafe';
|
||||||
|
|
||||||
@ -115,6 +116,7 @@ export const findConnectionToValidHandle = (
|
|||||||
};
|
};
|
||||||
|
|
||||||
export const getFirstValidConnection = (
|
export const getFirstValidConnection = (
|
||||||
|
templates: Templates,
|
||||||
nodes: AnyNode[],
|
nodes: AnyNode[],
|
||||||
edges: InvocationNodeEdge[],
|
edges: InvocationNodeEdge[],
|
||||||
pendingConnection: PendingConnection,
|
pendingConnection: PendingConnection,
|
||||||
@ -170,12 +172,11 @@ export const getFirstValidConnection = (
|
|||||||
} else {
|
} else {
|
||||||
// Connecting from a target to a source
|
// Connecting from a target to a source
|
||||||
// Ensure we there is not already an edge to the target, except for collect nodes
|
// Ensure we there is not already an edge to the target, except for collect nodes
|
||||||
if (
|
const isCollect = pendingConnection.node.data.type === 'collect';
|
||||||
pendingConnection.node.data.type !== 'collect' &&
|
const isTargetAlreadyConnected = edges.some(
|
||||||
edges.some(
|
|
||||||
(e) => e.target === pendingConnection.node.id && e.targetHandle === pendingConnection.fieldTemplate.name
|
(e) => e.target === pendingConnection.node.id && e.targetHandle === pendingConnection.fieldTemplate.name
|
||||||
)
|
);
|
||||||
) {
|
if (!isCollect && isTargetAlreadyConnected) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -184,7 +185,14 @@ export const getFirstValidConnection = (
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Sources/outputs can have any number of edges, we can take the first matching output field
|
// Sources/outputs can have any number of edges, we can take the first matching output field
|
||||||
const candidateFields = map(candidateTemplate.outputs);
|
let candidateFields = map(candidateTemplate.outputs);
|
||||||
|
if (isCollect) {
|
||||||
|
// Narrow candidates to same field type as already is connected to the collect node
|
||||||
|
const collectItemType = getCollectItemType(templates, nodes, edges, pendingConnection.node.id);
|
||||||
|
if (collectItemType) {
|
||||||
|
candidateFields = candidateFields.filter((field) => isEqual(field.type, collectItemType));
|
||||||
|
}
|
||||||
|
}
|
||||||
const candidateField = candidateFields.find((field) => {
|
const candidateField = candidateFields.find((field) => {
|
||||||
const isValid = validateSourceAndTargetTypes(field.type, pendingConnection.fieldTemplate.type);
|
const isValid = validateSourceAndTargetTypes(field.type, pendingConnection.fieldTemplate.type);
|
||||||
const isAlreadyConnected = edges.some((e) => e.source === candidateNode.id && e.sourceHandle === field.name);
|
const isAlreadyConnected = edges.some((e) => e.source === candidateNode.id && e.sourceHandle === field.name);
|
||||||
|
@ -1,19 +1,44 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||||
import type { PendingConnection } from 'features/nodes/store/types';
|
import type { PendingConnection, Templates } from 'features/nodes/store/types';
|
||||||
import type { FieldType } from 'features/nodes/types/field';
|
import type { FieldType } from 'features/nodes/types/field';
|
||||||
|
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
|
||||||
import i18n from 'i18next';
|
import i18n from 'i18next';
|
||||||
|
import { isEqual } from 'lodash-es';
|
||||||
import type { HandleType } from 'reactflow';
|
import type { HandleType } from 'reactflow';
|
||||||
|
|
||||||
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
|
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
|
||||||
import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
|
import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
|
||||||
|
|
||||||
|
export const getCollectItemType = (
|
||||||
|
templates: Templates,
|
||||||
|
nodes: AnyNode[],
|
||||||
|
edges: InvocationNodeEdge[],
|
||||||
|
nodeId: string
|
||||||
|
): FieldType | null => {
|
||||||
|
const firstEdgeToCollect = edges.find((edge) => edge.target === nodeId && edge.targetHandle === 'item');
|
||||||
|
if (!firstEdgeToCollect?.sourceHandle) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
const node = nodes.find((n) => n.id === firstEdgeToCollect.source);
|
||||||
|
if (!node) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
const template = templates[node.data.type];
|
||||||
|
if (!template) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
const fieldType = template.outputs[firstEdgeToCollect.sourceHandle]?.type ?? null;
|
||||||
|
return fieldType;
|
||||||
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts`
|
* NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts`
|
||||||
* TODO: Figure out how to do this without duplicating all the logic
|
* TODO: Figure out how to do this without duplicating all the logic
|
||||||
*/
|
*/
|
||||||
|
|
||||||
export const makeConnectionErrorSelector = (
|
export const makeConnectionErrorSelector = (
|
||||||
|
templates: Templates,
|
||||||
pendingConnection: PendingConnection | null,
|
pendingConnection: PendingConnection | null,
|
||||||
nodeId: string,
|
nodeId: string,
|
||||||
fieldName: string,
|
fieldName: string,
|
||||||
@ -72,6 +97,17 @@ export const makeConnectionErrorSelector = (
|
|||||||
return i18n.t('nodes.cannotDuplicateConnection');
|
return i18n.t('nodes.cannotDuplicateConnection');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const targetNode = nodes.find((node) => node.id === target);
|
||||||
|
if (targetNode?.data.type === 'collect' && targetHandle === 'item') {
|
||||||
|
// Collect nodes shouldn't mix and match field types
|
||||||
|
const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id);
|
||||||
|
if (collectItemType) {
|
||||||
|
if (!isEqual(sourceType, collectItemType)) {
|
||||||
|
return i18n.t('nodes.cannotMixAndMatchCollectionItemTypes');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (
|
if (
|
||||||
edges.find((edge) => {
|
edges.find((edge) => {
|
||||||
return edge.target === target && edge.targetHandle === targetHandle;
|
return edge.target === target && edge.targetHandle === targetHandle;
|
||||||
|
Reference in New Issue
Block a user