mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): rework node and edge mutation logic
Remove our DIY'd reducers, consolidating all node and edge mutations to use `edgesChanged` and `nodesChanged`, which are called by reactflow. This makes the API for manipulating nodes and edges less tangly and error-prone.
This commit is contained in:
parent
504ac82077
commit
26029108f7
@ -14,11 +14,12 @@ import {
|
|||||||
$pendingConnection,
|
$pendingConnection,
|
||||||
$templates,
|
$templates,
|
||||||
closeAddNodePopover,
|
closeAddNodePopover,
|
||||||
connectionMade,
|
edgesChanged,
|
||||||
nodeAdded,
|
nodeAdded,
|
||||||
openAddNodePopover,
|
openAddNodePopover,
|
||||||
} from 'features/nodes/store/nodesSlice';
|
} from 'features/nodes/store/nodesSlice';
|
||||||
import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection';
|
import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection';
|
||||||
|
import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil';
|
||||||
import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes';
|
import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes';
|
||||||
import type { AnyNode } from 'features/nodes/types/invocation';
|
import type { AnyNode } from 'features/nodes/types/invocation';
|
||||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||||
@ -166,7 +167,8 @@ const AddNodePopover = () => {
|
|||||||
edgePendingUpdate
|
edgePendingUpdate
|
||||||
);
|
);
|
||||||
if (connection) {
|
if (connection) {
|
||||||
dispatch(connectionMade(connection));
|
const newEdge = connectionToEdge(connection);
|
||||||
|
dispatch(edgesChanged([{ type: 'add', item: newEdge }]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -14,29 +14,24 @@ import {
|
|||||||
$lastEdgeUpdateMouseEvent,
|
$lastEdgeUpdateMouseEvent,
|
||||||
$pendingConnection,
|
$pendingConnection,
|
||||||
$viewport,
|
$viewport,
|
||||||
connectionMade,
|
|
||||||
edgeDeleted,
|
|
||||||
edgesChanged,
|
edgesChanged,
|
||||||
edgesDeleted,
|
|
||||||
nodesChanged,
|
nodesChanged,
|
||||||
nodesDeleted,
|
|
||||||
redo,
|
redo,
|
||||||
selectedAll,
|
selectedAll,
|
||||||
|
selectionDeleted,
|
||||||
undo,
|
undo,
|
||||||
} from 'features/nodes/store/nodesSlice';
|
} from 'features/nodes/store/nodesSlice';
|
||||||
import { $flow } from 'features/nodes/store/reactFlowInstance';
|
import { $flow } from 'features/nodes/store/reactFlowInstance';
|
||||||
import { isString } from 'lodash-es';
|
import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil';
|
||||||
import type { CSSProperties, MouseEvent } from 'react';
|
import type { CSSProperties, MouseEvent } from 'react';
|
||||||
import { memo, useCallback, useMemo, useRef } from 'react';
|
import { memo, useCallback, useMemo, useRef } from 'react';
|
||||||
import { useHotkeys } from 'react-hotkeys-hook';
|
import { useHotkeys } from 'react-hotkeys-hook';
|
||||||
import type {
|
import type {
|
||||||
OnEdgesChange,
|
OnEdgesChange,
|
||||||
OnEdgesDelete,
|
|
||||||
OnEdgeUpdateFunc,
|
OnEdgeUpdateFunc,
|
||||||
OnInit,
|
OnInit,
|
||||||
OnMoveEnd,
|
OnMoveEnd,
|
||||||
OnNodesChange,
|
OnNodesChange,
|
||||||
OnNodesDelete,
|
|
||||||
ProOptions,
|
ProOptions,
|
||||||
ReactFlowProps,
|
ReactFlowProps,
|
||||||
ReactFlowState,
|
ReactFlowState,
|
||||||
@ -50,8 +45,6 @@ import CurrentImageNode from './nodes/CurrentImage/CurrentImageNode';
|
|||||||
import InvocationNodeWrapper from './nodes/Invocation/InvocationNodeWrapper';
|
import InvocationNodeWrapper from './nodes/Invocation/InvocationNodeWrapper';
|
||||||
import NotesNode from './nodes/Notes/NotesNode';
|
import NotesNode from './nodes/Notes/NotesNode';
|
||||||
|
|
||||||
const DELETE_KEYS = ['Delete', 'Backspace'];
|
|
||||||
|
|
||||||
const edgeTypes = {
|
const edgeTypes = {
|
||||||
collapsed: InvocationCollapsedEdge,
|
collapsed: InvocationCollapsedEdge,
|
||||||
default: InvocationDefaultEdge,
|
default: InvocationDefaultEdge,
|
||||||
@ -109,20 +102,6 @@ export const Flow = memo(() => {
|
|||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
|
|
||||||
const onEdgesDelete: OnEdgesDelete = useCallback(
|
|
||||||
(edges) => {
|
|
||||||
dispatch(edgesDeleted(edges));
|
|
||||||
},
|
|
||||||
[dispatch]
|
|
||||||
);
|
|
||||||
|
|
||||||
const onNodesDelete: OnNodesDelete = useCallback(
|
|
||||||
(nodes) => {
|
|
||||||
dispatch(nodesDeleted(nodes));
|
|
||||||
},
|
|
||||||
[dispatch]
|
|
||||||
);
|
|
||||||
|
|
||||||
const handleMoveEnd: OnMoveEnd = useCallback((e, viewport) => {
|
const handleMoveEnd: OnMoveEnd = useCallback((e, viewport) => {
|
||||||
$viewport.set(viewport);
|
$viewport.set(viewport);
|
||||||
}, []);
|
}, []);
|
||||||
@ -167,16 +146,20 @@ export const Flow = memo(() => {
|
|||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const onEdgeUpdate: OnEdgeUpdateFunc = useCallback(
|
const onEdgeUpdate: OnEdgeUpdateFunc = useCallback(
|
||||||
(edge, newConnection) => {
|
(oldEdge, newConnection) => {
|
||||||
// This event is fired when an edge update is successful
|
// This event is fired when an edge update is successful
|
||||||
$didUpdateEdge.set(true);
|
$didUpdateEdge.set(true);
|
||||||
// When an edge update is successful, we need to delete the old edge and create a new one
|
// When an edge update is successful, we need to delete the old edge and create a new one
|
||||||
dispatch(edgeDeleted(edge.id));
|
const newEdge = connectionToEdge(newConnection);
|
||||||
dispatch(connectionMade(newConnection));
|
dispatch(
|
||||||
|
edgesChanged([
|
||||||
|
{ type: 'remove', id: oldEdge.id },
|
||||||
|
{ type: 'add', item: newEdge },
|
||||||
|
])
|
||||||
|
);
|
||||||
// Because we shift the position of handles depending on whether a field is connected or not, we must use
|
// Because we shift the position of handles depending on whether a field is connected or not, we must use
|
||||||
// updateNodeInternals to tell reactflow to recalculate the positions of the handles
|
// updateNodeInternals to tell reactflow to recalculate the positions of the handles
|
||||||
const nodesToUpdate = [edge.source, edge.target, newConnection.source, newConnection.target].filter(isString);
|
updateNodeInternals([oldEdge.source, oldEdge.target, newEdge.source, newEdge.target]);
|
||||||
updateNodeInternals(nodesToUpdate);
|
|
||||||
},
|
},
|
||||||
[dispatch, updateNodeInternals]
|
[dispatch, updateNodeInternals]
|
||||||
);
|
);
|
||||||
@ -193,7 +176,7 @@ export const Flow = memo(() => {
|
|||||||
// If we got this far and did not successfully update an edge, and the mouse moved away from the handle,
|
// If we got this far and did not successfully update an edge, and the mouse moved away from the handle,
|
||||||
// the user probably intended to delete the edge
|
// the user probably intended to delete the edge
|
||||||
if (!didUpdateEdge && didMouseMove) {
|
if (!didUpdateEdge && didMouseMove) {
|
||||||
dispatch(edgeDeleted(edge.id));
|
dispatch(edgesChanged([{ type: 'remove', id: edge.id }]));
|
||||||
}
|
}
|
||||||
|
|
||||||
$edgePendingUpdate.set(null);
|
$edgePendingUpdate.set(null);
|
||||||
@ -267,6 +250,11 @@ export const Flow = memo(() => {
|
|||||||
}, [cancelConnection]);
|
}, [cancelConnection]);
|
||||||
useHotkeys('esc', onEscapeHotkey);
|
useHotkeys('esc', onEscapeHotkey);
|
||||||
|
|
||||||
|
const onDeleteHotkey = useCallback(() => {
|
||||||
|
dispatch(selectionDeleted());
|
||||||
|
}, [dispatch]);
|
||||||
|
useHotkeys(['delete', 'backspace'], onDeleteHotkey);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<ReactFlow
|
<ReactFlow
|
||||||
id="workflow-editor"
|
id="workflow-editor"
|
||||||
@ -280,11 +268,9 @@ export const Flow = memo(() => {
|
|||||||
onMouseMove={onMouseMove}
|
onMouseMove={onMouseMove}
|
||||||
onNodesChange={onNodesChange}
|
onNodesChange={onNodesChange}
|
||||||
onEdgesChange={onEdgesChange}
|
onEdgesChange={onEdgesChange}
|
||||||
onEdgesDelete={onEdgesDelete}
|
|
||||||
onEdgeUpdate={onEdgeUpdate}
|
onEdgeUpdate={onEdgeUpdate}
|
||||||
onEdgeUpdateStart={onEdgeUpdateStart}
|
onEdgeUpdateStart={onEdgeUpdateStart}
|
||||||
onEdgeUpdateEnd={onEdgeUpdateEnd}
|
onEdgeUpdateEnd={onEdgeUpdateEnd}
|
||||||
onNodesDelete={onNodesDelete}
|
|
||||||
onConnectStart={onConnectStart}
|
onConnectStart={onConnectStart}
|
||||||
onConnect={onConnect}
|
onConnect={onConnect}
|
||||||
onConnectEnd={onConnectEnd}
|
onConnectEnd={onConnectEnd}
|
||||||
@ -298,7 +284,7 @@ export const Flow = memo(() => {
|
|||||||
proOptions={proOptions}
|
proOptions={proOptions}
|
||||||
style={flowStyles}
|
style={flowStyles}
|
||||||
onPaneClick={handlePaneClick}
|
onPaneClick={handlePaneClick}
|
||||||
deleteKeyCode={DELETE_KEYS}
|
deleteKeyCode={null}
|
||||||
selectionMode={selectionMode}
|
selectionMode={selectionMode}
|
||||||
elevateEdgesOnSelect
|
elevateEdgesOnSelect
|
||||||
>
|
>
|
||||||
|
@ -0,0 +1,20 @@
|
|||||||
|
import { useDoesFieldExist } from 'features/nodes/hooks/useDoesFieldExist';
|
||||||
|
import type { PropsWithChildren } from 'react';
|
||||||
|
import { memo } from 'react';
|
||||||
|
|
||||||
|
type Props = PropsWithChildren<{
|
||||||
|
nodeId: string;
|
||||||
|
fieldName?: string;
|
||||||
|
}>;
|
||||||
|
|
||||||
|
export const MissingFallback = memo((props: Props) => {
|
||||||
|
// We must be careful here to avoid race conditions where a deleted node is still referenced as an exposed field
|
||||||
|
const exists = useDoesFieldExist(props.nodeId, props.fieldName);
|
||||||
|
if (!exists) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return props.children;
|
||||||
|
});
|
||||||
|
|
||||||
|
MissingFallback.displayName = 'MissingFallback';
|
@ -3,6 +3,7 @@ import { CSS } from '@dnd-kit/utilities';
|
|||||||
import { Flex, Icon, IconButton, Spacer, Tooltip } from '@invoke-ai/ui-library';
|
import { Flex, Icon, IconButton, Spacer, Tooltip } from '@invoke-ai/ui-library';
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import NodeSelectionOverlay from 'common/components/NodeSelectionOverlay';
|
import NodeSelectionOverlay from 'common/components/NodeSelectionOverlay';
|
||||||
|
import { MissingFallback } from 'features/nodes/components/flow/nodes/Invocation/MissingFallback';
|
||||||
import { useFieldOriginalValue } from 'features/nodes/hooks/useFieldOriginalValue';
|
import { useFieldOriginalValue } from 'features/nodes/hooks/useFieldOriginalValue';
|
||||||
import { useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
|
import { useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
|
||||||
import { workflowExposedFieldRemoved } from 'features/nodes/store/workflowSlice';
|
import { workflowExposedFieldRemoved } from 'features/nodes/store/workflowSlice';
|
||||||
@ -20,7 +21,7 @@ type Props = {
|
|||||||
fieldName: string;
|
fieldName: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
const LinearViewField = ({ nodeId, fieldName }: Props) => {
|
const LinearViewFieldInternal = ({ nodeId, fieldName }: Props) => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { isValueChanged, onReset } = useFieldOriginalValue(nodeId, fieldName);
|
const { isValueChanged, onReset } = useFieldOriginalValue(nodeId, fieldName);
|
||||||
const { isMouseOverNode, handleMouseOut, handleMouseOver } = useMouseOverNode(nodeId);
|
const { isMouseOverNode, handleMouseOut, handleMouseOver } = useMouseOverNode(nodeId);
|
||||||
@ -99,4 +100,12 @@ const LinearViewField = ({ nodeId, fieldName }: Props) => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const LinearViewField = ({ nodeId, fieldName }: Props) => {
|
||||||
|
return (
|
||||||
|
<MissingFallback nodeId={nodeId} fieldName={fieldName}>
|
||||||
|
<LinearViewFieldInternal nodeId={nodeId} fieldName={fieldName} />
|
||||||
|
</MissingFallback>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
export default memo(LinearViewField);
|
export default memo(LinearViewField);
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import { Flex, FormLabel, Icon, IconButton, Spacer, Tooltip } from '@invoke-ai/ui-library';
|
import { Flex, FormLabel, Icon, IconButton, Spacer, Tooltip } from '@invoke-ai/ui-library';
|
||||||
import FieldTooltipContent from 'features/nodes/components/flow/nodes/Invocation/fields/FieldTooltipContent';
|
import FieldTooltipContent from 'features/nodes/components/flow/nodes/Invocation/fields/FieldTooltipContent';
|
||||||
import InputFieldRenderer from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer';
|
import InputFieldRenderer from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer';
|
||||||
|
import { MissingFallback } from 'features/nodes/components/flow/nodes/Invocation/MissingFallback';
|
||||||
import { useFieldLabel } from 'features/nodes/hooks/useFieldLabel';
|
import { useFieldLabel } from 'features/nodes/hooks/useFieldLabel';
|
||||||
import { useFieldOriginalValue } from 'features/nodes/hooks/useFieldOriginalValue';
|
import { useFieldOriginalValue } from 'features/nodes/hooks/useFieldOriginalValue';
|
||||||
import { useFieldTemplateTitle } from 'features/nodes/hooks/useFieldTemplateTitle';
|
import { useFieldTemplateTitle } from 'features/nodes/hooks/useFieldTemplateTitle';
|
||||||
@ -14,7 +15,7 @@ type Props = {
|
|||||||
fieldName: string;
|
fieldName: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
const WorkflowField = ({ nodeId, fieldName }: Props) => {
|
const WorkflowFieldInternal = ({ nodeId, fieldName }: Props) => {
|
||||||
const label = useFieldLabel(nodeId, fieldName);
|
const label = useFieldLabel(nodeId, fieldName);
|
||||||
const fieldTemplateTitle = useFieldTemplateTitle(nodeId, fieldName, 'inputs');
|
const fieldTemplateTitle = useFieldTemplateTitle(nodeId, fieldName, 'inputs');
|
||||||
const { isValueChanged, onReset } = useFieldOriginalValue(nodeId, fieldName);
|
const { isValueChanged, onReset } = useFieldOriginalValue(nodeId, fieldName);
|
||||||
@ -50,4 +51,12 @@ const WorkflowField = ({ nodeId, fieldName }: Props) => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const WorkflowField = ({ nodeId, fieldName }: Props) => {
|
||||||
|
return (
|
||||||
|
<MissingFallback nodeId={nodeId} fieldName={fieldName}>
|
||||||
|
<WorkflowFieldInternal nodeId={nodeId} fieldName={fieldName} />
|
||||||
|
</MissingFallback>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
export default memo(WorkflowField);
|
export default memo(WorkflowField);
|
||||||
|
@ -6,10 +6,10 @@ import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
|||||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||||
import DndSortable from 'features/dnd/components/DndSortable';
|
import DndSortable from 'features/dnd/components/DndSortable';
|
||||||
import type { DragEndEvent } from 'features/dnd/types';
|
import type { DragEndEvent } from 'features/dnd/types';
|
||||||
import LinearViewField from 'features/nodes/components/flow/nodes/Invocation/fields/LinearViewField';
|
import LinearViewFieldInternal from 'features/nodes/components/flow/nodes/Invocation/fields/LinearViewField';
|
||||||
import { selectWorkflowSlice, workflowExposedFieldsReordered } from 'features/nodes/store/workflowSlice';
|
import { selectWorkflowSlice, workflowExposedFieldsReordered } from 'features/nodes/store/workflowSlice';
|
||||||
import type { FieldIdentifier } from 'features/nodes/types/field';
|
import type { FieldIdentifier } from 'features/nodes/types/field';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useGetOpenAPISchemaQuery } from 'services/api/endpoints/appInfo';
|
import { useGetOpenAPISchemaQuery } from 'services/api/endpoints/appInfo';
|
||||||
|
|
||||||
@ -40,16 +40,18 @@ const WorkflowLinearTab = () => {
|
|||||||
[dispatch, fields]
|
[dispatch, fields]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const items = useMemo(() => fields.map((field) => `${field.nodeId}.${field.fieldName}`), [fields]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Box position="relative" w="full" h="full">
|
<Box position="relative" w="full" h="full">
|
||||||
<ScrollableContent>
|
<ScrollableContent>
|
||||||
<DndSortable onDragEnd={handleDragEnd} items={fields.map((field) => `${field.nodeId}.${field.fieldName}`)}>
|
<DndSortable onDragEnd={handleDragEnd} items={items}>
|
||||||
<Flex position="relative" flexDir="column" alignItems="flex-start" p={1} gap={2} h="full" w="full">
|
<Flex position="relative" flexDir="column" alignItems="flex-start" p={1} gap={2} h="full" w="full">
|
||||||
{isLoading ? (
|
{isLoading ? (
|
||||||
<IAINoContentFallback label={t('nodes.loadingNodes')} icon={null} />
|
<IAINoContentFallback label={t('nodes.loadingNodes')} icon={null} />
|
||||||
) : fields.length ? (
|
) : fields.length ? (
|
||||||
fields.map(({ nodeId, fieldName }) => (
|
fields.map(({ nodeId, fieldName }) => (
|
||||||
<LinearViewField key={`${nodeId}.${fieldName}`} nodeId={nodeId} fieldName={fieldName} />
|
<LinearViewFieldInternal key={`${nodeId}.${fieldName}`} nodeId={nodeId} fieldName={fieldName} />
|
||||||
))
|
))
|
||||||
) : (
|
) : (
|
||||||
<IAINoContentFallback label={t('nodes.noFieldsLinearview')} icon={null} />
|
<IAINoContentFallback label={t('nodes.noFieldsLinearview')} icon={null} />
|
||||||
|
@ -7,13 +7,12 @@ import {
|
|||||||
$isAddNodePopoverOpen,
|
$isAddNodePopoverOpen,
|
||||||
$pendingConnection,
|
$pendingConnection,
|
||||||
$templates,
|
$templates,
|
||||||
connectionMade,
|
edgesChanged,
|
||||||
edgeDeleted,
|
|
||||||
} from 'features/nodes/store/nodesSlice';
|
} from 'features/nodes/store/nodesSlice';
|
||||||
import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection';
|
import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection';
|
||||||
import { isString } from 'lodash-es';
|
import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
import type { OnConnect, OnConnectEnd, OnConnectStart } from 'reactflow';
|
import type { EdgeChange, OnConnect, OnConnectEnd, OnConnectStart } from 'reactflow';
|
||||||
import { useUpdateNodeInternals } from 'reactflow';
|
import { useUpdateNodeInternals } from 'reactflow';
|
||||||
import { assert } from 'tsafe';
|
import { assert } from 'tsafe';
|
||||||
|
|
||||||
@ -50,9 +49,9 @@ export const useConnection = () => {
|
|||||||
const onConnect = useCallback<OnConnect>(
|
const onConnect = useCallback<OnConnect>(
|
||||||
(connection) => {
|
(connection) => {
|
||||||
const { dispatch } = store;
|
const { dispatch } = store;
|
||||||
dispatch(connectionMade(connection));
|
const newEdge = connectionToEdge(connection);
|
||||||
const nodesToUpdate = [connection.source, connection.target].filter(isString);
|
dispatch(edgesChanged([{ type: 'add', item: newEdge }]));
|
||||||
updateNodeInternals(nodesToUpdate);
|
updateNodeInternals([newEdge.source, newEdge.target]);
|
||||||
$pendingConnection.set(null);
|
$pendingConnection.set(null);
|
||||||
},
|
},
|
||||||
[store, updateNodeInternals]
|
[store, updateNodeInternals]
|
||||||
@ -92,13 +91,17 @@ export const useConnection = () => {
|
|||||||
edgePendingUpdate
|
edgePendingUpdate
|
||||||
);
|
);
|
||||||
if (connection) {
|
if (connection) {
|
||||||
dispatch(connectionMade(connection));
|
const newEdge = connectionToEdge(connection);
|
||||||
const nodesToUpdate = [connection.source, connection.target].filter(isString);
|
const changes: EdgeChange[] = [{ type: 'add', item: newEdge }];
|
||||||
updateNodeInternals(nodesToUpdate);
|
|
||||||
|
const nodesToUpdate = [newEdge.source, newEdge.target];
|
||||||
if (edgePendingUpdate) {
|
if (edgePendingUpdate) {
|
||||||
dispatch(edgeDeleted(edgePendingUpdate.id));
|
|
||||||
$didUpdateEdge.set(true);
|
$didUpdateEdge.set(true);
|
||||||
|
changes.push({ type: 'remove', id: edgePendingUpdate.id });
|
||||||
|
nodesToUpdate.push(edgePendingUpdate.source, edgePendingUpdate.target);
|
||||||
}
|
}
|
||||||
|
dispatch(edgesChanged(changes));
|
||||||
|
updateNodeInternals(nodesToUpdate);
|
||||||
}
|
}
|
||||||
$pendingConnection.set(null);
|
$pendingConnection.set(null);
|
||||||
} else {
|
} else {
|
||||||
|
@ -0,0 +1,20 @@
|
|||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||||
|
|
||||||
|
export const useDoesFieldExist = (nodeId: string, fieldName?: string) => {
|
||||||
|
const doesFieldExist = useAppSelector((s) => {
|
||||||
|
const node = s.nodes.present.nodes.find((n) => n.id === nodeId);
|
||||||
|
if (!isInvocationNode(node)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (fieldName === undefined) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (!node.data.inputs[fieldName]) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
});
|
||||||
|
|
||||||
|
return doesFieldExist;
|
||||||
|
};
|
@ -1,6 +1,7 @@
|
|||||||
import type { PayloadAction, UnknownAction } from '@reduxjs/toolkit';
|
import type { PayloadAction, UnknownAction } from '@reduxjs/toolkit';
|
||||||
import { createSlice, isAnyOf } from '@reduxjs/toolkit';
|
import { createSlice, isAnyOf } from '@reduxjs/toolkit';
|
||||||
import type { PersistConfig, RootState } from 'app/store/store';
|
import type { PersistConfig, RootState } from 'app/store/store';
|
||||||
|
import { deepClone } from 'common/util/deepClone';
|
||||||
import { workflowLoaded } from 'features/nodes/store/actions';
|
import { workflowLoaded } from 'features/nodes/store/actions';
|
||||||
import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
|
import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
|
||||||
import type {
|
import type {
|
||||||
@ -48,8 +49,8 @@ import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocatio
|
|||||||
import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
|
import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
|
||||||
import { atom } from 'nanostores';
|
import { atom } from 'nanostores';
|
||||||
import type { MouseEvent } from 'react';
|
import type { MouseEvent } from 'react';
|
||||||
import type { Connection, Edge, EdgeChange, EdgeRemoveChange, Node, NodeChange, Viewport, XYPosition } from 'reactflow';
|
import type { Edge, EdgeChange, Node, NodeChange, Viewport, XYPosition } from 'reactflow';
|
||||||
import { addEdge, applyEdgeChanges, applyNodeChanges, getConnectedEdges, getIncomers, getOutgoers } from 'reactflow';
|
import { applyEdgeChanges, applyNodeChanges, getConnectedEdges, getIncomers, getOutgoers } from 'reactflow';
|
||||||
import type { UndoableOptions } from 'redux-undo';
|
import type { UndoableOptions } from 'redux-undo';
|
||||||
import type { z } from 'zod';
|
import type { z } from 'zod';
|
||||||
|
|
||||||
@ -124,10 +125,27 @@ export const nodesSlice = createSlice({
|
|||||||
state.nodes.push(node);
|
state.nodes.push(node);
|
||||||
},
|
},
|
||||||
edgesChanged: (state, action: PayloadAction<EdgeChange[]>) => {
|
edgesChanged: (state, action: PayloadAction<EdgeChange[]>) => {
|
||||||
state.edges = applyEdgeChanges(action.payload, state.edges);
|
const changes = deepClone(action.payload);
|
||||||
},
|
action.payload.forEach((change) => {
|
||||||
connectionMade: (state, action: PayloadAction<Connection>) => {
|
if (change.type === 'remove' || change.type === 'select') {
|
||||||
state.edges = addEdge({ ...action.payload, type: 'default' }, state.edges);
|
const edge = state.edges.find((e) => e.id === change.id);
|
||||||
|
// If we deleted or selected a collapsed edge, we need to find its "hidden" edges and do the same to them
|
||||||
|
if (edge && edge.type === 'collapsed') {
|
||||||
|
const hiddenEdges = state.edges.filter((e) => e.source === edge.source && e.target === edge.target);
|
||||||
|
if (change.type === 'remove') {
|
||||||
|
hiddenEdges.forEach((e) => {
|
||||||
|
changes.push({ type: 'remove', id: e.id });
|
||||||
|
});
|
||||||
|
}
|
||||||
|
if (change.type === 'select') {
|
||||||
|
hiddenEdges.forEach((e) => {
|
||||||
|
changes.push({ type: 'select', id: e.id, selected: change.selected });
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
state.edges = applyEdgeChanges(changes, state.edges);
|
||||||
},
|
},
|
||||||
fieldLabelChanged: (
|
fieldLabelChanged: (
|
||||||
state,
|
state,
|
||||||
@ -264,33 +282,6 @@ export const nodesSlice = createSlice({
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
edgeDeleted: (state, action: PayloadAction<string>) => {
|
|
||||||
state.edges = state.edges.filter((e) => e.id !== action.payload);
|
|
||||||
},
|
|
||||||
edgesDeleted: (state, action: PayloadAction<Edge[]>) => {
|
|
||||||
const edges = action.payload;
|
|
||||||
const collapsedEdges = edges.filter((e) => e.type === 'collapsed');
|
|
||||||
|
|
||||||
// if we delete a collapsed edge, we need to delete all collapsed edges between the same nodes
|
|
||||||
if (collapsedEdges.length) {
|
|
||||||
const edgeChanges: EdgeRemoveChange[] = [];
|
|
||||||
collapsedEdges.forEach((collapsedEdge) => {
|
|
||||||
state.edges.forEach((edge) => {
|
|
||||||
if (edge.source === collapsedEdge.source && edge.target === collapsedEdge.target) {
|
|
||||||
edgeChanges.push({ id: edge.id, type: 'remove' });
|
|
||||||
}
|
|
||||||
});
|
|
||||||
});
|
|
||||||
state.edges = applyEdgeChanges(edgeChanges, state.edges);
|
|
||||||
}
|
|
||||||
},
|
|
||||||
nodesDeleted: (state, action: PayloadAction<AnyNode[]>) => {
|
|
||||||
action.payload.forEach((node) => {
|
|
||||||
if (!isInvocationNode(node)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
},
|
|
||||||
nodeLabelChanged: (state, action: PayloadAction<{ nodeId: string; label: string }>) => {
|
nodeLabelChanged: (state, action: PayloadAction<{ nodeId: string; label: string }>) => {
|
||||||
const { nodeId, label } = action.payload;
|
const { nodeId, label } = action.payload;
|
||||||
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
|
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
|
||||||
@ -435,6 +426,23 @@ export const nodesSlice = createSlice({
|
|||||||
state.nodes = applyNodeChanges(nodeChanges, state.nodes);
|
state.nodes = applyNodeChanges(nodeChanges, state.nodes);
|
||||||
state.edges = applyEdgeChanges(edgeChanges, state.edges);
|
state.edges = applyEdgeChanges(edgeChanges, state.edges);
|
||||||
},
|
},
|
||||||
|
selectionDeleted: (state) => {
|
||||||
|
const selectedNodes = state.nodes.filter((n) => n.selected);
|
||||||
|
const selectedEdges = state.edges.filter((e) => e.selected);
|
||||||
|
|
||||||
|
const nodeChanges: NodeChange[] = selectedNodes.map((n) => ({
|
||||||
|
id: n.id,
|
||||||
|
type: 'remove',
|
||||||
|
}));
|
||||||
|
|
||||||
|
const edgeChanges: EdgeChange[] = selectedEdges.map((e) => ({
|
||||||
|
id: e.id,
|
||||||
|
type: 'remove',
|
||||||
|
}));
|
||||||
|
|
||||||
|
state.nodes = applyNodeChanges(nodeChanges, state.nodes);
|
||||||
|
state.edges = applyEdgeChanges(edgeChanges, state.edges);
|
||||||
|
},
|
||||||
undo: (state) => state,
|
undo: (state) => state,
|
||||||
redo: (state) => state,
|
redo: (state) => state,
|
||||||
},
|
},
|
||||||
@ -457,10 +465,7 @@ export const nodesSlice = createSlice({
|
|||||||
});
|
});
|
||||||
|
|
||||||
export const {
|
export const {
|
||||||
connectionMade,
|
|
||||||
edgeDeleted,
|
|
||||||
edgesChanged,
|
edgesChanged,
|
||||||
edgesDeleted,
|
|
||||||
fieldValueReset,
|
fieldValueReset,
|
||||||
fieldBoardValueChanged,
|
fieldBoardValueChanged,
|
||||||
fieldBooleanValueChanged,
|
fieldBooleanValueChanged,
|
||||||
@ -488,11 +493,11 @@ export const {
|
|||||||
nodeLabelChanged,
|
nodeLabelChanged,
|
||||||
nodeNotesChanged,
|
nodeNotesChanged,
|
||||||
nodesChanged,
|
nodesChanged,
|
||||||
nodesDeleted,
|
|
||||||
nodeUseCacheChanged,
|
nodeUseCacheChanged,
|
||||||
notesNodeValueChanged,
|
notesNodeValueChanged,
|
||||||
selectedAll,
|
selectedAll,
|
||||||
selectionPasted,
|
selectionPasted,
|
||||||
|
selectionDeleted,
|
||||||
undo,
|
undo,
|
||||||
redo,
|
redo,
|
||||||
} = nodesSlice.actions;
|
} = nodesSlice.actions;
|
||||||
@ -580,10 +585,7 @@ export const nodesUndoableConfig: UndoableOptions<NodesState, UnknownAction> = {
|
|||||||
|
|
||||||
// This is used for tracking `state.workflow.isTouched`
|
// This is used for tracking `state.workflow.isTouched`
|
||||||
export const isAnyNodeOrEdgeMutation = isAnyOf(
|
export const isAnyNodeOrEdgeMutation = isAnyOf(
|
||||||
connectionMade,
|
|
||||||
edgeDeleted,
|
|
||||||
edgesChanged,
|
edgesChanged,
|
||||||
edgesDeleted,
|
|
||||||
fieldBoardValueChanged,
|
fieldBoardValueChanged,
|
||||||
fieldBooleanValueChanged,
|
fieldBooleanValueChanged,
|
||||||
fieldColorValueChanged,
|
fieldColorValueChanged,
|
||||||
@ -601,13 +603,14 @@ export const isAnyNodeOrEdgeMutation = isAnyOf(
|
|||||||
fieldStringValueChanged,
|
fieldStringValueChanged,
|
||||||
fieldVaeModelValueChanged,
|
fieldVaeModelValueChanged,
|
||||||
nodeAdded,
|
nodeAdded,
|
||||||
|
nodesChanged,
|
||||||
nodeReplaced,
|
nodeReplaced,
|
||||||
nodeIsIntermediateChanged,
|
nodeIsIntermediateChanged,
|
||||||
nodeIsOpenChanged,
|
nodeIsOpenChanged,
|
||||||
nodeLabelChanged,
|
nodeLabelChanged,
|
||||||
nodeNotesChanged,
|
nodeNotesChanged,
|
||||||
nodesDeleted,
|
|
||||||
nodeUseCacheChanged,
|
nodeUseCacheChanged,
|
||||||
notesNodeValueChanged,
|
notesNodeValueChanged,
|
||||||
selectionPasted
|
selectionPasted,
|
||||||
|
selectionDeleted
|
||||||
);
|
);
|
||||||
|
@ -0,0 +1,32 @@
|
|||||||
|
import type { Connection, Edge } from 'reactflow';
|
||||||
|
import { assert } from 'tsafe';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets the edge id for a connection
|
||||||
|
* Copied from: https://github.com/xyflow/xyflow/blob/v11/packages/core/src/utils/graph.ts#L44-L45
|
||||||
|
* Requested for this to be exported in: https://github.com/xyflow/xyflow/issues/4290
|
||||||
|
* @param connection The connection to get the id for
|
||||||
|
* @returns The edge id
|
||||||
|
*/
|
||||||
|
const getEdgeId = (connection: Connection): string => {
|
||||||
|
const { source, sourceHandle, target, targetHandle } = connection;
|
||||||
|
return `reactflow__edge-${source}${sourceHandle || ''}-${target}${targetHandle || ''}`;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts a connection to an edge
|
||||||
|
* @param connection The connection to convert to an edge
|
||||||
|
* @returns The edge
|
||||||
|
* @throws If the connection is invalid (e.g. missing source, sourcehandle, target, or targetHandle)
|
||||||
|
*/
|
||||||
|
export const connectionToEdge = (connection: Connection): Edge => {
|
||||||
|
const { source, sourceHandle, target, targetHandle } = connection;
|
||||||
|
assert(source && sourceHandle && target && targetHandle, 'Invalid connection');
|
||||||
|
return {
|
||||||
|
source,
|
||||||
|
sourceHandle,
|
||||||
|
target,
|
||||||
|
targetHandle,
|
||||||
|
id: getEdgeId({ source, sourceHandle, target, targetHandle }),
|
||||||
|
};
|
||||||
|
};
|
@ -3,7 +3,7 @@ import { createSlice } from '@reduxjs/toolkit';
|
|||||||
import type { PersistConfig, RootState } from 'app/store/store';
|
import type { PersistConfig, RootState } from 'app/store/store';
|
||||||
import { deepClone } from 'common/util/deepClone';
|
import { deepClone } from 'common/util/deepClone';
|
||||||
import { workflowLoaded } from 'features/nodes/store/actions';
|
import { workflowLoaded } from 'features/nodes/store/actions';
|
||||||
import { isAnyNodeOrEdgeMutation, nodeEditorReset, nodesChanged, nodesDeleted } from 'features/nodes/store/nodesSlice';
|
import { isAnyNodeOrEdgeMutation, nodeEditorReset, nodesChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import type {
|
import type {
|
||||||
FieldIdentifierWithValue,
|
FieldIdentifierWithValue,
|
||||||
WorkflowMode,
|
WorkflowMode,
|
||||||
@ -139,16 +139,16 @@ export const workflowSlice = createSlice({
|
|||||||
};
|
};
|
||||||
});
|
});
|
||||||
|
|
||||||
builder.addCase(nodesDeleted, (state, action) => {
|
|
||||||
action.payload.forEach((node) => {
|
|
||||||
state.exposedFields = state.exposedFields.filter((f) => f.nodeId !== node.id);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
builder.addCase(nodeEditorReset, () => deepClone(initialWorkflowState));
|
builder.addCase(nodeEditorReset, () => deepClone(initialWorkflowState));
|
||||||
|
|
||||||
builder.addCase(nodesChanged, (state, action) => {
|
builder.addCase(nodesChanged, (state, action) => {
|
||||||
// Not all changes to nodes should result in the workflow being marked touched
|
// Not all changes to nodes should result in the workflow being marked touched
|
||||||
|
action.payload.forEach((change) => {
|
||||||
|
if (change.type === 'remove') {
|
||||||
|
state.exposedFields = state.exposedFields.filter((f) => f.nodeId !== change.id);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
const filteredChanges = action.payload.filter((change) => {
|
const filteredChanges = action.payload.filter((change) => {
|
||||||
// We always want to mark the workflow as touched if a node is added, removed, or reset
|
// We always want to mark the workflow as touched if a node is added, removed, or reset
|
||||||
if (['add', 'remove', 'reset'].includes(change.type)) {
|
if (['add', 'remove', 'reset'].includes(change.type)) {
|
||||||
|
Loading…
Reference in New Issue
Block a user