From e78b36a9f7cbbd68ad58a4a15e61954bc27964c3 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 15 Sep 2023 11:05:25 +1000 Subject: [PATCH] feat(ui): render input components for polymorphic fields Polymorphic fields now render the appropriate input component for their base type. For example, float polymorphics will render the number input box. You no longer need to specify ui_type to force it to display. TODO: The UI *may* break if a list is provided as the default value for a polymorphic field. --- .../Invocation/fields/InputFieldRenderer.tsx | 24 +++++++++++++++---- .../fields/inputs/BooleanInputField.tsx | 7 +++++- .../fields/inputs/ImageInputField.tsx | 7 +++++- .../fields/inputs/NumberInputField.tsx | 14 +++++++++-- .../fields/inputs/StringInputField.tsx | 7 +++++- .../hooks/useAnyOrDirectInputFieldNames.ts | 11 ++++++++- .../hooks/useConnectionInputFieldNames.ts | 11 ++++++++- .../web/src/features/nodes/types/constants.ts | 22 +++++++++++++++++ .../web/src/features/nodes/types/types.ts | 10 ++++---- 9 files changed, 97 insertions(+), 16 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx index 10cde5bf49..549c284c0c 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx @@ -30,7 +30,11 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => { return Output field in input: {field?.type}; } - if (field?.type === 'string' && fieldTemplate?.type === 'string') { + if ( + (field?.type === 'string' && fieldTemplate?.type === 'string') || + (field?.type === 'StringPolymorphic' && + fieldTemplate?.type === 'StringPolymorphic') + ) { return ( { ); } - if (field?.type === 'boolean' && fieldTemplate?.type === 'boolean') { + if ( + (field?.type === 'boolean' && fieldTemplate?.type === 'boolean') || + (field?.type === 'BooleanPolymorphic' && + fieldTemplate?.type === 'BooleanPolymorphic') + ) { return ( { if ( (field?.type === 'integer' && fieldTemplate?.type === 'integer') || - (field?.type === 'float' && fieldTemplate?.type === 'float') + (field?.type === 'float' && fieldTemplate?.type === 'float') || + (field?.type === 'FloatPolymorphic' && + fieldTemplate?.type === 'FloatPolymorphic') || + (field?.type === 'IntegerPolymorphic' && + fieldTemplate?.type === 'IntegerPolymorphic') ) { return ( { ); } - if (field?.type === 'ImageField' && fieldTemplate?.type === 'ImageField') { + if ( + (field?.type === 'ImageField' && fieldTemplate?.type === 'ImageField') || + (field?.type === 'ImagePolymorphic' && + fieldTemplate?.type === 'ImagePolymorphic') + ) { return ( + props: FieldComponentProps< + BooleanInputFieldValue | BooleanPolymorphicInputFieldValue, + BooleanInputFieldTemplate | BooleanPolymorphicInputFieldTemplate + > ) => { const { nodeId, field } = props; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageInputField.tsx index 7f96675792..6099593c2a 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageInputField.tsx @@ -12,6 +12,8 @@ import { FieldComponentProps, ImageInputFieldTemplate, ImageInputFieldValue, + ImagePolymorphicInputFieldTemplate, + ImagePolymorphicInputFieldValue, } from 'features/nodes/types/types'; import { memo, useCallback, useMemo } from 'react'; import { FaUndo } from 'react-icons/fa'; @@ -19,7 +21,10 @@ import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { PostUploadAction } from 'services/api/types'; const ImageInputFieldComponent = ( - props: FieldComponentProps + props: FieldComponentProps< + ImageInputFieldValue | ImagePolymorphicInputFieldValue, + ImageInputFieldTemplate | ImagePolymorphicInputFieldTemplate + > ) => { const { nodeId, field } = props; const dispatch = useAppDispatch(); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/NumberInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/NumberInputField.tsx index 61387d751b..2afbec1df8 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/NumberInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/NumberInputField.tsx @@ -12,15 +12,25 @@ import { FieldComponentProps, FloatInputFieldTemplate, FloatInputFieldValue, + FloatPolymorphicInputFieldTemplate, + FloatPolymorphicInputFieldValue, IntegerInputFieldTemplate, IntegerInputFieldValue, + IntegerPolymorphicInputFieldTemplate, + IntegerPolymorphicInputFieldValue, } from 'features/nodes/types/types'; import { memo, useEffect, useMemo, useState } from 'react'; const NumberInputFieldComponent = ( props: FieldComponentProps< - IntegerInputFieldValue | FloatInputFieldValue, - IntegerInputFieldTemplate | FloatInputFieldTemplate + | IntegerInputFieldValue + | IntegerPolymorphicInputFieldValue + | FloatInputFieldValue + | FloatPolymorphicInputFieldValue, + | IntegerInputFieldTemplate + | IntegerPolymorphicInputFieldTemplate + | FloatInputFieldTemplate + | FloatPolymorphicInputFieldTemplate > ) => { const { nodeId, field, fieldTemplate } = props; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/StringInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/StringInputField.tsx index c82b8f612c..720722030b 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/StringInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/StringInputField.tsx @@ -6,11 +6,16 @@ import { StringInputFieldTemplate, StringInputFieldValue, FieldComponentProps, + StringPolymorphicInputFieldValue, + StringPolymorphicInputFieldTemplate, } from 'features/nodes/types/types'; import { ChangeEvent, memo, useCallback } from 'react'; const StringInputFieldComponent = ( - props: FieldComponentProps + props: FieldComponentProps< + StringInputFieldValue | StringPolymorphicInputFieldValue, + StringInputFieldTemplate | StringPolymorphicInputFieldTemplate + > ) => { const { nodeId, field, fieldTemplate } = props; const dispatch = useAppDispatch(); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts b/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts index 5822d2ac53..36f2e8a62c 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts @@ -5,6 +5,10 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { map } from 'lodash-es'; import { useMemo } from 'react'; import { isInvocationNode } from '../types/types'; +import { + POLYMORPHIC_TYPES, + TYPES_WITH_INPUT_COMPONENTS, +} from '../types/constants'; export const useAnyOrDirectInputFieldNames = (nodeId: string) => { const selector = useMemo( @@ -21,7 +25,12 @@ export const useAnyOrDirectInputFieldNames = (nodeId: string) => { return []; } return map(nodeTemplate.inputs) - .filter((field) => ['any', 'direct'].includes(field.input)) + .filter( + (field) => + (['any', 'direct'].includes(field.input) || + POLYMORPHIC_TYPES.includes(field.type)) && + TYPES_WITH_INPUT_COMPONENTS.includes(field.type) + ) .filter((field) => !field.ui_hidden) .sort((a, b) => (a.ui_order ?? 0) - (b.ui_order ?? 0)) .map((field) => field.name) diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts index 93125a3499..eea874cc87 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts @@ -4,6 +4,10 @@ import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { map } from 'lodash-es'; import { useMemo } from 'react'; +import { + POLYMORPHIC_TYPES, + TYPES_WITH_INPUT_COMPONENTS, +} from '../types/constants'; import { isInvocationNode } from '../types/types'; export const useConnectionInputFieldNames = (nodeId: string) => { @@ -21,7 +25,12 @@ export const useConnectionInputFieldNames = (nodeId: string) => { return []; } return map(nodeTemplate.inputs) - .filter((field) => field.input === 'connection') + .filter( + (field) => + (field.input === 'connection' && + !POLYMORPHIC_TYPES.includes(field.type)) || + !TYPES_WITH_INPUT_COMPONENTS.includes(field.type) + ) .filter((field) => !field.ui_hidden) .sort((a, b) => (a.ui_order ?? 0) - (b.ui_order ?? 0)) .map((field) => field.name) diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts index 34e494677f..a54f84d3f0 100644 --- a/invokeai/frontend/web/src/features/nodes/types/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts @@ -96,6 +96,28 @@ export const POLYMORPHIC_TO_SINGLE_MAP = { ColorPolymorphic: 'ColorField', }; +export const TYPES_WITH_INPUT_COMPONENTS = [ + 'string', + 'StringPolymorphic', + 'boolean', + 'BooleanPolymorphic', + 'integer', + 'float', + 'FloatPolymorphic', + 'IntegerPolymorphic', + 'enum', + 'ImageField', + 'ImagePolymorphic', + 'MainModelField', + 'SDXLRefinerModelField', + 'VaeModelField', + 'LoRAModelField', + 'ControlNetModelField', + 'ColorField', + 'SDXLMainModelField', + 'Scheduler', +]; + export const isPolymorphicItemType = ( itemType: string | undefined ): itemType is keyof typeof SINGLE_TO_POLYMORPHIC_MAP => diff --git a/invokeai/frontend/web/src/features/nodes/types/types.ts b/invokeai/frontend/web/src/features/nodes/types/types.ts index 361de26ea5..2a3e5a762b 100644 --- a/invokeai/frontend/web/src/features/nodes/types/types.ts +++ b/invokeai/frontend/web/src/features/nodes/types/types.ts @@ -220,7 +220,7 @@ export type IntegerCollectionInputFieldValue = z.infer< export const zIntegerPolymorphicInputFieldValue = zInputFieldValueBase.extend({ type: z.literal('IntegerPolymorphic'), - value: z.union([z.number().int(), z.array(z.number().int())]).optional(), + value: z.number().int().optional(), }); export type IntegerPolymorphicInputFieldValue = z.infer< typeof zIntegerPolymorphicInputFieldValue @@ -242,7 +242,7 @@ export type FloatCollectionInputFieldValue = z.infer< export const zFloatPolymorphicInputFieldValue = zInputFieldValueBase.extend({ type: z.literal('FloatPolymorphic'), - value: z.union([z.number(), z.array(z.number())]).optional(), + value: z.number().optional(), }); export type FloatPolymorphicInputFieldValue = z.infer< typeof zFloatPolymorphicInputFieldValue @@ -264,7 +264,7 @@ export type StringCollectionInputFieldValue = z.infer< export const zStringPolymorphicInputFieldValue = zInputFieldValueBase.extend({ type: z.literal('StringPolymorphic'), - value: z.union([z.string(), z.array(z.string())]).optional(), + value: z.string().optional(), }); export type StringPolymorphicInputFieldValue = z.infer< typeof zStringPolymorphicInputFieldValue @@ -286,7 +286,7 @@ export type BooleanCollectionInputFieldValue = z.infer< export const zBooleanPolymorphicInputFieldValue = zInputFieldValueBase.extend({ type: z.literal('BooleanPolymorphic'), - value: z.union([z.boolean(), z.array(z.boolean())]).optional(), + value: z.boolean().optional(), }); export type BooleanPolymorphicInputFieldValue = z.infer< typeof zBooleanPolymorphicInputFieldValue @@ -496,7 +496,7 @@ export type ImageInputFieldValue = z.infer; export const zImagePolymorphicInputFieldValue = zInputFieldValueBase.extend({ type: z.literal('ImagePolymorphic'), - value: z.union([zImageField, z.array(zImageField)]).optional(), + value: zImageField.optional(), }); export type ImagePolymorphicInputFieldValue = z.infer< typeof zImagePolymorphicInputFieldValue