Workflow editor improvements - add node from empty connection and auto-connect to empy handle. (#4684)

* Initial commit of edge drag feature.

* Fixed build warnings

* code cleanup and drag to existing node

* improved isValidConnection check

* fixed build issues, removed cyclic dependency

* edge created nodes now spawn at cursor

* Add Node popover will no longer show when using drag to delete an edge.

* Fixed collection handling, added priority for handles matching name of source handle, removed current image/notes nodes from filtered list

* Fixed not properly clearing startParams when closing the Add Node popover

* fix(ui): do not allow Collect -> Iterate connection

This can be removed when #3956 is resolved

* feat(ui): use existing node validation logic in add-node-on-drop

This logic handles a number of special cases

---------

Co-authored-by: Millun Atluri <Millu@users.noreply.github.com>
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
This commit is contained in:
CrypticWit 2023-09-29 21:12:57 +13:00 committed by GitHub
parent 95fd2ee6ff
commit dc1e804887
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 308 additions and 50 deletions

View File

@ -17,14 +17,15 @@ import {
addNodePopoverOpened, addNodePopoverOpened,
nodeAdded, nodeAdded,
} from 'features/nodes/store/nodesSlice'; } from 'features/nodes/store/nodesSlice';
import { map } from 'lodash-es'; import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes';
import { filter, map, some } from 'lodash-es';
import { memo, useCallback, useRef } from 'react'; import { memo, useCallback, useRef } from 'react';
import { useHotkeys } from 'react-hotkeys-hook'; import { useHotkeys } from 'react-hotkeys-hook';
import { HotkeyCallback } from 'react-hotkeys-hook/dist/types'; import { HotkeyCallback } from 'react-hotkeys-hook/dist/types';
import { useTranslation } from 'react-i18next';
import 'reactflow/dist/style.css'; import 'reactflow/dist/style.css';
import { AnyInvocationType } from 'services/events/types'; import { AnyInvocationType } from 'services/events/types';
import { AddNodePopoverSelectItem } from './AddNodePopoverSelectItem'; import { AddNodePopoverSelectItem } from './AddNodePopoverSelectItem';
import { useTranslation } from 'react-i18next';
type NodeTemplate = { type NodeTemplate = {
label: string; label: string;
@ -33,7 +34,7 @@ type NodeTemplate = {
tags: string[]; tags: string[];
}; };
const filter = (value: string, item: NodeTemplate) => { const selectFilter = (value: string, item: NodeTemplate) => {
const regex = new RegExp( const regex = new RegExp(
value value
.trim() .trim()
@ -55,10 +56,34 @@ const AddNodePopover = () => {
const toaster = useAppToaster(); const toaster = useAppToaster();
const { t } = useTranslation(); const { t } = useTranslation();
const fieldFilter = useAppSelector(
(state) => state.nodes.currentConnectionFieldType
);
const handleFilter = useAppSelector(
(state) => state.nodes.connectionStartParams?.handleType
);
const selector = createSelector( const selector = createSelector(
[stateSelector], [stateSelector],
({ nodes }) => { ({ nodes }) => {
const data: NodeTemplate[] = map(nodes.nodeTemplates, (template) => { // If we have a connection in progress, we need to filter the node choices
const filteredNodeTemplates = fieldFilter
? filter(nodes.nodeTemplates, (template) => {
const handles =
handleFilter == 'source' ? template.inputs : template.outputs;
return some(handles, (handle) => {
const sourceType =
handleFilter == 'source' ? fieldFilter : handle.type;
const targetType =
handleFilter == 'target' ? fieldFilter : handle.type;
return validateSourceAndTargetTypes(sourceType, targetType);
});
})
: map(nodes.nodeTemplates);
const data: NodeTemplate[] = map(filteredNodeTemplates, (template) => {
return { return {
label: template.title, label: template.title,
value: template.type, value: template.type,
@ -67,19 +92,22 @@ const AddNodePopover = () => {
}; };
}); });
data.push({ //We only want these nodes if we're not filtered
label: t('nodes.currentImage'), if (fieldFilter === null) {
value: 'current_image', data.push({
description: t('nodes.currentImageDescription'), label: t('nodes.currentImage'),
tags: ['progress'], value: 'current_image',
}); description: t('nodes.currentImageDescription'),
tags: ['progress'],
});
data.push({ data.push({
label: t('nodes.notes'), label: t('nodes.notes'),
value: 'notes', value: 'notes',
description: t('nodes.notesDescription'), description: t('nodes.notesDescription'),
tags: ['notes'], tags: ['notes'],
}); });
}
data.sort((a, b) => a.label.localeCompare(b.label)); data.sort((a, b) => a.label.localeCompare(b.label));
@ -190,7 +218,7 @@ const AddNodePopover = () => {
maxDropdownHeight={400} maxDropdownHeight={400}
nothingFound={t('nodes.noMatchingNodes')} nothingFound={t('nodes.noMatchingNodes')}
itemComponent={AddNodePopoverSelectItem} itemComponent={AddNodePopoverSelectItem}
filter={filter} filter={selectFilter}
onChange={handleChange} onChange={handleChange}
hoverOnSearchChange={true} hoverOnSearchChange={true}
onDropdownClose={onClose} onDropdownClose={onClose}

View File

@ -30,6 +30,7 @@ import {
connectionEnded, connectionEnded,
connectionMade, connectionMade,
connectionStarted, connectionStarted,
edgeChangeStarted,
edgeAdded, edgeAdded,
edgeDeleted, edgeDeleted,
edgesChanged, edgesChanged,
@ -119,7 +120,7 @@ export const Flow = () => {
); );
const onConnectEnd: OnConnectEnd = useCallback(() => { const onConnectEnd: OnConnectEnd = useCallback(() => {
dispatch(connectionEnded()); dispatch(connectionEnded({ cursorPosition: cursorPosition.current }));
}, [dispatch]); }, [dispatch]);
const onEdgesDelete: OnEdgesDelete = useCallback( const onEdgesDelete: OnEdgesDelete = useCallback(
@ -194,6 +195,7 @@ export const Flow = () => {
edgeUpdateMouseEvent.current = e; edgeUpdateMouseEvent.current = e;
// always delete the edge when starting an updated // always delete the edge when starting an updated
dispatch(edgeDeleted(edge.id)); dispatch(edgeDeleted(edge.id));
dispatch(edgeChangeStarted());
}, },
[dispatch] [dispatch]
); );

View File

@ -1,9 +1,9 @@
// TODO: enable this at some point // TODO: enable this at some point
import graphlib from '@dagrejs/graphlib';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { useCallback } from 'react'; import { useCallback } from 'react';
import { Connection, Edge, Node, useReactFlow } from 'reactflow'; import { Connection, Node, useReactFlow } from 'reactflow';
import { validateSourceAndTargetTypes } from '../store/util/validateSourceAndTargetTypes'; import { validateSourceAndTargetTypes } from '../store/util/validateSourceAndTargetTypes';
import { getIsGraphAcyclic } from '../store/util/getIsGraphAcyclic';
import { InvocationNodeData } from '../types/types'; import { InvocationNodeData } from '../types/types';
/** /**
@ -87,27 +87,3 @@ export const useIsValidConnection = () => {
return isValidConnection; return isValidConnection;
}; };
export const getIsGraphAcyclic = (
source: string,
target: string,
nodes: Node[],
edges: Edge[]
) => {
// construct graphlib graph from editor state
const g = new graphlib.Graph();
nodes.forEach((n) => {
g.setNode(n.id);
});
edges.forEach((e) => {
g.setEdge(e.source, e.target);
});
// add the candidate edge
g.setEdge(source, target);
// check if the graph is acyclic
return graphlib.alg.isAcyclic(g);
};

View File

@ -12,4 +12,7 @@ export const nodesPersistDenylist: (keyof NodesState)[] = [
'isReady', 'isReady',
'nodesToCopy', 'nodesToCopy',
'edgesToCopy', 'edgesToCopy',
'connectionMade',
'modifyingEdge',
'addNewNodePosition',
]; ];

View File

@ -60,6 +60,7 @@ import {
} from '../types/types'; } from '../types/types';
import { NodesState } from './types'; import { NodesState } from './types';
import { findUnoccupiedPosition } from './util/findUnoccupiedPosition'; import { findUnoccupiedPosition } from './util/findUnoccupiedPosition';
import { findConnectionToValidHandle } from './util/findConnectionToValidHandle';
export const WORKFLOW_FORMAT_VERSION = '1.0.0'; export const WORKFLOW_FORMAT_VERSION = '1.0.0';
@ -92,6 +93,9 @@ export const initialNodesState: NodesState = {
isReady: false, isReady: false,
connectionStartParams: null, connectionStartParams: null,
currentConnectionFieldType: null, currentConnectionFieldType: null,
connectionMade: false,
modifyingEdge: false,
addNewNodePosition: null,
shouldShowFieldTypeLegend: false, shouldShowFieldTypeLegend: false,
shouldShowMinimapPanel: true, shouldShowMinimapPanel: true,
shouldValidateGraph: true, shouldValidateGraph: true,
@ -153,8 +157,8 @@ const nodesSlice = createSlice({
const node = action.payload; const node = action.payload;
const position = findUnoccupiedPosition( const position = findUnoccupiedPosition(
state.nodes, state.nodes,
node.position.x, state.addNewNodePosition?.x ?? node.position.x,
node.position.y state.addNewNodePosition?.y ?? node.position.y
); );
node.position = position; node.position = position;
node.selected = true; node.selected = true;
@ -179,6 +183,38 @@ const nodesSlice = createSlice({
nodeId: node.id, nodeId: node.id,
...initialNodeExecutionState, ...initialNodeExecutionState,
}; };
if (state.connectionStartParams) {
const { nodeId, handleId, handleType } = state.connectionStartParams;
if (
nodeId &&
handleId &&
handleType &&
state.currentConnectionFieldType
) {
const newConnection = findConnectionToValidHandle(
node,
state.nodes,
state.edges,
nodeId,
handleId,
handleType,
state.currentConnectionFieldType
);
if (newConnection) {
state.edges = addEdge(
{ ...newConnection, type: 'default' },
state.edges
);
}
}
}
state.connectionStartParams = null;
state.currentConnectionFieldType = null;
},
edgeChangeStarted: (state) => {
state.modifyingEdge = true;
}, },
edgesChanged: (state, action: PayloadAction<EdgeChange[]>) => { edgesChanged: (state, action: PayloadAction<EdgeChange[]>) => {
state.edges = applyEdgeChanges(action.payload, state.edges); state.edges = applyEdgeChanges(action.payload, state.edges);
@ -195,6 +231,7 @@ const nodesSlice = createSlice({
}, },
connectionStarted: (state, action: PayloadAction<OnConnectStartParams>) => { connectionStarted: (state, action: PayloadAction<OnConnectStartParams>) => {
state.connectionStartParams = action.payload; state.connectionStartParams = action.payload;
state.connectionMade = state.modifyingEdge;
const { nodeId, handleId, handleType } = action.payload; const { nodeId, handleId, handleType } = action.payload;
if (!nodeId || !handleId) { if (!nodeId || !handleId) {
return; return;
@ -219,10 +256,53 @@ const nodesSlice = createSlice({
{ ...action.payload, type: 'default' }, { ...action.payload, type: 'default' },
state.edges state.edges
); );
state.connectionMade = true;
}, },
connectionEnded: (state) => { connectionEnded: (state, action) => {
state.connectionStartParams = null; if (!state.connectionMade) {
state.currentConnectionFieldType = null; if (state.mouseOverNode) {
const nodeIndex = state.nodes.findIndex(
(n) => n.id === state.mouseOverNode
);
const mouseOverNode = state.nodes?.[nodeIndex];
if (mouseOverNode && state.connectionStartParams) {
const { nodeId, handleId, handleType } =
state.connectionStartParams;
if (
nodeId &&
handleId &&
handleType &&
state.currentConnectionFieldType
) {
const newConnection = findConnectionToValidHandle(
mouseOverNode,
state.nodes,
state.edges,
nodeId,
handleId,
handleType,
state.currentConnectionFieldType
);
if (newConnection) {
state.edges = addEdge(
{ ...newConnection, type: 'default' },
state.edges
);
}
}
}
state.connectionStartParams = null;
state.currentConnectionFieldType = null;
} else {
state.addNewNodePosition = action.payload.cursorPosition;
state.isAddNodePopoverOpen = true;
}
} else {
state.connectionStartParams = null;
state.currentConnectionFieldType = null;
}
state.modifyingEdge = false;
}, },
workflowExposedFieldAdded: ( workflowExposedFieldAdded: (
state, state,
@ -835,10 +915,15 @@ const nodesSlice = createSlice({
}); });
}, },
addNodePopoverOpened: (state) => { addNodePopoverOpened: (state) => {
state.addNewNodePosition = null; //Create the node in viewport center by default
state.isAddNodePopoverOpen = true; state.isAddNodePopoverOpen = true;
}, },
addNodePopoverClosed: (state) => { addNodePopoverClosed: (state) => {
state.isAddNodePopoverOpen = false; state.isAddNodePopoverOpen = false;
//Make sure these get reset if we close the popover and haven't selected a node
state.connectionStartParams = null;
state.currentConnectionFieldType = null;
}, },
addNodePopoverToggled: (state) => { addNodePopoverToggled: (state) => {
state.isAddNodePopoverOpen = !state.isAddNodePopoverOpen; state.isAddNodePopoverOpen = !state.isAddNodePopoverOpen;
@ -913,6 +998,7 @@ export const {
connectionMade, connectionMade,
connectionStarted, connectionStarted,
edgeDeleted, edgeDeleted,
edgeChangeStarted,
edgesChanged, edgesChanged,
edgesDeleted, edgesDeleted,
edgeUpdated, edgeUpdated,

View File

@ -4,6 +4,7 @@ import {
OnConnectStartParams, OnConnectStartParams,
SelectionMode, SelectionMode,
Viewport, Viewport,
XYPosition,
} from 'reactflow'; } from 'reactflow';
import { import {
FieldIdentifier, FieldIdentifier,
@ -21,6 +22,8 @@ export type NodesState = {
nodeTemplates: Record<string, InvocationTemplate>; nodeTemplates: Record<string, InvocationTemplate>;
connectionStartParams: OnConnectStartParams | null; connectionStartParams: OnConnectStartParams | null;
currentConnectionFieldType: FieldType | null; currentConnectionFieldType: FieldType | null;
connectionMade: boolean;
modifyingEdge: boolean;
shouldShowFieldTypeLegend: boolean; shouldShowFieldTypeLegend: boolean;
shouldShowMinimapPanel: boolean; shouldShowMinimapPanel: boolean;
shouldValidateGraph: boolean; shouldValidateGraph: boolean;
@ -39,5 +42,6 @@ export type NodesState = {
nodesToCopy: Node<NodeData>[]; nodesToCopy: Node<NodeData>[];
edgesToCopy: Edge<InvocationEdgeExtra>[]; edgesToCopy: Edge<InvocationEdgeExtra>[];
isAddNodePopoverOpen: boolean; isAddNodePopoverOpen: boolean;
addNewNodePosition: XYPosition | null;
selectionMode: SelectionMode; selectionMode: SelectionMode;
}; };

View File

@ -0,0 +1,126 @@
import { Connection, HandleType } from 'reactflow';
import { Node, Edge } from 'reactflow';
import {
FieldType,
InputFieldValue,
OutputFieldValue,
} from 'features/nodes/types/types';
import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
const isValidConnection = (
edges: Edge[],
handleCurrentType: HandleType,
handleCurrentFieldType: FieldType,
node: Node,
handle: InputFieldValue | OutputFieldValue
) => {
let isValidConnection = true;
if (handleCurrentType === 'source') {
if (
edges.find((edge) => {
return edge.target === node.id && edge.targetHandle === handle.name;
})
) {
isValidConnection = false;
}
} else {
if (
edges.find((edge) => {
return edge.source === node.id && edge.sourceHandle === handle.name;
})
) {
isValidConnection = false;
}
}
if (!validateSourceAndTargetTypes(handleCurrentFieldType, handle.type)) {
isValidConnection = false;
}
return isValidConnection;
};
export const findConnectionToValidHandle = (
node: Node,
nodes: Node[],
edges: Edge[],
handleCurrentNodeId: string,
handleCurrentName: string,
handleCurrentType: HandleType,
handleCurrentFieldType: FieldType
): Connection | null => {
if (node.id === handleCurrentNodeId) {
return null;
}
const handles =
handleCurrentType == 'source' ? node.data.inputs : node.data.outputs;
//Prioritize handles whos name matches the node we're coming from
if (handles[handleCurrentName]) {
const handle = handles[handleCurrentName];
const sourceID =
handleCurrentType == 'source' ? handleCurrentNodeId : node.id;
const targetID =
handleCurrentType == 'source' ? node.id : handleCurrentNodeId;
const sourceHandle =
handleCurrentType == 'source' ? handleCurrentName : handle.name;
const targetHandle =
handleCurrentType == 'source' ? handle.name : handleCurrentName;
const isGraphAcyclic = getIsGraphAcyclic(sourceID, targetID, nodes, edges);
const valid = isValidConnection(
edges,
handleCurrentType,
handleCurrentFieldType,
node,
handle
);
if (isGraphAcyclic && valid) {
return {
source: sourceID,
sourceHandle: sourceHandle,
target: targetID,
targetHandle: targetHandle,
};
}
}
for (const handleName in handles) {
const handle = handles[handleName];
const sourceID =
handleCurrentType == 'source' ? handleCurrentNodeId : node.id;
const targetID =
handleCurrentType == 'source' ? node.id : handleCurrentNodeId;
const sourceHandle =
handleCurrentType == 'source' ? handleCurrentName : handle.name;
const targetHandle =
handleCurrentType == 'source' ? handle.name : handleCurrentName;
const isGraphAcyclic = getIsGraphAcyclic(sourceID, targetID, nodes, edges);
const valid = isValidConnection(
edges,
handleCurrentType,
handleCurrentFieldType,
node,
handle
);
if (isGraphAcyclic && valid) {
return {
source: sourceID,
sourceHandle: sourceHandle,
target: targetID,
targetHandle: targetHandle,
};
}
}
return null;
};

View File

@ -0,0 +1,26 @@
import graphlib from '@dagrejs/graphlib';
import { Edge, Node } from 'reactflow';
export const getIsGraphAcyclic = (
source: string,
target: string,
nodes: Node[],
edges: Edge[]
) => {
// construct graphlib graph from editor state
const g = new graphlib.Graph();
nodes.forEach((n) => {
g.setNode(n.id);
});
edges.forEach((e) => {
g.setEdge(e.source, e.target);
});
// add the candidate edge
g.setEdge(source, target);
// check if the graph is acyclic
return graphlib.alg.isAcyclic(g);
};

View File

@ -1,6 +1,6 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
import { getIsGraphAcyclic } from 'features/nodes/hooks/useIsValidConnection'; import { getIsGraphAcyclic } from './getIsGraphAcyclic';
import { FieldType } from 'features/nodes/types/types'; import { FieldType } from 'features/nodes/types/types';
import i18n from 'i18next'; import i18n from 'i18next';
import { HandleType } from 'reactflow'; import { HandleType } from 'reactflow';

View File

@ -10,6 +10,13 @@ export const validateSourceAndTargetTypes = (
sourceType: FieldType, sourceType: FieldType,
targetType: FieldType targetType: FieldType
) => { ) => {
// TODO: There's a bug with Collect -> Iterate nodes:
// https://github.com/invoke-ai/InvokeAI/issues/3956
// Once this is resolved, we can remove this check.
if (sourceType === 'Collection' && targetType === 'Collection') {
return false;
}
if (sourceType === targetType) { if (sourceType === targetType) {
return true; return true;
} }