mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): remove remaining extraneous state from nodes slice
This commit is contained in:
parent
4d68cd8dbb
commit
b0c7c7cb47
@ -8,6 +8,7 @@ import { useAppDispatch, useAppStore } from 'app/store/storeHooks';
|
||||
import type { SelectInstance } from 'chakra-react-select';
|
||||
import { useBuildNode } from 'features/nodes/hooks/useBuildNode';
|
||||
import {
|
||||
$cursorPos,
|
||||
$isAddNodePopoverOpen,
|
||||
$pendingConnection,
|
||||
$templates,
|
||||
@ -117,8 +118,8 @@ const AddNodePopover = () => {
|
||||
|
||||
const addNode = useCallback(
|
||||
(nodeType: string): AnyNode | null => {
|
||||
const invocation = buildInvocation(nodeType);
|
||||
if (!invocation) {
|
||||
const node = buildInvocation(nodeType);
|
||||
if (!node) {
|
||||
const errorMessage = t('nodes.unknownNode', {
|
||||
nodeType: nodeType,
|
||||
});
|
||||
@ -128,9 +129,9 @@ const AddNodePopover = () => {
|
||||
});
|
||||
return null;
|
||||
}
|
||||
|
||||
dispatch(nodeAdded(invocation));
|
||||
return invocation;
|
||||
const cursorPos = $cursorPos.get();
|
||||
dispatch(nodeAdded({ node, cursorPos }));
|
||||
return node;
|
||||
},
|
||||
[dispatch, buildInvocation, toaster, t]
|
||||
);
|
||||
|
@ -8,7 +8,6 @@ import {
|
||||
$cursorPos,
|
||||
connectionMade,
|
||||
edgeAdded,
|
||||
edgeChangeStarted,
|
||||
edgeDeleted,
|
||||
edgesChanged,
|
||||
edgesDeleted,
|
||||
@ -170,7 +169,6 @@ export const Flow = memo(() => {
|
||||
edgeUpdateMouseEvent.current = e;
|
||||
// always delete the edge when starting an updated
|
||||
dispatch(edgeDeleted(edge.id));
|
||||
dispatch(edgeChangeStarted());
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
@ -1,29 +1,33 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
|
||||
import { getFieldColor } from 'features/nodes/components/flow/edges/util/getEdgeColor';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
|
||||
import { $pendingConnection } from 'features/nodes/store/nodesSlice';
|
||||
import type { CSSProperties } from 'react';
|
||||
import { memo } from 'react';
|
||||
import { memo, useMemo } from 'react';
|
||||
import type { ConnectionLineComponentProps } from 'reactflow';
|
||||
import { getBezierPath } from 'reactflow';
|
||||
|
||||
const selectStroke = createSelector([selectNodesSlice, selectWorkflowSettingsSlice], (nodes, workflowSettings) =>
|
||||
workflowSettings.shouldColorEdges ? getFieldColor(nodes.connectionStartFieldType) : colorTokenToCssVar('base.500')
|
||||
);
|
||||
|
||||
const selectClassName = createSelector(selectWorkflowSettingsSlice, (workflowSettings) =>
|
||||
workflowSettings.shouldAnimateEdges
|
||||
? 'react-flow__custom_connection-path animated'
|
||||
: 'react-flow__custom_connection-path'
|
||||
);
|
||||
|
||||
const pathStyles: CSSProperties = { opacity: 0.8 };
|
||||
|
||||
const CustomConnectionLine = ({ fromX, fromY, fromPosition, toX, toY, toPosition }: ConnectionLineComponentProps) => {
|
||||
const stroke = useAppSelector(selectStroke);
|
||||
const className = useAppSelector(selectClassName);
|
||||
const pendingConnection = useStore($pendingConnection);
|
||||
const shouldColorEdges = useAppSelector((state) => state.workflowSettings.shouldColorEdges);
|
||||
const shouldAnimateEdges = useAppSelector((state) => state.workflowSettings.shouldAnimateEdges);
|
||||
const stroke = useMemo(() => {
|
||||
if (shouldColorEdges && pendingConnection) {
|
||||
return getFieldColor(pendingConnection.fieldTemplate.type);
|
||||
} else {
|
||||
return colorTokenToCssVar('base.500');
|
||||
}
|
||||
}, [pendingConnection, shouldColorEdges]);
|
||||
const className = useMemo(() => {
|
||||
if (shouldAnimateEdges) {
|
||||
return 'react-flow__custom_connection-path animated';
|
||||
} else {
|
||||
return 'react-flow__custom_connection-path';
|
||||
}
|
||||
}, [shouldAnimateEdges]);
|
||||
|
||||
const pathParams = {
|
||||
sourceX: fromX,
|
||||
|
@ -1,23 +1,18 @@
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { addNodePopoverOpened } from 'features/nodes/store/nodesSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { openAddNodePopover } from 'features/nodes/store/nodesSlice';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlusBold } from 'react-icons/pi';
|
||||
|
||||
const AddNodeButton = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const handleOpenAddNodePopover = useCallback(() => {
|
||||
dispatch(addNodePopoverOpened());
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
tooltip={t('nodes.addNodeToolTip')}
|
||||
aria-label={t('nodes.addNode')}
|
||||
icon={<PiPlusBold />}
|
||||
onClick={handleOpenAddNodePopover}
|
||||
onClick={openAddNodePopover}
|
||||
pointerEvents="auto"
|
||||
/>
|
||||
);
|
||||
|
@ -1,16 +1,12 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { $pendingConnection, selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { makeConnectionErrorSelector } from 'features/nodes/store/util/makeIsConnectionValidSelector';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
import { useFieldType } from './useFieldType.ts';
|
||||
|
||||
const selectIsConnectionInProgress = createSelector(
|
||||
selectNodesSlice,
|
||||
(nodes) => nodes.connectionStartFieldType !== null && nodes.connectionStartParams !== null
|
||||
);
|
||||
|
||||
type UseConnectionStateProps = {
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
@ -18,6 +14,7 @@ type UseConnectionStateProps = {
|
||||
};
|
||||
|
||||
export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionStateProps) => {
|
||||
const pendingConnection = useStore($pendingConnection);
|
||||
const fieldType = useFieldType(nodeId, fieldName, kind);
|
||||
|
||||
const selectIsConnected = useMemo(
|
||||
@ -36,25 +33,29 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta
|
||||
);
|
||||
|
||||
const selectConnectionError = useMemo(
|
||||
() => makeConnectionErrorSelector(nodeId, fieldName, kind === 'inputs' ? 'target' : 'source', fieldType),
|
||||
[nodeId, fieldName, kind, fieldType]
|
||||
);
|
||||
|
||||
const selectIsConnectionStartField = useMemo(
|
||||
() =>
|
||||
createSelector(selectNodesSlice, (nodes) =>
|
||||
Boolean(
|
||||
nodes.connectionStartParams?.nodeId === nodeId &&
|
||||
nodes.connectionStartParams?.handleId === fieldName &&
|
||||
nodes.connectionStartParams?.handleType === { inputs: 'target', outputs: 'source' }[kind]
|
||||
)
|
||||
makeConnectionErrorSelector(
|
||||
pendingConnection,
|
||||
nodeId,
|
||||
fieldName,
|
||||
kind === 'inputs' ? 'target' : 'source',
|
||||
fieldType
|
||||
),
|
||||
[fieldName, kind, nodeId]
|
||||
[pendingConnection, nodeId, fieldName, kind, fieldType]
|
||||
);
|
||||
|
||||
const isConnected = useAppSelector(selectIsConnected);
|
||||
const isConnectionInProgress = useAppSelector(selectIsConnectionInProgress);
|
||||
const isConnectionStartField = useAppSelector(selectIsConnectionStartField);
|
||||
const isConnectionInProgress = useMemo(() => Boolean(pendingConnection), [pendingConnection]);
|
||||
const isConnectionStartField = useMemo(() => {
|
||||
if (!pendingConnection) {
|
||||
return false;
|
||||
}
|
||||
return (
|
||||
pendingConnection.node.id === nodeId &&
|
||||
pendingConnection.fieldTemplate.name === fieldName &&
|
||||
pendingConnection.fieldTemplate.fieldKind === { inputs: 'input', outputs: 'output' }[kind]
|
||||
);
|
||||
}, [fieldName, kind, nodeId, pendingConnection]);
|
||||
const connectionError = useAppSelector(selectConnectionError);
|
||||
|
||||
const shouldDim = useMemo(
|
||||
|
@ -47,17 +47,7 @@ import type { AnyNode, InvocationNodeEdge, NodeExecutionState } from 'features/n
|
||||
import { isInvocationNode, isNotesNode, zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { atom } from 'nanostores';
|
||||
import type {
|
||||
Connection,
|
||||
Edge,
|
||||
EdgeChange,
|
||||
EdgeRemoveChange,
|
||||
Node,
|
||||
NodeChange,
|
||||
OnConnectStartParams,
|
||||
Viewport,
|
||||
XYPosition,
|
||||
} from 'reactflow';
|
||||
import type { Connection, Edge, EdgeChange, EdgeRemoveChange, Node, NodeChange, Viewport, XYPosition } from 'reactflow';
|
||||
import { addEdge, applyEdgeChanges, applyNodeChanges, getConnectedEdges, getIncomers, getOutgoers } from 'reactflow';
|
||||
import type { UndoableOptions } from 'redux-undo';
|
||||
import {
|
||||
@ -70,7 +60,6 @@ import {
|
||||
import type { z } from 'zod';
|
||||
|
||||
import type { NodesState, PendingConnection, Templates } from './types';
|
||||
import { findConnectionToValidHandle } from './util/findConnectionToValidHandle';
|
||||
import { findUnoccupiedPosition } from './util/findUnoccupiedPosition';
|
||||
|
||||
const initialNodeExecutionState: Omit<NodeExecutionState, 'nodeId'> = {
|
||||
@ -85,13 +74,6 @@ const initialNodesState: NodesState = {
|
||||
_version: 1,
|
||||
nodes: [],
|
||||
edges: [],
|
||||
templates: {},
|
||||
connectionStartParams: null,
|
||||
connectionStartFieldType: null,
|
||||
connectionMade: false,
|
||||
modifyingEdge: false,
|
||||
addNewNodePosition: null,
|
||||
isAddNodePopoverOpen: false,
|
||||
selectedNodes: [],
|
||||
selectedEdges: [],
|
||||
nodeExecutionStates: {},
|
||||
@ -137,12 +119,12 @@ export const nodesSlice = createSlice({
|
||||
}
|
||||
state.nodes[nodeIndex] = action.payload.node;
|
||||
},
|
||||
nodeAdded: (state, action: PayloadAction<AnyNode>) => {
|
||||
const node = action.payload;
|
||||
nodeAdded: (state, action: PayloadAction<{ node: AnyNode; cursorPos: XYPosition | null }>) => {
|
||||
const { node, cursorPos } = action.payload;
|
||||
const position = findUnoccupiedPosition(
|
||||
state.nodes,
|
||||
state.addNewNodePosition?.x ?? node.position.x,
|
||||
state.addNewNodePosition?.y ?? node.position.y
|
||||
cursorPos?.x ?? node.position.x,
|
||||
cursorPos?.y ?? node.position.y
|
||||
);
|
||||
node.position = position;
|
||||
node.selected = true;
|
||||
@ -167,31 +149,6 @@ export const nodesSlice = createSlice({
|
||||
nodeId: node.id,
|
||||
...initialNodeExecutionState,
|
||||
};
|
||||
|
||||
if (state.connectionStartParams) {
|
||||
const { nodeId, handleId, handleType } = state.connectionStartParams;
|
||||
if (nodeId && handleId && handleType && state.connectionStartFieldType) {
|
||||
const newConnection = findConnectionToValidHandle(
|
||||
node,
|
||||
state.nodes,
|
||||
state.edges,
|
||||
state.templates,
|
||||
nodeId,
|
||||
handleId,
|
||||
handleType,
|
||||
state.connectionStartFieldType
|
||||
);
|
||||
if (newConnection) {
|
||||
state.edges = addEdge({ ...newConnection, type: 'default' }, state.edges);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
state.connectionStartParams = null;
|
||||
state.connectionStartFieldType = null;
|
||||
},
|
||||
edgeChangeStarted: (state) => {
|
||||
state.modifyingEdge = true;
|
||||
},
|
||||
edgesChanged: (state, action: PayloadAction<EdgeChange[]>) => {
|
||||
state.edges = applyEdgeChanges(action.payload, state.edges);
|
||||
@ -199,66 +156,9 @@ export const nodesSlice = createSlice({
|
||||
edgeAdded: (state, action: PayloadAction<Edge>) => {
|
||||
state.edges = addEdge(action.payload, state.edges);
|
||||
},
|
||||
connectionStarted: (state, action: PayloadAction<OnConnectStartParams>) => {
|
||||
state.connectionStartParams = action.payload;
|
||||
state.connectionMade = state.modifyingEdge;
|
||||
const { nodeId, handleId, handleType } = action.payload;
|
||||
if (!nodeId || !handleId) {
|
||||
return;
|
||||
}
|
||||
const node = state.nodes.find((n) => n.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
const template = state.templates[node.data.type];
|
||||
const field = handleType === 'source' ? template?.outputs[handleId] : template?.inputs[handleId];
|
||||
state.connectionStartFieldType = field?.type ?? null;
|
||||
},
|
||||
connectionMade: (state, action: PayloadAction<Connection>) => {
|
||||
state.edges = addEdge({ ...action.payload, type: 'default' }, state.edges);
|
||||
},
|
||||
connectionEnded: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
cursorPosition: XYPosition;
|
||||
mouseOverNodeId: string | null;
|
||||
}>
|
||||
) => {
|
||||
const { cursorPosition, mouseOverNodeId } = action.payload;
|
||||
if (!state.connectionMade) {
|
||||
if (mouseOverNodeId) {
|
||||
const nodeIndex = state.nodes.findIndex((n) => n.id === mouseOverNodeId);
|
||||
const mouseOverNode = state.nodes?.[nodeIndex];
|
||||
if (mouseOverNode && state.connectionStartParams) {
|
||||
const { nodeId, handleId, handleType } = state.connectionStartParams;
|
||||
if (nodeId && handleId && handleType && state.connectionStartFieldType) {
|
||||
const newConnection = findConnectionToValidHandle(
|
||||
mouseOverNode,
|
||||
state.nodes,
|
||||
state.edges,
|
||||
state.templates,
|
||||
nodeId,
|
||||
handleId,
|
||||
handleType,
|
||||
state.connectionStartFieldType
|
||||
);
|
||||
if (newConnection) {
|
||||
state.edges = addEdge({ ...newConnection, type: 'default' }, state.edges);
|
||||
}
|
||||
}
|
||||
}
|
||||
state.connectionStartParams = null;
|
||||
state.connectionStartFieldType = null;
|
||||
} else {
|
||||
state.addNewNodePosition = cursorPosition;
|
||||
state.isAddNodePopoverOpen = true;
|
||||
}
|
||||
} else {
|
||||
state.connectionStartParams = null;
|
||||
state.connectionStartFieldType = null;
|
||||
}
|
||||
state.modifyingEdge = false;
|
||||
},
|
||||
fieldLabelChanged: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
@ -580,17 +480,6 @@ export const nodesSlice = createSlice({
|
||||
};
|
||||
});
|
||||
},
|
||||
addNodePopoverOpened: (state) => {
|
||||
state.addNewNodePosition = null; //Create the node in viewport center by default
|
||||
state.isAddNodePopoverOpen = true;
|
||||
},
|
||||
addNodePopoverClosed: (state) => {
|
||||
state.isAddNodePopoverOpen = false;
|
||||
|
||||
//Make sure these get reset if we close the popover and haven't selected a node
|
||||
state.connectionStartParams = null;
|
||||
state.connectionStartFieldType = null;
|
||||
},
|
||||
undo: (state) => state,
|
||||
redo: (state) => state,
|
||||
},
|
||||
@ -670,13 +559,8 @@ export const nodesSlice = createSlice({
|
||||
});
|
||||
|
||||
export const {
|
||||
addNodePopoverClosed,
|
||||
addNodePopoverOpened,
|
||||
connectionEnded,
|
||||
connectionMade,
|
||||
connectionStarted,
|
||||
edgeDeleted,
|
||||
edgeChangeStarted,
|
||||
edgesChanged,
|
||||
edgesDeleted,
|
||||
fieldValueReset,
|
||||
@ -720,7 +604,6 @@ export const {
|
||||
|
||||
// This is used for tracking `state.workflow.isTouched`
|
||||
export const isAnyNodeOrEdgeMutation = isAnyOf(
|
||||
connectionEnded,
|
||||
connectionMade,
|
||||
edgeDeleted,
|
||||
edgesChanged,
|
||||
@ -783,15 +666,7 @@ export const nodesPersistConfig: PersistConfig<NodesState> = {
|
||||
name: nodesSlice.name,
|
||||
initialState: initialNodesState,
|
||||
migrate: migrateNodesState,
|
||||
persistDenylist: [
|
||||
'connectionStartParams',
|
||||
'connectionStartFieldType',
|
||||
'selectedNodes',
|
||||
'selectedEdges',
|
||||
'connectionMade',
|
||||
'modifyingEdge',
|
||||
'addNewNodePosition',
|
||||
],
|
||||
persistDenylist: ['selectedNodes', 'selectedEdges'],
|
||||
};
|
||||
|
||||
export const nodesUndoableConfig: UndoableOptions<NodesState, UnknownAction> = {
|
||||
|
@ -2,7 +2,6 @@ import type {
|
||||
FieldIdentifier,
|
||||
FieldInputTemplate,
|
||||
FieldOutputTemplate,
|
||||
FieldType,
|
||||
StatefulFieldValue,
|
||||
} from 'features/nodes/types/field';
|
||||
import type {
|
||||
@ -13,7 +12,7 @@ import type {
|
||||
NodeExecutionState,
|
||||
} from 'features/nodes/types/invocation';
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import type { OnConnectStartParams, Viewport, XYPosition } from 'reactflow';
|
||||
import type { Viewport } from 'reactflow';
|
||||
|
||||
export type Templates = Record<string, InvocationTemplate>;
|
||||
|
||||
@ -27,16 +26,10 @@ export type NodesState = {
|
||||
_version: 1;
|
||||
nodes: AnyNode[];
|
||||
edges: InvocationNodeEdge[];
|
||||
connectionStartParams: OnConnectStartParams | null;
|
||||
connectionStartFieldType: FieldType | null;
|
||||
connectionMade: boolean;
|
||||
modifyingEdge: boolean;
|
||||
selectedNodes: string[];
|
||||
selectedEdges: string[];
|
||||
nodeExecutionStates: Record<string, NodeExecutionState>;
|
||||
viewport: Viewport;
|
||||
isAddNodePopoverOpen: boolean;
|
||||
addNewNodePosition: XYPosition | null;
|
||||
};
|
||||
|
||||
export type WorkflowMode = 'edit' | 'view';
|
||||
|
@ -1,5 +1,6 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import type { PendingConnection } from 'features/nodes/store/types';
|
||||
import type { FieldType } from 'features/nodes/types/field';
|
||||
import i18n from 'i18next';
|
||||
import type { HandleType } from 'reactflow';
|
||||
@ -13,27 +14,27 @@ import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
|
||||
*/
|
||||
|
||||
export const makeConnectionErrorSelector = (
|
||||
pendingConnection: PendingConnection | null,
|
||||
nodeId: string,
|
||||
fieldName: string,
|
||||
handleType: HandleType,
|
||||
fieldType?: FieldType | null
|
||||
) => {
|
||||
return createSelector(selectNodesSlice, (nodesSlice) => {
|
||||
const { nodes, edges } = nodesSlice;
|
||||
|
||||
if (!fieldType) {
|
||||
return i18n.t('nodes.noFieldType');
|
||||
}
|
||||
|
||||
const { connectionStartFieldType, connectionStartParams, nodes, edges } = nodesSlice;
|
||||
|
||||
if (!connectionStartParams || !connectionStartFieldType) {
|
||||
if (!pendingConnection) {
|
||||
return i18n.t('nodes.noConnectionInProgress');
|
||||
}
|
||||
|
||||
const {
|
||||
handleType: connectionHandleType,
|
||||
nodeId: connectionNodeId,
|
||||
handleId: connectionFieldName,
|
||||
} = connectionStartParams;
|
||||
const connectionNodeId = pendingConnection.node.id;
|
||||
const connectionFieldName = pendingConnection.fieldTemplate.name;
|
||||
const connectionHandleType = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source';
|
||||
const connectionStartFieldType = pendingConnection.fieldTemplate.type;
|
||||
|
||||
if (!connectionHandleType || !connectionNodeId || !connectionFieldName) {
|
||||
return i18n.t('nodes.noConnectionData');
|
||||
|
Loading…
Reference in New Issue
Block a user