diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx index 10cde5bf49..549c284c0c 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx @@ -30,7 +30,11 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => { return Output field in input: {field?.type}; } - if (field?.type === 'string' && fieldTemplate?.type === 'string') { + if ( + (field?.type === 'string' && fieldTemplate?.type === 'string') || + (field?.type === 'StringPolymorphic' && + fieldTemplate?.type === 'StringPolymorphic') + ) { return ( { ); } - if (field?.type === 'boolean' && fieldTemplate?.type === 'boolean') { + if ( + (field?.type === 'boolean' && fieldTemplate?.type === 'boolean') || + (field?.type === 'BooleanPolymorphic' && + fieldTemplate?.type === 'BooleanPolymorphic') + ) { return ( { if ( (field?.type === 'integer' && fieldTemplate?.type === 'integer') || - (field?.type === 'float' && fieldTemplate?.type === 'float') + (field?.type === 'float' && fieldTemplate?.type === 'float') || + (field?.type === 'FloatPolymorphic' && + fieldTemplate?.type === 'FloatPolymorphic') || + (field?.type === 'IntegerPolymorphic' && + fieldTemplate?.type === 'IntegerPolymorphic') ) { return ( { ); } - if (field?.type === 'ImageField' && fieldTemplate?.type === 'ImageField') { + if ( + (field?.type === 'ImageField' && fieldTemplate?.type === 'ImageField') || + (field?.type === 'ImagePolymorphic' && + fieldTemplate?.type === 'ImagePolymorphic') + ) { return ( + props: FieldComponentProps< + BooleanInputFieldValue | BooleanPolymorphicInputFieldValue, + BooleanInputFieldTemplate | BooleanPolymorphicInputFieldTemplate + > ) => { const { nodeId, field } = props; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageInputField.tsx index 7f96675792..6099593c2a 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageInputField.tsx @@ -12,6 +12,8 @@ import { FieldComponentProps, ImageInputFieldTemplate, ImageInputFieldValue, + ImagePolymorphicInputFieldTemplate, + ImagePolymorphicInputFieldValue, } from 'features/nodes/types/types'; import { memo, useCallback, useMemo } from 'react'; import { FaUndo } from 'react-icons/fa'; @@ -19,7 +21,10 @@ import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { PostUploadAction } from 'services/api/types'; const ImageInputFieldComponent = ( - props: FieldComponentProps + props: FieldComponentProps< + ImageInputFieldValue | ImagePolymorphicInputFieldValue, + ImageInputFieldTemplate | ImagePolymorphicInputFieldTemplate + > ) => { const { nodeId, field } = props; const dispatch = useAppDispatch(); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/NumberInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/NumberInputField.tsx index 61387d751b..2afbec1df8 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/NumberInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/NumberInputField.tsx @@ -12,15 +12,25 @@ import { FieldComponentProps, FloatInputFieldTemplate, FloatInputFieldValue, + FloatPolymorphicInputFieldTemplate, + FloatPolymorphicInputFieldValue, IntegerInputFieldTemplate, IntegerInputFieldValue, + IntegerPolymorphicInputFieldTemplate, + IntegerPolymorphicInputFieldValue, } from 'features/nodes/types/types'; import { memo, useEffect, useMemo, useState } from 'react'; const NumberInputFieldComponent = ( props: FieldComponentProps< - IntegerInputFieldValue | FloatInputFieldValue, - IntegerInputFieldTemplate | FloatInputFieldTemplate + | IntegerInputFieldValue + | IntegerPolymorphicInputFieldValue + | FloatInputFieldValue + | FloatPolymorphicInputFieldValue, + | IntegerInputFieldTemplate + | IntegerPolymorphicInputFieldTemplate + | FloatInputFieldTemplate + | FloatPolymorphicInputFieldTemplate > ) => { const { nodeId, field, fieldTemplate } = props; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/StringInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/StringInputField.tsx index c82b8f612c..720722030b 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/StringInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/StringInputField.tsx @@ -6,11 +6,16 @@ import { StringInputFieldTemplate, StringInputFieldValue, FieldComponentProps, + StringPolymorphicInputFieldValue, + StringPolymorphicInputFieldTemplate, } from 'features/nodes/types/types'; import { ChangeEvent, memo, useCallback } from 'react'; const StringInputFieldComponent = ( - props: FieldComponentProps + props: FieldComponentProps< + StringInputFieldValue | StringPolymorphicInputFieldValue, + StringInputFieldTemplate | StringPolymorphicInputFieldTemplate + > ) => { const { nodeId, field, fieldTemplate } = props; const dispatch = useAppDispatch(); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts b/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts index 5822d2ac53..36f2e8a62c 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts @@ -5,6 +5,10 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { map } from 'lodash-es'; import { useMemo } from 'react'; import { isInvocationNode } from '../types/types'; +import { + POLYMORPHIC_TYPES, + TYPES_WITH_INPUT_COMPONENTS, +} from '../types/constants'; export const useAnyOrDirectInputFieldNames = (nodeId: string) => { const selector = useMemo( @@ -21,7 +25,12 @@ export const useAnyOrDirectInputFieldNames = (nodeId: string) => { return []; } return map(nodeTemplate.inputs) - .filter((field) => ['any', 'direct'].includes(field.input)) + .filter( + (field) => + (['any', 'direct'].includes(field.input) || + POLYMORPHIC_TYPES.includes(field.type)) && + TYPES_WITH_INPUT_COMPONENTS.includes(field.type) + ) .filter((field) => !field.ui_hidden) .sort((a, b) => (a.ui_order ?? 0) - (b.ui_order ?? 0)) .map((field) => field.name) diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts index 93125a3499..eea874cc87 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts @@ -4,6 +4,10 @@ import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { map } from 'lodash-es'; import { useMemo } from 'react'; +import { + POLYMORPHIC_TYPES, + TYPES_WITH_INPUT_COMPONENTS, +} from '../types/constants'; import { isInvocationNode } from '../types/types'; export const useConnectionInputFieldNames = (nodeId: string) => { @@ -21,7 +25,12 @@ export const useConnectionInputFieldNames = (nodeId: string) => { return []; } return map(nodeTemplate.inputs) - .filter((field) => field.input === 'connection') + .filter( + (field) => + (field.input === 'connection' && + !POLYMORPHIC_TYPES.includes(field.type)) || + !TYPES_WITH_INPUT_COMPONENTS.includes(field.type) + ) .filter((field) => !field.ui_hidden) .sort((a, b) => (a.ui_order ?? 0) - (b.ui_order ?? 0)) .map((field) => field.name) diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts index 34e494677f..a54f84d3f0 100644 --- a/invokeai/frontend/web/src/features/nodes/types/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts @@ -96,6 +96,28 @@ export const POLYMORPHIC_TO_SINGLE_MAP = { ColorPolymorphic: 'ColorField', }; +export const TYPES_WITH_INPUT_COMPONENTS = [ + 'string', + 'StringPolymorphic', + 'boolean', + 'BooleanPolymorphic', + 'integer', + 'float', + 'FloatPolymorphic', + 'IntegerPolymorphic', + 'enum', + 'ImageField', + 'ImagePolymorphic', + 'MainModelField', + 'SDXLRefinerModelField', + 'VaeModelField', + 'LoRAModelField', + 'ControlNetModelField', + 'ColorField', + 'SDXLMainModelField', + 'Scheduler', +]; + export const isPolymorphicItemType = ( itemType: string | undefined ): itemType is keyof typeof SINGLE_TO_POLYMORPHIC_MAP => diff --git a/invokeai/frontend/web/src/features/nodes/types/types.ts b/invokeai/frontend/web/src/features/nodes/types/types.ts index 361de26ea5..2a3e5a762b 100644 --- a/invokeai/frontend/web/src/features/nodes/types/types.ts +++ b/invokeai/frontend/web/src/features/nodes/types/types.ts @@ -220,7 +220,7 @@ export type IntegerCollectionInputFieldValue = z.infer< export const zIntegerPolymorphicInputFieldValue = zInputFieldValueBase.extend({ type: z.literal('IntegerPolymorphic'), - value: z.union([z.number().int(), z.array(z.number().int())]).optional(), + value: z.number().int().optional(), }); export type IntegerPolymorphicInputFieldValue = z.infer< typeof zIntegerPolymorphicInputFieldValue @@ -242,7 +242,7 @@ export type FloatCollectionInputFieldValue = z.infer< export const zFloatPolymorphicInputFieldValue = zInputFieldValueBase.extend({ type: z.literal('FloatPolymorphic'), - value: z.union([z.number(), z.array(z.number())]).optional(), + value: z.number().optional(), }); export type FloatPolymorphicInputFieldValue = z.infer< typeof zFloatPolymorphicInputFieldValue @@ -264,7 +264,7 @@ export type StringCollectionInputFieldValue = z.infer< export const zStringPolymorphicInputFieldValue = zInputFieldValueBase.extend({ type: z.literal('StringPolymorphic'), - value: z.union([z.string(), z.array(z.string())]).optional(), + value: z.string().optional(), }); export type StringPolymorphicInputFieldValue = z.infer< typeof zStringPolymorphicInputFieldValue @@ -286,7 +286,7 @@ export type BooleanCollectionInputFieldValue = z.infer< export const zBooleanPolymorphicInputFieldValue = zInputFieldValueBase.extend({ type: z.literal('BooleanPolymorphic'), - value: z.union([z.boolean(), z.array(z.boolean())]).optional(), + value: z.boolean().optional(), }); export type BooleanPolymorphicInputFieldValue = z.infer< typeof zBooleanPolymorphicInputFieldValue @@ -496,7 +496,7 @@ export type ImageInputFieldValue = z.infer; export const zImagePolymorphicInputFieldValue = zInputFieldValueBase.extend({ type: z.literal('ImagePolymorphic'), - value: z.union([zImageField, z.array(zImageField)]).optional(), + value: zImageField.optional(), }); export type ImagePolymorphicInputFieldValue = z.infer< typeof zImagePolymorphicInputFieldValue