From 26029108f7150a1bf870b26b91205d6da2f9d0c0 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 13:54:32 +1000 Subject: [PATCH] feat(ui): rework node and edge mutation logic Remove our DIY'd reducers, consolidating all node and edge mutations to use `edgesChanged` and `nodesChanged`, which are called by reactflow. This makes the API for manipulating nodes and edges less tangly and error-prone. --- .../flow/AddNodePopover/AddNodePopover.tsx | 6 +- .../features/nodes/components/flow/Flow.tsx | 50 ++++------- .../flow/nodes/Invocation/MissingFallback.tsx | 20 +++++ .../Invocation/fields/LinearViewField.tsx | 11 ++- .../sidePanel/viewMode/WorkflowField.tsx | 11 ++- .../sidePanel/workflow/WorkflowLinearTab.tsx | 10 ++- .../src/features/nodes/hooks/useConnection.ts | 25 +++--- .../features/nodes/hooks/useDoesFieldExist.ts | 20 +++++ .../src/features/nodes/store/nodesSlice.ts | 87 ++++++++++--------- .../nodes/store/util/reactFlowUtil.ts | 32 +++++++ .../src/features/nodes/store/workflowSlice.ts | 14 +-- 11 files changed, 186 insertions(+), 100 deletions(-) create mode 100644 invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/MissingFallback.tsx create mode 100644 invokeai/frontend/web/src/features/nodes/hooks/useDoesFieldExist.ts create mode 100644 invokeai/frontend/web/src/features/nodes/store/util/reactFlowUtil.ts diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx index 561890245e..12592c86da 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx @@ -14,11 +14,12 @@ import { $pendingConnection, $templates, closeAddNodePopover, - connectionMade, + edgesChanged, nodeAdded, openAddNodePopover, } from 'features/nodes/store/nodesSlice'; import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection'; +import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil'; import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes'; import type { AnyNode } from 'features/nodes/types/invocation'; import { isInvocationNode } from 'features/nodes/types/invocation'; @@ -166,7 +167,8 @@ const AddNodePopover = () => { edgePendingUpdate ); if (connection) { - dispatch(connectionMade(connection)); + const newEdge = connectionToEdge(connection); + dispatch(edgesChanged([{ type: 'add', item: newEdge }])); } } diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx index 18bbac0b44..5327d72478 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx @@ -14,29 +14,24 @@ import { $lastEdgeUpdateMouseEvent, $pendingConnection, $viewport, - connectionMade, - edgeDeleted, edgesChanged, - edgesDeleted, nodesChanged, - nodesDeleted, redo, selectedAll, + selectionDeleted, undo, } from 'features/nodes/store/nodesSlice'; import { $flow } from 'features/nodes/store/reactFlowInstance'; -import { isString } from 'lodash-es'; +import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil'; import type { CSSProperties, MouseEvent } from 'react'; import { memo, useCallback, useMemo, useRef } from 'react'; import { useHotkeys } from 'react-hotkeys-hook'; import type { OnEdgesChange, - OnEdgesDelete, OnEdgeUpdateFunc, OnInit, OnMoveEnd, OnNodesChange, - OnNodesDelete, ProOptions, ReactFlowProps, ReactFlowState, @@ -50,8 +45,6 @@ import CurrentImageNode from './nodes/CurrentImage/CurrentImageNode'; import InvocationNodeWrapper from './nodes/Invocation/InvocationNodeWrapper'; import NotesNode from './nodes/Notes/NotesNode'; -const DELETE_KEYS = ['Delete', 'Backspace']; - const edgeTypes = { collapsed: InvocationCollapsedEdge, default: InvocationDefaultEdge, @@ -109,20 +102,6 @@ export const Flow = memo(() => { [dispatch] ); - const onEdgesDelete: OnEdgesDelete = useCallback( - (edges) => { - dispatch(edgesDeleted(edges)); - }, - [dispatch] - ); - - const onNodesDelete: OnNodesDelete = useCallback( - (nodes) => { - dispatch(nodesDeleted(nodes)); - }, - [dispatch] - ); - const handleMoveEnd: OnMoveEnd = useCallback((e, viewport) => { $viewport.set(viewport); }, []); @@ -167,16 +146,20 @@ export const Flow = memo(() => { }, []); const onEdgeUpdate: OnEdgeUpdateFunc = useCallback( - (edge, newConnection) => { + (oldEdge, newConnection) => { // This event is fired when an edge update is successful $didUpdateEdge.set(true); // When an edge update is successful, we need to delete the old edge and create a new one - dispatch(edgeDeleted(edge.id)); - dispatch(connectionMade(newConnection)); + const newEdge = connectionToEdge(newConnection); + dispatch( + edgesChanged([ + { type: 'remove', id: oldEdge.id }, + { type: 'add', item: newEdge }, + ]) + ); // Because we shift the position of handles depending on whether a field is connected or not, we must use // updateNodeInternals to tell reactflow to recalculate the positions of the handles - const nodesToUpdate = [edge.source, edge.target, newConnection.source, newConnection.target].filter(isString); - updateNodeInternals(nodesToUpdate); + updateNodeInternals([oldEdge.source, oldEdge.target, newEdge.source, newEdge.target]); }, [dispatch, updateNodeInternals] ); @@ -193,7 +176,7 @@ export const Flow = memo(() => { // If we got this far and did not successfully update an edge, and the mouse moved away from the handle, // the user probably intended to delete the edge if (!didUpdateEdge && didMouseMove) { - dispatch(edgeDeleted(edge.id)); + dispatch(edgesChanged([{ type: 'remove', id: edge.id }])); } $edgePendingUpdate.set(null); @@ -267,6 +250,11 @@ export const Flow = memo(() => { }, [cancelConnection]); useHotkeys('esc', onEscapeHotkey); + const onDeleteHotkey = useCallback(() => { + dispatch(selectionDeleted()); + }, [dispatch]); + useHotkeys(['delete', 'backspace'], onDeleteHotkey); + return ( { onMouseMove={onMouseMove} onNodesChange={onNodesChange} onEdgesChange={onEdgesChange} - onEdgesDelete={onEdgesDelete} onEdgeUpdate={onEdgeUpdate} onEdgeUpdateStart={onEdgeUpdateStart} onEdgeUpdateEnd={onEdgeUpdateEnd} - onNodesDelete={onNodesDelete} onConnectStart={onConnectStart} onConnect={onConnect} onConnectEnd={onConnectEnd} @@ -298,7 +284,7 @@ export const Flow = memo(() => { proOptions={proOptions} style={flowStyles} onPaneClick={handlePaneClick} - deleteKeyCode={DELETE_KEYS} + deleteKeyCode={null} selectionMode={selectionMode} elevateEdgesOnSelect > diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/MissingFallback.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/MissingFallback.tsx new file mode 100644 index 0000000000..ca5b74b7ff --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/MissingFallback.tsx @@ -0,0 +1,20 @@ +import { useDoesFieldExist } from 'features/nodes/hooks/useDoesFieldExist'; +import type { PropsWithChildren } from 'react'; +import { memo } from 'react'; + +type Props = PropsWithChildren<{ + nodeId: string; + fieldName?: string; +}>; + +export const MissingFallback = memo((props: Props) => { + // We must be careful here to avoid race conditions where a deleted node is still referenced as an exposed field + const exists = useDoesFieldExist(props.nodeId, props.fieldName); + if (!exists) { + return null; + } + + return props.children; +}); + +MissingFallback.displayName = 'MissingFallback'; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/LinearViewField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/LinearViewField.tsx index 0cd199f7a4..f7ff85f479 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/LinearViewField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/LinearViewField.tsx @@ -3,6 +3,7 @@ import { CSS } from '@dnd-kit/utilities'; import { Flex, Icon, IconButton, Spacer, Tooltip } from '@invoke-ai/ui-library'; import { useAppDispatch } from 'app/store/storeHooks'; import NodeSelectionOverlay from 'common/components/NodeSelectionOverlay'; +import { MissingFallback } from 'features/nodes/components/flow/nodes/Invocation/MissingFallback'; import { useFieldOriginalValue } from 'features/nodes/hooks/useFieldOriginalValue'; import { useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode'; import { workflowExposedFieldRemoved } from 'features/nodes/store/workflowSlice'; @@ -20,7 +21,7 @@ type Props = { fieldName: string; }; -const LinearViewField = ({ nodeId, fieldName }: Props) => { +const LinearViewFieldInternal = ({ nodeId, fieldName }: Props) => { const dispatch = useAppDispatch(); const { isValueChanged, onReset } = useFieldOriginalValue(nodeId, fieldName); const { isMouseOverNode, handleMouseOut, handleMouseOver } = useMouseOverNode(nodeId); @@ -99,4 +100,12 @@ const LinearViewField = ({ nodeId, fieldName }: Props) => { ); }; +const LinearViewField = ({ nodeId, fieldName }: Props) => { + return ( + + + + ); +}; + export default memo(LinearViewField); diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/viewMode/WorkflowField.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/viewMode/WorkflowField.tsx index e707dd4f54..a30bda354d 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/viewMode/WorkflowField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/viewMode/WorkflowField.tsx @@ -1,6 +1,7 @@ import { Flex, FormLabel, Icon, IconButton, Spacer, Tooltip } from '@invoke-ai/ui-library'; import FieldTooltipContent from 'features/nodes/components/flow/nodes/Invocation/fields/FieldTooltipContent'; import InputFieldRenderer from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer'; +import { MissingFallback } from 'features/nodes/components/flow/nodes/Invocation/MissingFallback'; import { useFieldLabel } from 'features/nodes/hooks/useFieldLabel'; import { useFieldOriginalValue } from 'features/nodes/hooks/useFieldOriginalValue'; import { useFieldTemplateTitle } from 'features/nodes/hooks/useFieldTemplateTitle'; @@ -14,7 +15,7 @@ type Props = { fieldName: string; }; -const WorkflowField = ({ nodeId, fieldName }: Props) => { +const WorkflowFieldInternal = ({ nodeId, fieldName }: Props) => { const label = useFieldLabel(nodeId, fieldName); const fieldTemplateTitle = useFieldTemplateTitle(nodeId, fieldName, 'inputs'); const { isValueChanged, onReset } = useFieldOriginalValue(nodeId, fieldName); @@ -50,4 +51,12 @@ const WorkflowField = ({ nodeId, fieldName }: Props) => { ); }; +const WorkflowField = ({ nodeId, fieldName }: Props) => { + return ( + + + + ); +}; + export default memo(WorkflowField); diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLinearTab.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLinearTab.tsx index fa1767138e..9b0e5bb9d6 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLinearTab.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLinearTab.tsx @@ -6,10 +6,10 @@ import { IAINoContentFallback } from 'common/components/IAIImageFallback'; import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; import DndSortable from 'features/dnd/components/DndSortable'; import type { DragEndEvent } from 'features/dnd/types'; -import LinearViewField from 'features/nodes/components/flow/nodes/Invocation/fields/LinearViewField'; +import LinearViewFieldInternal from 'features/nodes/components/flow/nodes/Invocation/fields/LinearViewField'; import { selectWorkflowSlice, workflowExposedFieldsReordered } from 'features/nodes/store/workflowSlice'; import type { FieldIdentifier } from 'features/nodes/types/field'; -import { memo, useCallback } from 'react'; +import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useGetOpenAPISchemaQuery } from 'services/api/endpoints/appInfo'; @@ -40,16 +40,18 @@ const WorkflowLinearTab = () => { [dispatch, fields] ); + const items = useMemo(() => fields.map((field) => `${field.nodeId}.${field.fieldName}`), [fields]); + return ( - `${field.nodeId}.${field.fieldName}`)}> + {isLoading ? ( ) : fields.length ? ( fields.map(({ nodeId, fieldName }) => ( - + )) ) : ( diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts index de01c79b30..36491e80bc 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts @@ -7,13 +7,12 @@ import { $isAddNodePopoverOpen, $pendingConnection, $templates, - connectionMade, - edgeDeleted, + edgesChanged, } from 'features/nodes/store/nodesSlice'; import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection'; -import { isString } from 'lodash-es'; +import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil'; import { useCallback, useMemo } from 'react'; -import type { OnConnect, OnConnectEnd, OnConnectStart } from 'reactflow'; +import type { EdgeChange, OnConnect, OnConnectEnd, OnConnectStart } from 'reactflow'; import { useUpdateNodeInternals } from 'reactflow'; import { assert } from 'tsafe'; @@ -50,9 +49,9 @@ export const useConnection = () => { const onConnect = useCallback( (connection) => { const { dispatch } = store; - dispatch(connectionMade(connection)); - const nodesToUpdate = [connection.source, connection.target].filter(isString); - updateNodeInternals(nodesToUpdate); + const newEdge = connectionToEdge(connection); + dispatch(edgesChanged([{ type: 'add', item: newEdge }])); + updateNodeInternals([newEdge.source, newEdge.target]); $pendingConnection.set(null); }, [store, updateNodeInternals] @@ -92,13 +91,17 @@ export const useConnection = () => { edgePendingUpdate ); if (connection) { - dispatch(connectionMade(connection)); - const nodesToUpdate = [connection.source, connection.target].filter(isString); - updateNodeInternals(nodesToUpdate); + const newEdge = connectionToEdge(connection); + const changes: EdgeChange[] = [{ type: 'add', item: newEdge }]; + + const nodesToUpdate = [newEdge.source, newEdge.target]; if (edgePendingUpdate) { - dispatch(edgeDeleted(edgePendingUpdate.id)); $didUpdateEdge.set(true); + changes.push({ type: 'remove', id: edgePendingUpdate.id }); + nodesToUpdate.push(edgePendingUpdate.source, edgePendingUpdate.target); } + dispatch(edgesChanged(changes)); + updateNodeInternals(nodesToUpdate); } $pendingConnection.set(null); } else { diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useDoesFieldExist.ts b/invokeai/frontend/web/src/features/nodes/hooks/useDoesFieldExist.ts new file mode 100644 index 0000000000..4e97b1689c --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/hooks/useDoesFieldExist.ts @@ -0,0 +1,20 @@ +import { useAppSelector } from 'app/store/storeHooks'; +import { isInvocationNode } from 'features/nodes/types/invocation'; + +export const useDoesFieldExist = (nodeId: string, fieldName?: string) => { + const doesFieldExist = useAppSelector((s) => { + const node = s.nodes.present.nodes.find((n) => n.id === nodeId); + if (!isInvocationNode(node)) { + return false; + } + if (fieldName === undefined) { + return true; + } + if (!node.data.inputs[fieldName]) { + return false; + } + return true; + }); + + return doesFieldExist; +}; diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index 7915d3608c..a1e32a72fe 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -1,6 +1,7 @@ import type { PayloadAction, UnknownAction } from '@reduxjs/toolkit'; import { createSlice, isAnyOf } from '@reduxjs/toolkit'; import type { PersistConfig, RootState } from 'app/store/store'; +import { deepClone } from 'common/util/deepClone'; import { workflowLoaded } from 'features/nodes/store/actions'; import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants'; import type { @@ -48,8 +49,8 @@ import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocatio import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation'; import { atom } from 'nanostores'; import type { MouseEvent } from 'react'; -import type { Connection, Edge, EdgeChange, EdgeRemoveChange, Node, NodeChange, Viewport, XYPosition } from 'reactflow'; -import { addEdge, applyEdgeChanges, applyNodeChanges, getConnectedEdges, getIncomers, getOutgoers } from 'reactflow'; +import type { Edge, EdgeChange, Node, NodeChange, Viewport, XYPosition } from 'reactflow'; +import { applyEdgeChanges, applyNodeChanges, getConnectedEdges, getIncomers, getOutgoers } from 'reactflow'; import type { UndoableOptions } from 'redux-undo'; import type { z } from 'zod'; @@ -124,10 +125,27 @@ export const nodesSlice = createSlice({ state.nodes.push(node); }, edgesChanged: (state, action: PayloadAction) => { - state.edges = applyEdgeChanges(action.payload, state.edges); - }, - connectionMade: (state, action: PayloadAction) => { - state.edges = addEdge({ ...action.payload, type: 'default' }, state.edges); + const changes = deepClone(action.payload); + action.payload.forEach((change) => { + if (change.type === 'remove' || change.type === 'select') { + const edge = state.edges.find((e) => e.id === change.id); + // If we deleted or selected a collapsed edge, we need to find its "hidden" edges and do the same to them + if (edge && edge.type === 'collapsed') { + const hiddenEdges = state.edges.filter((e) => e.source === edge.source && e.target === edge.target); + if (change.type === 'remove') { + hiddenEdges.forEach((e) => { + changes.push({ type: 'remove', id: e.id }); + }); + } + if (change.type === 'select') { + hiddenEdges.forEach((e) => { + changes.push({ type: 'select', id: e.id, selected: change.selected }); + }); + } + } + } + }); + state.edges = applyEdgeChanges(changes, state.edges); }, fieldLabelChanged: ( state, @@ -264,33 +282,6 @@ export const nodesSlice = createSlice({ } } }, - edgeDeleted: (state, action: PayloadAction) => { - state.edges = state.edges.filter((e) => e.id !== action.payload); - }, - edgesDeleted: (state, action: PayloadAction) => { - const edges = action.payload; - const collapsedEdges = edges.filter((e) => e.type === 'collapsed'); - - // if we delete a collapsed edge, we need to delete all collapsed edges between the same nodes - if (collapsedEdges.length) { - const edgeChanges: EdgeRemoveChange[] = []; - collapsedEdges.forEach((collapsedEdge) => { - state.edges.forEach((edge) => { - if (edge.source === collapsedEdge.source && edge.target === collapsedEdge.target) { - edgeChanges.push({ id: edge.id, type: 'remove' }); - } - }); - }); - state.edges = applyEdgeChanges(edgeChanges, state.edges); - } - }, - nodesDeleted: (state, action: PayloadAction) => { - action.payload.forEach((node) => { - if (!isInvocationNode(node)) { - return; - } - }); - }, nodeLabelChanged: (state, action: PayloadAction<{ nodeId: string; label: string }>) => { const { nodeId, label } = action.payload; const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId); @@ -435,6 +426,23 @@ export const nodesSlice = createSlice({ state.nodes = applyNodeChanges(nodeChanges, state.nodes); state.edges = applyEdgeChanges(edgeChanges, state.edges); }, + selectionDeleted: (state) => { + const selectedNodes = state.nodes.filter((n) => n.selected); + const selectedEdges = state.edges.filter((e) => e.selected); + + const nodeChanges: NodeChange[] = selectedNodes.map((n) => ({ + id: n.id, + type: 'remove', + })); + + const edgeChanges: EdgeChange[] = selectedEdges.map((e) => ({ + id: e.id, + type: 'remove', + })); + + state.nodes = applyNodeChanges(nodeChanges, state.nodes); + state.edges = applyEdgeChanges(edgeChanges, state.edges); + }, undo: (state) => state, redo: (state) => state, }, @@ -457,10 +465,7 @@ export const nodesSlice = createSlice({ }); export const { - connectionMade, - edgeDeleted, edgesChanged, - edgesDeleted, fieldValueReset, fieldBoardValueChanged, fieldBooleanValueChanged, @@ -488,11 +493,11 @@ export const { nodeLabelChanged, nodeNotesChanged, nodesChanged, - nodesDeleted, nodeUseCacheChanged, notesNodeValueChanged, selectedAll, selectionPasted, + selectionDeleted, undo, redo, } = nodesSlice.actions; @@ -580,10 +585,7 @@ export const nodesUndoableConfig: UndoableOptions = { // This is used for tracking `state.workflow.isTouched` export const isAnyNodeOrEdgeMutation = isAnyOf( - connectionMade, - edgeDeleted, edgesChanged, - edgesDeleted, fieldBoardValueChanged, fieldBooleanValueChanged, fieldColorValueChanged, @@ -601,13 +603,14 @@ export const isAnyNodeOrEdgeMutation = isAnyOf( fieldStringValueChanged, fieldVaeModelValueChanged, nodeAdded, + nodesChanged, nodeReplaced, nodeIsIntermediateChanged, nodeIsOpenChanged, nodeLabelChanged, nodeNotesChanged, - nodesDeleted, nodeUseCacheChanged, notesNodeValueChanged, - selectionPasted + selectionPasted, + selectionDeleted ); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/reactFlowUtil.ts b/invokeai/frontend/web/src/features/nodes/store/util/reactFlowUtil.ts new file mode 100644 index 0000000000..89be7951a2 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/reactFlowUtil.ts @@ -0,0 +1,32 @@ +import type { Connection, Edge } from 'reactflow'; +import { assert } from 'tsafe'; + +/** + * Gets the edge id for a connection + * Copied from: https://github.com/xyflow/xyflow/blob/v11/packages/core/src/utils/graph.ts#L44-L45 + * Requested for this to be exported in: https://github.com/xyflow/xyflow/issues/4290 + * @param connection The connection to get the id for + * @returns The edge id + */ +const getEdgeId = (connection: Connection): string => { + const { source, sourceHandle, target, targetHandle } = connection; + return `reactflow__edge-${source}${sourceHandle || ''}-${target}${targetHandle || ''}`; +}; + +/** + * Converts a connection to an edge + * @param connection The connection to convert to an edge + * @returns The edge + * @throws If the connection is invalid (e.g. missing source, sourcehandle, target, or targetHandle) + */ +export const connectionToEdge = (connection: Connection): Edge => { + const { source, sourceHandle, target, targetHandle } = connection; + assert(source && sourceHandle && target && targetHandle, 'Invalid connection'); + return { + source, + sourceHandle, + target, + targetHandle, + id: getEdgeId({ source, sourceHandle, target, targetHandle }), + }; +}; diff --git a/invokeai/frontend/web/src/features/nodes/store/workflowSlice.ts b/invokeai/frontend/web/src/features/nodes/store/workflowSlice.ts index 6293d3cce5..b3ec4f0614 100644 --- a/invokeai/frontend/web/src/features/nodes/store/workflowSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/workflowSlice.ts @@ -3,7 +3,7 @@ import { createSlice } from '@reduxjs/toolkit'; import type { PersistConfig, RootState } from 'app/store/store'; import { deepClone } from 'common/util/deepClone'; import { workflowLoaded } from 'features/nodes/store/actions'; -import { isAnyNodeOrEdgeMutation, nodeEditorReset, nodesChanged, nodesDeleted } from 'features/nodes/store/nodesSlice'; +import { isAnyNodeOrEdgeMutation, nodeEditorReset, nodesChanged } from 'features/nodes/store/nodesSlice'; import type { FieldIdentifierWithValue, WorkflowMode, @@ -139,16 +139,16 @@ export const workflowSlice = createSlice({ }; }); - builder.addCase(nodesDeleted, (state, action) => { - action.payload.forEach((node) => { - state.exposedFields = state.exposedFields.filter((f) => f.nodeId !== node.id); - }); - }); - builder.addCase(nodeEditorReset, () => deepClone(initialWorkflowState)); builder.addCase(nodesChanged, (state, action) => { // Not all changes to nodes should result in the workflow being marked touched + action.payload.forEach((change) => { + if (change.type === 'remove') { + state.exposedFields = state.exposedFields.filter((f) => f.nodeId !== change.id); + } + }); + const filteredChanges = action.payload.filter((change) => { // We always want to mark the workflow as touched if a node is added, removed, or reset if (['add', 'remove', 'reset'].includes(change.type)) {