mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Workflow editor improvements - add node from empty connection and auto-connect to empy handle. (#4684)
* Initial commit of edge drag feature. * Fixed build warnings * code cleanup and drag to existing node * improved isValidConnection check * fixed build issues, removed cyclic dependency * edge created nodes now spawn at cursor * Add Node popover will no longer show when using drag to delete an edge. * Fixed collection handling, added priority for handles matching name of source handle, removed current image/notes nodes from filtered list * Fixed not properly clearing startParams when closing the Add Node popover * fix(ui): do not allow Collect -> Iterate connection This can be removed when #3956 is resolved * feat(ui): use existing node validation logic in add-node-on-drop This logic handles a number of special cases --------- Co-authored-by: Millun Atluri <Millu@users.noreply.github.com> Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
This commit is contained in:
parent
95fd2ee6ff
commit
dc1e804887
@ -17,14 +17,15 @@ import {
|
||||
addNodePopoverOpened,
|
||||
nodeAdded,
|
||||
} from 'features/nodes/store/nodesSlice';
|
||||
import { map } from 'lodash-es';
|
||||
import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes';
|
||||
import { filter, map, some } from 'lodash-es';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { HotkeyCallback } from 'react-hotkeys-hook/dist/types';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import 'reactflow/dist/style.css';
|
||||
import { AnyInvocationType } from 'services/events/types';
|
||||
import { AddNodePopoverSelectItem } from './AddNodePopoverSelectItem';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
type NodeTemplate = {
|
||||
label: string;
|
||||
@ -33,7 +34,7 @@ type NodeTemplate = {
|
||||
tags: string[];
|
||||
};
|
||||
|
||||
const filter = (value: string, item: NodeTemplate) => {
|
||||
const selectFilter = (value: string, item: NodeTemplate) => {
|
||||
const regex = new RegExp(
|
||||
value
|
||||
.trim()
|
||||
@ -55,10 +56,34 @@ const AddNodePopover = () => {
|
||||
const toaster = useAppToaster();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const fieldFilter = useAppSelector(
|
||||
(state) => state.nodes.currentConnectionFieldType
|
||||
);
|
||||
const handleFilter = useAppSelector(
|
||||
(state) => state.nodes.connectionStartParams?.handleType
|
||||
);
|
||||
|
||||
const selector = createSelector(
|
||||
[stateSelector],
|
||||
({ nodes }) => {
|
||||
const data: NodeTemplate[] = map(nodes.nodeTemplates, (template) => {
|
||||
// If we have a connection in progress, we need to filter the node choices
|
||||
const filteredNodeTemplates = fieldFilter
|
||||
? filter(nodes.nodeTemplates, (template) => {
|
||||
const handles =
|
||||
handleFilter == 'source' ? template.inputs : template.outputs;
|
||||
|
||||
return some(handles, (handle) => {
|
||||
const sourceType =
|
||||
handleFilter == 'source' ? fieldFilter : handle.type;
|
||||
const targetType =
|
||||
handleFilter == 'target' ? fieldFilter : handle.type;
|
||||
|
||||
return validateSourceAndTargetTypes(sourceType, targetType);
|
||||
});
|
||||
})
|
||||
: map(nodes.nodeTemplates);
|
||||
|
||||
const data: NodeTemplate[] = map(filteredNodeTemplates, (template) => {
|
||||
return {
|
||||
label: template.title,
|
||||
value: template.type,
|
||||
@ -67,19 +92,22 @@ const AddNodePopover = () => {
|
||||
};
|
||||
});
|
||||
|
||||
data.push({
|
||||
label: t('nodes.currentImage'),
|
||||
value: 'current_image',
|
||||
description: t('nodes.currentImageDescription'),
|
||||
tags: ['progress'],
|
||||
});
|
||||
//We only want these nodes if we're not filtered
|
||||
if (fieldFilter === null) {
|
||||
data.push({
|
||||
label: t('nodes.currentImage'),
|
||||
value: 'current_image',
|
||||
description: t('nodes.currentImageDescription'),
|
||||
tags: ['progress'],
|
||||
});
|
||||
|
||||
data.push({
|
||||
label: t('nodes.notes'),
|
||||
value: 'notes',
|
||||
description: t('nodes.notesDescription'),
|
||||
tags: ['notes'],
|
||||
});
|
||||
data.push({
|
||||
label: t('nodes.notes'),
|
||||
value: 'notes',
|
||||
description: t('nodes.notesDescription'),
|
||||
tags: ['notes'],
|
||||
});
|
||||
}
|
||||
|
||||
data.sort((a, b) => a.label.localeCompare(b.label));
|
||||
|
||||
@ -190,7 +218,7 @@ const AddNodePopover = () => {
|
||||
maxDropdownHeight={400}
|
||||
nothingFound={t('nodes.noMatchingNodes')}
|
||||
itemComponent={AddNodePopoverSelectItem}
|
||||
filter={filter}
|
||||
filter={selectFilter}
|
||||
onChange={handleChange}
|
||||
hoverOnSearchChange={true}
|
||||
onDropdownClose={onClose}
|
||||
|
@ -30,6 +30,7 @@ import {
|
||||
connectionEnded,
|
||||
connectionMade,
|
||||
connectionStarted,
|
||||
edgeChangeStarted,
|
||||
edgeAdded,
|
||||
edgeDeleted,
|
||||
edgesChanged,
|
||||
@ -119,7 +120,7 @@ export const Flow = () => {
|
||||
);
|
||||
|
||||
const onConnectEnd: OnConnectEnd = useCallback(() => {
|
||||
dispatch(connectionEnded());
|
||||
dispatch(connectionEnded({ cursorPosition: cursorPosition.current }));
|
||||
}, [dispatch]);
|
||||
|
||||
const onEdgesDelete: OnEdgesDelete = useCallback(
|
||||
@ -194,6 +195,7 @@ export const Flow = () => {
|
||||
edgeUpdateMouseEvent.current = e;
|
||||
// always delete the edge when starting an updated
|
||||
dispatch(edgeDeleted(edge.id));
|
||||
dispatch(edgeChangeStarted());
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
@ -1,9 +1,9 @@
|
||||
// TODO: enable this at some point
|
||||
import graphlib from '@dagrejs/graphlib';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useCallback } from 'react';
|
||||
import { Connection, Edge, Node, useReactFlow } from 'reactflow';
|
||||
import { Connection, Node, useReactFlow } from 'reactflow';
|
||||
import { validateSourceAndTargetTypes } from '../store/util/validateSourceAndTargetTypes';
|
||||
import { getIsGraphAcyclic } from '../store/util/getIsGraphAcyclic';
|
||||
import { InvocationNodeData } from '../types/types';
|
||||
|
||||
/**
|
||||
@ -87,27 +87,3 @@ export const useIsValidConnection = () => {
|
||||
|
||||
return isValidConnection;
|
||||
};
|
||||
|
||||
export const getIsGraphAcyclic = (
|
||||
source: string,
|
||||
target: string,
|
||||
nodes: Node[],
|
||||
edges: Edge[]
|
||||
) => {
|
||||
// construct graphlib graph from editor state
|
||||
const g = new graphlib.Graph();
|
||||
|
||||
nodes.forEach((n) => {
|
||||
g.setNode(n.id);
|
||||
});
|
||||
|
||||
edges.forEach((e) => {
|
||||
g.setEdge(e.source, e.target);
|
||||
});
|
||||
|
||||
// add the candidate edge
|
||||
g.setEdge(source, target);
|
||||
|
||||
// check if the graph is acyclic
|
||||
return graphlib.alg.isAcyclic(g);
|
||||
};
|
||||
|
@ -12,4 +12,7 @@ export const nodesPersistDenylist: (keyof NodesState)[] = [
|
||||
'isReady',
|
||||
'nodesToCopy',
|
||||
'edgesToCopy',
|
||||
'connectionMade',
|
||||
'modifyingEdge',
|
||||
'addNewNodePosition',
|
||||
];
|
||||
|
@ -60,6 +60,7 @@ import {
|
||||
} from '../types/types';
|
||||
import { NodesState } from './types';
|
||||
import { findUnoccupiedPosition } from './util/findUnoccupiedPosition';
|
||||
import { findConnectionToValidHandle } from './util/findConnectionToValidHandle';
|
||||
|
||||
export const WORKFLOW_FORMAT_VERSION = '1.0.0';
|
||||
|
||||
@ -92,6 +93,9 @@ export const initialNodesState: NodesState = {
|
||||
isReady: false,
|
||||
connectionStartParams: null,
|
||||
currentConnectionFieldType: null,
|
||||
connectionMade: false,
|
||||
modifyingEdge: false,
|
||||
addNewNodePosition: null,
|
||||
shouldShowFieldTypeLegend: false,
|
||||
shouldShowMinimapPanel: true,
|
||||
shouldValidateGraph: true,
|
||||
@ -153,8 +157,8 @@ const nodesSlice = createSlice({
|
||||
const node = action.payload;
|
||||
const position = findUnoccupiedPosition(
|
||||
state.nodes,
|
||||
node.position.x,
|
||||
node.position.y
|
||||
state.addNewNodePosition?.x ?? node.position.x,
|
||||
state.addNewNodePosition?.y ?? node.position.y
|
||||
);
|
||||
node.position = position;
|
||||
node.selected = true;
|
||||
@ -179,6 +183,38 @@ const nodesSlice = createSlice({
|
||||
nodeId: node.id,
|
||||
...initialNodeExecutionState,
|
||||
};
|
||||
|
||||
if (state.connectionStartParams) {
|
||||
const { nodeId, handleId, handleType } = state.connectionStartParams;
|
||||
if (
|
||||
nodeId &&
|
||||
handleId &&
|
||||
handleType &&
|
||||
state.currentConnectionFieldType
|
||||
) {
|
||||
const newConnection = findConnectionToValidHandle(
|
||||
node,
|
||||
state.nodes,
|
||||
state.edges,
|
||||
nodeId,
|
||||
handleId,
|
||||
handleType,
|
||||
state.currentConnectionFieldType
|
||||
);
|
||||
if (newConnection) {
|
||||
state.edges = addEdge(
|
||||
{ ...newConnection, type: 'default' },
|
||||
state.edges
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
state.connectionStartParams = null;
|
||||
state.currentConnectionFieldType = null;
|
||||
},
|
||||
edgeChangeStarted: (state) => {
|
||||
state.modifyingEdge = true;
|
||||
},
|
||||
edgesChanged: (state, action: PayloadAction<EdgeChange[]>) => {
|
||||
state.edges = applyEdgeChanges(action.payload, state.edges);
|
||||
@ -195,6 +231,7 @@ const nodesSlice = createSlice({
|
||||
},
|
||||
connectionStarted: (state, action: PayloadAction<OnConnectStartParams>) => {
|
||||
state.connectionStartParams = action.payload;
|
||||
state.connectionMade = state.modifyingEdge;
|
||||
const { nodeId, handleId, handleType } = action.payload;
|
||||
if (!nodeId || !handleId) {
|
||||
return;
|
||||
@ -219,10 +256,53 @@ const nodesSlice = createSlice({
|
||||
{ ...action.payload, type: 'default' },
|
||||
state.edges
|
||||
);
|
||||
|
||||
state.connectionMade = true;
|
||||
},
|
||||
connectionEnded: (state) => {
|
||||
state.connectionStartParams = null;
|
||||
state.currentConnectionFieldType = null;
|
||||
connectionEnded: (state, action) => {
|
||||
if (!state.connectionMade) {
|
||||
if (state.mouseOverNode) {
|
||||
const nodeIndex = state.nodes.findIndex(
|
||||
(n) => n.id === state.mouseOverNode
|
||||
);
|
||||
const mouseOverNode = state.nodes?.[nodeIndex];
|
||||
if (mouseOverNode && state.connectionStartParams) {
|
||||
const { nodeId, handleId, handleType } =
|
||||
state.connectionStartParams;
|
||||
if (
|
||||
nodeId &&
|
||||
handleId &&
|
||||
handleType &&
|
||||
state.currentConnectionFieldType
|
||||
) {
|
||||
const newConnection = findConnectionToValidHandle(
|
||||
mouseOverNode,
|
||||
state.nodes,
|
||||
state.edges,
|
||||
nodeId,
|
||||
handleId,
|
||||
handleType,
|
||||
state.currentConnectionFieldType
|
||||
);
|
||||
if (newConnection) {
|
||||
state.edges = addEdge(
|
||||
{ ...newConnection, type: 'default' },
|
||||
state.edges
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
state.connectionStartParams = null;
|
||||
state.currentConnectionFieldType = null;
|
||||
} else {
|
||||
state.addNewNodePosition = action.payload.cursorPosition;
|
||||
state.isAddNodePopoverOpen = true;
|
||||
}
|
||||
} else {
|
||||
state.connectionStartParams = null;
|
||||
state.currentConnectionFieldType = null;
|
||||
}
|
||||
state.modifyingEdge = false;
|
||||
},
|
||||
workflowExposedFieldAdded: (
|
||||
state,
|
||||
@ -835,10 +915,15 @@ const nodesSlice = createSlice({
|
||||
});
|
||||
},
|
||||
addNodePopoverOpened: (state) => {
|
||||
state.addNewNodePosition = null; //Create the node in viewport center by default
|
||||
state.isAddNodePopoverOpen = true;
|
||||
},
|
||||
addNodePopoverClosed: (state) => {
|
||||
state.isAddNodePopoverOpen = false;
|
||||
|
||||
//Make sure these get reset if we close the popover and haven't selected a node
|
||||
state.connectionStartParams = null;
|
||||
state.currentConnectionFieldType = null;
|
||||
},
|
||||
addNodePopoverToggled: (state) => {
|
||||
state.isAddNodePopoverOpen = !state.isAddNodePopoverOpen;
|
||||
@ -913,6 +998,7 @@ export const {
|
||||
connectionMade,
|
||||
connectionStarted,
|
||||
edgeDeleted,
|
||||
edgeChangeStarted,
|
||||
edgesChanged,
|
||||
edgesDeleted,
|
||||
edgeUpdated,
|
||||
|
@ -4,6 +4,7 @@ import {
|
||||
OnConnectStartParams,
|
||||
SelectionMode,
|
||||
Viewport,
|
||||
XYPosition,
|
||||
} from 'reactflow';
|
||||
import {
|
||||
FieldIdentifier,
|
||||
@ -21,6 +22,8 @@ export type NodesState = {
|
||||
nodeTemplates: Record<string, InvocationTemplate>;
|
||||
connectionStartParams: OnConnectStartParams | null;
|
||||
currentConnectionFieldType: FieldType | null;
|
||||
connectionMade: boolean;
|
||||
modifyingEdge: boolean;
|
||||
shouldShowFieldTypeLegend: boolean;
|
||||
shouldShowMinimapPanel: boolean;
|
||||
shouldValidateGraph: boolean;
|
||||
@ -39,5 +42,6 @@ export type NodesState = {
|
||||
nodesToCopy: Node<NodeData>[];
|
||||
edgesToCopy: Edge<InvocationEdgeExtra>[];
|
||||
isAddNodePopoverOpen: boolean;
|
||||
addNewNodePosition: XYPosition | null;
|
||||
selectionMode: SelectionMode;
|
||||
};
|
||||
|
@ -0,0 +1,126 @@
|
||||
import { Connection, HandleType } from 'reactflow';
|
||||
import { Node, Edge } from 'reactflow';
|
||||
import {
|
||||
FieldType,
|
||||
InputFieldValue,
|
||||
OutputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
|
||||
import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
|
||||
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
|
||||
|
||||
const isValidConnection = (
|
||||
edges: Edge[],
|
||||
handleCurrentType: HandleType,
|
||||
handleCurrentFieldType: FieldType,
|
||||
node: Node,
|
||||
handle: InputFieldValue | OutputFieldValue
|
||||
) => {
|
||||
let isValidConnection = true;
|
||||
if (handleCurrentType === 'source') {
|
||||
if (
|
||||
edges.find((edge) => {
|
||||
return edge.target === node.id && edge.targetHandle === handle.name;
|
||||
})
|
||||
) {
|
||||
isValidConnection = false;
|
||||
}
|
||||
} else {
|
||||
if (
|
||||
edges.find((edge) => {
|
||||
return edge.source === node.id && edge.sourceHandle === handle.name;
|
||||
})
|
||||
) {
|
||||
isValidConnection = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!validateSourceAndTargetTypes(handleCurrentFieldType, handle.type)) {
|
||||
isValidConnection = false;
|
||||
}
|
||||
|
||||
return isValidConnection;
|
||||
};
|
||||
|
||||
export const findConnectionToValidHandle = (
|
||||
node: Node,
|
||||
nodes: Node[],
|
||||
edges: Edge[],
|
||||
handleCurrentNodeId: string,
|
||||
handleCurrentName: string,
|
||||
handleCurrentType: HandleType,
|
||||
handleCurrentFieldType: FieldType
|
||||
): Connection | null => {
|
||||
if (node.id === handleCurrentNodeId) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const handles =
|
||||
handleCurrentType == 'source' ? node.data.inputs : node.data.outputs;
|
||||
|
||||
//Prioritize handles whos name matches the node we're coming from
|
||||
if (handles[handleCurrentName]) {
|
||||
const handle = handles[handleCurrentName];
|
||||
|
||||
const sourceID =
|
||||
handleCurrentType == 'source' ? handleCurrentNodeId : node.id;
|
||||
const targetID =
|
||||
handleCurrentType == 'source' ? node.id : handleCurrentNodeId;
|
||||
const sourceHandle =
|
||||
handleCurrentType == 'source' ? handleCurrentName : handle.name;
|
||||
const targetHandle =
|
||||
handleCurrentType == 'source' ? handle.name : handleCurrentName;
|
||||
|
||||
const isGraphAcyclic = getIsGraphAcyclic(sourceID, targetID, nodes, edges);
|
||||
|
||||
const valid = isValidConnection(
|
||||
edges,
|
||||
handleCurrentType,
|
||||
handleCurrentFieldType,
|
||||
node,
|
||||
handle
|
||||
);
|
||||
|
||||
if (isGraphAcyclic && valid) {
|
||||
return {
|
||||
source: sourceID,
|
||||
sourceHandle: sourceHandle,
|
||||
target: targetID,
|
||||
targetHandle: targetHandle,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
for (const handleName in handles) {
|
||||
const handle = handles[handleName];
|
||||
|
||||
const sourceID =
|
||||
handleCurrentType == 'source' ? handleCurrentNodeId : node.id;
|
||||
const targetID =
|
||||
handleCurrentType == 'source' ? node.id : handleCurrentNodeId;
|
||||
const sourceHandle =
|
||||
handleCurrentType == 'source' ? handleCurrentName : handle.name;
|
||||
const targetHandle =
|
||||
handleCurrentType == 'source' ? handle.name : handleCurrentName;
|
||||
|
||||
const isGraphAcyclic = getIsGraphAcyclic(sourceID, targetID, nodes, edges);
|
||||
|
||||
const valid = isValidConnection(
|
||||
edges,
|
||||
handleCurrentType,
|
||||
handleCurrentFieldType,
|
||||
node,
|
||||
handle
|
||||
);
|
||||
|
||||
if (isGraphAcyclic && valid) {
|
||||
return {
|
||||
source: sourceID,
|
||||
sourceHandle: sourceHandle,
|
||||
target: targetID,
|
||||
targetHandle: targetHandle,
|
||||
};
|
||||
}
|
||||
}
|
||||
return null;
|
||||
};
|
@ -0,0 +1,26 @@
|
||||
import graphlib from '@dagrejs/graphlib';
|
||||
import { Edge, Node } from 'reactflow';
|
||||
|
||||
export const getIsGraphAcyclic = (
|
||||
source: string,
|
||||
target: string,
|
||||
nodes: Node[],
|
||||
edges: Edge[]
|
||||
) => {
|
||||
// construct graphlib graph from editor state
|
||||
const g = new graphlib.Graph();
|
||||
|
||||
nodes.forEach((n) => {
|
||||
g.setNode(n.id);
|
||||
});
|
||||
|
||||
edges.forEach((e) => {
|
||||
g.setEdge(e.source, e.target);
|
||||
});
|
||||
|
||||
// add the candidate edge
|
||||
g.setEdge(source, target);
|
||||
|
||||
// check if the graph is acyclic
|
||||
return graphlib.alg.isAcyclic(g);
|
||||
};
|
@ -1,6 +1,6 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { getIsGraphAcyclic } from 'features/nodes/hooks/useIsValidConnection';
|
||||
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
|
||||
import { FieldType } from 'features/nodes/types/types';
|
||||
import i18n from 'i18next';
|
||||
import { HandleType } from 'reactflow';
|
||||
|
@ -10,6 +10,13 @@ export const validateSourceAndTargetTypes = (
|
||||
sourceType: FieldType,
|
||||
targetType: FieldType
|
||||
) => {
|
||||
// TODO: There's a bug with Collect -> Iterate nodes:
|
||||
// https://github.com/invoke-ai/InvokeAI/issues/3956
|
||||
// Once this is resolved, we can remove this check.
|
||||
if (sourceType === 'Collection' && targetType === 'Collection') {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (sourceType === targetType) {
|
||||
return true;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user