From 269388c9f422d1908f69ac33cc45301d710688f2 Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Wed, 21 Aug 2024 14:35:39 -0400 Subject: [PATCH] feat(ui): create new field for t5 encoder models in nodes --- invokeai/frontend/web/public/locales/en.json | 1 + .../Invocation/fields/InputFieldRenderer.tsx | 7 +++ .../T5EncoderModelFieldInputComponent.tsx | 60 +++++++++++++++++++ .../src/features/nodes/store/nodesSlice.ts | 7 +++ .../web/src/features/nodes/types/common.ts | 1 + .../web/src/features/nodes/types/field.ts | 31 ++++++++++ .../util/schema/buildFieldInputTemplate.ts | 16 +++++ .../frontend/web/src/services/api/schema.ts | 25 ++++---- 8 files changed, 133 insertions(+), 15 deletions(-) create mode 100644 invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T5EncoderModelFieldInputComponent.tsx diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 1737bd4f29..a9ece94b96 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -784,6 +784,7 @@ "simpleModelPlaceholder": "URL or path to a local file or diffusers folder", "source": "Source", "starterModels": "Starter Models", + "starterModelsInModelManager": "Starter Models can be found in Model Manager", "syncModels": "Sync Models", "textualInversions": "Textual Inversions", "triggerPhrases": "Trigger Phrases", diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx index ba09ce6840..c4e8da6eda 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx @@ -40,6 +40,8 @@ import { isStringFieldInputTemplate, isT2IAdapterModelFieldInputInstance, isT2IAdapterModelFieldInputTemplate, + isT5EncoderModelFieldInputInstance, + isT5EncoderModelFieldInputTemplate, isVAEModelFieldInputInstance, isVAEModelFieldInputTemplate, } from 'features/nodes/types/field'; @@ -62,6 +64,7 @@ import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputCo import SpandrelImageToImageModelFieldInputComponent from './inputs/SpandrelImageToImageModelFieldInputComponent'; import StringFieldInputComponent from './inputs/StringFieldInputComponent'; import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent'; +import T5EncoderModelFieldInputComponent from './inputs/T5EncoderModelFieldInputComponent'; import VAEModelFieldInputComponent from './inputs/VAEModelFieldInputComponent'; type InputFieldProps = { @@ -116,6 +119,10 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => { return ; } + if (isT5EncoderModelFieldInputInstance(fieldInstance) && isT5EncoderModelFieldInputTemplate(fieldTemplate)) { + return ; + } + if (isLoRAModelFieldInputInstance(fieldInstance) && isLoRAModelFieldInputTemplate(fieldTemplate)) { return ; } diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T5EncoderModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T5EncoderModelFieldInputComponent.tsx new file mode 100644 index 0000000000..d92163c9c3 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T5EncoderModelFieldInputComponent.tsx @@ -0,0 +1,60 @@ +import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; +import { fieldT5EncoderValueChanged } from 'features/nodes/store/nodesSlice'; +import type { T5EncoderModelFieldInputInstance, T5EncoderModelFieldInputTemplate } from 'features/nodes/types/field'; +import { memo, useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { useT5EncoderModels } from 'services/api/hooks/modelsByType'; +import type { T5Encoder8bModelConfig, T5EncoderModelConfig } from 'services/api/types'; + +import type { FieldComponentProps } from './types'; + +type Props = FieldComponentProps; + +const T5EncoderModelFieldInputComponent = (props: Props) => { + const { nodeId, field } = props; + const { t } = useTranslation(); + const disabledTabs = useAppSelector((s) => s.config.disabledTabs); + const dispatch = useAppDispatch(); + const [modelConfigs, { isLoading }] = useT5EncoderModels(); + const _onChange = useCallback( + (value: T5Encoder8bModelConfig | T5EncoderModelConfig | null) => { + if (!value) { + return; + } + dispatch( + fieldT5EncoderValueChanged({ + nodeId, + fieldName: field.name, + value, + }) + ); + }, + [dispatch, field.name, nodeId] + ); + const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({ + modelConfigs, + onChange: _onChange, + isLoading, + selectedModel: field.value, + }); + + return ( + + + + + + + + ); +}; + +export default memo(T5EncoderModelFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index f9214c1572..6bcd5f276e 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -23,6 +23,7 @@ import type { StatefulFieldValue, StringFieldValue, T2IAdapterModelFieldValue, + T5EncoderModelFieldValue, VAEModelFieldValue, } from 'features/nodes/types/field'; import { @@ -44,6 +45,7 @@ import { zStatefulFieldValue, zStringFieldValue, zT2IAdapterModelFieldValue, + zT5EncoderModelFieldValue, zVAEModelFieldValue, } from 'features/nodes/types/field'; import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation'; @@ -341,6 +343,9 @@ export const nodesSlice = createSlice({ ) => { fieldValueReducer(state, action, zSpandrelImageToImageModelFieldValue); }, + fieldT5EncoderValueChanged: (state, action: FieldValueAction) => { + fieldValueReducer(state, action, zT5EncoderModelFieldValue); + }, fieldEnumModelValueChanged: (state, action: FieldValueAction) => { fieldValueReducer(state, action, zEnumFieldValue); }, @@ -402,6 +407,7 @@ export const { fieldSchedulerValueChanged, fieldStringValueChanged, fieldVaeModelValueChanged, + fieldT5EncoderValueChanged, nodeEditorReset, nodeIsIntermediateChanged, nodeIsOpenChanged, @@ -514,6 +520,7 @@ export const isAnyNodeOrEdgeMutation = isAnyOf( fieldSchedulerValueChanged, fieldStringValueChanged, fieldVaeModelValueChanged, + fieldT5EncoderValueChanged, nodesChanged, nodeIsIntermediateChanged, nodeIsOpenChanged, diff --git a/invokeai/frontend/web/src/features/nodes/types/common.ts b/invokeai/frontend/web/src/features/nodes/types/common.ts index 894d257f28..e806271345 100644 --- a/invokeai/frontend/web/src/features/nodes/types/common.ts +++ b/invokeai/frontend/web/src/features/nodes/types/common.ts @@ -73,6 +73,7 @@ const zModelType = z.enum([ 'onnx', 'clip_vision', 'spandrel_image_to_image', + 't5_encoder', ]); const zSubModelType = z.enum([ 'unet', diff --git a/invokeai/frontend/web/src/features/nodes/types/field.ts b/invokeai/frontend/web/src/features/nodes/types/field.ts index 607a1005ac..ee0f61a0fe 100644 --- a/invokeai/frontend/web/src/features/nodes/types/field.ts +++ b/invokeai/frontend/web/src/features/nodes/types/field.ts @@ -147,6 +147,10 @@ const zSpandrelImageToImageModelFieldType = zFieldTypeBase.extend({ name: z.literal('SpandrelImageToImageModelField'), originalType: zStatelessFieldType.optional(), }); +const zT5EncoderModelFieldType = zFieldTypeBase.extend({ + name: z.literal('T5EncoderModelField'), + originalType: zStatelessFieldType.optional(), +}); const zSchedulerFieldType = zFieldTypeBase.extend({ name: z.literal('SchedulerField'), originalType: zStatelessFieldType.optional(), @@ -170,6 +174,7 @@ const zStatefulFieldType = z.union([ zIPAdapterModelFieldType, zT2IAdapterModelFieldType, zSpandrelImageToImageModelFieldType, + zT5EncoderModelFieldType, zColorFieldType, zSchedulerFieldType, ]); @@ -641,6 +646,29 @@ export const isSpandrelImageToImageModelFieldInputTemplate = ( zSpandrelImageToImageModelFieldInputTemplate.safeParse(val).success; // #endregion +// #region T5EncoderModelField + +export const zT5EncoderModelFieldValue = zModelIdentifierField.optional(); +const zT5EncoderModelFieldInputInstance = zFieldInputInstanceBase.extend({ + value: zT5EncoderModelFieldValue, +}); +const zT5EncoderModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zT5EncoderModelFieldType, + originalType: zFieldType.optional(), + default: zT5EncoderModelFieldValue, +}); + +export type T5EncoderModelFieldValue = z.infer; + +export type T5EncoderModelFieldInputInstance = z.infer; +export type T5EncoderModelFieldInputTemplate = z.infer; +export const isT5EncoderModelFieldInputInstance = (val: unknown): val is T5EncoderModelFieldInputInstance => + zT5EncoderModelFieldInputInstance.safeParse(val).success; +export const isT5EncoderModelFieldInputTemplate = (val: unknown): val is T5EncoderModelFieldInputTemplate => + zT5EncoderModelFieldInputTemplate.safeParse(val).success; + +// #endregio + // #region SchedulerField export const zSchedulerFieldValue = zSchedulerField.optional(); @@ -729,6 +757,7 @@ export const zStatefulFieldValue = z.union([ zIPAdapterModelFieldValue, zT2IAdapterModelFieldValue, zSpandrelImageToImageModelFieldValue, + zT5EncoderModelFieldValue, zColorFieldValue, zSchedulerFieldValue, ]); @@ -758,6 +787,7 @@ const zStatefulFieldInputInstance = z.union([ zIPAdapterModelFieldInputInstance, zT2IAdapterModelFieldInputInstance, zSpandrelImageToImageModelFieldInputInstance, + zT5EncoderModelFieldInputInstance, zColorFieldInputInstance, zSchedulerFieldInputInstance, ]); @@ -788,6 +818,7 @@ const zStatefulFieldInputTemplate = z.union([ zIPAdapterModelFieldInputTemplate, zT2IAdapterModelFieldInputTemplate, zSpandrelImageToImageModelFieldInputTemplate, + zT5EncoderModelFieldInputTemplate, zColorFieldInputTemplate, zSchedulerFieldInputTemplate, zStatelessFieldInputTemplate, diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts index f4f3ef85af..5149bd4d3a 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts @@ -23,6 +23,7 @@ import type { StatelessFieldInputTemplate, StringFieldInputTemplate, T2IAdapterModelFieldInputTemplate, + T5EncoderModelFieldInputTemplate, VAEModelFieldInputTemplate, } from 'features/nodes/types/field'; import { isStatefulFieldType } from 'features/nodes/types/field'; @@ -223,6 +224,20 @@ const buildVAEModelFieldInputTemplate: FieldInputTemplateBuilder = ({ + schemaObject, + baseField, + fieldType, +}) => { + const template: T5EncoderModelFieldInputTemplate = { + ...baseField, + type: fieldType, + default: schemaObject.default ?? undefined, + }; + + return template; +}; + const buildLoRAModelFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, @@ -407,6 +422,7 @@ export const TEMPLATE_BUILDER_MAP: Record