mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): makeConnectionErrorSelector now creates a parameterized selector
This commit is contained in:
parent
6658897210
commit
9d127fee6b
@ -34,16 +34,8 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta
|
||||
);
|
||||
|
||||
const selectConnectionError = useMemo(
|
||||
() =>
|
||||
makeConnectionErrorSelector(
|
||||
templates,
|
||||
pendingConnection,
|
||||
nodeId,
|
||||
fieldName,
|
||||
kind === 'inputs' ? 'target' : 'source',
|
||||
fieldType
|
||||
),
|
||||
[templates, pendingConnection, nodeId, fieldName, kind, fieldType]
|
||||
() => makeConnectionErrorSelector(templates, nodeId, fieldName, kind === 'inputs' ? 'target' : 'source', fieldType),
|
||||
[templates, nodeId, fieldName, kind, fieldType]
|
||||
);
|
||||
|
||||
const isConnected = useAppSelector(selectIsConnected);
|
||||
@ -58,7 +50,7 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta
|
||||
pendingConnection.fieldTemplate.fieldKind === { inputs: 'input', outputs: 'output' }[kind]
|
||||
);
|
||||
}, [fieldName, kind, nodeId, pendingConnection]);
|
||||
const connectionError = useAppSelector(selectConnectionError);
|
||||
const connectionError = useAppSelector((s) => selectConnectionError(s, pendingConnection));
|
||||
|
||||
const shouldDim = useMemo(
|
||||
() => Boolean(isConnectionInProgress && connectionError && !isConnectionStartField),
|
||||
|
@ -1,7 +1,8 @@
|
||||
import graphlib from '@dagrejs/graphlib';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import type { PendingConnection, Templates } from 'features/nodes/store/types';
|
||||
import type { NodesState, PendingConnection, Templates } from 'features/nodes/store/types';
|
||||
import { type FieldType, isStatefulFieldType } from 'features/nodes/types/field';
|
||||
import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation';
|
||||
import i18n from 'i18next';
|
||||
@ -190,105 +191,108 @@ export const getCollectItemType = (
|
||||
*/
|
||||
export const makeConnectionErrorSelector = (
|
||||
templates: Templates,
|
||||
pendingConnection: PendingConnection | null,
|
||||
nodeId: string,
|
||||
fieldName: string,
|
||||
handleType: HandleType,
|
||||
fieldType: FieldType
|
||||
) => {
|
||||
return createMemoizedSelector(selectNodesSlice, (nodesSlice) => {
|
||||
const { nodes, edges } = nodesSlice;
|
||||
return createMemoizedSelector(
|
||||
selectNodesSlice,
|
||||
(state: RootState, pendingConnection: PendingConnection | null) => pendingConnection,
|
||||
(nodesSlice: NodesState, pendingConnection: PendingConnection | null) => {
|
||||
const { nodes, edges } = nodesSlice;
|
||||
|
||||
if (!pendingConnection) {
|
||||
return i18n.t('nodes.noConnectionInProgress');
|
||||
}
|
||||
|
||||
const connectionNodeId = pendingConnection.node.id;
|
||||
const connectionFieldName = pendingConnection.fieldTemplate.name;
|
||||
const connectionHandleType = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source';
|
||||
const connectionStartFieldType = pendingConnection.fieldTemplate.type;
|
||||
|
||||
if (!connectionHandleType || !connectionNodeId || !connectionFieldName) {
|
||||
return i18n.t('nodes.noConnectionData');
|
||||
}
|
||||
|
||||
const targetType = handleType === 'target' ? fieldType : connectionStartFieldType;
|
||||
const sourceType = handleType === 'source' ? fieldType : connectionStartFieldType;
|
||||
|
||||
if (nodeId === connectionNodeId) {
|
||||
return i18n.t('nodes.cannotConnectToSelf');
|
||||
}
|
||||
|
||||
if (handleType === connectionHandleType) {
|
||||
if (handleType === 'source') {
|
||||
return i18n.t('nodes.cannotConnectOutputToOutput');
|
||||
if (!pendingConnection) {
|
||||
return i18n.t('nodes.noConnectionInProgress');
|
||||
}
|
||||
return i18n.t('nodes.cannotConnectInputToInput');
|
||||
}
|
||||
|
||||
// we have to figure out which is the target and which is the source
|
||||
const targetNodeId = handleType === 'target' ? nodeId : connectionNodeId;
|
||||
const targetFieldName = handleType === 'target' ? fieldName : connectionFieldName;
|
||||
const sourceNodeId = handleType === 'source' ? nodeId : connectionNodeId;
|
||||
const sourceFieldName = handleType === 'source' ? fieldName : connectionFieldName;
|
||||
const connectionNodeId = pendingConnection.node.id;
|
||||
const connectionFieldName = pendingConnection.fieldTemplate.name;
|
||||
const connectionHandleType = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source';
|
||||
const connectionStartFieldType = pendingConnection.fieldTemplate.type;
|
||||
|
||||
if (
|
||||
edges.find((edge) => {
|
||||
edge.target === targetNodeId &&
|
||||
edge.targetHandle === targetFieldName &&
|
||||
edge.source === sourceNodeId &&
|
||||
edge.sourceHandle === sourceFieldName;
|
||||
})
|
||||
) {
|
||||
// We already have a connection from this source to this target
|
||||
return i18n.t('nodes.cannotDuplicateConnection');
|
||||
}
|
||||
if (!connectionHandleType || !connectionNodeId || !connectionFieldName) {
|
||||
return i18n.t('nodes.noConnectionData');
|
||||
}
|
||||
|
||||
const targetNode = nodes.find((node) => node.id === targetNodeId);
|
||||
assert(targetNode, `Target node not found: ${targetNodeId}`);
|
||||
const targetTemplate = templates[targetNode.data.type];
|
||||
assert(targetTemplate, `Target template not found: ${targetNode.data.type}`);
|
||||
const targetType = handleType === 'target' ? fieldType : connectionStartFieldType;
|
||||
const sourceType = handleType === 'source' ? fieldType : connectionStartFieldType;
|
||||
|
||||
if (targetTemplate.inputs[targetFieldName]?.input === 'direct') {
|
||||
return i18n.t('nodes.cannotConnectToDirectInput');
|
||||
}
|
||||
if (targetNode.data.type === 'collect' && targetFieldName === 'item') {
|
||||
// Collect nodes shouldn't mix and match field types
|
||||
const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id);
|
||||
if (collectItemType) {
|
||||
if (!areTypesEqual(sourceType, collectItemType)) {
|
||||
return i18n.t('nodes.cannotMixAndMatchCollectionItemTypes');
|
||||
if (nodeId === connectionNodeId) {
|
||||
return i18n.t('nodes.cannotConnectToSelf');
|
||||
}
|
||||
|
||||
if (handleType === connectionHandleType) {
|
||||
if (handleType === 'source') {
|
||||
return i18n.t('nodes.cannotConnectOutputToOutput');
|
||||
}
|
||||
return i18n.t('nodes.cannotConnectInputToInput');
|
||||
}
|
||||
|
||||
// we have to figure out which is the target and which is the source
|
||||
const targetNodeId = handleType === 'target' ? nodeId : connectionNodeId;
|
||||
const targetFieldName = handleType === 'target' ? fieldName : connectionFieldName;
|
||||
const sourceNodeId = handleType === 'source' ? nodeId : connectionNodeId;
|
||||
const sourceFieldName = handleType === 'source' ? fieldName : connectionFieldName;
|
||||
|
||||
if (
|
||||
edges.find((edge) => {
|
||||
edge.target === targetNodeId &&
|
||||
edge.targetHandle === targetFieldName &&
|
||||
edge.source === sourceNodeId &&
|
||||
edge.sourceHandle === sourceFieldName;
|
||||
})
|
||||
) {
|
||||
// We already have a connection from this source to this target
|
||||
return i18n.t('nodes.cannotDuplicateConnection');
|
||||
}
|
||||
|
||||
const targetNode = nodes.find((node) => node.id === targetNodeId);
|
||||
assert(targetNode, `Target node not found: ${targetNodeId}`);
|
||||
const targetTemplate = templates[targetNode.data.type];
|
||||
assert(targetTemplate, `Target template not found: ${targetNode.data.type}`);
|
||||
|
||||
if (targetTemplate.inputs[targetFieldName]?.input === 'direct') {
|
||||
return i18n.t('nodes.cannotConnectToDirectInput');
|
||||
}
|
||||
if (targetNode.data.type === 'collect' && targetFieldName === 'item') {
|
||||
// Collect nodes shouldn't mix and match field types
|
||||
const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id);
|
||||
if (collectItemType) {
|
||||
if (!areTypesEqual(sourceType, collectItemType)) {
|
||||
return i18n.t('nodes.cannotMixAndMatchCollectionItemTypes');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (
|
||||
edges.find((edge) => {
|
||||
return edge.target === targetNodeId && edge.targetHandle === targetFieldName;
|
||||
}) &&
|
||||
// except CollectionItem inputs can have multiples
|
||||
targetType.name !== 'CollectionItemField'
|
||||
) {
|
||||
return i18n.t('nodes.inputMayOnlyHaveOneConnection');
|
||||
}
|
||||
|
||||
if (!validateSourceAndTargetTypes(sourceType, targetType)) {
|
||||
return i18n.t('nodes.fieldTypesMustMatch');
|
||||
}
|
||||
|
||||
const hasCycles = getHasCycles(
|
||||
connectionHandleType === 'source' ? connectionNodeId : nodeId,
|
||||
connectionHandleType === 'source' ? nodeId : connectionNodeId,
|
||||
nodes,
|
||||
edges
|
||||
);
|
||||
|
||||
if (hasCycles) {
|
||||
return i18n.t('nodes.connectionWouldCreateCycle');
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
if (
|
||||
edges.find((edge) => {
|
||||
return edge.target === targetNodeId && edge.targetHandle === targetFieldName;
|
||||
}) &&
|
||||
// except CollectionItem inputs can have multiples
|
||||
targetType.name !== 'CollectionItemField'
|
||||
) {
|
||||
return i18n.t('nodes.inputMayOnlyHaveOneConnection');
|
||||
}
|
||||
|
||||
if (!validateSourceAndTargetTypes(sourceType, targetType)) {
|
||||
return i18n.t('nodes.fieldTypesMustMatch');
|
||||
}
|
||||
|
||||
const hasCycles = getHasCycles(
|
||||
connectionHandleType === 'source' ? connectionNodeId : nodeId,
|
||||
connectionHandleType === 'source' ? nodeId : connectionNodeId,
|
||||
nodes,
|
||||
edges
|
||||
);
|
||||
|
||||
if (hasCycles) {
|
||||
return i18n.t('nodes.connectionWouldCreateCycle');
|
||||
}
|
||||
|
||||
return;
|
||||
});
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
|
Loading…
Reference in New Issue
Block a user