diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 756686b548..d68f8eaa97 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -11,6 +11,7 @@ from invokeai.app.invocations.baseinvocation import ( invocation, invocation_output, ) +from invokeai.backend.flux.util import max_seq_lengths from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.shared.models import FreeUConfig @@ -188,17 +189,13 @@ class FluxModelLoaderInvocation(BaseInvocation): vae = self._get_model(context, SubModelType.VAE) transformer_config = context.models.get_config(transformer) assert isinstance(transformer_config, CheckpointConfigBase) - legacy_config_path = context.config.get().legacy_conf_path / transformer_config.config_path - config_path = legacy_config_path.as_posix() - with open(config_path, "r") as stream: - flux_conf = yaml.safe_load(stream) return FluxModelLoaderOutput( transformer=TransformerField(transformer=transformer), clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0), t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder), vae=VAEField(vae=vae), - max_seq_len=flux_conf["max_seq_len"], + max_seq_len=max_seq_lengths[transformer_config.config_path], ) def _get_model(self, context: InvocationContext, submodel: SubModelType) -> ModelIdentifierField: diff --git a/invokeai/backend/flux/util.py b/invokeai/backend/flux/util.py index 112d7111de..40e0554dcd 100644 --- a/invokeai/backend/flux/util.py +++ b/invokeai/backend/flux/util.py @@ -2,6 +2,7 @@ import os from dataclasses import dataclass +from typing import Dict, Literal from invokeai.backend.flux.model import FluxParams from invokeai.backend.flux.modules.autoencoder import AutoEncoderParams @@ -18,6 +19,25 @@ class ModelSpec: repo_ae: str | None +max_seq_lengths: Dict[str, Literal[256, 512]] = { + "flux-dev": 512, + "flux-schnell": 256, +} + + +ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, +) + + configs = { "flux-dev": ModelSpec( repo_id="black-forest-labs/FLUX.1-dev", diff --git a/invokeai/backend/model_manager/load/model_loaders/flux.py b/invokeai/backend/model_manager/load/model_loaders/flux.py index 79613b7602..063367f30d 100644 --- a/invokeai/backend/model_manager/load/model_loaders/flux.py +++ b/invokeai/backend/model_manager/load/model_loaders/flux.py @@ -1,19 +1,18 @@ # Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team """Class for Flux model loading in InvokeAI.""" -from dataclasses import fields from pathlib import Path -from typing import Any, Optional +from typing import Optional import accelerate import torch -import yaml from safetensors.torch import load_file from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer from invokeai.app.services.config.config_default import get_config -from invokeai.backend.flux.model import Flux, FluxParams -from invokeai.backend.flux.modules.autoencoder import AutoEncoder, AutoEncoderParams +from invokeai.backend.flux.model import Flux +from invokeai.backend.flux.util import configs, ae_params +from invokeai.backend.flux.modules.autoencoder import AutoEncoder from invokeai.backend.model_manager import ( AnyModel, AnyModelConfig, @@ -58,17 +57,9 @@ class FluxVAELoader(ModelLoader): if not isinstance(config, VAECheckpointConfig): raise ValueError("Only VAECheckpointConfig models are currently supported here.") model_path = Path(config.path) - legacy_config_path = app_config.legacy_conf_path / config.config_path - config_path = legacy_config_path.as_posix() - with open(config_path, "r") as stream: - flux_conf = yaml.safe_load(stream) - - dataclass_fields = {f.name for f in fields(AutoEncoderParams)} - filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields} - params = AutoEncoderParams(**filtered_data) with SilenceWarnings(): - model = AutoEncoder(params) + model = AutoEncoder(ae_params) sd = load_file(model_path) model.load_state_dict(sd, assign=True) model.to(dtype=self._torch_dtype) @@ -182,14 +173,10 @@ class FluxCheckpointModel(ModelLoader): ) -> AnyModel: if not isinstance(config, CheckpointConfigBase): raise ValueError("Only CheckpointConfigBase models are currently supported here.") - legacy_config_path = app_config.legacy_conf_path / config.config_path - config_path = legacy_config_path.as_posix() - with open(config_path, "r") as stream: - flux_conf = yaml.safe_load(stream) match submodel_type: case SubModelType.Transformer: - return self._load_from_singlefile(config, flux_conf) + return self._load_from_singlefile(config) raise ValueError( f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}" @@ -198,16 +185,12 @@ class FluxCheckpointModel(ModelLoader): def _load_from_singlefile( self, config: AnyModelConfig, - flux_conf: Any, ) -> AnyModel: assert isinstance(config, MainCheckpointConfig) model_path = Path(config.path) - dataclass_fields = {f.name for f in fields(FluxParams)} - filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields} - params = FluxParams(**filtered_data) with SilenceWarnings(): - model = Flux(params) + model = Flux(configs[config.config_path].params) sd = load_file(model_path) model.load_state_dict(sd, assign=True) return model @@ -224,14 +207,10 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader): ) -> AnyModel: if not isinstance(config, CheckpointConfigBase): raise ValueError("Only CheckpointConfigBase models are currently supported here.") - legacy_config_path = app_config.legacy_conf_path / config.config_path - config_path = legacy_config_path.as_posix() - with open(config_path, "r") as stream: - flux_conf = yaml.safe_load(stream) match submodel_type: case SubModelType.Transformer: - return self._load_from_singlefile(config, flux_conf) + return self._load_from_singlefile(config) raise ValueError( f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}" @@ -240,7 +219,6 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader): def _load_from_singlefile( self, config: AnyModelConfig, - flux_conf: Any, ) -> AnyModel: assert isinstance(config, MainBnbQuantized4bCheckpointConfig) if not bnb_available: @@ -248,13 +226,10 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader): "The bnb modules are not available. Please install bitsandbytes if available on your platform." ) model_path = Path(config.path) - dataclass_fields = {f.name for f in fields(FluxParams)} - filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields} - params = FluxParams(**filtered_data) with SilenceWarnings(): with accelerate.init_empty_weights(): - model = Flux(params) + model = Flux(configs[config.config_path].params) model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16) sd = load_file(model_path) model.load_state_dict(sd, assign=True) diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index e552b1cf1e..0ad537a5f3 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -329,9 +329,9 @@ class ModelProbe(object): checkpoint = ModelProbe._scan_and_load_checkpoint(model_path) state_dict = checkpoint.get("state_dict") or checkpoint if "guidance_in.out_layer.weight" in state_dict: - config_file = "flux/flux1-dev.yaml" + config_file = "flux-dev" else: - config_file = "flux/flux1-schnell.yaml" + config_file = "flux-schnell" else: config_file = LEGACY_CONFIGS[base_type][variant_type] if isinstance(config_file, dict): # need another tier for sd-2.x models diff --git a/invokeai/configs/flux/flux1-dev.yaml b/invokeai/configs/flux/flux1-dev.yaml deleted file mode 100644 index 40a5b26a97..0000000000 --- a/invokeai/configs/flux/flux1-dev.yaml +++ /dev/null @@ -1,19 +0,0 @@ -repo_id: "black-forest-labs/FLUX.1-dev" -repo_ae: "ae.safetensors" -max_seq_len: 512 -params: - in_channels: 64 - vec_in_dim: 768 - context_in_dim: 4096 - hidden_size: 3072 - mlp_ratio: 4.0 - num_heads: 24 - depth: 19 - depth_single_blocks: 38 - axes_dim: - - 16 - - 56 - - 56 - theta: 10_000 - qkv_bias: True - guidance_embed: True diff --git a/invokeai/configs/flux/flux1-schnell.yaml b/invokeai/configs/flux/flux1-schnell.yaml deleted file mode 100644 index 2e9208c2c4..0000000000 --- a/invokeai/configs/flux/flux1-schnell.yaml +++ /dev/null @@ -1,19 +0,0 @@ -repo_id: "black-forest-labs/FLUX.1-schnell" -repo_ae: "ae.safetensors" -max_seq_len: 256 -params: - in_channels: 64 - vec_in_dim: 768 - context_in_dim: 4096 - hidden_size: 3072 - mlp_ratio: 4.0 - num_heads: 24 - depth: 19 - depth_single_blocks: 38 - axes_dim: - - 16 - - 56 - - 56 - theta: 10_000 - qkv_bias: True - guidance_embed: False diff --git a/invokeai/configs/flux/flux1-vae.yaml b/invokeai/configs/flux/flux1-vae.yaml deleted file mode 100644 index 2949378a2b..0000000000 --- a/invokeai/configs/flux/flux1-vae.yaml +++ /dev/null @@ -1,16 +0,0 @@ -repo_id: "black-forest-labs/FLUX.1-schnell" -repo_path: "ae.safetensors" -params: - resolution: 256 - in_channels: 3 - ch: 128 - out_ch: 3 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - z_channels: 16 - scale_factor: 0.3611 - shift_factor: 0.1159 \ No newline at end of file