mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): make nodesSlice undoable
This commit is contained in:
parent
31d8b50276
commit
27826369f0
@ -21,7 +21,7 @@ export const addDeleteBoardAndImagesFulfilledListener = (startAppListening: AppS
|
|||||||
|
|
||||||
const { canvas, nodes, controlAdapters, controlLayers } = getState();
|
const { canvas, nodes, controlAdapters, controlLayers } = getState();
|
||||||
deleted_images.forEach((image_name) => {
|
deleted_images.forEach((image_name) => {
|
||||||
const imageUsage = getImageUsage(canvas, nodes, controlAdapters, controlLayers.present, image_name);
|
const imageUsage = getImageUsage(canvas, nodes.present, controlAdapters, controlLayers.present, image_name);
|
||||||
|
|
||||||
if (imageUsage.isCanvasImage && !wasCanvasReset) {
|
if (imageUsage.isCanvasImage && !wasCanvasReset) {
|
||||||
dispatch(resetCanvas());
|
dispatch(resetCanvas());
|
||||||
|
@ -11,9 +11,9 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
|
|||||||
enqueueRequested.match(action) && action.payload.tabName === 'workflows',
|
enqueueRequested.match(action) && action.payload.tabName === 'workflows',
|
||||||
effect: async (action, { getState, dispatch }) => {
|
effect: async (action, { getState, dispatch }) => {
|
||||||
const state = getState();
|
const state = getState();
|
||||||
const { nodes, edges } = state.nodes;
|
const { nodes, edges } = state.nodes.present;
|
||||||
const workflow = state.workflow;
|
const workflow = state.workflow;
|
||||||
const graph = buildNodesGraph(state.nodes);
|
const graph = buildNodesGraph(state.nodes.present);
|
||||||
const builtWorkflow = buildWorkflowWithValidation({
|
const builtWorkflow = buildWorkflowWithValidation({
|
||||||
nodes,
|
nodes,
|
||||||
edges,
|
edges,
|
||||||
|
@ -29,7 +29,7 @@ import type { ImageDTO } from 'services/api/types';
|
|||||||
import { imagesSelectors } from 'services/api/util';
|
import { imagesSelectors } from 'services/api/util';
|
||||||
|
|
||||||
const deleteNodesImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
|
const deleteNodesImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
|
||||||
state.nodes.nodes.forEach((node) => {
|
state.nodes.present.nodes.forEach((node) => {
|
||||||
if (!isInvocationNode(node)) {
|
if (!isInvocationNode(node)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -14,7 +14,7 @@ export const addUpdateAllNodesRequestedListener = (startAppListening: AppStartLi
|
|||||||
actionCreator: updateAllNodesRequested,
|
actionCreator: updateAllNodesRequested,
|
||||||
effect: (action, { dispatch, getState }) => {
|
effect: (action, { dispatch, getState }) => {
|
||||||
const log = logger('nodes');
|
const log = logger('nodes');
|
||||||
const { nodes, templates } = getState().nodes;
|
const { nodes, templates } = getState().nodes.present;
|
||||||
|
|
||||||
let unableToUpdateCount = 0;
|
let unableToUpdateCount = 0;
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@ export const addWorkflowLoadRequestedListener = (startAppListening: AppStartList
|
|||||||
effect: (action, { dispatch, getState }) => {
|
effect: (action, { dispatch, getState }) => {
|
||||||
const log = logger('nodes');
|
const log = logger('nodes');
|
||||||
const { workflow, asCopy } = action.payload;
|
const { workflow, asCopy } = action.payload;
|
||||||
const nodeTemplates = getState().nodes.templates;
|
const nodeTemplates = getState().nodes.present.templates;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const { workflow: validatedWorkflow, warnings } = validateWorkflow(workflow, nodeTemplates);
|
const { workflow: validatedWorkflow, warnings } = validateWorkflow(workflow, nodeTemplates);
|
||||||
|
@ -21,7 +21,7 @@ import { galleryPersistConfig, gallerySlice } from 'features/gallery/store/galle
|
|||||||
import { hrfPersistConfig, hrfSlice } from 'features/hrf/store/hrfSlice';
|
import { hrfPersistConfig, hrfSlice } from 'features/hrf/store/hrfSlice';
|
||||||
import { loraPersistConfig, loraSlice } from 'features/lora/store/loraSlice';
|
import { loraPersistConfig, loraSlice } from 'features/lora/store/loraSlice';
|
||||||
import { modelManagerV2PersistConfig, modelManagerV2Slice } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
import { modelManagerV2PersistConfig, modelManagerV2Slice } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||||
import { nodesPersistConfig, nodesSlice } from 'features/nodes/store/nodesSlice';
|
import { nodesPersistConfig, nodesSlice, nodesUndoableConfig } from 'features/nodes/store/nodesSlice';
|
||||||
import { workflowPersistConfig, workflowSlice } from 'features/nodes/store/workflowSlice';
|
import { workflowPersistConfig, workflowSlice } from 'features/nodes/store/workflowSlice';
|
||||||
import { generationPersistConfig, generationSlice } from 'features/parameters/store/generationSlice';
|
import { generationPersistConfig, generationSlice } from 'features/parameters/store/generationSlice';
|
||||||
import { postprocessingPersistConfig, postprocessingSlice } from 'features/parameters/store/postprocessingSlice';
|
import { postprocessingPersistConfig, postprocessingSlice } from 'features/parameters/store/postprocessingSlice';
|
||||||
@ -50,7 +50,7 @@ const allReducers = {
|
|||||||
[canvasSlice.name]: canvasSlice.reducer,
|
[canvasSlice.name]: canvasSlice.reducer,
|
||||||
[gallerySlice.name]: gallerySlice.reducer,
|
[gallerySlice.name]: gallerySlice.reducer,
|
||||||
[generationSlice.name]: generationSlice.reducer,
|
[generationSlice.name]: generationSlice.reducer,
|
||||||
[nodesSlice.name]: nodesSlice.reducer,
|
[nodesSlice.name]: undoable(nodesSlice.reducer, nodesUndoableConfig),
|
||||||
[postprocessingSlice.name]: postprocessingSlice.reducer,
|
[postprocessingSlice.name]: postprocessingSlice.reducer,
|
||||||
[systemSlice.name]: systemSlice.reducer,
|
[systemSlice.name]: systemSlice.reducer,
|
||||||
[configSlice.name]: configSlice.reducer,
|
[configSlice.name]: configSlice.reducer,
|
||||||
|
@ -55,8 +55,8 @@ const AddNodePopover = () => {
|
|||||||
const selectRef = useRef<SelectInstance<ComboboxOption> | null>(null);
|
const selectRef = useRef<SelectInstance<ComboboxOption> | null>(null);
|
||||||
const inputRef = useRef<HTMLInputElement>(null);
|
const inputRef = useRef<HTMLInputElement>(null);
|
||||||
|
|
||||||
const fieldFilter = useAppSelector((s) => s.nodes.connectionStartFieldType);
|
const fieldFilter = useAppSelector((s) => s.nodes.present.connectionStartFieldType);
|
||||||
const handleFilter = useAppSelector((s) => s.nodes.connectionStartParams?.handleType);
|
const handleFilter = useAppSelector((s) => s.nodes.present.connectionStartParams?.handleType);
|
||||||
|
|
||||||
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
|
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||||
// If we have a connection in progress, we need to filter the node choices
|
// If we have a connection in progress, we need to filter the node choices
|
||||||
@ -105,7 +105,7 @@ const AddNodePopover = () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
const { options } = useAppSelector(selector);
|
const { options } = useAppSelector(selector);
|
||||||
const isOpen = useAppSelector((s) => s.nodes.isAddNodePopoverOpen);
|
const isOpen = useAppSelector((s) => s.nodes.present.isAddNodePopoverOpen);
|
||||||
|
|
||||||
const addNode = useCallback(
|
const addNode = useCallback(
|
||||||
(nodeType: string) => {
|
(nodeType: string) => {
|
||||||
|
@ -14,11 +14,13 @@ import {
|
|||||||
edgesDeleted,
|
edgesDeleted,
|
||||||
nodesChanged,
|
nodesChanged,
|
||||||
nodesDeleted,
|
nodesDeleted,
|
||||||
|
redo,
|
||||||
selectedAll,
|
selectedAll,
|
||||||
selectedEdgesChanged,
|
selectedEdgesChanged,
|
||||||
selectedNodesChanged,
|
selectedNodesChanged,
|
||||||
selectionCopied,
|
selectionCopied,
|
||||||
selectionPasted,
|
selectionPasted,
|
||||||
|
undo,
|
||||||
viewportChanged,
|
viewportChanged,
|
||||||
} from 'features/nodes/store/nodesSlice';
|
} from 'features/nodes/store/nodesSlice';
|
||||||
import { $flow } from 'features/nodes/store/reactFlowInstance';
|
import { $flow } from 'features/nodes/store/reactFlowInstance';
|
||||||
@ -70,11 +72,11 @@ const snapGrid: [number, number] = [25, 25];
|
|||||||
|
|
||||||
export const Flow = memo(() => {
|
export const Flow = memo(() => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const nodes = useAppSelector((s) => s.nodes.nodes);
|
const nodes = useAppSelector((s) => s.nodes.present.nodes);
|
||||||
const edges = useAppSelector((s) => s.nodes.edges);
|
const edges = useAppSelector((s) => s.nodes.present.edges);
|
||||||
const viewport = useAppSelector((s) => s.nodes.viewport);
|
const viewport = useAppSelector((s) => s.nodes.present.viewport);
|
||||||
const shouldSnapToGrid = useAppSelector((s) => s.nodes.shouldSnapToGrid);
|
const shouldSnapToGrid = useAppSelector((s) => s.nodes.present.shouldSnapToGrid);
|
||||||
const selectionMode = useAppSelector((s) => s.nodes.selectionMode);
|
const selectionMode = useAppSelector((s) => s.nodes.present.selectionMode);
|
||||||
const flowWrapper = useRef<HTMLDivElement>(null);
|
const flowWrapper = useRef<HTMLDivElement>(null);
|
||||||
const cursorPosition = useRef<XYPosition | null>(null);
|
const cursorPosition = useRef<XYPosition | null>(null);
|
||||||
const isValidConnection = useIsValidConnection();
|
const isValidConnection = useIsValidConnection();
|
||||||
@ -251,6 +253,22 @@ export const Flow = memo(() => {
|
|||||||
dispatch(selectionPasted({ cursorPosition: cursorPosition.current }));
|
dispatch(selectionPasted({ cursorPosition: cursorPosition.current }));
|
||||||
});
|
});
|
||||||
|
|
||||||
|
useHotkeys(
|
||||||
|
['meta+z', 'ctrl+z'],
|
||||||
|
() => {
|
||||||
|
dispatch(undo());
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
useHotkeys(
|
||||||
|
['meta+shift+z', 'ctrl+shift+z'],
|
||||||
|
() => {
|
||||||
|
dispatch(redo());
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<ReactFlow
|
<ReactFlow
|
||||||
id="workflow-editor"
|
id="workflow-editor"
|
||||||
|
@ -27,7 +27,7 @@ const InvocationDefaultEdge = ({
|
|||||||
);
|
);
|
||||||
|
|
||||||
const { isSelected, shouldAnimate, stroke, label } = useAppSelector(selector);
|
const { isSelected, shouldAnimate, stroke, label } = useAppSelector(selector);
|
||||||
const shouldShowEdgeLabels = useAppSelector((s) => s.nodes.shouldShowEdgeLabels);
|
const shouldShowEdgeLabels = useAppSelector((s) => s.nodes.present.shouldShowEdgeLabels);
|
||||||
|
|
||||||
const [edgePath, labelX, labelY] = getBezierPath({
|
const [edgePath, labelX, labelY] = getBezierPath({
|
||||||
sourceX,
|
sourceX,
|
||||||
|
@ -39,7 +39,7 @@ const NodeWrapper = (props: NodeWrapperProps) => {
|
|||||||
|
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const opacity = useAppSelector((s) => s.nodes.nodeOpacity);
|
const opacity = useAppSelector((s) => s.nodes.present.nodeOpacity);
|
||||||
const { onCloseGlobal } = useGlobalMenuClose();
|
const { onCloseGlobal } = useGlobalMenuClose();
|
||||||
|
|
||||||
const handleClick = useCallback(
|
const handleClick = useCallback(
|
||||||
|
@ -6,7 +6,7 @@ import { useTranslation } from 'react-i18next';
|
|||||||
|
|
||||||
const NodeOpacitySlider = () => {
|
const NodeOpacitySlider = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const nodeOpacity = useAppSelector((s) => s.nodes.nodeOpacity);
|
const nodeOpacity = useAppSelector((s) => s.nodes.present.nodeOpacity);
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const handleChange = useCallback(
|
const handleChange = useCallback(
|
||||||
|
@ -19,9 +19,9 @@ const ViewportControls = () => {
|
|||||||
const { zoomIn, zoomOut, fitView } = useReactFlow();
|
const { zoomIn, zoomOut, fitView } = useReactFlow();
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
// const shouldShowFieldTypeLegend = useAppSelector(
|
// const shouldShowFieldTypeLegend = useAppSelector(
|
||||||
// (s) => s.nodes.shouldShowFieldTypeLegend
|
// (s) => s.nodes.present.shouldShowFieldTypeLegend
|
||||||
// );
|
// );
|
||||||
const shouldShowMinimapPanel = useAppSelector((s) => s.nodes.shouldShowMinimapPanel);
|
const shouldShowMinimapPanel = useAppSelector((s) => s.nodes.present.shouldShowMinimapPanel);
|
||||||
|
|
||||||
const handleClickedZoomIn = useCallback(() => {
|
const handleClickedZoomIn = useCallback(() => {
|
||||||
zoomIn();
|
zoomIn();
|
||||||
|
@ -16,7 +16,7 @@ const minimapStyles: SystemStyleObject = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const MinimapPanel = () => {
|
const MinimapPanel = () => {
|
||||||
const shouldShowMinimapPanel = useAppSelector((s) => s.nodes.shouldShowMinimapPanel);
|
const shouldShowMinimapPanel = useAppSelector((s) => s.nodes.present.shouldShowMinimapPanel);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex gap={2} position="absolute" bottom={0} insetInlineEnd={0}>
|
<Flex gap={2} position="absolute" bottom={0} insetInlineEnd={0}>
|
||||||
|
@ -8,7 +8,7 @@ import { useCallback } from 'react';
|
|||||||
import { useReactFlow } from 'reactflow';
|
import { useReactFlow } from 'reactflow';
|
||||||
|
|
||||||
export const useBuildNode = () => {
|
export const useBuildNode = () => {
|
||||||
const nodeTemplates = useAppSelector((s) => s.nodes.templates);
|
const nodeTemplates = useAppSelector((s) => s.nodes.present.templates);
|
||||||
|
|
||||||
const flow = useReactFlow();
|
const flow = useReactFlow();
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ import type { Connection, Node } from 'reactflow';
|
|||||||
|
|
||||||
export const useIsValidConnection = () => {
|
export const useIsValidConnection = () => {
|
||||||
const store = useAppStore();
|
const store = useAppStore();
|
||||||
const shouldValidateGraph = useAppSelector((s) => s.nodes.shouldValidateGraph);
|
const shouldValidateGraph = useAppSelector((s) => s.nodes.present.shouldValidateGraph);
|
||||||
const isValidConnection = useCallback(
|
const isValidConnection = useCallback(
|
||||||
({ source, sourceHandle, target, targetHandle }: Connection): boolean => {
|
({ source, sourceHandle, target, targetHandle }: Connection): boolean => {
|
||||||
// Connection must have valid targets
|
// Connection must have valid targets
|
||||||
@ -27,7 +27,7 @@ export const useIsValidConnection = () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const state = store.getState();
|
const state = store.getState();
|
||||||
const { nodes, edges, templates } = state.nodes;
|
const { nodes, edges, templates } = state.nodes.present;
|
||||||
|
|
||||||
// Find the source and target nodes
|
// Find the source and target nodes
|
||||||
const sourceNode = nodes.find((node) => node.id === source) as Node<InvocationNodeData>;
|
const sourceNode = nodes.find((node) => node.id === source) as Node<InvocationNodeData>;
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
import type { PayloadAction, UnknownAction } from '@reduxjs/toolkit';
|
||||||
import { createSlice, isAnyOf } from '@reduxjs/toolkit';
|
import { createSlice, isAnyOf } from '@reduxjs/toolkit';
|
||||||
import type { PersistConfig, RootState } from 'app/store/store';
|
import type { PersistConfig, RootState } from 'app/store/store';
|
||||||
import { deepClone } from 'common/util/deepClone';
|
import { deepClone } from 'common/util/deepClone';
|
||||||
@ -66,6 +66,7 @@ import {
|
|||||||
getOutgoers,
|
getOutgoers,
|
||||||
SelectionMode,
|
SelectionMode,
|
||||||
} from 'reactflow';
|
} from 'reactflow';
|
||||||
|
import type { UndoableOptions } from 'redux-undo';
|
||||||
import {
|
import {
|
||||||
socketGeneratorProgress,
|
socketGeneratorProgress,
|
||||||
socketInvocationComplete,
|
socketInvocationComplete,
|
||||||
@ -705,6 +706,8 @@ export const nodesSlice = createSlice({
|
|||||||
nodeTemplatesBuilt: (state, action: PayloadAction<Record<string, InvocationTemplate>>) => {
|
nodeTemplatesBuilt: (state, action: PayloadAction<Record<string, InvocationTemplate>>) => {
|
||||||
state.templates = action.payload;
|
state.templates = action.payload;
|
||||||
},
|
},
|
||||||
|
undo: (state) => state,
|
||||||
|
redo: (state) => state,
|
||||||
},
|
},
|
||||||
extraReducers: (builder) => {
|
extraReducers: (builder) => {
|
||||||
builder.addCase(workflowLoaded, (state, action) => {
|
builder.addCase(workflowLoaded, (state, action) => {
|
||||||
@ -836,6 +839,8 @@ export const {
|
|||||||
edgeAdded,
|
edgeAdded,
|
||||||
nodeTemplatesBuilt,
|
nodeTemplatesBuilt,
|
||||||
shouldShowEdgeLabelsChanged,
|
shouldShowEdgeLabelsChanged,
|
||||||
|
undo,
|
||||||
|
redo,
|
||||||
} = nodesSlice.actions;
|
} = nodesSlice.actions;
|
||||||
|
|
||||||
// This is used for tracking `state.workflow.isTouched`
|
// This is used for tracking `state.workflow.isTouched`
|
||||||
@ -874,7 +879,7 @@ export const isAnyNodeOrEdgeMutation = isAnyOf(
|
|||||||
edgeAdded
|
edgeAdded
|
||||||
);
|
);
|
||||||
|
|
||||||
export const selectNodesSlice = (state: RootState) => state.nodes;
|
export const selectNodesSlice = (state: RootState) => state.nodes.present;
|
||||||
|
|
||||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||||
const migrateNodesState = (state: any): any => {
|
const migrateNodesState = (state: any): any => {
|
||||||
@ -900,3 +905,15 @@ export const nodesPersistConfig: PersistConfig<NodesState> = {
|
|||||||
'addNewNodePosition',
|
'addNewNodePosition',
|
||||||
],
|
],
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const nodesUndoableConfig: UndoableOptions<NodesState, UnknownAction> = {
|
||||||
|
limit: 64,
|
||||||
|
undoType: nodesSlice.actions.undo.type,
|
||||||
|
redoType: nodesSlice.actions.redo.type,
|
||||||
|
groupBy: (action, state, history) => {
|
||||||
|
return null;
|
||||||
|
},
|
||||||
|
filter: (action, _state, _history) => {
|
||||||
|
return true;
|
||||||
|
},
|
||||||
|
};
|
||||||
|
@ -18,7 +18,7 @@ import { v4 as uuidv4 } from 'uuid';
|
|||||||
* @returns The workflow.
|
* @returns The workflow.
|
||||||
*/
|
*/
|
||||||
export const graphToWorkflow = (graph: NonNullableGraph, autoLayout = true): WorkflowV3 => {
|
export const graphToWorkflow = (graph: NonNullableGraph, autoLayout = true): WorkflowV3 => {
|
||||||
const invocationTemplates = getStore().getState().nodes.templates;
|
const invocationTemplates = getStore().getState().nodes.present.templates;
|
||||||
|
|
||||||
if (!invocationTemplates) {
|
if (!invocationTemplates) {
|
||||||
throw new Error(t('app.storeNotInitialized'));
|
throw new Error(t('app.storeNotInitialized'));
|
||||||
|
@ -33,7 +33,7 @@ const zWorkflowMetaVersion = z.object({
|
|||||||
* - Workflow schema version bumped to 2.0.0
|
* - Workflow schema version bumped to 2.0.0
|
||||||
*/
|
*/
|
||||||
const migrateV1toV2 = (workflowToMigrate: WorkflowV1): WorkflowV2 => {
|
const migrateV1toV2 = (workflowToMigrate: WorkflowV1): WorkflowV2 => {
|
||||||
const invocationTemplates = $store.get()?.getState().nodes.templates;
|
const invocationTemplates = $store.get()?.getState().nodes.present.templates;
|
||||||
|
|
||||||
if (!invocationTemplates) {
|
if (!invocationTemplates) {
|
||||||
throw new Error(t('app.storeNotInitialized'));
|
throw new Error(t('app.storeNotInitialized'));
|
||||||
|
Loading…
Reference in New Issue
Block a user