From 65bb46bcca0e2c1cbc3df6a399593178bced8649 Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Mon, 26 Aug 2024 15:42:42 -0400 Subject: [PATCH] Rename params for flux and flux vae, add comments explaining use of the config_path in model config --- invokeai/backend/flux/util.py | 114 ++++++------------ .../model_manager/load/model_loaders/flux.py | 8 +- invokeai/backend/model_manager/probe.py | 14 ++- .../scripts/load_flux_model_bnb_llm_int8.py | 6 +- .../scripts/load_flux_model_bnb_nf4.py | 6 +- 5 files changed, 63 insertions(+), 85 deletions(-) diff --git a/invokeai/backend/flux/util.py b/invokeai/backend/flux/util.py index 748c79e11d..703b032fa3 100644 --- a/invokeai/backend/flux/util.py +++ b/invokeai/backend/flux/util.py @@ -25,82 +25,48 @@ max_seq_lengths: Dict[str, Literal[256, 512]] = { } -ae_params = AutoEncoderParams( - resolution=256, - in_channels=3, - ch=128, - out_ch=3, - ch_mult=[1, 2, 4, 4], - num_res_blocks=2, - z_channels=16, - scale_factor=0.3611, - shift_factor=0.1159, -) +ae_params = { + "flux": AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ) +} -configs = { - "flux-dev": ModelSpec( - repo_id="black-forest-labs/FLUX.1-dev", - repo_flow="flux1-dev.safetensors", - repo_ae="ae.safetensors", - ckpt_path=os.getenv("FLUX_DEV"), - params=FluxParams( - in_channels=64, - vec_in_dim=768, - context_in_dim=4096, - hidden_size=3072, - mlp_ratio=4.0, - num_heads=24, - depth=19, - depth_single_blocks=38, - axes_dim=[16, 56, 56], - theta=10_000, - qkv_bias=True, - guidance_embed=True, - ), - ae_path=os.getenv("AE"), - ae_params=AutoEncoderParams( - resolution=256, - in_channels=3, - ch=128, - out_ch=3, - ch_mult=[1, 2, 4, 4], - num_res_blocks=2, - z_channels=16, - scale_factor=0.3611, - shift_factor=0.1159, - ), +params = { + "flux-dev": FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, ), - "flux-schnell": ModelSpec( - repo_id="black-forest-labs/FLUX.1-schnell", - repo_flow="flux1-schnell.safetensors", - repo_ae="ae.safetensors", - ckpt_path=os.getenv("FLUX_SCHNELL"), - params=FluxParams( - in_channels=64, - vec_in_dim=768, - context_in_dim=4096, - hidden_size=3072, - mlp_ratio=4.0, - num_heads=24, - depth=19, - depth_single_blocks=38, - axes_dim=[16, 56, 56], - theta=10_000, - qkv_bias=True, - guidance_embed=False, - ), - ae_path=os.getenv("AE"), - ae_params=AutoEncoderParams( - resolution=256, - in_channels=3, - ch=128, - out_ch=3, - ch_mult=[1, 2, 4, 4], - num_res_blocks=2, - z_channels=16, - scale_factor=0.3611, - shift_factor=0.1159, - ), + "flux-schnell": FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=False, ), } diff --git a/invokeai/backend/model_manager/load/model_loaders/flux.py b/invokeai/backend/model_manager/load/model_loaders/flux.py index 53119f6df0..0316de6044 100644 --- a/invokeai/backend/model_manager/load/model_loaders/flux.py +++ b/invokeai/backend/model_manager/load/model_loaders/flux.py @@ -12,7 +12,7 @@ from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CL from invokeai.app.services.config.config_default import get_config from invokeai.backend.flux.model import Flux from invokeai.backend.flux.modules.autoencoder import AutoEncoder -from invokeai.backend.flux.util import ae_params, configs +from invokeai.backend.flux.util import ae_params, params from invokeai.backend.model_manager import ( AnyModel, AnyModelConfig, @@ -59,7 +59,7 @@ class FluxVAELoader(ModelLoader): model_path = Path(config.path) with SilenceWarnings(): - model = AutoEncoder(ae_params) + model = AutoEncoder(ae_params[config.config_path]) sd = load_file(model_path) model.load_state_dict(sd, assign=True) model.to(dtype=self._torch_dtype) @@ -188,7 +188,7 @@ class FluxCheckpointModel(ModelLoader): model_path = Path(config.path) with SilenceWarnings(): - model = Flux(configs[config.config_path].params) + model = Flux(params[config.config_path]) sd = load_file(model_path) model.load_state_dict(sd, assign=True) return model @@ -227,7 +227,7 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader): with SilenceWarnings(): with accelerate.init_empty_weights(): - model = Flux(configs[config.config_path].params) + model = Flux(params[config.config_path]) model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16) sd = load_file(model_path) model.load_state_dict(sd, assign=True) diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index 0ad537a5f3..029366e357 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -329,8 +329,16 @@ class ModelProbe(object): checkpoint = ModelProbe._scan_and_load_checkpoint(model_path) state_dict = checkpoint.get("state_dict") or checkpoint if "guidance_in.out_layer.weight" in state_dict: + # For flux, this is a key in invokeai.backend.flux.util.params + # Due to model type and format being the descriminator for model configs this + # is used rather than attempting to support flux with separate model types and format + # If changed in the future, please fix me config_file = "flux-dev" else: + # For flux, this is a key in invokeai.backend.flux.util.params + # Due to model type and format being the descriminator for model configs this + # is used rather than attempting to support flux with separate model types and format + # If changed in the future, please fix me config_file = "flux-schnell" else: config_file = LEGACY_CONFIGS[base_type][variant_type] @@ -345,7 +353,11 @@ class ModelProbe(object): ) elif model_type is ModelType.VAE: config_file = ( - "flux/flux1-vae.yaml" + # For flux, this is a key in invokeai.backend.flux.util.ae_params + # Due to model type and format being the descriminator for model configs this + # is used rather than attempting to support flux with separate model types and format + # If changed in the future, please fix me + "flux" if base_type is BaseModelType.Flux else "stable-diffusion/v1-inference.yaml" if base_type is BaseModelType.StableDiffusion1 diff --git a/invokeai/backend/quantization/scripts/load_flux_model_bnb_llm_int8.py b/invokeai/backend/quantization/scripts/load_flux_model_bnb_llm_int8.py index 51c787d8ef..804336e000 100644 --- a/invokeai/backend/quantization/scripts/load_flux_model_bnb_llm_int8.py +++ b/invokeai/backend/quantization/scripts/load_flux_model_bnb_llm_int8.py @@ -4,7 +4,7 @@ import accelerate from safetensors.torch import load_file, save_file from invokeai.backend.flux.model import Flux -from invokeai.backend.flux.util import configs as flux_configs +from invokeai.backend.flux.util import params from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8 from invokeai.backend.quantization.scripts.load_flux_model_bnb_nf4 import log_time @@ -22,11 +22,11 @@ def main(): with log_time("Intialize FLUX transformer on meta device"): # TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config. - params = flux_configs["flux-schnell"].params + p = params["flux-schnell"] # Initialize the model on the "meta" device. with accelerate.init_empty_weights(): - model = Flux(params) + model = Flux(p) # TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate # `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize. diff --git a/invokeai/backend/quantization/scripts/load_flux_model_bnb_nf4.py b/invokeai/backend/quantization/scripts/load_flux_model_bnb_nf4.py index 5415407a2b..f1621dbc6d 100644 --- a/invokeai/backend/quantization/scripts/load_flux_model_bnb_nf4.py +++ b/invokeai/backend/quantization/scripts/load_flux_model_bnb_nf4.py @@ -7,7 +7,7 @@ import torch from safetensors.torch import load_file, save_file from invokeai.backend.flux.model import Flux -from invokeai.backend.flux.util import configs as flux_configs +from invokeai.backend.flux.util import params from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4 @@ -35,11 +35,11 @@ def main(): # inference_dtype = torch.bfloat16 with log_time("Intialize FLUX transformer on meta device"): # TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config. - params = flux_configs["flux-schnell"].params + p = params["flux-schnell"] # Initialize the model on the "meta" device. with accelerate.init_empty_weights(): - model = Flux(params) + model = Flux(p) # TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate # `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.