feat(ui): move nodes copy/paste out of slice

This commit is contained in:
psychedelicious 2024-05-16 14:20:22 +10:00
parent 9c0d44b412
commit d4df312300
5 changed files with 124 additions and 124 deletions

View File

@ -1,9 +1,11 @@
import { useGlobalMenuClose, useToken } from '@invoke-ai/ui-library'; import { useGlobalMenuClose, useToken } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useCopyPaste } from 'features/nodes/hooks/useCopyPaste';
import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection'; import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection';
import { $mouseOverNode } from 'features/nodes/hooks/useMouseOverNode'; import { $mouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
import { useWorkflowWatcher } from 'features/nodes/hooks/useWorkflowWatcher'; import { useWorkflowWatcher } from 'features/nodes/hooks/useWorkflowWatcher';
import { import {
$cursorPos,
connectionEnded, connectionEnded,
connectionMade, connectionMade,
connectionStarted, connectionStarted,
@ -18,8 +20,6 @@ import {
selectedAll, selectedAll,
selectedEdgesChanged, selectedEdgesChanged,
selectedNodesChanged, selectedNodesChanged,
selectionCopied,
selectionPasted,
undo, undo,
viewportChanged, viewportChanged,
} from 'features/nodes/store/nodesSlice'; } from 'features/nodes/store/nodesSlice';
@ -41,7 +41,6 @@ import type {
OnSelectionChangeFunc, OnSelectionChangeFunc,
ProOptions, ProOptions,
ReactFlowProps, ReactFlowProps,
XYPosition,
} from 'reactflow'; } from 'reactflow';
import { Background, ReactFlow } from 'reactflow'; import { Background, ReactFlow } from 'reactflow';
@ -78,7 +77,6 @@ export const Flow = memo(() => {
const shouldSnapToGrid = useAppSelector((s) => s.workflowSettings.shouldSnapToGrid); const shouldSnapToGrid = useAppSelector((s) => s.workflowSettings.shouldSnapToGrid);
const selectionMode = useAppSelector((s) => s.workflowSettings.selectionMode); const selectionMode = useAppSelector((s) => s.workflowSettings.selectionMode);
const flowWrapper = useRef<HTMLDivElement>(null); const flowWrapper = useRef<HTMLDivElement>(null);
const cursorPosition = useRef<XYPosition | null>(null);
const isValidConnection = useIsValidConnection(); const isValidConnection = useIsValidConnection();
useWorkflowWatcher(); useWorkflowWatcher();
const [borderRadius] = useToken('radii', ['base']); const [borderRadius] = useToken('radii', ['base']);
@ -119,12 +117,13 @@ export const Flow = memo(() => {
); );
const onConnectEnd: OnConnectEnd = useCallback(() => { const onConnectEnd: OnConnectEnd = useCallback(() => {
if (!cursorPosition.current) { const cursorPosition = $cursorPos.get();
if (!cursorPosition) {
return; return;
} }
dispatch( dispatch(
connectionEnded({ connectionEnded({
cursorPosition: cursorPosition.current, cursorPosition,
mouseOverNodeId: $mouseOverNode.get(), mouseOverNodeId: $mouseOverNode.get(),
}) })
); );
@ -171,11 +170,12 @@ export const Flow = memo(() => {
const onMouseMove = useCallback((event: MouseEvent<HTMLDivElement>) => { const onMouseMove = useCallback((event: MouseEvent<HTMLDivElement>) => {
if (flowWrapper.current?.getBoundingClientRect()) { if (flowWrapper.current?.getBoundingClientRect()) {
cursorPosition.current = $cursorPos.set(
$flow.get()?.screenToFlowPosition({ $flow.get()?.screenToFlowPosition({
x: event.clientX, x: event.clientX,
y: event.clientY, y: event.clientY,
}) ?? null; }) ?? null
);
} }
}, []); }, []);
@ -235,9 +235,11 @@ export const Flow = memo(() => {
// #endregion // #endregion
const { copySelection, pasteSelection } = useCopyPaste();
useHotkeys(['Ctrl+c', 'Meta+c'], (e) => { useHotkeys(['Ctrl+c', 'Meta+c'], (e) => {
e.preventDefault(); e.preventDefault();
dispatch(selectionCopied()); copySelection();
}); });
useHotkeys(['Ctrl+a', 'Meta+a'], (e) => { useHotkeys(['Ctrl+a', 'Meta+a'], (e) => {
@ -246,11 +248,8 @@ export const Flow = memo(() => {
}); });
useHotkeys(['Ctrl+v', 'Meta+v'], (e) => { useHotkeys(['Ctrl+v', 'Meta+v'], (e) => {
if (!cursorPosition.current) {
return;
}
e.preventDefault(); e.preventDefault();
dispatch(selectionPasted({ cursorPosition: cursorPosition.current })); pasteSelection();
}); });
useHotkeys( useHotkeys(

View File

@ -0,0 +1,63 @@
import { getStore } from 'app/store/nanostores/store';
import { deepClone } from 'common/util/deepClone';
import { $copiedEdges,$copiedNodes,$cursorPos, selectionPasted, selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { findUnoccupiedPosition } from 'features/nodes/store/util/findUnoccupiedPosition';
import { v4 as uuidv4 } from 'uuid';
const copySelection = () => {
// Use the imperative API here so we don't have to pass the whole slice around
const { getState } = getStore();
const { nodes, edges } = selectNodesSlice(getState());
const selectedNodes = nodes.filter((node) => node.selected);
const selectedEdges = edges.filter((edge) => edge.selected);
$copiedNodes.set(selectedNodes);
$copiedEdges.set(selectedEdges);
};
const pasteSelection = () => {
const { getState, dispatch } = getStore();
const currentNodes = selectNodesSlice(getState()).nodes;
const cursorPos = $cursorPos.get();
const copiedNodes = deepClone($copiedNodes.get());
const copiedEdges = deepClone($copiedEdges.get());
// Calculate an offset to reposition nodes to surround the cursor position, maintaining relative positioning
const xCoords = copiedNodes.map((node) => node.position.x);
const yCoords = copiedNodes.map((node) => node.position.y);
const minX = Math.min(...xCoords);
const minY = Math.min(...yCoords);
const offsetX = cursorPos ? cursorPos.x - minX : 50;
const offsetY = cursorPos ? cursorPos.y - minY : 50;
copiedNodes.forEach((node) => {
const { x, y } = findUnoccupiedPosition(currentNodes, node.position.x + offsetX, node.position.y + offsetY);
node.position.x = x;
node.position.y = y;
// Pasted nodes are selected
node.selected = true;
// Also give em a fresh id
const id = uuidv4();
// Update the edges to point to the new node id
for (const edge of copiedEdges) {
if (edge.source === node.id) {
edge.source = id;
edge.id = edge.id.replace(node.data.id, id);
}
if (edge.target === node.id) {
edge.target = id;
edge.id = edge.id.replace(node.data.id, id);
}
}
node.id = id;
node.data.id = id;
});
dispatch(selectionPasted({ nodes: copiedNodes, edges: copiedEdges }));
};
const api = { copySelection, pasteSelection };
export const useCopyPaste = () => {
return api;
};

View File

@ -43,9 +43,15 @@ import {
zT2IAdapterModelFieldValue, zT2IAdapterModelFieldValue,
zVAEModelFieldValue, zVAEModelFieldValue,
} from 'features/nodes/types/field'; } from 'features/nodes/types/field';
import type { AnyNode, InvocationTemplate, NodeExecutionState } from 'features/nodes/types/invocation'; import type {
AnyNode,
InvocationNodeEdge,
InvocationTemplate,
NodeExecutionState,
} from 'features/nodes/types/invocation';
import { isInvocationNode, isNotesNode, zNodeStatus } from 'features/nodes/types/invocation'; import { isInvocationNode, isNotesNode, zNodeStatus } from 'features/nodes/types/invocation';
import { forEach } from 'lodash-es'; import { forEach } from 'lodash-es';
import { atom } from 'nanostores';
import type { import type {
Connection, Connection,
Edge, Edge,
@ -66,7 +72,6 @@ import {
socketInvocationStarted, socketInvocationStarted,
socketQueueItemStatusChanged, socketQueueItemStatusChanged,
} from 'services/events/actions'; } from 'services/events/actions';
import { v4 as uuidv4 } from 'uuid';
import type { z } from 'zod'; import type { z } from 'zod';
import type { NodesState } from './types'; import type { NodesState } from './types';
@ -96,8 +101,6 @@ const initialNodesState: NodesState = {
selectedEdges: [], selectedEdges: [],
nodeExecutionStates: {}, nodeExecutionStates: {},
viewport: { x: 0, y: 0, zoom: 1 }, viewport: { x: 0, y: 0, zoom: 1 },
nodesToCopy: [],
edgesToCopy: [],
}; };
type FieldValueAction<T extends FieldValue> = PayloadAction<{ type FieldValueAction<T extends FieldValue> = PayloadAction<{
@ -539,116 +542,52 @@ export const nodesSlice = createSlice({
state.edges state.edges
); );
}, },
selectionCopied: (state) => { selectionPasted: (state, action: PayloadAction<{ nodes: AnyNode[]; edges: InvocationNodeEdge[] }>) => {
const nodesToCopy: AnyNode[] = []; const { nodes, edges } = action.payload;
const edgesToCopy: Edge[] = [];
for (const node of state.nodes) { const nodeChanges: NodeChange[] = [];
if (node.selected) {
nodesToCopy.push(deepClone(node));
}
}
for (const edge of state.edges) { // Deselect existing nodes
if (edge.selected) { state.nodes.forEach((n) => {
edgesToCopy.push(deepClone(edge)); nodeChanges.push({
}
}
state.nodesToCopy = nodesToCopy;
state.edgesToCopy = edgesToCopy;
if (state.nodesToCopy.length > 0) {
const averagePosition = { x: 0, y: 0 };
state.nodesToCopy.forEach((e) => {
const xOffset = 0.15 * (e.width ?? 0);
const yOffset = 0.5 * (e.height ?? 0);
averagePosition.x += e.position.x + xOffset;
averagePosition.y += e.position.y + yOffset;
});
averagePosition.x /= state.nodesToCopy.length;
averagePosition.y /= state.nodesToCopy.length;
state.nodesToCopy.forEach((e) => {
e.position.x -= averagePosition.x;
e.position.y -= averagePosition.y;
});
}
},
selectionPasted: (state, action: PayloadAction<{ cursorPosition?: XYPosition }>) => {
const { cursorPosition } = action.payload;
const newNodes: AnyNode[] = [];
for (const node of state.nodesToCopy) {
newNodes.push(deepClone(node));
}
const oldNodeIds = newNodes.map((n) => n.data.id);
const newEdges: Edge[] = [];
for (const edge of state.edgesToCopy) {
if (oldNodeIds.includes(edge.source) && oldNodeIds.includes(edge.target)) {
newEdges.push(deepClone(edge));
}
}
newEdges.forEach((e) => (e.selected = true));
newNodes.forEach((node) => {
const newNodeId = uuidv4();
newEdges.forEach((edge) => {
if (edge.source === node.data.id) {
edge.source = newNodeId;
edge.id = edge.id.replace(node.data.id, newNodeId);
}
if (edge.target === node.data.id) {
edge.target = newNodeId;
edge.id = edge.id.replace(node.data.id, newNodeId);
}
});
node.selected = true;
node.id = newNodeId;
node.data.id = newNodeId;
const position = findUnoccupiedPosition(
state.nodes,
node.position.x + (cursorPosition?.x ?? 0),
node.position.y + (cursorPosition?.y ?? 0)
);
node.position = position;
});
const nodeAdditions: NodeChange[] = newNodes.map((n) => ({
item: n,
type: 'add',
}));
const nodeSelectionChanges: NodeChange[] = state.nodes.map((n) => ({
id: n.data.id, id: n.data.id,
type: 'select', type: 'select',
selected: false, selected: false,
})); });
});
const edgeAdditions: EdgeChange[] = newEdges.map((e) => ({ // Add new nodes
item: e, nodes.forEach((n) => {
nodeChanges.push({
item: n,
type: 'add', type: 'add',
})); });
const edgeSelectionChanges: EdgeChange[] = state.edges.map((e) => ({ });
const edgeChanges: EdgeChange[] = [];
// Deselect existing edges
state.edges.forEach((e) => {
edgeChanges.push({
id: e.id, id: e.id,
type: 'select', type: 'select',
selected: false, selected: false,
})); });
});
// Add new edges
edges.forEach((e) => {
edgeChanges.push({
item: e,
type: 'add',
});
});
state.nodes = applyNodeChanges(nodeAdditions.concat(nodeSelectionChanges), state.nodes); state.nodes = applyNodeChanges(nodeChanges, state.nodes);
state.edges = applyEdgeChanges(edgeChanges, state.edges);
state.edges = applyEdgeChanges(edgeAdditions.concat(edgeSelectionChanges), state.edges); // Add node execution states for new nodes
nodes.forEach((node) => {
newNodes.forEach((node) => {
state.nodeExecutionStates[node.id] = { state.nodeExecutionStates[node.id] = {
nodeId: node.id, nodeId: node.id,
...initialNodeExecutionState, ...deepClone(initialNodeExecutionState),
}; };
}); });
}, },
@ -786,7 +725,6 @@ export const {
selectedAll, selectedAll,
selectedEdgesChanged, selectedEdgesChanged,
selectedNodesChanged, selectedNodesChanged,
selectionCopied,
selectionPasted, selectionPasted,
viewportChanged, viewportChanged,
edgeAdded, edgeAdded,
@ -831,6 +769,10 @@ export const isAnyNodeOrEdgeMutation = isAnyOf(
edgeAdded edgeAdded
); );
export const $cursorPos = atom<XYPosition | null>(null);
export const $copiedNodes = atom<AnyNode[]>([]);
export const $copiedEdges = atom<InvocationNodeEdge[]>([]);
export const selectNodesSlice = (state: RootState) => state.nodes.present; export const selectNodesSlice = (state: RootState) => state.nodes.present;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */ /* eslint-disable-next-line @typescript-eslint/no-explicit-any */
@ -850,8 +792,6 @@ export const nodesPersistConfig: PersistConfig<NodesState> = {
'connectionStartFieldType', 'connectionStartFieldType',
'selectedNodes', 'selectedNodes',
'selectedEdges', 'selectedEdges',
'nodesToCopy',
'edgesToCopy',
'connectionMade', 'connectionMade',
'modifyingEdge', 'modifyingEdge',
'addNewNodePosition', 'addNewNodePosition',

View File

@ -21,8 +21,6 @@ export type NodesState = {
selectedEdges: string[]; selectedEdges: string[];
nodeExecutionStates: Record<string, NodeExecutionState>; nodeExecutionStates: Record<string, NodeExecutionState>;
viewport: Viewport; viewport: Viewport;
nodesToCopy: AnyNode[];
edgesToCopy: InvocationNodeEdge[];
isAddNodePopoverOpen: boolean; isAddNodePopoverOpen: boolean;
addNewNodePosition: XYPosition | null; addNewNodePosition: XYPosition | null;
}; };

View File

@ -4,8 +4,8 @@ export const findUnoccupiedPosition = (nodes: Node[], x: number, y: number) => {
let newX = x; let newX = x;
let newY = y; let newY = y;
while (nodes.find((n) => n.position.x === newX && n.position.y === newY)) { while (nodes.find((n) => n.position.x === newX && n.position.y === newY)) {
newX = newX + 50; newX = Math.floor(newX + 50);
newY = newY + 50; newY = Math.floor(newY + 50);
} }
return { x: newX, y: newY }; return { x: newX, y: newY };
}; };