mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): connection validation for collection items types
This commit is contained in:
@ -70,7 +70,7 @@ export const useConnection = () => {
|
||||
}
|
||||
const candidateTemplate = templates[candidateNode.data.type];
|
||||
assert(candidateTemplate, `Template not found for node type: ${candidateNode.data.type}`);
|
||||
const connection = getFirstValidConnection(nodes, edges, pendingConnection, candidateNode, candidateTemplate);
|
||||
const connection = getFirstValidConnection(templates, nodes, edges, pendingConnection, candidateNode, candidateTemplate);
|
||||
if (connection) {
|
||||
dispatch(connectionMade(connection));
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { $pendingConnection, selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { $pendingConnection, $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { makeConnectionErrorSelector } from 'features/nodes/store/util/makeIsConnectionValidSelector';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
@ -15,6 +15,7 @@ type UseConnectionStateProps = {
|
||||
|
||||
export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionStateProps) => {
|
||||
const pendingConnection = useStore($pendingConnection);
|
||||
const templates = useStore($templates);
|
||||
const fieldType = useFieldType(nodeId, fieldName, kind);
|
||||
|
||||
const selectIsConnected = useMemo(
|
||||
@ -35,13 +36,14 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta
|
||||
const selectConnectionError = useMemo(
|
||||
() =>
|
||||
makeConnectionErrorSelector(
|
||||
templates,
|
||||
pendingConnection,
|
||||
nodeId,
|
||||
fieldName,
|
||||
kind === 'inputs' ? 'target' : 'source',
|
||||
fieldType
|
||||
),
|
||||
[pendingConnection, nodeId, fieldName, kind, fieldType]
|
||||
[templates, pendingConnection, nodeId, fieldName, kind, fieldType]
|
||||
);
|
||||
|
||||
const isConnected = useAppSelector(selectIsConnected);
|
||||
|
@ -3,8 +3,10 @@ import { useStore } from '@nanostores/react';
|
||||
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import { getIsGraphAcyclic } from 'features/nodes/store/util/getIsGraphAcyclic';
|
||||
import { getCollectItemType } from 'features/nodes/store/util/makeIsConnectionValidSelector';
|
||||
import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes';
|
||||
import type { InvocationNodeData } from 'features/nodes/types/invocation';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import { useCallback } from 'react';
|
||||
import type { Connection, Node } from 'reactflow';
|
||||
|
||||
@ -60,6 +62,14 @@ export const useIsValidConnection = () => {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (targetNode.data.type === 'collect' && targetFieldTemplate.name === 'item') {
|
||||
// Collect nodes shouldn't mix and match field types
|
||||
const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id);
|
||||
if (collectItemType) {
|
||||
return isEqual(sourceFieldTemplate.type, collectItemType);
|
||||
}
|
||||
}
|
||||
|
||||
// Connection is invalid if target already has a connection
|
||||
if (
|
||||
edges.find((edge) => {
|
||||
|
Reference in New Issue
Block a user