feat(ui): render input components for polymorphic fields

Polymorphic fields now render the appropriate input component for their base type.

For example, float polymorphics will render the number input box.

You no longer need to specify ui_type to force it to display.

TODO: The UI *may* break if a list is provided as the default value for a polymorphic field.
This commit is contained in:
psychedelicious 2023-09-15 11:05:25 +10:00
parent 144ede031e
commit e78b36a9f7
9 changed files with 97 additions and 16 deletions

View File

@ -30,7 +30,11 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
return <Box p={2}>Output field in input: {field?.type}</Box>; return <Box p={2}>Output field in input: {field?.type}</Box>;
} }
if (field?.type === 'string' && fieldTemplate?.type === 'string') { if (
(field?.type === 'string' && fieldTemplate?.type === 'string') ||
(field?.type === 'StringPolymorphic' &&
fieldTemplate?.type === 'StringPolymorphic')
) {
return ( return (
<StringInputField <StringInputField
nodeId={nodeId} nodeId={nodeId}
@ -40,7 +44,11 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
); );
} }
if (field?.type === 'boolean' && fieldTemplate?.type === 'boolean') { if (
(field?.type === 'boolean' && fieldTemplate?.type === 'boolean') ||
(field?.type === 'BooleanPolymorphic' &&
fieldTemplate?.type === 'BooleanPolymorphic')
) {
return ( return (
<BooleanInputField <BooleanInputField
nodeId={nodeId} nodeId={nodeId}
@ -52,7 +60,11 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
if ( if (
(field?.type === 'integer' && fieldTemplate?.type === 'integer') || (field?.type === 'integer' && fieldTemplate?.type === 'integer') ||
(field?.type === 'float' && fieldTemplate?.type === 'float') (field?.type === 'float' && fieldTemplate?.type === 'float') ||
(field?.type === 'FloatPolymorphic' &&
fieldTemplate?.type === 'FloatPolymorphic') ||
(field?.type === 'IntegerPolymorphic' &&
fieldTemplate?.type === 'IntegerPolymorphic')
) { ) {
return ( return (
<NumberInputField <NumberInputField
@ -73,7 +85,11 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
); );
} }
if (field?.type === 'ImageField' && fieldTemplate?.type === 'ImageField') { if (
(field?.type === 'ImageField' && fieldTemplate?.type === 'ImageField') ||
(field?.type === 'ImagePolymorphic' &&
fieldTemplate?.type === 'ImagePolymorphic')
) {
return ( return (
<ImageInputField <ImageInputField
nodeId={nodeId} nodeId={nodeId}

View File

@ -4,12 +4,17 @@ import { fieldBooleanValueChanged } from 'features/nodes/store/nodesSlice';
import { import {
BooleanInputFieldTemplate, BooleanInputFieldTemplate,
BooleanInputFieldValue, BooleanInputFieldValue,
BooleanPolymorphicInputFieldTemplate,
BooleanPolymorphicInputFieldValue,
FieldComponentProps, FieldComponentProps,
} from 'features/nodes/types/types'; } from 'features/nodes/types/types';
import { ChangeEvent, memo, useCallback } from 'react'; import { ChangeEvent, memo, useCallback } from 'react';
const BooleanInputFieldComponent = ( const BooleanInputFieldComponent = (
props: FieldComponentProps<BooleanInputFieldValue, BooleanInputFieldTemplate> props: FieldComponentProps<
BooleanInputFieldValue | BooleanPolymorphicInputFieldValue,
BooleanInputFieldTemplate | BooleanPolymorphicInputFieldTemplate
>
) => { ) => {
const { nodeId, field } = props; const { nodeId, field } = props;

View File

@ -12,6 +12,8 @@ import {
FieldComponentProps, FieldComponentProps,
ImageInputFieldTemplate, ImageInputFieldTemplate,
ImageInputFieldValue, ImageInputFieldValue,
ImagePolymorphicInputFieldTemplate,
ImagePolymorphicInputFieldValue,
} from 'features/nodes/types/types'; } from 'features/nodes/types/types';
import { memo, useCallback, useMemo } from 'react'; import { memo, useCallback, useMemo } from 'react';
import { FaUndo } from 'react-icons/fa'; import { FaUndo } from 'react-icons/fa';
@ -19,7 +21,10 @@ import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { PostUploadAction } from 'services/api/types'; import { PostUploadAction } from 'services/api/types';
const ImageInputFieldComponent = ( const ImageInputFieldComponent = (
props: FieldComponentProps<ImageInputFieldValue, ImageInputFieldTemplate> props: FieldComponentProps<
ImageInputFieldValue | ImagePolymorphicInputFieldValue,
ImageInputFieldTemplate | ImagePolymorphicInputFieldTemplate
>
) => { ) => {
const { nodeId, field } = props; const { nodeId, field } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();

View File

@ -12,15 +12,25 @@ import {
FieldComponentProps, FieldComponentProps,
FloatInputFieldTemplate, FloatInputFieldTemplate,
FloatInputFieldValue, FloatInputFieldValue,
FloatPolymorphicInputFieldTemplate,
FloatPolymorphicInputFieldValue,
IntegerInputFieldTemplate, IntegerInputFieldTemplate,
IntegerInputFieldValue, IntegerInputFieldValue,
IntegerPolymorphicInputFieldTemplate,
IntegerPolymorphicInputFieldValue,
} from 'features/nodes/types/types'; } from 'features/nodes/types/types';
import { memo, useEffect, useMemo, useState } from 'react'; import { memo, useEffect, useMemo, useState } from 'react';
const NumberInputFieldComponent = ( const NumberInputFieldComponent = (
props: FieldComponentProps< props: FieldComponentProps<
IntegerInputFieldValue | FloatInputFieldValue, | IntegerInputFieldValue
IntegerInputFieldTemplate | FloatInputFieldTemplate | IntegerPolymorphicInputFieldValue
| FloatInputFieldValue
| FloatPolymorphicInputFieldValue,
| IntegerInputFieldTemplate
| IntegerPolymorphicInputFieldTemplate
| FloatInputFieldTemplate
| FloatPolymorphicInputFieldTemplate
> >
) => { ) => {
const { nodeId, field, fieldTemplate } = props; const { nodeId, field, fieldTemplate } = props;

View File

@ -6,11 +6,16 @@ import {
StringInputFieldTemplate, StringInputFieldTemplate,
StringInputFieldValue, StringInputFieldValue,
FieldComponentProps, FieldComponentProps,
StringPolymorphicInputFieldValue,
StringPolymorphicInputFieldTemplate,
} from 'features/nodes/types/types'; } from 'features/nodes/types/types';
import { ChangeEvent, memo, useCallback } from 'react'; import { ChangeEvent, memo, useCallback } from 'react';
const StringInputFieldComponent = ( const StringInputFieldComponent = (
props: FieldComponentProps<StringInputFieldValue, StringInputFieldTemplate> props: FieldComponentProps<
StringInputFieldValue | StringPolymorphicInputFieldValue,
StringInputFieldTemplate | StringPolymorphicInputFieldTemplate
>
) => { ) => {
const { nodeId, field, fieldTemplate } = props; const { nodeId, field, fieldTemplate } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();

View File

@ -5,6 +5,10 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { map } from 'lodash-es'; import { map } from 'lodash-es';
import { useMemo } from 'react'; import { useMemo } from 'react';
import { isInvocationNode } from '../types/types'; import { isInvocationNode } from '../types/types';
import {
POLYMORPHIC_TYPES,
TYPES_WITH_INPUT_COMPONENTS,
} from '../types/constants';
export const useAnyOrDirectInputFieldNames = (nodeId: string) => { export const useAnyOrDirectInputFieldNames = (nodeId: string) => {
const selector = useMemo( const selector = useMemo(
@ -21,7 +25,12 @@ export const useAnyOrDirectInputFieldNames = (nodeId: string) => {
return []; return [];
} }
return map(nodeTemplate.inputs) return map(nodeTemplate.inputs)
.filter((field) => ['any', 'direct'].includes(field.input)) .filter(
(field) =>
(['any', 'direct'].includes(field.input) ||
POLYMORPHIC_TYPES.includes(field.type)) &&
TYPES_WITH_INPUT_COMPONENTS.includes(field.type)
)
.filter((field) => !field.ui_hidden) .filter((field) => !field.ui_hidden)
.sort((a, b) => (a.ui_order ?? 0) - (b.ui_order ?? 0)) .sort((a, b) => (a.ui_order ?? 0) - (b.ui_order ?? 0))
.map((field) => field.name) .map((field) => field.name)

View File

@ -4,6 +4,10 @@ import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { map } from 'lodash-es'; import { map } from 'lodash-es';
import { useMemo } from 'react'; import { useMemo } from 'react';
import {
POLYMORPHIC_TYPES,
TYPES_WITH_INPUT_COMPONENTS,
} from '../types/constants';
import { isInvocationNode } from '../types/types'; import { isInvocationNode } from '../types/types';
export const useConnectionInputFieldNames = (nodeId: string) => { export const useConnectionInputFieldNames = (nodeId: string) => {
@ -21,7 +25,12 @@ export const useConnectionInputFieldNames = (nodeId: string) => {
return []; return [];
} }
return map(nodeTemplate.inputs) return map(nodeTemplate.inputs)
.filter((field) => field.input === 'connection') .filter(
(field) =>
(field.input === 'connection' &&
!POLYMORPHIC_TYPES.includes(field.type)) ||
!TYPES_WITH_INPUT_COMPONENTS.includes(field.type)
)
.filter((field) => !field.ui_hidden) .filter((field) => !field.ui_hidden)
.sort((a, b) => (a.ui_order ?? 0) - (b.ui_order ?? 0)) .sort((a, b) => (a.ui_order ?? 0) - (b.ui_order ?? 0))
.map((field) => field.name) .map((field) => field.name)

View File

@ -96,6 +96,28 @@ export const POLYMORPHIC_TO_SINGLE_MAP = {
ColorPolymorphic: 'ColorField', ColorPolymorphic: 'ColorField',
}; };
export const TYPES_WITH_INPUT_COMPONENTS = [
'string',
'StringPolymorphic',
'boolean',
'BooleanPolymorphic',
'integer',
'float',
'FloatPolymorphic',
'IntegerPolymorphic',
'enum',
'ImageField',
'ImagePolymorphic',
'MainModelField',
'SDXLRefinerModelField',
'VaeModelField',
'LoRAModelField',
'ControlNetModelField',
'ColorField',
'SDXLMainModelField',
'Scheduler',
];
export const isPolymorphicItemType = ( export const isPolymorphicItemType = (
itemType: string | undefined itemType: string | undefined
): itemType is keyof typeof SINGLE_TO_POLYMORPHIC_MAP => ): itemType is keyof typeof SINGLE_TO_POLYMORPHIC_MAP =>

View File

@ -220,7 +220,7 @@ export type IntegerCollectionInputFieldValue = z.infer<
export const zIntegerPolymorphicInputFieldValue = zInputFieldValueBase.extend({ export const zIntegerPolymorphicInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('IntegerPolymorphic'), type: z.literal('IntegerPolymorphic'),
value: z.union([z.number().int(), z.array(z.number().int())]).optional(), value: z.number().int().optional(),
}); });
export type IntegerPolymorphicInputFieldValue = z.infer< export type IntegerPolymorphicInputFieldValue = z.infer<
typeof zIntegerPolymorphicInputFieldValue typeof zIntegerPolymorphicInputFieldValue
@ -242,7 +242,7 @@ export type FloatCollectionInputFieldValue = z.infer<
export const zFloatPolymorphicInputFieldValue = zInputFieldValueBase.extend({ export const zFloatPolymorphicInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('FloatPolymorphic'), type: z.literal('FloatPolymorphic'),
value: z.union([z.number(), z.array(z.number())]).optional(), value: z.number().optional(),
}); });
export type FloatPolymorphicInputFieldValue = z.infer< export type FloatPolymorphicInputFieldValue = z.infer<
typeof zFloatPolymorphicInputFieldValue typeof zFloatPolymorphicInputFieldValue
@ -264,7 +264,7 @@ export type StringCollectionInputFieldValue = z.infer<
export const zStringPolymorphicInputFieldValue = zInputFieldValueBase.extend({ export const zStringPolymorphicInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('StringPolymorphic'), type: z.literal('StringPolymorphic'),
value: z.union([z.string(), z.array(z.string())]).optional(), value: z.string().optional(),
}); });
export type StringPolymorphicInputFieldValue = z.infer< export type StringPolymorphicInputFieldValue = z.infer<
typeof zStringPolymorphicInputFieldValue typeof zStringPolymorphicInputFieldValue
@ -286,7 +286,7 @@ export type BooleanCollectionInputFieldValue = z.infer<
export const zBooleanPolymorphicInputFieldValue = zInputFieldValueBase.extend({ export const zBooleanPolymorphicInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('BooleanPolymorphic'), type: z.literal('BooleanPolymorphic'),
value: z.union([z.boolean(), z.array(z.boolean())]).optional(), value: z.boolean().optional(),
}); });
export type BooleanPolymorphicInputFieldValue = z.infer< export type BooleanPolymorphicInputFieldValue = z.infer<
typeof zBooleanPolymorphicInputFieldValue typeof zBooleanPolymorphicInputFieldValue
@ -496,7 +496,7 @@ export type ImageInputFieldValue = z.infer<typeof zImageInputFieldValue>;
export const zImagePolymorphicInputFieldValue = zInputFieldValueBase.extend({ export const zImagePolymorphicInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('ImagePolymorphic'), type: z.literal('ImagePolymorphic'),
value: z.union([zImageField, z.array(zImageField)]).optional(), value: zImageField.optional(),
}); });
export type ImagePolymorphicInputFieldValue = z.infer< export type ImagePolymorphicInputFieldValue = z.infer<
typeof zImagePolymorphicInputFieldValue typeof zImagePolymorphicInputFieldValue