Setup flux model loading in the UI

This commit is contained in:
Brandon Rising
2024-08-12 14:04:23 -04:00
committed by Brandon
parent 1fa6bddc89
commit 5f59a828f9
22 changed files with 463 additions and 18 deletions

View File

@ -13,6 +13,7 @@ const BASE_COLOR_MAP: Record<BaseModelType, string> = {
'sd-2': 'teal',
sdxl: 'invokeBlue',
'sdxl-refiner': 'invokeBlue',
flux: 'invokeBlue',
};
const ModelBaseBadge = ({ base }: Props) => {

View File

@ -14,6 +14,8 @@ import {
isEnumFieldInputTemplate,
isFloatFieldInputInstance,
isFloatFieldInputTemplate,
isFluxMainModelFieldInputInstance,
isFluxMainModelFieldInputTemplate,
isImageFieldInputInstance,
isImageFieldInputTemplate,
isIntegerFieldInputInstance,
@ -48,6 +50,7 @@ import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent';
import ColorFieldInputComponent from './inputs/ColorFieldInputComponent';
import ControlNetModelFieldInputComponent from './inputs/ControlNetModelFieldInputComponent';
import EnumFieldInputComponent from './inputs/EnumFieldInputComponent';
import FluxMainModelFieldInputComponent from './inputs/FluxMainModelFieldInputComponent';
import ImageFieldInputComponent from './inputs/ImageFieldInputComponent';
import IPAdapterModelFieldInputComponent from './inputs/IPAdapterModelFieldInputComponent';
import LoRAModelFieldInputComponent from './inputs/LoRAModelFieldInputComponent';
@ -69,6 +72,7 @@ type InputFieldProps = {
const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
const fieldInstance = useFieldInputInstance(nodeId, fieldName);
const fieldTemplate = useFieldInputTemplate(nodeId, fieldName);
window.console.log("Hit 0")
if (isStringFieldInputInstance(fieldInstance) && isStringFieldInputTemplate(fieldTemplate)) {
return <StringFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
@ -145,6 +149,9 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
if (isColorFieldInputInstance(fieldInstance) && isColorFieldInputTemplate(fieldTemplate)) {
return <ColorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isFluxMainModelFieldInputInstance(fieldInstance) && isFluxMainModelFieldInputTemplate(fieldTemplate)) {
return <FluxMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isSDXLMainModelFieldInputInstance(fieldInstance) && isSDXLMainModelFieldInputTemplate(fieldTemplate)) {
return <SDXLMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;

View File

@ -0,0 +1,55 @@
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { FluxMainModelFieldInputInstance, FluxMainModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useFluxModels } from 'services/api/hooks/modelsByType';
import type { MainModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<FluxMainModelFieldInputInstance, FluxMainModelFieldInputTemplate>;
const FluxMainModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useFluxModels();
const _onChange = useCallback(
(value: MainModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldMainModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
return (
<Flex w="full" alignItems="center" gap={2}>
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Flex>
);
};
export default memo(FluxMainModelFieldInputComponent);

View File

@ -61,7 +61,7 @@ export type SchedulerField = z.infer<typeof zSchedulerField>;
// #endregion
// #region Model-related schemas
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner', 'flux']);
const zModelType = z.enum([
'main',
'vae',
@ -76,6 +76,7 @@ const zModelType = z.enum([
]);
const zSubModelType = z.enum([
'unet',
'transformer',
'text_encoder',
'text_encoder_2',
'tokenizer',

View File

@ -31,6 +31,7 @@ export const MODEL_TYPES = [
'ControlNetModelField',
'LoRAModelField',
'MainModelField',
'FluxMainModelField',
'SDXLMainModelField',
'SDXLRefinerModelField',
'VaeModelField',
@ -61,6 +62,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
LatentsField: 'pink.500',
LoRAModelField: 'teal.500',
MainModelField: 'teal.500',
FluxMainModelField: 'teal.500',
SDXLMainModelField: 'teal.500',
SDXLRefinerModelField: 'teal.500',
SpandrelImageToImageModelField: 'teal.500',
@ -68,6 +70,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
T2IAdapterField: 'teal.500',
T2IAdapterModelField: 'teal.500',
UNetField: 'red.500',
TransformerField: 'red.500',
VAEField: 'blue.500',
VAEModelField: 'teal.500',
};

View File

@ -115,6 +115,10 @@ const zSDXLMainModelFieldType = zFieldTypeBase.extend({
name: z.literal('SDXLMainModelField'),
originalType: zStatelessFieldType.optional(),
});
const zFluxMainModelFieldType = zFieldTypeBase.extend({
name: z.literal('FluxMainModelField'),
originalType: zStatelessFieldType.optional(),
});
const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({
name: z.literal('SDXLRefinerModelField'),
originalType: zStatelessFieldType.optional(),
@ -158,6 +162,7 @@ const zStatefulFieldType = z.union([
zModelIdentifierFieldType,
zMainModelFieldType,
zSDXLMainModelFieldType,
zFluxMainModelFieldType,
zSDXLRefinerModelFieldType,
zVAEModelFieldType,
zLoRAModelFieldType,
@ -447,6 +452,29 @@ export const isSDXLMainModelFieldInputTemplate = (val: unknown): val is SDXLMain
zSDXLMainModelFieldInputTemplate.safeParse(val).success;
// #endregion
// #region FluxMainModelField
const zFluxMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
const zFluxMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zFluxMainModelFieldValue,
});
const zFluxMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zFluxMainModelFieldType,
originalType: zFieldType.optional(),
default: zFluxMainModelFieldValue,
});
const zFluxMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zFluxMainModelFieldType,
});
export type FluxMainModelFieldInputInstance = z.infer<typeof zFluxMainModelFieldInputInstance>;
export type FluxMainModelFieldInputTemplate = z.infer<typeof zFluxMainModelFieldInputTemplate>;
export const isFluxMainModelFieldInputInstance = (val: unknown): val is FluxMainModelFieldInputInstance =>
zFluxMainModelFieldInputInstance.safeParse(val).success;
export const isFluxMainModelFieldInputTemplate = (val: unknown): val is FluxMainModelFieldInputTemplate =>
zFluxMainModelFieldInputTemplate.safeParse(val).success;
// #endregion
// #region SDXLRefinerModelField
/** @alias */ // tells knip to ignore this duplicate export
@ -693,6 +721,7 @@ export const zStatefulFieldValue = z.union([
zModelIdentifierFieldValue,
zMainModelFieldValue,
zSDXLMainModelFieldValue,
zFluxMainModelFieldValue,
zSDXLRefinerModelFieldValue,
zVAEModelFieldValue,
zLoRAModelFieldValue,
@ -720,6 +749,7 @@ const zStatefulFieldInputInstance = z.union([
zBoardFieldInputInstance,
zModelIdentifierFieldInputInstance,
zMainModelFieldInputInstance,
zFluxMainModelFieldInputInstance,
zSDXLMainModelFieldInputInstance,
zSDXLRefinerModelFieldInputInstance,
zVAEModelFieldInputInstance,
@ -749,6 +779,7 @@ const zStatefulFieldInputTemplate = z.union([
zBoardFieldInputTemplate,
zModelIdentifierFieldInputTemplate,
zMainModelFieldInputTemplate,
zFluxMainModelFieldInputTemplate,
zSDXLMainModelFieldInputTemplate,
zSDXLRefinerModelFieldInputTemplate,
zVAEModelFieldInputTemplate,
@ -779,6 +810,7 @@ const zStatefulFieldOutputTemplate = z.union([
zBoardFieldOutputTemplate,
zModelIdentifierFieldOutputTemplate,
zMainModelFieldOutputTemplate,
zFluxMainModelFieldOutputTemplate,
zSDXLMainModelFieldOutputTemplate,
zSDXLRefinerModelFieldOutputTemplate,
zVAEModelFieldOutputTemplate,

View File

@ -114,6 +114,11 @@ const FIELD_TYPE_V1_TO_STATEFUL_FIELD_TYPE_V2: {
isCollection: false,
isCollectionOrScalar: false,
},
FluxMainModelField: {
name: 'FluxMainModelField',
isCollection: false,
isCollectionOrScalar: false,
},
SDXLMainModelField: {
name: 'SDXLMainModelField',
isCollection: false,

View File

@ -27,7 +27,7 @@ const zScheduler = z.enum([
'kdpm_2_a',
'lcm',
]);
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner', 'flux']);
const zMainModel = z.object({
model_name: z.string().min(1),
base_model: zBaseModel,
@ -89,6 +89,7 @@ const zFieldTypeV1 = z.enum([
'ONNXModelField',
'Scheduler',
'SDXLMainModelField',
'FluxMainModelField',
'SDXLRefinerModelField',
'string',
'StringCollection',
@ -417,6 +418,11 @@ const zSDXLMainModelInputFieldValue = zInputFieldValueBase.extend({
value: zMainOrOnnxModel.optional(),
});
const zFluxMainModelInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('FluxMainModelField'),
value: zMainModel.optional(),
});
const zSDXLRefinerModelInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('SDXLRefinerModelField'),
value: zMainOrOnnxModel.optional(), // TODO: should narrow this down to a refiner model
@ -572,6 +578,7 @@ const zInputFieldValue = z.discriminatedUnion('type', [
zMainModelInputFieldValue,
zSchedulerInputFieldValue,
zSDXLMainModelInputFieldValue,
zFluxMainModelInputFieldValue,
zSDXLRefinerModelInputFieldValue,
zStringCollectionInputFieldValue,
zStringPolymorphicInputFieldValue,

View File

@ -44,7 +44,7 @@ export const zSchedulerField = z.enum([
// #endregion
// #region Model-related schemas
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner', 'flux']);
const zModelName = z.string().min(3);
export const zModelIdentifier = z.object({
model_name: zModelName,

View File

@ -203,6 +203,20 @@ const zSDXLMainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
});
// #endregion
// #region FluxMainModelField
const zFluxMainModelFieldType = zFieldTypeBase.extend({
name: z.literal('FluxMainModelField'),
});
const zFluxMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
const zFluxMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
type: zFluxMainModelFieldType,
value: zFluxMainModelFieldValue,
});
const zFluxMainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
type: zFluxMainModelFieldType,
});
// #endregion
// #region SDXLRefinerModelField
const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({
name: z.literal('SDXLRefinerModelField'),
@ -338,6 +352,7 @@ const zStatefulFieldType = z.union([
zBoardFieldType,
zMainModelFieldType,
zSDXLMainModelFieldType,
zFluxMainModelFieldType,
zSDXLRefinerModelFieldType,
zVAEModelFieldType,
zLoRAModelFieldType,
@ -377,6 +392,7 @@ const zStatefulFieldInputInstance = z.union([
zBoardFieldInputInstance,
zMainModelFieldInputInstance,
zSDXLMainModelFieldInputInstance,
zFluxMainModelFieldInputInstance,
zSDXLRefinerModelFieldInputInstance,
zVAEModelFieldInputInstance,
zLoRAModelFieldInputInstance,
@ -401,6 +417,7 @@ const zStatefulFieldOutputInstance = z.union([
zBoardFieldOutputInstance,
zMainModelFieldOutputInstance,
zSDXLMainModelFieldOutputInstance,
zFluxMainModelFieldOutputInstance,
zSDXLRefinerModelFieldOutputInstance,
zVAEModelFieldOutputInstance,
zLoRAModelFieldOutputInstance,

View File

@ -15,6 +15,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
MainModelField: undefined,
SchedulerField: 'euler',
SDXLMainModelField: undefined,
FluxMainModelField: undefined,
SDXLRefinerModelField: undefined,
StringField: '',
T2IAdapterModelField: undefined,

View File

@ -8,6 +8,7 @@ import type {
FieldInputTemplate,
FieldType,
FloatFieldInputTemplate,
FluxMainModelFieldInputTemplate,
ImageFieldInputTemplate,
IntegerFieldInputTemplate,
IPAdapterModelFieldInputTemplate,
@ -180,6 +181,20 @@ const buildSDXLMainModelFieldInputTemplate: FieldInputTemplateBuilder<SDXLMainMo
return template;
};
const buildFluxMainModelFieldInputTemplate: FieldInputTemplateBuilder<FluxMainModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: FluxMainModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildRefinerModelFieldInputTemplate: FieldInputTemplateBuilder<SDXLRefinerModelFieldInputTemplate> = ({
schemaObject,
baseField,
@ -386,6 +401,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
MainModelField: buildMainModelFieldInputTemplate,
SchedulerField: buildSchedulerFieldInputTemplate,
SDXLMainModelField: buildSDXLMainModelFieldInputTemplate,
FluxMainModelField: buildFluxMainModelFieldInputTemplate,
SDXLRefinerModelField: buildRefinerModelFieldInputTemplate,
StringField: buildStringFieldInputTemplate,
T2IAdapterModelField: buildT2IAdapterModelFieldInputTemplate,

View File

@ -29,6 +29,7 @@ const MODEL_FIELD_TYPES = [
'ModelIdentifier',
'MainModelField',
'SDXLMainModelField',
'FluxMainModelField',
'SDXLRefinerModelField',
'VAEModelField',
'LoRAModelField',

View File

@ -9,6 +9,7 @@ export const MODEL_TYPE_MAP = {
'sd-2': 'Stable Diffusion 2.x',
sdxl: 'Stable Diffusion XL',
'sdxl-refiner': 'Stable Diffusion XL Refiner',
flux: 'Flux',
};
/**
@ -20,6 +21,7 @@ export const MODEL_TYPE_SHORT_MAP = {
'sd-2': 'SD2.X',
sdxl: 'SDXL',
'sdxl-refiner': 'SDXLR',
flux: 'FLUX',
};
/**
@ -46,6 +48,10 @@ export const CLIP_SKIP_MAP = {
maxClip: 24,
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
},
flux: {
maxClip: 0,
markers: [],
},
};
/**

View File

@ -5,6 +5,7 @@ import type { AnyModelConfig } from 'services/api/types';
import {
isControlNetModelConfig,
isControlNetOrT2IAdapterModelConfig,
isFluxMainModelModelConfig,
isIPAdapterModelConfig,
isLoRAModelConfig,
isNonRefinerMainModelConfig,
@ -35,6 +36,7 @@ const buildModelsHook =
export const useMainModels = buildModelsHook(isNonRefinerMainModelConfig);
export const useNonSDXLMainModels = buildModelsHook(isNonSDXLMainModelConfig);
export const useRefinerModels = buildModelsHook(isRefinerMainModelModelConfig);
export const useFluxModels = buildModelsHook(isFluxMainModelModelConfig);
export const useSDXLModels = buildModelsHook(isSDXLMainModelModelConfig);
export const useLoRAModels = buildModelsHook(isLoRAModelConfig);
export const useControlNetAndT2IAdapterModels = buildModelsHook(isControlNetOrT2IAdapterModelConfig);

File diff suppressed because one or more lines are too long

View File

@ -118,6 +118,10 @@ export const isSDXLMainModelModelConfig = (config: AnyModelConfig): config is Ma
return config.type === 'main' && config.base === 'sdxl';
};
export const isFluxMainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
return config.type === 'main' && config.base === 'flux';
};
export const isNonSDXLMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
return config.type === 'main' && (config.base === 'sd-1' || config.base === 'sd-2');
};