feat(ui): rework node and edge mutation logic

Remove our DIY'd reducers, consolidating all node and edge mutations to use `edgesChanged` and `nodesChanged`, which are called by reactflow. This makes the API for manipulating nodes and edges less tangly and error-prone.
This commit is contained in:
psychedelicious 2024-05-19 13:54:32 +10:00
parent 504ac82077
commit 26029108f7
11 changed files with 186 additions and 100 deletions

View File

@ -14,11 +14,12 @@ import {
$pendingConnection, $pendingConnection,
$templates, $templates,
closeAddNodePopover, closeAddNodePopover,
connectionMade, edgesChanged,
nodeAdded, nodeAdded,
openAddNodePopover, openAddNodePopover,
} from 'features/nodes/store/nodesSlice'; } from 'features/nodes/store/nodesSlice';
import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection'; import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection';
import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil';
import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes'; import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes';
import type { AnyNode } from 'features/nodes/types/invocation'; import type { AnyNode } from 'features/nodes/types/invocation';
import { isInvocationNode } from 'features/nodes/types/invocation'; import { isInvocationNode } from 'features/nodes/types/invocation';
@ -166,7 +167,8 @@ const AddNodePopover = () => {
edgePendingUpdate edgePendingUpdate
); );
if (connection) { if (connection) {
dispatch(connectionMade(connection)); const newEdge = connectionToEdge(connection);
dispatch(edgesChanged([{ type: 'add', item: newEdge }]));
} }
} }

View File

