diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx index cad4dded39..a2fce55ce3 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx @@ -69,7 +69,7 @@ const InputField = ({ nodeId, fieldName }: Props) => { ); } - if (fieldTemplate.input === 'connection') { + if (fieldTemplate.input === 'connection' || isConnected) { return ( diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts b/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts index f5931db87e..7972f9eee3 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts @@ -1,4 +1,7 @@ +import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; +import { useAppSelector } from 'app/store/storeHooks'; import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate'; +import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames'; import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate'; import { keys, map } from 'lodash-es'; @@ -6,14 +9,31 @@ import { useMemo } from 'react'; export const useAnyOrDirectInputFieldNames = (nodeId: string): string[] => { const template = useNodeTemplate(nodeId); + const selectConnectedFieldNames = useMemo( + () => + createMemoizedSelector(selectNodesSlice, (nodesSlice) => + nodesSlice.edges + .filter((e) => e.target === nodeId) + .map((e) => e.targetHandle) + .filter(Boolean) + ), + [nodeId] + ); + const connectedFieldNames = useAppSelector(selectConnectedFieldNames); + const fieldNames = useMemo(() => { - const fields = map(template.inputs).filter( - (field) => + const fields = map(template.inputs).filter((field) => { + if (connectedFieldNames.includes(field.name)) { + return false; + } + + return ( (['any', 'direct'].includes(field.input) || field.type.isCollectionOrScalar) && keys(TEMPLATE_BUILDER_MAP).includes(field.type.name) - ); + ); + }); return getSortedFilteredFieldNames(fields); - }, [template]); + }, [connectedFieldNames, template.inputs]); return fieldNames; }; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts index 84413fc9c8..0eeb592c31 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts @@ -1,4 +1,7 @@ +import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; +import { useAppSelector } from 'app/store/storeHooks'; import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate'; +import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames'; import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate'; import { keys, map } from 'lodash-es'; @@ -6,15 +9,31 @@ import { useMemo } from 'react'; export const useConnectionInputFieldNames = (nodeId: string): string[] => { const template = useNodeTemplate(nodeId); + const selectConnectedFieldNames = useMemo( + () => + createMemoizedSelector(selectNodesSlice, (nodesSlice) => + nodesSlice.edges + .filter((e) => e.target === nodeId) + .map((e) => e.targetHandle) + .filter(Boolean) + ), + [nodeId] + ); + const connectedFieldNames = useAppSelector(selectConnectedFieldNames); + const fieldNames = useMemo(() => { // get the visible fields - const fields = map(template.inputs).filter( - (field) => + const fields = map(template.inputs).filter((field) => { + if (connectedFieldNames.includes(field.name)) { + return true; + } + return ( (field.input === 'connection' && !field.type.isCollectionOrScalar) || !keys(TEMPLATE_BUILDER_MAP).includes(field.type.name) - ); + ); + }); return getSortedFilteredFieldNames(fields); - }, [template]); + }, [connectedFieldNames, template.inputs]); return fieldNames; };