feat(ui): create new field for t5 encoder models in nodes

This commit is contained in:
Mary Hipp 2024-08-21 14:35:39 -04:00
parent ec6981a860
commit fe636bb6ca
8 changed files with 133 additions and 15 deletions

View File

@ -784,6 +784,7 @@
"simpleModelPlaceholder": "URL or path to a local file or diffusers folder", "simpleModelPlaceholder": "URL or path to a local file or diffusers folder",
"source": "Source", "source": "Source",
"starterModels": "Starter Models", "starterModels": "Starter Models",
"starterModelsInModelManager": "Starter Models can be found in Model Manager",
"syncModels": "Sync Models", "syncModels": "Sync Models",
"textualInversions": "Textual Inversions", "textualInversions": "Textual Inversions",
"triggerPhrases": "Trigger Phrases", "triggerPhrases": "Trigger Phrases",

View File

@ -40,6 +40,8 @@ import {
isStringFieldInputTemplate, isStringFieldInputTemplate,
isT2IAdapterModelFieldInputInstance, isT2IAdapterModelFieldInputInstance,
isT2IAdapterModelFieldInputTemplate, isT2IAdapterModelFieldInputTemplate,
isT5EncoderModelFieldInputInstance,
isT5EncoderModelFieldInputTemplate,
isVAEModelFieldInputInstance, isVAEModelFieldInputInstance,
isVAEModelFieldInputTemplate, isVAEModelFieldInputTemplate,
} from 'features/nodes/types/field'; } from 'features/nodes/types/field';
@ -62,6 +64,7 @@ import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputCo
import SpandrelImageToImageModelFieldInputComponent from './inputs/SpandrelImageToImageModelFieldInputComponent'; import SpandrelImageToImageModelFieldInputComponent from './inputs/SpandrelImageToImageModelFieldInputComponent';
import StringFieldInputComponent from './inputs/StringFieldInputComponent'; import StringFieldInputComponent from './inputs/StringFieldInputComponent';
import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent'; import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent';
import T5EncoderModelFieldInputComponent from './inputs/T5EncoderModelFieldInputComponent';
import VAEModelFieldInputComponent from './inputs/VAEModelFieldInputComponent'; import VAEModelFieldInputComponent from './inputs/VAEModelFieldInputComponent';
type InputFieldProps = { type InputFieldProps = {
@ -116,6 +119,10 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
return <VAEModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />; return <VAEModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
} }
if (isT5EncoderModelFieldInputInstance(fieldInstance) && isT5EncoderModelFieldInputTemplate(fieldTemplate)) {
return <T5EncoderModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isLoRAModelFieldInputInstance(fieldInstance) && isLoRAModelFieldInputTemplate(fieldTemplate)) { if (isLoRAModelFieldInputInstance(fieldInstance) && isLoRAModelFieldInputTemplate(fieldTemplate)) {
return <LoRAModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />; return <LoRAModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
} }

View File

