diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts index 9571ce2ee2..5dcb7a28b5 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts @@ -2,11 +2,9 @@ import { useStore } from '@nanostores/react'; import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { $pendingConnection, $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { makeConnectionErrorSelector } from 'features/nodes/store/util/connectionValidation.js'; +import { makeConnectionErrorSelector } from 'features/nodes/store/util/makeConnectionErrorSelector'; import { useMemo } from 'react'; -import { useFieldType } from './useFieldType.ts'; - type UseConnectionStateProps = { nodeId: string; fieldName: string; @@ -16,7 +14,6 @@ type UseConnectionStateProps = { export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionStateProps) => { const pendingConnection = useStore($pendingConnection); const templates = useStore($templates); - const fieldType = useFieldType(nodeId, fieldName, kind); const selectIsConnected = useMemo( () => @@ -34,8 +31,8 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta ); const selectConnectionError = useMemo( - () => makeConnectionErrorSelector(templates, nodeId, fieldName, kind === 'inputs' ? 'target' : 'source', fieldType), - [templates, nodeId, fieldName, kind, fieldType] + () => makeConnectionErrorSelector(templates, nodeId, fieldName, kind === 'inputs' ? 'target' : 'source'), + [templates, nodeId, fieldName, kind] ); const isConnected = useAppSelector(selectIsConnected); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts index 77c4e3c75b..0f8609d2ff 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts @@ -2,13 +2,9 @@ import { useStore } from '@nanostores/react'; import { useAppSelector, useAppStore } from 'app/store/storeHooks'; import { $templates } from 'features/nodes/store/nodesSlice'; -import { areTypesEqual } from 'features/nodes/store/util/areTypesEqual'; -import { getCollectItemType } from 'features/nodes/store/util/getCollectItemType'; -import { getHasCycles } from 'features/nodes/store/util/getHasCycles'; -import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes'; -import type { InvocationNodeData } from 'features/nodes/types/invocation'; +import { validateConnection } from 'features/nodes/store/util/validateConnection'; import { useCallback } from 'react'; -import type { Connection, Node } from 'reactflow'; +import type { Connection } from 'reactflow'; /** * NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts` @@ -26,74 +22,20 @@ export const useIsValidConnection = () => { return false; } - if (source === target) { - // Don't allow nodes to connect to themselves, even if validation is disabled - return false; - } + const { nodes, edges } = store.getState().nodes.present; - const state = store.getState(); - const { nodes, edges } = state.nodes.present; + const validationResult = validateConnection( + { source, sourceHandle, target, targetHandle }, + nodes, + edges, + templates, + null, + shouldValidateGraph + ); - // Find the source and target nodes - const sourceNode = nodes.find((node) => node.id === source) as Node; - const targetNode = nodes.find((node) => node.id === target) as Node; - const sourceFieldTemplate = templates[sourceNode.data.type]?.outputs[sourceHandle]; - const targetFieldTemplate = templates[targetNode.data.type]?.inputs[targetHandle]; - - // Conditional guards against undefined nodes/handles - if (!(sourceFieldTemplate && targetFieldTemplate)) { - return false; - } - - if (targetFieldTemplate.input === 'direct') { - return false; - } - - if (!shouldValidateGraph) { - // manual override! - return true; - } - - if ( - edges.find((edge) => { - edge.target === target && - edge.targetHandle === targetHandle && - edge.source === source && - edge.sourceHandle === sourceHandle; - }) - ) { - // We already have a connection from this source to this target - return false; - } - - if (targetNode.data.type === 'collect' && targetFieldTemplate.name === 'item') { - // Collect nodes shouldn't mix and match field types - const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id); - if (collectItemType) { - return areTypesEqual(sourceFieldTemplate.type, collectItemType); - } - } - - // Connection is invalid if target already has a connection - if ( - edges.find((edge) => { - return edge.target === target && edge.targetHandle === targetHandle; - }) && - // except CollectionItem inputs can have multiples - targetFieldTemplate.type.name !== 'CollectionItemField' - ) { - return false; - } - - // Must use the originalType here if it exists - if (!validateConnectionTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) { - return false; - } - - // Graphs much be acyclic (no loops!) - return !getHasCycles(source, target, nodes, edges); + return validationResult.isValid; }, - [shouldValidateGraph, templates, store] + [templates, shouldValidateGraph, store] ); return isValidConnection; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/connectionValidation.ts b/invokeai/frontend/web/src/features/nodes/store/util/connectionValidation.ts deleted file mode 100644 index 7819221f8a..0000000000 --- a/invokeai/frontend/web/src/features/nodes/store/util/connectionValidation.ts +++ /dev/null @@ -1,134 +0,0 @@ -import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; -import type { RootState } from 'app/store/store'; -import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import type { NodesState, PendingConnection, Templates } from 'features/nodes/store/types'; -import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes'; -import type { FieldType } from 'features/nodes/types/field'; -import i18n from 'i18next'; -import type { HandleType } from 'reactflow'; -import { assert } from 'tsafe'; - -import { areTypesEqual } from './areTypesEqual'; -import { getCollectItemType } from './getCollectItemType'; -import { getHasCycles } from './getHasCycles'; - -/** - * Creates a selector that validates a pending connection. - * - * NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts` - * TODO: Figure out how to do this without duplicating all the logic - * - * @param templates The invocation templates - * @param pendingConnection The current pending connection (if there is one) - * @param nodeId The id of the node for which the selector is being created - * @param fieldName The name of the field for which the selector is being created - * @param handleType The type of the handle for which the selector is being created - * @param fieldType The type of the field for which the selector is being created - * @returns - */ -export const makeConnectionErrorSelector = ( - templates: Templates, - nodeId: string, - fieldName: string, - handleType: HandleType, - fieldType: FieldType -) => { - return createMemoizedSelector( - selectNodesSlice, - (state: RootState, pendingConnection: PendingConnection | null) => pendingConnection, - (nodesSlice: NodesState, pendingConnection: PendingConnection | null) => { - const { nodes, edges } = nodesSlice; - - if (!pendingConnection) { - return i18n.t('nodes.noConnectionInProgress'); - } - - const connectionNodeId = pendingConnection.node.id; - const connectionFieldName = pendingConnection.fieldTemplate.name; - const connectionHandleType = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source'; - const connectionStartFieldType = pendingConnection.fieldTemplate.type; - - if (!connectionHandleType || !connectionNodeId || !connectionFieldName) { - return i18n.t('nodes.noConnectionData'); - } - - const targetType = handleType === 'target' ? fieldType : connectionStartFieldType; - const sourceType = handleType === 'source' ? fieldType : connectionStartFieldType; - - if (nodeId === connectionNodeId) { - return i18n.t('nodes.cannotConnectToSelf'); - } - - if (handleType === connectionHandleType) { - if (handleType === 'source') { - return i18n.t('nodes.cannotConnectOutputToOutput'); - } - return i18n.t('nodes.cannotConnectInputToInput'); - } - - // we have to figure out which is the target and which is the source - const targetNodeId = handleType === 'target' ? nodeId : connectionNodeId; - const targetFieldName = handleType === 'target' ? fieldName : connectionFieldName; - const sourceNodeId = handleType === 'source' ? nodeId : connectionNodeId; - const sourceFieldName = handleType === 'source' ? fieldName : connectionFieldName; - - if ( - edges.find((edge) => { - edge.target === targetNodeId && - edge.targetHandle === targetFieldName && - edge.source === sourceNodeId && - edge.sourceHandle === sourceFieldName; - }) - ) { - // We already have a connection from this source to this target - return i18n.t('nodes.cannotDuplicateConnection'); - } - - const targetNode = nodes.find((node) => node.id === targetNodeId); - assert(targetNode, `Target node not found: ${targetNodeId}`); - const targetTemplate = templates[targetNode.data.type]; - assert(targetTemplate, `Target template not found: ${targetNode.data.type}`); - - if (targetTemplate.inputs[targetFieldName]?.input === 'direct') { - return i18n.t('nodes.cannotConnectToDirectInput'); - } - - if (targetNode.data.type === 'collect' && targetFieldName === 'item') { - // Collect nodes shouldn't mix and match field types - const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id); - if (collectItemType) { - if (!areTypesEqual(sourceType, collectItemType)) { - return i18n.t('nodes.cannotMixAndMatchCollectionItemTypes'); - } - } - } - - if ( - edges.find((edge) => { - return edge.target === targetNodeId && edge.targetHandle === targetFieldName; - }) && - // except CollectionItem inputs can have multiples - targetType.name !== 'CollectionItemField' - ) { - return i18n.t('nodes.inputMayOnlyHaveOneConnection'); - } - - if (!validateConnectionTypes(sourceType, targetType)) { - return i18n.t('nodes.fieldTypesMustMatch'); - } - - const hasCycles = getHasCycles( - connectionHandleType === 'source' ? connectionNodeId : nodeId, - connectionHandleType === 'source' ? nodeId : connectionNodeId, - nodes, - edges - ); - - if (hasCycles) { - return i18n.t('nodes.connectionWouldCreateCycle'); - } - - return; - } - ); -}; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/makeConnectionErrorSelector.ts b/invokeai/frontend/web/src/features/nodes/store/util/makeConnectionErrorSelector.ts new file mode 100644 index 0000000000..3cefb6815f --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/makeConnectionErrorSelector.ts @@ -0,0 +1,72 @@ +import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; +import type { RootState } from 'app/store/store'; +import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; +import type { NodesState, PendingConnection, Templates } from 'features/nodes/store/types'; +import { validateConnection } from 'features/nodes/store/util/validateConnection'; +import i18n from 'i18next'; +import type { HandleType } from 'reactflow'; + +/** + * Creates a selector that validates a pending connection. + * + * NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts` + * TODO: Figure out how to do this without duplicating all the logic + * + * @param templates The invocation templates + * @param nodeId The id of the node for which the selector is being created + * @param fieldName The name of the field for which the selector is being created + * @param handleType The type of the handle for which the selector is being created + * @returns + */ +export const makeConnectionErrorSelector = ( + templates: Templates, + nodeId: string, + fieldName: string, + handleType: HandleType +) => { + return createMemoizedSelector( + selectNodesSlice, + (state: RootState, pendingConnection: PendingConnection | null) => pendingConnection, + (nodesSlice: NodesState, pendingConnection: PendingConnection | null) => { + const { nodes, edges } = nodesSlice; + + if (!pendingConnection) { + return i18n.t('nodes.noConnectionInProgress'); + } + + const connectionNodeId = pendingConnection.node.id; + const connectionFieldName = pendingConnection.fieldTemplate.name; + const connectionHandleType = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source'; + + if (handleType === connectionHandleType) { + if (handleType === 'source') { + return i18n.t('nodes.cannotConnectOutputToOutput'); + } + return i18n.t('nodes.cannotConnectInputToInput'); + } + + // we have to figure out which is the target and which is the source + const source = handleType === 'source' ? nodeId : connectionNodeId; + const sourceHandle = handleType === 'source' ? fieldName : connectionFieldName; + const target = handleType === 'target' ? nodeId : connectionNodeId; + const targetHandle = handleType === 'target' ? fieldName : connectionFieldName; + + const validationResult = validateConnection( + { + source, + sourceHandle, + target, + targetHandle, + }, + nodes, + edges, + templates, + null + ); + + if (!validationResult.isValid) { + return i18n.t(validationResult.messageTKey); + } + } + ); +};