diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx index 12592c86da..6e695561a2 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx @@ -15,9 +15,10 @@ import { $templates, closeAddNodePopover, edgesChanged, - nodeAdded, + nodesChanged, openAddNodePopover, } from 'features/nodes/store/nodesSlice'; +import { findUnoccupiedPosition } from 'features/nodes/store/util/findUnoccupiedPosition'; import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection'; import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil'; import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes'; @@ -30,6 +31,7 @@ import { useHotkeys } from 'react-hotkeys-hook'; import type { HotkeyCallback } from 'react-hotkeys-hook/dist/types'; import { useTranslation } from 'react-i18next'; import type { FilterOptionOption } from 'react-select/dist/declarations/src/filters'; +import type { EdgeChange, NodeChange } from 'reactflow'; const createRegex = memoize( (inputValue: string) => @@ -131,11 +133,29 @@ const AddNodePopover = () => { }); return null; } + + // Find a cozy spot for the node const cursorPos = $cursorPos.get(); - dispatch(nodeAdded({ node, cursorPos })); + const { nodes, edges } = store.getState().nodes.present; + node.position = findUnoccupiedPosition(nodes, cursorPos?.x ?? node.position.x, cursorPos?.y ?? node.position.y); + node.selected = true; + + // Deselect all other nodes and edges + const nodeChanges: NodeChange[] = [{ type: 'add', item: node }]; + const edgeChanges: EdgeChange[] = []; + nodes.forEach((n) => { + nodeChanges.push({ id: n.id, type: 'select', selected: false }); + }); + edges.forEach((e) => { + edgeChanges.push({ id: e.id, type: 'select', selected: false }); + }); + + // Onwards! + dispatch(nodesChanged(nodeChanges)); + dispatch(edgesChanged(edgeChanges)); return node; }, - [dispatch, buildInvocation, toaster, t] + [buildInvocation, store, dispatch, t, toaster] ); const onChange = useCallback( diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index 05b53c518d..5f0dbb2b14 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -55,7 +55,6 @@ import type { UndoableOptions } from 'redux-undo'; import type { z } from 'zod'; import type { NodesState, PendingConnection, Templates } from './types'; -import { findUnoccupiedPosition } from './util/findUnoccupiedPosition'; const initialNodesState: NodesState = { _version: 1, @@ -102,28 +101,6 @@ export const nodesSlice = createSlice({ } state.nodes[nodeIndex] = action.payload.node; }, - nodeAdded: (state, action: PayloadAction<{ node: AnyNode; cursorPos: XYPosition | null }>) => { - const { node, cursorPos } = action.payload; - const position = findUnoccupiedPosition( - state.nodes, - cursorPos?.x ?? node.position.x, - cursorPos?.y ?? node.position.y - ); - node.position = position; - node.selected = true; - - state.nodes = applyNodeChanges( - state.nodes.map((n) => ({ id: n.id, type: 'select', selected: false })), - state.nodes - ); - - state.edges = applyEdgeChanges( - state.edges.map((e) => ({ id: e.id, type: 'select', selected: false })), - state.edges - ); - - state.nodes.push(node); - }, edgesChanged: (state, action: PayloadAction) => { const changes = deepClone(action.payload); action.payload.forEach((change) => { @@ -486,7 +463,6 @@ export const { fieldSchedulerValueChanged, fieldStringValueChanged, fieldVaeModelValueChanged, - nodeAdded, nodeReplaced, nodeEditorReset, nodeExclusivelySelected, @@ -604,7 +580,6 @@ export const isAnyNodeOrEdgeMutation = isAnyOf( fieldSchedulerValueChanged, fieldStringValueChanged, fieldVaeModelValueChanged, - nodeAdded, nodesChanged, nodeReplaced, nodeIsIntermediateChanged,