feat(ui): make nodesSlice undoable

This commit is contained in:
psychedelicious 2024-05-15 16:37:13 +10:00
parent 31d8b50276
commit 27826369f0
18 changed files with 64 additions and 29 deletions

View File

@ -21,7 +21,7 @@ export const addDeleteBoardAndImagesFulfilledListener = (startAppListening: AppS
const { canvas, nodes, controlAdapters, controlLayers } = getState(); const { canvas, nodes, controlAdapters, controlLayers } = getState();
deleted_images.forEach((image_name) => { deleted_images.forEach((image_name) => {
const imageUsage = getImageUsage(canvas, nodes, controlAdapters, controlLayers.present, image_name); const imageUsage = getImageUsage(canvas, nodes.present, controlAdapters, controlLayers.present, image_name);
if (imageUsage.isCanvasImage && !wasCanvasReset) { if (imageUsage.isCanvasImage && !wasCanvasReset) {
dispatch(resetCanvas()); dispatch(resetCanvas());

View File

@ -11,9 +11,9 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
enqueueRequested.match(action) && action.payload.tabName === 'workflows', enqueueRequested.match(action) && action.payload.tabName === 'workflows',
effect: async (action, { getState, dispatch }) => { effect: async (action, { getState, dispatch }) => {
const state = getState(); const state = getState();
const { nodes, edges } = state.nodes; const { nodes, edges } = state.nodes.present;
const workflow = state.workflow; const workflow = state.workflow;
const graph = buildNodesGraph(state.nodes); const graph = buildNodesGraph(state.nodes.present);
const builtWorkflow = buildWorkflowWithValidation({ const builtWorkflow = buildWorkflowWithValidation({
nodes, nodes,
edges, edges,

View File

@ -29,7 +29,7 @@ import type { ImageDTO } from 'services/api/types';
import { imagesSelectors } from 'services/api/util'; import { imagesSelectors } from 'services/api/util';
const deleteNodesImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => { const deleteNodesImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
state.nodes.nodes.forEach((node) => { state.nodes.present.nodes.forEach((node) => {
if (!isInvocationNode(node)) { if (!isInvocationNode(node)) {
return; return;
} }

View File

@ -14,7 +14,7 @@ export const addUpdateAllNodesRequestedListener = (startAppListening: AppStartLi
actionCreator: updateAllNodesRequested, actionCreator: updateAllNodesRequested,
effect: (action, { dispatch, getState }) => { effect: (action, { dispatch, getState }) => {
const log = logger('nodes'); const log = logger('nodes');
const { nodes, templates } = getState().nodes; const { nodes, templates } = getState().nodes.present;
let unableToUpdateCount = 0; let unableToUpdateCount = 0;

View File

@ -17,7 +17,7 @@ export const addWorkflowLoadRequestedListener = (startAppListening: AppStartList
effect: (action, { dispatch, getState }) => { effect: (action, { dispatch, getState }) => {
const log = logger('nodes'); const log = logger('nodes');
const { workflow, asCopy } = action.payload; const { workflow, asCopy } = action.payload;
const nodeTemplates = getState().nodes.templates; const nodeTemplates = getState().nodes.present.templates;
try { try {
const { workflow: validatedWorkflow, warnings } = validateWorkflow(workflow, nodeTemplates); const { workflow: validatedWorkflow, warnings } = validateWorkflow(workflow, nodeTemplates);

View File

@ -21,7 +21,7 @@ import { galleryPersistConfig, gallerySlice } from 'features/gallery/store/galle
import { hrfPersistConfig, hrfSlice } from 'features/hrf/store/hrfSlice'; import { hrfPersistConfig, hrfSlice } from 'features/hrf/store/hrfSlice';
import { loraPersistConfig, loraSlice } from 'features/lora/store/loraSlice'; import { loraPersistConfig, loraSlice } from 'features/lora/store/loraSlice';
import { modelManagerV2PersistConfig, modelManagerV2Slice } from 'features/modelManagerV2/store/modelManagerV2Slice'; import { modelManagerV2PersistConfig, modelManagerV2Slice } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { nodesPersistConfig, nodesSlice } from 'features/nodes/store/nodesSlice'; import { nodesPersistConfig, nodesSlice, nodesUndoableConfig } from 'features/nodes/store/nodesSlice';
import { workflowPersistConfig, workflowSlice } from 'features/nodes/store/workflowSlice'; import { workflowPersistConfig, workflowSlice } from 'features/nodes/store/workflowSlice';
import { generationPersistConfig, generationSlice } from 'features/parameters/store/generationSlice'; import { generationPersistConfig, generationSlice } from 'features/parameters/store/generationSlice';
import { postprocessingPersistConfig, postprocessingSlice } from 'features/parameters/store/postprocessingSlice'; import { postprocessingPersistConfig, postprocessingSlice } from 'features/parameters/store/postprocessingSlice';
@ -50,7 +50,7 @@ const allReducers = {
[canvasSlice.name]: canvasSlice.reducer, [canvasSlice.name]: canvasSlice.reducer,
[gallerySlice.name]: gallerySlice.reducer, [gallerySlice.name]: gallerySlice.reducer,
[generationSlice.name]: generationSlice.reducer, [generationSlice.name]: generationSlice.reducer,
[nodesSlice.name]: nodesSlice.reducer, [nodesSlice.name]: undoable(nodesSlice.reducer, nodesUndoableConfig),
[postprocessingSlice.name]: postprocessingSlice.reducer, [postprocessingSlice.name]: postprocessingSlice.reducer,
[systemSlice.name]: systemSlice.reducer, [systemSlice.name]: systemSlice.reducer,
[configSlice.name]: configSlice.reducer, [configSlice.name]: configSlice.reducer,

View File

@ -55,8 +55,8 @@ const AddNodePopover = () => {
const selectRef = useRef<SelectInstance<ComboboxOption> | null>(null); const selectRef = useRef<SelectInstance<ComboboxOption> | null>(null);
const inputRef = useRef<HTMLInputElement>(null); const inputRef = useRef<HTMLInputElement>(null);
const fieldFilter = useAppSelector((s) => s.nodes.connectionStartFieldType); const fieldFilter = useAppSelector((s) => s.nodes.present.connectionStartFieldType);
const handleFilter = useAppSelector((s) => s.nodes.connectionStartParams?.handleType); const handleFilter = useAppSelector((s) => s.nodes.present.connectionStartParams?.handleType);
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => { const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
// If we have a connection in progress, we need to filter the node choices // If we have a connection in progress, we need to filter the node choices
@ -105,7 +105,7 @@ const AddNodePopover = () => {
}); });
const { options } = useAppSelector(selector); const { options } = useAppSelector(selector);
const isOpen = useAppSelector((s) => s.nodes.isAddNodePopoverOpen); const isOpen = useAppSelector((s) => s.nodes.present.isAddNodePopoverOpen);
const addNode = useCallback( const addNode = useCallback(
(nodeType: string) => { (nodeType: string) => {

View File

@ -14,11 +14,13 @@ import {
edgesDeleted, edgesDeleted,
nodesChanged, nodesChanged,
nodesDeleted, nodesDeleted,
redo,
selectedAll, selectedAll,
selectedEdgesChanged, selectedEdgesChanged,
selectedNodesChanged, selectedNodesChanged,
selectionCopied, selectionCopied,
selectionPasted, selectionPasted,
undo,
viewportChanged, viewportChanged,
} from 'features/nodes/store/nodesSlice'; } from 'features/nodes/store/nodesSlice';
import { $flow } from 'features/nodes/store/reactFlowInstance'; import { $flow } from 'features/nodes/store/reactFlowInstance';
@ -70,11 +72,11 @@ const snapGrid: [number, number] = [25, 25];
export const Flow = memo(() => { export const Flow = memo(() => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const nodes = useAppSelector((s) => s.nodes.nodes); const nodes = useAppSelector((s) => s.nodes.present.nodes);
const edges = useAppSelector((s) => s.nodes.edges); const edges = useAppSelector((s) => s.nodes.present.edges);
const viewport = useAppSelector((s) => s.nodes.viewport); const viewport = useAppSelector((s) => s.nodes.present.viewport);
const shouldSnapToGrid = useAppSelector((s) => s.nodes.shouldSnapToGrid); const shouldSnapToGrid = useAppSelector((s) => s.nodes.present.shouldSnapToGrid);
const selectionMode = useAppSelector((s) => s.nodes.selectionMode); const selectionMode = useAppSelector((s) => s.nodes.present.selectionMode);
const flowWrapper = useRef<HTMLDivElement>(null); const flowWrapper = useRef<HTMLDivElement>(null);
const cursorPosition = useRef<XYPosition | null>(null); const cursorPosition = useRef<XYPosition | null>(null);
const isValidConnection = useIsValidConnection(); const isValidConnection = useIsValidConnection();
@ -251,6 +253,22 @@ export const Flow = memo(() => {
dispatch(selectionPasted({ cursorPosition: cursorPosition.current })); dispatch(selectionPasted({ cursorPosition: cursorPosition.current }));
}); });
useHotkeys(
['meta+z', 'ctrl+z'],
() => {
dispatch(undo());
},
[dispatch]
);
useHotkeys(
['meta+shift+z', 'ctrl+shift+z'],
() => {
dispatch(redo());
},
[dispatch]
);
return ( return (
<ReactFlow <ReactFlow
id="workflow-editor" id="workflow-editor"

View File

@ -27,7 +27,7 @@ const InvocationDefaultEdge = ({
); );
const { isSelected, shouldAnimate, stroke, label } = useAppSelector(selector); const { isSelected, shouldAnimate, stroke, label } = useAppSelector(selector);
const shouldShowEdgeLabels = useAppSelector((s) => s.nodes.shouldShowEdgeLabels); const shouldShowEdgeLabels = useAppSelector((s) => s.nodes.present.shouldShowEdgeLabels);
const [edgePath, labelX, labelY] = getBezierPath({ const [edgePath, labelX, labelY] = getBezierPath({
sourceX, sourceX,

View File

@ -39,7 +39,7 @@ const NodeWrapper = (props: NodeWrapperProps) => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const opacity = useAppSelector((s) => s.nodes.nodeOpacity); const opacity = useAppSelector((s) => s.nodes.present.nodeOpacity);
const { onCloseGlobal } = useGlobalMenuClose(); const { onCloseGlobal } = useGlobalMenuClose();
const handleClick = useCallback( const handleClick = useCallback(

View File

@ -6,7 +6,7 @@ import { useTranslation } from 'react-i18next';
const NodeOpacitySlider = () => { const NodeOpacitySlider = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const nodeOpacity = useAppSelector((s) => s.nodes.nodeOpacity); const nodeOpacity = useAppSelector((s) => s.nodes.present.nodeOpacity);
const { t } = useTranslation(); const { t } = useTranslation();
const handleChange = useCallback( const handleChange = useCallback(

View File

@ -19,9 +19,9 @@ const ViewportControls = () => {
const { zoomIn, zoomOut, fitView } = useReactFlow(); const { zoomIn, zoomOut, fitView } = useReactFlow();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
// const shouldShowFieldTypeLegend = useAppSelector( // const shouldShowFieldTypeLegend = useAppSelector(
// (s) => s.nodes.shouldShowFieldTypeLegend // (s) => s.nodes.present.shouldShowFieldTypeLegend
// ); // );
const shouldShowMinimapPanel = useAppSelector((s) => s.nodes.shouldShowMinimapPanel); const shouldShowMinimapPanel = useAppSelector((s) => s.nodes.present.shouldShowMinimapPanel);
const handleClickedZoomIn = useCallback(() => { const handleClickedZoomIn = useCallback(() => {
zoomIn(); zoomIn();

View File

@ -16,7 +16,7 @@ const minimapStyles: SystemStyleObject = {
}; };
const MinimapPanel = () => { const MinimapPanel = () => {
const shouldShowMinimapPanel = useAppSelector((s) => s.nodes.shouldShowMinimapPanel); const shouldShowMinimapPanel = useAppSelector((s) => s.nodes.present.shouldShowMinimapPanel);
return ( return (
<Flex gap={2} position="absolute" bottom={0} insetInlineEnd={0}> <Flex gap={2} position="absolute" bottom={0} insetInlineEnd={0}>

View File

@ -8,7 +8,7 @@ import { useCallback } from 'react';
import { useReactFlow } from 'reactflow'; import { useReactFlow } from 'reactflow';
export const useBuildNode = () => { export const useBuildNode = () => {
const nodeTemplates = useAppSelector((s) => s.nodes.templates); const nodeTemplates = useAppSelector((s) => s.nodes.present.templates);
const flow = useReactFlow(); const flow = useReactFlow();

View File

@ -13,7 +13,7 @@ import type { Connection, Node } from 'reactflow';
export const useIsValidConnection = () => { export const useIsValidConnection = () => {
const store = useAppStore(); const store = useAppStore();
const shouldValidateGraph = useAppSelector((s) => s.nodes.shouldValidateGraph); const shouldValidateGraph = useAppSelector((s) => s.nodes.present.shouldValidateGraph);
const isValidConnection = useCallback( const isValidConnection = useCallback(
({ source, sourceHandle, target, targetHandle }: Connection): boolean => { ({ source, sourceHandle, target, targetHandle }: Connection): boolean => {
// Connection must have valid targets // Connection must have valid targets
@ -27,7 +27,7 @@ export const useIsValidConnection = () => {
} }
const state = store.getState(); const state = store.getState();
const { nodes, edges, templates } = state.nodes; const { nodes, edges, templates } = state.nodes.present;
// Find the source and target nodes // Find the source and target nodes
const sourceNode = nodes.find((node) => node.id === source) as Node<InvocationNodeData>; const sourceNode = nodes.find((node) => node.id === source) as Node<InvocationNodeData>;

View File

@ -1,4 +1,4 @@
import type { PayloadAction } from '@reduxjs/toolkit'; import type { PayloadAction, UnknownAction } from '@reduxjs/toolkit';
import { createSlice, isAnyOf } from '@reduxjs/toolkit'; import { createSlice, isAnyOf } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store'; import type { PersistConfig, RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone'; import { deepClone } from 'common/util/deepClone';
@ -66,6 +66,7 @@ import {
getOutgoers, getOutgoers,
SelectionMode, SelectionMode,
} from 'reactflow'; } from 'reactflow';
import type { UndoableOptions } from 'redux-undo';
import { import {
socketGeneratorProgress, socketGeneratorProgress,
socketInvocationComplete, socketInvocationComplete,
@ -705,6 +706,8 @@ export const nodesSlice = createSlice({
nodeTemplatesBuilt: (state, action: PayloadAction<Record<string, InvocationTemplate>>) => { nodeTemplatesBuilt: (state, action: PayloadAction<Record<string, InvocationTemplate>>) => {
state.templates = action.payload; state.templates = action.payload;
}, },
undo: (state) => state,
redo: (state) => state,
}, },
extraReducers: (builder) => { extraReducers: (builder) => {
builder.addCase(workflowLoaded, (state, action) => { builder.addCase(workflowLoaded, (state, action) => {
@ -836,6 +839,8 @@ export const {
edgeAdded, edgeAdded,
nodeTemplatesBuilt, nodeTemplatesBuilt,
shouldShowEdgeLabelsChanged, shouldShowEdgeLabelsChanged,
undo,
redo,
} = nodesSlice.actions; } = nodesSlice.actions;
// This is used for tracking `state.workflow.isTouched` // This is used for tracking `state.workflow.isTouched`
@ -874,7 +879,7 @@ export const isAnyNodeOrEdgeMutation = isAnyOf(
edgeAdded edgeAdded
); );
export const selectNodesSlice = (state: RootState) => state.nodes; export const selectNodesSlice = (state: RootState) => state.nodes.present;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */ /* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrateNodesState = (state: any): any => { const migrateNodesState = (state: any): any => {
@ -900,3 +905,15 @@ export const nodesPersistConfig: PersistConfig<NodesState> = {
'addNewNodePosition', 'addNewNodePosition',
], ],
}; };
export const nodesUndoableConfig: UndoableOptions<NodesState, UnknownAction> = {
limit: 64,
undoType: nodesSlice.actions.undo.type,
redoType: nodesSlice.actions.redo.type,
groupBy: (action, state, history) => {
return null;
},
filter: (action, _state, _history) => {
return true;
},
};

View File

@ -18,7 +18,7 @@ import { v4 as uuidv4 } from 'uuid';
* @returns The workflow. * @returns The workflow.
*/ */
export const graphToWorkflow = (graph: NonNullableGraph, autoLayout = true): WorkflowV3 => { export const graphToWorkflow = (graph: NonNullableGraph, autoLayout = true): WorkflowV3 => {
const invocationTemplates = getStore().getState().nodes.templates; const invocationTemplates = getStore().getState().nodes.present.templates;
if (!invocationTemplates) { if (!invocationTemplates) {
throw new Error(t('app.storeNotInitialized')); throw new Error(t('app.storeNotInitialized'));

View File

@ -33,7 +33,7 @@ const zWorkflowMetaVersion = z.object({
* - Workflow schema version bumped to 2.0.0 * - Workflow schema version bumped to 2.0.0
*/ */
const migrateV1toV2 = (workflowToMigrate: WorkflowV1): WorkflowV2 => { const migrateV1toV2 = (workflowToMigrate: WorkflowV1): WorkflowV2 => {
const invocationTemplates = $store.get()?.getState().nodes.templates; const invocationTemplates = $store.get()?.getState().nodes.present.templates;
if (!invocationTemplates) { if (!invocationTemplates) {
throw new Error(t('app.storeNotInitialized')); throw new Error(t('app.storeNotInitialized'));