feat(ui): remove remaining extraneous state from nodes slice

This commit is contained in:
psychedelicious 2024-05-16 19:36:21 +10:00
parent 4d68cd8dbb
commit b0c7c7cb47
8 changed files with 66 additions and 198 deletions

View File

@ -8,6 +8,7 @@ import { useAppDispatch, useAppStore } from 'app/store/storeHooks';
import type { SelectInstance } from 'chakra-react-select';
import { useBuildNode } from 'features/nodes/hooks/useBuildNode';
import {
$cursorPos,
$isAddNodePopoverOpen,
$pendingConnection,
$templates,
@ -117,8 +118,8 @@ const AddNodePopover = () => {
const addNode = useCallback(
(nodeType: string): AnyNode | null => {
const invocation = buildInvocation(nodeType);
if (!invocation) {
const node = buildInvocation(nodeType);
if (!node) {
const errorMessage = t('nodes.unknownNode', {
nodeType: nodeType,
});
@ -128,9 +129,9 @@ const AddNodePopover = () => {
});
return null;
}
dispatch(nodeAdded(invocation));
return invocation;
const cursorPos = $cursorPos.get();
dispatch(nodeAdded({ node, cursorPos }));
return node;
},
[dispatch, buildInvocation, toaster, t]
);

View File

@ -8,7 +8,6 @@ import {
$cursorPos,
connectionMade,
edgeAdded,
edgeChangeStarted,
edgeDeleted,
edgesChanged,
edgesDeleted,
@ -170,7 +169,6 @@ export const Flow = memo(() => {
edgeUpdateMouseEvent.current = e;
// always delete the edge when starting an updated
dispatch(edgeDeleted(edge.id));
dispatch(edgeChangeStarted());
},
[dispatch]
);

View File

@ -1,29 +1,33 @@
import { createSelector } from '@reduxjs/toolkit';
import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks';
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { getFieldColor } from 'features/nodes/components/flow/edges/util/getEdgeColor';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
import { $pendingConnection } from 'features/nodes/store/nodesSlice';
import type { CSSProperties } from 'react';
import { memo } from 'react';
import { memo, useMemo } from 'react';
import type { ConnectionLineComponentProps } from 'reactflow';
import { getBezierPath } from 'reactflow';
const selectStroke = createSelector([selectNodesSlice, selectWorkflowSettingsSlice], (nodes, workflowSettings) =>
workflowSettings.shouldColorEdges ? getFieldColor(nodes.connectionStartFieldType) : colorTokenToCssVar('base.500')
);
const selectClassName = createSelector(selectWorkflowSettingsSlice, (workflowSettings) =>
workflowSettings.shouldAnimateEdges
? 'react-flow__custom_connection-path animated'
: 'react-flow__custom_connection-path'
);
const pathStyles: CSSProperties = { opacity: 0.8 };
const CustomConnectionLine = ({ fromX, fromY, fromPosition, toX, toY, toPosition }: ConnectionLineComponentProps) => {
const stroke = useAppSelector(selectStroke);
const className = useAppSelector(selectClassName);
const pendingConnection = useStore($pendingConnection);
const shouldColorEdges = useAppSelector((state) => state.workflowSettings.shouldColorEdges);
const shouldAnimateEdges = useAppSelector((state) => state.workflowSettings.shouldAnimateEdges);
const stroke = useMemo(() => {
if (shouldColorEdges && pendingConnection) {
return getFieldColor(pendingConnection.fieldTemplate.type);
} else {
return colorTokenToCssVar('base.500');
}
}, [pendingConnection, shouldColorEdges]);
const className = useMemo(() => {
if (shouldAnimateEdges) {
return 'react-flow__custom_connection-path animated';
} else {
return 'react-flow__custom_connection-path';
}
}, [shouldAnimateEdges]);
const pathParams = {
sourceX: fromX,

View File

@ -1,23 +1,18 @@
import { IconButton } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { addNodePopoverOpened } from 'features/nodes/store/nodesSlice';
import { memo, useCallback } from 'react';
import { openAddNodePopover } from 'features/nodes/store/nodesSlice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPlusBold } from 'react-icons/pi';
const AddNodeButton = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleOpenAddNodePopover = useCallback(() => {
dispatch(addNodePopoverOpened());
}, [dispatch]);
return (
<IconButton
tooltip={t('nodes.addNodeToolTip')}
aria-label={t('nodes.addNode')}
icon={<PiPlusBold />}
onClick={handleOpenAddNodePopover}
onClick={openAddNodePopover}
pointerEvents="auto"
/>
);

View File

@ -1,16 +1,12 @@
import { useStore } from '@nanostores/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { $pendingConnection, selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { makeConnectionErrorSelector } from 'features/nodes/store/util/makeIsConnectionValidSelector';
import { useMemo } from 'react';
import { useFieldType } from './useFieldType.ts';
const selectIsConnectionInProgress = createSelector(
selectNodesSlice,
(nodes) => nodes.connectionStartFieldType !== null && nodes.connectionStartParams !== null
);
type UseConnectionStateProps = {
nodeId: string;
fieldName: string;
@ -18,6 +14,7 @@ type UseConnectionStateProps = {
};
export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionStateProps) => {
const pendingConnection = useStore($pendingConnection);
const fieldType = useFieldType(nodeId, fieldName, kind);
const selectIsConnected = useMemo(
@ -36,25 +33,29 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta
);
const selectConnectionError = useMemo(
() => makeConnectionErrorSelector(nodeId, fieldName, kind === 'inputs' ? 'target' : 'source', fieldType),
[nodeId, fieldName, kind, fieldType]
);
const selectIsConnectionStartField = useMemo(
() =>
createSelector(selectNodesSlice, (nodes) =>
Boolean(
nodes.connectionStartParams?.nodeId === nodeId &&
nodes.connectionStartParams?.handleId === fieldName &&
nodes.connectionStartParams?.handleType === { inputs: 'target', outputs: 'source' }[kind]
)
makeConnectionErrorSelector(
pendingConnection,
nodeId,
fieldName,
kind === 'inputs' ? 'target' : 'source',
fieldType
),
[fieldName, kind, nodeId]
[pendingConnection, nodeId, fieldName, kind, fieldType]
);
const isConnected = useAppSelector(selectIsConnected);
const isConnectionInProgress = useAppSelector(selectIsConnectionInProgress);
const isConnectionStartField = useAppSelector(selectIsConnectionStartField);
const isConnectionInProgress = useMemo(() => Boolean(pendingConnection), [pendingConnection]);
const isConnectionStartField = useMemo(() => {
if (!pendingConnection) {
return false;
}
return (
pendingConnection.node.id === nodeId &&
pendingConnection.fieldTemplate.name === fieldName &&
pendingConnection.fieldTemplate.fieldKind === { inputs: 'input', outputs: 'output' }[kind]
);
}, [fieldName, kind, nodeId, pendingConnection]);
const connectionError = useAppSelector(selectConnectionError);
const shouldDim = useMemo(

View File

@ -47,17 +47,7 @@ import type { AnyNode, InvocationNodeEdge, NodeExecutionState } from 'features/n
import { isInvocationNode, isNotesNode, zNodeStatus } from 'features/nodes/types/invocation';
import { forEach } from 'lodash-es';
import { atom } from 'nanostores';
import type {
Connection,
Edge,
EdgeChange,
EdgeRemoveChange,
Node,
NodeChange,
OnConnectStartParams,
Viewport,
XYPosition,
} from 'reactflow';
import type { Connection, Edge, EdgeChange, EdgeRemoveChange, Node, NodeChange, Viewport, XYPosition } from 'reactflow';
import { addEdge, applyEdgeChanges, applyNodeChanges, getConnectedEdges, getIncomers, getOutgoers } from 'reactflow';
import type { UndoableOptions } from 'redux-undo';
import {
@ -70,7 +60,6 @@ import {
import type { z } from 'zod';
import type { NodesState, PendingConnection, Templates } from './types';
import { findConnectionToValidHandle } from './util/findConnectionToValidHandle';
import { findUnoccupiedPosition } from './util/findUnoccupiedPosition';
const initialNodeExecutionState: Omit<NodeExecutionState, 'nodeId'> = {
@ -85,13 +74,6 @@ const initialNodesState: NodesState = {
_version: 1,
nodes: [],
edges: [],
templates: {},
connectionStartParams: null,
connectionStartFieldType: null,
connectionMade: false,
modifyingEdge: false,
addNewNodePosition: null,
isAddNodePopoverOpen: false,
selectedNodes: [],
selectedEdges: [],
nodeExecutionStates: {},
@ -137,12 +119,12 @@ export const nodesSlice = createSlice({
}
state.nodes[nodeIndex] = action.payload.node;
},
nodeAdded: (state, action: PayloadAction<AnyNode>) => {
const node = action.payload;
nodeAdded: (state, action: PayloadAction<{ node: AnyNode; cursorPos: XYPosition | null }>) => {
const { node, cursorPos } = action.payload;
const position = findUnoccupiedPosition(
state.nodes,
state.addNewNodePosition?.x ?? node.position.x,
state.addNewNodePosition?.y ?? node.position.y
cursorPos?.x ?? node.position.x,
cursorPos?.y ?? node.position.y
);
node.position = position;
node.selected = true;
@ -167,31 +149,6 @@ export const nodesSlice = createSlice({
nodeId: node.id,
...initialNodeExecutionState,
};
if (state.connectionStartParams) {
const { nodeId, handleId, handleType } = state.connectionStartParams;
if (nodeId && handleId && handleType && state.connectionStartFieldType) {
const newConnection = findConnectionToValidHandle(
node,
state.nodes,
state.edges,
state.templates,
nodeId,
handleId,
handleType,
state.connectionStartFieldType
);
if (newConnection) {
state.edges = addEdge({ ...newConnection, type: 'default' }, state.edges);
}
}
}
state.connectionStartParams = null;
state.connectionStartFieldType = null;
},
edgeChangeStarted: (state) => {
state.modifyingEdge = true;
},
edgesChanged: (state, action: PayloadAction<EdgeChange[]>) => {
state.edges = applyEdgeChanges(action.payload, state.edges);
@ -199,66 +156,9 @@ export const nodesSlice = createSlice({
edgeAdded: (state, action: PayloadAction<Edge>) => {
state.edges = addEdge(action.payload, state.edges);
},
connectionStarted: (state, action: PayloadAction<OnConnectStartParams>) => {
state.connectionStartParams = action.payload;
state.connectionMade = state.modifyingEdge;
const { nodeId, handleId, handleType } = action.payload;
if (!nodeId || !handleId) {
return;
}
const node = state.nodes.find((n) => n.id === nodeId);
if (!isInvocationNode(node)) {
return;
}
const template = state.templates[node.data.type];
const field = handleType === 'source' ? template?.outputs[handleId] : template?.inputs[handleId];
state.connectionStartFieldType = field?.type ?? null;
},
connectionMade: (state, action: PayloadAction<Connection>) => {
state.edges = addEdge({ ...action.payload, type: 'default' }, state.edges);
},
connectionEnded: (
state,
action: PayloadAction<{
cursorPosition: XYPosition;
mouseOverNodeId: string | null;
}>
) => {
const { cursorPosition, mouseOverNodeId } = action.payload;
if (!state.connectionMade) {
if (mouseOverNodeId) {
const nodeIndex = state.nodes.findIndex((n) => n.id === mouseOverNodeId);
const mouseOverNode = state.nodes?.[nodeIndex];
if (mouseOverNode && state.connectionStartParams) {
const { nodeId, handleId, handleType } = state.connectionStartParams;
if (nodeId && handleId && handleType && state.connectionStartFieldType) {
const newConnection = findConnectionToValidHandle(
mouseOverNode,
state.nodes,
state.edges,
state.templates,
nodeId,
handleId,
handleType,
state.connectionStartFieldType
);
if (newConnection) {
state.edges = addEdge({ ...newConnection, type: 'default' }, state.edges);
}
}
}
state.connectionStartParams = null;
state.connectionStartFieldType = null;
} else {
state.addNewNodePosition = cursorPosition;
state.isAddNodePopoverOpen = true;
}
} else {
state.connectionStartParams = null;
state.connectionStartFieldType = null;
}
state.modifyingEdge = false;
},
fieldLabelChanged: (
state,
action: PayloadAction<{
@ -580,17 +480,6 @@ export const nodesSlice = createSlice({
};
});
},
addNodePopoverOpened: (state) => {
state.addNewNodePosition = null; //Create the node in viewport center by default
state.isAddNodePopoverOpen = true;
},
addNodePopoverClosed: (state) => {
state.isAddNodePopoverOpen = false;
//Make sure these get reset if we close the popover and haven't selected a node
state.connectionStartParams = null;
state.connectionStartFieldType = null;
},
undo: (state) => state,
redo: (state) => state,
},
@ -670,13 +559,8 @@ export const nodesSlice = createSlice({
});
export const {
addNodePopoverClosed,
addNodePopoverOpened,
connectionEnded,
connectionMade,
connectionStarted,
edgeDeleted,
edgeChangeStarted,
edgesChanged,
edgesDeleted,
fieldValueReset,
@ -720,7 +604,6 @@ export const {
// This is used for tracking `state.workflow.isTouched`
export const isAnyNodeOrEdgeMutation = isAnyOf(
connectionEnded,
connectionMade,
edgeDeleted,
edgesChanged,
@ -783,15 +666,7 @@ export const nodesPersistConfig: PersistConfig<NodesState> = {
name: nodesSlice.name,
initialState: initialNodesState,
migrate: migrateNodesState,
persistDenylist: [
'connectionStartParams',
'connectionStartFieldType',
'selectedNodes',
'selectedEdges',
'connectionMade',
'modifyingEdge',
'addNewNodePosition',
],
persistDenylist: ['selectedNodes', 'selectedEdges'],
};
export const nodesUndoableConfig: UndoableOptions<NodesState, UnknownAction> = {

View File

@ -2,7 +2,6 @@ import type {
FieldIdentifier,
FieldInputTemplate,
FieldOutputTemplate,
FieldType,
StatefulFieldValue,
} from 'features/nodes/types/field';
import type {
@ -13,7 +12,7 @@ import type {
NodeExecutionState,
} from 'features/nodes/types/invocation';
import type { WorkflowV3 } from 'features/nodes/types/workflow';
import type { OnConnectStartParams, Viewport, XYPosition } from 'reactflow';
import type { Viewport } from 'reactflow';
export type Templates = Record<string, InvocationTemplate>;
@ -27,16 +26,10 @@ export type NodesState = {
_version: 1;
nodes: AnyNode[];
edges: InvocationNodeEdge[];
connectionStartParams: OnConnectStartParams | null;
connectionStartFieldType: FieldType | null;
connectionMade: boolean;
modifyingEdge: boolean;
selectedNodes: string[];
selectedEdges: string[];
nodeExecutionStates: Record<string, NodeExecutionState>;
viewport: Viewport;
isAddNodePopoverOpen: boolean;
addNewNodePosition: XYPosition | null;
};
export type WorkflowMode = 'edit' | 'view';

View File

@ -1,5 +1,6 @@
import { createSelector } from '@reduxjs/toolkit';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import type { PendingConnection } from 'features/nodes/store/types';
import type { FieldType } from 'features/nodes/types/field';
import i18n from 'i18next';
import type { HandleType } from 'reactflow';
@ -13,27 +14,27 @@ import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
*/
export const makeConnectionErrorSelector = (
pendingConnection: PendingConnection | null,
nodeId: string,
fieldName: string,
handleType: HandleType,
fieldType?: FieldType | null
) => {
return createSelector(selectNodesSlice, (nodesSlice) => {
const { nodes, edges } = nodesSlice;
if (!fieldType) {
return i18n.t('nodes.noFieldType');
}
const { connectionStartFieldType, connectionStartParams, nodes, edges } = nodesSlice;
if (!connectionStartParams || !connectionStartFieldType) {
if (!pendingConnection) {
return i18n.t('nodes.noConnectionInProgress');
}
const {
handleType: connectionHandleType,
nodeId: connectionNodeId,
handleId: connectionFieldName,
} = connectionStartParams;
const connectionNodeId = pendingConnection.node.id;
const connectionFieldName = pendingConnection.fieldTemplate.name;
const connectionHandleType = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source';
const connectionStartFieldType = pendingConnection.fieldTemplate.type;
if (!connectionHandleType || !connectionNodeId || !connectionFieldName) {
return i18n.t('nodes.noConnectionData');