diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts index 81eea993be..f0dba67bf5 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts @@ -10,13 +10,15 @@ import { } from 'features/nodes/store/nodesSlice'; import { getFirstValidConnection } from 'features/nodes/store/util/connectionValidation'; import { isInvocationNode } from 'features/nodes/types/invocation'; +import { isString } from 'lodash-es'; import { useCallback, useMemo } from 'react'; -import type { OnConnect, OnConnectEnd, OnConnectStart } from 'reactflow'; +import { type OnConnect, type OnConnectEnd, type OnConnectStart, useUpdateNodeInternals } from 'reactflow'; import { assert } from 'tsafe'; export const useConnection = () => { const store = useAppStore(); const templates = useStore($templates); + const updateNodeInternals = useUpdateNodeInternals(); const onConnectStart = useCallback( (event, params) => { @@ -41,9 +43,11 @@ export const useConnection = () => { (connection) => { const { dispatch } = store; dispatch(connectionMade(connection)); + const nodesToUpdate = [connection.source, connection.target].filter(isString); + updateNodeInternals(nodesToUpdate); $pendingConnection.set(null); }, - [store] + [store, updateNodeInternals] ); const onConnectEnd = useCallback(() => { const { dispatch } = store; @@ -80,13 +84,15 @@ export const useConnection = () => { ); if (connection) { dispatch(connectionMade(connection)); + const nodesToUpdate = [connection.source, connection.target].filter(isString); + updateNodeInternals(nodesToUpdate); } $pendingConnection.set(null); } else { // The mouse is not over a node - we should open the add node popover $isAddNodePopoverOpen.set(true); } - }, [store, templates]); + }, [store, templates, updateNodeInternals]); const api = useMemo(() => ({ onConnectStart, onConnect, onConnectEnd }), [onConnectStart, onConnect, onConnectEnd]); return api;