diff --git a/invokeai/backend/quantization/quantize_t5_xxl_bnb_llm_int8.py b/invokeai/backend/quantization/quantize_t5_xxl_bnb_llm_int8.py new file mode 100644 index 0000000000..a77af4cb24 --- /dev/null +++ b/invokeai/backend/quantization/quantize_t5_xxl_bnb_llm_int8.py @@ -0,0 +1,80 @@ +from pathlib import Path + +import accelerate +from safetensors.torch import load_file, save_model +from transformers import AutoConfig, AutoModelForTextEncoding, T5EncoderModel + +from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8 +from invokeai.backend.quantization.load_flux_model_bnb_nf4 import log_time + + +def main(): + # Load the FLUX transformer model onto the meta device. + model_path = Path( + "/data/invokeai/models/.download_cache/black-forest-labs_flux.1-schnell/FLUX.1-schnell/text_encoder_2" + ) + + with log_time("Intialize T5 on meta device"): + model_config = AutoConfig.from_pretrained(model_path) + with accelerate.init_empty_weights(): + model = AutoModelForTextEncoding.from_config(model_config) + + # TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate + # `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize. + modules_to_not_convert: set[str] = set() + + model_int8_path = model_path / "bnb_llm_int8.safetensors" + if model_int8_path.exists(): + # The quantized model already exists, load it and return it. + print(f"A pre-quantized model already exists at '{model_int8_path}'. Attempting to load it...") + + # Replace the linear layers with LLM.int8() quantized linear layers (still on the meta device). + with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights(): + model = quantize_model_llm_int8(model, modules_to_not_convert=modules_to_not_convert) + + with log_time("Load state dict into model"): + sd = load_file(model_int8_path) + missing_keys, unexpected_keys = model.load_state_dict(sd, strict=False, assign=True) + assert len(unexpected_keys) == 0 + assert set(missing_keys) == {"shared.weight"} + # load_model(model, model_int8_path) + + with log_time("Move model to cuda"): + model = model.to("cuda") + + print(f"Successfully loaded pre-quantized model from '{model_int8_path}'.") + + else: + # The quantized model does not exist, quantize the model and save it. + print(f"No pre-quantized model found at '{model_int8_path}'. Quantizing the model...") + + with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights(): + model = quantize_model_llm_int8(model, modules_to_not_convert=modules_to_not_convert) + + with log_time("Load state dict into model"): + # Load sharded state dict. + files = list(model_path.glob("*.safetensors")) + state_dict = {} + for file in files: + sd = load_file(file) + state_dict.update(sd) + # TODO(ryand): Cast the state_dict to the appropriate dtype? + # The state dict is expected to have some extra keys, so we use `strict=False`. + model.load_state_dict(state_dict, strict=True, assign=True) + + with log_time("Move model to cuda and quantize"): + model = model.to("cuda") + + with log_time("Save quantized model"): + model_int8_path.parent.mkdir(parents=True, exist_ok=True) + # save_file(model.state_dict(), model_int8_path) + save_model(model, model_int8_path) + + print(f"Successfully quantized and saved model to '{model_int8_path}'.") + + assert isinstance(model, T5EncoderModel) + return model + + +if __name__ == "__main__": + main()