From 203542c7a8de9f0fc2df10853b8aa44cfd3f148f Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 21 Aug 2024 19:03:09 +0000 Subject: [PATCH] Update load_flux_model_bnb_llm_int8.py to work with a single-file FLUX transformer checkpoint. --- .../load_flux_model_bnb_llm_int8.py | 41 +++++++++---------- 1 file changed, 19 insertions(+), 22 deletions(-) diff --git a/invokeai/backend/quantization/load_flux_model_bnb_llm_int8.py b/invokeai/backend/quantization/load_flux_model_bnb_llm_int8.py index a24370967c..47ce0f56b1 100644 --- a/invokeai/backend/quantization/load_flux_model_bnb_llm_int8.py +++ b/invokeai/backend/quantization/load_flux_model_bnb_llm_int8.py @@ -1,7 +1,8 @@ from pathlib import Path import accelerate -from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel +from flux.model import Flux +from flux.util import configs as flux_configs from safetensors.torch import load_file, save_file from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8 @@ -11,30 +12,32 @@ from invokeai.backend.quantization.load_flux_model_bnb_nf4 import log_time def main(): # Load the FLUX transformer model onto the meta device. model_path = Path( - "/data/invokeai/models/.download_cache/black-forest-labs_flux.1-schnell/FLUX.1-schnell/transformer/" + "/data/invokeai/models/.download_cache/https__huggingface.co_black-forest-labs_flux.1-schnell_resolve_main_flux1-schnell.safetensors/flux1-schnell.safetensors" ) - with log_time("Initialize FLUX transformer on meta device"): - model_config = FluxTransformer2DModel.load_config(model_path, local_files_only=True) + with log_time("Intialize FLUX transformer on meta device"): + # TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config. + params = flux_configs["flux-schnell"].params + + # Initialize the model on the "meta" device. with accelerate.init_empty_weights(): - empty_model = FluxTransformer2DModel.from_config(model_config) - assert isinstance(empty_model, FluxTransformer2DModel) + model = Flux(params) # TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate # `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize. modules_to_not_convert: set[str] = set() - model_int8_path = model_path / "bnb_llm_int8" + model_int8_path = model_path.parent / "bnb_llm_int8.safetensors" if model_int8_path.exists(): # The quantized model already exists, load it and return it. print(f"A pre-quantized model already exists at '{model_int8_path}'. Attempting to load it...") # Replace the linear layers with LLM.int8() quantized linear layers (still on the meta device). with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights(): - model = quantize_model_llm_int8(empty_model, modules_to_not_convert=modules_to_not_convert) + model = quantize_model_llm_int8(model, modules_to_not_convert=modules_to_not_convert) with log_time("Load state dict into model"): - sd = load_file(model_int8_path / "model.safetensors") + sd = load_file(model_int8_path) model.load_state_dict(sd, strict=True, assign=True) with log_time("Move model to cuda"): @@ -47,29 +50,23 @@ def main(): print(f"No pre-quantized model found at '{model_int8_path}'. Quantizing the model...") with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights(): - model = quantize_model_llm_int8(empty_model, modules_to_not_convert=modules_to_not_convert) + model = quantize_model_llm_int8(model, modules_to_not_convert=modules_to_not_convert) with log_time("Load state dict into model"): - # Load sharded state dict. - files = list(model_path.glob("*.safetensors")) - state_dict = {} - for file in files: - sd = load_file(file) - state_dict.update(sd) - + state_dict = load_file(model_path) + # TODO(ryand): Cast the state_dict to the appropriate dtype? model.load_state_dict(state_dict, strict=True, assign=True) with log_time("Move model to cuda and quantize"): model = model.to("cuda") with log_time("Save quantized model"): - model_int8_path.mkdir(parents=True, exist_ok=True) - output_path = model_int8_path / "model.safetensors" - save_file(model.state_dict(), output_path) + model_int8_path.parent.mkdir(parents=True, exist_ok=True) + save_file(model.state_dict(), model_int8_path) - print(f"Successfully quantized and saved model to '{output_path}'.") + print(f"Successfully quantized and saved model to '{model_int8_path}'.") - assert isinstance(model, FluxTransformer2DModel) + assert isinstance(model, Flux) return model