From 9d127fee6bc7ca2ac93ec91a65fbfb4fb27fbc39 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 18 May 2024 17:28:56 +1000 Subject: [PATCH] feat(ui): makeConnectionErrorSelector now creates a parameterized selector --- .../nodes/hooks/useConnectionState.ts | 14 +- .../nodes/store/util/connectionValidation.ts | 176 +++++++++--------- 2 files changed, 93 insertions(+), 97 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts index dfa8b0cf36..9571ce2ee2 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts @@ -34,16 +34,8 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta ); const selectConnectionError = useMemo( - () => - makeConnectionErrorSelector( - templates, - pendingConnection, - nodeId, - fieldName, - kind === 'inputs' ? 'target' : 'source', - fieldType - ), - [templates, pendingConnection, nodeId, fieldName, kind, fieldType] + () => makeConnectionErrorSelector(templates, nodeId, fieldName, kind === 'inputs' ? 'target' : 'source', fieldType), + [templates, nodeId, fieldName, kind, fieldType] ); const isConnected = useAppSelector(selectIsConnected); @@ -58,7 +50,7 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta pendingConnection.fieldTemplate.fieldKind === { inputs: 'input', outputs: 'output' }[kind] ); }, [fieldName, kind, nodeId, pendingConnection]); - const connectionError = useAppSelector(selectConnectionError); + const connectionError = useAppSelector((s) => selectConnectionError(s, pendingConnection)); const shouldDim = useMemo( () => Boolean(isConnectionInProgress && connectionError && !isConnectionStartField), diff --git a/invokeai/frontend/web/src/features/nodes/store/util/connectionValidation.ts b/invokeai/frontend/web/src/features/nodes/store/util/connectionValidation.ts index 98de4284ad..907426b51d 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/connectionValidation.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/connectionValidation.ts @@ -1,7 +1,8 @@ import graphlib from '@dagrejs/graphlib'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; +import type { RootState } from 'app/store/store'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import type { PendingConnection, Templates } from 'features/nodes/store/types'; +import type { NodesState, PendingConnection, Templates } from 'features/nodes/store/types'; import { type FieldType, isStatefulFieldType } from 'features/nodes/types/field'; import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation'; import i18n from 'i18next'; @@ -190,105 +191,108 @@ export const getCollectItemType = ( */ export const makeConnectionErrorSelector = ( templates: Templates, - pendingConnection: PendingConnection | null, nodeId: string, fieldName: string, handleType: HandleType, fieldType: FieldType ) => { - return createMemoizedSelector(selectNodesSlice, (nodesSlice) => { - const { nodes, edges } = nodesSlice; + return createMemoizedSelector( + selectNodesSlice, + (state: RootState, pendingConnection: PendingConnection | null) => pendingConnection, + (nodesSlice: NodesState, pendingConnection: PendingConnection | null) => { + const { nodes, edges } = nodesSlice; - if (!pendingConnection) { - return i18n.t('nodes.noConnectionInProgress'); - } - - const connectionNodeId = pendingConnection.node.id; - const connectionFieldName = pendingConnection.fieldTemplate.name; - const connectionHandleType = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source'; - const connectionStartFieldType = pendingConnection.fieldTemplate.type; - - if (!connectionHandleType || !connectionNodeId || !connectionFieldName) { - return i18n.t('nodes.noConnectionData'); - } - - const targetType = handleType === 'target' ? fieldType : connectionStartFieldType; - const sourceType = handleType === 'source' ? fieldType : connectionStartFieldType; - - if (nodeId === connectionNodeId) { - return i18n.t('nodes.cannotConnectToSelf'); - } - - if (handleType === connectionHandleType) { - if (handleType === 'source') { - return i18n.t('nodes.cannotConnectOutputToOutput'); + if (!pendingConnection) { + return i18n.t('nodes.noConnectionInProgress'); } - return i18n.t('nodes.cannotConnectInputToInput'); - } - // we have to figure out which is the target and which is the source - const targetNodeId = handleType === 'target' ? nodeId : connectionNodeId; - const targetFieldName = handleType === 'target' ? fieldName : connectionFieldName; - const sourceNodeId = handleType === 'source' ? nodeId : connectionNodeId; - const sourceFieldName = handleType === 'source' ? fieldName : connectionFieldName; + const connectionNodeId = pendingConnection.node.id; + const connectionFieldName = pendingConnection.fieldTemplate.name; + const connectionHandleType = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source'; + const connectionStartFieldType = pendingConnection.fieldTemplate.type; - if ( - edges.find((edge) => { - edge.target === targetNodeId && - edge.targetHandle === targetFieldName && - edge.source === sourceNodeId && - edge.sourceHandle === sourceFieldName; - }) - ) { - // We already have a connection from this source to this target - return i18n.t('nodes.cannotDuplicateConnection'); - } + if (!connectionHandleType || !connectionNodeId || !connectionFieldName) { + return i18n.t('nodes.noConnectionData'); + } - const targetNode = nodes.find((node) => node.id === targetNodeId); - assert(targetNode, `Target node not found: ${targetNodeId}`); - const targetTemplate = templates[targetNode.data.type]; - assert(targetTemplate, `Target template not found: ${targetNode.data.type}`); + const targetType = handleType === 'target' ? fieldType : connectionStartFieldType; + const sourceType = handleType === 'source' ? fieldType : connectionStartFieldType; - if (targetTemplate.inputs[targetFieldName]?.input === 'direct') { - return i18n.t('nodes.cannotConnectToDirectInput'); - } - if (targetNode.data.type === 'collect' && targetFieldName === 'item') { - // Collect nodes shouldn't mix and match field types - const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id); - if (collectItemType) { - if (!areTypesEqual(sourceType, collectItemType)) { - return i18n.t('nodes.cannotMixAndMatchCollectionItemTypes'); + if (nodeId === connectionNodeId) { + return i18n.t('nodes.cannotConnectToSelf'); + } + + if (handleType === connectionHandleType) { + if (handleType === 'source') { + return i18n.t('nodes.cannotConnectOutputToOutput'); + } + return i18n.t('nodes.cannotConnectInputToInput'); + } + + // we have to figure out which is the target and which is the source + const targetNodeId = handleType === 'target' ? nodeId : connectionNodeId; + const targetFieldName = handleType === 'target' ? fieldName : connectionFieldName; + const sourceNodeId = handleType === 'source' ? nodeId : connectionNodeId; + const sourceFieldName = handleType === 'source' ? fieldName : connectionFieldName; + + if ( + edges.find((edge) => { + edge.target === targetNodeId && + edge.targetHandle === targetFieldName && + edge.source === sourceNodeId && + edge.sourceHandle === sourceFieldName; + }) + ) { + // We already have a connection from this source to this target + return i18n.t('nodes.cannotDuplicateConnection'); + } + + const targetNode = nodes.find((node) => node.id === targetNodeId); + assert(targetNode, `Target node not found: ${targetNodeId}`); + const targetTemplate = templates[targetNode.data.type]; + assert(targetTemplate, `Target template not found: ${targetNode.data.type}`); + + if (targetTemplate.inputs[targetFieldName]?.input === 'direct') { + return i18n.t('nodes.cannotConnectToDirectInput'); + } + if (targetNode.data.type === 'collect' && targetFieldName === 'item') { + // Collect nodes shouldn't mix and match field types + const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id); + if (collectItemType) { + if (!areTypesEqual(sourceType, collectItemType)) { + return i18n.t('nodes.cannotMixAndMatchCollectionItemTypes'); + } } } + + if ( + edges.find((edge) => { + return edge.target === targetNodeId && edge.targetHandle === targetFieldName; + }) && + // except CollectionItem inputs can have multiples + targetType.name !== 'CollectionItemField' + ) { + return i18n.t('nodes.inputMayOnlyHaveOneConnection'); + } + + if (!validateSourceAndTargetTypes(sourceType, targetType)) { + return i18n.t('nodes.fieldTypesMustMatch'); + } + + const hasCycles = getHasCycles( + connectionHandleType === 'source' ? connectionNodeId : nodeId, + connectionHandleType === 'source' ? nodeId : connectionNodeId, + nodes, + edges + ); + + if (hasCycles) { + return i18n.t('nodes.connectionWouldCreateCycle'); + } + + return; } - - if ( - edges.find((edge) => { - return edge.target === targetNodeId && edge.targetHandle === targetFieldName; - }) && - // except CollectionItem inputs can have multiples - targetType.name !== 'CollectionItemField' - ) { - return i18n.t('nodes.inputMayOnlyHaveOneConnection'); - } - - if (!validateSourceAndTargetTypes(sourceType, targetType)) { - return i18n.t('nodes.fieldTypesMustMatch'); - } - - const hasCycles = getHasCycles( - connectionHandleType === 'source' ? connectionNodeId : nodeId, - connectionHandleType === 'source' ? nodeId : connectionNodeId, - nodes, - edges - ); - - if (hasCycles) { - return i18n.t('nodes.connectionWouldCreateCycle'); - } - - return; - }); + ); }; /**