mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): use new validateConnection
This commit is contained in:
@ -2,11 +2,9 @@ import { useStore } from '@nanostores/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { $pendingConnection, $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { makeConnectionErrorSelector } from 'features/nodes/store/util/connectionValidation.js';
|
||||
import { makeConnectionErrorSelector } from 'features/nodes/store/util/makeConnectionErrorSelector';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
import { useFieldType } from './useFieldType.ts';
|
||||
|
||||
type UseConnectionStateProps = {
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
@ -16,7 +14,6 @@ type UseConnectionStateProps = {
|
||||
export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionStateProps) => {
|
||||
const pendingConnection = useStore($pendingConnection);
|
||||
const templates = useStore($templates);
|
||||
const fieldType = useFieldType(nodeId, fieldName, kind);
|
||||
|
||||
const selectIsConnected = useMemo(
|
||||
() =>
|
||||
@ -34,8 +31,8 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta
|
||||
);
|
||||
|
||||
const selectConnectionError = useMemo(
|
||||
() => makeConnectionErrorSelector(templates, nodeId, fieldName, kind === 'inputs' ? 'target' : 'source', fieldType),
|
||||
[templates, nodeId, fieldName, kind, fieldType]
|
||||
() => makeConnectionErrorSelector(templates, nodeId, fieldName, kind === 'inputs' ? 'target' : 'source'),
|
||||
[templates, nodeId, fieldName, kind]
|
||||
);
|
||||
|
||||
const isConnected = useAppSelector(selectIsConnected);
|
||||
|
@ -2,13 +2,9 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import { areTypesEqual } from 'features/nodes/store/util/areTypesEqual';
|
||||
import { getCollectItemType } from 'features/nodes/store/util/getCollectItemType';
|
||||
import { getHasCycles } from 'features/nodes/store/util/getHasCycles';
|
||||
import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes';
|
||||
import type { InvocationNodeData } from 'features/nodes/types/invocation';
|
||||
import { validateConnection } from 'features/nodes/store/util/validateConnection';
|
||||
import { useCallback } from 'react';
|
||||
import type { Connection, Node } from 'reactflow';
|
||||
import type { Connection } from 'reactflow';
|
||||
|
||||
/**
|
||||
* NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts`
|
||||
@ -26,74 +22,20 @@ export const useIsValidConnection = () => {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (source === target) {
|
||||
// Don't allow nodes to connect to themselves, even if validation is disabled
|
||||
return false;
|
||||
}
|
||||
const { nodes, edges } = store.getState().nodes.present;
|
||||
|
||||
const state = store.getState();
|
||||
const { nodes, edges } = state.nodes.present;
|
||||
const validationResult = validateConnection(
|
||||
{ source, sourceHandle, target, targetHandle },
|
||||
nodes,
|
||||
edges,
|
||||
templates,
|
||||
null,
|
||||
shouldValidateGraph
|
||||
);
|
||||
|
||||
// Find the source and target nodes
|
||||
const sourceNode = nodes.find((node) => node.id === source) as Node<InvocationNodeData>;
|
||||
const targetNode = nodes.find((node) => node.id === target) as Node<InvocationNodeData>;
|
||||
const sourceFieldTemplate = templates[sourceNode.data.type]?.outputs[sourceHandle];
|
||||
const targetFieldTemplate = templates[targetNode.data.type]?.inputs[targetHandle];
|
||||
|
||||
// Conditional guards against undefined nodes/handles
|
||||
if (!(sourceFieldTemplate && targetFieldTemplate)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (targetFieldTemplate.input === 'direct') {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!shouldValidateGraph) {
|
||||
// manual override!
|
||||
return true;
|
||||
}
|
||||
|
||||
if (
|
||||
edges.find((edge) => {
|
||||
edge.target === target &&
|
||||
edge.targetHandle === targetHandle &&
|
||||
edge.source === source &&
|
||||
edge.sourceHandle === sourceHandle;
|
||||
})
|
||||
) {
|
||||
// We already have a connection from this source to this target
|
||||
return false;
|
||||
}
|
||||
|
||||
if (targetNode.data.type === 'collect' && targetFieldTemplate.name === 'item') {
|
||||
// Collect nodes shouldn't mix and match field types
|
||||
const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id);
|
||||
if (collectItemType) {
|
||||
return areTypesEqual(sourceFieldTemplate.type, collectItemType);
|
||||
}
|
||||
}
|
||||
|
||||
// Connection is invalid if target already has a connection
|
||||
if (
|
||||
edges.find((edge) => {
|
||||
return edge.target === target && edge.targetHandle === targetHandle;
|
||||
}) &&
|
||||
// except CollectionItem inputs can have multiples
|
||||
targetFieldTemplate.type.name !== 'CollectionItemField'
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Must use the originalType here if it exists
|
||||
if (!validateConnectionTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Graphs much be acyclic (no loops!)
|
||||
return !getHasCycles(source, target, nodes, edges);
|
||||
return validationResult.isValid;
|
||||
},
|
||||
[shouldValidateGraph, templates, store]
|
||||
[templates, shouldValidateGraph, store]
|
||||
);
|
||||
|
||||
return isValidConnection;
|
||||
|
Reference in New Issue
Block a user