Setup flux model loading in the UI

This commit is contained in:
Brandon Rising 2024-08-12 14:04:23 -04:00 committed by Brandon
parent 1fa6bddc89
commit 5f59a828f9
22 changed files with 463 additions and 18 deletions

View File

@ -40,6 +40,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
# region Model Field Types # region Model Field Types
MainModel = "MainModelField" MainModel = "MainModelField"
FluxMainModel = "FluxMainModelField"
SDXLMainModel = "SDXLMainModelField" SDXLMainModel = "SDXLMainModelField"
SDXLRefinerModel = "SDXLRefinerModelField" SDXLRefinerModel = "SDXLRefinerModelField"
ONNXModel = "ONNXModelField" ONNXModel = "ONNXModelField"
@ -126,12 +127,14 @@ class FieldDescriptions:
noise = "Noise tensor" noise = "Noise tensor"
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count" clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
unet = "UNet (scheduler, LoRAs)" unet = "UNet (scheduler, LoRAs)"
transformer = "Transformer"
vae = "VAE" vae = "VAE"
cond = "Conditioning tensor" cond = "Conditioning tensor"
controlnet_model = "ControlNet model to load" controlnet_model = "ControlNet model to load"
vae_model = "VAE model to load" vae_model = "VAE model to load"
lora_model = "LoRA model to load" lora_model = "LoRA model to load"
main_model = "Main model (UNet, VAE, CLIP) to load" main_model = "Main model (UNet, VAE, CLIP) to load"
flux_model = "Flux model (Transformer, VAE, CLIP) to load"
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load" sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load" sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load" onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"

View File

