fix(ui): handling for in-progress edge updates during conection validation

This commit is contained in:
psychedelicious 2024-05-19 01:37:54 +10:00
parent fc31dddbf7
commit 3605b6b1a3
7 changed files with 28 additions and 22 deletions

View File

@ -9,8 +9,8 @@ import { useWorkflowWatcher } from 'features/nodes/hooks/useWorkflowWatcher';
import { import {
$cursorPos, $cursorPos,
$didUpdateEdge, $didUpdateEdge,
$edgePendingUpdate,
$isAddNodePopoverOpen, $isAddNodePopoverOpen,
$isUpdatingEdge,
$lastEdgeUpdateMouseEvent, $lastEdgeUpdateMouseEvent,
$pendingConnection, $pendingConnection,
$viewport, $viewport,
@ -160,8 +160,8 @@ export const Flow = memo(() => {
* where the edge is deleted if you click it accidentally). * where the edge is deleted if you click it accidentally).
*/ */
const onEdgeUpdateStart: NonNullable<ReactFlowProps['onEdgeUpdateStart']> = useCallback((e, _edge, _handleType) => { const onEdgeUpdateStart: NonNullable<ReactFlowProps['onEdgeUpdateStart']> = useCallback((e, edge, _handleType) => {
$isUpdatingEdge.set(true); $edgePendingUpdate.set(edge);
$didUpdateEdge.set(false); $didUpdateEdge.set(false);
$lastEdgeUpdateMouseEvent.set(e); $lastEdgeUpdateMouseEvent.set(e);
}, []); }, []);
@ -196,7 +196,7 @@ export const Flow = memo(() => {
dispatch(edgeDeleted(edge.id)); dispatch(edgeDeleted(edge.id));
} }
$isUpdatingEdge.set(false); $edgePendingUpdate.set(null);
$didUpdateEdge.set(false); $didUpdateEdge.set(false);
$pendingConnection.set(null); $pendingConnection.set(null);
$lastEdgeUpdateMouseEvent.set(null); $lastEdgeUpdateMouseEvent.set(null);
@ -259,7 +259,7 @@ export const Flow = memo(() => {
useHotkeys(['meta+shift+z', 'ctrl+shift+z'], onRedoHotkey); useHotkeys(['meta+shift+z', 'ctrl+shift+z'], onRedoHotkey);
const onEscapeHotkey = useCallback(() => { const onEscapeHotkey = useCallback(() => {
if (!$isUpdatingEdge.get()) { if (!$edgePendingUpdate.get()) {
$pendingConnection.set(null); $pendingConnection.set(null);
$isAddNodePopoverOpen.set(false); $isAddNodePopoverOpen.set(false);
cancelConnection(); cancelConnection();

View File

@ -2,8 +2,8 @@ import { useStore } from '@nanostores/react';
import { useAppStore } from 'app/store/storeHooks'; import { useAppStore } from 'app/store/storeHooks';
import { $mouseOverNode } from 'features/nodes/hooks/useMouseOverNode'; import { $mouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
import { import {
$edgePendingUpdate,
$isAddNodePopoverOpen, $isAddNodePopoverOpen,
$isUpdatingEdge,
$pendingConnection, $pendingConnection,
$templates, $templates,
connectionMade, connectionMade,
@ -52,12 +52,12 @@ export const useConnection = () => {
const onConnectEnd = useCallback<OnConnectEnd>(() => { const onConnectEnd = useCallback<OnConnectEnd>(() => {
const { dispatch } = store; const { dispatch } = store;
const pendingConnection = $pendingConnection.get(); const pendingConnection = $pendingConnection.get();
const isUpdatingEdge = $isUpdatingEdge.get(); const edgePendingUpdate = $edgePendingUpdate.get();
const mouseOverNodeId = $mouseOverNode.get(); const mouseOverNodeId = $mouseOverNode.get();
// If we are in the middle of an edge update, and the mouse isn't over a node, we should just bail so the edge // If we are in the middle of an edge update, and the mouse isn't over a node, we should just bail so the edge
// update logic can finish up // update logic can finish up
if (isUpdatingEdge && !mouseOverNodeId) { if (edgePendingUpdate && !mouseOverNodeId) {
$pendingConnection.set(null); $pendingConnection.set(null);
return; return;
} }
@ -80,7 +80,8 @@ export const useConnection = () => {
edges, edges,
pendingConnection, pendingConnection,
candidateNode, candidateNode,
candidateTemplate candidateTemplate,
edgePendingUpdate
); );
if (connection) { if (connection) {
dispatch(connectionMade(connection)); dispatch(connectionMade(connection));

View File

@ -1,7 +1,7 @@
import { useStore } from '@nanostores/react'; import { useStore } from '@nanostores/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { $pendingConnection, $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice'; import { $edgePendingUpdate, $pendingConnection, $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { makeConnectionErrorSelector } from 'features/nodes/store/util/makeConnectionErrorSelector'; import { makeConnectionErrorSelector } from 'features/nodes/store/util/makeConnectionErrorSelector';
import { useMemo } from 'react'; import { useMemo } from 'react';
@ -14,6 +14,7 @@ type UseConnectionStateProps = {
export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionStateProps) => { export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionStateProps) => {
const pendingConnection = useStore($pendingConnection); const pendingConnection = useStore($pendingConnection);
const templates = useStore($templates); const templates = useStore($templates);
const edgePendingUpdate = useStore($edgePendingUpdate);
const selectIsConnected = useMemo( const selectIsConnected = useMemo(
() => () =>
@ -47,7 +48,7 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta
pendingConnection.fieldTemplate.fieldKind === { inputs: 'input', outputs: 'output' }[kind] pendingConnection.fieldTemplate.fieldKind === { inputs: 'input', outputs: 'output' }[kind]
); );
}, [fieldName, kind, nodeId, pendingConnection]); }, [fieldName, kind, nodeId, pendingConnection]);
const connectionError = useAppSelector((s) => selectConnectionError(s, pendingConnection)); const connectionError = useAppSelector((s) => selectConnectionError(s, pendingConnection, edgePendingUpdate));
const shouldDim = useMemo( const shouldDim = useMemo(
() => Boolean(isConnectionInProgress && connectionError && !isConnectionStartField), () => Boolean(isConnectionInProgress && connectionError && !isConnectionStartField),

View File

@ -1,7 +1,7 @@
// TODO: enable this at some point // TODO: enable this at some point
import { useStore } from '@nanostores/react'; import { useStore } from '@nanostores/react';
import { useAppSelector, useAppStore } from 'app/store/storeHooks'; import { useAppSelector, useAppStore } from 'app/store/storeHooks';
import { $templates } from 'features/nodes/store/nodesSlice'; import { $edgePendingUpdate, $templates } from 'features/nodes/store/nodesSlice';
import { validateConnection } from 'features/nodes/store/util/validateConnection'; import { validateConnection } from 'features/nodes/store/util/validateConnection';
import { useCallback } from 'react'; import { useCallback } from 'react';
import type { Connection } from 'reactflow'; import type { Connection } from 'reactflow';
@ -21,7 +21,7 @@ export const useIsValidConnection = () => {
if (!(source && sourceHandle && target && targetHandle)) { if (!(source && sourceHandle && target && targetHandle)) {
return false; return false;
} }
const edgePendingUpdate = $edgePendingUpdate.get();
const { nodes, edges } = store.getState().nodes.present; const { nodes, edges } = store.getState().nodes.present;
const validationResult = validateConnection( const validationResult = validateConnection(
@ -29,7 +29,7 @@ export const useIsValidConnection = () => {
nodes, nodes,
edges, edges,
templates, templates,
null, edgePendingUpdate,
shouldValidateGraph shouldValidateGraph
); );

View File

@ -503,7 +503,7 @@ export const $copiedNodes = atom<AnyNode[]>([]);
export const $copiedEdges = atom<InvocationNodeEdge[]>([]); export const $copiedEdges = atom<InvocationNodeEdge[]>([]);
export const $edgesToCopiedNodes = atom<InvocationNodeEdge[]>([]); export const $edgesToCopiedNodes = atom<InvocationNodeEdge[]>([]);
export const $pendingConnection = atom<PendingConnection | null>(null); export const $pendingConnection = atom<PendingConnection | null>(null);
export const $isUpdatingEdge = atom(false); export const $edgePendingUpdate = atom<Edge | null>(null);
export const $didUpdateEdge = atom(false); export const $didUpdateEdge = atom(false);
export const $lastEdgeUpdateMouseEvent = atom<MouseEvent | null>(null); export const $lastEdgeUpdateMouseEvent = atom<MouseEvent | null>(null);

View File

@ -2,7 +2,7 @@ import type { PendingConnection, Templates } from 'features/nodes/store/types';
import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes'; import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes';
import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation'; import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation';
import { differenceWith, map } from 'lodash-es'; import { differenceWith, map } from 'lodash-es';
import type { Connection } from 'reactflow'; import type { Connection, Edge } from 'reactflow';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
import { areTypesEqual } from './areTypesEqual'; import { areTypesEqual } from './areTypesEqual';
@ -26,7 +26,8 @@ export const getFirstValidConnection = (
edges: InvocationNodeEdge[], edges: InvocationNodeEdge[],
pendingConnection: PendingConnection, pendingConnection: PendingConnection,
candidateNode: InvocationNode, candidateNode: InvocationNode,
candidateTemplate: InvocationTemplate candidateTemplate: InvocationTemplate,
edgePendingUpdate: Edge | null
): Connection | null => { ): Connection | null => {
if (pendingConnection.node.id === candidateNode.id) { if (pendingConnection.node.id === candidateNode.id) {
// Cannot connect to self // Cannot connect to self
@ -52,7 +53,7 @@ export const getFirstValidConnection = (
// Only one connection per target field is allowed - look for an unconnected target field // Only one connection per target field is allowed - look for an unconnected target field
const candidateFields = map(candidateTemplate.inputs); const candidateFields = map(candidateTemplate.inputs);
const candidateConnectedFields = edges const candidateConnectedFields = edges
.filter((edge) => edge.target === candidateNode.id) .filter((edge) => edge.target === candidateNode.id || edge.id === edgePendingUpdate?.id)
.map((edge) => { .map((edge) => {
// Edges must always have a targetHandle, safe to assert here // Edges must always have a targetHandle, safe to assert here
assert(edge.targetHandle); assert(edge.targetHandle);
@ -63,7 +64,8 @@ export const getFirstValidConnection = (
candidateConnectedFields, candidateConnectedFields,
(field, connectedFieldName) => field.name === connectedFieldName (field, connectedFieldName) => field.name === connectedFieldName
); );
const candidateField = candidateUnconnectedFields.find((field) => validateConnectionTypes(pendingConnection.fieldTemplate.type, field.type) const candidateField = candidateUnconnectedFields.find((field) =>
validateConnectionTypes(pendingConnection.fieldTemplate.type, field.type)
); );
if (candidateField) { if (candidateField) {
return { return {

View File

@ -4,7 +4,7 @@ import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import type { NodesState, PendingConnection, Templates } from 'features/nodes/store/types'; import type { NodesState, PendingConnection, Templates } from 'features/nodes/store/types';
import { validateConnection } from 'features/nodes/store/util/validateConnection'; import { validateConnection } from 'features/nodes/store/util/validateConnection';
import i18n from 'i18next'; import i18n from 'i18next';
import type { HandleType } from 'reactflow'; import type { Edge, HandleType } from 'reactflow';
/** /**
* Creates a selector that validates a pending connection. * Creates a selector that validates a pending connection.
@ -27,7 +27,9 @@ export const makeConnectionErrorSelector = (
return createMemoizedSelector( return createMemoizedSelector(
selectNodesSlice, selectNodesSlice,
(state: RootState, pendingConnection: PendingConnection | null) => pendingConnection, (state: RootState, pendingConnection: PendingConnection | null) => pendingConnection,
(nodesSlice: NodesState, pendingConnection: PendingConnection | null) => { (state: RootState, pendingConnection: PendingConnection | null, edgePendingUpdate: Edge | null) =>
edgePendingUpdate,
(nodesSlice: NodesState, pendingConnection: PendingConnection | null, edgePendingUpdate: Edge | null) => {
const { nodes, edges } = nodesSlice; const { nodes, edges } = nodesSlice;
if (!pendingConnection) { if (!pendingConnection) {
@ -61,7 +63,7 @@ export const makeConnectionErrorSelector = (
nodes, nodes,
edges, edges,
templates, templates,
null edgePendingUpdate
); );
if (!validationResult.isValid) { if (!validationResult.isValid) {