diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx index c4e8da6eda..787a471373 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx @@ -6,6 +6,8 @@ import { isBoardFieldInputTemplate, isBooleanFieldInputInstance, isBooleanFieldInputTemplate, + isCLIPEmbedModelFieldInputInstance, + isCLIPEmbedModelFieldInputTemplate, isColorFieldInputInstance, isColorFieldInputTemplate, isControlNetModelFieldInputInstance, @@ -16,6 +18,8 @@ import { isFloatFieldInputTemplate, isFluxMainModelFieldInputInstance, isFluxMainModelFieldInputTemplate, + isFluxVAEModelFieldInputInstance, + isFluxVAEModelFieldInputTemplate, isImageFieldInputInstance, isImageFieldInputTemplate, isIntegerFieldInputInstance, @@ -49,10 +53,12 @@ import { memo } from 'react'; import BoardFieldInputComponent from './inputs/BoardFieldInputComponent'; import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent'; +import CLIPEmbedModelFieldInputComponent from './inputs/CLIPEmbedModelFieldInputComponent'; import ColorFieldInputComponent from './inputs/ColorFieldInputComponent'; import ControlNetModelFieldInputComponent from './inputs/ControlNetModelFieldInputComponent'; import EnumFieldInputComponent from './inputs/EnumFieldInputComponent'; import FluxMainModelFieldInputComponent from './inputs/FluxMainModelFieldInputComponent'; +import FluxVAEModelFieldInputComponent from './inputs/FluxVAEModelFieldInputComponent'; import ImageFieldInputComponent from './inputs/ImageFieldInputComponent'; import IPAdapterModelFieldInputComponent from './inputs/IPAdapterModelFieldInputComponent'; import LoRAModelFieldInputComponent from './inputs/LoRAModelFieldInputComponent'; @@ -122,6 +128,13 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => { if (isT5EncoderModelFieldInputInstance(fieldInstance) && isT5EncoderModelFieldInputTemplate(fieldTemplate)) { return ; } + if (isCLIPEmbedModelFieldInputInstance(fieldInstance) && isCLIPEmbedModelFieldInputTemplate(fieldTemplate)) { + return ; + } + + if (isFluxVAEModelFieldInputInstance(fieldInstance) && isFluxVAEModelFieldInputTemplate(fieldTemplate)) { + return ; + } if (isLoRAModelFieldInputInstance(fieldInstance) && isLoRAModelFieldInputTemplate(fieldTemplate)) { return ; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/CLIPEmbedModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/CLIPEmbedModelFieldInputComponent.tsx new file mode 100644 index 0000000000..e3dc207420 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/CLIPEmbedModelFieldInputComponent.tsx @@ -0,0 +1,60 @@ +import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; +import { fieldCLIPEmbedValueChanged } from 'features/nodes/store/nodesSlice'; +import type { CLIPEmbedModelFieldInputInstance, CLIPEmbedModelFieldInputTemplate } from 'features/nodes/types/field'; +import { memo, useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { useClipEmbedModels } from 'services/api/hooks/modelsByType'; +import type { ClipEmbedModelConfig } from 'services/api/types'; + +import type { FieldComponentProps } from './types'; + +type Props = FieldComponentProps; + +const CLIPEmbedModelFieldInputComponent = (props: Props) => { + const { nodeId, field } = props; + const { t } = useTranslation(); + const disabledTabs = useAppSelector((s) => s.config.disabledTabs); + const dispatch = useAppDispatch(); + const [modelConfigs, { isLoading }] = useClipEmbedModels(); + const _onChange = useCallback( + (value: ClipEmbedModelConfig | null) => { + if (!value) { + return; + } + dispatch( + fieldCLIPEmbedValueChanged({ + nodeId, + fieldName: field.name, + value, + }) + ); + }, + [dispatch, field.name, nodeId] + ); + const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({ + modelConfigs, + onChange: _onChange, + isLoading, + selectedModel: field.value, + }); + + return ( + + + + + + + + ); +}; + +export default memo(CLIPEmbedModelFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/FluxVAEModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/FluxVAEModelFieldInputComponent.tsx new file mode 100644 index 0000000000..cc62d9153a --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/FluxVAEModelFieldInputComponent.tsx @@ -0,0 +1,60 @@ +import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; +import { fieldFluxVAEModelValueChanged } from 'features/nodes/store/nodesSlice'; +import type { FluxVAEModelFieldInputInstance, FluxVAEModelFieldInputTemplate } from 'features/nodes/types/field'; +import { memo, useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { useFluxVAEModels } from 'services/api/hooks/modelsByType'; +import type { VAEModelConfig } from 'services/api/types'; + +import type { FieldComponentProps } from './types'; + +type Props = FieldComponentProps; + +const FluxVAEModelFieldInputComponent = (props: Props) => { + const { nodeId, field } = props; + const { t } = useTranslation(); + const disabledTabs = useAppSelector((s) => s.config.disabledTabs); + const dispatch = useAppDispatch(); + const [modelConfigs, { isLoading }] = useFluxVAEModels(); + const _onChange = useCallback( + (value: VAEModelConfig | null) => { + if (!value) { + return; + } + dispatch( + fieldFluxVAEModelValueChanged({ + nodeId, + fieldName: field.name, + value, + }) + ); + }, + [dispatch, field.name, nodeId] + ); + const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({ + modelConfigs, + onChange: _onChange, + isLoading, + selectedModel: field.value, + }); + + return ( + + + + + + + + ); +}; + +export default memo(FluxVAEModelFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index 6bcd5f276e..8306225764 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -6,11 +6,13 @@ import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants'; import type { BoardFieldValue, BooleanFieldValue, + CLIPEmbedModelFieldValue, ColorFieldValue, ControlNetModelFieldValue, EnumFieldValue, FieldValue, FloatFieldValue, + FluxVAEModelFieldValue, ImageFieldValue, IntegerFieldValue, IPAdapterModelFieldValue, @@ -29,10 +31,12 @@ import type { import { zBoardFieldValue, zBooleanFieldValue, + zCLIPEmbedModelFieldValue, zColorFieldValue, zControlNetModelFieldValue, zEnumFieldValue, zFloatFieldValue, + zFluxVAEModelFieldValue, zImageFieldValue, zIntegerFieldValue, zIPAdapterModelFieldValue, @@ -346,6 +350,12 @@ export const nodesSlice = createSlice({ fieldT5EncoderValueChanged: (state, action: FieldValueAction) => { fieldValueReducer(state, action, zT5EncoderModelFieldValue); }, + fieldCLIPEmbedValueChanged: (state, action: FieldValueAction) => { + fieldValueReducer(state, action, zCLIPEmbedModelFieldValue); + }, + fieldFluxVAEModelValueChanged: (state, action: FieldValueAction) => { + fieldValueReducer(state, action, zFluxVAEModelFieldValue); + }, fieldEnumModelValueChanged: (state, action: FieldValueAction) => { fieldValueReducer(state, action, zEnumFieldValue); }, @@ -408,6 +418,8 @@ export const { fieldStringValueChanged, fieldVaeModelValueChanged, fieldT5EncoderValueChanged, + fieldCLIPEmbedValueChanged, + fieldFluxVAEModelValueChanged, nodeEditorReset, nodeIsIntermediateChanged, nodeIsOpenChanged, @@ -521,6 +533,8 @@ export const isAnyNodeOrEdgeMutation = isAnyOf( fieldStringValueChanged, fieldVaeModelValueChanged, fieldT5EncoderValueChanged, + fieldCLIPEmbedValueChanged, + fieldFluxVAEModelValueChanged, nodesChanged, nodeIsIntermediateChanged, nodeIsOpenChanged, diff --git a/invokeai/frontend/web/src/features/nodes/types/field.ts b/invokeai/frontend/web/src/features/nodes/types/field.ts index ee0f61a0fe..a4ca41c44e 100644 --- a/invokeai/frontend/web/src/features/nodes/types/field.ts +++ b/invokeai/frontend/web/src/features/nodes/types/field.ts @@ -151,6 +151,14 @@ const zT5EncoderModelFieldType = zFieldTypeBase.extend({ name: z.literal('T5EncoderModelField'), originalType: zStatelessFieldType.optional(), }); +const zCLIPEmbedModelFieldType = zFieldTypeBase.extend({ + name: z.literal('CLIPEmbedModelField'), + originalType: zStatelessFieldType.optional(), +}); +const zFluxVAEModelFieldType = zFieldTypeBase.extend({ + name: z.literal('FluxVAEModelField'), + originalType: zStatelessFieldType.optional(), +}); const zSchedulerFieldType = zFieldTypeBase.extend({ name: z.literal('SchedulerField'), originalType: zStatelessFieldType.optional(), @@ -175,6 +183,8 @@ const zStatefulFieldType = z.union([ zT2IAdapterModelFieldType, zSpandrelImageToImageModelFieldType, zT5EncoderModelFieldType, + zCLIPEmbedModelFieldType, + zFluxVAEModelFieldType, zColorFieldType, zSchedulerFieldType, ]); @@ -667,7 +677,53 @@ export const isT5EncoderModelFieldInputInstance = (val: unknown): val is T5Encod export const isT5EncoderModelFieldInputTemplate = (val: unknown): val is T5EncoderModelFieldInputTemplate => zT5EncoderModelFieldInputTemplate.safeParse(val).success; -// #endregio +// #endregion + +// #region FluxVAEModelField + +export const zFluxVAEModelFieldValue = zModelIdentifierField.optional(); +const zFluxVAEModelFieldInputInstance = zFieldInputInstanceBase.extend({ + value: zFluxVAEModelFieldValue, +}); +const zFluxVAEModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zFluxVAEModelFieldType, + originalType: zFieldType.optional(), + default: zFluxVAEModelFieldValue, +}); + +export type FluxVAEModelFieldValue = z.infer; + +export type FluxVAEModelFieldInputInstance = z.infer; +export type FluxVAEModelFieldInputTemplate = z.infer; +export const isFluxVAEModelFieldInputInstance = (val: unknown): val is FluxVAEModelFieldInputInstance => + zFluxVAEModelFieldInputInstance.safeParse(val).success; +export const isFluxVAEModelFieldInputTemplate = (val: unknown): val is FluxVAEModelFieldInputTemplate => + zFluxVAEModelFieldInputTemplate.safeParse(val).success; + +// #endregion + +// #region CLIPEmbedModelField + +export const zCLIPEmbedModelFieldValue = zModelIdentifierField.optional(); +const zCLIPEmbedModelFieldInputInstance = zFieldInputInstanceBase.extend({ + value: zCLIPEmbedModelFieldValue, +}); +const zCLIPEmbedModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zCLIPEmbedModelFieldType, + originalType: zFieldType.optional(), + default: zCLIPEmbedModelFieldValue, +}); + +export type CLIPEmbedModelFieldValue = z.infer; + +export type CLIPEmbedModelFieldInputInstance = z.infer; +export type CLIPEmbedModelFieldInputTemplate = z.infer; +export const isCLIPEmbedModelFieldInputInstance = (val: unknown): val is CLIPEmbedModelFieldInputInstance => + zCLIPEmbedModelFieldInputInstance.safeParse(val).success; +export const isCLIPEmbedModelFieldInputTemplate = (val: unknown): val is CLIPEmbedModelFieldInputTemplate => + zCLIPEmbedModelFieldInputTemplate.safeParse(val).success; + +// #endregion // #region SchedulerField @@ -758,6 +814,8 @@ export const zStatefulFieldValue = z.union([ zT2IAdapterModelFieldValue, zSpandrelImageToImageModelFieldValue, zT5EncoderModelFieldValue, + zFluxVAEModelFieldValue, + zCLIPEmbedModelFieldValue, zColorFieldValue, zSchedulerFieldValue, ]); @@ -788,6 +846,8 @@ const zStatefulFieldInputInstance = z.union([ zT2IAdapterModelFieldInputInstance, zSpandrelImageToImageModelFieldInputInstance, zT5EncoderModelFieldInputInstance, + zFluxVAEModelFieldInputInstance, + zCLIPEmbedModelFieldInputInstance, zColorFieldInputInstance, zSchedulerFieldInputInstance, ]); @@ -819,6 +879,8 @@ const zStatefulFieldInputTemplate = z.union([ zT2IAdapterModelFieldInputTemplate, zSpandrelImageToImageModelFieldInputTemplate, zT5EncoderModelFieldInputTemplate, + zFluxVAEModelFieldInputTemplate, + zCLIPEmbedModelFieldInputTemplate, zColorFieldInputTemplate, zSchedulerFieldInputTemplate, zStatelessFieldInputTemplate, diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts index 8afda4e2a7..45a9e28209 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts @@ -23,6 +23,8 @@ const FIELD_VALUE_FALLBACK_MAP: Record = VAEModelField: undefined, ControlNetModelField: undefined, T5EncoderModelField: undefined, + FluxVAEModelField: undefined, + CLIPEmbedModelField: undefined, }; export const buildFieldInputInstance = (id: string, template: FieldInputTemplate): FieldInputInstance => { diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts index 5149bd4d3a..d2aa49a1e3 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts @@ -2,6 +2,7 @@ import { FieldParseError } from 'features/nodes/types/error'; import type { BoardFieldInputTemplate, BooleanFieldInputTemplate, + CLIPEmbedModelFieldInputTemplate, ColorFieldInputTemplate, ControlNetModelFieldInputTemplate, EnumFieldInputTemplate, @@ -9,6 +10,7 @@ import type { FieldType, FloatFieldInputTemplate, FluxMainModelFieldInputTemplate, + FluxVAEModelFieldInputTemplate, ImageFieldInputTemplate, IntegerFieldInputTemplate, IPAdapterModelFieldInputTemplate, @@ -238,6 +240,34 @@ const buildT5EncoderModelFieldInputTemplate: FieldInputTemplateBuilder = ({ + schemaObject, + baseField, + fieldType, +}) => { + const template: CLIPEmbedModelFieldInputTemplate = { + ...baseField, + type: fieldType, + default: schemaObject.default ?? undefined, + }; + + return template; +}; + +const buildFluxVAEModelFieldInputTemplate: FieldInputTemplateBuilder = ({ + schemaObject, + baseField, + fieldType, +}) => { + const template: FluxVAEModelFieldInputTemplate = { + ...baseField, + type: fieldType, + default: schemaObject.default ?? undefined, + }; + + return template; +}; + const buildLoRAModelFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, @@ -423,6 +453,8 @@ export const TEMPLATE_BUILDER_MAP: Record { + return config.type === 'vae' && config.base === 'flux'; +}; + export const isControlNetModelConfig = (config: AnyModelConfig): config is ControlNetModelConfig => { return config.type === 'controlnet'; };