@ -1,14 +1,13 @@
from pathlib import Path from pathlib import Path
from typing import Literal from typing import Literal
from pydantic import Field
import accelerate import accelerate
import torch import torch
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from einops import rearrange, repeat from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from flux.model import Flux from invokeai.app.invocations.model import ModelIdentifierField
from flux.modules.autoencoder import AutoEncoder from optimum.quanto import qfloat8
from flux.sampling import denoise, get_noise, get_schedule, unpack
from flux.util import configs as flux_configs
from PIL import Image from PIL import Image
from safetensors.torch import load_file from safetensors.torch import load_file
from transformers.models.auto import AutoModelForTextEncoding from transformers.models.auto import AutoModelForTextEncoding
@ -21,6 +20,7 @@ from invokeai.app.invocations.fields import (
InputField, InputField,
WithBoard, WithBoard,
WithMetadata, WithMetadata,
UIType,
) )
from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
@ -52,6 +52,11 @@ class QuantizedModelForTextEncoding(FastQuantizedTransformersModel):
class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Text-to-image generation using a FLUX model.""" """Text-to-image generation using a FLUX model."""
flux_model: ModelIdentifierField = InputField(
description="The Flux model",
input=Input.Any,
ui_type=UIType.FluxMainModel
)
model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.") model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.")
quantization_type: Literal["raw", "NF4", "llm_int8"] = InputField( quantization_type: Literal["raw", "NF4", "llm_int8"] = InputField(
default="raw", description="The type of quantization to use for the transformer model." default="raw", description="The type of quantization to use for the transformer model."

View File

@ -60,6 +60,12 @@ class CLIPField(BaseModel):
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading") loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
class TransformerField(BaseModel):
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
scheduler: ModelIdentifierField = Field(description="Info to load scheduler submodel")
class VAEField(BaseModel): class VAEField(BaseModel):
vae: ModelIdentifierField = Field(description="Info to load vae submodel") vae: ModelIdentifierField = Field(description="Info to load vae submodel")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless') seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
@ -122,6 +128,49 @@ class ModelIdentifierInvocation(BaseInvocation):
return ModelIdentifierOutput(model=self.model) return ModelIdentifierOutput(model=self.model)
@invocation_output("flux_model_loader_output")
class FluxModelLoaderOutput(BaseInvocationOutput):
"""Flux base model loader output"""
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
clip2: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation("flux_model_loader", title="Flux Main Model", tags=["model", "flux"], category="model", version="1.0.3")
class FluxModelLoaderInvocation(BaseInvocation):
"""Loads a flux base model, outputting its submodels."""
model: ModelIdentifierField = InputField(
description=FieldDescriptions.flux_model,
ui_type=UIType.FluxMainModel,
input=Input.Direct,
)
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
model_key = self.model.key
# TODO: not found exceptions
if not context.models.exists(model_key):
raise Exception(f"Unknown model: {model_key}")
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
tokenizer2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
text_encoder2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer, scheduler=scheduler),
clip=CLIPField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0),
clip2=CLIPField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0),
vae=VAEField(vae=vae),
)
@invocation( @invocation(
"main_model_loader", "main_model_loader",
title="Main Model", title="Main Model",

View File

@ -52,6 +52,7 @@ class BaseModelType(str, Enum):
StableDiffusion2 = "sd-2" StableDiffusion2 = "sd-2"
StableDiffusionXL = "sdxl" StableDiffusionXL = "sdxl"
StableDiffusionXLRefiner = "sdxl-refiner" StableDiffusionXLRefiner = "sdxl-refiner"
Flux = "flux"
# Kandinsky2_1 = "kandinsky-2.1" # Kandinsky2_1 = "kandinsky-2.1"
@ -74,6 +75,7 @@ class SubModelType(str, Enum):
"""Submodel type.""" """Submodel type."""
UNet = "unet" UNet = "unet"
Transformer = "transformer"
TextEncoder = "text_encoder" TextEncoder = "text_encoder"
TextEncoder2 = "text_encoder_2" TextEncoder2 = "text_encoder_2"
Tokenizer = "tokenizer" Tokenizer = "tokenizer"

View File

@ -95,6 +95,7 @@ class ModelProbe(object):
} }
CLASS2TYPE = { CLASS2TYPE = {
"FluxPipeline": ModelType.Main,
"StableDiffusionPipeline": ModelType.Main, "StableDiffusionPipeline": ModelType.Main,
"StableDiffusionInpaintPipeline": ModelType.Main, "StableDiffusionInpaintPipeline": ModelType.Main,
"StableDiffusionXLPipeline": ModelType.Main, "StableDiffusionXLPipeline": ModelType.Main,
@ -626,6 +627,10 @@ class FolderProbeBase(ProbeBase):
class PipelineFolderProbe(FolderProbeBase): class PipelineFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType: def get_base_type(self) -> BaseModelType:
with open(f"{self.model_path}/model_index.json", "r") as file:
conf = json.load(file)
if "_class_name" in conf and conf.get("_class_name") == "FluxPipeline":
return BaseModelType.Flux
with open(self.model_path / "unet" / "config.json", "r") as file: with open(self.model_path / "unet" / "config.json", "r") as file:
unet_conf = json.load(file) unet_conf = json.load(file)
if unet_conf["cross_attention_dim"] == 768: if unet_conf["cross_attention_dim"] == 768:

View File

@ -13,6 +13,7 @@ const BASE_COLOR_MAP: Record<BaseModelType, string> = {
'sd-2': 'teal', 'sd-2': 'teal',
sdxl: 'invokeBlue', sdxl: 'invokeBlue',
'sdxl-refiner': 'invokeBlue', 'sdxl-refiner': 'invokeBlue',
flux: 'invokeBlue',
}; };
const ModelBaseBadge = ({ base }: Props) => { const ModelBaseBadge = ({ base }: Props) => {

View File

@ -14,6 +14,8 @@ import {
isEnumFieldInputTemplate, isEnumFieldInputTemplate,
isFloatFieldInputInstance, isFloatFieldInputInstance,
isFloatFieldInputTemplate, isFloatFieldInputTemplate,
isFluxMainModelFieldInputInstance,
isFluxMainModelFieldInputTemplate,
isImageFieldInputInstance, isImageFieldInputInstance,
isImageFieldInputTemplate, isImageFieldInputTemplate,
isIntegerFieldInputInstance, isIntegerFieldInputInstance,
@ -48,6 +50,7 @@ import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent';
import ColorFieldInputComponent from './inputs/ColorFieldInputComponent'; import ColorFieldInputComponent from './inputs/ColorFieldInputComponent';
import ControlNetModelFieldInputComponent from './inputs/ControlNetModelFieldInputComponent'; import ControlNetModelFieldInputComponent from './inputs/ControlNetModelFieldInputComponent';
import EnumFieldInputComponent from './inputs/EnumFieldInputComponent'; import EnumFieldInputComponent from './inputs/EnumFieldInputComponent';
import FluxMainModelFieldInputComponent from './inputs/FluxMainModelFieldInputComponent';
import ImageFieldInputComponent from './inputs/ImageFieldInputComponent'; import ImageFieldInputComponent from './inputs/ImageFieldInputComponent';
import IPAdapterModelFieldInputComponent from './inputs/IPAdapterModelFieldInputComponent'; import IPAdapterModelFieldInputComponent from './inputs/IPAdapterModelFieldInputComponent';
import LoRAModelFieldInputComponent from './inputs/LoRAModelFieldInputComponent'; import LoRAModelFieldInputComponent from './inputs/LoRAModelFieldInputComponent';
@ -69,6 +72,7 @@ type InputFieldProps = {
const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => { const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
const fieldInstance = useFieldInputInstance(nodeId, fieldName); const fieldInstance = useFieldInputInstance(nodeId, fieldName);
const fieldTemplate = useFieldInputTemplate(nodeId, fieldName); const fieldTemplate = useFieldInputTemplate(nodeId, fieldName);
window.console.log("Hit 0")
if (isStringFieldInputInstance(fieldInstance) && isStringFieldInputTemplate(fieldTemplate)) { if (isStringFieldInputInstance(fieldInstance) && isStringFieldInputTemplate(fieldTemplate)) {
return <StringFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />; return <StringFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
@ -145,6 +149,9 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
if (isColorFieldInputInstance(fieldInstance) && isColorFieldInputTemplate(fieldTemplate)) { if (isColorFieldInputInstance(fieldInstance) && isColorFieldInputTemplate(fieldTemplate)) {
return <ColorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />; return <ColorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
} }
if (isFluxMainModelFieldInputInstance(fieldInstance) && isFluxMainModelFieldInputTemplate(fieldTemplate)) {
return <FluxMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isSDXLMainModelFieldInputInstance(fieldInstance) && isSDXLMainModelFieldInputTemplate(fieldTemplate)) { if (isSDXLMainModelFieldInputInstance(fieldInstance) && isSDXLMainModelFieldInputTemplate(fieldTemplate)) {
return <SDXLMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />; return <SDXLMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;

View File

@ -0,0 +1,55 @@
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { FluxMainModelFieldInputInstance, FluxMainModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useFluxModels } from 'services/api/hooks/modelsByType';
import type { MainModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<FluxMainModelFieldInputInstance, FluxMainModelFieldInputTemplate>;
const FluxMainModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useFluxModels();
const _onChange = useCallback(
(value: MainModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldMainModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
return (
<Flex w="full" alignItems="center" gap={2}>
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Flex>
);
};
export default memo(FluxMainModelFieldInputComponent);

View File

@ -61,7 +61,7 @@ export type SchedulerField = z.infer<typeof zSchedulerField>;
// #endregion // #endregion
// #region Model-related schemas // #region Model-related schemas
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']); const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner', 'flux']);
const zModelType = z.enum([ const zModelType = z.enum([
'main', 'main',
'vae', 'vae',
@ -76,6 +76,7 @@ const zModelType = z.enum([
]); ]);
const zSubModelType = z.enum([ const zSubModelType = z.enum([
'unet', 'unet',
'transformer',
'text_encoder', 'text_encoder',
'text_encoder_2', 'text_encoder_2',
'tokenizer', 'tokenizer',

View File

@ -31,6 +31,7 @@ export const MODEL_TYPES = [
'ControlNetModelField', 'ControlNetModelField',
'LoRAModelField', 'LoRAModelField',
'MainModelField', 'MainModelField',
'FluxMainModelField',
'SDXLMainModelField', 'SDXLMainModelField',
'SDXLRefinerModelField', 'SDXLRefinerModelField',
'VaeModelField', 'VaeModelField',
@ -61,6 +62,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
LatentsField: 'pink.500', LatentsField: 'pink.500',
LoRAModelField: 'teal.500', LoRAModelField: 'teal.500',
MainModelField: 'teal.500', MainModelField: 'teal.500',
FluxMainModelField: 'teal.500',
SDXLMainModelField: 'teal.500', SDXLMainModelField: 'teal.500',
SDXLRefinerModelField: 'teal.500', SDXLRefinerModelField: 'teal.500',
SpandrelImageToImageModelField: 'teal.500', SpandrelImageToImageModelField: 'teal.500',
@ -68,6 +70,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
T2IAdapterField: 'teal.500', T2IAdapterField: 'teal.500',
T2IAdapterModelField: 'teal.500', T2IAdapterModelField: 'teal.500',
UNetField: 'red.500', UNetField: 'red.500',
TransformerField: 'red.500',
VAEField: 'blue.500', VAEField: 'blue.500',
VAEModelField: 'teal.500', VAEModelField: 'teal.500',
}; };

View File

@ -115,6 +115,10 @@ const zSDXLMainModelFieldType = zFieldTypeBase.extend({
name: z.literal('SDXLMainModelField'), name: z.literal('SDXLMainModelField'),
originalType: zStatelessFieldType.optional(), originalType: zStatelessFieldType.optional(),
}); });
const zFluxMainModelFieldType = zFieldTypeBase.extend({
name: z.literal('FluxMainModelField'),
originalType: zStatelessFieldType.optional(),
});
const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({ const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({
name: z.literal('SDXLRefinerModelField'), name: z.literal('SDXLRefinerModelField'),
originalType: zStatelessFieldType.optional(), originalType: zStatelessFieldType.optional(),
@ -158,6 +162,7 @@ const zStatefulFieldType = z.union([
zModelIdentifierFieldType, zModelIdentifierFieldType,
zMainModelFieldType, zMainModelFieldType,
zSDXLMainModelFieldType, zSDXLMainModelFieldType,
zFluxMainModelFieldType,
zSDXLRefinerModelFieldType, zSDXLRefinerModelFieldType,
zVAEModelFieldType, zVAEModelFieldType,
zLoRAModelFieldType, zLoRAModelFieldType,
@ -447,6 +452,29 @@ export const isSDXLMainModelFieldInputTemplate = (val: unknown): val is SDXLMain
zSDXLMainModelFieldInputTemplate.safeParse(val).success; zSDXLMainModelFieldInputTemplate.safeParse(val).success;
// #endregion // #endregion
// #region FluxMainModelField
const zFluxMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
const zFluxMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zFluxMainModelFieldValue,
});
const zFluxMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zFluxMainModelFieldType,
originalType: zFieldType.optional(),
default: zFluxMainModelFieldValue,
});
const zFluxMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zFluxMainModelFieldType,
});
export type FluxMainModelFieldInputInstance = z.infer<typeof zFluxMainModelFieldInputInstance>;
export type FluxMainModelFieldInputTemplate = z.infer<typeof zFluxMainModelFieldInputTemplate>;
export const isFluxMainModelFieldInputInstance = (val: unknown): val is FluxMainModelFieldInputInstance =>
zFluxMainModelFieldInputInstance.safeParse(val).success;
export const isFluxMainModelFieldInputTemplate = (val: unknown): val is FluxMainModelFieldInputTemplate =>
zFluxMainModelFieldInputTemplate.safeParse(val).success;
// #endregion
// #region SDXLRefinerModelField // #region SDXLRefinerModelField
/** @alias */ // tells knip to ignore this duplicate export /** @alias */ // tells knip to ignore this duplicate export
@ -693,6 +721,7 @@ export const zStatefulFieldValue = z.union([
zModelIdentifierFieldValue, zModelIdentifierFieldValue,
zMainModelFieldValue, zMainModelFieldValue,
zSDXLMainModelFieldValue, zSDXLMainModelFieldValue,
zFluxMainModelFieldValue,
zSDXLRefinerModelFieldValue, zSDXLRefinerModelFieldValue,
zVAEModelFieldValue, zVAEModelFieldValue,
zLoRAModelFieldValue, zLoRAModelFieldValue,
@ -720,6 +749,7 @@ const zStatefulFieldInputInstance = z.union([
zBoardFieldInputInstance, zBoardFieldInputInstance,
zModelIdentifierFieldInputInstance, zModelIdentifierFieldInputInstance,
zMainModelFieldInputInstance, zMainModelFieldInputInstance,
zFluxMainModelFieldInputInstance,
zSDXLMainModelFieldInputInstance, zSDXLMainModelFieldInputInstance,
zSDXLRefinerModelFieldInputInstance, zSDXLRefinerModelFieldInputInstance,
zVAEModelFieldInputInstance, zVAEModelFieldInputInstance,
@ -749,6 +779,7 @@ const zStatefulFieldInputTemplate = z.union([
zBoardFieldInputTemplate, zBoardFieldInputTemplate,
zModelIdentifierFieldInputTemplate, zModelIdentifierFieldInputTemplate,
zMainModelFieldInputTemplate, zMainModelFieldInputTemplate,
zFluxMainModelFieldInputTemplate,
zSDXLMainModelFieldInputTemplate, zSDXLMainModelFieldInputTemplate,
zSDXLRefinerModelFieldInputTemplate, zSDXLRefinerModelFieldInputTemplate,
zVAEModelFieldInputTemplate, zVAEModelFieldInputTemplate,
@ -779,6 +810,7 @@ const zStatefulFieldOutputTemplate = z.union([
zBoardFieldOutputTemplate, zBoardFieldOutputTemplate,
zModelIdentifierFieldOutputTemplate, zModelIdentifierFieldOutputTemplate,
zMainModelFieldOutputTemplate, zMainModelFieldOutputTemplate,
zFluxMainModelFieldOutputTemplate,
zSDXLMainModelFieldOutputTemplate, zSDXLMainModelFieldOutputTemplate,
zSDXLRefinerModelFieldOutputTemplate, zSDXLRefinerModelFieldOutputTemplate,
zVAEModelFieldOutputTemplate, zVAEModelFieldOutputTemplate,

View File

@ -114,6 +114,11 @@ const FIELD_TYPE_V1_TO_STATEFUL_FIELD_TYPE_V2: {
isCollection: false, isCollection: false,
isCollectionOrScalar: false, isCollectionOrScalar: false,
}, },
FluxMainModelField: {
name: 'FluxMainModelField',
isCollection: false,
isCollectionOrScalar: false,
},
SDXLMainModelField: { SDXLMainModelField: {
name: 'SDXLMainModelField', name: 'SDXLMainModelField',
isCollection: false, isCollection: false,

View File

@ -27,7 +27,7 @@ const zScheduler = z.enum([
'kdpm_2_a', 'kdpm_2_a',
'lcm', 'lcm',
]); ]);
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']); const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner', 'flux']);
const zMainModel = z.object({ const zMainModel = z.object({
model_name: z.string().min(1), model_name: z.string().min(1),
base_model: zBaseModel, base_model: zBaseModel,
@ -89,6 +89,7 @@ const zFieldTypeV1 = z.enum([
'ONNXModelField', 'ONNXModelField',
'Scheduler', 'Scheduler',
'SDXLMainModelField', 'SDXLMainModelField',
'FluxMainModelField',
'SDXLRefinerModelField', 'SDXLRefinerModelField',
'string', 'string',
'StringCollection', 'StringCollection',
@ -417,6 +418,11 @@ const zSDXLMainModelInputFieldValue = zInputFieldValueBase.extend({
value: zMainOrOnnxModel.optional(), value: zMainOrOnnxModel.optional(),
}); });
const zFluxMainModelInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('FluxMainModelField'),
value: zMainModel.optional(),
});
const zSDXLRefinerModelInputFieldValue = zInputFieldValueBase.extend({ const zSDXLRefinerModelInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('SDXLRefinerModelField'), type: z.literal('SDXLRefinerModelField'),
value: zMainOrOnnxModel.optional(), // TODO: should narrow this down to a refiner model value: zMainOrOnnxModel.optional(), // TODO: should narrow this down to a refiner model
@ -572,6 +578,7 @@ const zInputFieldValue = z.discriminatedUnion('type', [
zMainModelInputFieldValue, zMainModelInputFieldValue,
zSchedulerInputFieldValue, zSchedulerInputFieldValue,
zSDXLMainModelInputFieldValue, zSDXLMainModelInputFieldValue,
zFluxMainModelInputFieldValue,
zSDXLRefinerModelInputFieldValue, zSDXLRefinerModelInputFieldValue,
zStringCollectionInputFieldValue, zStringCollectionInputFieldValue,
zStringPolymorphicInputFieldValue, zStringPolymorphicInputFieldValue,

View File

@ -44,7 +44,7 @@ export const zSchedulerField = z.enum([
// #endregion // #endregion
// #region Model-related schemas // #region Model-related schemas
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']); const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner', 'flux']);
const zModelName = z.string().min(3); const zModelName = z.string().min(3);
export const zModelIdentifier = z.object({ export const zModelIdentifier = z.object({
model_name: zModelName, model_name: zModelName,

View File

@ -203,6 +203,20 @@ const zSDXLMainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
}); });
// #endregion // #endregion
// #region FluxMainModelField
const zFluxMainModelFieldType = zFieldTypeBase.extend({
name: z.literal('FluxMainModelField'),
});
const zFluxMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
const zFluxMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
type: zFluxMainModelFieldType,
value: zFluxMainModelFieldValue,
});
const zFluxMainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
type: zFluxMainModelFieldType,
});
// #endregion
// #region SDXLRefinerModelField // #region SDXLRefinerModelField
const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({ const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({
name: z.literal('SDXLRefinerModelField'), name: z.literal('SDXLRefinerModelField'),
@ -338,6 +352,7 @@ const zStatefulFieldType = z.union([
zBoardFieldType, zBoardFieldType,
zMainModelFieldType, zMainModelFieldType,
zSDXLMainModelFieldType, zSDXLMainModelFieldType,
zFluxMainModelFieldType,
zSDXLRefinerModelFieldType, zSDXLRefinerModelFieldType,
zVAEModelFieldType, zVAEModelFieldType,
zLoRAModelFieldType, zLoRAModelFieldType,
@ -377,6 +392,7 @@ const zStatefulFieldInputInstance = z.union([
zBoardFieldInputInstance, zBoardFieldInputInstance,
zMainModelFieldInputInstance, zMainModelFieldInputInstance,
zSDXLMainModelFieldInputInstance, zSDXLMainModelFieldInputInstance,
zFluxMainModelFieldInputInstance,
zSDXLRefinerModelFieldInputInstance, zSDXLRefinerModelFieldInputInstance,
zVAEModelFieldInputInstance, zVAEModelFieldInputInstance,
zLoRAModelFieldInputInstance, zLoRAModelFieldInputInstance,
@ -401,6 +417,7 @@ const zStatefulFieldOutputInstance = z.union([
zBoardFieldOutputInstance, zBoardFieldOutputInstance,
zMainModelFieldOutputInstance, zMainModelFieldOutputInstance,
zSDXLMainModelFieldOutputInstance, zSDXLMainModelFieldOutputInstance,
zFluxMainModelFieldOutputInstance,
zSDXLRefinerModelFieldOutputInstance, zSDXLRefinerModelFieldOutputInstance,
zVAEModelFieldOutputInstance, zVAEModelFieldOutputInstance,
zLoRAModelFieldOutputInstance, zLoRAModelFieldOutputInstance,

View File

@ -15,6 +15,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
MainModelField: undefined, MainModelField: undefined,
SchedulerField: 'euler', SchedulerField: 'euler',
SDXLMainModelField: undefined, SDXLMainModelField: undefined,
FluxMainModelField: undefined,
SDXLRefinerModelField: undefined, SDXLRefinerModelField: undefined,
StringField: '', StringField: '',
T2IAdapterModelField: undefined, T2IAdapterModelField: undefined,

View File

@ -8,6 +8,7 @@ import type {
FieldInputTemplate, FieldInputTemplate,
FieldType, FieldType,
FloatFieldInputTemplate, FloatFieldInputTemplate,
FluxMainModelFieldInputTemplate,
ImageFieldInputTemplate, ImageFieldInputTemplate,
IntegerFieldInputTemplate, IntegerFieldInputTemplate,
IPAdapterModelFieldInputTemplate, IPAdapterModelFieldInputTemplate,
@ -180,6 +181,20 @@ const buildSDXLMainModelFieldInputTemplate: FieldInputTemplateBuilder<SDXLMainMo
return template; return template;
}; };
const buildFluxMainModelFieldInputTemplate: FieldInputTemplateBuilder<FluxMainModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: FluxMainModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildRefinerModelFieldInputTemplate: FieldInputTemplateBuilder<SDXLRefinerModelFieldInputTemplate> = ({ const buildRefinerModelFieldInputTemplate: FieldInputTemplateBuilder<SDXLRefinerModelFieldInputTemplate> = ({
schemaObject, schemaObject,
baseField, baseField,
@ -386,6 +401,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
MainModelField: buildMainModelFieldInputTemplate, MainModelField: buildMainModelFieldInputTemplate,
SchedulerField: buildSchedulerFieldInputTemplate, SchedulerField: buildSchedulerFieldInputTemplate,
SDXLMainModelField: buildSDXLMainModelFieldInputTemplate, SDXLMainModelField: buildSDXLMainModelFieldInputTemplate,
FluxMainModelField: buildFluxMainModelFieldInputTemplate,
SDXLRefinerModelField: buildRefinerModelFieldInputTemplate, SDXLRefinerModelField: buildRefinerModelFieldInputTemplate,
StringField: buildStringFieldInputTemplate, StringField: buildStringFieldInputTemplate,
T2IAdapterModelField: buildT2IAdapterModelFieldInputTemplate, T2IAdapterModelField: buildT2IAdapterModelFieldInputTemplate,

View File

@ -29,6 +29,7 @@ const MODEL_FIELD_TYPES = [
'ModelIdentifier', 'ModelIdentifier',
'MainModelField', 'MainModelField',
'SDXLMainModelField', 'SDXLMainModelField',
'FluxMainModelField',
'SDXLRefinerModelField', 'SDXLRefinerModelField',
'VAEModelField', 'VAEModelField',
'LoRAModelField', 'LoRAModelField',

View File

@ -9,6 +9,7 @@ export const MODEL_TYPE_MAP = {
'sd-2': 'Stable Diffusion 2.x', 'sd-2': 'Stable Diffusion 2.x',
sdxl: 'Stable Diffusion XL', sdxl: 'Stable Diffusion XL',
'sdxl-refiner': 'Stable Diffusion XL Refiner', 'sdxl-refiner': 'Stable Diffusion XL Refiner',
flux: 'Flux',
}; };
/** /**
@ -20,6 +21,7 @@ export const MODEL_TYPE_SHORT_MAP = {
'sd-2': 'SD2.X', 'sd-2': 'SD2.X',
sdxl: 'SDXL', sdxl: 'SDXL',
'sdxl-refiner': 'SDXLR', 'sdxl-refiner': 'SDXLR',
flux: 'FLUX',
}; };
/** /**
@ -46,6 +48,10 @@ export const CLIP_SKIP_MAP = {
maxClip: 24, maxClip: 24,
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24], markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
}, },
flux: {
maxClip: 0,
markers: [],
},
}; };
/** /**

View File

@ -5,6 +5,7 @@ import type { AnyModelConfig } from 'services/api/types';
import { import {
isControlNetModelConfig, isControlNetModelConfig,
isControlNetOrT2IAdapterModelConfig, isControlNetOrT2IAdapterModelConfig,
isFluxMainModelModelConfig,
isIPAdapterModelConfig, isIPAdapterModelConfig,
isLoRAModelConfig, isLoRAModelConfig,
isNonRefinerMainModelConfig, isNonRefinerMainModelConfig,
@ -35,6 +36,7 @@ const buildModelsHook =
export const useMainModels = buildModelsHook(isNonRefinerMainModelConfig); export const useMainModels = buildModelsHook(isNonRefinerMainModelConfig);
export const useNonSDXLMainModels = buildModelsHook(isNonSDXLMainModelConfig); export const useNonSDXLMainModels = buildModelsHook(isNonSDXLMainModelConfig);
export const useRefinerModels = buildModelsHook(isRefinerMainModelModelConfig); export const useRefinerModels = buildModelsHook(isRefinerMainModelModelConfig);
export const useFluxModels = buildModelsHook(isFluxMainModelModelConfig);
export const useSDXLModels = buildModelsHook(isSDXLMainModelModelConfig); export const useSDXLModels = buildModelsHook(isSDXLMainModelModelConfig);
export const useLoRAModels = buildModelsHook(isLoRAModelConfig); export const useLoRAModels = buildModelsHook(isLoRAModelConfig);
export const useControlNetAndT2IAdapterModels = buildModelsHook(isControlNetOrT2IAdapterModelConfig); export const useControlNetAndT2IAdapterModels = buildModelsHook(isControlNetOrT2IAdapterModelConfig);

File diff suppressed because one or more lines are too long

View File

@ -118,6 +118,10 @@ export const isSDXLMainModelModelConfig = (config: AnyModelConfig): config is Ma
return config.type === 'main' && config.base === 'sdxl'; return config.type === 'main' && config.base === 'sdxl';
}; };
export const isFluxMainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
return config.type === 'main' && config.base === 'flux';
};
export const isNonSDXLMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => { export const isNonSDXLMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
return config.type === 'main' && (config.base === 'sd-1' || config.base === 'sd-2'); return config.type === 'main' && (config.base === 'sd-1' || config.base === 'sd-2');
}; };