diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index 9efcf2148f..91dfcb51a7 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -40,6 +40,7 @@ class UIType(str, Enum, metaclass=MetaEnum): # region Model Field Types MainModel = "MainModelField" + FluxMainModel = "FluxMainModelField" SDXLMainModel = "SDXLMainModelField" SDXLRefinerModel = "SDXLRefinerModelField" ONNXModel = "ONNXModelField" @@ -126,12 +127,14 @@ class FieldDescriptions: noise = "Noise tensor" clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count" unet = "UNet (scheduler, LoRAs)" + transformer = "Transformer" vae = "VAE" cond = "Conditioning tensor" controlnet_model = "ControlNet model to load" vae_model = "VAE model to load" lora_model = "LoRA model to load" main_model = "Main model (UNet, VAE, CLIP) to load" + flux_model = "Flux model (Transformer, VAE, CLIP) to load" sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load" sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load" onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load" diff --git a/invokeai/app/invocations/flux_text_to_image.py b/invokeai/app/invocations/flux_text_to_image.py index 19829c47a4..7a577215f8 100644 --- a/invokeai/app/invocations/flux_text_to_image.py +++ b/invokeai/app/invocations/flux_text_to_image.py @@ -1,14 +1,13 @@ from pathlib import Path from typing import Literal +from pydantic import Field import accelerate import torch from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel -from einops import rearrange, repeat -from flux.model import Flux -from flux.modules.autoencoder import AutoEncoder -from flux.sampling import denoise, get_noise, get_schedule, unpack -from flux.util import configs as flux_configs +from diffusers.pipelines.flux.pipeline_flux import FluxPipeline +from invokeai.app.invocations.model import ModelIdentifierField +from optimum.quanto import qfloat8 from PIL import Image from safetensors.torch import load_file from transformers.models.auto import AutoModelForTextEncoding @@ -21,6 +20,7 @@ from invokeai.app.invocations.fields import ( InputField, WithBoard, WithMetadata, + UIType, ) from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.shared.invocation_context import InvocationContext @@ -52,6 +52,11 @@ class QuantizedModelForTextEncoding(FastQuantizedTransformersModel): class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): """Text-to-image generation using a FLUX model.""" + flux_model: ModelIdentifierField = InputField( + description="The Flux model", + input=Input.Any, + ui_type=UIType.FluxMainModel + ) model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.") quantization_type: Literal["raw", "NF4", "llm_int8"] = InputField( default="raw", description="The type of quantization to use for the transformer model." diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index c0d067c0a7..dd12109269 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -60,6 +60,12 @@ class CLIPField(BaseModel): loras: List[LoRAField] = Field(description="LoRAs to apply on model loading") + +class TransformerField(BaseModel): + transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel") + scheduler: ModelIdentifierField = Field(description="Info to load scheduler submodel") + + class VAEField(BaseModel): vae: ModelIdentifierField = Field(description="Info to load vae submodel") seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless') @@ -122,6 +128,49 @@ class ModelIdentifierInvocation(BaseInvocation): return ModelIdentifierOutput(model=self.model) +@invocation_output("flux_model_loader_output") +class FluxModelLoaderOutput(BaseInvocationOutput): + """Flux base model loader output""" + + transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer") + clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 1") + clip2: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 2") + vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE") + + +@invocation("flux_model_loader", title="Flux Main Model", tags=["model", "flux"], category="model", version="1.0.3") +class FluxModelLoaderInvocation(BaseInvocation): + """Loads a flux base model, outputting its submodels.""" + + model: ModelIdentifierField = InputField( + description=FieldDescriptions.flux_model, + ui_type=UIType.FluxMainModel, + input=Input.Direct, + ) + + def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput: + model_key = self.model.key + + # TODO: not found exceptions + if not context.models.exists(model_key): + raise Exception(f"Unknown model: {model_key}") + + transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer}) + scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler}) + tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer}) + text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder}) + tokenizer2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2}) + text_encoder2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2}) + vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE}) + + return FluxModelLoaderOutput( + transformer=TransformerField(transformer=transformer, scheduler=scheduler), + clip=CLIPField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0), + clip2=CLIPField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0), + vae=VAEField(vae=vae), + ) + + @invocation( "main_model_loader", title="Main Model", diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 332ac6c8fa..29ef953666 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -52,6 +52,7 @@ class BaseModelType(str, Enum): StableDiffusion2 = "sd-2" StableDiffusionXL = "sdxl" StableDiffusionXLRefiner = "sdxl-refiner" + Flux = "flux" # Kandinsky2_1 = "kandinsky-2.1" @@ -74,6 +75,7 @@ class SubModelType(str, Enum): """Submodel type.""" UNet = "unet" + Transformer = "transformer" TextEncoder = "text_encoder" TextEncoder2 = "text_encoder_2" Tokenizer = "tokenizer" diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index 1929b3f4fd..82053149ad 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -95,6 +95,7 @@ class ModelProbe(object): } CLASS2TYPE = { + "FluxPipeline": ModelType.Main, "StableDiffusionPipeline": ModelType.Main, "StableDiffusionInpaintPipeline": ModelType.Main, "StableDiffusionXLPipeline": ModelType.Main, @@ -626,6 +627,10 @@ class FolderProbeBase(ProbeBase): class PipelineFolderProbe(FolderProbeBase): def get_base_type(self) -> BaseModelType: + with open(f"{self.model_path}/model_index.json", "r") as file: + conf = json.load(file) + if "_class_name" in conf and conf.get("_class_name") == "FluxPipeline": + return BaseModelType.Flux with open(self.model_path / "unet" / "config.json", "r") as file: unet_conf = json.load(file) if unet_conf["cross_attention_dim"] == 768: diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge.tsx index bf07bad58c..2cf4e25354 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge.tsx @@ -13,6 +13,7 @@ const BASE_COLOR_MAP: Record = { 'sd-2': 'teal', sdxl: 'invokeBlue', 'sdxl-refiner': 'invokeBlue', + flux: 'invokeBlue', }; const ModelBaseBadge = ({ base }: Props) => { diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx index d863def973..6ec51aba13 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx @@ -14,6 +14,8 @@ import { isEnumFieldInputTemplate, isFloatFieldInputInstance, isFloatFieldInputTemplate, + isFluxMainModelFieldInputInstance, + isFluxMainModelFieldInputTemplate, isImageFieldInputInstance, isImageFieldInputTemplate, isIntegerFieldInputInstance, @@ -48,6 +50,7 @@ import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent'; import ColorFieldInputComponent from './inputs/ColorFieldInputComponent'; import ControlNetModelFieldInputComponent from './inputs/ControlNetModelFieldInputComponent'; import EnumFieldInputComponent from './inputs/EnumFieldInputComponent'; +import FluxMainModelFieldInputComponent from './inputs/FluxMainModelFieldInputComponent'; import ImageFieldInputComponent from './inputs/ImageFieldInputComponent'; import IPAdapterModelFieldInputComponent from './inputs/IPAdapterModelFieldInputComponent'; import LoRAModelFieldInputComponent from './inputs/LoRAModelFieldInputComponent'; @@ -69,6 +72,7 @@ type InputFieldProps = { const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => { const fieldInstance = useFieldInputInstance(nodeId, fieldName); const fieldTemplate = useFieldInputTemplate(nodeId, fieldName); + window.console.log("Hit 0") if (isStringFieldInputInstance(fieldInstance) && isStringFieldInputTemplate(fieldTemplate)) { return ; @@ -145,6 +149,9 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => { if (isColorFieldInputInstance(fieldInstance) && isColorFieldInputTemplate(fieldTemplate)) { return ; } + if (isFluxMainModelFieldInputInstance(fieldInstance) && isFluxMainModelFieldInputTemplate(fieldTemplate)) { + return ; + } if (isSDXLMainModelFieldInputInstance(fieldInstance) && isSDXLMainModelFieldInputTemplate(fieldTemplate)) { return ; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/FluxMainModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/FluxMainModelFieldInputComponent.tsx new file mode 100644 index 0000000000..3a0ddb211e --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/FluxMainModelFieldInputComponent.tsx @@ -0,0 +1,55 @@ +import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library'; +import { useAppDispatch } from 'app/store/storeHooks'; +import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; +import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice'; +import type { FluxMainModelFieldInputInstance, FluxMainModelFieldInputTemplate } from 'features/nodes/types/field'; +import { memo, useCallback } from 'react'; +import { useFluxModels } from 'services/api/hooks/modelsByType'; +import type { MainModelConfig } from 'services/api/types'; + +import type { FieldComponentProps } from './types'; + +type Props = FieldComponentProps; + +const FluxMainModelFieldInputComponent = (props: Props) => { + const { nodeId, field } = props; + const dispatch = useAppDispatch(); + const [modelConfigs, { isLoading }] = useFluxModels(); + const _onChange = useCallback( + (value: MainModelConfig | null) => { + if (!value) { + return; + } + dispatch( + fieldMainModelValueChanged({ + nodeId, + fieldName: field.name, + value, + }) + ); + }, + [dispatch, field.name, nodeId] + ); + const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({ + modelConfigs, + onChange: _onChange, + isLoading, + selectedModel: field.value, + }); + + return ( + + + + + + ); +}; + +export default memo(FluxMainModelFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/types/common.ts b/invokeai/frontend/web/src/features/nodes/types/common.ts index c84b2dae62..894d257f28 100644 --- a/invokeai/frontend/web/src/features/nodes/types/common.ts +++ b/invokeai/frontend/web/src/features/nodes/types/common.ts @@ -61,7 +61,7 @@ export type SchedulerField = z.infer; // #endregion // #region Model-related schemas -const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']); +const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner', 'flux']); const zModelType = z.enum([ 'main', 'vae', @@ -76,6 +76,7 @@ const zModelType = z.enum([ ]); const zSubModelType = z.enum([ 'unet', + 'transformer', 'text_encoder', 'text_encoder_2', 'tokenizer', diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts index 05697c384c..ca43f35b55 100644 --- a/invokeai/frontend/web/src/features/nodes/types/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts @@ -31,6 +31,7 @@ export const MODEL_TYPES = [ 'ControlNetModelField', 'LoRAModelField', 'MainModelField', + 'FluxMainModelField', 'SDXLMainModelField', 'SDXLRefinerModelField', 'VaeModelField', @@ -61,6 +62,7 @@ export const FIELD_COLORS: { [key: string]: string } = { LatentsField: 'pink.500', LoRAModelField: 'teal.500', MainModelField: 'teal.500', + FluxMainModelField: 'teal.500', SDXLMainModelField: 'teal.500', SDXLRefinerModelField: 'teal.500', SpandrelImageToImageModelField: 'teal.500', @@ -68,6 +70,7 @@ export const FIELD_COLORS: { [key: string]: string } = { T2IAdapterField: 'teal.500', T2IAdapterModelField: 'teal.500', UNetField: 'red.500', + TransformerField: 'red.500', VAEField: 'blue.500', VAEModelField: 'teal.500', }; diff --git a/invokeai/frontend/web/src/features/nodes/types/field.ts b/invokeai/frontend/web/src/features/nodes/types/field.ts index 925bd40b9d..607a1005ac 100644 --- a/invokeai/frontend/web/src/features/nodes/types/field.ts +++ b/invokeai/frontend/web/src/features/nodes/types/field.ts @@ -115,6 +115,10 @@ const zSDXLMainModelFieldType = zFieldTypeBase.extend({ name: z.literal('SDXLMainModelField'), originalType: zStatelessFieldType.optional(), }); +const zFluxMainModelFieldType = zFieldTypeBase.extend({ + name: z.literal('FluxMainModelField'), + originalType: zStatelessFieldType.optional(), +}); const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({ name: z.literal('SDXLRefinerModelField'), originalType: zStatelessFieldType.optional(), @@ -158,6 +162,7 @@ const zStatefulFieldType = z.union([ zModelIdentifierFieldType, zMainModelFieldType, zSDXLMainModelFieldType, + zFluxMainModelFieldType, zSDXLRefinerModelFieldType, zVAEModelFieldType, zLoRAModelFieldType, @@ -447,6 +452,29 @@ export const isSDXLMainModelFieldInputTemplate = (val: unknown): val is SDXLMain zSDXLMainModelFieldInputTemplate.safeParse(val).success; // #endregion +// #region FluxMainModelField + +const zFluxMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only. +const zFluxMainModelFieldInputInstance = zFieldInputInstanceBase.extend({ + value: zFluxMainModelFieldValue, +}); +const zFluxMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zFluxMainModelFieldType, + originalType: zFieldType.optional(), + default: zFluxMainModelFieldValue, +}); +const zFluxMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zFluxMainModelFieldType, +}); +export type FluxMainModelFieldInputInstance = z.infer; +export type FluxMainModelFieldInputTemplate = z.infer; +export const isFluxMainModelFieldInputInstance = (val: unknown): val is FluxMainModelFieldInputInstance => + zFluxMainModelFieldInputInstance.safeParse(val).success; +export const isFluxMainModelFieldInputTemplate = (val: unknown): val is FluxMainModelFieldInputTemplate => + zFluxMainModelFieldInputTemplate.safeParse(val).success; + +// #endregion + // #region SDXLRefinerModelField /** @alias */ // tells knip to ignore this duplicate export @@ -693,6 +721,7 @@ export const zStatefulFieldValue = z.union([ zModelIdentifierFieldValue, zMainModelFieldValue, zSDXLMainModelFieldValue, + zFluxMainModelFieldValue, zSDXLRefinerModelFieldValue, zVAEModelFieldValue, zLoRAModelFieldValue, @@ -720,6 +749,7 @@ const zStatefulFieldInputInstance = z.union([ zBoardFieldInputInstance, zModelIdentifierFieldInputInstance, zMainModelFieldInputInstance, + zFluxMainModelFieldInputInstance, zSDXLMainModelFieldInputInstance, zSDXLRefinerModelFieldInputInstance, zVAEModelFieldInputInstance, @@ -749,6 +779,7 @@ const zStatefulFieldInputTemplate = z.union([ zBoardFieldInputTemplate, zModelIdentifierFieldInputTemplate, zMainModelFieldInputTemplate, + zFluxMainModelFieldInputTemplate, zSDXLMainModelFieldInputTemplate, zSDXLRefinerModelFieldInputTemplate, zVAEModelFieldInputTemplate, @@ -779,6 +810,7 @@ const zStatefulFieldOutputTemplate = z.union([ zBoardFieldOutputTemplate, zModelIdentifierFieldOutputTemplate, zMainModelFieldOutputTemplate, + zFluxMainModelFieldOutputTemplate, zSDXLMainModelFieldOutputTemplate, zSDXLRefinerModelFieldOutputTemplate, zVAEModelFieldOutputTemplate, diff --git a/invokeai/frontend/web/src/features/nodes/types/v1/fieldTypeMap.ts b/invokeai/frontend/web/src/features/nodes/types/v1/fieldTypeMap.ts index f1d4e61300..719063cf68 100644 --- a/invokeai/frontend/web/src/features/nodes/types/v1/fieldTypeMap.ts +++ b/invokeai/frontend/web/src/features/nodes/types/v1/fieldTypeMap.ts @@ -114,6 +114,11 @@ const FIELD_TYPE_V1_TO_STATEFUL_FIELD_TYPE_V2: { isCollection: false, isCollectionOrScalar: false, }, + FluxMainModelField: { + name: 'FluxMainModelField', + isCollection: false, + isCollectionOrScalar: false, + }, SDXLMainModelField: { name: 'SDXLMainModelField', isCollection: false, diff --git a/invokeai/frontend/web/src/features/nodes/types/v1/workflowV1.ts b/invokeai/frontend/web/src/features/nodes/types/v1/workflowV1.ts index c7a50b20e4..b4ec9cd94e 100644 --- a/invokeai/frontend/web/src/features/nodes/types/v1/workflowV1.ts +++ b/invokeai/frontend/web/src/features/nodes/types/v1/workflowV1.ts @@ -27,7 +27,7 @@ const zScheduler = z.enum([ 'kdpm_2_a', 'lcm', ]); -const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']); +const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner', 'flux']); const zMainModel = z.object({ model_name: z.string().min(1), base_model: zBaseModel, @@ -89,6 +89,7 @@ const zFieldTypeV1 = z.enum([ 'ONNXModelField', 'Scheduler', 'SDXLMainModelField', + 'FluxMainModelField', 'SDXLRefinerModelField', 'string', 'StringCollection', @@ -417,6 +418,11 @@ const zSDXLMainModelInputFieldValue = zInputFieldValueBase.extend({ value: zMainOrOnnxModel.optional(), }); +const zFluxMainModelInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('FluxMainModelField'), + value: zMainModel.optional(), +}); + const zSDXLRefinerModelInputFieldValue = zInputFieldValueBase.extend({ type: z.literal('SDXLRefinerModelField'), value: zMainOrOnnxModel.optional(), // TODO: should narrow this down to a refiner model @@ -572,6 +578,7 @@ const zInputFieldValue = z.discriminatedUnion('type', [ zMainModelInputFieldValue, zSchedulerInputFieldValue, zSDXLMainModelInputFieldValue, + zFluxMainModelInputFieldValue, zSDXLRefinerModelInputFieldValue, zStringCollectionInputFieldValue, zStringPolymorphicInputFieldValue, diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/common.ts b/invokeai/frontend/web/src/features/nodes/types/v2/common.ts index 8613076132..64d4db0451 100644 --- a/invokeai/frontend/web/src/features/nodes/types/v2/common.ts +++ b/invokeai/frontend/web/src/features/nodes/types/v2/common.ts @@ -44,7 +44,7 @@ export const zSchedulerField = z.enum([ // #endregion // #region Model-related schemas -const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']); +const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner', 'flux']); const zModelName = z.string().min(3); export const zModelIdentifier = z.object({ model_name: zModelName, diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/field.ts b/invokeai/frontend/web/src/features/nodes/types/v2/field.ts index 4b680d1de3..a02a998508 100644 --- a/invokeai/frontend/web/src/features/nodes/types/v2/field.ts +++ b/invokeai/frontend/web/src/features/nodes/types/v2/field.ts @@ -203,6 +203,20 @@ const zSDXLMainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ }); // #endregion +// #region FluxMainModelField +const zFluxMainModelFieldType = zFieldTypeBase.extend({ + name: z.literal('FluxMainModelField'), +}); +const zFluxMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only. +const zFluxMainModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zFluxMainModelFieldType, + value: zFluxMainModelFieldValue, +}); +const zFluxMainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zFluxMainModelFieldType, +}); +// #endregion + // #region SDXLRefinerModelField const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({ name: z.literal('SDXLRefinerModelField'), @@ -338,6 +352,7 @@ const zStatefulFieldType = z.union([ zBoardFieldType, zMainModelFieldType, zSDXLMainModelFieldType, + zFluxMainModelFieldType, zSDXLRefinerModelFieldType, zVAEModelFieldType, zLoRAModelFieldType, @@ -377,6 +392,7 @@ const zStatefulFieldInputInstance = z.union([ zBoardFieldInputInstance, zMainModelFieldInputInstance, zSDXLMainModelFieldInputInstance, + zFluxMainModelFieldInputInstance, zSDXLRefinerModelFieldInputInstance, zVAEModelFieldInputInstance, zLoRAModelFieldInputInstance, @@ -401,6 +417,7 @@ const zStatefulFieldOutputInstance = z.union([ zBoardFieldOutputInstance, zMainModelFieldOutputInstance, zSDXLMainModelFieldOutputInstance, + zFluxMainModelFieldOutputInstance, zSDXLRefinerModelFieldOutputInstance, zVAEModelFieldOutputInstance, zLoRAModelFieldOutputInstance, diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts index a5a2d89f03..e8784a1163 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts @@ -15,6 +15,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record = MainModelField: undefined, SchedulerField: 'euler', SDXLMainModelField: undefined, + FluxMainModelField: undefined, SDXLRefinerModelField: undefined, StringField: '', T2IAdapterModelField: undefined, diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts index 8478415cd1..f4f3ef85af 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts @@ -8,6 +8,7 @@ import type { FieldInputTemplate, FieldType, FloatFieldInputTemplate, + FluxMainModelFieldInputTemplate, ImageFieldInputTemplate, IntegerFieldInputTemplate, IPAdapterModelFieldInputTemplate, @@ -180,6 +181,20 @@ const buildSDXLMainModelFieldInputTemplate: FieldInputTemplateBuilder = ({ + schemaObject, + baseField, + fieldType, +}) => { + const template: FluxMainModelFieldInputTemplate = { + ...baseField, + type: fieldType, + default: schemaObject.default ?? undefined, + }; + + return template; +}; + const buildRefinerModelFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, @@ -386,6 +401,7 @@ export const TEMPLATE_BUILDER_MAP: Record { + return config.type === 'main' && config.base === 'flux'; +}; + export const isNonSDXLMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => { return config.type === 'main' && (config.base === 'sd-1' || config.base === 'sd-2'); };