mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): rework pendingConnection
This commit is contained in:
parent
4bda174eb9
commit
a80e3448f5
@ -9,6 +9,7 @@ import type { SelectInstance } from 'chakra-react-select';
|
|||||||
import { useBuildNode } from 'features/nodes/hooks/useBuildNode';
|
import { useBuildNode } from 'features/nodes/hooks/useBuildNode';
|
||||||
import {
|
import {
|
||||||
$cursorPos,
|
$cursorPos,
|
||||||
|
$edgePendingUpdate,
|
||||||
$isAddNodePopoverOpen,
|
$isAddNodePopoverOpen,
|
||||||
$pendingConnection,
|
$pendingConnection,
|
||||||
$templates,
|
$templates,
|
||||||
@ -28,7 +29,6 @@ import { useHotkeys } from 'react-hotkeys-hook';
|
|||||||
import type { HotkeyCallback } from 'react-hotkeys-hook/dist/types';
|
import type { HotkeyCallback } from 'react-hotkeys-hook/dist/types';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import type { FilterOptionOption } from 'react-select/dist/declarations/src/filters';
|
import type { FilterOptionOption } from 'react-select/dist/declarations/src/filters';
|
||||||
import { assert } from 'tsafe';
|
|
||||||
|
|
||||||
const createRegex = memoize(
|
const createRegex = memoize(
|
||||||
(inputValue: string) =>
|
(inputValue: string) =>
|
||||||
@ -68,16 +68,18 @@ const AddNodePopover = () => {
|
|||||||
|
|
||||||
const filteredTemplates = useMemo(() => {
|
const filteredTemplates = useMemo(() => {
|
||||||
// If we have a connection in progress, we need to filter the node choices
|
// If we have a connection in progress, we need to filter the node choices
|
||||||
|
const templatesArray = map(templates);
|
||||||
if (!pendingConnection) {
|
if (!pendingConnection) {
|
||||||
return map(templates);
|
return templatesArray;
|
||||||
}
|
}
|
||||||
|
|
||||||
return filter(templates, (template) => {
|
return filter(templates, (template) => {
|
||||||
const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind;
|
const candidateFields = pendingConnection.handleType === 'source' ? template.inputs : template.outputs;
|
||||||
const fields = pendingFieldKind === 'input' ? template.outputs : template.inputs;
|
return some(candidateFields, (field) => {
|
||||||
return some(fields, (field) => {
|
const sourceType =
|
||||||
const sourceType = pendingFieldKind === 'input' ? field.type : pendingConnection.fieldTemplate.type;
|
pendingConnection.handleType === 'source' ? field.type : pendingConnection.fieldTemplate.type;
|
||||||
const targetType = pendingFieldKind === 'output' ? field.type : pendingConnection.fieldTemplate.type;
|
const targetType =
|
||||||
|
pendingConnection.handleType === 'target' ? field.type : pendingConnection.fieldTemplate.type;
|
||||||
return validateConnectionTypes(sourceType, targetType);
|
return validateConnectionTypes(sourceType, targetType);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
@ -144,10 +146,25 @@ const AddNodePopover = () => {
|
|||||||
|
|
||||||
// Auto-connect an edge if we just added a node and have a pending connection
|
// Auto-connect an edge if we just added a node and have a pending connection
|
||||||
if (pendingConnection && isInvocationNode(node)) {
|
if (pendingConnection && isInvocationNode(node)) {
|
||||||
const template = templates[node.data.type];
|
const edgePendingUpdate = $edgePendingUpdate.get();
|
||||||
assert(template, 'Template not found');
|
const { handleType } = pendingConnection;
|
||||||
|
|
||||||
|
const source = handleType === 'source' ? pendingConnection.nodeId : node.id;
|
||||||
|
const sourceHandle = handleType === 'source' ? pendingConnection.handleId : null;
|
||||||
|
const target = handleType === 'target' ? pendingConnection.nodeId : node.id;
|
||||||
|
const targetHandle = handleType === 'target' ? pendingConnection.handleId : null;
|
||||||
|
|
||||||
const { nodes, edges } = store.getState().nodes.present;
|
const { nodes, edges } = store.getState().nodes.present;
|
||||||
const connection = getFirstValidConnection(templates, nodes, edges, pendingConnection, node, template);
|
const connection = getFirstValidConnection(
|
||||||
|
source,
|
||||||
|
sourceHandle,
|
||||||
|
target,
|
||||||
|
targetHandle,
|
||||||
|
nodes,
|
||||||
|
edges,
|
||||||
|
templates,
|
||||||
|
edgePendingUpdate
|
||||||
|
);
|
||||||
if (connection) {
|
if (connection) {
|
||||||
dispatch(connectionMade(connection));
|
dispatch(connectionMade(connection));
|
||||||
}
|
}
|
||||||
|
@ -9,10 +9,10 @@ import {
|
|||||||
connectionMade,
|
connectionMade,
|
||||||
} from 'features/nodes/store/nodesSlice';
|
} from 'features/nodes/store/nodesSlice';
|
||||||
import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection';
|
import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection';
|
||||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
|
||||||
import { isString } from 'lodash-es';
|
import { isString } from 'lodash-es';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
import { type OnConnect, type OnConnectEnd, type OnConnectStart, useUpdateNodeInternals } from 'reactflow';
|
import type { OnConnect, OnConnectEnd, OnConnectStart } from 'reactflow';
|
||||||
|
import { useUpdateNodeInternals } from 'reactflow';
|
||||||
import { assert } from 'tsafe';
|
import { assert } from 'tsafe';
|
||||||
|
|
||||||
export const useConnection = () => {
|
export const useConnection = () => {
|
||||||
@ -21,21 +21,27 @@ export const useConnection = () => {
|
|||||||
const updateNodeInternals = useUpdateNodeInternals();
|
const updateNodeInternals = useUpdateNodeInternals();
|
||||||
|
|
||||||
const onConnectStart = useCallback<OnConnectStart>(
|
const onConnectStart = useCallback<OnConnectStart>(
|
||||||
(event, params) => {
|
(event, { nodeId, handleId, handleType }) => {
|
||||||
|
assert(nodeId && handleId && handleType, 'Invalid connection start event');
|
||||||
const nodes = store.getState().nodes.present.nodes;
|
const nodes = store.getState().nodes.present.nodes;
|
||||||
const { nodeId, handleId, handleType } = params;
|
|
||||||
assert(nodeId && handleId && handleType, `Invalid connection start params: ${JSON.stringify(params)}`);
|
|
||||||
const node = nodes.find((n) => n.id === nodeId);
|
const node = nodes.find((n) => n.id === nodeId);
|
||||||
assert(isInvocationNode(node), `Invalid node during connection: ${JSON.stringify(node)}`);
|
if (!node) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const template = templates[node.data.type];
|
const template = templates[node.data.type];
|
||||||
assert(template, `Template not found for node type: ${node.data.type}`);
|
if (!template) {
|
||||||
const fieldTemplate = handleType === 'source' ? template.outputs[handleId] : template.inputs[handleId];
|
return;
|
||||||
assert(fieldTemplate, `Field template not found for field: ${node.data.type}.${handleId}`);
|
}
|
||||||
$pendingConnection.set({
|
|
||||||
node,
|
const fieldTemplates = template[handleType === 'source' ? 'outputs' : 'inputs'];
|
||||||
template,
|
const fieldTemplate = fieldTemplates[handleId];
|
||||||
fieldTemplate,
|
if (!fieldTemplate) {
|
||||||
});
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
$pendingConnection.set({ nodeId, handleId, handleType, fieldTemplate });
|
||||||
},
|
},
|
||||||
[store, templates]
|
[store, templates]
|
||||||
);
|
);
|
||||||
@ -67,20 +73,20 @@ export const useConnection = () => {
|
|||||||
}
|
}
|
||||||
const { nodes, edges } = store.getState().nodes.present;
|
const { nodes, edges } = store.getState().nodes.present;
|
||||||
if (mouseOverNodeId) {
|
if (mouseOverNodeId) {
|
||||||
const candidateNode = nodes.filter(isInvocationNode).find((n) => n.id === mouseOverNodeId);
|
const { handleType } = pendingConnection;
|
||||||
if (!candidateNode) {
|
const source = handleType === 'source' ? pendingConnection.nodeId : mouseOverNodeId;
|
||||||
// The mouse is over a non-invocation node - bail
|
const sourceHandle = handleType === 'source' ? pendingConnection.handleId : null;
|
||||||
return;
|
const target = handleType === 'target' ? pendingConnection.nodeId : mouseOverNodeId;
|
||||||
}
|
const targetHandle = handleType === 'target' ? pendingConnection.handleId : null;
|
||||||
const candidateTemplate = templates[candidateNode.data.type];
|
|
||||||
assert(candidateTemplate, `Template not found for node type: ${candidateNode.data.type}`);
|
|
||||||
const connection = getFirstValidConnection(
|
const connection = getFirstValidConnection(
|
||||||
templates,
|
source,
|
||||||
|
sourceHandle,
|
||||||
|
target,
|
||||||
|
targetHandle,
|
||||||
nodes,
|
nodes,
|
||||||
edges,
|
edges,
|
||||||
pendingConnection,
|
templates,
|
||||||
candidateNode,
|
|
||||||
candidateTemplate,
|
|
||||||
edgePendingUpdate
|
edgePendingUpdate
|
||||||
);
|
);
|
||||||
if (connection) {
|
if (connection) {
|
||||||
|
@ -43,8 +43,8 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return (
|
return (
|
||||||
pendingConnection.node.id === nodeId &&
|
pendingConnection.nodeId === nodeId &&
|
||||||
pendingConnection.fieldTemplate.name === fieldName &&
|
pendingConnection.handleId === fieldName &&
|
||||||
pendingConnection.fieldTemplate.fieldKind === { inputs: 'input', outputs: 'output' }[kind]
|
pendingConnection.fieldTemplate.fieldKind === { inputs: 'input', outputs: 'output' }[kind]
|
||||||
);
|
);
|
||||||
}, [fieldName, kind, nodeId, pendingConnection]);
|
}, [fieldName, kind, nodeId, pendingConnection]);
|
||||||
|
@ -6,19 +6,20 @@ import type {
|
|||||||
} from 'features/nodes/types/field';
|
} from 'features/nodes/types/field';
|
||||||
import type {
|
import type {
|
||||||
AnyNode,
|
AnyNode,
|
||||||
InvocationNode,
|
|
||||||
InvocationNodeEdge,
|
InvocationNodeEdge,
|
||||||
InvocationTemplate,
|
InvocationTemplate,
|
||||||
NodeExecutionState,
|
NodeExecutionState,
|
||||||
} from 'features/nodes/types/invocation';
|
} from 'features/nodes/types/invocation';
|
||||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||||
|
import type { HandleType } from 'reactflow';
|
||||||
|
|
||||||
export type Templates = Record<string, InvocationTemplate>;
|
export type Templates = Record<string, InvocationTemplate>;
|
||||||
export type NodeExecutionStates = Record<string, NodeExecutionState | undefined>;
|
export type NodeExecutionStates = Record<string, NodeExecutionState | undefined>;
|
||||||
|
|
||||||
export type PendingConnection = {
|
export type PendingConnection = {
|
||||||
node: InvocationNode;
|
nodeId: string;
|
||||||
template: InvocationTemplate;
|
handleId: string;
|
||||||
|
handleType: HandleType;
|
||||||
fieldTemplate: FieldInputTemplate | FieldOutputTemplate;
|
fieldTemplate: FieldInputTemplate | FieldOutputTemplate;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -36,9 +36,7 @@ export const makeConnectionErrorSelector = (
|
|||||||
return i18n.t('nodes.noConnectionInProgress');
|
return i18n.t('nodes.noConnectionInProgress');
|
||||||
}
|
}
|
||||||
|
|
||||||
const connectionHandleType = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source';
|
if (handleType === pendingConnection.handleType) {
|
||||||
|
|
||||||
if (handleType === connectionHandleType) {
|
|
||||||
if (handleType === 'source') {
|
if (handleType === 'source') {
|
||||||
return i18n.t('nodes.cannotConnectOutputToOutput');
|
return i18n.t('nodes.cannotConnectOutputToOutput');
|
||||||
}
|
}
|
||||||
@ -46,10 +44,10 @@ export const makeConnectionErrorSelector = (
|
|||||||
}
|
}
|
||||||
|
|
||||||
// we have to figure out which is the target and which is the source
|
// we have to figure out which is the target and which is the source
|
||||||
const source = handleType === 'source' ? nodeId : pendingConnection.node.id;
|
const source = handleType === 'source' ? nodeId : pendingConnection.nodeId;
|
||||||
const sourceHandle = handleType === 'source' ? fieldName : pendingConnection.fieldTemplate.name;
|
const sourceHandle = handleType === 'source' ? fieldName : pendingConnection.handleId;
|
||||||
const target = handleType === 'target' ? nodeId : pendingConnection.node.id;
|
const target = handleType === 'target' ? nodeId : pendingConnection.nodeId;
|
||||||
const targetHandle = handleType === 'target' ? fieldName : pendingConnection.fieldTemplate.name;
|
const targetHandle = handleType === 'target' ? fieldName : pendingConnection.handleId;
|
||||||
|
|
||||||
const validationResult = validateConnection(
|
const validationResult = validateConnection(
|
||||||
{
|
{
|
||||||
|
@ -2,7 +2,7 @@ import type { Templates } from 'features/nodes/store/types';
|
|||||||
import type { InvocationTemplate } from 'features/nodes/types/invocation';
|
import type { InvocationTemplate } from 'features/nodes/types/invocation';
|
||||||
import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode';
|
import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode';
|
||||||
import type { OpenAPIV3_1 } from 'openapi-types';
|
import type { OpenAPIV3_1 } from 'openapi-types';
|
||||||
import type { Edge, XYPosition } from 'reactflow';
|
import type { Edge } from 'reactflow';
|
||||||
|
|
||||||
export const buildEdge = (source: string, sourceHandle: string, target: string, targetHandle: string): Edge => ({
|
export const buildEdge = (source: string, sourceHandle: string, target: string, targetHandle: string): Edge => ({
|
||||||
source,
|
source,
|
||||||
@ -13,8 +13,6 @@ export const buildEdge = (source: string, sourceHandle: string, target: string,
|
|||||||
id: `reactflow__edge-${source}${sourceHandle}-${target}${targetHandle}`,
|
id: `reactflow__edge-${source}${sourceHandle}-${target}${targetHandle}`,
|
||||||
});
|
});
|
||||||
|
|
||||||
export const position: XYPosition = { x: 0, y: 0 };
|
|
||||||
|
|
||||||
export const buildNode = (template: InvocationTemplate) => buildInvocationNode({ x: 0, y: 0 }, template);
|
export const buildNode = (template: InvocationTemplate) => buildInvocationNode({ x: 0, y: 0 }, template);
|
||||||
|
|
||||||
export const add: InvocationTemplate = {
|
export const add: InvocationTemplate = {
|
||||||
@ -176,7 +174,7 @@ export const collect: InvocationTemplate = {
|
|||||||
classification: 'stable',
|
classification: 'stable',
|
||||||
};
|
};
|
||||||
|
|
||||||
export const scheduler: InvocationTemplate = {
|
const scheduler: InvocationTemplate = {
|
||||||
title: 'Scheduler',
|
title: 'Scheduler',
|
||||||
type: 'scheduler',
|
type: 'scheduler',
|
||||||
version: '1.0.0',
|
version: '1.0.0',
|
||||||
|
@ -6,11 +6,10 @@ import { validateConnectionTypes } from 'features/nodes/store/util/validateConne
|
|||||||
import type { AnyNode } from 'features/nodes/types/invocation';
|
import type { AnyNode } from 'features/nodes/types/invocation';
|
||||||
import type { Connection as NullableConnection, Edge } from 'reactflow';
|
import type { Connection as NullableConnection, Edge } from 'reactflow';
|
||||||
import type { O } from 'ts-toolbelt';
|
import type { O } from 'ts-toolbelt';
|
||||||
import { assert } from 'tsafe';
|
|
||||||
|
|
||||||
type Connection = O.NonNullable<NullableConnection>;
|
type Connection = O.NonNullable<NullableConnection>;
|
||||||
|
|
||||||
export type ValidateConnectionResult =
|
type ValidateConnectionResult =
|
||||||
| {
|
| {
|
||||||
isValid: true;
|
isValid: true;
|
||||||
messageTKey?: string;
|
messageTKey?: string;
|
||||||
@ -20,7 +19,7 @@ export type ValidateConnectionResult =
|
|||||||
messageTKey: string;
|
messageTKey: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
export type ValidateConnectionFunc = (
|
type ValidateConnectionFunc = (
|
||||||
connection: Connection,
|
connection: Connection,
|
||||||
nodes: AnyNode[],
|
nodes: AnyNode[],
|
||||||
edges: Edge[],
|
edges: Edge[],
|
||||||
@ -29,21 +28,6 @@ export type ValidateConnectionFunc = (
|
|||||||
strict?: boolean
|
strict?: boolean
|
||||||
) => ValidateConnectionResult;
|
) => ValidateConnectionResult;
|
||||||
|
|
||||||
export const buildResult = (isValid: boolean, messageTKey?: string): ValidateConnectionResult => {
|
|
||||||
if (isValid) {
|
|
||||||
return {
|
|
||||||
isValid,
|
|
||||||
messageTKey,
|
|
||||||
};
|
|
||||||
} else {
|
|
||||||
assert(messageTKey !== undefined);
|
|
||||||
return {
|
|
||||||
isValid,
|
|
||||||
messageTKey,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
const getEqualityPredicate =
|
const getEqualityPredicate =
|
||||||
(c: Connection) =>
|
(c: Connection) =>
|
||||||
(e: Edge): boolean => {
|
(e: Edge): boolean => {
|
||||||
|
Loading…
Reference in New Issue
Block a user