feat(ui): rework node and edge mutation logic

Remove our DIY'd reducers, consolidating all node and edge mutations to use `edgesChanged` and `nodesChanged`, which are called by reactflow. This makes the API for manipulating nodes and edges less tangly and error-prone.
This commit is contained in:
psychedelicious 2024-05-19 13:54:32 +10:00
parent 504ac82077
commit 26029108f7
11 changed files with 186 additions and 100 deletions

View File

@ -14,11 +14,12 @@ import {
$pendingConnection,
$templates,
closeAddNodePopover,
connectionMade,
edgesChanged,
nodeAdded,
openAddNodePopover,
} from 'features/nodes/store/nodesSlice';
import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection';
import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil';
import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes';
import type { AnyNode } from 'features/nodes/types/invocation';
import { isInvocationNode } from 'features/nodes/types/invocation';
@ -166,7 +167,8 @@ const AddNodePopover = () => {
edgePendingUpdate
);
if (connection) {
dispatch(connectionMade(connection));
const newEdge = connectionToEdge(connection);
dispatch(edgesChanged([{ type: 'add', item: newEdge }]));
}
}

View File

@ -14,29 +14,24 @@ import {
$lastEdgeUpdateMouseEvent,
$pendingConnection,
$viewport,
connectionMade,
edgeDeleted,
edgesChanged,
edgesDeleted,
nodesChanged,
nodesDeleted,
redo,
selectedAll,
selectionDeleted,
undo,
} from 'features/nodes/store/nodesSlice';
import { $flow } from 'features/nodes/store/reactFlowInstance';
import { isString } from 'lodash-es';
import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil';
import type { CSSProperties, MouseEvent } from 'react';
import { memo, useCallback, useMemo, useRef } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import type {
OnEdgesChange,
OnEdgesDelete,
OnEdgeUpdateFunc,
OnInit,
OnMoveEnd,
OnNodesChange,
OnNodesDelete,
ProOptions,
ReactFlowProps,
ReactFlowState,
@ -50,8 +45,6 @@ import CurrentImageNode from './nodes/CurrentImage/CurrentImageNode';
import InvocationNodeWrapper from './nodes/Invocation/InvocationNodeWrapper';
import NotesNode from './nodes/Notes/NotesNode';
const DELETE_KEYS = ['Delete', 'Backspace'];
const edgeTypes = {
collapsed: InvocationCollapsedEdge,
default: InvocationDefaultEdge,
@ -109,20 +102,6 @@ export const Flow = memo(() => {
[dispatch]
);
const onEdgesDelete: OnEdgesDelete = useCallback(
(edges) => {
dispatch(edgesDeleted(edges));
},
[dispatch]
);
const onNodesDelete: OnNodesDelete = useCallback(
(nodes) => {
dispatch(nodesDeleted(nodes));
},
[dispatch]
);
const handleMoveEnd: OnMoveEnd = useCallback((e, viewport) => {
$viewport.set(viewport);
}, []);
@ -167,16 +146,20 @@ export const Flow = memo(() => {
}, []);
const onEdgeUpdate: OnEdgeUpdateFunc = useCallback(
(edge, newConnection) => {
(oldEdge, newConnection) => {
// This event is fired when an edge update is successful
$didUpdateEdge.set(true);
// When an edge update is successful, we need to delete the old edge and create a new one
dispatch(edgeDeleted(edge.id));
dispatch(connectionMade(newConnection));
const newEdge = connectionToEdge(newConnection);
dispatch(
edgesChanged([
{ type: 'remove', id: oldEdge.id },
{ type: 'add', item: newEdge },
])
);
// Because we shift the position of handles depending on whether a field is connected or not, we must use
// updateNodeInternals to tell reactflow to recalculate the positions of the handles
const nodesToUpdate = [edge.source, edge.target, newConnection.source, newConnection.target].filter(isString);
updateNodeInternals(nodesToUpdate);
updateNodeInternals([oldEdge.source, oldEdge.target, newEdge.source, newEdge.target]);
},
[dispatch, updateNodeInternals]
);
@ -193,7 +176,7 @@ export const Flow = memo(() => {
// If we got this far and did not successfully update an edge, and the mouse moved away from the handle,
// the user probably intended to delete the edge
if (!didUpdateEdge && didMouseMove) {
dispatch(edgeDeleted(edge.id));
dispatch(edgesChanged([{ type: 'remove', id: edge.id }]));
}
$edgePendingUpdate.set(null);
@ -267,6 +250,11 @@ export const Flow = memo(() => {
}, [cancelConnection]);
useHotkeys('esc', onEscapeHotkey);
const onDeleteHotkey = useCallback(() => {
dispatch(selectionDeleted());
}, [dispatch]);
useHotkeys(['delete', 'backspace'], onDeleteHotkey);
return (
<ReactFlow
id="workflow-editor"
@ -280,11 +268,9 @@ export const Flow = memo(() => {
onMouseMove={onMouseMove}
onNodesChange={onNodesChange}
onEdgesChange={onEdgesChange}
onEdgesDelete={onEdgesDelete}
onEdgeUpdate={onEdgeUpdate}
onEdgeUpdateStart={onEdgeUpdateStart}
onEdgeUpdateEnd={onEdgeUpdateEnd}
onNodesDelete={onNodesDelete}
onConnectStart={onConnectStart}
onConnect={onConnect}
onConnectEnd={onConnectEnd}
@ -298,7 +284,7 @@ export const Flow = memo(() => {
proOptions={proOptions}
style={flowStyles}
onPaneClick={handlePaneClick}
deleteKeyCode={DELETE_KEYS}
deleteKeyCode={null}
selectionMode={selectionMode}
elevateEdgesOnSelect
>

View File

@ -0,0 +1,20 @@
import { useDoesFieldExist } from 'features/nodes/hooks/useDoesFieldExist';
import type { PropsWithChildren } from 'react';
import { memo } from 'react';
type Props = PropsWithChildren<{
nodeId: string;
fieldName?: string;
}>;
export const MissingFallback = memo((props: Props) => {
// We must be careful here to avoid race conditions where a deleted node is still referenced as an exposed field
const exists = useDoesFieldExist(props.nodeId, props.fieldName);
if (!exists) {
return null;
}
return props.children;
});
MissingFallback.displayName = 'MissingFallback';

View File

@ -3,6 +3,7 @@ import { CSS } from '@dnd-kit/utilities';
import { Flex, Icon, IconButton, Spacer, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import NodeSelectionOverlay from 'common/components/NodeSelectionOverlay';
import { MissingFallback } from 'features/nodes/components/flow/nodes/Invocation/MissingFallback';
import { useFieldOriginalValue } from 'features/nodes/hooks/useFieldOriginalValue';
import { useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
import { workflowExposedFieldRemoved } from 'features/nodes/store/workflowSlice';
@ -20,7 +21,7 @@ type Props = {
fieldName: string;
};
const LinearViewField = ({ nodeId, fieldName }: Props) => {
const LinearViewFieldInternal = ({ nodeId, fieldName }: Props) => {
const dispatch = useAppDispatch();
const { isValueChanged, onReset } = useFieldOriginalValue(nodeId, fieldName);
const { isMouseOverNode, handleMouseOut, handleMouseOver } = useMouseOverNode(nodeId);
@ -99,4 +100,12 @@ const LinearViewField = ({ nodeId, fieldName }: Props) => {
);
};
const LinearViewField = ({ nodeId, fieldName }: Props) => {
return (
<MissingFallback nodeId={nodeId} fieldName={fieldName}>
<LinearViewFieldInternal nodeId={nodeId} fieldName={fieldName} />
</MissingFallback>
);
};
export default memo(LinearViewField);

View File

@ -1,6 +1,7 @@
import { Flex, FormLabel, Icon, IconButton, Spacer, Tooltip } from '@invoke-ai/ui-library';
import FieldTooltipContent from 'features/nodes/components/flow/nodes/Invocation/fields/FieldTooltipContent';
import InputFieldRenderer from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer';
import { MissingFallback } from 'features/nodes/components/flow/nodes/Invocation/MissingFallback';
import { useFieldLabel } from 'features/nodes/hooks/useFieldLabel';
import { useFieldOriginalValue } from 'features/nodes/hooks/useFieldOriginalValue';
import { useFieldTemplateTitle } from 'features/nodes/hooks/useFieldTemplateTitle';
@ -14,7 +15,7 @@ type Props = {
fieldName: string;
};
const WorkflowField = ({ nodeId, fieldName }: Props) => {
const WorkflowFieldInternal = ({ nodeId, fieldName }: Props) => {
const label = useFieldLabel(nodeId, fieldName);
const fieldTemplateTitle = useFieldTemplateTitle(nodeId, fieldName, 'inputs');
const { isValueChanged, onReset } = useFieldOriginalValue(nodeId, fieldName);
@ -50,4 +51,12 @@ const WorkflowField = ({ nodeId, fieldName }: Props) => {
);
};
const WorkflowField = ({ nodeId, fieldName }: Props) => {
return (
<MissingFallback nodeId={nodeId} fieldName={fieldName}>
<WorkflowFieldInternal nodeId={nodeId} fieldName={fieldName} />
</MissingFallback>
);
};
export default memo(WorkflowField);

View File

@ -6,10 +6,10 @@ import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import DndSortable from 'features/dnd/components/DndSortable';
import type { DragEndEvent } from 'features/dnd/types';
import LinearViewField from 'features/nodes/components/flow/nodes/Invocation/fields/LinearViewField';
import LinearViewFieldInternal from 'features/nodes/components/flow/nodes/Invocation/fields/LinearViewField';
import { selectWorkflowSlice, workflowExposedFieldsReordered } from 'features/nodes/store/workflowSlice';
import type { FieldIdentifier } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetOpenAPISchemaQuery } from 'services/api/endpoints/appInfo';
@ -40,16 +40,18 @@ const WorkflowLinearTab = () => {
[dispatch, fields]
);
const items = useMemo(() => fields.map((field) => `${field.nodeId}.${field.fieldName}`), [fields]);
return (
<Box position="relative" w="full" h="full">
<ScrollableContent>
<DndSortable onDragEnd={handleDragEnd} items={fields.map((field) => `${field.nodeId}.${field.fieldName}`)}>
<DndSortable onDragEnd={handleDragEnd} items={items}>
<Flex position="relative" flexDir="column" alignItems="flex-start" p={1} gap={2} h="full" w="full">
{isLoading ? (
<IAINoContentFallback label={t('nodes.loadingNodes')} icon={null} />
) : fields.length ? (
fields.map(({ nodeId, fieldName }) => (
<LinearViewField key={`${nodeId}.${fieldName}`} nodeId={nodeId} fieldName={fieldName} />
<LinearViewFieldInternal key={`${nodeId}.${fieldName}`} nodeId={nodeId} fieldName={fieldName} />
))
) : (
<IAINoContentFallback label={t('nodes.noFieldsLinearview')} icon={null} />

View File

@ -7,13 +7,12 @@ import {
$isAddNodePopoverOpen,
$pendingConnection,
$templates,
connectionMade,
edgeDeleted,
edgesChanged,
} from 'features/nodes/store/nodesSlice';
import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection';
import { isString } from 'lodash-es';
import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil';
import { useCallback, useMemo } from 'react';
import type { OnConnect, OnConnectEnd, OnConnectStart } from 'reactflow';
import type { EdgeChange, OnConnect, OnConnectEnd, OnConnectStart } from 'reactflow';
import { useUpdateNodeInternals } from 'reactflow';
import { assert } from 'tsafe';
@ -50,9 +49,9 @@ export const useConnection = () => {
const onConnect = useCallback<OnConnect>(
(connection) => {
const { dispatch } = store;
dispatch(connectionMade(connection));
const nodesToUpdate = [connection.source, connection.target].filter(isString);
updateNodeInternals(nodesToUpdate);
const newEdge = connectionToEdge(connection);
dispatch(edgesChanged([{ type: 'add', item: newEdge }]));
updateNodeInternals([newEdge.source, newEdge.target]);
$pendingConnection.set(null);
},
[store, updateNodeInternals]
@ -92,13 +91,17 @@ export const useConnection = () => {
edgePendingUpdate
);
if (connection) {
dispatch(connectionMade(connection));
const nodesToUpdate = [connection.source, connection.target].filter(isString);
updateNodeInternals(nodesToUpdate);
const newEdge = connectionToEdge(connection);
const changes: EdgeChange[] = [{ type: 'add', item: newEdge }];
const nodesToUpdate = [newEdge.source, newEdge.target];
if (edgePendingUpdate) {
dispatch(edgeDeleted(edgePendingUpdate.id));
$didUpdateEdge.set(true);
changes.push({ type: 'remove', id: edgePendingUpdate.id });
nodesToUpdate.push(edgePendingUpdate.source, edgePendingUpdate.target);
}
dispatch(edgesChanged(changes));
updateNodeInternals(nodesToUpdate);
}
$pendingConnection.set(null);
} else {

View File

@ -0,0 +1,20 @@
import { useAppSelector } from 'app/store/storeHooks';
import { isInvocationNode } from 'features/nodes/types/invocation';
export const useDoesFieldExist = (nodeId: string, fieldName?: string) => {
const doesFieldExist = useAppSelector((s) => {
const node = s.nodes.present.nodes.find((n) => n.id === nodeId);
if (!isInvocationNode(node)) {
return false;
}
if (fieldName === undefined) {
return true;
}
if (!node.data.inputs[fieldName]) {
return false;
}
return true;
});
return doesFieldExist;
};

View File

@ -1,6 +1,7 @@
import type { PayloadAction, UnknownAction } from '@reduxjs/toolkit';
import { createSlice, isAnyOf } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import { workflowLoaded } from 'features/nodes/store/actions';
import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
import type {
@ -48,8 +49,8 @@ import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocatio
import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
import { atom } from 'nanostores';
import type { MouseEvent } from 'react';
import type { Connection, Edge, EdgeChange, EdgeRemoveChange, Node, NodeChange, Viewport, XYPosition } from 'reactflow';
import { addEdge, applyEdgeChanges, applyNodeChanges, getConnectedEdges, getIncomers, getOutgoers } from 'reactflow';
import type { Edge, EdgeChange, Node, NodeChange, Viewport, XYPosition } from 'reactflow';
import { applyEdgeChanges, applyNodeChanges, getConnectedEdges, getIncomers, getOutgoers } from 'reactflow';
import type { UndoableOptions } from 'redux-undo';
import type { z } from 'zod';
@ -124,10 +125,27 @@ export const nodesSlice = createSlice({
state.nodes.push(node);
},
edgesChanged: (state, action: PayloadAction<EdgeChange[]>) => {
state.edges = applyEdgeChanges(action.payload, state.edges);
},
connectionMade: (state, action: PayloadAction<Connection>) => {
state.edges = addEdge({ ...action.payload, type: 'default' }, state.edges);
const changes = deepClone(action.payload);
action.payload.forEach((change) => {
if (change.type === 'remove' || change.type === 'select') {
const edge = state.edges.find((e) => e.id === change.id);
// If we deleted or selected a collapsed edge, we need to find its "hidden" edges and do the same to them
if (edge && edge.type === 'collapsed') {
const hiddenEdges = state.edges.filter((e) => e.source === edge.source && e.target === edge.target);
if (change.type === 'remove') {
hiddenEdges.forEach((e) => {
changes.push({ type: 'remove', id: e.id });
});
}
if (change.type === 'select') {
hiddenEdges.forEach((e) => {
changes.push({ type: 'select', id: e.id, selected: change.selected });
});
}
}
}
});
state.edges = applyEdgeChanges(changes, state.edges);
},
fieldLabelChanged: (
state,
@ -264,33 +282,6 @@ export const nodesSlice = createSlice({
}
}
},
edgeDeleted: (state, action: PayloadAction<string>) => {
state.edges = state.edges.filter((e) => e.id !== action.payload);
},
edgesDeleted: (state, action: PayloadAction<Edge[]>) => {
const edges = action.payload;
const collapsedEdges = edges.filter((e) => e.type === 'collapsed');
// if we delete a collapsed edge, we need to delete all collapsed edges between the same nodes
if (collapsedEdges.length) {
const edgeChanges: EdgeRemoveChange[] = [];
collapsedEdges.forEach((collapsedEdge) => {
state.edges.forEach((edge) => {
if (edge.source === collapsedEdge.source && edge.target === collapsedEdge.target) {
edgeChanges.push({ id: edge.id, type: 'remove' });
}
});
});
state.edges = applyEdgeChanges(edgeChanges, state.edges);
}
},
nodesDeleted: (state, action: PayloadAction<AnyNode[]>) => {
action.payload.forEach((node) => {
if (!isInvocationNode(node)) {
return;
}
});
},
nodeLabelChanged: (state, action: PayloadAction<{ nodeId: string; label: string }>) => {
const { nodeId, label } = action.payload;
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
@ -435,6 +426,23 @@ export const nodesSlice = createSlice({
state.nodes = applyNodeChanges(nodeChanges, state.nodes);
state.edges = applyEdgeChanges(edgeChanges, state.edges);
},
selectionDeleted: (state) => {
const selectedNodes = state.nodes.filter((n) => n.selected);
const selectedEdges = state.edges.filter((e) => e.selected);
const nodeChanges: NodeChange[] = selectedNodes.map((n) => ({
id: n.id,
type: 'remove',
}));
const edgeChanges: EdgeChange[] = selectedEdges.map((e) => ({
id: e.id,
type: 'remove',
}));
state.nodes = applyNodeChanges(nodeChanges, state.nodes);
state.edges = applyEdgeChanges(edgeChanges, state.edges);
},
undo: (state) => state,
redo: (state) => state,
},
@ -457,10 +465,7 @@ export const nodesSlice = createSlice({
});
export const {
connectionMade,
edgeDeleted,
edgesChanged,
edgesDeleted,
fieldValueReset,
fieldBoardValueChanged,
fieldBooleanValueChanged,
@ -488,11 +493,11 @@ export const {
nodeLabelChanged,
nodeNotesChanged,
nodesChanged,
nodesDeleted,
nodeUseCacheChanged,
notesNodeValueChanged,
selectedAll,
selectionPasted,
selectionDeleted,
undo,
redo,
} = nodesSlice.actions;
@ -580,10 +585,7 @@ export const nodesUndoableConfig: UndoableOptions<NodesState, UnknownAction> = {
// This is used for tracking `state.workflow.isTouched`
export const isAnyNodeOrEdgeMutation = isAnyOf(
connectionMade,
edgeDeleted,
edgesChanged,
edgesDeleted,
fieldBoardValueChanged,
fieldBooleanValueChanged,
fieldColorValueChanged,
@ -601,13 +603,14 @@ export const isAnyNodeOrEdgeMutation = isAnyOf(
fieldStringValueChanged,
fieldVaeModelValueChanged,
nodeAdded,
nodesChanged,
nodeReplaced,
nodeIsIntermediateChanged,
nodeIsOpenChanged,
nodeLabelChanged,
nodeNotesChanged,
nodesDeleted,
nodeUseCacheChanged,
notesNodeValueChanged,
selectionPasted
selectionPasted,
selectionDeleted
);

View File

@ -0,0 +1,32 @@
import type { Connection, Edge } from 'reactflow';
import { assert } from 'tsafe';
/**
* Gets the edge id for a connection
* Copied from: https://github.com/xyflow/xyflow/blob/v11/packages/core/src/utils/graph.ts#L44-L45
* Requested for this to be exported in: https://github.com/xyflow/xyflow/issues/4290
* @param connection The connection to get the id for
* @returns The edge id
*/
const getEdgeId = (connection: Connection): string => {
const { source, sourceHandle, target, targetHandle } = connection;
return `reactflow__edge-${source}${sourceHandle || ''}-${target}${targetHandle || ''}`;
};
/**
* Converts a connection to an edge
* @param connection The connection to convert to an edge
* @returns The edge
* @throws If the connection is invalid (e.g. missing source, sourcehandle, target, or targetHandle)
*/
export const connectionToEdge = (connection: Connection): Edge => {
const { source, sourceHandle, target, targetHandle } = connection;
assert(source && sourceHandle && target && targetHandle, 'Invalid connection');
return {
source,
sourceHandle,
target,
targetHandle,
id: getEdgeId({ source, sourceHandle, target, targetHandle }),
};
};

View File

@ -3,7 +3,7 @@ import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import { workflowLoaded } from 'features/nodes/store/actions';
import { isAnyNodeOrEdgeMutation, nodeEditorReset, nodesChanged, nodesDeleted } from 'features/nodes/store/nodesSlice';
import { isAnyNodeOrEdgeMutation, nodeEditorReset, nodesChanged } from 'features/nodes/store/nodesSlice';
import type {
FieldIdentifierWithValue,
WorkflowMode,
@ -139,16 +139,16 @@ export const workflowSlice = createSlice({
};
});
builder.addCase(nodesDeleted, (state, action) => {
action.payload.forEach((node) => {
state.exposedFields = state.exposedFields.filter((f) => f.nodeId !== node.id);
});
});
builder.addCase(nodeEditorReset, () => deepClone(initialWorkflowState));
builder.addCase(nodesChanged, (state, action) => {
// Not all changes to nodes should result in the workflow being marked touched
action.payload.forEach((change) => {
if (change.type === 'remove') {
state.exposedFields = state.exposedFields.filter((f) => f.nodeId !== change.id);
}
});
const filteredChanges = action.payload.filter((change) => {
// We always want to mark the workflow as touched if a node is added, removed, or reset
if (['add', 'remove', 'reset'].includes(change.type)) {