mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Workflow editor improvements - add node from empty connection and auto-connect to empy handle. (#4684)
* Initial commit of edge drag feature. * Fixed build warnings * code cleanup and drag to existing node * improved isValidConnection check * fixed build issues, removed cyclic dependency * edge created nodes now spawn at cursor * Add Node popover will no longer show when using drag to delete an edge. * Fixed collection handling, added priority for handles matching name of source handle, removed current image/notes nodes from filtered list * Fixed not properly clearing startParams when closing the Add Node popover * fix(ui): do not allow Collect -> Iterate connection This can be removed when #3956 is resolved * feat(ui): use existing node validation logic in add-node-on-drop This logic handles a number of special cases --------- Co-authored-by: Millun Atluri <Millu@users.noreply.github.com> Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
This commit is contained in:
parent
95fd2ee6ff
commit
dc1e804887
@ -17,14 +17,15 @@ import {
|
|||||||
addNodePopoverOpened,
|
addNodePopoverOpened,
|
||||||
nodeAdded,
|
nodeAdded,
|
||||||
} from 'features/nodes/store/nodesSlice';
|
} from 'features/nodes/store/nodesSlice';
|
||||||
import { map } from 'lodash-es';
|
import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes';
|
||||||
|
import { filter, map, some } from 'lodash-es';
|
||||||
import { memo, useCallback, useRef } from 'react';
|
import { memo, useCallback, useRef } from 'react';
|
||||||
import { useHotkeys } from 'react-hotkeys-hook';
|
import { useHotkeys } from 'react-hotkeys-hook';
|
||||||
import { HotkeyCallback } from 'react-hotkeys-hook/dist/types';
|
import { HotkeyCallback } from 'react-hotkeys-hook/dist/types';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
import 'reactflow/dist/style.css';
|
import 'reactflow/dist/style.css';
|
||||||
import { AnyInvocationType } from 'services/events/types';
|
import { AnyInvocationType } from 'services/events/types';
|
||||||
import { AddNodePopoverSelectItem } from './AddNodePopoverSelectItem';
|
import { AddNodePopoverSelectItem } from './AddNodePopoverSelectItem';
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
|
|
||||||
type NodeTemplate = {
|
type NodeTemplate = {
|
||||||
label: string;
|
label: string;
|
||||||
@ -33,7 +34,7 @@ type NodeTemplate = {
|
|||||||
tags: string[];
|
tags: string[];
|
||||||
};
|
};
|
||||||
|
|
||||||
const filter = (value: string, item: NodeTemplate) => {
|
const selectFilter = (value: string, item: NodeTemplate) => {
|
||||||
const regex = new RegExp(
|
const regex = new RegExp(
|
||||||
value
|
value
|
||||||
.trim()
|
.trim()
|
||||||
@ -55,10 +56,34 @@ const AddNodePopover = () => {
|
|||||||
const toaster = useAppToaster();
|
const toaster = useAppToaster();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const fieldFilter = useAppSelector(
|
||||||
|
(state) => state.nodes.currentConnectionFieldType
|
||||||
|
);
|
||||||
|
const handleFilter = useAppSelector(
|
||||||
|
(state) => state.nodes.connectionStartParams?.handleType
|
||||||
|
);
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
[stateSelector],
|
[stateSelector],
|
||||||
({ nodes }) => {
|
({ nodes }) => {
|
||||||
const data: NodeTemplate[] = map(nodes.nodeTemplates, (template) => {
|
// If we have a connection in progress, we need to filter the node choices
|
||||||
|
const filteredNodeTemplates = fieldFilter
|
||||||
|
? filter(nodes.nodeTemplates, (template) => {
|
||||||
|
const handles =
|
||||||
|
handleFilter == 'source' ? template.inputs : template.outputs;
|
||||||
|
|
||||||
|
return some(handles, (handle) => {
|
||||||
|
const sourceType =
|
||||||
|
handleFilter == 'source' ? fieldFilter : handle.type;
|
||||||
|
const targetType =
|
||||||
|
handleFilter == 'target' ? fieldFilter : handle.type;
|
||||||
|
|
||||||
|
return validateSourceAndTargetTypes(sourceType, targetType);
|
||||||
|
});
|
||||||
|
})
|
||||||
|
: map(nodes.nodeTemplates);
|
||||||
|
|
||||||
|
const data: NodeTemplate[] = map(filteredNodeTemplates, (template) => {
|
||||||
return {
|
return {
|
||||||
label: template.title,
|
label: template.title,
|
||||||
value: template.type,
|
value: template.type,
|
||||||
@ -67,19 +92,22 @@ const AddNodePopover = () => {
|
|||||||
};
|
};
|
||||||
});
|
});
|
||||||
|
|
||||||
data.push({
|
//We only want these nodes if we're not filtered
|
||||||
label: t('nodes.currentImage'),
|
if (fieldFilter === null) {
|
||||||
value: 'current_image',
|
data.push({
|
||||||
description: t('nodes.currentImageDescription'),
|
label: t('nodes.currentImage'),
|
||||||
tags: ['progress'],
|
value: 'current_image',
|
||||||
});
|
description: t('nodes.currentImageDescription'),
|
||||||
|
tags: ['progress'],
|
||||||
|
});
|
||||||
|
|
||||||
data.push({
|
data.push({
|
||||||
label: t('nodes.notes'),
|
label: t('nodes.notes'),
|
||||||
value: 'notes',
|
value: 'notes',
|
||||||
description: t('nodes.notesDescription'),
|
description: t('nodes.notesDescription'),
|
||||||
tags: ['notes'],
|
tags: ['notes'],
|
||||||
});
|
});
|
||||||
|
}
|
||||||
|
|
||||||
data.sort((a, b) => a.label.localeCompare(b.label));
|
data.sort((a, b) => a.label.localeCompare(b.label));
|
||||||
|
|
||||||
@ -190,7 +218,7 @@ const AddNodePopover = () => {
|
|||||||
maxDropdownHeight={400}
|
maxDropdownHeight={400}
|
||||||
nothingFound={t('nodes.noMatchingNodes')}
|
nothingFound={t('nodes.noMatchingNodes')}
|
||||||
itemComponent={AddNodePopoverSelectItem}
|
itemComponent={AddNodePopoverSelectItem}
|
||||||
filter={filter}
|
filter={selectFilter}
|
||||||
onChange={handleChange}
|
onChange={handleChange}
|
||||||
hoverOnSearchChange={true}
|
hoverOnSearchChange={true}
|
||||||
onDropdownClose={onClose}
|
onDropdownClose={onClose}
|
||||||
|
@ -30,6 +30,7 @@ import {
|
|||||||
connectionEnded,
|
connectionEnded,
|
||||||
connectionMade,
|
connectionMade,
|
||||||
connectionStarted,
|
connectionStarted,
|
||||||
|
edgeChangeStarted,
|
||||||
edgeAdded,
|
edgeAdded,
|
||||||
edgeDeleted,
|
edgeDeleted,
|
||||||
edgesChanged,
|
edgesChanged,
|
||||||
@ -119,7 +120,7 @@ export const Flow = () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
const onConnectEnd: OnConnectEnd = useCallback(() => {
|
const onConnectEnd: OnConnectEnd = useCallback(() => {
|
||||||
dispatch(connectionEnded());
|
dispatch(connectionEnded({ cursorPosition: cursorPosition.current }));
|
||||||
}, [dispatch]);
|
}, [dispatch]);
|
||||||
|
|
||||||
const onEdgesDelete: OnEdgesDelete = useCallback(
|
const onEdgesDelete: OnEdgesDelete = useCallback(
|
||||||
@ -194,6 +195,7 @@ export const Flow = () => {
|
|||||||
edgeUpdateMouseEvent.current = e;
|
edgeUpdateMouseEvent.current = e;
|
||||||
// always delete the edge when starting an updated
|
// always delete the edge when starting an updated
|
||||||
dispatch(edgeDeleted(edge.id));
|
dispatch(edgeDeleted(edge.id));
|
||||||
|
dispatch(edgeChangeStarted());
|
||||||
},
|
},
|
||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
// TODO: enable this at some point
|
// TODO: enable this at some point
|
||||||
import graphlib from '@dagrejs/graphlib';
|
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { useCallback } from 'react';
|
import { useCallback } from 'react';
|
||||||
import { Connection, Edge, Node, useReactFlow } from 'reactflow';
|
import { Connection, Node, useReactFlow } from 'reactflow';
|
||||||
import { validateSourceAndTargetTypes } from '../store/util/validateSourceAndTargetTypes';
|
import { validateSourceAndTargetTypes } from '../store/util/validateSourceAndTargetTypes';
|
||||||
|
import { getIsGraphAcyclic } from '../store/util/getIsGraphAcyclic';
|
||||||
import { InvocationNodeData } from '../types/types';
|
import { InvocationNodeData } from '../types/types';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -87,27 +87,3 @@ export const useIsValidConnection = () => {
|
|||||||
|
|
||||||
return isValidConnection;
|
return isValidConnection;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const getIsGraphAcyclic = (
|
|
||||||
source: string,
|
|
||||||
target: string,
|
|
||||||
nodes: Node[],
|
|
||||||
edges: Edge[]
|
|
||||||
) => {
|
|
||||||
// construct graphlib graph from editor state
|
|
||||||
const g = new graphlib.Graph();
|
|
||||||
|
|
||||||
nodes.forEach((n) => {
|
|
||||||
g.setNode(n.id);
|
|
||||||
});
|
|
||||||
|
|
||||||
edges.forEach((e) => {
|
|
||||||
g.setEdge(e.source, e.target);
|
|
||||||
});
|
|
||||||
|
|
||||||
// add the candidate edge
|
|
||||||
g.setEdge(source, target);
|
|
||||||
|
|
||||||
// check if the graph is acyclic
|
|
||||||
return graphlib.alg.isAcyclic(g);
|
|
||||||
};
|
|
||||||
|
@ -12,4 +12,7 @@ export const nodesPersistDenylist: (keyof NodesState)[] = [
|
|||||||
'isReady',
|
'isReady',
|
||||||
'nodesToCopy',
|
'nodesToCopy',
|
||||||
'edgesToCopy',
|
'edgesToCopy',
|
||||||
|
'connectionMade',
|
||||||
|
'modifyingEdge',
|
||||||
|
'addNewNodePosition',
|
||||||
];
|
];
|
||||||
|
@ -60,6 +60,7 @@ import {
|
|||||||
} from '../types/types';
|
} from '../types/types';
|
||||||
import { NodesState } from './types';
|
import { NodesState } from './types';
|
||||||
import { findUnoccupiedPosition } from './util/findUnoccupiedPosition';
|
import { findUnoccupiedPosition } from './util/findUnoccupiedPosition';
|
||||||
|
import { findConnectionToValidHandle } from './util/findConnectionToValidHandle';
|
||||||
|
|
||||||
export const WORKFLOW_FORMAT_VERSION = '1.0.0';
|
export const WORKFLOW_FORMAT_VERSION = '1.0.0';
|
||||||
|
|
||||||
@ -92,6 +93,9 @@ export const initialNodesState: NodesState = {
|
|||||||
isReady: false,
|
isReady: false,
|
||||||
connectionStartParams: null,
|
connectionStartParams: null,
|
||||||
currentConnectionFieldType: null,
|
currentConnectionFieldType: null,
|
||||||
|
connectionMade: false,
|
||||||
|
modifyingEdge: false,
|
||||||
|
addNewNodePosition: null,
|
||||||
shouldShowFieldTypeLegend: false,
|
shouldShowFieldTypeLegend: false,
|
||||||
shouldShowMinimapPanel: true,
|
shouldShowMinimapPanel: true,
|
||||||
shouldValidateGraph: true,
|
shouldValidateGraph: true,
|
||||||
@ -153,8 +157,8 @@ const nodesSlice = createSlice({
|
|||||||
const node = action.payload;
|
const node = action.payload;
|
||||||
const position = findUnoccupiedPosition(
|
const position = findUnoccupiedPosition(
|
||||||
state.nodes,
|
state.nodes,
|
||||||
node.position.x,
|
state.addNewNodePosition?.x ?? node.position.x,
|
||||||
node.position.y
|
state.addNewNodePosition?.y ?? node.position.y
|
||||||
);
|
);
|
||||||
node.position = position;
|
node.position = position;
|
||||||
node.selected = true;
|
node.selected = true;
|
||||||
@ -179,6 +183,38 @@ const nodesSlice = createSlice({
|
|||||||
nodeId: node.id,
|
nodeId: node.id,
|
||||||
...initialNodeExecutionState,
|
...initialNodeExecutionState,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if (state.connectionStartParams) {
|
||||||
|
const { nodeId, handleId, handleType } = state.connectionStartParams;
|
||||||
|
if (
|
||||||
|
nodeId &&
|
||||||
|
handleId &&
|
||||||
|
handleType &&
|
||||||
|
state.currentConnectionFieldType
|
||||||
|
) {
|
||||||
|
const newConnection = findConnectionToValidHandle(
|
||||||
|
node,
|
||||||
|
state.nodes,
|
||||||
|
state.edges,
|
||||||
|
nodeId,
|
||||||
|
handleId,
|
||||||
|
handleType,
|
||||||
|
state.currentConnectionFieldType
|
||||||
|
);
|
||||||
|
if (newConnection) {
|
||||||
|
state.edges = addEdge(
|
||||||
|
{ ...newConnection, type: 'default' },
|
||||||
|
state.edges
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
state.connectionStartParams = null;
|
||||||
|
state.currentConnectionFieldType = null;
|
||||||
|
},
|
||||||
|
edgeChangeStarted: (state) => {
|
||||||
|
state.modifyingEdge = true;
|
||||||
},
|
},
|
||||||
edgesChanged: (state, action: PayloadAction<EdgeChange[]>) => {
|
edgesChanged: (state, action: PayloadAction<EdgeChange[]>) => {
|
||||||
state.edges = applyEdgeChanges(action.payload, state.edges);
|
state.edges = applyEdgeChanges(action.payload, state.edges);
|
||||||
@ -195,6 +231,7 @@ const nodesSlice = createSlice({
|
|||||||
},
|
},
|
||||||
connectionStarted: (state, action: PayloadAction<OnConnectStartParams>) => {
|
connectionStarted: (state, action: PayloadAction<OnConnectStartParams>) => {
|
||||||
state.connectionStartParams = action.payload;
|
state.connectionStartParams = action.payload;
|
||||||
|
state.connectionMade = state.modifyingEdge;
|
||||||
const { nodeId, handleId, handleType } = action.payload;
|
const { nodeId, handleId, handleType } = action.payload;
|
||||||
if (!nodeId || !handleId) {
|
if (!nodeId || !handleId) {
|
||||||
return;
|
return;
|
||||||
@ -219,10 +256,53 @@ const nodesSlice = createSlice({
|
|||||||
{ ...action.payload, type: 'default' },
|
{ ...action.payload, type: 'default' },
|
||||||
state.edges
|
state.edges
|
||||||
);
|
);
|
||||||
|
|
||||||
|
state.connectionMade = true;
|
||||||
},
|
},
|
||||||
connectionEnded: (state) => {
|
connectionEnded: (state, action) => {
|
||||||
state.connectionStartParams = null;
|
if (!state.connectionMade) {
|
||||||
state.currentConnectionFieldType = null;
|
if (state.mouseOverNode) {
|
||||||
|
const nodeIndex = state.nodes.findIndex(
|
||||||
|
(n) => n.id === state.mouseOverNode
|
||||||
|
);
|
||||||
|
const mouseOverNode = state.nodes?.[nodeIndex];
|
||||||
|
if (mouseOverNode && state.connectionStartParams) {
|
||||||
|
const { nodeId, handleId, handleType } =
|
||||||
|
state.connectionStartParams;
|
||||||
|
if (
|
||||||
|
nodeId &&
|
||||||
|
handleId &&
|
||||||
|
handleType &&
|
||||||
|
state.currentConnectionFieldType
|
||||||
|
) {
|
||||||
|
const newConnection = findConnectionToValidHandle(
|
||||||
|
mouseOverNode,
|
||||||
|
state.nodes,
|
||||||
|
state.edges,
|
||||||
|
nodeId,
|
||||||
|
handleId,
|
||||||
|
handleType,
|
||||||
|
state.currentConnectionFieldType
|
||||||
|
);
|
||||||
|
if (newConnection) {
|
||||||
|
state.edges = addEdge(
|
||||||
|
{ ...newConnection, type: 'default' },
|
||||||
|
state.edges
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
state.connectionStartParams = null;
|
||||||
|
state.currentConnectionFieldType = null;
|
||||||
|
} else {
|
||||||
|
state.addNewNodePosition = action.payload.cursorPosition;
|
||||||
|
state.isAddNodePopoverOpen = true;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
state.connectionStartParams = null;
|
||||||
|
state.currentConnectionFieldType = null;
|
||||||
|
}
|
||||||
|
state.modifyingEdge = false;
|
||||||
},
|
},
|
||||||
workflowExposedFieldAdded: (
|
workflowExposedFieldAdded: (
|
||||||
state,
|
state,
|
||||||
@ -835,10 +915,15 @@ const nodesSlice = createSlice({
|
|||||||
});
|
});
|
||||||
},
|
},
|
||||||
addNodePopoverOpened: (state) => {
|
addNodePopoverOpened: (state) => {
|
||||||
|
state.addNewNodePosition = null; //Create the node in viewport center by default
|
||||||
state.isAddNodePopoverOpen = true;
|
state.isAddNodePopoverOpen = true;
|
||||||
},
|
},
|
||||||
addNodePopoverClosed: (state) => {
|
addNodePopoverClosed: (state) => {
|
||||||
state.isAddNodePopoverOpen = false;
|
state.isAddNodePopoverOpen = false;
|
||||||
|
|
||||||
|
//Make sure these get reset if we close the popover and haven't selected a node
|
||||||
|
state.connectionStartParams = null;
|
||||||
|
state.currentConnectionFieldType = null;
|
||||||
},
|
},
|
||||||
addNodePopoverToggled: (state) => {
|
addNodePopoverToggled: (state) => {
|
||||||
state.isAddNodePopoverOpen = !state.isAddNodePopoverOpen;
|
state.isAddNodePopoverOpen = !state.isAddNodePopoverOpen;
|
||||||
@ -913,6 +998,7 @@ export const {
|
|||||||
connectionMade,
|
connectionMade,
|
||||||
connectionStarted,
|
connectionStarted,
|
||||||
edgeDeleted,
|
edgeDeleted,
|
||||||
|
edgeChangeStarted,
|
||||||
edgesChanged,
|
edgesChanged,
|
||||||
edgesDeleted,
|
edgesDeleted,
|
||||||
edgeUpdated,
|
edgeUpdated,
|
||||||
|
@ -4,6 +4,7 @@ import {
|
|||||||
OnConnectStartParams,
|
OnConnectStartParams,
|
||||||
SelectionMode,
|
SelectionMode,
|
||||||
Viewport,
|
Viewport,
|
||||||
|
XYPosition,
|
||||||
} from 'reactflow';
|
} from 'reactflow';
|
||||||
import {
|
import {
|
||||||
FieldIdentifier,
|
FieldIdentifier,
|
||||||
@ -21,6 +22,8 @@ export type NodesState = {
|
|||||||
nodeTemplates: Record<string, InvocationTemplate>;
|
nodeTemplates: Record<string, InvocationTemplate>;
|
||||||
connectionStartParams: OnConnectStartParams | null;
|
connectionStartParams: OnConnectStartParams | null;
|
||||||
currentConnectionFieldType: FieldType | null;
|
currentConnectionFieldType: FieldType | null;
|
||||||
|
connectionMade: boolean;
|
||||||
|
modifyingEdge: boolean;
|
||||||
shouldShowFieldTypeLegend: boolean;
|
shouldShowFieldTypeLegend: boolean;
|
||||||
shouldShowMinimapPanel: boolean;
|
shouldShowMinimapPanel: boolean;
|
||||||
shouldValidateGraph: boolean;
|
shouldValidateGraph: boolean;
|
||||||
@ -39,5 +42,6 @@ export type NodesState = {
|
|||||||
nodesToCopy: Node<NodeData>[];
|
nodesToCopy: Node<NodeData>[];
|
||||||
edgesToCopy: Edge<InvocationEdgeExtra>[];
|
edgesToCopy: Edge<InvocationEdgeExtra>[];
|
||||||
isAddNodePopoverOpen: boolean;
|
isAddNodePopoverOpen: boolean;
|
||||||
|
addNewNodePosition: XYPosition | null;
|
||||||
selectionMode: SelectionMode;
|
selectionMode: SelectionMode;
|
||||||
};
|
};
|
||||||
|
@ -0,0 +1,126 @@
|
|||||||
|
import { Connection, HandleType } from 'reactflow';
|
||||||
|
import { Node, Edge } from 'reactflow';
|
||||||
|
import {
|
||||||
|
FieldType,
|
||||||
|
InputFieldValue,
|
||||||
|
OutputFieldValue,
|
||||||
|
} from 'features/nodes/types/types';
|
||||||
|
|
||||||
|
import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
|
||||||
|
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
|
||||||
|
|
||||||
|
const isValidConnection = (
|
||||||
|
edges: Edge[],
|
||||||
|
handleCurrentType: HandleType,
|
||||||
|
handleCurrentFieldType: FieldType,
|
||||||
|
node: Node,
|
||||||
|
handle: InputFieldValue | OutputFieldValue
|
||||||
|
) => {
|
||||||
|
let isValidConnection = true;
|
||||||
|
if (handleCurrentType === 'source') {
|
||||||
|
if (
|
||||||
|
edges.find((edge) => {
|
||||||
|
return edge.target === node.id && edge.targetHandle === handle.name;
|
||||||
|
})
|
||||||
|
) {
|
||||||
|
isValidConnection = false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (
|
||||||
|
edges.find((edge) => {
|
||||||
|
return edge.source === node.id && edge.sourceHandle === handle.name;
|
||||||
|
})
|
||||||
|
) {
|
||||||
|
isValidConnection = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!validateSourceAndTargetTypes(handleCurrentFieldType, handle.type)) {
|
||||||
|
isValidConnection = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return isValidConnection;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const findConnectionToValidHandle = (
|
||||||
|
node: Node,
|
||||||
|
nodes: Node[],
|
||||||
|
edges: Edge[],
|
||||||
|
handleCurrentNodeId: string,
|
||||||
|
handleCurrentName: string,
|
||||||
|
handleCurrentType: HandleType,
|
||||||
|
handleCurrentFieldType: FieldType
|
||||||
|
): Connection | null => {
|
||||||
|
if (node.id === handleCurrentNodeId) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
const handles =
|
||||||
|
handleCurrentType == 'source' ? node.data.inputs : node.data.outputs;
|
||||||
|
|
||||||
|
//Prioritize handles whos name matches the node we're coming from
|
||||||
|
if (handles[handleCurrentName]) {
|
||||||
|
const handle = handles[handleCurrentName];
|
||||||
|
|
||||||
|
const sourceID =
|
||||||
|
handleCurrentType == 'source' ? handleCurrentNodeId : node.id;
|
||||||
|
const targetID =
|
||||||
|
handleCurrentType == 'source' ? node.id : handleCurrentNodeId;
|
||||||
|
const sourceHandle =
|
||||||
|
handleCurrentType == 'source' ? handleCurrentName : handle.name;
|
||||||
|
const targetHandle =
|
||||||
|
handleCurrentType == 'source' ? handle.name : handleCurrentName;
|
||||||
|
|
||||||
|
const isGraphAcyclic = getIsGraphAcyclic(sourceID, targetID, nodes, edges);
|
||||||
|
|
||||||
|
const valid = isValidConnection(
|
||||||
|
edges,
|
||||||
|
handleCurrentType,
|
||||||
|
handleCurrentFieldType,
|
||||||
|
node,
|
||||||
|
handle
|
||||||
|
);
|
||||||
|
|
||||||
|
if (isGraphAcyclic && valid) {
|
||||||
|
return {
|
||||||
|
source: sourceID,
|
||||||
|
sourceHandle: sourceHandle,
|
||||||
|
target: targetID,
|
||||||
|
targetHandle: targetHandle,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const handleName in handles) {
|
||||||
|
const handle = handles[handleName];
|
||||||
|
|
||||||
|
const sourceID =
|
||||||
|
handleCurrentType == 'source' ? handleCurrentNodeId : node.id;
|
||||||
|
const targetID =
|
||||||
|
handleCurrentType == 'source' ? node.id : handleCurrentNodeId;
|
||||||
|
const sourceHandle =
|
||||||
|
handleCurrentType == 'source' ? handleCurrentName : handle.name;
|
||||||
|
const targetHandle =
|
||||||
|
handleCurrentType == 'source' ? handle.name : handleCurrentName;
|
||||||
|
|
||||||
|
const isGraphAcyclic = getIsGraphAcyclic(sourceID, targetID, nodes, edges);
|
||||||
|
|
||||||
|
const valid = isValidConnection(
|
||||||
|
edges,
|
||||||
|
handleCurrentType,
|
||||||
|
handleCurrentFieldType,
|
||||||
|
node,
|
||||||
|
handle
|
||||||
|
);
|
||||||
|
|
||||||
|
if (isGraphAcyclic && valid) {
|
||||||
|
return {
|
||||||
|
source: sourceID,
|
||||||
|
sourceHandle: sourceHandle,
|
||||||
|
target: targetID,
|
||||||
|
targetHandle: targetHandle,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
};
|
@ -0,0 +1,26 @@
|
|||||||
|
import graphlib from '@dagrejs/graphlib';
|
||||||
|
import { Edge, Node } from 'reactflow';
|
||||||
|
|
||||||
|
export const getIsGraphAcyclic = (
|
||||||
|
source: string,
|
||||||
|
target: string,
|
||||||
|
nodes: Node[],
|
||||||
|
edges: Edge[]
|
||||||
|
) => {
|
||||||
|
// construct graphlib graph from editor state
|
||||||
|
const g = new graphlib.Graph();
|
||||||
|
|
||||||
|
nodes.forEach((n) => {
|
||||||
|
g.setNode(n.id);
|
||||||
|
});
|
||||||
|
|
||||||
|
edges.forEach((e) => {
|
||||||
|
g.setEdge(e.source, e.target);
|
||||||
|
});
|
||||||
|
|
||||||
|
// add the candidate edge
|
||||||
|
g.setEdge(source, target);
|
||||||
|
|
||||||
|
// check if the graph is acyclic
|
||||||
|
return graphlib.alg.isAcyclic(g);
|
||||||
|
};
|
@ -1,6 +1,6 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { stateSelector } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
import { getIsGraphAcyclic } from 'features/nodes/hooks/useIsValidConnection';
|
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
|
||||||
import { FieldType } from 'features/nodes/types/types';
|
import { FieldType } from 'features/nodes/types/types';
|
||||||
import i18n from 'i18next';
|
import i18n from 'i18next';
|
||||||
import { HandleType } from 'reactflow';
|
import { HandleType } from 'reactflow';
|
||||||
|
@ -10,6 +10,13 @@ export const validateSourceAndTargetTypes = (
|
|||||||
sourceType: FieldType,
|
sourceType: FieldType,
|
||||||
targetType: FieldType
|
targetType: FieldType
|
||||||
) => {
|
) => {
|
||||||
|
// TODO: There's a bug with Collect -> Iterate nodes:
|
||||||
|
// https://github.com/invoke-ai/InvokeAI/issues/3956
|
||||||
|
// Once this is resolved, we can remove this check.
|
||||||
|
if (sourceType === 'Collection' && targetType === 'Collection') {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
if (sourceType === targetType) {
|
if (sourceType === targetType) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user