mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): recreate edge autoconnect logic
This commit is contained in:
parent
708c68413d
commit
2c1fa30639
@ -1,14 +1,14 @@
|
||||
import { useGlobalMenuClose, useToken } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useConnection } from 'features/nodes/hooks/useConnection';
|
||||
import { useCopyPaste } from 'features/nodes/hooks/useCopyPaste';
|
||||
import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection';
|
||||
import { $mouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
|
||||
import { useWorkflowWatcher } from 'features/nodes/hooks/useWorkflowWatcher';
|
||||
import {
|
||||
$cursorPos,
|
||||
connectionEnded,
|
||||
$pendingConnection,
|
||||
connectionMade,
|
||||
connectionStarted,
|
||||
edgeAdded,
|
||||
edgeChangeStarted,
|
||||
edgeDeleted,
|
||||
@ -28,9 +28,6 @@ import type { CSSProperties, MouseEvent } from 'react';
|
||||
import { memo, useCallback, useMemo, useRef } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import type {
|
||||
OnConnect,
|
||||
OnConnectEnd,
|
||||
OnConnectStart,
|
||||
OnEdgesChange,
|
||||
OnEdgesDelete,
|
||||
OnEdgeUpdateFunc,
|
||||
@ -76,6 +73,7 @@ export const Flow = memo(() => {
|
||||
const viewport = useAppSelector((s) => s.nodes.present.viewport);
|
||||
const shouldSnapToGrid = useAppSelector((s) => s.workflowSettings.shouldSnapToGrid);
|
||||
const selectionMode = useAppSelector((s) => s.workflowSettings.selectionMode);
|
||||
const { onConnectStart, onConnect, onConnectEnd } = useConnection();
|
||||
const flowWrapper = useRef<HTMLDivElement>(null);
|
||||
const isValidConnection = useIsValidConnection();
|
||||
useWorkflowWatcher();
|
||||
@ -102,32 +100,35 @@ export const Flow = memo(() => {
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const onConnectStart: OnConnectStart = useCallback(
|
||||
(event, params) => {
|
||||
dispatch(connectionStarted(params));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
// const onConnectStart: OnConnectStart = useCallback(
|
||||
// (event, params) => {
|
||||
// dispatch(connectionStarted(params));
|
||||
// },
|
||||
// [dispatch]
|
||||
// );
|
||||
|
||||
const onConnect: OnConnect = useCallback(
|
||||
(connection) => {
|
||||
dispatch(connectionMade(connection));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
// const onConnect: OnConnect = useCallback(
|
||||
// (connection) => {
|
||||
// dispatch(connectionMade(connection));
|
||||
// },
|
||||
// [dispatch]
|
||||
// );
|
||||
|
||||
const onConnectEnd: OnConnectEnd = useCallback(() => {
|
||||
const cursorPosition = $cursorPos.get();
|
||||
if (!cursorPosition) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
connectionEnded({
|
||||
cursorPosition,
|
||||
mouseOverNodeId: $mouseOverNode.get(),
|
||||
})
|
||||
);
|
||||
}, [dispatch]);
|
||||
// const onConnectEnd: OnConnectEnd = useCallback(() => {
|
||||
// const cursorPosition = $cursorPos.get();
|
||||
// if (!cursorPosition) {
|
||||
// return;
|
||||
// }
|
||||
// dispatch(
|
||||
// connectionEnded({
|
||||
// cursorPosition,
|
||||
// mouseOverNodeId: $mouseOverNode.get(),
|
||||
// })
|
||||
// );
|
||||
// }, [dispatch]);
|
||||
|
||||
const pendingConnection = useStore($pendingConnection);
|
||||
console.log(pendingConnection)
|
||||
|
||||
const onEdgesDelete: OnEdgesDelete = useCallback(
|
||||
(edges) => {
|
||||
|
@ -0,0 +1,68 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppStore } from 'app/store/storeHooks';
|
||||
import { $mouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
|
||||
import { $pendingConnection, $templates, connectionMade } from 'features/nodes/store/nodesSlice';
|
||||
import { getFirstValidConnection } from 'features/nodes/store/util/findConnectionToValidHandle';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { OnConnect, OnConnectEnd, OnConnectStart } from 'reactflow';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
export const useConnection = () => {
|
||||
const store = useAppStore();
|
||||
const templates = useStore($templates);
|
||||
|
||||
const onConnectStart = useCallback<OnConnectStart>(
|
||||
(event, params) => {
|
||||
const nodes = store.getState().nodes.present.nodes;
|
||||
const { nodeId, handleId, handleType } = params;
|
||||
assert(nodeId && handleId && handleType, `Invalid connection start params: ${JSON.stringify(params)}`);
|
||||
const node = nodes.find((n) => n.id === nodeId);
|
||||
assert(isInvocationNode(node), `Invalid node during connection: ${JSON.stringify(node)}`);
|
||||
const template = templates[node.data.type];
|
||||
assert(template, `Template not found for node type: ${node.data.type}`);
|
||||
const fieldTemplate = handleType === 'source' ? template.outputs[handleId] : template.inputs[handleId];
|
||||
assert(fieldTemplate, `Field template not found for field: ${node.data.type}.${handleId}`);
|
||||
$pendingConnection.set({
|
||||
node,
|
||||
template,
|
||||
fieldTemplate,
|
||||
});
|
||||
},
|
||||
[store, templates]
|
||||
);
|
||||
const onConnect = useCallback<OnConnect>(
|
||||
(connection) => {
|
||||
const { dispatch } = store;
|
||||
dispatch(connectionMade(connection));
|
||||
$pendingConnection.set(null);
|
||||
},
|
||||
[store]
|
||||
);
|
||||
const onConnectEnd = useCallback<OnConnectEnd>(() => {
|
||||
const { dispatch } = store;
|
||||
const pendingConnection = $pendingConnection.get();
|
||||
if (!pendingConnection) {
|
||||
return;
|
||||
}
|
||||
const mouseOverNodeId = $mouseOverNode.get();
|
||||
const { nodes, edges } = store.getState().nodes.present;
|
||||
if (mouseOverNodeId) {
|
||||
const candidateNode = nodes.filter(isInvocationNode).find((n) => n.id === mouseOverNodeId);
|
||||
if (!candidateNode) {
|
||||
// The mouse is over a non-invocation node - bail
|
||||
return;
|
||||
}
|
||||
const candidateTemplate = templates[candidateNode.data.type];
|
||||
assert(candidateTemplate, `Template not found for node type: ${candidateNode.data.type}`);
|
||||
const connection = getFirstValidConnection(nodes, edges, pendingConnection, candidateNode, candidateTemplate);
|
||||
if (connection) {
|
||||
dispatch(connectionMade(connection));
|
||||
}
|
||||
}
|
||||
$pendingConnection.set(null);
|
||||
}, [store, templates]);
|
||||
|
||||
const api = useMemo(() => ({ onConnectStart, onConnect, onConnectEnd }), [onConnectStart, onConnect, onConnectEnd]);
|
||||
return api;
|
||||
};
|
@ -69,7 +69,7 @@ import {
|
||||
} from 'services/events/actions';
|
||||
import type { z } from 'zod';
|
||||
|
||||
import type { NodesState, Templates } from './types';
|
||||
import type { NodesState, PendingConnection, Templates } from './types';
|
||||
import { findConnectionToValidHandle } from './util/findConnectionToValidHandle';
|
||||
import { findUnoccupiedPosition } from './util/findUnoccupiedPosition';
|
||||
|
||||
@ -215,13 +215,7 @@ export const nodesSlice = createSlice({
|
||||
state.connectionStartFieldType = field?.type ?? null;
|
||||
},
|
||||
connectionMade: (state, action: PayloadAction<Connection>) => {
|
||||
const fieldType = state.connectionStartFieldType;
|
||||
if (!fieldType) {
|
||||
return;
|
||||
}
|
||||
state.edges = addEdge({ ...action.payload, type: 'default' }, state.edges);
|
||||
|
||||
state.connectionMade = true;
|
||||
},
|
||||
connectionEnded: (
|
||||
state,
|
||||
@ -764,6 +758,8 @@ export const $cursorPos = atom<XYPosition | null>(null);
|
||||
export const $templates = atom<Templates>({});
|
||||
export const $copiedNodes = atom<AnyNode[]>([]);
|
||||
export const $copiedEdges = atom<InvocationNodeEdge[]>([]);
|
||||
export const $pendingConnection = atom<PendingConnection | null>(null);
|
||||
export const $isModifyingEdge = atom(false);
|
||||
|
||||
export const selectNodesSlice = (state: RootState) => state.nodes.present;
|
||||
|
||||
|
@ -1,6 +1,13 @@
|
||||
import type { FieldIdentifier, FieldType, StatefulFieldValue } from 'features/nodes/types/field';
|
||||
import type {
|
||||
FieldIdentifier,
|
||||
FieldInputTemplate,
|
||||
FieldOutputTemplate,
|
||||
FieldType,
|
||||
StatefulFieldValue,
|
||||
} from 'features/nodes/types/field';
|
||||
import type {
|
||||
AnyNode,
|
||||
InvocationNode,
|
||||
InvocationNodeEdge,
|
||||
InvocationTemplate,
|
||||
NodeExecutionState,
|
||||
@ -10,6 +17,12 @@ import type { OnConnectStartParams, Viewport, XYPosition } from 'reactflow';
|
||||
|
||||
export type Templates = Record<string, InvocationTemplate>;
|
||||
|
||||
export type PendingConnection = {
|
||||
node: InvocationNode;
|
||||
template: InvocationTemplate;
|
||||
fieldTemplate: FieldInputTemplate | FieldOutputTemplate;
|
||||
};
|
||||
|
||||
export type NodesState = {
|
||||
_version: 1;
|
||||
nodes: AnyNode[];
|
||||
|
@ -1,8 +1,10 @@
|
||||
import type { Templates } from 'features/nodes/store/types';
|
||||
import type { PendingConnection, Templates } from 'features/nodes/store/types';
|
||||
import type { FieldInputTemplate, FieldOutputTemplate, FieldType } from 'features/nodes/types/field';
|
||||
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
|
||||
import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { differenceWith, map } from 'lodash-es';
|
||||
import type { Connection, Edge, HandleType, Node } from 'reactflow';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
|
||||
import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
|
||||
@ -111,3 +113,98 @@ export const findConnectionToValidHandle = (
|
||||
}
|
||||
return null;
|
||||
};
|
||||
|
||||
export const getFirstValidConnection = (
|
||||
nodes: AnyNode[],
|
||||
edges: InvocationNodeEdge[],
|
||||
pendingConnection: PendingConnection,
|
||||
candidateNode: InvocationNode,
|
||||
candidateTemplate: InvocationTemplate
|
||||
): Connection | null => {
|
||||
if (pendingConnection.node.id === candidateNode.id) {
|
||||
// Cannot connect to self
|
||||
return null;
|
||||
}
|
||||
|
||||
const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source';
|
||||
|
||||
if (pendingFieldKind === 'source') {
|
||||
// Connecting from a source to a target
|
||||
if (!getIsGraphAcyclic(pendingConnection.node.id, candidateNode.id, nodes, edges)) {
|
||||
return null;
|
||||
}
|
||||
if (candidateNode.data.type === 'collect') {
|
||||
// Special handling for collect node - the `item` field takes any number of connections
|
||||
return {
|
||||
source: pendingConnection.node.id,
|
||||
sourceHandle: pendingConnection.fieldTemplate.name,
|
||||
target: candidateNode.id,
|
||||
targetHandle: 'item',
|
||||
};
|
||||
}
|
||||
// Only one connection per target field is allowed - look for an unconnected target field
|
||||
const candidateFields = map(candidateTemplate.inputs);
|
||||
const candidateConnectedFields = edges
|
||||
.filter((edge) => edge.target === candidateNode.id)
|
||||
.map((edge) => {
|
||||
// Edges must always have a targetHandle, safe to assert here
|
||||
assert(edge.targetHandle);
|
||||
return edge.targetHandle;
|
||||
});
|
||||
const candidateUnconnectedFields = differenceWith(
|
||||
candidateFields,
|
||||
candidateConnectedFields,
|
||||
(field, connectedFieldName) => field.name === connectedFieldName
|
||||
);
|
||||
const candidateField = candidateUnconnectedFields.find((field) =>
|
||||
validateSourceAndTargetTypes(pendingConnection.fieldTemplate.type, field.type)
|
||||
);
|
||||
if (candidateField) {
|
||||
return {
|
||||
source: pendingConnection.node.id,
|
||||
sourceHandle: pendingConnection.fieldTemplate.name,
|
||||
target: candidateNode.id,
|
||||
targetHandle: candidateField.name,
|
||||
};
|
||||
}
|
||||
} else {
|
||||
// Connecting from a target to a source
|
||||
// Ensure we there is not already an edge to the target
|
||||
if (
|
||||
edges.some(
|
||||
(e) => e.target === pendingConnection.node.id && e.targetHandle === pendingConnection.fieldTemplate.name
|
||||
)
|
||||
) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (!getIsGraphAcyclic(candidateNode.id, pendingConnection.node.id, nodes, edges)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (candidateNode.data.type === 'collect') {
|
||||
// Special handling for collect node - connect to the `collection` field
|
||||
return {
|
||||
source: candidateNode.id,
|
||||
sourceHandle: 'collection',
|
||||
target: pendingConnection.node.id,
|
||||
targetHandle: pendingConnection.fieldTemplate.name,
|
||||
};
|
||||
}
|
||||
// Sources/outputs can have any number of edges, we can take the first matching output field
|
||||
const candidateFields = map(candidateTemplate.outputs);
|
||||
const candidateField = candidateFields.find((field) =>
|
||||
validateSourceAndTargetTypes(field.type, pendingConnection.fieldTemplate.type)
|
||||
);
|
||||
if (candidateField) {
|
||||
return {
|
||||
source: candidateNode.id,
|
||||
sourceHandle: candidateField.name,
|
||||
target: pendingConnection.node.id,
|
||||
targetHandle: pendingConnection.fieldTemplate.name,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user