@ -14,29 +14,24 @@ import {
$lastEdgeUpdateMouseEvent, $lastEdgeUpdateMouseEvent,
$pendingConnection, $pendingConnection,
$viewport, $viewport,
connectionMade,
edgeDeleted,
edgesChanged, edgesChanged,
edgesDeleted,
nodesChanged, nodesChanged,
nodesDeleted,
redo, redo,
selectedAll, selectedAll,
selectionDeleted,
undo, undo,
} from 'features/nodes/store/nodesSlice'; } from 'features/nodes/store/nodesSlice';
import { $flow } from 'features/nodes/store/reactFlowInstance'; import { $flow } from 'features/nodes/store/reactFlowInstance';
import { isString } from 'lodash-es'; import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil';
import type { CSSProperties, MouseEvent } from 'react'; import type { CSSProperties, MouseEvent } from 'react';
import { memo, useCallback, useMemo, useRef } from 'react'; import { memo, useCallback, useMemo, useRef } from 'react';
import { useHotkeys } from 'react-hotkeys-hook'; import { useHotkeys } from 'react-hotkeys-hook';
import type { import type {
OnEdgesChange, OnEdgesChange,
OnEdgesDelete,
OnEdgeUpdateFunc, OnEdgeUpdateFunc,
OnInit, OnInit,
OnMoveEnd, OnMoveEnd,
OnNodesChange, OnNodesChange,
OnNodesDelete,
ProOptions, ProOptions,
ReactFlowProps, ReactFlowProps,
ReactFlowState, ReactFlowState,
@ -50,8 +45,6 @@ import CurrentImageNode from './nodes/CurrentImage/CurrentImageNode';
import InvocationNodeWrapper from './nodes/Invocation/InvocationNodeWrapper'; import InvocationNodeWrapper from './nodes/Invocation/InvocationNodeWrapper';
import NotesNode from './nodes/Notes/NotesNode'; import NotesNode from './nodes/Notes/NotesNode';
const DELETE_KEYS = ['Delete', 'Backspace'];
const edgeTypes = { const edgeTypes = {
collapsed: InvocationCollapsedEdge, collapsed: InvocationCollapsedEdge,
default: InvocationDefaultEdge, default: InvocationDefaultEdge,
@ -109,20 +102,6 @@ export const Flow = memo(() => {
[dispatch] [dispatch]
); );
const onEdgesDelete: OnEdgesDelete = useCallback(
(edges) => {
dispatch(edgesDeleted(edges));
},
[dispatch]
);
const onNodesDelete: OnNodesDelete = useCallback(
(nodes) => {
dispatch(nodesDeleted(nodes));
},
[dispatch]
);
const handleMoveEnd: OnMoveEnd = useCallback((e, viewport) => { const handleMoveEnd: OnMoveEnd = useCallback((e, viewport) => {
$viewport.set(viewport); $viewport.set(viewport);
}, []); }, []);
@ -167,16 +146,20 @@ export const Flow = memo(() => {
}, []); }, []);
const onEdgeUpdate: OnEdgeUpdateFunc = useCallback( const onEdgeUpdate: OnEdgeUpdateFunc = useCallback(
(edge, newConnection) => { (oldEdge, newConnection) => {
// This event is fired when an edge update is successful // This event is fired when an edge update is successful
$didUpdateEdge.set(true); $didUpdateEdge.set(true);
// When an edge update is successful, we need to delete the old edge and create a new one // When an edge update is successful, we need to delete the old edge and create a new one
dispatch(edgeDeleted(edge.id)); const newEdge = connectionToEdge(newConnection);
dispatch(connectionMade(newConnection)); dispatch(
edgesChanged([
{ type: 'remove', id: oldEdge.id },
{ type: 'add', item: newEdge },
])
);
// Because we shift the position of handles depending on whether a field is connected or not, we must use // Because we shift the position of handles depending on whether a field is connected or not, we must use
// updateNodeInternals to tell reactflow to recalculate the positions of the handles // updateNodeInternals to tell reactflow to recalculate the positions of the handles
const nodesToUpdate = [edge.source, edge.target, newConnection.source, newConnection.target].filter(isString); updateNodeInternals([oldEdge.source, oldEdge.target, newEdge.source, newEdge.target]);
updateNodeInternals(nodesToUpdate);
}, },
[dispatch, updateNodeInternals] [dispatch, updateNodeInternals]
); );
@ -193,7 +176,7 @@ export const Flow = memo(() => {
// If we got this far and did not successfully update an edge, and the mouse moved away from the handle, // If we got this far and did not successfully update an edge, and the mouse moved away from the handle,
// the user probably intended to delete the edge // the user probably intended to delete the edge
if (!didUpdateEdge && didMouseMove) { if (!didUpdateEdge && didMouseMove) {
dispatch(edgeDeleted(edge.id)); dispatch(edgesChanged([{ type: 'remove', id: edge.id }]));
} }
$edgePendingUpdate.set(null); $edgePendingUpdate.set(null);
@ -267,6 +250,11 @@ export const Flow = memo(() => {
}, [cancelConnection]); }, [cancelConnection]);
useHotkeys('esc', onEscapeHotkey); useHotkeys('esc', onEscapeHotkey);
const onDeleteHotkey = useCallback(() => {
dispatch(selectionDeleted());
}, [dispatch]);
useHotkeys(['delete', 'backspace'], onDeleteHotkey);
return ( return (
<ReactFlow <ReactFlow
id="workflow-editor" id="workflow-editor"
@ -280,11 +268,9 @@ export const Flow = memo(() => {
onMouseMove={onMouseMove} onMouseMove={onMouseMove}
onNodesChange={onNodesChange} onNodesChange={onNodesChange}
onEdgesChange={onEdgesChange} onEdgesChange={onEdgesChange}
onEdgesDelete={onEdgesDelete}
onEdgeUpdate={onEdgeUpdate} onEdgeUpdate={onEdgeUpdate}
onEdgeUpdateStart={onEdgeUpdateStart} onEdgeUpdateStart={onEdgeUpdateStart}
onEdgeUpdateEnd={onEdgeUpdateEnd} onEdgeUpdateEnd={onEdgeUpdateEnd}
onNodesDelete={onNodesDelete}
onConnectStart={onConnectStart} onConnectStart={onConnectStart}
onConnect={onConnect} onConnect={onConnect}
onConnectEnd={onConnectEnd} onConnectEnd={onConnectEnd}
@ -298,7 +284,7 @@ export const Flow = memo(() => {
proOptions={proOptions} proOptions={proOptions}
style={flowStyles} style={flowStyles}
onPaneClick={handlePaneClick} onPaneClick={handlePaneClick}
deleteKeyCode={DELETE_KEYS} deleteKeyCode={null}
selectionMode={selectionMode} selectionMode={selectionMode}
elevateEdgesOnSelect elevateEdgesOnSelect
> >

View File

@ -0,0 +1,20 @@
import { useDoesFieldExist } from 'features/nodes/hooks/useDoesFieldExist';
import type { PropsWithChildren } from 'react';
import { memo } from 'react';
type Props = PropsWithChildren<{
nodeId: string;
fieldName?: string;
}>;
export const MissingFallback = memo((props: Props) => {
// We must be careful here to avoid race conditions where a deleted node is still referenced as an exposed field
const exists = useDoesFieldExist(props.nodeId, props.fieldName);
if (!exists) {
return null;
}
return props.children;
});
MissingFallback.displayName = 'MissingFallback';

View File

@ -3,6 +3,7 @@ import { CSS } from '@dnd-kit/utilities';
import { Flex, Icon, IconButton, Spacer, Tooltip } from '@invoke-ai/ui-library'; import { Flex, Icon, IconButton, Spacer, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import NodeSelectionOverlay from 'common/components/NodeSelectionOverlay'; import NodeSelectionOverlay from 'common/components/NodeSelectionOverlay';
import { MissingFallback } from 'features/nodes/components/flow/nodes/Invocation/MissingFallback';
import { useFieldOriginalValue } from 'features/nodes/hooks/useFieldOriginalValue'; import { useFieldOriginalValue } from 'features/nodes/hooks/useFieldOriginalValue';
import { useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode'; import { useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
import { workflowExposedFieldRemoved } from 'features/nodes/store/workflowSlice'; import { workflowExposedFieldRemoved } from 'features/nodes/store/workflowSlice';
@ -20,7 +21,7 @@ type Props = {
fieldName: string; fieldName: string;
}; };
const LinearViewField = ({ nodeId, fieldName }: Props) => { const LinearViewFieldInternal = ({ nodeId, fieldName }: Props) => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { isValueChanged, onReset } = useFieldOriginalValue(nodeId, fieldName); const { isValueChanged, onReset } = useFieldOriginalValue(nodeId, fieldName);
const { isMouseOverNode, handleMouseOut, handleMouseOver } = useMouseOverNode(nodeId); const { isMouseOverNode, handleMouseOut, handleMouseOver } = useMouseOverNode(nodeId);
@ -99,4 +100,12 @@ const LinearViewField = ({ nodeId, fieldName }: Props) => {
); );
}; };
const LinearViewField = ({ nodeId, fieldName }: Props) => {
return (
<MissingFallback nodeId={nodeId} fieldName={fieldName}>
<LinearViewFieldInternal nodeId={nodeId} fieldName={fieldName} />
</MissingFallback>
);
};
export default memo(LinearViewField); export default memo(LinearViewField);

View File

@ -1,6 +1,7 @@
import { Flex, FormLabel, Icon, IconButton, Spacer, Tooltip } from '@invoke-ai/ui-library'; import { Flex, FormLabel, Icon, IconButton, Spacer, Tooltip } from '@invoke-ai/ui-library';
import FieldTooltipContent from 'features/nodes/components/flow/nodes/Invocation/fields/FieldTooltipContent'; import FieldTooltipContent from 'features/nodes/components/flow/nodes/Invocation/fields/FieldTooltipContent';
import InputFieldRenderer from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer'; import InputFieldRenderer from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer';
import { MissingFallback } from 'features/nodes/components/flow/nodes/Invocation/MissingFallback';
import { useFieldLabel } from 'features/nodes/hooks/useFieldLabel'; import { useFieldLabel } from 'features/nodes/hooks/useFieldLabel';
import { useFieldOriginalValue } from 'features/nodes/hooks/useFieldOriginalValue'; import { useFieldOriginalValue } from 'features/nodes/hooks/useFieldOriginalValue';
import { useFieldTemplateTitle } from 'features/nodes/hooks/useFieldTemplateTitle'; import { useFieldTemplateTitle } from 'features/nodes/hooks/useFieldTemplateTitle';
@ -14,7 +15,7 @@ type Props = {
fieldName: string; fieldName: string;
}; };
const WorkflowField = ({ nodeId, fieldName }: Props) => { const WorkflowFieldInternal = ({ nodeId, fieldName }: Props) => {
const label = useFieldLabel(nodeId, fieldName); const label = useFieldLabel(nodeId, fieldName);
const fieldTemplateTitle = useFieldTemplateTitle(nodeId, fieldName, 'inputs'); const fieldTemplateTitle = useFieldTemplateTitle(nodeId, fieldName, 'inputs');
const { isValueChanged, onReset } = useFieldOriginalValue(nodeId, fieldName); const { isValueChanged, onReset } = useFieldOriginalValue(nodeId, fieldName);
@ -50,4 +51,12 @@ const WorkflowField = ({ nodeId, fieldName }: Props) => {
); );
}; };
const WorkflowField = ({ nodeId, fieldName }: Props) => {
return (
<MissingFallback nodeId={nodeId} fieldName={fieldName}>
<WorkflowFieldInternal nodeId={nodeId} fieldName={fieldName} />
</MissingFallback>
);
};
export default memo(WorkflowField); export default memo(WorkflowField);

View File

@ -6,10 +6,10 @@ import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import DndSortable from 'features/dnd/components/DndSortable'; import DndSortable from 'features/dnd/components/DndSortable';
import type { DragEndEvent } from 'features/dnd/types'; import type { DragEndEvent } from 'features/dnd/types';
import LinearViewField from 'features/nodes/components/flow/nodes/Invocation/fields/LinearViewField'; import LinearViewFieldInternal from 'features/nodes/components/flow/nodes/Invocation/fields/LinearViewField';
import { selectWorkflowSlice, workflowExposedFieldsReordered } from 'features/nodes/store/workflowSlice'; import { selectWorkflowSlice, workflowExposedFieldsReordered } from 'features/nodes/store/workflowSlice';
import type { FieldIdentifier } from 'features/nodes/types/field'; import type { FieldIdentifier } from 'features/nodes/types/field';
import { memo, useCallback } from 'react'; import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useGetOpenAPISchemaQuery } from 'services/api/endpoints/appInfo'; import { useGetOpenAPISchemaQuery } from 'services/api/endpoints/appInfo';
@ -40,16 +40,18 @@ const WorkflowLinearTab = () => {
[dispatch, fields] [dispatch, fields]
); );
const items = useMemo(() => fields.map((field) => `${field.nodeId}.${field.fieldName}`), [fields]);
return ( return (
<Box position="relative" w="full" h="full"> <Box position="relative" w="full" h="full">
<ScrollableContent> <ScrollableContent>
<DndSortable onDragEnd={handleDragEnd} items={fields.map((field) => `${field.nodeId}.${field.fieldName}`)}> <DndSortable onDragEnd={handleDragEnd} items={items}>
<Flex position="relative" flexDir="column" alignItems="flex-start" p={1} gap={2} h="full" w="full"> <Flex position="relative" flexDir="column" alignItems="flex-start" p={1} gap={2} h="full" w="full">
{isLoading ? ( {isLoading ? (
<IAINoContentFallback label={t('nodes.loadingNodes')} icon={null} /> <IAINoContentFallback label={t('nodes.loadingNodes')} icon={null} />
) : fields.length ? ( ) : fields.length ? (
fields.map(({ nodeId, fieldName }) => ( fields.map(({ nodeId, fieldName }) => (
<LinearViewField key={`${nodeId}.${fieldName}`} nodeId={nodeId} fieldName={fieldName} /> <LinearViewFieldInternal key={`${nodeId}.${fieldName}`} nodeId={nodeId} fieldName={fieldName} />
)) ))
) : ( ) : (
<IAINoContentFallback label={t('nodes.noFieldsLinearview')} icon={null} /> <IAINoContentFallback label={t('nodes.noFieldsLinearview')} icon={null} />

View File

@ -7,13 +7,12 @@ import {
$isAddNodePopoverOpen, $isAddNodePopoverOpen,
$pendingConnection, $pendingConnection,
$templates, $templates,
connectionMade, edgesChanged,
edgeDeleted,
} from 'features/nodes/store/nodesSlice'; } from 'features/nodes/store/nodesSlice';
import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection'; import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection';
import { isString } from 'lodash-es'; import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import type { OnConnect, OnConnectEnd, OnConnectStart } from 'reactflow'; import type { EdgeChange, OnConnect, OnConnectEnd, OnConnectStart } from 'reactflow';
import { useUpdateNodeInternals } from 'reactflow'; import { useUpdateNodeInternals } from 'reactflow';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
@ -50,9 +49,9 @@ export const useConnection = () => {
const onConnect = useCallback<OnConnect>( const onConnect = useCallback<OnConnect>(
(connection) => { (connection) => {
const { dispatch } = store; const { dispatch } = store;
dispatch(connectionMade(connection)); const newEdge = connectionToEdge(connection);
const nodesToUpdate = [connection.source, connection.target].filter(isString); dispatch(edgesChanged([{ type: 'add', item: newEdge }]));
updateNodeInternals(nodesToUpdate); updateNodeInternals([newEdge.source, newEdge.target]);
$pendingConnection.set(null); $pendingConnection.set(null);
}, },
[store, updateNodeInternals] [store, updateNodeInternals]
@ -92,13 +91,17 @@ export const useConnection = () => {
edgePendingUpdate edgePendingUpdate
); );
if (connection) { if (connection) {
dispatch(connectionMade(connection)); const newEdge = connectionToEdge(connection);
const nodesToUpdate = [connection.source, connection.target].filter(isString); const changes: EdgeChange[] = [{ type: 'add', item: newEdge }];
updateNodeInternals(nodesToUpdate);
const nodesToUpdate = [newEdge.source, newEdge.target];
if (edgePendingUpdate) { if (edgePendingUpdate) {
dispatch(edgeDeleted(edgePendingUpdate.id));
$didUpdateEdge.set(true); $didUpdateEdge.set(true);
changes.push({ type: 'remove', id: edgePendingUpdate.id });
nodesToUpdate.push(edgePendingUpdate.source, edgePendingUpdate.target);
} }
dispatch(edgesChanged(changes));
updateNodeInternals(nodesToUpdate);
} }
$pendingConnection.set(null); $pendingConnection.set(null);
} else { } else {

View File

@ -0,0 +1,20 @@
import { useAppSelector } from 'app/store/storeHooks';
import { isInvocationNode } from 'features/nodes/types/invocation';
export const useDoesFieldExist = (nodeId: string, fieldName?: string) => {
const doesFieldExist = useAppSelector((s) => {
const node = s.nodes.present.nodes.find((n) => n.id === nodeId);
if (!isInvocationNode(node)) {
return false;
}
if (fieldName === undefined) {
return true;
}
if (!node.data.inputs[fieldName]) {
return false;
}
return true;
});
return doesFieldExist;
};

View File

@ -1,6 +1,7 @@
import type { PayloadAction, UnknownAction } from '@reduxjs/toolkit'; import type { PayloadAction, UnknownAction } from '@reduxjs/toolkit';
import { createSlice, isAnyOf } from '@reduxjs/toolkit'; import { createSlice, isAnyOf } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store'; import type { PersistConfig, RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import { workflowLoaded } from 'features/nodes/store/actions'; import { workflowLoaded } from 'features/nodes/store/actions';
import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants'; import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
import type { import type {
@ -48,8 +49,8 @@ import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocatio
import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation'; import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
import { atom } from 'nanostores'; import { atom } from 'nanostores';
import type { MouseEvent } from 'react'; import type { MouseEvent } from 'react';
import type { Connection, Edge, EdgeChange, EdgeRemoveChange, Node, NodeChange, Viewport, XYPosition } from 'reactflow'; import type { Edge, EdgeChange, Node, NodeChange, Viewport, XYPosition } from 'reactflow';
import { addEdge, applyEdgeChanges, applyNodeChanges, getConnectedEdges, getIncomers, getOutgoers } from 'reactflow'; import { applyEdgeChanges, applyNodeChanges, getConnectedEdges, getIncomers, getOutgoers } from 'reactflow';
import type { UndoableOptions } from 'redux-undo'; import type { UndoableOptions } from 'redux-undo';
import type { z } from 'zod'; import type { z } from 'zod';
@ -124,10 +125,27 @@ export const nodesSlice = createSlice({
state.nodes.push(node); state.nodes.push(node);
}, },
edgesChanged: (state, action: PayloadAction<EdgeChange[]>) => { edgesChanged: (state, action: PayloadAction<EdgeChange[]>) => {
state.edges = applyEdgeChanges(action.payload, state.edges); const changes = deepClone(action.payload);
}, action.payload.forEach((change) => {
connectionMade: (state, action: PayloadAction<Connection>) => { if (change.type === 'remove' || change.type === 'select') {
state.edges = addEdge({ ...action.payload, type: 'default' }, state.edges); const edge = state.edges.find((e) => e.id === change.id);
// If we deleted or selected a collapsed edge, we need to find its "hidden" edges and do the same to them
if (edge && edge.type === 'collapsed') {
const hiddenEdges = state.edges.filter((e) => e.source === edge.source && e.target === edge.target);
if (change.type === 'remove') {
hiddenEdges.forEach((e) => {
changes.push({ type: 'remove', id: e.id });
});
}
if (change.type === 'select') {
hiddenEdges.forEach((e) => {
changes.push({ type: 'select', id: e.id, selected: change.selected });
});
}
}
}
});
state.edges = applyEdgeChanges(changes, state.edges);
}, },
fieldLabelChanged: ( fieldLabelChanged: (
state, state,
@ -264,33 +282,6 @@ export const nodesSlice = createSlice({
} }
} }
}, },
edgeDeleted: (state, action: PayloadAction<string>) => {
state.edges = state.edges.filter((e) => e.id !== action.payload);
},
edgesDeleted: (state, action: PayloadAction<Edge[]>) => {
const edges = action.payload;
const collapsedEdges = edges.filter((e) => e.type === 'collapsed');
// if we delete a collapsed edge, we need to delete all collapsed edges between the same nodes
if (collapsedEdges.length) {
const edgeChanges: EdgeRemoveChange[] = [];
collapsedEdges.forEach((collapsedEdge) => {
state.edges.forEach((edge) => {
if (edge.source === collapsedEdge.source && edge.target === collapsedEdge.target) {
edgeChanges.push({ id: edge.id, type: 'remove' });
}
});
});
state.edges = applyEdgeChanges(edgeChanges, state.edges);
}
},
nodesDeleted: (state, action: PayloadAction<AnyNode[]>) => {
action.payload.forEach((node) => {
if (!isInvocationNode(node)) {
return;
}
});
},
nodeLabelChanged: (state, action: PayloadAction<{ nodeId: string; label: string }>) => { nodeLabelChanged: (state, action: PayloadAction<{ nodeId: string; label: string }>) => {
const { nodeId, label } = action.payload; const { nodeId, label } = action.payload;
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId); const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
@ -435,6 +426,23 @@ export const nodesSlice = createSlice({
state.nodes = applyNodeChanges(nodeChanges, state.nodes); state.nodes = applyNodeChanges(nodeChanges, state.nodes);
state.edges = applyEdgeChanges(edgeChanges, state.edges); state.edges = applyEdgeChanges(edgeChanges, state.edges);
}, },
selectionDeleted: (state) => {
const selectedNodes = state.nodes.filter((n) => n.selected);
const selectedEdges = state.edges.filter((e) => e.selected);
const nodeChanges: NodeChange[] = selectedNodes.map((n) => ({
id: n.id,
type: 'remove',
}));
const edgeChanges: EdgeChange[] = selectedEdges.map((e) => ({
id: e.id,
type: 'remove',
}));
state.nodes = applyNodeChanges(nodeChanges, state.nodes);
state.edges = applyEdgeChanges(edgeChanges, state.edges);
},
undo: (state) => state, undo: (state) => state,
redo: (state) => state, redo: (state) => state,
}, },
@ -457,10 +465,7 @@ export const nodesSlice = createSlice({
}); });
export const { export const {
connectionMade,
edgeDeleted,
edgesChanged, edgesChanged,
edgesDeleted,
fieldValueReset, fieldValueReset,
fieldBoardValueChanged, fieldBoardValueChanged,
fieldBooleanValueChanged, fieldBooleanValueChanged,
@ -488,11 +493,11 @@ export const {
nodeLabelChanged, nodeLabelChanged,
nodeNotesChanged, nodeNotesChanged,
nodesChanged, nodesChanged,
nodesDeleted,
nodeUseCacheChanged, nodeUseCacheChanged,
notesNodeValueChanged, notesNodeValueChanged,
selectedAll, selectedAll,
selectionPasted, selectionPasted,
selectionDeleted,
undo, undo,
redo, redo,
} = nodesSlice.actions; } = nodesSlice.actions;
@ -580,10 +585,7 @@ export const nodesUndoableConfig: UndoableOptions<NodesState, UnknownAction> = {
// This is used for tracking `state.workflow.isTouched` // This is used for tracking `state.workflow.isTouched`
export const isAnyNodeOrEdgeMutation = isAnyOf( export const isAnyNodeOrEdgeMutation = isAnyOf(
connectionMade,
edgeDeleted,
edgesChanged, edgesChanged,
edgesDeleted,
fieldBoardValueChanged, fieldBoardValueChanged,
fieldBooleanValueChanged, fieldBooleanValueChanged,
fieldColorValueChanged, fieldColorValueChanged,
@ -601,13 +603,14 @@ export const isAnyNodeOrEdgeMutation = isAnyOf(
fieldStringValueChanged, fieldStringValueChanged,
fieldVaeModelValueChanged, fieldVaeModelValueChanged,
nodeAdded, nodeAdded,
nodesChanged,
nodeReplaced, nodeReplaced,
nodeIsIntermediateChanged, nodeIsIntermediateChanged,
nodeIsOpenChanged, nodeIsOpenChanged,
nodeLabelChanged, nodeLabelChanged,
nodeNotesChanged, nodeNotesChanged,
nodesDeleted,
nodeUseCacheChanged, nodeUseCacheChanged,
notesNodeValueChanged, notesNodeValueChanged,
selectionPasted selectionPasted,
selectionDeleted
); );

View File

@ -0,0 +1,32 @@
import type { Connection, Edge } from 'reactflow';
import { assert } from 'tsafe';
/**
* Gets the edge id for a connection
* Copied from: https://github.com/xyflow/xyflow/blob/v11/packages/core/src/utils/graph.ts#L44-L45
* Requested for this to be exported in: https://github.com/xyflow/xyflow/issues/4290
* @param connection The connection to get the id for
* @returns The edge id
*/
const getEdgeId = (connection: Connection): string => {
const { source, sourceHandle, target, targetHandle } = connection;
return `reactflow__edge-${source}${sourceHandle || ''}-${target}${targetHandle || ''}`;
};
/**
* Converts a connection to an edge
* @param connection The connection to convert to an edge
* @returns The edge
* @throws If the connection is invalid (e.g. missing source, sourcehandle, target, or targetHandle)
*/
export const connectionToEdge = (connection: Connection): Edge => {
const { source, sourceHandle, target, targetHandle } = connection;
assert(source && sourceHandle && target && targetHandle, 'Invalid connection');
return {
source,
sourceHandle,
target,
targetHandle,
id: getEdgeId({ source, sourceHandle, target, targetHandle }),
};
};

View File

@ -3,7 +3,7 @@ import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store'; import type { PersistConfig, RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone'; import { deepClone } from 'common/util/deepClone';
import { workflowLoaded } from 'features/nodes/store/actions'; import { workflowLoaded } from 'features/nodes/store/actions';
import { isAnyNodeOrEdgeMutation, nodeEditorReset, nodesChanged, nodesDeleted } from 'features/nodes/store/nodesSlice'; import { isAnyNodeOrEdgeMutation, nodeEditorReset, nodesChanged } from 'features/nodes/store/nodesSlice';
import type { import type {
FieldIdentifierWithValue, FieldIdentifierWithValue,
WorkflowMode, WorkflowMode,
@ -139,16 +139,16 @@ export const workflowSlice = createSlice({
}; };
}); });
builder.addCase(nodesDeleted, (state, action) => {
action.payload.forEach((node) => {
state.exposedFields = state.exposedFields.filter((f) => f.nodeId !== node.id);
});
});
builder.addCase(nodeEditorReset, () => deepClone(initialWorkflowState)); builder.addCase(nodeEditorReset, () => deepClone(initialWorkflowState));
builder.addCase(nodesChanged, (state, action) => { builder.addCase(nodesChanged, (state, action) => {
// Not all changes to nodes should result in the workflow being marked touched // Not all changes to nodes should result in the workflow being marked touched
action.payload.forEach((change) => {
if (change.type === 'remove') {
state.exposedFields = state.exposedFields.filter((f) => f.nodeId !== change.id);
}
});
const filteredChanges = action.payload.filter((change) => { const filteredChanges = action.payload.filter((change) => {
// We always want to mark the workflow as touched if a node is added, removed, or reset // We always want to mark the workflow as touched if a node is added, removed, or reset
if (['add', 'remove', 'reset'].includes(change.type)) { if (['add', 'remove', 'reset'].includes(change.type)) {