mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Setup flux model loading in the UI
This commit is contained in:
@ -13,6 +13,7 @@ const BASE_COLOR_MAP: Record<BaseModelType, string> = {
|
||||
'sd-2': 'teal',
|
||||
sdxl: 'invokeBlue',
|
||||
'sdxl-refiner': 'invokeBlue',
|
||||
flux: 'invokeBlue',
|
||||
};
|
||||
|
||||
const ModelBaseBadge = ({ base }: Props) => {
|
||||
|
@ -14,6 +14,8 @@ import {
|
||||
isEnumFieldInputTemplate,
|
||||
isFloatFieldInputInstance,
|
||||
isFloatFieldInputTemplate,
|
||||
isFluxMainModelFieldInputInstance,
|
||||
isFluxMainModelFieldInputTemplate,
|
||||
isImageFieldInputInstance,
|
||||
isImageFieldInputTemplate,
|
||||
isIntegerFieldInputInstance,
|
||||
@ -48,6 +50,7 @@ import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent';
|
||||
import ColorFieldInputComponent from './inputs/ColorFieldInputComponent';
|
||||
import ControlNetModelFieldInputComponent from './inputs/ControlNetModelFieldInputComponent';
|
||||
import EnumFieldInputComponent from './inputs/EnumFieldInputComponent';
|
||||
import FluxMainModelFieldInputComponent from './inputs/FluxMainModelFieldInputComponent';
|
||||
import ImageFieldInputComponent from './inputs/ImageFieldInputComponent';
|
||||
import IPAdapterModelFieldInputComponent from './inputs/IPAdapterModelFieldInputComponent';
|
||||
import LoRAModelFieldInputComponent from './inputs/LoRAModelFieldInputComponent';
|
||||
@ -69,6 +72,7 @@ type InputFieldProps = {
|
||||
const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
const fieldInstance = useFieldInputInstance(nodeId, fieldName);
|
||||
const fieldTemplate = useFieldInputTemplate(nodeId, fieldName);
|
||||
window.console.log("Hit 0")
|
||||
|
||||
if (isStringFieldInputInstance(fieldInstance) && isStringFieldInputTemplate(fieldTemplate)) {
|
||||
return <StringFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
@ -145,6 +149,9 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
if (isColorFieldInputInstance(fieldInstance) && isColorFieldInputTemplate(fieldTemplate)) {
|
||||
return <ColorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
if (isFluxMainModelFieldInputInstance(fieldInstance) && isFluxMainModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <FluxMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isSDXLMainModelFieldInputInstance(fieldInstance) && isSDXLMainModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <SDXLMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
|
@ -0,0 +1,55 @@
|
||||
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { FluxMainModelFieldInputInstance, FluxMainModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useFluxModels } from 'services/api/hooks/modelsByType';
|
||||
import type { MainModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<FluxMainModelFieldInputInstance, FluxMainModelFieldInputTemplate>;
|
||||
|
||||
const FluxMainModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useFluxModels();
|
||||
const _onChange = useCallback(
|
||||
(value: MainModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldMainModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
isLoading,
|
||||
selectedModel: field.value,
|
||||
});
|
||||
|
||||
return (
|
||||
<Flex w="full" alignItems="center" gap={2}>
|
||||
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
|
||||
<Combobox
|
||||
value={value}
|
||||
placeholder={placeholder}
|
||||
options={options}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</FormControl>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(FluxMainModelFieldInputComponent);
|
@ -61,7 +61,7 @@ export type SchedulerField = z.infer<typeof zSchedulerField>;
|
||||
// #endregion
|
||||
|
||||
// #region Model-related schemas
|
||||
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
|
||||
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner', 'flux']);
|
||||
const zModelType = z.enum([
|
||||
'main',
|
||||
'vae',
|
||||
@ -76,6 +76,7 @@ const zModelType = z.enum([
|
||||
]);
|
||||
const zSubModelType = z.enum([
|
||||
'unet',
|
||||
'transformer',
|
||||
'text_encoder',
|
||||
'text_encoder_2',
|
||||
'tokenizer',
|
||||
|
@ -31,6 +31,7 @@ export const MODEL_TYPES = [
|
||||
'ControlNetModelField',
|
||||
'LoRAModelField',
|
||||
'MainModelField',
|
||||
'FluxMainModelField',
|
||||
'SDXLMainModelField',
|
||||
'SDXLRefinerModelField',
|
||||
'VaeModelField',
|
||||
@ -61,6 +62,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
|
||||
LatentsField: 'pink.500',
|
||||
LoRAModelField: 'teal.500',
|
||||
MainModelField: 'teal.500',
|
||||
FluxMainModelField: 'teal.500',
|
||||
SDXLMainModelField: 'teal.500',
|
||||
SDXLRefinerModelField: 'teal.500',
|
||||
SpandrelImageToImageModelField: 'teal.500',
|
||||
@ -68,6 +70,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
|
||||
T2IAdapterField: 'teal.500',
|
||||
T2IAdapterModelField: 'teal.500',
|
||||
UNetField: 'red.500',
|
||||
TransformerField: 'red.500',
|
||||
VAEField: 'blue.500',
|
||||
VAEModelField: 'teal.500',
|
||||
};
|
||||
|
@ -115,6 +115,10 @@ const zSDXLMainModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SDXLMainModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zFluxMainModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('FluxMainModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SDXLRefinerModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
@ -158,6 +162,7 @@ const zStatefulFieldType = z.union([
|
||||
zModelIdentifierFieldType,
|
||||
zMainModelFieldType,
|
||||
zSDXLMainModelFieldType,
|
||||
zFluxMainModelFieldType,
|
||||
zSDXLRefinerModelFieldType,
|
||||
zVAEModelFieldType,
|
||||
zLoRAModelFieldType,
|
||||
@ -447,6 +452,29 @@ export const isSDXLMainModelFieldInputTemplate = (val: unknown): val is SDXLMain
|
||||
zSDXLMainModelFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region FluxMainModelField
|
||||
|
||||
const zFluxMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
|
||||
const zFluxMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zFluxMainModelFieldValue,
|
||||
});
|
||||
const zFluxMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zFluxMainModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zFluxMainModelFieldValue,
|
||||
});
|
||||
const zFluxMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zFluxMainModelFieldType,
|
||||
});
|
||||
export type FluxMainModelFieldInputInstance = z.infer<typeof zFluxMainModelFieldInputInstance>;
|
||||
export type FluxMainModelFieldInputTemplate = z.infer<typeof zFluxMainModelFieldInputTemplate>;
|
||||
export const isFluxMainModelFieldInputInstance = (val: unknown): val is FluxMainModelFieldInputInstance =>
|
||||
zFluxMainModelFieldInputInstance.safeParse(val).success;
|
||||
export const isFluxMainModelFieldInputTemplate = (val: unknown): val is FluxMainModelFieldInputTemplate =>
|
||||
zFluxMainModelFieldInputTemplate.safeParse(val).success;
|
||||
|
||||
// #endregion
|
||||
|
||||
// #region SDXLRefinerModelField
|
||||
|
||||
/** @alias */ // tells knip to ignore this duplicate export
|
||||
@ -693,6 +721,7 @@ export const zStatefulFieldValue = z.union([
|
||||
zModelIdentifierFieldValue,
|
||||
zMainModelFieldValue,
|
||||
zSDXLMainModelFieldValue,
|
||||
zFluxMainModelFieldValue,
|
||||
zSDXLRefinerModelFieldValue,
|
||||
zVAEModelFieldValue,
|
||||
zLoRAModelFieldValue,
|
||||
@ -720,6 +749,7 @@ const zStatefulFieldInputInstance = z.union([
|
||||
zBoardFieldInputInstance,
|
||||
zModelIdentifierFieldInputInstance,
|
||||
zMainModelFieldInputInstance,
|
||||
zFluxMainModelFieldInputInstance,
|
||||
zSDXLMainModelFieldInputInstance,
|
||||
zSDXLRefinerModelFieldInputInstance,
|
||||
zVAEModelFieldInputInstance,
|
||||
@ -749,6 +779,7 @@ const zStatefulFieldInputTemplate = z.union([
|
||||
zBoardFieldInputTemplate,
|
||||
zModelIdentifierFieldInputTemplate,
|
||||
zMainModelFieldInputTemplate,
|
||||
zFluxMainModelFieldInputTemplate,
|
||||
zSDXLMainModelFieldInputTemplate,
|
||||
zSDXLRefinerModelFieldInputTemplate,
|
||||
zVAEModelFieldInputTemplate,
|
||||
@ -779,6 +810,7 @@ const zStatefulFieldOutputTemplate = z.union([
|
||||
zBoardFieldOutputTemplate,
|
||||
zModelIdentifierFieldOutputTemplate,
|
||||
zMainModelFieldOutputTemplate,
|
||||
zFluxMainModelFieldOutputTemplate,
|
||||
zSDXLMainModelFieldOutputTemplate,
|
||||
zSDXLRefinerModelFieldOutputTemplate,
|
||||
zVAEModelFieldOutputTemplate,
|
||||
|
@ -114,6 +114,11 @@ const FIELD_TYPE_V1_TO_STATEFUL_FIELD_TYPE_V2: {
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
},
|
||||
FluxMainModelField: {
|
||||
name: 'FluxMainModelField',
|
||||
isCollection: false,
|
||||
isCollectionOrScalar: false,
|
||||
},
|
||||
SDXLMainModelField: {
|
||||
name: 'SDXLMainModelField',
|
||||
isCollection: false,
|
||||
|
@ -27,7 +27,7 @@ const zScheduler = z.enum([
|
||||
'kdpm_2_a',
|
||||
'lcm',
|
||||
]);
|
||||
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
|
||||
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner', 'flux']);
|
||||
const zMainModel = z.object({
|
||||
model_name: z.string().min(1),
|
||||
base_model: zBaseModel,
|
||||
@ -89,6 +89,7 @@ const zFieldTypeV1 = z.enum([
|
||||
'ONNXModelField',
|
||||
'Scheduler',
|
||||
'SDXLMainModelField',
|
||||
'FluxMainModelField',
|
||||
'SDXLRefinerModelField',
|
||||
'string',
|
||||
'StringCollection',
|
||||
@ -417,6 +418,11 @@ const zSDXLMainModelInputFieldValue = zInputFieldValueBase.extend({
|
||||
value: zMainOrOnnxModel.optional(),
|
||||
});
|
||||
|
||||
const zFluxMainModelInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('FluxMainModelField'),
|
||||
value: zMainModel.optional(),
|
||||
});
|
||||
|
||||
const zSDXLRefinerModelInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('SDXLRefinerModelField'),
|
||||
value: zMainOrOnnxModel.optional(), // TODO: should narrow this down to a refiner model
|
||||
@ -572,6 +578,7 @@ const zInputFieldValue = z.discriminatedUnion('type', [
|
||||
zMainModelInputFieldValue,
|
||||
zSchedulerInputFieldValue,
|
||||
zSDXLMainModelInputFieldValue,
|
||||
zFluxMainModelInputFieldValue,
|
||||
zSDXLRefinerModelInputFieldValue,
|
||||
zStringCollectionInputFieldValue,
|
||||
zStringPolymorphicInputFieldValue,
|
||||
|
@ -44,7 +44,7 @@ export const zSchedulerField = z.enum([
|
||||
// #endregion
|
||||
|
||||
// #region Model-related schemas
|
||||
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
|
||||
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner', 'flux']);
|
||||
const zModelName = z.string().min(3);
|
||||
export const zModelIdentifier = z.object({
|
||||
model_name: zModelName,
|
||||
|
@ -203,6 +203,20 @@ const zSDXLMainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
});
|
||||
// #endregion
|
||||
|
||||
// #region FluxMainModelField
|
||||
const zFluxMainModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('FluxMainModelField'),
|
||||
});
|
||||
const zFluxMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
|
||||
const zFluxMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zFluxMainModelFieldType,
|
||||
value: zFluxMainModelFieldValue,
|
||||
});
|
||||
const zFluxMainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zFluxMainModelFieldType,
|
||||
});
|
||||
// #endregion
|
||||
|
||||
// #region SDXLRefinerModelField
|
||||
const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SDXLRefinerModelField'),
|
||||
@ -338,6 +352,7 @@ const zStatefulFieldType = z.union([
|
||||
zBoardFieldType,
|
||||
zMainModelFieldType,
|
||||
zSDXLMainModelFieldType,
|
||||
zFluxMainModelFieldType,
|
||||
zSDXLRefinerModelFieldType,
|
||||
zVAEModelFieldType,
|
||||
zLoRAModelFieldType,
|
||||
@ -377,6 +392,7 @@ const zStatefulFieldInputInstance = z.union([
|
||||
zBoardFieldInputInstance,
|
||||
zMainModelFieldInputInstance,
|
||||
zSDXLMainModelFieldInputInstance,
|
||||
zFluxMainModelFieldInputInstance,
|
||||
zSDXLRefinerModelFieldInputInstance,
|
||||
zVAEModelFieldInputInstance,
|
||||
zLoRAModelFieldInputInstance,
|
||||
@ -401,6 +417,7 @@ const zStatefulFieldOutputInstance = z.union([
|
||||
zBoardFieldOutputInstance,
|
||||
zMainModelFieldOutputInstance,
|
||||
zSDXLMainModelFieldOutputInstance,
|
||||
zFluxMainModelFieldOutputInstance,
|
||||
zSDXLRefinerModelFieldOutputInstance,
|
||||
zVAEModelFieldOutputInstance,
|
||||
zLoRAModelFieldOutputInstance,
|
||||
|
@ -15,6 +15,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
|
||||
MainModelField: undefined,
|
||||
SchedulerField: 'euler',
|
||||
SDXLMainModelField: undefined,
|
||||
FluxMainModelField: undefined,
|
||||
SDXLRefinerModelField: undefined,
|
||||
StringField: '',
|
||||
T2IAdapterModelField: undefined,
|
||||
|
@ -8,6 +8,7 @@ import type {
|
||||
FieldInputTemplate,
|
||||
FieldType,
|
||||
FloatFieldInputTemplate,
|
||||
FluxMainModelFieldInputTemplate,
|
||||
ImageFieldInputTemplate,
|
||||
IntegerFieldInputTemplate,
|
||||
IPAdapterModelFieldInputTemplate,
|
||||
@ -180,6 +181,20 @@ const buildSDXLMainModelFieldInputTemplate: FieldInputTemplateBuilder<SDXLMainMo
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildFluxMainModelFieldInputTemplate: FieldInputTemplateBuilder<FluxMainModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: FluxMainModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildRefinerModelFieldInputTemplate: FieldInputTemplateBuilder<SDXLRefinerModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@ -386,6 +401,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
|
||||
MainModelField: buildMainModelFieldInputTemplate,
|
||||
SchedulerField: buildSchedulerFieldInputTemplate,
|
||||
SDXLMainModelField: buildSDXLMainModelFieldInputTemplate,
|
||||
FluxMainModelField: buildFluxMainModelFieldInputTemplate,
|
||||
SDXLRefinerModelField: buildRefinerModelFieldInputTemplate,
|
||||
StringField: buildStringFieldInputTemplate,
|
||||
T2IAdapterModelField: buildT2IAdapterModelFieldInputTemplate,
|
||||
|
@ -29,6 +29,7 @@ const MODEL_FIELD_TYPES = [
|
||||
'ModelIdentifier',
|
||||
'MainModelField',
|
||||
'SDXLMainModelField',
|
||||
'FluxMainModelField',
|
||||
'SDXLRefinerModelField',
|
||||
'VAEModelField',
|
||||
'LoRAModelField',
|
||||
|
@ -9,6 +9,7 @@ export const MODEL_TYPE_MAP = {
|
||||
'sd-2': 'Stable Diffusion 2.x',
|
||||
sdxl: 'Stable Diffusion XL',
|
||||
'sdxl-refiner': 'Stable Diffusion XL Refiner',
|
||||
flux: 'Flux',
|
||||
};
|
||||
|
||||
/**
|
||||
@ -20,6 +21,7 @@ export const MODEL_TYPE_SHORT_MAP = {
|
||||
'sd-2': 'SD2.X',
|
||||
sdxl: 'SDXL',
|
||||
'sdxl-refiner': 'SDXLR',
|
||||
flux: 'FLUX',
|
||||
};
|
||||
|
||||
/**
|
||||
@ -46,6 +48,10 @@ export const CLIP_SKIP_MAP = {
|
||||
maxClip: 24,
|
||||
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
|
||||
},
|
||||
flux: {
|
||||
maxClip: 0,
|
||||
markers: [],
|
||||
},
|
||||
};
|
||||
|
||||
/**
|
||||
|
@ -5,6 +5,7 @@ import type { AnyModelConfig } from 'services/api/types';
|
||||
import {
|
||||
isControlNetModelConfig,
|
||||
isControlNetOrT2IAdapterModelConfig,
|
||||
isFluxMainModelModelConfig,
|
||||
isIPAdapterModelConfig,
|
||||
isLoRAModelConfig,
|
||||
isNonRefinerMainModelConfig,
|
||||
@ -35,6 +36,7 @@ const buildModelsHook =
|
||||
export const useMainModels = buildModelsHook(isNonRefinerMainModelConfig);
|
||||
export const useNonSDXLMainModels = buildModelsHook(isNonSDXLMainModelConfig);
|
||||
export const useRefinerModels = buildModelsHook(isRefinerMainModelModelConfig);
|
||||
export const useFluxModels = buildModelsHook(isFluxMainModelModelConfig);
|
||||
export const useSDXLModels = buildModelsHook(isSDXLMainModelModelConfig);
|
||||
export const useLoRAModels = buildModelsHook(isLoRAModelConfig);
|
||||
export const useControlNetAndT2IAdapterModels = buildModelsHook(isControlNetOrT2IAdapterModelConfig);
|
||||
|
File diff suppressed because one or more lines are too long
@ -118,6 +118,10 @@ export const isSDXLMainModelModelConfig = (config: AnyModelConfig): config is Ma
|
||||
return config.type === 'main' && config.base === 'sdxl';
|
||||
};
|
||||
|
||||
export const isFluxMainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
|
||||
return config.type === 'main' && config.base === 'flux';
|
||||
};
|
||||
|
||||
export const isNonSDXLMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
|
||||
return config.type === 'main' && (config.base === 'sd-1' || config.base === 'sd-2');
|
||||
};
|
||||
|
Reference in New Issue
Block a user