@ -0,0 +1,60 @@
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldT5EncoderValueChanged } from 'features/nodes/store/nodesSlice';
import type { T5EncoderModelFieldInputInstance, T5EncoderModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useT5EncoderModels } from 'services/api/hooks/modelsByType';
import type { T5Encoder8bModelConfig, T5EncoderModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<T5EncoderModelFieldInputInstance, T5EncoderModelFieldInputTemplate>;
const T5EncoderModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const { t } = useTranslation();
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useT5EncoderModels();
const _onChange = useCallback(
(value: T5Encoder8bModelConfig | T5EncoderModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldT5EncoderValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
return (
<Flex w="full" alignItems="center" gap={2}>
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Tooltip>
</Flex>
);
};
export default memo(T5EncoderModelFieldInputComponent);

View File

@ -23,6 +23,7 @@ import type {
StatefulFieldValue, StatefulFieldValue,
StringFieldValue, StringFieldValue,
T2IAdapterModelFieldValue, T2IAdapterModelFieldValue,
T5EncoderModelFieldValue,
VAEModelFieldValue, VAEModelFieldValue,
} from 'features/nodes/types/field'; } from 'features/nodes/types/field';
import { import {
@ -44,6 +45,7 @@ import {
zStatefulFieldValue, zStatefulFieldValue,
zStringFieldValue, zStringFieldValue,
zT2IAdapterModelFieldValue, zT2IAdapterModelFieldValue,
zT5EncoderModelFieldValue,
zVAEModelFieldValue, zVAEModelFieldValue,
} from 'features/nodes/types/field'; } from 'features/nodes/types/field';
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation'; import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
@ -341,6 +343,9 @@ export const nodesSlice = createSlice({
) => { ) => {
fieldValueReducer(state, action, zSpandrelImageToImageModelFieldValue); fieldValueReducer(state, action, zSpandrelImageToImageModelFieldValue);
}, },
fieldT5EncoderValueChanged: (state, action: FieldValueAction<T5EncoderModelFieldValue>) => {
fieldValueReducer(state, action, zT5EncoderModelFieldValue);
},
fieldEnumModelValueChanged: (state, action: FieldValueAction<EnumFieldValue>) => { fieldEnumModelValueChanged: (state, action: FieldValueAction<EnumFieldValue>) => {
fieldValueReducer(state, action, zEnumFieldValue); fieldValueReducer(state, action, zEnumFieldValue);
}, },
@ -402,6 +407,7 @@ export const {
fieldSchedulerValueChanged, fieldSchedulerValueChanged,
fieldStringValueChanged, fieldStringValueChanged,
fieldVaeModelValueChanged, fieldVaeModelValueChanged,
fieldT5EncoderValueChanged,
nodeEditorReset, nodeEditorReset,
nodeIsIntermediateChanged, nodeIsIntermediateChanged,
nodeIsOpenChanged, nodeIsOpenChanged,
@ -514,6 +520,7 @@ export const isAnyNodeOrEdgeMutation = isAnyOf(
fieldSchedulerValueChanged, fieldSchedulerValueChanged,
fieldStringValueChanged, fieldStringValueChanged,
fieldVaeModelValueChanged, fieldVaeModelValueChanged,
fieldT5EncoderValueChanged,
nodesChanged, nodesChanged,
nodeIsIntermediateChanged, nodeIsIntermediateChanged,
nodeIsOpenChanged, nodeIsOpenChanged,

View File

@ -73,6 +73,7 @@ const zModelType = z.enum([
'onnx', 'onnx',
'clip_vision', 'clip_vision',
'spandrel_image_to_image', 'spandrel_image_to_image',
't5_encoder',
]); ]);
const zSubModelType = z.enum([ const zSubModelType = z.enum([
'unet', 'unet',

View File

@ -147,6 +147,10 @@ const zSpandrelImageToImageModelFieldType = zFieldTypeBase.extend({
name: z.literal('SpandrelImageToImageModelField'), name: z.literal('SpandrelImageToImageModelField'),
originalType: zStatelessFieldType.optional(), originalType: zStatelessFieldType.optional(),
}); });
const zT5EncoderModelFieldType = zFieldTypeBase.extend({
name: z.literal('T5EncoderModelField'),
originalType: zStatelessFieldType.optional(),
});
const zSchedulerFieldType = zFieldTypeBase.extend({ const zSchedulerFieldType = zFieldTypeBase.extend({
name: z.literal('SchedulerField'), name: z.literal('SchedulerField'),
originalType: zStatelessFieldType.optional(), originalType: zStatelessFieldType.optional(),
@ -170,6 +174,7 @@ const zStatefulFieldType = z.union([
zIPAdapterModelFieldType, zIPAdapterModelFieldType,
zT2IAdapterModelFieldType, zT2IAdapterModelFieldType,
zSpandrelImageToImageModelFieldType, zSpandrelImageToImageModelFieldType,
zT5EncoderModelFieldType,
zColorFieldType, zColorFieldType,
zSchedulerFieldType, zSchedulerFieldType,
]); ]);
@ -641,6 +646,29 @@ export const isSpandrelImageToImageModelFieldInputTemplate = (
zSpandrelImageToImageModelFieldInputTemplate.safeParse(val).success; zSpandrelImageToImageModelFieldInputTemplate.safeParse(val).success;
// #endregion // #endregion
// #region T5EncoderModelField
export const zT5EncoderModelFieldValue = zModelIdentifierField.optional();
const zT5EncoderModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zT5EncoderModelFieldValue,
});
const zT5EncoderModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zT5EncoderModelFieldType,
originalType: zFieldType.optional(),
default: zT5EncoderModelFieldValue,
});
export type T5EncoderModelFieldValue = z.infer<typeof zT5EncoderModelFieldValue>;
export type T5EncoderModelFieldInputInstance = z.infer<typeof zT5EncoderModelFieldInputInstance>;
export type T5EncoderModelFieldInputTemplate = z.infer<typeof zT5EncoderModelFieldInputTemplate>;
export const isT5EncoderModelFieldInputInstance = (val: unknown): val is T5EncoderModelFieldInputInstance =>
zT5EncoderModelFieldInputInstance.safeParse(val).success;
export const isT5EncoderModelFieldInputTemplate = (val: unknown): val is T5EncoderModelFieldInputTemplate =>
zT5EncoderModelFieldInputTemplate.safeParse(val).success;
// #endregio
// #region SchedulerField // #region SchedulerField
export const zSchedulerFieldValue = zSchedulerField.optional(); export const zSchedulerFieldValue = zSchedulerField.optional();
@ -729,6 +757,7 @@ export const zStatefulFieldValue = z.union([
zIPAdapterModelFieldValue, zIPAdapterModelFieldValue,
zT2IAdapterModelFieldValue, zT2IAdapterModelFieldValue,
zSpandrelImageToImageModelFieldValue, zSpandrelImageToImageModelFieldValue,
zT5EncoderModelFieldValue,
zColorFieldValue, zColorFieldValue,
zSchedulerFieldValue, zSchedulerFieldValue,
]); ]);
@ -758,6 +787,7 @@ const zStatefulFieldInputInstance = z.union([
zIPAdapterModelFieldInputInstance, zIPAdapterModelFieldInputInstance,
zT2IAdapterModelFieldInputInstance, zT2IAdapterModelFieldInputInstance,
zSpandrelImageToImageModelFieldInputInstance, zSpandrelImageToImageModelFieldInputInstance,
zT5EncoderModelFieldInputInstance,
zColorFieldInputInstance, zColorFieldInputInstance,
zSchedulerFieldInputInstance, zSchedulerFieldInputInstance,
]); ]);
@ -788,6 +818,7 @@ const zStatefulFieldInputTemplate = z.union([
zIPAdapterModelFieldInputTemplate, zIPAdapterModelFieldInputTemplate,
zT2IAdapterModelFieldInputTemplate, zT2IAdapterModelFieldInputTemplate,
zSpandrelImageToImageModelFieldInputTemplate, zSpandrelImageToImageModelFieldInputTemplate,
zT5EncoderModelFieldInputTemplate,
zColorFieldInputTemplate, zColorFieldInputTemplate,
zSchedulerFieldInputTemplate, zSchedulerFieldInputTemplate,
zStatelessFieldInputTemplate, zStatelessFieldInputTemplate,

View File

@ -23,6 +23,7 @@ import type {
StatelessFieldInputTemplate, StatelessFieldInputTemplate,
StringFieldInputTemplate, StringFieldInputTemplate,
T2IAdapterModelFieldInputTemplate, T2IAdapterModelFieldInputTemplate,
T5EncoderModelFieldInputTemplate,
VAEModelFieldInputTemplate, VAEModelFieldInputTemplate,
} from 'features/nodes/types/field'; } from 'features/nodes/types/field';
import { isStatefulFieldType } from 'features/nodes/types/field'; import { isStatefulFieldType } from 'features/nodes/types/field';
@ -223,6 +224,20 @@ const buildVAEModelFieldInputTemplate: FieldInputTemplateBuilder<VAEModelFieldIn
return template; return template;
}; };
const buildT5EncoderModelFieldInputTemplate: FieldInputTemplateBuilder<T5EncoderModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: T5EncoderModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildLoRAModelFieldInputTemplate: FieldInputTemplateBuilder<LoRAModelFieldInputTemplate> = ({ const buildLoRAModelFieldInputTemplate: FieldInputTemplateBuilder<LoRAModelFieldInputTemplate> = ({
schemaObject, schemaObject,
baseField, baseField,
@ -407,6 +422,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
T2IAdapterModelField: buildT2IAdapterModelFieldInputTemplate, T2IAdapterModelField: buildT2IAdapterModelFieldInputTemplate,
SpandrelImageToImageModelField: buildSpandrelImageToImageModelFieldInputTemplate, SpandrelImageToImageModelField: buildSpandrelImageToImageModelFieldInputTemplate,
VAEModelField: buildVAEModelFieldInputTemplate, VAEModelField: buildVAEModelFieldInputTemplate,
T5EncoderModelField: buildT5EncoderModelFieldInputTemplate,
} as const; } as const;
export const buildFieldInputTemplate = ( export const buildFieldInputTemplate = (

View File

@ -5742,15 +5742,10 @@ export type components = {
* @default true * @default true
*/ */
use_cache?: boolean; use_cache?: boolean;
/** @description Flux model (Transformer, VAE, CLIP) to load */ /** @description Flux model (Transformer) to load */
model: components["schemas"]["ModelIdentifierField"]; model: components["schemas"]["ModelIdentifierField"];
/** /** @description T5 tokenizer and text encoder */
* T5 Encoder t5_encoder: components["schemas"]["ModelIdentifierField"];
* @description The T5 Encoder model to use.
* @default null
* @enum {string}
*/
t5_encoder?: "base" | "8b_quantized";
/** /**
* type * type
* @default flux_model_loader * @default flux_model_loader
@ -5833,12 +5828,12 @@ export type components = {
*/ */
t5_encoder?: components["schemas"]["T5EncoderField"]; t5_encoder?: components["schemas"]["T5EncoderField"];
/** /**
* Max Seq Len * T5 Max Seq Len
* @description Max sequence length for the desired flux model * @description Max sequence length for the T5 encoder. Expected to be 256 for FLUX schnell models and 512 for FLUX dev models.
* @default null * @default null
* @enum {integer} * @enum {integer}
*/ */
max_seq_len?: 256 | 512; t5_max_seq_len?: 256 | 512;
/** /**
* Positive Prompt * Positive Prompt
* @description Positive prompt for text-to-image generation. * @description Positive prompt for text-to-image generation.
@ -5887,7 +5882,7 @@ export type components = {
use_cache?: boolean; use_cache?: boolean;
/** /**
* Transformer * Transformer
* @description UNet (scheduler, LoRAs) * @description Flux model (Transformer) to load
* @default null * @default null
*/ */
transformer?: components["schemas"]["TransformerField"]; transformer?: components["schemas"]["TransformerField"];
@ -5915,13 +5910,13 @@ export type components = {
height?: number; height?: number;
/** /**
* Num Steps * Num Steps
* @description Number of diffusion steps. * @description Number of diffusion steps. Recommend values are schnell: 4, dev: 50.
* @default 4 * @default 4
*/ */
num_steps?: number; num_steps?: number;
/** /**
* Guidance * Guidance
* @description The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. * @description The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.
* @default 4 * @default 4
*/ */
guidance?: number; guidance?: number;
@ -15074,7 +15069,7 @@ export type components = {
* used, and the type will be ignored. They are included here for backwards compatibility. * used, and the type will be ignored. They are included here for backwards compatibility.
* @enum {string} * @enum {string}
*/ */
UIType: "MainModelField" | "FluxMainModelField" | "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VAEModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "T2IAdapterModelField" | "SpandrelImageToImageModelField" | "SchedulerField" | "AnyField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict"; UIType: "MainModelField" | "FluxMainModelField" | "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VAEModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "T2IAdapterModelField" | "T5EncoderModelField" | "SpandrelImageToImageModelField" | "SchedulerField" | "AnyField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict";
/** UNetField */ /** UNetField */
UNetField: { UNetField: {
/** @description Info to load unet submodel */ /** @description Info to load unet submodel */