mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): recreate edge autoconnect logic
This commit is contained in:
parent
708c68413d
commit
2c1fa30639
@ -1,14 +1,14 @@
|
|||||||
import { useGlobalMenuClose, useToken } from '@invoke-ai/ui-library';
|
import { useGlobalMenuClose, useToken } from '@invoke-ai/ui-library';
|
||||||
|
import { useStore } from '@nanostores/react';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { useConnection } from 'features/nodes/hooks/useConnection';
|
||||||
import { useCopyPaste } from 'features/nodes/hooks/useCopyPaste';
|
import { useCopyPaste } from 'features/nodes/hooks/useCopyPaste';
|
||||||
import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection';
|
import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection';
|
||||||
import { $mouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
|
|
||||||
import { useWorkflowWatcher } from 'features/nodes/hooks/useWorkflowWatcher';
|
import { useWorkflowWatcher } from 'features/nodes/hooks/useWorkflowWatcher';
|
||||||
import {
|
import {
|
||||||
$cursorPos,
|
$cursorPos,
|
||||||
connectionEnded,
|
$pendingConnection,
|
||||||
connectionMade,
|
connectionMade,
|
||||||
connectionStarted,
|
|
||||||
edgeAdded,
|
edgeAdded,
|
||||||
edgeChangeStarted,
|
edgeChangeStarted,
|
||||||
edgeDeleted,
|
edgeDeleted,
|
||||||
@ -28,9 +28,6 @@ import type { CSSProperties, MouseEvent } from 'react';
|
|||||||
import { memo, useCallback, useMemo, useRef } from 'react';
|
import { memo, useCallback, useMemo, useRef } from 'react';
|
||||||
import { useHotkeys } from 'react-hotkeys-hook';
|
import { useHotkeys } from 'react-hotkeys-hook';
|
||||||
import type {
|
import type {
|
||||||
OnConnect,
|
|
||||||
OnConnectEnd,
|
|
||||||
OnConnectStart,
|
|
||||||
OnEdgesChange,
|
OnEdgesChange,
|
||||||
OnEdgesDelete,
|
OnEdgesDelete,
|
||||||
OnEdgeUpdateFunc,
|
OnEdgeUpdateFunc,
|
||||||
@ -76,6 +73,7 @@ export const Flow = memo(() => {
|
|||||||
const viewport = useAppSelector((s) => s.nodes.present.viewport);
|
const viewport = useAppSelector((s) => s.nodes.present.viewport);
|
||||||
const shouldSnapToGrid = useAppSelector((s) => s.workflowSettings.shouldSnapToGrid);
|
const shouldSnapToGrid = useAppSelector((s) => s.workflowSettings.shouldSnapToGrid);
|
||||||
const selectionMode = useAppSelector((s) => s.workflowSettings.selectionMode);
|
const selectionMode = useAppSelector((s) => s.workflowSettings.selectionMode);
|
||||||
|
const { onConnectStart, onConnect, onConnectEnd } = useConnection();
|
||||||
const flowWrapper = useRef<HTMLDivElement>(null);
|
const flowWrapper = useRef<HTMLDivElement>(null);
|
||||||
const isValidConnection = useIsValidConnection();
|
const isValidConnection = useIsValidConnection();
|
||||||
useWorkflowWatcher();
|
useWorkflowWatcher();
|
||||||
@ -102,32 +100,35 @@ export const Flow = memo(() => {
|
|||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
|
|
||||||
const onConnectStart: OnConnectStart = useCallback(
|
// const onConnectStart: OnConnectStart = useCallback(
|
||||||
(event, params) => {
|
// (event, params) => {
|
||||||
dispatch(connectionStarted(params));
|
// dispatch(connectionStarted(params));
|
||||||
},
|
// },
|
||||||
[dispatch]
|
// [dispatch]
|
||||||
);
|
// );
|
||||||
|
|
||||||
const onConnect: OnConnect = useCallback(
|
// const onConnect: OnConnect = useCallback(
|
||||||
(connection) => {
|
// (connection) => {
|
||||||
dispatch(connectionMade(connection));
|
// dispatch(connectionMade(connection));
|
||||||
},
|
// },
|
||||||
[dispatch]
|
// [dispatch]
|
||||||
);
|
// );
|
||||||
|
|
||||||
const onConnectEnd: OnConnectEnd = useCallback(() => {
|
// const onConnectEnd: OnConnectEnd = useCallback(() => {
|
||||||
const cursorPosition = $cursorPos.get();
|
// const cursorPosition = $cursorPos.get();
|
||||||
if (!cursorPosition) {
|
// if (!cursorPosition) {
|
||||||
return;
|
// return;
|
||||||
}
|
// }
|
||||||
dispatch(
|
// dispatch(
|
||||||
connectionEnded({
|
// connectionEnded({
|
||||||
cursorPosition,
|
// cursorPosition,
|
||||||
mouseOverNodeId: $mouseOverNode.get(),
|
// mouseOverNodeId: $mouseOverNode.get(),
|
||||||
})
|
// })
|
||||||
);
|
// );
|
||||||
}, [dispatch]);
|
// }, [dispatch]);
|
||||||
|
|
||||||
|
const pendingConnection = useStore($pendingConnection);
|
||||||
|
console.log(pendingConnection)
|
||||||
|
|
||||||
const onEdgesDelete: OnEdgesDelete = useCallback(
|
const onEdgesDelete: OnEdgesDelete = useCallback(
|
||||||
(edges) => {
|
(edges) => {
|
||||||
|
@ -0,0 +1,68 @@
|
|||||||
|
import { useStore } from '@nanostores/react';
|
||||||
|
import { useAppStore } from 'app/store/storeHooks';
|
||||||
|
import { $mouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
|
||||||
|
import { $pendingConnection, $templates, connectionMade } from 'features/nodes/store/nodesSlice';
|
||||||
|
import { getFirstValidConnection } from 'features/nodes/store/util/findConnectionToValidHandle';
|
||||||
|
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||||
|
import { useCallback, useMemo } from 'react';
|
||||||
|
import type { OnConnect, OnConnectEnd, OnConnectStart } from 'reactflow';
|
||||||
|
import { assert } from 'tsafe';
|
||||||
|
|
||||||
|
export const useConnection = () => {
|
||||||
|
const store = useAppStore();
|
||||||
|
const templates = useStore($templates);
|
||||||
|
|
||||||
|
const onConnectStart = useCallback<OnConnectStart>(
|
||||||
|
(event, params) => {
|
||||||
|
const nodes = store.getState().nodes.present.nodes;
|
||||||
|
const { nodeId, handleId, handleType } = params;
|
||||||
|
assert(nodeId && handleId && handleType, `Invalid connection start params: ${JSON.stringify(params)}`);
|
||||||
|
const node = nodes.find((n) => n.id === nodeId);
|
||||||
|
assert(isInvocationNode(node), `Invalid node during connection: ${JSON.stringify(node)}`);
|
||||||
|
const template = templates[node.data.type];
|
||||||
|
assert(template, `Template not found for node type: ${node.data.type}`);
|
||||||
|
const fieldTemplate = handleType === 'source' ? template.outputs[handleId] : template.inputs[handleId];
|
||||||
|
assert(fieldTemplate, `Field template not found for field: ${node.data.type}.${handleId}`);
|
||||||
|
$pendingConnection.set({
|
||||||
|
node,
|
||||||
|
template,
|
||||||
|
fieldTemplate,
|
||||||
|
});
|
||||||
|
},
|
||||||
|
[store, templates]
|
||||||
|
);
|
||||||
|
const onConnect = useCallback<OnConnect>(
|
||||||
|
(connection) => {
|
||||||
|
const { dispatch } = store;
|
||||||
|
dispatch(connectionMade(connection));
|
||||||
|
$pendingConnection.set(null);
|
||||||
|
},
|
||||||
|
[store]
|
||||||
|
);
|
||||||
|
const onConnectEnd = useCallback<OnConnectEnd>(() => {
|
||||||
|
const { dispatch } = store;
|
||||||
|
const pendingConnection = $pendingConnection.get();
|
||||||
|
if (!pendingConnection) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const mouseOverNodeId = $mouseOverNode.get();
|
||||||
|
const { nodes, edges } = store.getState().nodes.present;
|
||||||
|
if (mouseOverNodeId) {
|
||||||
|
const candidateNode = nodes.filter(isInvocationNode).find((n) => n.id === mouseOverNodeId);
|
||||||
|
if (!candidateNode) {
|
||||||
|
// The mouse is over a non-invocation node - bail
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const candidateTemplate = templates[candidateNode.data.type];
|
||||||
|
assert(candidateTemplate, `Template not found for node type: ${candidateNode.data.type}`);
|
||||||
|
const connection = getFirstValidConnection(nodes, edges, pendingConnection, candidateNode, candidateTemplate);
|
||||||
|
if (connection) {
|
||||||
|
dispatch(connectionMade(connection));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
$pendingConnection.set(null);
|
||||||
|
}, [store, templates]);
|
||||||
|
|
||||||
|
const api = useMemo(() => ({ onConnectStart, onConnect, onConnectEnd }), [onConnectStart, onConnect, onConnectEnd]);
|
||||||
|
return api;
|
||||||
|
};
|
@ -69,7 +69,7 @@ import {
|
|||||||
} from 'services/events/actions';
|
} from 'services/events/actions';
|
||||||
import type { z } from 'zod';
|
import type { z } from 'zod';
|
||||||
|
|
||||||
import type { NodesState, Templates } from './types';
|
import type { NodesState, PendingConnection, Templates } from './types';
|
||||||
import { findConnectionToValidHandle } from './util/findConnectionToValidHandle';
|
import { findConnectionToValidHandle } from './util/findConnectionToValidHandle';
|
||||||
import { findUnoccupiedPosition } from './util/findUnoccupiedPosition';
|
import { findUnoccupiedPosition } from './util/findUnoccupiedPosition';
|
||||||
|
|
||||||
@ -215,13 +215,7 @@ export const nodesSlice = createSlice({
|
|||||||
state.connectionStartFieldType = field?.type ?? null;
|
state.connectionStartFieldType = field?.type ?? null;
|
||||||
},
|
},
|
||||||
connectionMade: (state, action: PayloadAction<Connection>) => {
|
connectionMade: (state, action: PayloadAction<Connection>) => {
|
||||||
const fieldType = state.connectionStartFieldType;
|
|
||||||
if (!fieldType) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
state.edges = addEdge({ ...action.payload, type: 'default' }, state.edges);
|
state.edges = addEdge({ ...action.payload, type: 'default' }, state.edges);
|
||||||
|
|
||||||
state.connectionMade = true;
|
|
||||||
},
|
},
|
||||||
connectionEnded: (
|
connectionEnded: (
|
||||||
state,
|
state,
|
||||||
@ -764,6 +758,8 @@ export const $cursorPos = atom<XYPosition | null>(null);
|
|||||||
export const $templates = atom<Templates>({});
|
export const $templates = atom<Templates>({});
|
||||||
export const $copiedNodes = atom<AnyNode[]>([]);
|
export const $copiedNodes = atom<AnyNode[]>([]);
|
||||||
export const $copiedEdges = atom<InvocationNodeEdge[]>([]);
|
export const $copiedEdges = atom<InvocationNodeEdge[]>([]);
|
||||||
|
export const $pendingConnection = atom<PendingConnection | null>(null);
|
||||||
|
export const $isModifyingEdge = atom(false);
|
||||||
|
|
||||||
export const selectNodesSlice = (state: RootState) => state.nodes.present;
|
export const selectNodesSlice = (state: RootState) => state.nodes.present;
|
||||||
|
|
||||||
|
@ -1,6 +1,13 @@
|
|||||||
import type { FieldIdentifier, FieldType, StatefulFieldValue } from 'features/nodes/types/field';
|
import type {
|
||||||
|
FieldIdentifier,
|
||||||
|
FieldInputTemplate,
|
||||||
|
FieldOutputTemplate,
|
||||||
|
FieldType,
|
||||||
|
StatefulFieldValue,
|
||||||
|
} from 'features/nodes/types/field';
|
||||||
import type {
|
import type {
|
||||||
AnyNode,
|
AnyNode,
|
||||||
|
InvocationNode,
|
||||||
InvocationNodeEdge,
|
InvocationNodeEdge,
|
||||||
InvocationTemplate,
|
InvocationTemplate,
|
||||||
NodeExecutionState,
|
NodeExecutionState,
|
||||||
@ -10,6 +17,12 @@ import type { OnConnectStartParams, Viewport, XYPosition } from 'reactflow';
|
|||||||
|
|
||||||
export type Templates = Record<string, InvocationTemplate>;
|
export type Templates = Record<string, InvocationTemplate>;
|
||||||
|
|
||||||
|
export type PendingConnection = {
|
||||||
|
node: InvocationNode;
|
||||||
|
template: InvocationTemplate;
|
||||||
|
fieldTemplate: FieldInputTemplate | FieldOutputTemplate;
|
||||||
|
};
|
||||||
|
|
||||||
export type NodesState = {
|
export type NodesState = {
|
||||||
_version: 1;
|
_version: 1;
|
||||||
nodes: AnyNode[];
|
nodes: AnyNode[];
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
import type { Templates } from 'features/nodes/store/types';
|
import type { PendingConnection, Templates } from 'features/nodes/store/types';
|
||||||
import type { FieldInputTemplate, FieldOutputTemplate, FieldType } from 'features/nodes/types/field';
|
import type { FieldInputTemplate, FieldOutputTemplate, FieldType } from 'features/nodes/types/field';
|
||||||
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
|
import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation';
|
||||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||||
|
import { differenceWith, map } from 'lodash-es';
|
||||||
import type { Connection, Edge, HandleType, Node } from 'reactflow';
|
import type { Connection, Edge, HandleType, Node } from 'reactflow';
|
||||||
|
import { assert } from 'tsafe';
|
||||||
|
|
||||||
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
|
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
|
||||||
import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
|
import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
|
||||||
@ -111,3 +113,98 @@ export const findConnectionToValidHandle = (
|
|||||||
}
|
}
|
||||||
return null;
|
return null;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const getFirstValidConnection = (
|
||||||
|
nodes: AnyNode[],
|
||||||
|
edges: InvocationNodeEdge[],
|
||||||
|
pendingConnection: PendingConnection,
|
||||||
|
candidateNode: InvocationNode,
|
||||||
|
candidateTemplate: InvocationTemplate
|
||||||
|
): Connection | null => {
|
||||||
|
if (pendingConnection.node.id === candidateNode.id) {
|
||||||
|
// Cannot connect to self
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source';
|
||||||
|
|
||||||
|
if (pendingFieldKind === 'source') {
|
||||||
|
// Connecting from a source to a target
|
||||||
|
if (!getIsGraphAcyclic(pendingConnection.node.id, candidateNode.id, nodes, edges)) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
if (candidateNode.data.type === 'collect') {
|
||||||
|
// Special handling for collect node - the `item` field takes any number of connections
|
||||||
|
return {
|
||||||
|
source: pendingConnection.node.id,
|
||||||
|
sourceHandle: pendingConnection.fieldTemplate.name,
|
||||||
|
target: candidateNode.id,
|
||||||
|
targetHandle: 'item',
|
||||||
|
};
|
||||||
|
}
|
||||||
|
// Only one connection per target field is allowed - look for an unconnected target field
|
||||||
|
const candidateFields = map(candidateTemplate.inputs);
|
||||||
|
const candidateConnectedFields = edges
|
||||||
|
.filter((edge) => edge.target === candidateNode.id)
|
||||||
|
.map((edge) => {
|
||||||
|
// Edges must always have a targetHandle, safe to assert here
|
||||||
|
assert(edge.targetHandle);
|
||||||
|
return edge.targetHandle;
|
||||||
|
});
|
||||||
|
const candidateUnconnectedFields = differenceWith(
|
||||||
|
candidateFields,
|
||||||
|
candidateConnectedFields,
|
||||||
|
(field, connectedFieldName) => field.name === connectedFieldName
|
||||||
|
);
|
||||||
|
const candidateField = candidateUnconnectedFields.find((field) =>
|
||||||
|
validateSourceAndTargetTypes(pendingConnection.fieldTemplate.type, field.type)
|
||||||
|
);
|
||||||
|
if (candidateField) {
|
||||||
|
return {
|
||||||
|
source: pendingConnection.node.id,
|
||||||
|
sourceHandle: pendingConnection.fieldTemplate.name,
|
||||||
|
target: candidateNode.id,
|
||||||
|
targetHandle: candidateField.name,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Connecting from a target to a source
|
||||||
|
// Ensure we there is not already an edge to the target
|
||||||
|
if (
|
||||||
|
edges.some(
|
||||||
|
(e) => e.target === pendingConnection.node.id && e.targetHandle === pendingConnection.fieldTemplate.name
|
||||||
|
)
|
||||||
|
) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!getIsGraphAcyclic(candidateNode.id, pendingConnection.node.id, nodes, edges)) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (candidateNode.data.type === 'collect') {
|
||||||
|
// Special handling for collect node - connect to the `collection` field
|
||||||
|
return {
|
||||||
|
source: candidateNode.id,
|
||||||
|
sourceHandle: 'collection',
|
||||||
|
target: pendingConnection.node.id,
|
||||||
|
targetHandle: pendingConnection.fieldTemplate.name,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
// Sources/outputs can have any number of edges, we can take the first matching output field
|
||||||
|
const candidateFields = map(candidateTemplate.outputs);
|
||||||
|
const candidateField = candidateFields.find((field) =>
|
||||||
|
validateSourceAndTargetTypes(field.type, pendingConnection.fieldTemplate.type)
|
||||||
|
);
|
||||||
|
if (candidateField) {
|
||||||
|
return {
|
||||||
|
source: candidateNode.id,
|
||||||
|
sourceHandle: candidateField.name,
|
||||||
|
target: pendingConnection.node.id,
|
||||||
|
targetHandle: pendingConnection.fieldTemplate.name,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return null;
|
||||||
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user