diff --git a/invokeai/app/services/graph.py b/invokeai/app/services/graph.py index 18c99fafc1..0e40636280 100644 --- a/invokeai/app/services/graph.py +++ b/invokeai/app/services/graph.py @@ -112,6 +112,10 @@ def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool: if to_type in get_args(from_type): return True + # allow int -> float, pydantic will cast for us + if from_type is int and to_type is float: + return True + # if not issubclass(from_type, to_type): if not is_union_subtype(from_type, to_type): return False diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts index 1c372e551d..d1d10bb7e7 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts @@ -120,12 +120,15 @@ export const useIsValidConnection = () => { const isCollectionToGenericCollection = targetType === 'Collection' && COLLECTION_TYPES.includes(sourceType); + const isIntToFloat = sourceType === 'integer' && targetType === 'float'; + return ( isCollectionItemToNonCollection || isNonCollectionToCollectionItem || isAnythingToPolymorphicOfSameBaseType || isGenericCollectionToAnyCollectionOrPolymorphic || - isCollectionToGenericCollection + isCollectionToGenericCollection || + isIntToFloat ); } diff --git a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts index 0c5fee509c..5cb6d557e8 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts @@ -113,13 +113,16 @@ export const makeConnectionErrorSelector = ( const isCollectionToGenericCollection = targetType === 'Collection' && COLLECTION_TYPES.includes(sourceType); + const isIntToFloat = sourceType === 'integer' && targetType === 'float'; + if ( !( isCollectionItemToNonCollection || isNonCollectionToCollectionItem || isAnythingToPolymorphicOfSameBaseType || isGenericCollectionToAnyCollectionOrPolymorphic || - isCollectionToGenericCollection + isCollectionToGenericCollection || + isIntToFloat ) ) { return 'Field types must match';