feat(ui): create new field for t5 encoder models in nodes

This commit is contained in:
Mary Hipp 2024-08-21 14:35:39 -04:00
parent ec6981a860
commit fe636bb6ca
8 changed files with 133 additions and 15 deletions

View File

@ -784,6 +784,7 @@
"simpleModelPlaceholder": "URL or path to a local file or diffusers folder",
"source": "Source",
"starterModels": "Starter Models",
"starterModelsInModelManager": "Starter Models can be found in Model Manager",
"syncModels": "Sync Models",
"textualInversions": "Textual Inversions",
"triggerPhrases": "Trigger Phrases",

View File

@ -40,6 +40,8 @@ import {
isStringFieldInputTemplate,
isT2IAdapterModelFieldInputInstance,
isT2IAdapterModelFieldInputTemplate,
isT5EncoderModelFieldInputInstance,
isT5EncoderModelFieldInputTemplate,
isVAEModelFieldInputInstance,
isVAEModelFieldInputTemplate,
} from 'features/nodes/types/field';
@ -62,6 +64,7 @@ import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputCo
import SpandrelImageToImageModelFieldInputComponent from './inputs/SpandrelImageToImageModelFieldInputComponent';
import StringFieldInputComponent from './inputs/StringFieldInputComponent';
import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent';
import T5EncoderModelFieldInputComponent from './inputs/T5EncoderModelFieldInputComponent';
import VAEModelFieldInputComponent from './inputs/VAEModelFieldInputComponent';
type InputFieldProps = {
@ -116,6 +119,10 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
return <VAEModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isT5EncoderModelFieldInputInstance(fieldInstance) && isT5EncoderModelFieldInputTemplate(fieldTemplate)) {
return <T5EncoderModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isLoRAModelFieldInputInstance(fieldInstance) && isLoRAModelFieldInputTemplate(fieldTemplate)) {
return <LoRAModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}

View File

