mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): rework pendingConnection
This commit is contained in:
parent
4bda174eb9
commit
a80e3448f5
@ -9,6 +9,7 @@ import type { SelectInstance } from 'chakra-react-select';
|
||||
import { useBuildNode } from 'features/nodes/hooks/useBuildNode';
|
||||
import {
|
||||
$cursorPos,
|
||||
$edgePendingUpdate,
|
||||
$isAddNodePopoverOpen,
|
||||
$pendingConnection,
|
||||
$templates,
|
||||
@ -28,7 +29,6 @@ import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import type { HotkeyCallback } from 'react-hotkeys-hook/dist/types';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { FilterOptionOption } from 'react-select/dist/declarations/src/filters';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const createRegex = memoize(
|
||||
(inputValue: string) =>
|
||||
@ -68,16 +68,18 @@ const AddNodePopover = () => {
|
||||
|
||||
const filteredTemplates = useMemo(() => {
|
||||
// If we have a connection in progress, we need to filter the node choices
|
||||
const templatesArray = map(templates);
|
||||
if (!pendingConnection) {
|
||||
return map(templates);
|
||||
return templatesArray;
|
||||
}
|
||||
|
||||
return filter(templates, (template) => {
|
||||
const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind;
|
||||
const fields = pendingFieldKind === 'input' ? template.outputs : template.inputs;
|
||||
return some(fields, (field) => {
|
||||
const sourceType = pendingFieldKind === 'input' ? field.type : pendingConnection.fieldTemplate.type;
|
||||
const targetType = pendingFieldKind === 'output' ? field.type : pendingConnection.fieldTemplate.type;
|
||||
const candidateFields = pendingConnection.handleType === 'source' ? template.inputs : template.outputs;
|
||||
return some(candidateFields, (field) => {
|
||||
const sourceType =
|
||||
pendingConnection.handleType === 'source' ? field.type : pendingConnection.fieldTemplate.type;
|
||||
const targetType =
|
||||
pendingConnection.handleType === 'target' ? field.type : pendingConnection.fieldTemplate.type;
|
||||
return validateConnectionTypes(sourceType, targetType);
|
||||
});
|
||||
});
|
||||
@ -144,10 +146,25 @@ const AddNodePopover = () => {
|
||||
|
||||
// Auto-connect an edge if we just added a node and have a pending connection
|
||||
if (pendingConnection && isInvocationNode(node)) {
|
||||
const template = templates[node.data.type];
|
||||
assert(template, 'Template not found');
|
||||
const edgePendingUpdate = $edgePendingUpdate.get();
|
||||
const { handleType } = pendingConnection;
|
||||
|
||||
const source = handleType === 'source' ? pendingConnection.nodeId : node.id;
|
||||
const sourceHandle = handleType === 'source' ? pendingConnection.handleId : null;
|
||||
const target = handleType === 'target' ? pendingConnection.nodeId : node.id;
|
||||
const targetHandle = handleType === 'target' ? pendingConnection.handleId : null;
|
||||
|
||||
const { nodes, edges } = store.getState().nodes.present;
|
||||
const connection = getFirstValidConnection(templates, nodes, edges, pendingConnection, node, template);
|
||||
const connection = getFirstValidConnection(
|
||||
source,
|
||||
sourceHandle,
|
||||
target,
|
||||
targetHandle,
|
||||
nodes,
|
||||
edges,
|
||||
templates,
|
||||
edgePendingUpdate
|
||||
);
|
||||
if (connection) {
|
||||
dispatch(connectionMade(connection));
|
||||
}
|
||||
|
@ -9,10 +9,10 @@ import {
|
||||
connectionMade,
|
||||
} from 'features/nodes/store/nodesSlice';
|
||||
import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { isString } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { type OnConnect, type OnConnectEnd, type OnConnectStart, useUpdateNodeInternals } from 'reactflow';
|
||||
import type { OnConnect, OnConnectEnd, OnConnectStart } from 'reactflow';
|
||||
import { useUpdateNodeInternals } from 'reactflow';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
export const useConnection = () => {
|
||||
@ -21,21 +21,27 @@ export const useConnection = () => {
|
||||
const updateNodeInternals = useUpdateNodeInternals();
|
||||
|
||||
const onConnectStart = useCallback<OnConnectStart>(
|
||||
(event, params) => {
|
||||
(event, { nodeId, handleId, handleType }) => {
|
||||
assert(nodeId && handleId && handleType, 'Invalid connection start event');
|
||||
const nodes = store.getState().nodes.present.nodes;
|
||||
const { nodeId, handleId, handleType } = params;
|
||||
assert(nodeId && handleId && handleType, `Invalid connection start params: ${JSON.stringify(params)}`);
|
||||
|
||||
const node = nodes.find((n) => n.id === nodeId);
|
||||
assert(isInvocationNode(node), `Invalid node during connection: ${JSON.stringify(node)}`);
|
||||
if (!node) {
|
||||
return;
|
||||
}
|
||||
|
||||
const template = templates[node.data.type];
|
||||
assert(template, `Template not found for node type: ${node.data.type}`);
|
||||
const fieldTemplate = handleType === 'source' ? template.outputs[handleId] : template.inputs[handleId];
|
||||
assert(fieldTemplate, `Field template not found for field: ${node.data.type}.${handleId}`);
|
||||
$pendingConnection.set({
|
||||
node,
|
||||
template,
|
||||
fieldTemplate,
|
||||
});
|
||||
if (!template) {
|
||||
return;
|
||||
}
|
||||
|
||||
const fieldTemplates = template[handleType === 'source' ? 'outputs' : 'inputs'];
|
||||
const fieldTemplate = fieldTemplates[handleId];
|
||||
if (!fieldTemplate) {
|
||||
return;
|
||||
}
|
||||
|
||||
$pendingConnection.set({ nodeId, handleId, handleType, fieldTemplate });
|
||||
},
|
||||
[store, templates]
|
||||
);
|
||||
@ -67,20 +73,20 @@ export const useConnection = () => {
|
||||
}
|
||||
const { nodes, edges } = store.getState().nodes.present;
|
||||
if (mouseOverNodeId) {
|
||||
const candidateNode = nodes.filter(isInvocationNode).find((n) => n.id === mouseOverNodeId);
|
||||
if (!candidateNode) {
|
||||
// The mouse is over a non-invocation node - bail
|
||||
return;
|
||||
}
|
||||
const candidateTemplate = templates[candidateNode.data.type];
|
||||
assert(candidateTemplate, `Template not found for node type: ${candidateNode.data.type}`);
|
||||
const { handleType } = pendingConnection;
|
||||
const source = handleType === 'source' ? pendingConnection.nodeId : mouseOverNodeId;
|
||||
const sourceHandle = handleType === 'source' ? pendingConnection.handleId : null;
|
||||
const target = handleType === 'target' ? pendingConnection.nodeId : mouseOverNodeId;
|
||||
const targetHandle = handleType === 'target' ? pendingConnection.handleId : null;
|
||||
|
||||
const connection = getFirstValidConnection(
|
||||
templates,
|
||||
source,
|
||||
sourceHandle,
|
||||
target,
|
||||
targetHandle,
|
||||
nodes,
|
||||
edges,
|
||||
pendingConnection,
|
||||
candidateNode,
|
||||
candidateTemplate,
|
||||
templates,
|
||||
edgePendingUpdate
|
||||
);
|
||||
if (connection) {
|
||||
|
@ -43,8 +43,8 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta
|
||||
return false;
|
||||
}
|
||||
return (
|
||||
pendingConnection.node.id === nodeId &&
|
||||
pendingConnection.fieldTemplate.name === fieldName &&
|
||||
pendingConnection.nodeId === nodeId &&
|
||||
pendingConnection.handleId === fieldName &&
|
||||
pendingConnection.fieldTemplate.fieldKind === { inputs: 'input', outputs: 'output' }[kind]
|
||||
);
|
||||
}, [fieldName, kind, nodeId, pendingConnection]);
|
||||
|
@ -6,19 +6,20 @@ import type {
|
||||
} from 'features/nodes/types/field';
|
||||
import type {
|
||||
AnyNode,
|
||||
InvocationNode,
|
||||
InvocationNodeEdge,
|
||||
InvocationTemplate,
|
||||
NodeExecutionState,
|
||||
} from 'features/nodes/types/invocation';
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import type { HandleType } from 'reactflow';
|
||||
|
||||
export type Templates = Record<string, InvocationTemplate>;
|
||||
export type NodeExecutionStates = Record<string, NodeExecutionState | undefined>;
|
||||
|
||||
export type PendingConnection = {
|
||||
node: InvocationNode;
|
||||
template: InvocationTemplate;
|
||||
nodeId: string;
|
||||
handleId: string;
|
||||
handleType: HandleType;
|
||||
fieldTemplate: FieldInputTemplate | FieldOutputTemplate;
|
||||
};
|
||||
|
||||
|
@ -36,9 +36,7 @@ export const makeConnectionErrorSelector = (
|
||||
return i18n.t('nodes.noConnectionInProgress');
|
||||
}
|
||||
|
||||
const connectionHandleType = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source';
|
||||
|
||||
if (handleType === connectionHandleType) {
|
||||
if (handleType === pendingConnection.handleType) {
|
||||
if (handleType === 'source') {
|
||||
return i18n.t('nodes.cannotConnectOutputToOutput');
|
||||
}
|
||||
@ -46,10 +44,10 @@ export const makeConnectionErrorSelector = (
|
||||
}
|
||||
|
||||
// we have to figure out which is the target and which is the source
|
||||
const source = handleType === 'source' ? nodeId : pendingConnection.node.id;
|
||||
const sourceHandle = handleType === 'source' ? fieldName : pendingConnection.fieldTemplate.name;
|
||||
const target = handleType === 'target' ? nodeId : pendingConnection.node.id;
|
||||
const targetHandle = handleType === 'target' ? fieldName : pendingConnection.fieldTemplate.name;
|
||||
const source = handleType === 'source' ? nodeId : pendingConnection.nodeId;
|
||||
const sourceHandle = handleType === 'source' ? fieldName : pendingConnection.handleId;
|
||||
const target = handleType === 'target' ? nodeId : pendingConnection.nodeId;
|
||||
const targetHandle = handleType === 'target' ? fieldName : pendingConnection.handleId;
|
||||
|
||||
const validationResult = validateConnection(
|
||||
{
|
||||
|
@ -2,7 +2,7 @@ import type { Templates } from 'features/nodes/store/types';
|
||||
import type { InvocationTemplate } from 'features/nodes/types/invocation';
|
||||
import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode';
|
||||
import type { OpenAPIV3_1 } from 'openapi-types';
|
||||
import type { Edge, XYPosition } from 'reactflow';
|
||||
import type { Edge } from 'reactflow';
|
||||
|
||||
export const buildEdge = (source: string, sourceHandle: string, target: string, targetHandle: string): Edge => ({
|
||||
source,
|
||||
@ -13,8 +13,6 @@ export const buildEdge = (source: string, sourceHandle: string, target: string,
|
||||
id: `reactflow__edge-${source}${sourceHandle}-${target}${targetHandle}`,
|
||||
});
|
||||
|
||||
export const position: XYPosition = { x: 0, y: 0 };
|
||||
|
||||
export const buildNode = (template: InvocationTemplate) => buildInvocationNode({ x: 0, y: 0 }, template);
|
||||
|
||||
export const add: InvocationTemplate = {
|
||||
@ -176,7 +174,7 @@ export const collect: InvocationTemplate = {
|
||||
classification: 'stable',
|
||||
};
|
||||
|
||||
export const scheduler: InvocationTemplate = {
|
||||
const scheduler: InvocationTemplate = {
|
||||
title: 'Scheduler',
|
||||
type: 'scheduler',
|
||||
version: '1.0.0',
|
||||
|
@ -6,11 +6,10 @@ import { validateConnectionTypes } from 'features/nodes/store/util/validateConne
|
||||
import type { AnyNode } from 'features/nodes/types/invocation';
|
||||
import type { Connection as NullableConnection, Edge } from 'reactflow';
|
||||
import type { O } from 'ts-toolbelt';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
type Connection = O.NonNullable<NullableConnection>;
|
||||
|
||||
export type ValidateConnectionResult =
|
||||
type ValidateConnectionResult =
|
||||
| {
|
||||
isValid: true;
|
||||
messageTKey?: string;
|
||||
@ -20,7 +19,7 @@ export type ValidateConnectionResult =
|
||||
messageTKey: string;
|
||||
};
|
||||
|
||||
export type ValidateConnectionFunc = (
|
||||
type ValidateConnectionFunc = (
|
||||
connection: Connection,
|
||||
nodes: AnyNode[],
|
||||
edges: Edge[],
|
||||
@ -29,21 +28,6 @@ export type ValidateConnectionFunc = (
|
||||
strict?: boolean
|
||||
) => ValidateConnectionResult;
|
||||
|
||||
export const buildResult = (isValid: boolean, messageTKey?: string): ValidateConnectionResult => {
|
||||
if (isValid) {
|
||||
return {
|
||||
isValid,
|
||||
messageTKey,
|
||||
};
|
||||
} else {
|
||||
assert(messageTKey !== undefined);
|
||||
return {
|
||||
isValid,
|
||||
messageTKey,
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
const getEqualityPredicate =
|
||||
(c: Connection) =>
|
||||
(e: Edge): boolean => {
|
||||
|
Loading…
Reference in New Issue
Block a user