From 5601858f4f9b72e3e1a42f4e39c535186b929766 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 22 Sep 2023 21:38:51 +1000 Subject: [PATCH] feat(ui): allow numbers to connect to strings Pydantic handles the casting so this is always safe. Also de-duplicate some validation logic code that was needlessly duplicated. --- invokeai/app/services/graph.py | 4 + .../nodes/hooks/useIsValidConnection.ts | 78 ++++--------------- .../util/makeIsConnectionValidSelector.ts | 69 +--------------- .../util/validateSourceAndTargetTypes.ts | 74 ++++++++++++++++++ 4 files changed, 95 insertions(+), 130 deletions(-) create mode 100644 invokeai/frontend/web/src/features/nodes/store/util/validateSourceAndTargetTypes.ts diff --git a/invokeai/app/services/graph.py b/invokeai/app/services/graph.py index 2a5fc4c441..9dccd14026 100644 --- a/invokeai/app/services/graph.py +++ b/invokeai/app/services/graph.py @@ -117,6 +117,10 @@ def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool: if from_type is int and to_type is float: return True + # allow int|float -> str, pydantic will cast for us + if (from_type is int or from_type is float) and to_type is str: + return True + # if not issubclass(from_type, to_type): if not is_union_subtype(from_type, to_type): return False diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts index d1d10bb7e7..0439445c24 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts @@ -3,12 +3,7 @@ import graphlib from '@dagrejs/graphlib'; import { useAppSelector } from 'app/store/storeHooks'; import { useCallback } from 'react'; import { Connection, Edge, Node, useReactFlow } from 'reactflow'; -import { - COLLECTION_MAP, - COLLECTION_TYPES, - POLYMORPHIC_TO_SINGLE_MAP, - POLYMORPHIC_TYPES, -} from '../types/constants'; +import { validateSourceAndTargetTypes } from '../store/util/validateSourceAndTargetTypes'; import { InvocationNodeData } from '../types/types'; /** @@ -23,11 +18,6 @@ export const useIsValidConnection = () => { ); const isValidConnection = useCallback( ({ source, sourceHandle, target, targetHandle }: Connection): boolean => { - if (!shouldValidateGraph) { - // manual override! - return true; - } - const edges = flow.getEdges(); const nodes = flow.getNodes(); // Connection must have valid targets @@ -52,6 +42,16 @@ export const useIsValidConnection = () => { return false; } + if (source === target) { + // Don't allow nodes to connect to themselves, even if validation is disabled + return false; + } + + if (!shouldValidateGraph) { + // manual override! + return true; + } + if ( edges .filter((edge) => { @@ -76,60 +76,8 @@ export const useIsValidConnection = () => { return false; } - /** - * Connection types must be the same for a connection, with exceptions: - * - CollectionItem can connect to any non-Collection - * - Non-Collections can connect to CollectionItem - * - Anything (non-Collections, Collections, Polymorphics) can connect to Polymorphics of the same base type - * - Generic Collection can connect to any other Collection or Polymorphic - * - Any Collection can connect to a Generic Collection - */ - - if (sourceType !== targetType) { - const isCollectionItemToNonCollection = - sourceType === 'CollectionItem' && - !COLLECTION_TYPES.includes(targetType); - - const isNonCollectionToCollectionItem = - targetType === 'CollectionItem' && - !COLLECTION_TYPES.includes(sourceType) && - !POLYMORPHIC_TYPES.includes(sourceType); - - const isAnythingToPolymorphicOfSameBaseType = - POLYMORPHIC_TYPES.includes(targetType) && - (() => { - if (!POLYMORPHIC_TYPES.includes(targetType)) { - return false; - } - const baseType = - POLYMORPHIC_TO_SINGLE_MAP[ - targetType as keyof typeof POLYMORPHIC_TO_SINGLE_MAP - ]; - - const collectionType = - COLLECTION_MAP[baseType as keyof typeof COLLECTION_MAP]; - - return sourceType === baseType || sourceType === collectionType; - })(); - - const isGenericCollectionToAnyCollectionOrPolymorphic = - sourceType === 'Collection' && - (COLLECTION_TYPES.includes(targetType) || - POLYMORPHIC_TYPES.includes(targetType)); - - const isCollectionToGenericCollection = - targetType === 'Collection' && COLLECTION_TYPES.includes(sourceType); - - const isIntToFloat = sourceType === 'integer' && targetType === 'float'; - - return ( - isCollectionItemToNonCollection || - isNonCollectionToCollectionItem || - isAnythingToPolymorphicOfSameBaseType || - isGenericCollectionToAnyCollectionOrPolymorphic || - isCollectionToGenericCollection || - isIntToFloat - ); + if (!validateSourceAndTargetTypes(sourceType, targetType)) { + return false; } // Graphs much be acyclic (no loops!) diff --git a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts index ac157bb476..1be2d579d8 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts @@ -1,15 +1,10 @@ import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; import { getIsGraphAcyclic } from 'features/nodes/hooks/useIsValidConnection'; -import { - COLLECTION_MAP, - COLLECTION_TYPES, - POLYMORPHIC_TO_SINGLE_MAP, - POLYMORPHIC_TYPES, -} from 'features/nodes/types/constants'; import { FieldType } from 'features/nodes/types/types'; -import { HandleType } from 'reactflow'; import i18n from 'i18next'; +import { HandleType } from 'reactflow'; +import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes'; /** * NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts` @@ -70,64 +65,8 @@ export const makeConnectionErrorSelector = ( return i18n.t('nodes.inputMayOnlyHaveOneConnection'); } - /** - * Connection types must be the same for a connection, with exceptions: - * - CollectionItem can connect to any non-Collection - * - Non-Collections can connect to CollectionItem - * - Anything (non-Collections, Collections, Polymorphics) can connect to Polymorphics of the same base type - * - Generic Collection can connect to any other Collection or Polymorphic - * - Any Collection can connect to a Generic Collection - */ - - if (sourceType !== targetType) { - const isCollectionItemToNonCollection = - sourceType === 'CollectionItem' && - !COLLECTION_TYPES.includes(targetType); - - const isNonCollectionToCollectionItem = - targetType === 'CollectionItem' && - !COLLECTION_TYPES.includes(sourceType) && - !POLYMORPHIC_TYPES.includes(sourceType); - - const isAnythingToPolymorphicOfSameBaseType = - POLYMORPHIC_TYPES.includes(targetType) && - (() => { - if (!POLYMORPHIC_TYPES.includes(targetType)) { - return false; - } - const baseType = - POLYMORPHIC_TO_SINGLE_MAP[ - targetType as keyof typeof POLYMORPHIC_TO_SINGLE_MAP - ]; - - const collectionType = - COLLECTION_MAP[baseType as keyof typeof COLLECTION_MAP]; - - return sourceType === baseType || sourceType === collectionType; - })(); - - const isGenericCollectionToAnyCollectionOrPolymorphic = - sourceType === 'Collection' && - (COLLECTION_TYPES.includes(targetType) || - POLYMORPHIC_TYPES.includes(targetType)); - - const isCollectionToGenericCollection = - targetType === 'Collection' && COLLECTION_TYPES.includes(sourceType); - - const isIntToFloat = sourceType === 'integer' && targetType === 'float'; - - if ( - !( - isCollectionItemToNonCollection || - isNonCollectionToCollectionItem || - isAnythingToPolymorphicOfSameBaseType || - isGenericCollectionToAnyCollectionOrPolymorphic || - isCollectionToGenericCollection || - isIntToFloat - ) - ) { - return i18n.t('nodes.fieldTypesMustMatch'); - } + if (!validateSourceAndTargetTypes(sourceType, targetType)) { + return i18n.t('nodes.fieldTypesMustMatch'); } const isGraphAcyclic = getIsGraphAcyclic( diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateSourceAndTargetTypes.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateSourceAndTargetTypes.ts new file mode 100644 index 0000000000..4f0be3329a --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateSourceAndTargetTypes.ts @@ -0,0 +1,74 @@ +import { + COLLECTION_MAP, + COLLECTION_TYPES, + POLYMORPHIC_TO_SINGLE_MAP, + POLYMORPHIC_TYPES, +} from 'features/nodes/types/constants'; +import { FieldType } from 'features/nodes/types/types'; + +export const validateSourceAndTargetTypes = ( + sourceType: FieldType, + targetType: FieldType +) => { + if (sourceType === targetType) { + return true; + } + + /** + * Connection types must be the same for a connection, with exceptions: + * - CollectionItem can connect to any non-Collection + * - Non-Collections can connect to CollectionItem + * - Anything (non-Collections, Collections, Polymorphics) can connect to Polymorphics of the same base type + * - Generic Collection can connect to any other Collection or Polymorphic + * - Any Collection can connect to a Generic Collection + */ + + const isCollectionItemToNonCollection = + sourceType === 'CollectionItem' && !COLLECTION_TYPES.includes(targetType); + + const isNonCollectionToCollectionItem = + targetType === 'CollectionItem' && + !COLLECTION_TYPES.includes(sourceType) && + !POLYMORPHIC_TYPES.includes(sourceType); + + const isAnythingToPolymorphicOfSameBaseType = + POLYMORPHIC_TYPES.includes(targetType) && + (() => { + if (!POLYMORPHIC_TYPES.includes(targetType)) { + return false; + } + const baseType = + POLYMORPHIC_TO_SINGLE_MAP[ + targetType as keyof typeof POLYMORPHIC_TO_SINGLE_MAP + ]; + + const collectionType = + COLLECTION_MAP[baseType as keyof typeof COLLECTION_MAP]; + + return sourceType === baseType || sourceType === collectionType; + })(); + + const isGenericCollectionToAnyCollectionOrPolymorphic = + sourceType === 'Collection' && + (COLLECTION_TYPES.includes(targetType) || + POLYMORPHIC_TYPES.includes(targetType)); + + const isCollectionToGenericCollection = + targetType === 'Collection' && COLLECTION_TYPES.includes(sourceType); + + const isIntToFloat = sourceType === 'integer' && targetType === 'float'; + + const isIntOrFloatToString = + (sourceType === 'integer' || sourceType === 'float') && + targetType === 'string'; + + return ( + isCollectionItemToNonCollection || + isNonCollectionToCollectionItem || + isAnythingToPolymorphicOfSameBaseType || + isGenericCollectionToAnyCollectionOrPolymorphic || + isCollectionToGenericCollection || + isIntToFloat || + isIntOrFloatToString + ); +};