@ -0,0 +1,60 @@
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldT5EncoderValueChanged } from 'features/nodes/store/nodesSlice';
import type { T5EncoderModelFieldInputInstance, T5EncoderModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useT5EncoderModels } from 'services/api/hooks/modelsByType';
import type { T5Encoder8bModelConfig, T5EncoderModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<T5EncoderModelFieldInputInstance, T5EncoderModelFieldInputTemplate>;
const T5EncoderModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const { t } = useTranslation();
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useT5EncoderModels();
const _onChange = useCallback(
(value: T5Encoder8bModelConfig | T5EncoderModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldT5EncoderValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
return (
<Flex w="full" alignItems="center" gap={2}>
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Tooltip>
</Flex>
);
};
export default memo(T5EncoderModelFieldInputComponent);

View File

@ -23,6 +23,7 @@ import type {
StatefulFieldValue,
StringFieldValue,
T2IAdapterModelFieldValue,
T5EncoderModelFieldValue,
VAEModelFieldValue,
} from 'features/nodes/types/field';
import {
@ -44,6 +45,7 @@ import {
zStatefulFieldValue,
zStringFieldValue,
zT2IAdapterModelFieldValue,
zT5EncoderModelFieldValue,
zVAEModelFieldValue,
} from 'features/nodes/types/field';
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
@ -341,6 +343,9 @@ export const nodesSlice = createSlice({
) => {
fieldValueReducer(state, action, zSpandrelImageToImageModelFieldValue);
},
fieldT5EncoderValueChanged: (state, action: FieldValueAction<T5EncoderModelFieldValue>) => {
fieldValueReducer(state, action, zT5EncoderModelFieldValue);
},
fieldEnumModelValueChanged: (state, action: FieldValueAction<EnumFieldValue>) => {
fieldValueReducer(state, action, zEnumFieldValue);
},
@ -402,6 +407,7 @@ export const {
fieldSchedulerValueChanged,
fieldStringValueChanged,
fieldVaeModelValueChanged,
fieldT5EncoderValueChanged,
nodeEditorReset,
nodeIsIntermediateChanged,
nodeIsOpenChanged,
@ -514,6 +520,7 @@ export const isAnyNodeOrEdgeMutation = isAnyOf(
fieldSchedulerValueChanged,
fieldStringValueChanged,
fieldVaeModelValueChanged,
fieldT5EncoderValueChanged,
nodesChanged,
nodeIsIntermediateChanged,
nodeIsOpenChanged,

View File

@ -73,6 +73,7 @@ const zModelType = z.enum([
'onnx',
'clip_vision',
'spandrel_image_to_image',
't5_encoder',
]);
const zSubModelType = z.enum([
'unet',

View File

@ -147,6 +147,10 @@ const zSpandrelImageToImageModelFieldType = zFieldTypeBase.extend({
name: z.literal('SpandrelImageToImageModelField'),
originalType: zStatelessFieldType.optional(),
});
const zT5EncoderModelFieldType = zFieldTypeBase.extend({
name: z.literal('T5EncoderModelField'),
originalType: zStatelessFieldType.optional(),
});
const zSchedulerFieldType = zFieldTypeBase.extend({
name: z.literal('SchedulerField'),
originalType: zStatelessFieldType.optional(),
@ -170,6 +174,7 @@ const zStatefulFieldType = z.union([
zIPAdapterModelFieldType,
zT2IAdapterModelFieldType,
zSpandrelImageToImageModelFieldType,
zT5EncoderModelFieldType,
zColorFieldType,
zSchedulerFieldType,
]);
@ -641,6 +646,29 @@ export const isSpandrelImageToImageModelFieldInputTemplate = (
zSpandrelImageToImageModelFieldInputTemplate.safeParse(val).success;
// #endregion
// #region T5EncoderModelField
export const zT5EncoderModelFieldValue = zModelIdentifierField.optional();
const zT5EncoderModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zT5EncoderModelFieldValue,
});
const zT5EncoderModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zT5EncoderModelFieldType,
originalType: zFieldType.optional(),
default: zT5EncoderModelFieldValue,
});
export type T5EncoderModelFieldValue = z.infer<typeof zT5EncoderModelFieldValue>;
export type T5EncoderModelFieldInputInstance = z.infer<typeof zT5EncoderModelFieldInputInstance>;
export type T5EncoderModelFieldInputTemplate = z.infer<typeof zT5EncoderModelFieldInputTemplate>;
export const isT5EncoderModelFieldInputInstance = (val: unknown): val is T5EncoderModelFieldInputInstance =>
zT5EncoderModelFieldInputInstance.safeParse(val).success;
export const isT5EncoderModelFieldInputTemplate = (val: unknown): val is T5EncoderModelFieldInputTemplate =>
zT5EncoderModelFieldInputTemplate.safeParse(val).success;
// #endregio
// #region SchedulerField
export const zSchedulerFieldValue = zSchedulerField.optional();
@ -729,6 +757,7 @@ export const zStatefulFieldValue = z.union([
zIPAdapterModelFieldValue,
zT2IAdapterModelFieldValue,
zSpandrelImageToImageModelFieldValue,
zT5EncoderModelFieldValue,
zColorFieldValue,
zSchedulerFieldValue,
]);
@ -758,6 +787,7 @@ const zStatefulFieldInputInstance = z.union([
zIPAdapterModelFieldInputInstance,
zT2IAdapterModelFieldInputInstance,
zSpandrelImageToImageModelFieldInputInstance,
zT5EncoderModelFieldInputInstance,
zColorFieldInputInstance,
zSchedulerFieldInputInstance,
]);
@ -788,6 +818,7 @@ const zStatefulFieldInputTemplate = z.union([
zIPAdapterModelFieldInputTemplate,
zT2IAdapterModelFieldInputTemplate,
zSpandrelImageToImageModelFieldInputTemplate,
zT5EncoderModelFieldInputTemplate,
zColorFieldInputTemplate,
zSchedulerFieldInputTemplate,
zStatelessFieldInputTemplate,

View File

@ -23,6 +23,7 @@ import type {
StatelessFieldInputTemplate,
StringFieldInputTemplate,
T2IAdapterModelFieldInputTemplate,
T5EncoderModelFieldInputTemplate,
VAEModelFieldInputTemplate,
} from 'features/nodes/types/field';
import { isStatefulFieldType } from 'features/nodes/types/field';
@ -223,6 +224,20 @@ const buildVAEModelFieldInputTemplate: FieldInputTemplateBuilder<VAEModelFieldIn
return template;
};
const buildT5EncoderModelFieldInputTemplate: FieldInputTemplateBuilder<T5EncoderModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: T5EncoderModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildLoRAModelFieldInputTemplate: FieldInputTemplateBuilder<LoRAModelFieldInputTemplate> = ({
schemaObject,
baseField,
@ -407,6 +422,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
T2IAdapterModelField: buildT2IAdapterModelFieldInputTemplate,
SpandrelImageToImageModelField: buildSpandrelImageToImageModelFieldInputTemplate,
VAEModelField: buildVAEModelFieldInputTemplate,
T5EncoderModelField: buildT5EncoderModelFieldInputTemplate,
} as const;
export const buildFieldInputTemplate = (

View File

@ -5742,15 +5742,10 @@ export type components = {
* @default true
*/
use_cache?: boolean;
/** @description Flux model (Transformer, VAE, CLIP) to load */
/** @description Flux model (Transformer) to load */
model: components["schemas"]["ModelIdentifierField"];
/**
* T5 Encoder
* @description The T5 Encoder model to use.
* @default null
* @enum {string}
*/
t5_encoder?: "base" | "8b_quantized";
/** @description T5 tokenizer and text encoder */
t5_encoder: components["schemas"]["ModelIdentifierField"];
/**
* type
* @default flux_model_loader
@ -5833,12 +5828,12 @@ export type components = {
*/
t5_encoder?: components["schemas"]["T5EncoderField"];
/**
* Max Seq Len
* @description Max sequence length for the desired flux model
* T5 Max Seq Len
* @description Max sequence length for the T5 encoder. Expected to be 256 for FLUX schnell models and 512 for FLUX dev models.
* @default null
* @enum {integer}
*/
max_seq_len?: 256 | 512;
t5_max_seq_len?: 256 | 512;
/**
* Positive Prompt
* @description Positive prompt for text-to-image generation.
@ -5887,7 +5882,7 @@ export type components = {
use_cache?: boolean;
/**
* Transformer
* @description UNet (scheduler, LoRAs)
* @description Flux model (Transformer) to load
* @default null
*/
transformer?: components["schemas"]["TransformerField"];
@ -5915,13 +5910,13 @@ export type components = {
height?: number;
/**
* Num Steps
* @description Number of diffusion steps.
* @description Number of diffusion steps. Recommend values are schnell: 4, dev: 50.
* @default 4
*/
num_steps?: number;
/**
* Guidance
* @description The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images.
* @description The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.
* @default 4
*/
guidance?: number;
@ -15074,7 +15069,7 @@ export type components = {
* used, and the type will be ignored. They are included here for backwards compatibility.
* @enum {string}
*/
UIType: "MainModelField" | "FluxMainModelField" | "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VAEModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "T2IAdapterModelField" | "SpandrelImageToImageModelField" | "SchedulerField" | "AnyField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict";
UIType: "MainModelField" | "FluxMainModelField" | "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VAEModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "T2IAdapterModelField" | "T5EncoderModelField" | "SpandrelImageToImageModelField" | "SchedulerField" | "AnyField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict";
/** UNetField */
UNetField: {
/** @description Info to load unet submodel */