feat(ui): recreate edge autoconnect logic

This commit is contained in:
psychedelicious 2024-05-16 18:56:11 +10:00
parent 708c68413d
commit 2c1fa30639
5 changed files with 215 additions and 40 deletions

View File

@ -1,14 +1,14 @@
import { useGlobalMenuClose, useToken } from '@invoke-ai/ui-library'; import { useGlobalMenuClose, useToken } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useConnection } from 'features/nodes/hooks/useConnection';
import { useCopyPaste } from 'features/nodes/hooks/useCopyPaste'; import { useCopyPaste } from 'features/nodes/hooks/useCopyPaste';
import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection'; import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection';
import { $mouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
import { useWorkflowWatcher } from 'features/nodes/hooks/useWorkflowWatcher'; import { useWorkflowWatcher } from 'features/nodes/hooks/useWorkflowWatcher';
import { import {
$cursorPos, $cursorPos,
connectionEnded, $pendingConnection,
connectionMade, connectionMade,
connectionStarted,
edgeAdded, edgeAdded,
edgeChangeStarted, edgeChangeStarted,
edgeDeleted, edgeDeleted,
@ -28,9 +28,6 @@ import type { CSSProperties, MouseEvent } from 'react';
import { memo, useCallback, useMemo, useRef } from 'react'; import { memo, useCallback, useMemo, useRef } from 'react';
import { useHotkeys } from 'react-hotkeys-hook'; import { useHotkeys } from 'react-hotkeys-hook';
import type { import type {
OnConnect,
OnConnectEnd,
OnConnectStart,
OnEdgesChange, OnEdgesChange,
OnEdgesDelete, OnEdgesDelete,
OnEdgeUpdateFunc, OnEdgeUpdateFunc,
@ -76,6 +73,7 @@ export const Flow = memo(() => {
const viewport = useAppSelector((s) => s.nodes.present.viewport); const viewport = useAppSelector((s) => s.nodes.present.viewport);
const shouldSnapToGrid = useAppSelector((s) => s.workflowSettings.shouldSnapToGrid); const shouldSnapToGrid = useAppSelector((s) => s.workflowSettings.shouldSnapToGrid);
const selectionMode = useAppSelector((s) => s.workflowSettings.selectionMode); const selectionMode = useAppSelector((s) => s.workflowSettings.selectionMode);
const { onConnectStart, onConnect, onConnectEnd } = useConnection();
const flowWrapper = useRef<HTMLDivElement>(null); const flowWrapper = useRef<HTMLDivElement>(null);
const isValidConnection = useIsValidConnection(); const isValidConnection = useIsValidConnection();
useWorkflowWatcher(); useWorkflowWatcher();
@ -102,32 +100,35 @@ export const Flow = memo(() => {
[dispatch] [dispatch]
); );
const onConnectStart: OnConnectStart = useCallback( // const onConnectStart: OnConnectStart = useCallback(
(event, params) => { // (event, params) => {
dispatch(connectionStarted(params)); // dispatch(connectionStarted(params));
}, // },
[dispatch] // [dispatch]
); // );
const onConnect: OnConnect = useCallback( // const onConnect: OnConnect = useCallback(
(connection) => { // (connection) => {
dispatch(connectionMade(connection)); // dispatch(connectionMade(connection));
}, // },
[dispatch] // [dispatch]
); // );
const onConnectEnd: OnConnectEnd = useCallback(() => { // const onConnectEnd: OnConnectEnd = useCallback(() => {
const cursorPosition = $cursorPos.get(); // const cursorPosition = $cursorPos.get();
if (!cursorPosition) { // if (!cursorPosition) {
return; // return;
} // }
dispatch( // dispatch(
connectionEnded({ // connectionEnded({
cursorPosition, // cursorPosition,
mouseOverNodeId: $mouseOverNode.get(), // mouseOverNodeId: $mouseOverNode.get(),
}) // })
); // );
}, [dispatch]); // }, [dispatch]);
const pendingConnection = useStore($pendingConnection);
console.log(pendingConnection)
const onEdgesDelete: OnEdgesDelete = useCallback( const onEdgesDelete: OnEdgesDelete = useCallback(
(edges) => { (edges) => {

View File

@ -0,0 +1,68 @@
import { useStore } from '@nanostores/react';
import { useAppStore } from 'app/store/storeHooks';
import { $mouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
import { $pendingConnection, $templates, connectionMade } from 'features/nodes/store/nodesSlice';
import { getFirstValidConnection } from 'features/nodes/store/util/findConnectionToValidHandle';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { useCallback, useMemo } from 'react';
import type { OnConnect, OnConnectEnd, OnConnectStart } from 'reactflow';
import { assert } from 'tsafe';
export const useConnection = () => {
const store = useAppStore();
const templates = useStore($templates);
const onConnectStart = useCallback<OnConnectStart>(
(event, params) => {
const nodes = store.getState().nodes.present.nodes;
const { nodeId, handleId, handleType } = params;
assert(nodeId && handleId && handleType, `Invalid connection start params: ${JSON.stringify(params)}`);
const node = nodes.find((n) => n.id === nodeId);
assert(isInvocationNode(node), `Invalid node during connection: ${JSON.stringify(node)}`);
const template = templates[node.data.type];
assert(template, `Template not found for node type: ${node.data.type}`);
const fieldTemplate = handleType === 'source' ? template.outputs[handleId] : template.inputs[handleId];
assert(fieldTemplate, `Field template not found for field: ${node.data.type}.${handleId}`);
$pendingConnection.set({
node,
template,
fieldTemplate,
});
},
[store, templates]
);
const onConnect = useCallback<OnConnect>(
(connection) => {
const { dispatch } = store;
dispatch(connectionMade(connection));
$pendingConnection.set(null);
},
[store]
);
const onConnectEnd = useCallback<OnConnectEnd>(() => {
const { dispatch } = store;
const pendingConnection = $pendingConnection.get();
if (!pendingConnection) {
return;
}
const mouseOverNodeId = $mouseOverNode.get();
const { nodes, edges } = store.getState().nodes.present;
if (mouseOverNodeId) {
const candidateNode = nodes.filter(isInvocationNode).find((n) => n.id === mouseOverNodeId);
if (!candidateNode) {
// The mouse is over a non-invocation node - bail
return;
}
const candidateTemplate = templates[candidateNode.data.type];
assert(candidateTemplate, `Template not found for node type: ${candidateNode.data.type}`);
const connection = getFirstValidConnection(nodes, edges, pendingConnection, candidateNode, candidateTemplate);
if (connection) {
dispatch(connectionMade(connection));
}
}
$pendingConnection.set(null);
}, [store, templates]);
const api = useMemo(() => ({ onConnectStart, onConnect, onConnectEnd }), [onConnectStart, onConnect, onConnectEnd]);
return api;
};

View File

@ -69,7 +69,7 @@ import {
} from 'services/events/actions'; } from 'services/events/actions';
import type { z } from 'zod'; import type { z } from 'zod';
import type { NodesState, Templates } from './types'; import type { NodesState, PendingConnection, Templates } from './types';
import { findConnectionToValidHandle } from './util/findConnectionToValidHandle'; import { findConnectionToValidHandle } from './util/findConnectionToValidHandle';
import { findUnoccupiedPosition } from './util/findUnoccupiedPosition'; import { findUnoccupiedPosition } from './util/findUnoccupiedPosition';
@ -215,13 +215,7 @@ export const nodesSlice = createSlice({
state.connectionStartFieldType = field?.type ?? null; state.connectionStartFieldType = field?.type ?? null;
}, },
connectionMade: (state, action: PayloadAction<Connection>) => { connectionMade: (state, action: PayloadAction<Connection>) => {
const fieldType = state.connectionStartFieldType;
if (!fieldType) {
return;
}
state.edges = addEdge({ ...action.payload, type: 'default' }, state.edges); state.edges = addEdge({ ...action.payload, type: 'default' }, state.edges);
state.connectionMade = true;
}, },
connectionEnded: ( connectionEnded: (
state, state,
@ -764,6 +758,8 @@ export const $cursorPos = atom<XYPosition | null>(null);
export const $templates = atom<Templates>({}); export const $templates = atom<Templates>({});
export const $copiedNodes = atom<AnyNode[]>([]); export const $copiedNodes = atom<AnyNode[]>([]);
export const $copiedEdges = atom<InvocationNodeEdge[]>([]); export const $copiedEdges = atom<InvocationNodeEdge[]>([]);
export const $pendingConnection = atom<PendingConnection | null>(null);
export const $isModifyingEdge = atom(false);
export const selectNodesSlice = (state: RootState) => state.nodes.present; export const selectNodesSlice = (state: RootState) => state.nodes.present;

View File

@ -1,6 +1,13 @@
import type { FieldIdentifier, FieldType, StatefulFieldValue } from 'features/nodes/types/field'; import type {
FieldIdentifier,
FieldInputTemplate,
FieldOutputTemplate,
FieldType,
StatefulFieldValue,
} from 'features/nodes/types/field';
import type { import type {
AnyNode, AnyNode,
InvocationNode,
InvocationNodeEdge, InvocationNodeEdge,
InvocationTemplate, InvocationTemplate,
NodeExecutionState, NodeExecutionState,
@ -10,6 +17,12 @@ import type { OnConnectStartParams, Viewport, XYPosition } from 'reactflow';
export type Templates = Record<string, InvocationTemplate>; export type Templates = Record<string, InvocationTemplate>;
export type PendingConnection = {
node: InvocationNode;
template: InvocationTemplate;
fieldTemplate: FieldInputTemplate | FieldOutputTemplate;
};
export type NodesState = { export type NodesState = {
_version: 1; _version: 1;
nodes: AnyNode[]; nodes: AnyNode[];

View File

@ -1,8 +1,10 @@
import type { Templates } from 'features/nodes/store/types'; import type { PendingConnection, Templates } from 'features/nodes/store/types';
import type { FieldInputTemplate, FieldOutputTemplate, FieldType } from 'features/nodes/types/field'; import type { FieldInputTemplate, FieldOutputTemplate, FieldType } from 'features/nodes/types/field';
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation'; import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation';
import { isInvocationNode } from 'features/nodes/types/invocation'; import { isInvocationNode } from 'features/nodes/types/invocation';
import { differenceWith, map } from 'lodash-es';
import type { Connection, Edge, HandleType, Node } from 'reactflow'; import type { Connection, Edge, HandleType, Node } from 'reactflow';
import { assert } from 'tsafe';
import { getIsGraphAcyclic } from './getIsGraphAcyclic'; import { getIsGraphAcyclic } from './getIsGraphAcyclic';
import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes'; import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
@ -111,3 +113,98 @@ export const findConnectionToValidHandle = (
} }
return null; return null;
}; };
export const getFirstValidConnection = (
nodes: AnyNode[],
edges: InvocationNodeEdge[],
pendingConnection: PendingConnection,
candidateNode: InvocationNode,
candidateTemplate: InvocationTemplate
): Connection | null => {
if (pendingConnection.node.id === candidateNode.id) {
// Cannot connect to self
return null;
}
const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source';
if (pendingFieldKind === 'source') {
// Connecting from a source to a target
if (!getIsGraphAcyclic(pendingConnection.node.id, candidateNode.id, nodes, edges)) {
return null;
}
if (candidateNode.data.type === 'collect') {
// Special handling for collect node - the `item` field takes any number of connections
return {
source: pendingConnection.node.id,
sourceHandle: pendingConnection.fieldTemplate.name,
target: candidateNode.id,
targetHandle: 'item',
};
}
// Only one connection per target field is allowed - look for an unconnected target field
const candidateFields = map(candidateTemplate.inputs);
const candidateConnectedFields = edges
.filter((edge) => edge.target === candidateNode.id)
.map((edge) => {
// Edges must always have a targetHandle, safe to assert here
assert(edge.targetHandle);
return edge.targetHandle;
});
const candidateUnconnectedFields = differenceWith(
candidateFields,
candidateConnectedFields,
(field, connectedFieldName) => field.name === connectedFieldName
);
const candidateField = candidateUnconnectedFields.find((field) =>
validateSourceAndTargetTypes(pendingConnection.fieldTemplate.type, field.type)
);
if (candidateField) {
return {
source: pendingConnection.node.id,
sourceHandle: pendingConnection.fieldTemplate.name,
target: candidateNode.id,
targetHandle: candidateField.name,
};
}
} else {
// Connecting from a target to a source
// Ensure we there is not already an edge to the target
if (
edges.some(
(e) => e.target === pendingConnection.node.id && e.targetHandle === pendingConnection.fieldTemplate.name
)
) {
return null;
}
if (!getIsGraphAcyclic(candidateNode.id, pendingConnection.node.id, nodes, edges)) {
return null;
}
if (candidateNode.data.type === 'collect') {
// Special handling for collect node - connect to the `collection` field
return {
source: candidateNode.id,
sourceHandle: 'collection',
target: pendingConnection.node.id,
targetHandle: pendingConnection.fieldTemplate.name,
};
}
// Sources/outputs can have any number of edges, we can take the first matching output field
const candidateFields = map(candidateTemplate.outputs);
const candidateField = candidateFields.find((field) =>
validateSourceAndTargetTypes(field.type, pendingConnection.fieldTemplate.type)
);
if (candidateField) {
return {
source: candidateNode.id,
sourceHandle: candidateField.name,
target: pendingConnection.node.id,
targetHandle: pendingConnection.fieldTemplate.name,
};
}
}
return null;
};