feat(ui): rework pendingConnection

This commit is contained in:
psychedelicious 2024-05-19 11:49:40 +10:00
parent 4bda174eb9
commit a80e3448f5
7 changed files with 73 additions and 69 deletions

View File

@ -9,6 +9,7 @@ import type { SelectInstance } from 'chakra-react-select';
import { useBuildNode } from 'features/nodes/hooks/useBuildNode';
import {
$cursorPos,
$edgePendingUpdate,
$isAddNodePopoverOpen,
$pendingConnection,
$templates,
@ -28,7 +29,6 @@ import { useHotkeys } from 'react-hotkeys-hook';
import type { HotkeyCallback } from 'react-hotkeys-hook/dist/types';
import { useTranslation } from 'react-i18next';
import type { FilterOptionOption } from 'react-select/dist/declarations/src/filters';
import { assert } from 'tsafe';
const createRegex = memoize(
(inputValue: string) =>
@ -68,16 +68,18 @@ const AddNodePopover = () => {
const filteredTemplates = useMemo(() => {
// If we have a connection in progress, we need to filter the node choices
const templatesArray = map(templates);
if (!pendingConnection) {
return map(templates);
return templatesArray;
}
return filter(templates, (template) => {
const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind;
const fields = pendingFieldKind === 'input' ? template.outputs : template.inputs;
return some(fields, (field) => {
const sourceType = pendingFieldKind === 'input' ? field.type : pendingConnection.fieldTemplate.type;
const targetType = pendingFieldKind === 'output' ? field.type : pendingConnection.fieldTemplate.type;
const candidateFields = pendingConnection.handleType === 'source' ? template.inputs : template.outputs;
return some(candidateFields, (field) => {
const sourceType =
pendingConnection.handleType === 'source' ? field.type : pendingConnection.fieldTemplate.type;
const targetType =
pendingConnection.handleType === 'target' ? field.type : pendingConnection.fieldTemplate.type;
return validateConnectionTypes(sourceType, targetType);
});
});
@ -144,10 +146,25 @@ const AddNodePopover = () => {
// Auto-connect an edge if we just added a node and have a pending connection
if (pendingConnection && isInvocationNode(node)) {
const template = templates[node.data.type];
assert(template, 'Template not found');
const edgePendingUpdate = $edgePendingUpdate.get();
const { handleType } = pendingConnection;
const source = handleType === 'source' ? pendingConnection.nodeId : node.id;
const sourceHandle = handleType === 'source' ? pendingConnection.handleId : null;
const target = handleType === 'target' ? pendingConnection.nodeId : node.id;
const targetHandle = handleType === 'target' ? pendingConnection.handleId : null;
const { nodes, edges } = store.getState().nodes.present;
const connection = getFirstValidConnection(templates, nodes, edges, pendingConnection, node, template);
const connection = getFirstValidConnection(
source,
sourceHandle,
target,
targetHandle,
nodes,
edges,
templates,
edgePendingUpdate
);
if (connection) {
dispatch(connectionMade(connection));
}

View File

@ -9,10 +9,10 @@ import {
connectionMade,
} from 'features/nodes/store/nodesSlice';
import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { isString } from 'lodash-es';
import { useCallback, useMemo } from 'react';
import { type OnConnect, type OnConnectEnd, type OnConnectStart, useUpdateNodeInternals } from 'reactflow';
import type { OnConnect, OnConnectEnd, OnConnectStart } from 'reactflow';
import { useUpdateNodeInternals } from 'reactflow';
import { assert } from 'tsafe';
export const useConnection = () => {
@ -21,21 +21,27 @@ export const useConnection = () => {
const updateNodeInternals = useUpdateNodeInternals();
const onConnectStart = useCallback<OnConnectStart>(
(event, params) => {
(event, { nodeId, handleId, handleType }) => {
assert(nodeId && handleId && handleType, 'Invalid connection start event');
const nodes = store.getState().nodes.present.nodes;
const { nodeId, handleId, handleType } = params;
assert(nodeId && handleId && handleType, `Invalid connection start params: ${JSON.stringify(params)}`);
const node = nodes.find((n) => n.id === nodeId);
assert(isInvocationNode(node), `Invalid node during connection: ${JSON.stringify(node)}`);
if (!node) {
return;
}
const template = templates[node.data.type];
assert(template, `Template not found for node type: ${node.data.type}`);
const fieldTemplate = handleType === 'source' ? template.outputs[handleId] : template.inputs[handleId];
assert(fieldTemplate, `Field template not found for field: ${node.data.type}.${handleId}`);
$pendingConnection.set({
node,
template,
fieldTemplate,
});
if (!template) {
return;
}
const fieldTemplates = template[handleType === 'source' ? 'outputs' : 'inputs'];
const fieldTemplate = fieldTemplates[handleId];
if (!fieldTemplate) {
return;
}
$pendingConnection.set({ nodeId, handleId, handleType, fieldTemplate });
},
[store, templates]
);
@ -67,20 +73,20 @@ export const useConnection = () => {
}
const { nodes, edges } = store.getState().nodes.present;
if (mouseOverNodeId) {
const candidateNode = nodes.filter(isInvocationNode).find((n) => n.id === mouseOverNodeId);
if (!candidateNode) {
// The mouse is over a non-invocation node - bail
return;
}
const candidateTemplate = templates[candidateNode.data.type];
assert(candidateTemplate, `Template not found for node type: ${candidateNode.data.type}`);
const { handleType } = pendingConnection;
const source = handleType === 'source' ? pendingConnection.nodeId : mouseOverNodeId;
const sourceHandle = handleType === 'source' ? pendingConnection.handleId : null;
const target = handleType === 'target' ? pendingConnection.nodeId : mouseOverNodeId;
const targetHandle = handleType === 'target' ? pendingConnection.handleId : null;
const connection = getFirstValidConnection(
templates,
source,
sourceHandle,
target,
targetHandle,
nodes,
edges,
pendingConnection,
candidateNode,
candidateTemplate,
templates,
edgePendingUpdate
);
if (connection) {

View File

@ -43,8 +43,8 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta
return false;
}
return (
pendingConnection.node.id === nodeId &&
pendingConnection.fieldTemplate.name === fieldName &&
pendingConnection.nodeId === nodeId &&
pendingConnection.handleId === fieldName &&
pendingConnection.fieldTemplate.fieldKind === { inputs: 'input', outputs: 'output' }[kind]
);
}, [fieldName, kind, nodeId, pendingConnection]);

View File

@ -6,19 +6,20 @@ import type {
} from 'features/nodes/types/field';
import type {
AnyNode,
InvocationNode,
InvocationNodeEdge,
InvocationTemplate,
NodeExecutionState,
} from 'features/nodes/types/invocation';
import type { WorkflowV3 } from 'features/nodes/types/workflow';
import type { HandleType } from 'reactflow';
export type Templates = Record<string, InvocationTemplate>;
export type NodeExecutionStates = Record<string, NodeExecutionState | undefined>;
export type PendingConnection = {
node: InvocationNode;
template: InvocationTemplate;
nodeId: string;
handleId: string;
handleType: HandleType;
fieldTemplate: FieldInputTemplate | FieldOutputTemplate;
};

View File

@ -36,9 +36,7 @@ export const makeConnectionErrorSelector = (
return i18n.t('nodes.noConnectionInProgress');
}
const connectionHandleType = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source';
if (handleType === connectionHandleType) {
if (handleType === pendingConnection.handleType) {
if (handleType === 'source') {
return i18n.t('nodes.cannotConnectOutputToOutput');
}
@ -46,10 +44,10 @@ export const makeConnectionErrorSelector = (
}
// we have to figure out which is the target and which is the source
const source = handleType === 'source' ? nodeId : pendingConnection.node.id;
const sourceHandle = handleType === 'source' ? fieldName : pendingConnection.fieldTemplate.name;
const target = handleType === 'target' ? nodeId : pendingConnection.node.id;
const targetHandle = handleType === 'target' ? fieldName : pendingConnection.fieldTemplate.name;
const source = handleType === 'source' ? nodeId : pendingConnection.nodeId;
const sourceHandle = handleType === 'source' ? fieldName : pendingConnection.handleId;
const target = handleType === 'target' ? nodeId : pendingConnection.nodeId;
const targetHandle = handleType === 'target' ? fieldName : pendingConnection.handleId;
const validationResult = validateConnection(
{

View File

@ -2,7 +2,7 @@ import type { Templates } from 'features/nodes/store/types';
import type { InvocationTemplate } from 'features/nodes/types/invocation';
import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode';
import type { OpenAPIV3_1 } from 'openapi-types';
import type { Edge, XYPosition } from 'reactflow';
import type { Edge } from 'reactflow';
export const buildEdge = (source: string, sourceHandle: string, target: string, targetHandle: string): Edge => ({
source,
@ -13,8 +13,6 @@ export const buildEdge = (source: string, sourceHandle: string, target: string,
id: `reactflow__edge-${source}${sourceHandle}-${target}${targetHandle}`,
});
export const position: XYPosition = { x: 0, y: 0 };
export const buildNode = (template: InvocationTemplate) => buildInvocationNode({ x: 0, y: 0 }, template);
export const add: InvocationTemplate = {
@ -176,7 +174,7 @@ export const collect: InvocationTemplate = {
classification: 'stable',
};
export const scheduler: InvocationTemplate = {
const scheduler: InvocationTemplate = {
title: 'Scheduler',
type: 'scheduler',
version: '1.0.0',

View File

@ -6,11 +6,10 @@ import { validateConnectionTypes } from 'features/nodes/store/util/validateConne
import type { AnyNode } from 'features/nodes/types/invocation';
import type { Connection as NullableConnection, Edge } from 'reactflow';
import type { O } from 'ts-toolbelt';
import { assert } from 'tsafe';
type Connection = O.NonNullable<NullableConnection>;
export type ValidateConnectionResult =
type ValidateConnectionResult =
| {
isValid: true;
messageTKey?: string;
@ -20,7 +19,7 @@ export type ValidateConnectionResult =
messageTKey: string;
};
export type ValidateConnectionFunc = (
type ValidateConnectionFunc = (
connection: Connection,
nodes: AnyNode[],
edges: Edge[],
@ -29,21 +28,6 @@ export type ValidateConnectionFunc = (
strict?: boolean
) => ValidateConnectionResult;
export const buildResult = (isValid: boolean, messageTKey?: string): ValidateConnectionResult => {
if (isValid) {
return {
isValid,
messageTKey,
};
} else {
assert(messageTKey !== undefined);
return {
isValid,
messageTKey,
};
}
};
const getEqualityPredicate =
(c: Connection) =>
(e: Edge): boolean => {