mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): use new validateConnection
This commit is contained in:
parent
6ad01d824d
commit
fc31dddbf7
@ -2,11 +2,9 @@ import { useStore } from '@nanostores/react';
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { $pendingConnection, $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
import { $pendingConnection, $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||||
import { makeConnectionErrorSelector } from 'features/nodes/store/util/connectionValidation.js';
|
import { makeConnectionErrorSelector } from 'features/nodes/store/util/makeConnectionErrorSelector';
|
||||||
import { useMemo } from 'react';
|
import { useMemo } from 'react';
|
||||||
|
|
||||||
import { useFieldType } from './useFieldType.ts';
|
|
||||||
|
|
||||||
type UseConnectionStateProps = {
|
type UseConnectionStateProps = {
|
||||||
nodeId: string;
|
nodeId: string;
|
||||||
fieldName: string;
|
fieldName: string;
|
||||||
@ -16,7 +14,6 @@ type UseConnectionStateProps = {
|
|||||||
export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionStateProps) => {
|
export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionStateProps) => {
|
||||||
const pendingConnection = useStore($pendingConnection);
|
const pendingConnection = useStore($pendingConnection);
|
||||||
const templates = useStore($templates);
|
const templates = useStore($templates);
|
||||||
const fieldType = useFieldType(nodeId, fieldName, kind);
|
|
||||||
|
|
||||||
const selectIsConnected = useMemo(
|
const selectIsConnected = useMemo(
|
||||||
() =>
|
() =>
|
||||||
@ -34,8 +31,8 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta
|
|||||||
);
|
);
|
||||||
|
|
||||||
const selectConnectionError = useMemo(
|
const selectConnectionError = useMemo(
|
||||||
() => makeConnectionErrorSelector(templates, nodeId, fieldName, kind === 'inputs' ? 'target' : 'source', fieldType),
|
() => makeConnectionErrorSelector(templates, nodeId, fieldName, kind === 'inputs' ? 'target' : 'source'),
|
||||||
[templates, nodeId, fieldName, kind, fieldType]
|
[templates, nodeId, fieldName, kind]
|
||||||
);
|
);
|
||||||
|
|
||||||
const isConnected = useAppSelector(selectIsConnected);
|
const isConnected = useAppSelector(selectIsConnected);
|
||||||
|
@ -2,13 +2,9 @@
|
|||||||
import { useStore } from '@nanostores/react';
|
import { useStore } from '@nanostores/react';
|
||||||
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
|
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
|
||||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||||
import { areTypesEqual } from 'features/nodes/store/util/areTypesEqual';
|
import { validateConnection } from 'features/nodes/store/util/validateConnection';
|
||||||
import { getCollectItemType } from 'features/nodes/store/util/getCollectItemType';
|
|
||||||
import { getHasCycles } from 'features/nodes/store/util/getHasCycles';
|
|
||||||
import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes';
|
|
||||||
import type { InvocationNodeData } from 'features/nodes/types/invocation';
|
|
||||||
import { useCallback } from 'react';
|
import { useCallback } from 'react';
|
||||||
import type { Connection, Node } from 'reactflow';
|
import type { Connection } from 'reactflow';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts`
|
* NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts`
|
||||||
@ -26,74 +22,20 @@ export const useIsValidConnection = () => {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (source === target) {
|
const { nodes, edges } = store.getState().nodes.present;
|
||||||
// Don't allow nodes to connect to themselves, even if validation is disabled
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
const state = store.getState();
|
const validationResult = validateConnection(
|
||||||
const { nodes, edges } = state.nodes.present;
|
{ source, sourceHandle, target, targetHandle },
|
||||||
|
nodes,
|
||||||
|
edges,
|
||||||
|
templates,
|
||||||
|
null,
|
||||||
|
shouldValidateGraph
|
||||||
|
);
|
||||||
|
|
||||||
// Find the source and target nodes
|
return validationResult.isValid;
|
||||||
const sourceNode = nodes.find((node) => node.id === source) as Node<InvocationNodeData>;
|
|
||||||
const targetNode = nodes.find((node) => node.id === target) as Node<InvocationNodeData>;
|
|
||||||
const sourceFieldTemplate = templates[sourceNode.data.type]?.outputs[sourceHandle];
|
|
||||||
const targetFieldTemplate = templates[targetNode.data.type]?.inputs[targetHandle];
|
|
||||||
|
|
||||||
// Conditional guards against undefined nodes/handles
|
|
||||||
if (!(sourceFieldTemplate && targetFieldTemplate)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (targetFieldTemplate.input === 'direct') {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!shouldValidateGraph) {
|
|
||||||
// manual override!
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (
|
|
||||||
edges.find((edge) => {
|
|
||||||
edge.target === target &&
|
|
||||||
edge.targetHandle === targetHandle &&
|
|
||||||
edge.source === source &&
|
|
||||||
edge.sourceHandle === sourceHandle;
|
|
||||||
})
|
|
||||||
) {
|
|
||||||
// We already have a connection from this source to this target
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (targetNode.data.type === 'collect' && targetFieldTemplate.name === 'item') {
|
|
||||||
// Collect nodes shouldn't mix and match field types
|
|
||||||
const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id);
|
|
||||||
if (collectItemType) {
|
|
||||||
return areTypesEqual(sourceFieldTemplate.type, collectItemType);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Connection is invalid if target already has a connection
|
|
||||||
if (
|
|
||||||
edges.find((edge) => {
|
|
||||||
return edge.target === target && edge.targetHandle === targetHandle;
|
|
||||||
}) &&
|
|
||||||
// except CollectionItem inputs can have multiples
|
|
||||||
targetFieldTemplate.type.name !== 'CollectionItemField'
|
|
||||||
) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Must use the originalType here if it exists
|
|
||||||
if (!validateConnectionTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Graphs much be acyclic (no loops!)
|
|
||||||
return !getHasCycles(source, target, nodes, edges);
|
|
||||||
},
|
},
|
||||||
[shouldValidateGraph, templates, store]
|
[templates, shouldValidateGraph, store]
|
||||||
);
|
);
|
||||||
|
|
||||||
return isValidConnection;
|
return isValidConnection;
|
||||||
|
@ -1,134 +0,0 @@
|
|||||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
|
||||||
import type { RootState } from 'app/store/store';
|
|
||||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
|
||||||
import type { NodesState, PendingConnection, Templates } from 'features/nodes/store/types';
|
|
||||||
import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes';
|
|
||||||
import type { FieldType } from 'features/nodes/types/field';
|
|
||||||
import i18n from 'i18next';
|
|
||||||
import type { HandleType } from 'reactflow';
|
|
||||||
import { assert } from 'tsafe';
|
|
||||||
|
|
||||||
import { areTypesEqual } from './areTypesEqual';
|
|
||||||
import { getCollectItemType } from './getCollectItemType';
|
|
||||||
import { getHasCycles } from './getHasCycles';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates a selector that validates a pending connection.
|
|
||||||
*
|
|
||||||
* NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts`
|
|
||||||
* TODO: Figure out how to do this without duplicating all the logic
|
|
||||||
*
|
|
||||||
* @param templates The invocation templates
|
|
||||||
* @param pendingConnection The current pending connection (if there is one)
|
|
||||||
* @param nodeId The id of the node for which the selector is being created
|
|
||||||
* @param fieldName The name of the field for which the selector is being created
|
|
||||||
* @param handleType The type of the handle for which the selector is being created
|
|
||||||
* @param fieldType The type of the field for which the selector is being created
|
|
||||||
* @returns
|
|
||||||
*/
|
|
||||||
export const makeConnectionErrorSelector = (
|
|
||||||
templates: Templates,
|
|
||||||
nodeId: string,
|
|
||||||
fieldName: string,
|
|
||||||
handleType: HandleType,
|
|
||||||
fieldType: FieldType
|
|
||||||
) => {
|
|
||||||
return createMemoizedSelector(
|
|
||||||
selectNodesSlice,
|
|
||||||
(state: RootState, pendingConnection: PendingConnection | null) => pendingConnection,
|
|
||||||
(nodesSlice: NodesState, pendingConnection: PendingConnection | null) => {
|
|
||||||
const { nodes, edges } = nodesSlice;
|
|
||||||
|
|
||||||
if (!pendingConnection) {
|
|
||||||
return i18n.t('nodes.noConnectionInProgress');
|
|
||||||
}
|
|
||||||
|
|
||||||
const connectionNodeId = pendingConnection.node.id;
|
|
||||||
const connectionFieldName = pendingConnection.fieldTemplate.name;
|
|
||||||
const connectionHandleType = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source';
|
|
||||||
const connectionStartFieldType = pendingConnection.fieldTemplate.type;
|
|
||||||
|
|
||||||
if (!connectionHandleType || !connectionNodeId || !connectionFieldName) {
|
|
||||||
return i18n.t('nodes.noConnectionData');
|
|
||||||
}
|
|
||||||
|
|
||||||
const targetType = handleType === 'target' ? fieldType : connectionStartFieldType;
|
|
||||||
const sourceType = handleType === 'source' ? fieldType : connectionStartFieldType;
|
|
||||||
|
|
||||||
if (nodeId === connectionNodeId) {
|
|
||||||
return i18n.t('nodes.cannotConnectToSelf');
|
|
||||||
}
|
|
||||||
|
|
||||||
if (handleType === connectionHandleType) {
|
|
||||||
if (handleType === 'source') {
|
|
||||||
return i18n.t('nodes.cannotConnectOutputToOutput');
|
|
||||||
}
|
|
||||||
return i18n.t('nodes.cannotConnectInputToInput');
|
|
||||||
}
|
|
||||||
|
|
||||||
// we have to figure out which is the target and which is the source
|
|
||||||
const targetNodeId = handleType === 'target' ? nodeId : connectionNodeId;
|
|
||||||
const targetFieldName = handleType === 'target' ? fieldName : connectionFieldName;
|
|
||||||
const sourceNodeId = handleType === 'source' ? nodeId : connectionNodeId;
|
|
||||||
const sourceFieldName = handleType === 'source' ? fieldName : connectionFieldName;
|
|
||||||
|
|
||||||
if (
|
|
||||||
edges.find((edge) => {
|
|
||||||
edge.target === targetNodeId &&
|
|
||||||
edge.targetHandle === targetFieldName &&
|
|
||||||
edge.source === sourceNodeId &&
|
|
||||||
edge.sourceHandle === sourceFieldName;
|
|
||||||
})
|
|
||||||
) {
|
|
||||||
// We already have a connection from this source to this target
|
|
||||||
return i18n.t('nodes.cannotDuplicateConnection');
|
|
||||||
}
|
|
||||||
|
|
||||||
const targetNode = nodes.find((node) => node.id === targetNodeId);
|
|
||||||
assert(targetNode, `Target node not found: ${targetNodeId}`);
|
|
||||||
const targetTemplate = templates[targetNode.data.type];
|
|
||||||
assert(targetTemplate, `Target template not found: ${targetNode.data.type}`);
|
|
||||||
|
|
||||||
if (targetTemplate.inputs[targetFieldName]?.input === 'direct') {
|
|
||||||
return i18n.t('nodes.cannotConnectToDirectInput');
|
|
||||||
}
|
|
||||||
|
|
||||||
if (targetNode.data.type === 'collect' && targetFieldName === 'item') {
|
|
||||||
// Collect nodes shouldn't mix and match field types
|
|
||||||
const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id);
|
|
||||||
if (collectItemType) {
|
|
||||||
if (!areTypesEqual(sourceType, collectItemType)) {
|
|
||||||
return i18n.t('nodes.cannotMixAndMatchCollectionItemTypes');
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (
|
|
||||||
edges.find((edge) => {
|
|
||||||
return edge.target === targetNodeId && edge.targetHandle === targetFieldName;
|
|
||||||
}) &&
|
|
||||||
// except CollectionItem inputs can have multiples
|
|
||||||
targetType.name !== 'CollectionItemField'
|
|
||||||
) {
|
|
||||||
return i18n.t('nodes.inputMayOnlyHaveOneConnection');
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!validateConnectionTypes(sourceType, targetType)) {
|
|
||||||
return i18n.t('nodes.fieldTypesMustMatch');
|
|
||||||
}
|
|
||||||
|
|
||||||
const hasCycles = getHasCycles(
|
|
||||||
connectionHandleType === 'source' ? connectionNodeId : nodeId,
|
|
||||||
connectionHandleType === 'source' ? nodeId : connectionNodeId,
|
|
||||||
nodes,
|
|
||||||
edges
|
|
||||||
);
|
|
||||||
|
|
||||||
if (hasCycles) {
|
|
||||||
return i18n.t('nodes.connectionWouldCreateCycle');
|
|
||||||
}
|
|
||||||
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
);
|
|
||||||
};
|
|
@ -0,0 +1,72 @@
|
|||||||
|
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||||
|
import type { RootState } from 'app/store/store';
|
||||||
|
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||||
|
import type { NodesState, PendingConnection, Templates } from 'features/nodes/store/types';
|
||||||
|
import { validateConnection } from 'features/nodes/store/util/validateConnection';
|
||||||
|
import i18n from 'i18next';
|
||||||
|
import type { HandleType } from 'reactflow';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a selector that validates a pending connection.
|
||||||
|
*
|
||||||
|
* NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts`
|
||||||
|
* TODO: Figure out how to do this without duplicating all the logic
|
||||||
|
*
|
||||||
|
* @param templates The invocation templates
|
||||||
|
* @param nodeId The id of the node for which the selector is being created
|
||||||
|
* @param fieldName The name of the field for which the selector is being created
|
||||||
|
* @param handleType The type of the handle for which the selector is being created
|
||||||
|
* @returns
|
||||||
|
*/
|
||||||
|
export const makeConnectionErrorSelector = (
|
||||||
|
templates: Templates,
|
||||||
|
nodeId: string,
|
||||||
|
fieldName: string,
|
||||||
|
handleType: HandleType
|
||||||
|
) => {
|
||||||
|
return createMemoizedSelector(
|
||||||
|
selectNodesSlice,
|
||||||
|
(state: RootState, pendingConnection: PendingConnection | null) => pendingConnection,
|
||||||
|
(nodesSlice: NodesState, pendingConnection: PendingConnection | null) => {
|
||||||
|
const { nodes, edges } = nodesSlice;
|
||||||
|
|
||||||
|
if (!pendingConnection) {
|
||||||
|
return i18n.t('nodes.noConnectionInProgress');
|
||||||
|
}
|
||||||
|
|
||||||
|
const connectionNodeId = pendingConnection.node.id;
|
||||||
|
const connectionFieldName = pendingConnection.fieldTemplate.name;
|
||||||
|
const connectionHandleType = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source';
|
||||||
|
|
||||||
|
if (handleType === connectionHandleType) {
|
||||||
|
if (handleType === 'source') {
|
||||||
|
return i18n.t('nodes.cannotConnectOutputToOutput');
|
||||||
|
}
|
||||||
|
return i18n.t('nodes.cannotConnectInputToInput');
|
||||||
|
}
|
||||||
|
|
||||||
|
// we have to figure out which is the target and which is the source
|
||||||
|
const source = handleType === 'source' ? nodeId : connectionNodeId;
|
||||||
|
const sourceHandle = handleType === 'source' ? fieldName : connectionFieldName;
|
||||||
|
const target = handleType === 'target' ? nodeId : connectionNodeId;
|
||||||
|
const targetHandle = handleType === 'target' ? fieldName : connectionFieldName;
|
||||||
|
|
||||||
|
const validationResult = validateConnection(
|
||||||
|
{
|
||||||
|
source,
|
||||||
|
sourceHandle,
|
||||||
|
target,
|
||||||
|
targetHandle,
|
||||||
|
},
|
||||||
|
nodes,
|
||||||
|
edges,
|
||||||
|
templates,
|
||||||
|
null
|
||||||
|
);
|
||||||
|
|
||||||
|
if (!validationResult.isValid) {
|
||||||
|
return i18n.t(validationResult.messageTKey);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
);
|
||||||
|
};
|
Loading…
Reference in New Issue
Block a user