feat(ui): rework pendingConnection

This commit is contained in:
psychedelicious 2024-05-19 11:49:40 +10:00
parent 4bda174eb9
commit a80e3448f5
7 changed files with 73 additions and 69 deletions

View File

@ -9,6 +9,7 @@ import type { SelectInstance } from 'chakra-react-select';
import { useBuildNode } from 'features/nodes/hooks/useBuildNode'; import { useBuildNode } from 'features/nodes/hooks/useBuildNode';
import { import {
$cursorPos, $cursorPos,
$edgePendingUpdate,
$isAddNodePopoverOpen, $isAddNodePopoverOpen,
$pendingConnection, $pendingConnection,
$templates, $templates,
@ -28,7 +29,6 @@ import { useHotkeys } from 'react-hotkeys-hook';
import type { HotkeyCallback } from 'react-hotkeys-hook/dist/types'; import type { HotkeyCallback } from 'react-hotkeys-hook/dist/types';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import type { FilterOptionOption } from 'react-select/dist/declarations/src/filters'; import type { FilterOptionOption } from 'react-select/dist/declarations/src/filters';
import { assert } from 'tsafe';
const createRegex = memoize( const createRegex = memoize(
(inputValue: string) => (inputValue: string) =>
@ -68,16 +68,18 @@ const AddNodePopover = () => {
const filteredTemplates = useMemo(() => { const filteredTemplates = useMemo(() => {
// If we have a connection in progress, we need to filter the node choices // If we have a connection in progress, we need to filter the node choices
const templatesArray = map(templates);
if (!pendingConnection) { if (!pendingConnection) {
return map(templates); return templatesArray;
} }
return filter(templates, (template) => { return filter(templates, (template) => {
const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind; const candidateFields = pendingConnection.handleType === 'source' ? template.inputs : template.outputs;
const fields = pendingFieldKind === 'input' ? template.outputs : template.inputs; return some(candidateFields, (field) => {
return some(fields, (field) => { const sourceType =
const sourceType = pendingFieldKind === 'input' ? field.type : pendingConnection.fieldTemplate.type; pendingConnection.handleType === 'source' ? field.type : pendingConnection.fieldTemplate.type;
const targetType = pendingFieldKind === 'output' ? field.type : pendingConnection.fieldTemplate.type; const targetType =
pendingConnection.handleType === 'target' ? field.type : pendingConnection.fieldTemplate.type;
return validateConnectionTypes(sourceType, targetType); return validateConnectionTypes(sourceType, targetType);
}); });
}); });
@ -144,10 +146,25 @@ const AddNodePopover = () => {
// Auto-connect an edge if we just added a node and have a pending connection // Auto-connect an edge if we just added a node and have a pending connection
if (pendingConnection && isInvocationNode(node)) { if (pendingConnection && isInvocationNode(node)) {
const template = templates[node.data.type]; const edgePendingUpdate = $edgePendingUpdate.get();
assert(template, 'Template not found'); const { handleType } = pendingConnection;
const source = handleType === 'source' ? pendingConnection.nodeId : node.id;
const sourceHandle = handleType === 'source' ? pendingConnection.handleId : null;
const target = handleType === 'target' ? pendingConnection.nodeId : node.id;
const targetHandle = handleType === 'target' ? pendingConnection.handleId : null;
const { nodes, edges } = store.getState().nodes.present; const { nodes, edges } = store.getState().nodes.present;
const connection = getFirstValidConnection(templates, nodes, edges, pendingConnection, node, template); const connection = getFirstValidConnection(
source,
sourceHandle,
target,
targetHandle,
nodes,
edges,
templates,
edgePendingUpdate
);
if (connection) { if (connection) {
dispatch(connectionMade(connection)); dispatch(connectionMade(connection));
} }

View File

@ -9,10 +9,10 @@ import {
connectionMade, connectionMade,
} from 'features/nodes/store/nodesSlice'; } from 'features/nodes/store/nodesSlice';
import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection'; import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { isString } from 'lodash-es'; import { isString } from 'lodash-es';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import { type OnConnect, type OnConnectEnd, type OnConnectStart, useUpdateNodeInternals } from 'reactflow'; import type { OnConnect, OnConnectEnd, OnConnectStart } from 'reactflow';
import { useUpdateNodeInternals } from 'reactflow';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
export const useConnection = () => { export const useConnection = () => {
@ -21,21 +21,27 @@ export const useConnection = () => {
const updateNodeInternals = useUpdateNodeInternals(); const updateNodeInternals = useUpdateNodeInternals();
const onConnectStart = useCallback<OnConnectStart>( const onConnectStart = useCallback<OnConnectStart>(
(event, params) => { (event, { nodeId, handleId, handleType }) => {
assert(nodeId && handleId && handleType, 'Invalid connection start event');
const nodes = store.getState().nodes.present.nodes; const nodes = store.getState().nodes.present.nodes;
const { nodeId, handleId, handleType } = params;
assert(nodeId && handleId && handleType, `Invalid connection start params: ${JSON.stringify(params)}`);
const node = nodes.find((n) => n.id === nodeId); const node = nodes.find((n) => n.id === nodeId);
assert(isInvocationNode(node), `Invalid node during connection: ${JSON.stringify(node)}`); if (!node) {
return;
}
const template = templates[node.data.type]; const template = templates[node.data.type];
assert(template, `Template not found for node type: ${node.data.type}`); if (!template) {
const fieldTemplate = handleType === 'source' ? template.outputs[handleId] : template.inputs[handleId]; return;
assert(fieldTemplate, `Field template not found for field: ${node.data.type}.${handleId}`); }
$pendingConnection.set({
node, const fieldTemplates = template[handleType === 'source' ? 'outputs' : 'inputs'];
template, const fieldTemplate = fieldTemplates[handleId];
fieldTemplate, if (!fieldTemplate) {
}); return;
}
$pendingConnection.set({ nodeId, handleId, handleType, fieldTemplate });
}, },
[store, templates] [store, templates]
); );
@ -67,20 +73,20 @@ export const useConnection = () => {
} }
const { nodes, edges } = store.getState().nodes.present; const { nodes, edges } = store.getState().nodes.present;
if (mouseOverNodeId) { if (mouseOverNodeId) {
const candidateNode = nodes.filter(isInvocationNode).find((n) => n.id === mouseOverNodeId); const { handleType } = pendingConnection;
if (!candidateNode) { const source = handleType === 'source' ? pendingConnection.nodeId : mouseOverNodeId;
// The mouse is over a non-invocation node - bail const sourceHandle = handleType === 'source' ? pendingConnection.handleId : null;
return; const target = handleType === 'target' ? pendingConnection.nodeId : mouseOverNodeId;
} const targetHandle = handleType === 'target' ? pendingConnection.handleId : null;
const candidateTemplate = templates[candidateNode.data.type];
assert(candidateTemplate, `Template not found for node type: ${candidateNode.data.type}`);
const connection = getFirstValidConnection( const connection = getFirstValidConnection(
templates, source,
sourceHandle,
target,
targetHandle,
nodes, nodes,
edges, edges,
pendingConnection, templates,
candidateNode,
candidateTemplate,
edgePendingUpdate edgePendingUpdate
); );
if (connection) { if (connection) {

View File

@ -43,8 +43,8 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta
return false; return false;
} }
return ( return (
pendingConnection.node.id === nodeId && pendingConnection.nodeId === nodeId &&
pendingConnection.fieldTemplate.name === fieldName && pendingConnection.handleId === fieldName &&
pendingConnection.fieldTemplate.fieldKind === { inputs: 'input', outputs: 'output' }[kind] pendingConnection.fieldTemplate.fieldKind === { inputs: 'input', outputs: 'output' }[kind]
); );
}, [fieldName, kind, nodeId, pendingConnection]); }, [fieldName, kind, nodeId, pendingConnection]);

View File

@ -6,19 +6,20 @@ import type {
} from 'features/nodes/types/field'; } from 'features/nodes/types/field';
import type { import type {
AnyNode, AnyNode,
InvocationNode,
InvocationNodeEdge, InvocationNodeEdge,
InvocationTemplate, InvocationTemplate,
NodeExecutionState, NodeExecutionState,
} from 'features/nodes/types/invocation'; } from 'features/nodes/types/invocation';
import type { WorkflowV3 } from 'features/nodes/types/workflow'; import type { WorkflowV3 } from 'features/nodes/types/workflow';
import type { HandleType } from 'reactflow';
export type Templates = Record<string, InvocationTemplate>; export type Templates = Record<string, InvocationTemplate>;
export type NodeExecutionStates = Record<string, NodeExecutionState | undefined>; export type NodeExecutionStates = Record<string, NodeExecutionState | undefined>;
export type PendingConnection = { export type PendingConnection = {
node: InvocationNode; nodeId: string;
template: InvocationTemplate; handleId: string;
handleType: HandleType;
fieldTemplate: FieldInputTemplate | FieldOutputTemplate; fieldTemplate: FieldInputTemplate | FieldOutputTemplate;
}; };

View File

@ -36,9 +36,7 @@ export const makeConnectionErrorSelector = (
return i18n.t('nodes.noConnectionInProgress'); return i18n.t('nodes.noConnectionInProgress');
} }
const connectionHandleType = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source'; if (handleType === pendingConnection.handleType) {
if (handleType === connectionHandleType) {
if (handleType === 'source') { if (handleType === 'source') {
return i18n.t('nodes.cannotConnectOutputToOutput'); return i18n.t('nodes.cannotConnectOutputToOutput');
} }
@ -46,10 +44,10 @@ export const makeConnectionErrorSelector = (
} }
// we have to figure out which is the target and which is the source // we have to figure out which is the target and which is the source
const source = handleType === 'source' ? nodeId : pendingConnection.node.id; const source = handleType === 'source' ? nodeId : pendingConnection.nodeId;
const sourceHandle = handleType === 'source' ? fieldName : pendingConnection.fieldTemplate.name; const sourceHandle = handleType === 'source' ? fieldName : pendingConnection.handleId;
const target = handleType === 'target' ? nodeId : pendingConnection.node.id; const target = handleType === 'target' ? nodeId : pendingConnection.nodeId;
const targetHandle = handleType === 'target' ? fieldName : pendingConnection.fieldTemplate.name; const targetHandle = handleType === 'target' ? fieldName : pendingConnection.handleId;
const validationResult = validateConnection( const validationResult = validateConnection(
{ {

View File

@ -2,7 +2,7 @@ import type { Templates } from 'features/nodes/store/types';
import type { InvocationTemplate } from 'features/nodes/types/invocation'; import type { InvocationTemplate } from 'features/nodes/types/invocation';
import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode'; import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode';
import type { OpenAPIV3_1 } from 'openapi-types'; import type { OpenAPIV3_1 } from 'openapi-types';
import type { Edge, XYPosition } from 'reactflow'; import type { Edge } from 'reactflow';
export const buildEdge = (source: string, sourceHandle: string, target: string, targetHandle: string): Edge => ({ export const buildEdge = (source: string, sourceHandle: string, target: string, targetHandle: string): Edge => ({
source, source,
@ -13,8 +13,6 @@ export const buildEdge = (source: string, sourceHandle: string, target: string,
id: `reactflow__edge-${source}${sourceHandle}-${target}${targetHandle}`, id: `reactflow__edge-${source}${sourceHandle}-${target}${targetHandle}`,
}); });
export const position: XYPosition = { x: 0, y: 0 };
export const buildNode = (template: InvocationTemplate) => buildInvocationNode({ x: 0, y: 0 }, template); export const buildNode = (template: InvocationTemplate) => buildInvocationNode({ x: 0, y: 0 }, template);
export const add: InvocationTemplate = { export const add: InvocationTemplate = {
@ -176,7 +174,7 @@ export const collect: InvocationTemplate = {
classification: 'stable', classification: 'stable',
}; };
export const scheduler: InvocationTemplate = { const scheduler: InvocationTemplate = {
title: 'Scheduler', title: 'Scheduler',
type: 'scheduler', type: 'scheduler',
version: '1.0.0', version: '1.0.0',

View File

@ -6,11 +6,10 @@ import { validateConnectionTypes } from 'features/nodes/store/util/validateConne
import type { AnyNode } from 'features/nodes/types/invocation'; import type { AnyNode } from 'features/nodes/types/invocation';
import type { Connection as NullableConnection, Edge } from 'reactflow'; import type { Connection as NullableConnection, Edge } from 'reactflow';
import type { O } from 'ts-toolbelt'; import type { O } from 'ts-toolbelt';
import { assert } from 'tsafe';
type Connection = O.NonNullable<NullableConnection>; type Connection = O.NonNullable<NullableConnection>;
export type ValidateConnectionResult = type ValidateConnectionResult =
| { | {
isValid: true; isValid: true;
messageTKey?: string; messageTKey?: string;
@ -20,7 +19,7 @@ export type ValidateConnectionResult =
messageTKey: string; messageTKey: string;
}; };
export type ValidateConnectionFunc = ( type ValidateConnectionFunc = (
connection: Connection, connection: Connection,
nodes: AnyNode[], nodes: AnyNode[],
edges: Edge[], edges: Edge[],
@ -29,21 +28,6 @@ export type ValidateConnectionFunc = (
strict?: boolean strict?: boolean
) => ValidateConnectionResult; ) => ValidateConnectionResult;
export const buildResult = (isValid: boolean, messageTKey?: string): ValidateConnectionResult => {
if (isValid) {
return {
isValid,
messageTKey,
};
} else {
assert(messageTKey !== undefined);
return {
isValid,
messageTKey,
};
}
};
const getEqualityPredicate = const getEqualityPredicate =
(c: Connection) => (c: Connection) =>
(e: Edge): boolean => { (e: Edge): boolean => {