mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(ui): handling for in-progress edge updates during conection validation
This commit is contained in:
parent
fc31dddbf7
commit
3605b6b1a3
@ -9,8 +9,8 @@ import { useWorkflowWatcher } from 'features/nodes/hooks/useWorkflowWatcher';
|
|||||||
import {
|
import {
|
||||||
$cursorPos,
|
$cursorPos,
|
||||||
$didUpdateEdge,
|
$didUpdateEdge,
|
||||||
|
$edgePendingUpdate,
|
||||||
$isAddNodePopoverOpen,
|
$isAddNodePopoverOpen,
|
||||||
$isUpdatingEdge,
|
|
||||||
$lastEdgeUpdateMouseEvent,
|
$lastEdgeUpdateMouseEvent,
|
||||||
$pendingConnection,
|
$pendingConnection,
|
||||||
$viewport,
|
$viewport,
|
||||||
@ -160,8 +160,8 @@ export const Flow = memo(() => {
|
|||||||
* where the edge is deleted if you click it accidentally).
|
* where the edge is deleted if you click it accidentally).
|
||||||
*/
|
*/
|
||||||
|
|
||||||
const onEdgeUpdateStart: NonNullable<ReactFlowProps['onEdgeUpdateStart']> = useCallback((e, _edge, _handleType) => {
|
const onEdgeUpdateStart: NonNullable<ReactFlowProps['onEdgeUpdateStart']> = useCallback((e, edge, _handleType) => {
|
||||||
$isUpdatingEdge.set(true);
|
$edgePendingUpdate.set(edge);
|
||||||
$didUpdateEdge.set(false);
|
$didUpdateEdge.set(false);
|
||||||
$lastEdgeUpdateMouseEvent.set(e);
|
$lastEdgeUpdateMouseEvent.set(e);
|
||||||
}, []);
|
}, []);
|
||||||
@ -196,7 +196,7 @@ export const Flow = memo(() => {
|
|||||||
dispatch(edgeDeleted(edge.id));
|
dispatch(edgeDeleted(edge.id));
|
||||||
}
|
}
|
||||||
|
|
||||||
$isUpdatingEdge.set(false);
|
$edgePendingUpdate.set(null);
|
||||||
$didUpdateEdge.set(false);
|
$didUpdateEdge.set(false);
|
||||||
$pendingConnection.set(null);
|
$pendingConnection.set(null);
|
||||||
$lastEdgeUpdateMouseEvent.set(null);
|
$lastEdgeUpdateMouseEvent.set(null);
|
||||||
@ -259,7 +259,7 @@ export const Flow = memo(() => {
|
|||||||
useHotkeys(['meta+shift+z', 'ctrl+shift+z'], onRedoHotkey);
|
useHotkeys(['meta+shift+z', 'ctrl+shift+z'], onRedoHotkey);
|
||||||
|
|
||||||
const onEscapeHotkey = useCallback(() => {
|
const onEscapeHotkey = useCallback(() => {
|
||||||
if (!$isUpdatingEdge.get()) {
|
if (!$edgePendingUpdate.get()) {
|
||||||
$pendingConnection.set(null);
|
$pendingConnection.set(null);
|
||||||
$isAddNodePopoverOpen.set(false);
|
$isAddNodePopoverOpen.set(false);
|
||||||
cancelConnection();
|
cancelConnection();
|
||||||
|
@ -2,8 +2,8 @@ import { useStore } from '@nanostores/react';
|
|||||||
import { useAppStore } from 'app/store/storeHooks';
|
import { useAppStore } from 'app/store/storeHooks';
|
||||||
import { $mouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
|
import { $mouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
|
||||||
import {
|
import {
|
||||||
|
$edgePendingUpdate,
|
||||||
$isAddNodePopoverOpen,
|
$isAddNodePopoverOpen,
|
||||||
$isUpdatingEdge,
|
|
||||||
$pendingConnection,
|
$pendingConnection,
|
||||||
$templates,
|
$templates,
|
||||||
connectionMade,
|
connectionMade,
|
||||||
@ -52,12 +52,12 @@ export const useConnection = () => {
|
|||||||
const onConnectEnd = useCallback<OnConnectEnd>(() => {
|
const onConnectEnd = useCallback<OnConnectEnd>(() => {
|
||||||
const { dispatch } = store;
|
const { dispatch } = store;
|
||||||
const pendingConnection = $pendingConnection.get();
|
const pendingConnection = $pendingConnection.get();
|
||||||
const isUpdatingEdge = $isUpdatingEdge.get();
|
const edgePendingUpdate = $edgePendingUpdate.get();
|
||||||
const mouseOverNodeId = $mouseOverNode.get();
|
const mouseOverNodeId = $mouseOverNode.get();
|
||||||
|
|
||||||
// If we are in the middle of an edge update, and the mouse isn't over a node, we should just bail so the edge
|
// If we are in the middle of an edge update, and the mouse isn't over a node, we should just bail so the edge
|
||||||
// update logic can finish up
|
// update logic can finish up
|
||||||
if (isUpdatingEdge && !mouseOverNodeId) {
|
if (edgePendingUpdate && !mouseOverNodeId) {
|
||||||
$pendingConnection.set(null);
|
$pendingConnection.set(null);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -80,7 +80,8 @@ export const useConnection = () => {
|
|||||||
edges,
|
edges,
|
||||||
pendingConnection,
|
pendingConnection,
|
||||||
candidateNode,
|
candidateNode,
|
||||||
candidateTemplate
|
candidateTemplate,
|
||||||
|
edgePendingUpdate
|
||||||
);
|
);
|
||||||
if (connection) {
|
if (connection) {
|
||||||
dispatch(connectionMade(connection));
|
dispatch(connectionMade(connection));
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import { useStore } from '@nanostores/react';
|
import { useStore } from '@nanostores/react';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { $pendingConnection, $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
import { $edgePendingUpdate, $pendingConnection, $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||||
import { makeConnectionErrorSelector } from 'features/nodes/store/util/makeConnectionErrorSelector';
|
import { makeConnectionErrorSelector } from 'features/nodes/store/util/makeConnectionErrorSelector';
|
||||||
import { useMemo } from 'react';
|
import { useMemo } from 'react';
|
||||||
|
|
||||||
@ -14,6 +14,7 @@ type UseConnectionStateProps = {
|
|||||||
export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionStateProps) => {
|
export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionStateProps) => {
|
||||||
const pendingConnection = useStore($pendingConnection);
|
const pendingConnection = useStore($pendingConnection);
|
||||||
const templates = useStore($templates);
|
const templates = useStore($templates);
|
||||||
|
const edgePendingUpdate = useStore($edgePendingUpdate);
|
||||||
|
|
||||||
const selectIsConnected = useMemo(
|
const selectIsConnected = useMemo(
|
||||||
() =>
|
() =>
|
||||||
@ -47,7 +48,7 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta
|
|||||||
pendingConnection.fieldTemplate.fieldKind === { inputs: 'input', outputs: 'output' }[kind]
|
pendingConnection.fieldTemplate.fieldKind === { inputs: 'input', outputs: 'output' }[kind]
|
||||||
);
|
);
|
||||||
}, [fieldName, kind, nodeId, pendingConnection]);
|
}, [fieldName, kind, nodeId, pendingConnection]);
|
||||||
const connectionError = useAppSelector((s) => selectConnectionError(s, pendingConnection));
|
const connectionError = useAppSelector((s) => selectConnectionError(s, pendingConnection, edgePendingUpdate));
|
||||||
|
|
||||||
const shouldDim = useMemo(
|
const shouldDim = useMemo(
|
||||||
() => Boolean(isConnectionInProgress && connectionError && !isConnectionStartField),
|
() => Boolean(isConnectionInProgress && connectionError && !isConnectionStartField),
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
// TODO: enable this at some point
|
// TODO: enable this at some point
|
||||||
import { useStore } from '@nanostores/react';
|
import { useStore } from '@nanostores/react';
|
||||||
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
|
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
|
||||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
import { $edgePendingUpdate, $templates } from 'features/nodes/store/nodesSlice';
|
||||||
import { validateConnection } from 'features/nodes/store/util/validateConnection';
|
import { validateConnection } from 'features/nodes/store/util/validateConnection';
|
||||||
import { useCallback } from 'react';
|
import { useCallback } from 'react';
|
||||||
import type { Connection } from 'reactflow';
|
import type { Connection } from 'reactflow';
|
||||||
@ -21,7 +21,7 @@ export const useIsValidConnection = () => {
|
|||||||
if (!(source && sourceHandle && target && targetHandle)) {
|
if (!(source && sourceHandle && target && targetHandle)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
const edgePendingUpdate = $edgePendingUpdate.get();
|
||||||
const { nodes, edges } = store.getState().nodes.present;
|
const { nodes, edges } = store.getState().nodes.present;
|
||||||
|
|
||||||
const validationResult = validateConnection(
|
const validationResult = validateConnection(
|
||||||
@ -29,7 +29,7 @@ export const useIsValidConnection = () => {
|
|||||||
nodes,
|
nodes,
|
||||||
edges,
|
edges,
|
||||||
templates,
|
templates,
|
||||||
null,
|
edgePendingUpdate,
|
||||||
shouldValidateGraph
|
shouldValidateGraph
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -503,7 +503,7 @@ export const $copiedNodes = atom<AnyNode[]>([]);
|
|||||||
export const $copiedEdges = atom<InvocationNodeEdge[]>([]);
|
export const $copiedEdges = atom<InvocationNodeEdge[]>([]);
|
||||||
export const $edgesToCopiedNodes = atom<InvocationNodeEdge[]>([]);
|
export const $edgesToCopiedNodes = atom<InvocationNodeEdge[]>([]);
|
||||||
export const $pendingConnection = atom<PendingConnection | null>(null);
|
export const $pendingConnection = atom<PendingConnection | null>(null);
|
||||||
export const $isUpdatingEdge = atom(false);
|
export const $edgePendingUpdate = atom<Edge | null>(null);
|
||||||
export const $didUpdateEdge = atom(false);
|
export const $didUpdateEdge = atom(false);
|
||||||
export const $lastEdgeUpdateMouseEvent = atom<MouseEvent | null>(null);
|
export const $lastEdgeUpdateMouseEvent = atom<MouseEvent | null>(null);
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@ import type { PendingConnection, Templates } from 'features/nodes/store/types';
|
|||||||
import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes';
|
import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes';
|
||||||
import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation';
|
import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation';
|
||||||
import { differenceWith, map } from 'lodash-es';
|
import { differenceWith, map } from 'lodash-es';
|
||||||
import type { Connection } from 'reactflow';
|
import type { Connection, Edge } from 'reactflow';
|
||||||
import { assert } from 'tsafe';
|
import { assert } from 'tsafe';
|
||||||
|
|
||||||
import { areTypesEqual } from './areTypesEqual';
|
import { areTypesEqual } from './areTypesEqual';
|
||||||
@ -26,7 +26,8 @@ export const getFirstValidConnection = (
|
|||||||
edges: InvocationNodeEdge[],
|
edges: InvocationNodeEdge[],
|
||||||
pendingConnection: PendingConnection,
|
pendingConnection: PendingConnection,
|
||||||
candidateNode: InvocationNode,
|
candidateNode: InvocationNode,
|
||||||
candidateTemplate: InvocationTemplate
|
candidateTemplate: InvocationTemplate,
|
||||||
|
edgePendingUpdate: Edge | null
|
||||||
): Connection | null => {
|
): Connection | null => {
|
||||||
if (pendingConnection.node.id === candidateNode.id) {
|
if (pendingConnection.node.id === candidateNode.id) {
|
||||||
// Cannot connect to self
|
// Cannot connect to self
|
||||||
@ -52,7 +53,7 @@ export const getFirstValidConnection = (
|
|||||||
// Only one connection per target field is allowed - look for an unconnected target field
|
// Only one connection per target field is allowed - look for an unconnected target field
|
||||||
const candidateFields = map(candidateTemplate.inputs);
|
const candidateFields = map(candidateTemplate.inputs);
|
||||||
const candidateConnectedFields = edges
|
const candidateConnectedFields = edges
|
||||||
.filter((edge) => edge.target === candidateNode.id)
|
.filter((edge) => edge.target === candidateNode.id || edge.id === edgePendingUpdate?.id)
|
||||||
.map((edge) => {
|
.map((edge) => {
|
||||||
// Edges must always have a targetHandle, safe to assert here
|
// Edges must always have a targetHandle, safe to assert here
|
||||||
assert(edge.targetHandle);
|
assert(edge.targetHandle);
|
||||||
@ -63,7 +64,8 @@ export const getFirstValidConnection = (
|
|||||||
candidateConnectedFields,
|
candidateConnectedFields,
|
||||||
(field, connectedFieldName) => field.name === connectedFieldName
|
(field, connectedFieldName) => field.name === connectedFieldName
|
||||||
);
|
);
|
||||||
const candidateField = candidateUnconnectedFields.find((field) => validateConnectionTypes(pendingConnection.fieldTemplate.type, field.type)
|
const candidateField = candidateUnconnectedFields.find((field) =>
|
||||||
|
validateConnectionTypes(pendingConnection.fieldTemplate.type, field.type)
|
||||||
);
|
);
|
||||||
if (candidateField) {
|
if (candidateField) {
|
||||||
return {
|
return {
|
||||||
|
@ -4,7 +4,7 @@ import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
|||||||
import type { NodesState, PendingConnection, Templates } from 'features/nodes/store/types';
|
import type { NodesState, PendingConnection, Templates } from 'features/nodes/store/types';
|
||||||
import { validateConnection } from 'features/nodes/store/util/validateConnection';
|
import { validateConnection } from 'features/nodes/store/util/validateConnection';
|
||||||
import i18n from 'i18next';
|
import i18n from 'i18next';
|
||||||
import type { HandleType } from 'reactflow';
|
import type { Edge, HandleType } from 'reactflow';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a selector that validates a pending connection.
|
* Creates a selector that validates a pending connection.
|
||||||
@ -27,7 +27,9 @@ export const makeConnectionErrorSelector = (
|
|||||||
return createMemoizedSelector(
|
return createMemoizedSelector(
|
||||||
selectNodesSlice,
|
selectNodesSlice,
|
||||||
(state: RootState, pendingConnection: PendingConnection | null) => pendingConnection,
|
(state: RootState, pendingConnection: PendingConnection | null) => pendingConnection,
|
||||||
(nodesSlice: NodesState, pendingConnection: PendingConnection | null) => {
|
(state: RootState, pendingConnection: PendingConnection | null, edgePendingUpdate: Edge | null) =>
|
||||||
|
edgePendingUpdate,
|
||||||
|
(nodesSlice: NodesState, pendingConnection: PendingConnection | null, edgePendingUpdate: Edge | null) => {
|
||||||
const { nodes, edges } = nodesSlice;
|
const { nodes, edges } = nodesSlice;
|
||||||
|
|
||||||
if (!pendingConnection) {
|
if (!pendingConnection) {
|
||||||
@ -61,7 +63,7 @@ export const makeConnectionErrorSelector = (
|
|||||||
nodes,
|
nodes,
|
||||||
edges,
|
edges,
|
||||||
templates,
|
templates,
|
||||||
null
|
edgePendingUpdate
|
||||||
);
|
);
|
||||||
|
|
||||||
if (!validationResult.isValid) {
|
if (!validationResult.isValid) {
|
||||||
|
Loading…
Reference in New Issue
Block a user