Workflow editor improvements - add node from empty connection and auto-connect to empy handle. (#4684)

* Initial commit of edge drag feature.

* Fixed build warnings

* code cleanup and drag to existing node

* improved isValidConnection check

* fixed build issues, removed cyclic dependency

* edge created nodes now spawn at cursor

* Add Node popover will no longer show when using drag to delete an edge.

* Fixed collection handling, added priority for handles matching name of source handle, removed current image/notes nodes from filtered list

* Fixed not properly clearing startParams when closing the Add Node popover

* fix(ui): do not allow Collect -> Iterate connection

This can be removed when #3956 is resolved

* feat(ui): use existing node validation logic in add-node-on-drop

This logic handles a number of special cases

---------

Co-authored-by: Millun Atluri <Millu@users.noreply.github.com>
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
This commit is contained in:
CrypticWit 2023-09-29 21:12:57 +13:00 committed by GitHub
parent 95fd2ee6ff
commit dc1e804887
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 308 additions and 50 deletions

View File

@ -17,14 +17,15 @@ import {
addNodePopoverOpened,
nodeAdded,
} from 'features/nodes/store/nodesSlice';
import { map } from 'lodash-es';
import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes';
import { filter, map, some } from 'lodash-es';
import { memo, useCallback, useRef } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { HotkeyCallback } from 'react-hotkeys-hook/dist/types';
import { useTranslation } from 'react-i18next';
import 'reactflow/dist/style.css';
import { AnyInvocationType } from 'services/events/types';
import { AddNodePopoverSelectItem } from './AddNodePopoverSelectItem';
import { useTranslation } from 'react-i18next';
type NodeTemplate = {
label: string;
@ -33,7 +34,7 @@ type NodeTemplate = {
tags: string[];
};
const filter = (value: string, item: NodeTemplate) => {
const selectFilter = (value: string, item: NodeTemplate) => {
const regex = new RegExp(
value
.trim()
@ -55,10 +56,34 @@ const AddNodePopover = () => {
const toaster = useAppToaster();
const { t } = useTranslation();
const fieldFilter = useAppSelector(
(state) => state.nodes.currentConnectionFieldType
);
const handleFilter = useAppSelector(
(state) => state.nodes.connectionStartParams?.handleType
);
const selector = createSelector(
[stateSelector],
({ nodes }) => {
const data: NodeTemplate[] = map(nodes.nodeTemplates, (template) => {
// If we have a connection in progress, we need to filter the node choices
const filteredNodeTemplates = fieldFilter
? filter(nodes.nodeTemplates, (template) => {
const handles =
handleFilter == 'source' ? template.inputs : template.outputs;
return some(handles, (handle) => {
const sourceType =
handleFilter == 'source' ? fieldFilter : handle.type;
const targetType =
handleFilter == 'target' ? fieldFilter : handle.type;
return validateSourceAndTargetTypes(sourceType, targetType);
});
})
: map(nodes.nodeTemplates);
const data: NodeTemplate[] = map(filteredNodeTemplates, (template) => {
return {
label: template.title,
value: template.type,
@ -67,19 +92,22 @@ const AddNodePopover = () => {
};
});
data.push({
label: t('nodes.currentImage'),
value: 'current_image',
description: t('nodes.currentImageDescription'),
tags: ['progress'],
});
//We only want these nodes if we're not filtered
if (fieldFilter === null) {
data.push({
label: t('nodes.currentImage'),
value: 'current_image',
description: t('nodes.currentImageDescription'),
tags: ['progress'],
});
data.push({
label: t('nodes.notes'),
value: 'notes',
description: t('nodes.notesDescription'),
tags: ['notes'],
});
data.push({
label: t('nodes.notes'),
value: 'notes',
description: t('nodes.notesDescription'),
tags: ['notes'],
});
}
data.sort((a, b) => a.label.localeCompare(b.label));
@ -190,7 +218,7 @@ const AddNodePopover = () => {
maxDropdownHeight={400}
nothingFound={t('nodes.noMatchingNodes')}
itemComponent={AddNodePopoverSelectItem}
filter={filter}
filter={selectFilter}
onChange={handleChange}
hoverOnSearchChange={true}
onDropdownClose={onClose}

View File

@ -30,6 +30,7 @@ import {
connectionEnded,
connectionMade,
connectionStarted,
edgeChangeStarted,
edgeAdded,
edgeDeleted,
edgesChanged,
@ -119,7 +120,7 @@ export const Flow = () => {
);
const onConnectEnd: OnConnectEnd = useCallback(() => {
dispatch(connectionEnded());
dispatch(connectionEnded({ cursorPosition: cursorPosition.current }));
}, [dispatch]);
const onEdgesDelete: OnEdgesDelete = useCallback(
@ -194,6 +195,7 @@ export const Flow = () => {
edgeUpdateMouseEvent.current = e;
// always delete the edge when starting an updated
dispatch(edgeDeleted(edge.id));
dispatch(edgeChangeStarted());
},
[dispatch]
);

View File

@ -1,9 +1,9 @@
// TODO: enable this at some point
import graphlib from '@dagrejs/graphlib';
import { useAppSelector } from 'app/store/storeHooks';
import { useCallback } from 'react';
import { Connection, Edge, Node, useReactFlow } from 'reactflow';
import { Connection, Node, useReactFlow } from 'reactflow';
import { validateSourceAndTargetTypes } from '../store/util/validateSourceAndTargetTypes';
import { getIsGraphAcyclic } from '../store/util/getIsGraphAcyclic';
import { InvocationNodeData } from '../types/types';
/**
@ -87,27 +87,3 @@ export const useIsValidConnection = () => {
return isValidConnection;
};
export const getIsGraphAcyclic = (
source: string,
target: string,
nodes: Node[],
edges: Edge[]
) => {
// construct graphlib graph from editor state
const g = new graphlib.Graph();
nodes.forEach((n) => {
g.setNode(n.id);
});
edges.forEach((e) => {
g.setEdge(e.source, e.target);
});
// add the candidate edge
g.setEdge(source, target);
// check if the graph is acyclic
return graphlib.alg.isAcyclic(g);
};

View File

@ -12,4 +12,7 @@ export const nodesPersistDenylist: (keyof NodesState)[] = [
'isReady',
'nodesToCopy',
'edgesToCopy',
'connectionMade',
'modifyingEdge',
'addNewNodePosition',
];

View File

@ -60,6 +60,7 @@ import {
} from '../types/types';
import { NodesState } from './types';
import { findUnoccupiedPosition } from './util/findUnoccupiedPosition';
import { findConnectionToValidHandle } from './util/findConnectionToValidHandle';
export const WORKFLOW_FORMAT_VERSION = '1.0.0';
@ -92,6 +93,9 @@ export const initialNodesState: NodesState = {
isReady: false,
connectionStartParams: null,
currentConnectionFieldType: null,
connectionMade: false,
modifyingEdge: false,
addNewNodePosition: null,
shouldShowFieldTypeLegend: false,
shouldShowMinimapPanel: true,
shouldValidateGraph: true,
@ -153,8 +157,8 @@ const nodesSlice = createSlice({
const node = action.payload;
const position = findUnoccupiedPosition(
state.nodes,
node.position.x,
node.position.y
state.addNewNodePosition?.x ?? node.position.x,
state.addNewNodePosition?.y ?? node.position.y
);
node.position = position;
node.selected = true;
@ -179,6 +183,38 @@ const nodesSlice = createSlice({
nodeId: node.id,
...initialNodeExecutionState,
};
if (state.connectionStartParams) {
const { nodeId, handleId, handleType } = state.connectionStartParams;
if (
nodeId &&
handleId &&
handleType &&
state.currentConnectionFieldType
) {
const newConnection = findConnectionToValidHandle(
node,
state.nodes,
state.edges,
nodeId,
handleId,
handleType,
state.currentConnectionFieldType
);
if (newConnection) {
state.edges = addEdge(
{ ...newConnection, type: 'default' },
state.edges
);
}
}
}
state.connectionStartParams = null;
state.currentConnectionFieldType = null;
},
edgeChangeStarted: (state) => {
state.modifyingEdge = true;
},
edgesChanged: (state, action: PayloadAction<EdgeChange[]>) => {
state.edges = applyEdgeChanges(action.payload, state.edges);
@ -195,6 +231,7 @@ const nodesSlice = createSlice({
},
connectionStarted: (state, action: PayloadAction<OnConnectStartParams>) => {
state.connectionStartParams = action.payload;
state.connectionMade = state.modifyingEdge;
const { nodeId, handleId, handleType } = action.payload;
if (!nodeId || !handleId) {
return;
@ -219,10 +256,53 @@ const nodesSlice = createSlice({
{ ...action.payload, type: 'default' },
state.edges
);
state.connectionMade = true;
},
connectionEnded: (state) => {
state.connectionStartParams = null;
state.currentConnectionFieldType = null;
connectionEnded: (state, action) => {
if (!state.connectionMade) {
if (state.mouseOverNode) {
const nodeIndex = state.nodes.findIndex(
(n) => n.id === state.mouseOverNode
);
const mouseOverNode = state.nodes?.[nodeIndex];
if (mouseOverNode && state.connectionStartParams) {
const { nodeId, handleId, handleType } =
state.connectionStartParams;
if (
nodeId &&
handleId &&
handleType &&
state.currentConnectionFieldType
) {
const newConnection = findConnectionToValidHandle(
mouseOverNode,
state.nodes,
state.edges,
nodeId,
handleId,
handleType,
state.currentConnectionFieldType
);
if (newConnection) {
state.edges = addEdge(
{ ...newConnection, type: 'default' },
state.edges
);
}
}
}
state.connectionStartParams = null;
state.currentConnectionFieldType = null;
} else {
state.addNewNodePosition = action.payload.cursorPosition;
state.isAddNodePopoverOpen = true;
}
} else {
state.connectionStartParams = null;
state.currentConnectionFieldType = null;
}
state.modifyingEdge = false;
},
workflowExposedFieldAdded: (
state,
@ -835,10 +915,15 @@ const nodesSlice = createSlice({
});
},
addNodePopoverOpened: (state) => {
state.addNewNodePosition = null; //Create the node in viewport center by default
state.isAddNodePopoverOpen = true;
},
addNodePopoverClosed: (state) => {
state.isAddNodePopoverOpen = false;
//Make sure these get reset if we close the popover and haven't selected a node
state.connectionStartParams = null;
state.currentConnectionFieldType = null;
},
addNodePopoverToggled: (state) => {
state.isAddNodePopoverOpen = !state.isAddNodePopoverOpen;
@ -913,6 +998,7 @@ export const {
connectionMade,
connectionStarted,
edgeDeleted,
edgeChangeStarted,
edgesChanged,
edgesDeleted,
edgeUpdated,

View File

@ -4,6 +4,7 @@ import {
OnConnectStartParams,
SelectionMode,
Viewport,
XYPosition,
} from 'reactflow';
import {
FieldIdentifier,
@ -21,6 +22,8 @@ export type NodesState = {
nodeTemplates: Record<string, InvocationTemplate>;
connectionStartParams: OnConnectStartParams | null;
currentConnectionFieldType: FieldType | null;
connectionMade: boolean;
modifyingEdge: boolean;
shouldShowFieldTypeLegend: boolean;
shouldShowMinimapPanel: boolean;
shouldValidateGraph: boolean;
@ -39,5 +42,6 @@ export type NodesState = {
nodesToCopy: Node<NodeData>[];
edgesToCopy: Edge<InvocationEdgeExtra>[];
isAddNodePopoverOpen: boolean;
addNewNodePosition: XYPosition | null;
selectionMode: SelectionMode;
};

View File

@ -0,0 +1,126 @@
import { Connection, HandleType } from 'reactflow';
import { Node, Edge } from 'reactflow';
import {
FieldType,
InputFieldValue,
OutputFieldValue,
} from 'features/nodes/types/types';
import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
const isValidConnection = (
edges: Edge[],
handleCurrentType: HandleType,
handleCurrentFieldType: FieldType,
node: Node,
handle: InputFieldValue | OutputFieldValue
) => {
let isValidConnection = true;
if (handleCurrentType === 'source') {
if (
edges.find((edge) => {
return edge.target === node.id && edge.targetHandle === handle.name;
})
) {
isValidConnection = false;
}
} else {
if (
edges.find((edge) => {
return edge.source === node.id && edge.sourceHandle === handle.name;
})
) {
isValidConnection = false;
}
}
if (!validateSourceAndTargetTypes(handleCurrentFieldType, handle.type)) {
isValidConnection = false;
}
return isValidConnection;
};
export const findConnectionToValidHandle = (
node: Node,
nodes: Node[],
edges: Edge[],
handleCurrentNodeId: string,
handleCurrentName: string,
handleCurrentType: HandleType,
handleCurrentFieldType: FieldType
): Connection | null => {
if (node.id === handleCurrentNodeId) {
return null;
}
const handles =
handleCurrentType == 'source' ? node.data.inputs : node.data.outputs;
//Prioritize handles whos name matches the node we're coming from
if (handles[handleCurrentName]) {
const handle = handles[handleCurrentName];
const sourceID =
handleCurrentType == 'source' ? handleCurrentNodeId : node.id;
const targetID =
handleCurrentType == 'source' ? node.id : handleCurrentNodeId;
const sourceHandle =
handleCurrentType == 'source' ? handleCurrentName : handle.name;
const targetHandle =
handleCurrentType == 'source' ? handle.name : handleCurrentName;
const isGraphAcyclic = getIsGraphAcyclic(sourceID, targetID, nodes, edges);
const valid = isValidConnection(
edges,
handleCurrentType,
handleCurrentFieldType,
node,
handle
);
if (isGraphAcyclic && valid) {
return {
source: sourceID,
sourceHandle: sourceHandle,
target: targetID,
targetHandle: targetHandle,
};
}
}
for (const handleName in handles) {
const handle = handles[handleName];
const sourceID =
handleCurrentType == 'source' ? handleCurrentNodeId : node.id;
const targetID =
handleCurrentType == 'source' ? node.id : handleCurrentNodeId;
const sourceHandle =
handleCurrentType == 'source' ? handleCurrentName : handle.name;
const targetHandle =
handleCurrentType == 'source' ? handle.name : handleCurrentName;
const isGraphAcyclic = getIsGraphAcyclic(sourceID, targetID, nodes, edges);
const valid = isValidConnection(
edges,
handleCurrentType,
handleCurrentFieldType,
node,
handle
);
if (isGraphAcyclic && valid) {
return {
source: sourceID,
sourceHandle: sourceHandle,
target: targetID,
targetHandle: targetHandle,
};
}
}
return null;
};

View File

@ -0,0 +1,26 @@
import graphlib from '@dagrejs/graphlib';
import { Edge, Node } from 'reactflow';
export const getIsGraphAcyclic = (
source: string,
target: string,
nodes: Node[],
edges: Edge[]
) => {
// construct graphlib graph from editor state
const g = new graphlib.Graph();
nodes.forEach((n) => {
g.setNode(n.id);
});
edges.forEach((e) => {
g.setEdge(e.source, e.target);
});
// add the candidate edge
g.setEdge(source, target);
// check if the graph is acyclic
return graphlib.alg.isAcyclic(g);
};

View File

@ -1,6 +1,6 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { getIsGraphAcyclic } from 'features/nodes/hooks/useIsValidConnection';
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
import { FieldType } from 'features/nodes/types/types';
import i18n from 'i18next';
import { HandleType } from 'reactflow';

View File

@ -10,6 +10,13 @@ export const validateSourceAndTargetTypes = (
sourceType: FieldType,
targetType: FieldType
) => {
// TODO: There's a bug with Collect -> Iterate nodes:
// https://github.com/invoke-ai/InvokeAI/issues/3956
// Once this is resolved, we can remove this check.
if (sourceType === 'Collection' && targetType === 'Collection') {
return false;
}
if (sourceType === targetType) {
return true;
}