mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): make nodesSlice undoable
This commit is contained in:
parent
31d8b50276
commit
27826369f0
@ -21,7 +21,7 @@ export const addDeleteBoardAndImagesFulfilledListener = (startAppListening: AppS
|
||||
|
||||
const { canvas, nodes, controlAdapters, controlLayers } = getState();
|
||||
deleted_images.forEach((image_name) => {
|
||||
const imageUsage = getImageUsage(canvas, nodes, controlAdapters, controlLayers.present, image_name);
|
||||
const imageUsage = getImageUsage(canvas, nodes.present, controlAdapters, controlLayers.present, image_name);
|
||||
|
||||
if (imageUsage.isCanvasImage && !wasCanvasReset) {
|
||||
dispatch(resetCanvas());
|
||||
|
@ -11,9 +11,9 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
|
||||
enqueueRequested.match(action) && action.payload.tabName === 'workflows',
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
const state = getState();
|
||||
const { nodes, edges } = state.nodes;
|
||||
const { nodes, edges } = state.nodes.present;
|
||||
const workflow = state.workflow;
|
||||
const graph = buildNodesGraph(state.nodes);
|
||||
const graph = buildNodesGraph(state.nodes.present);
|
||||
const builtWorkflow = buildWorkflowWithValidation({
|
||||
nodes,
|
||||
edges,
|
||||
|
@ -29,7 +29,7 @@ import type { ImageDTO } from 'services/api/types';
|
||||
import { imagesSelectors } from 'services/api/util';
|
||||
|
||||
const deleteNodesImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
|
||||
state.nodes.nodes.forEach((node) => {
|
||||
state.nodes.present.nodes.forEach((node) => {
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
|
@ -14,7 +14,7 @@ export const addUpdateAllNodesRequestedListener = (startAppListening: AppStartLi
|
||||
actionCreator: updateAllNodesRequested,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const log = logger('nodes');
|
||||
const { nodes, templates } = getState().nodes;
|
||||
const { nodes, templates } = getState().nodes.present;
|
||||
|
||||
let unableToUpdateCount = 0;
|
||||
|
||||
|
@ -17,7 +17,7 @@ export const addWorkflowLoadRequestedListener = (startAppListening: AppStartList
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const log = logger('nodes');
|
||||
const { workflow, asCopy } = action.payload;
|
||||
const nodeTemplates = getState().nodes.templates;
|
||||
const nodeTemplates = getState().nodes.present.templates;
|
||||
|
||||
try {
|
||||
const { workflow: validatedWorkflow, warnings } = validateWorkflow(workflow, nodeTemplates);
|
||||
|
@ -21,7 +21,7 @@ import { galleryPersistConfig, gallerySlice } from 'features/gallery/store/galle
|
||||
import { hrfPersistConfig, hrfSlice } from 'features/hrf/store/hrfSlice';
|
||||
import { loraPersistConfig, loraSlice } from 'features/lora/store/loraSlice';
|
||||
import { modelManagerV2PersistConfig, modelManagerV2Slice } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { nodesPersistConfig, nodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { nodesPersistConfig, nodesSlice, nodesUndoableConfig } from 'features/nodes/store/nodesSlice';
|
||||
import { workflowPersistConfig, workflowSlice } from 'features/nodes/store/workflowSlice';
|
||||
import { generationPersistConfig, generationSlice } from 'features/parameters/store/generationSlice';
|
||||
import { postprocessingPersistConfig, postprocessingSlice } from 'features/parameters/store/postprocessingSlice';
|
||||
@ -50,7 +50,7 @@ const allReducers = {
|
||||
[canvasSlice.name]: canvasSlice.reducer,
|
||||
[gallerySlice.name]: gallerySlice.reducer,
|
||||
[generationSlice.name]: generationSlice.reducer,
|
||||
[nodesSlice.name]: nodesSlice.reducer,
|
||||
[nodesSlice.name]: undoable(nodesSlice.reducer, nodesUndoableConfig),
|
||||
[postprocessingSlice.name]: postprocessingSlice.reducer,
|
||||
[systemSlice.name]: systemSlice.reducer,
|
||||
[configSlice.name]: configSlice.reducer,
|
||||
|
@ -55,8 +55,8 @@ const AddNodePopover = () => {
|
||||
const selectRef = useRef<SelectInstance<ComboboxOption> | null>(null);
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
const fieldFilter = useAppSelector((s) => s.nodes.connectionStartFieldType);
|
||||
const handleFilter = useAppSelector((s) => s.nodes.connectionStartParams?.handleType);
|
||||
const fieldFilter = useAppSelector((s) => s.nodes.present.connectionStartFieldType);
|
||||
const handleFilter = useAppSelector((s) => s.nodes.present.connectionStartParams?.handleType);
|
||||
|
||||
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
// If we have a connection in progress, we need to filter the node choices
|
||||
@ -105,7 +105,7 @@ const AddNodePopover = () => {
|
||||
});
|
||||
|
||||
const { options } = useAppSelector(selector);
|
||||
const isOpen = useAppSelector((s) => s.nodes.isAddNodePopoverOpen);
|
||||
const isOpen = useAppSelector((s) => s.nodes.present.isAddNodePopoverOpen);
|
||||
|
||||
const addNode = useCallback(
|
||||
(nodeType: string) => {
|
||||
|
@ -14,11 +14,13 @@ import {
|
||||
edgesDeleted,
|
||||
nodesChanged,
|
||||
nodesDeleted,
|
||||
redo,
|
||||
selectedAll,
|
||||
selectedEdgesChanged,
|
||||
selectedNodesChanged,
|
||||
selectionCopied,
|
||||
selectionPasted,
|
||||
undo,
|
||||
viewportChanged,
|
||||
} from 'features/nodes/store/nodesSlice';
|
||||
import { $flow } from 'features/nodes/store/reactFlowInstance';
|
||||
@ -70,11 +72,11 @@ const snapGrid: [number, number] = [25, 25];
|
||||
|
||||
export const Flow = memo(() => {
|
||||
const dispatch = useAppDispatch();
|
||||
const nodes = useAppSelector((s) => s.nodes.nodes);
|
||||
const edges = useAppSelector((s) => s.nodes.edges);
|
||||
const viewport = useAppSelector((s) => s.nodes.viewport);
|
||||
const shouldSnapToGrid = useAppSelector((s) => s.nodes.shouldSnapToGrid);
|
||||
const selectionMode = useAppSelector((s) => s.nodes.selectionMode);
|
||||
const nodes = useAppSelector((s) => s.nodes.present.nodes);
|
||||
const edges = useAppSelector((s) => s.nodes.present.edges);
|
||||
const viewport = useAppSelector((s) => s.nodes.present.viewport);
|
||||
const shouldSnapToGrid = useAppSelector((s) => s.nodes.present.shouldSnapToGrid);
|
||||
const selectionMode = useAppSelector((s) => s.nodes.present.selectionMode);
|
||||
const flowWrapper = useRef<HTMLDivElement>(null);
|
||||
const cursorPosition = useRef<XYPosition | null>(null);
|
||||
const isValidConnection = useIsValidConnection();
|
||||
@ -251,6 +253,22 @@ export const Flow = memo(() => {
|
||||
dispatch(selectionPasted({ cursorPosition: cursorPosition.current }));
|
||||
});
|
||||
|
||||
useHotkeys(
|
||||
['meta+z', 'ctrl+z'],
|
||||
() => {
|
||||
dispatch(undo());
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
useHotkeys(
|
||||
['meta+shift+z', 'ctrl+shift+z'],
|
||||
() => {
|
||||
dispatch(redo());
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<ReactFlow
|
||||
id="workflow-editor"
|
||||
|
@ -27,7 +27,7 @@ const InvocationDefaultEdge = ({
|
||||
);
|
||||
|
||||
const { isSelected, shouldAnimate, stroke, label } = useAppSelector(selector);
|
||||
const shouldShowEdgeLabels = useAppSelector((s) => s.nodes.shouldShowEdgeLabels);
|
||||
const shouldShowEdgeLabels = useAppSelector((s) => s.nodes.present.shouldShowEdgeLabels);
|
||||
|
||||
const [edgePath, labelX, labelY] = getBezierPath({
|
||||
sourceX,
|
||||
|
@ -39,7 +39,7 @@ const NodeWrapper = (props: NodeWrapperProps) => {
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const opacity = useAppSelector((s) => s.nodes.nodeOpacity);
|
||||
const opacity = useAppSelector((s) => s.nodes.present.nodeOpacity);
|
||||
const { onCloseGlobal } = useGlobalMenuClose();
|
||||
|
||||
const handleClick = useCallback(
|
||||
|
@ -6,7 +6,7 @@ import { useTranslation } from 'react-i18next';
|
||||
|
||||
const NodeOpacitySlider = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const nodeOpacity = useAppSelector((s) => s.nodes.nodeOpacity);
|
||||
const nodeOpacity = useAppSelector((s) => s.nodes.present.nodeOpacity);
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleChange = useCallback(
|
||||
|
@ -19,9 +19,9 @@ const ViewportControls = () => {
|
||||
const { zoomIn, zoomOut, fitView } = useReactFlow();
|
||||
const dispatch = useAppDispatch();
|
||||
// const shouldShowFieldTypeLegend = useAppSelector(
|
||||
// (s) => s.nodes.shouldShowFieldTypeLegend
|
||||
// (s) => s.nodes.present.shouldShowFieldTypeLegend
|
||||
// );
|
||||
const shouldShowMinimapPanel = useAppSelector((s) => s.nodes.shouldShowMinimapPanel);
|
||||
const shouldShowMinimapPanel = useAppSelector((s) => s.nodes.present.shouldShowMinimapPanel);
|
||||
|
||||
const handleClickedZoomIn = useCallback(() => {
|
||||
zoomIn();
|
||||
|
@ -16,7 +16,7 @@ const minimapStyles: SystemStyleObject = {
|
||||
};
|
||||
|
||||
const MinimapPanel = () => {
|
||||
const shouldShowMinimapPanel = useAppSelector((s) => s.nodes.shouldShowMinimapPanel);
|
||||
const shouldShowMinimapPanel = useAppSelector((s) => s.nodes.present.shouldShowMinimapPanel);
|
||||
|
||||
return (
|
||||
<Flex gap={2} position="absolute" bottom={0} insetInlineEnd={0}>
|
||||
|
@ -8,7 +8,7 @@ import { useCallback } from 'react';
|
||||
import { useReactFlow } from 'reactflow';
|
||||
|
||||
export const useBuildNode = () => {
|
||||
const nodeTemplates = useAppSelector((s) => s.nodes.templates);
|
||||
const nodeTemplates = useAppSelector((s) => s.nodes.present.templates);
|
||||
|
||||
const flow = useReactFlow();
|
||||
|
||||
|
@ -13,7 +13,7 @@ import type { Connection, Node } from 'reactflow';
|
||||
|
||||
export const useIsValidConnection = () => {
|
||||
const store = useAppStore();
|
||||
const shouldValidateGraph = useAppSelector((s) => s.nodes.shouldValidateGraph);
|
||||
const shouldValidateGraph = useAppSelector((s) => s.nodes.present.shouldValidateGraph);
|
||||
const isValidConnection = useCallback(
|
||||
({ source, sourceHandle, target, targetHandle }: Connection): boolean => {
|
||||
// Connection must have valid targets
|
||||
@ -27,7 +27,7 @@ export const useIsValidConnection = () => {
|
||||
}
|
||||
|
||||
const state = store.getState();
|
||||
const { nodes, edges, templates } = state.nodes;
|
||||
const { nodes, edges, templates } = state.nodes.present;
|
||||
|
||||
// Find the source and target nodes
|
||||
const sourceNode = nodes.find((node) => node.id === source) as Node<InvocationNodeData>;
|
||||
|
@ -1,4 +1,4 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import type { PayloadAction, UnknownAction } from '@reduxjs/toolkit';
|
||||
import { createSlice, isAnyOf } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
@ -66,6 +66,7 @@ import {
|
||||
getOutgoers,
|
||||
SelectionMode,
|
||||
} from 'reactflow';
|
||||
import type { UndoableOptions } from 'redux-undo';
|
||||
import {
|
||||
socketGeneratorProgress,
|
||||
socketInvocationComplete,
|
||||
@ -705,6 +706,8 @@ export const nodesSlice = createSlice({
|
||||
nodeTemplatesBuilt: (state, action: PayloadAction<Record<string, InvocationTemplate>>) => {
|
||||
state.templates = action.payload;
|
||||
},
|
||||
undo: (state) => state,
|
||||
redo: (state) => state,
|
||||
},
|
||||
extraReducers: (builder) => {
|
||||
builder.addCase(workflowLoaded, (state, action) => {
|
||||
@ -836,6 +839,8 @@ export const {
|
||||
edgeAdded,
|
||||
nodeTemplatesBuilt,
|
||||
shouldShowEdgeLabelsChanged,
|
||||
undo,
|
||||
redo,
|
||||
} = nodesSlice.actions;
|
||||
|
||||
// This is used for tracking `state.workflow.isTouched`
|
||||
@ -874,7 +879,7 @@ export const isAnyNodeOrEdgeMutation = isAnyOf(
|
||||
edgeAdded
|
||||
);
|
||||
|
||||
export const selectNodesSlice = (state: RootState) => state.nodes;
|
||||
export const selectNodesSlice = (state: RootState) => state.nodes.present;
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
const migrateNodesState = (state: any): any => {
|
||||
@ -900,3 +905,15 @@ export const nodesPersistConfig: PersistConfig<NodesState> = {
|
||||
'addNewNodePosition',
|
||||
],
|
||||
};
|
||||
|
||||
export const nodesUndoableConfig: UndoableOptions<NodesState, UnknownAction> = {
|
||||
limit: 64,
|
||||
undoType: nodesSlice.actions.undo.type,
|
||||
redoType: nodesSlice.actions.redo.type,
|
||||
groupBy: (action, state, history) => {
|
||||
return null;
|
||||
},
|
||||
filter: (action, _state, _history) => {
|
||||
return true;
|
||||
},
|
||||
};
|
||||
|
@ -18,7 +18,7 @@ import { v4 as uuidv4 } from 'uuid';
|
||||
* @returns The workflow.
|
||||
*/
|
||||
export const graphToWorkflow = (graph: NonNullableGraph, autoLayout = true): WorkflowV3 => {
|
||||
const invocationTemplates = getStore().getState().nodes.templates;
|
||||
const invocationTemplates = getStore().getState().nodes.present.templates;
|
||||
|
||||
if (!invocationTemplates) {
|
||||
throw new Error(t('app.storeNotInitialized'));
|
||||
|
@ -33,7 +33,7 @@ const zWorkflowMetaVersion = z.object({
|
||||
* - Workflow schema version bumped to 2.0.0
|
||||
*/
|
||||
const migrateV1toV2 = (workflowToMigrate: WorkflowV1): WorkflowV2 => {
|
||||
const invocationTemplates = $store.get()?.getState().nodes.templates;
|
||||
const invocationTemplates = $store.get()?.getState().nodes.present.templates;
|
||||
|
||||
if (!invocationTemplates) {
|
||||
throw new Error(t('app.storeNotInitialized'));
|
||||
|
Loading…
Reference in New Issue
Block a user