diff --git a/invokeai/backend/quantization/quantize_t5_xxl_bnb_llm_int8.py b/invokeai/backend/quantization/quantize_t5_xxl_bnb_llm_int8.py index a77af4cb24..3a11e6129b 100644 --- a/invokeai/backend/quantization/quantize_t5_xxl_bnb_llm_int8.py +++ b/invokeai/backend/quantization/quantize_t5_xxl_bnb_llm_int8.py @@ -1,17 +1,29 @@ from pathlib import Path import accelerate -from safetensors.torch import load_file, save_model +from safetensors.torch import load_file, save_file from transformers import AutoConfig, AutoModelForTextEncoding, T5EncoderModel from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8 from invokeai.backend.quantization.load_flux_model_bnb_nf4 import log_time +def load_state_dict_into_t5(model: T5EncoderModel, state_dict: dict): + # There is a shared reference to a single weight tensor in the model. + # Both "encoder.embed_tokens.weight" and "shared.weight" refer to the same tensor, so only the latter should + # be present in the state_dict. + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False, assign=True) + assert len(unexpected_keys) == 0 + assert set(missing_keys) == {"encoder.embed_tokens.weight"} + # Assert that the layers we expect to be shared are actually shared. + assert model.encoder.embed_tokens.weight is model.shared.weight + + def main(): # Load the FLUX transformer model onto the meta device. model_path = Path( - "/data/invokeai/models/.download_cache/black-forest-labs_flux.1-schnell/FLUX.1-schnell/text_encoder_2" + # "/data/invokeai/models/.download_cache/black-forest-labs_flux.1-schnell/FLUX.1-schnell/text_encoder_2" + "/data/misc/text_encoder_2" ) with log_time("Intialize T5 on meta device"): @@ -34,10 +46,7 @@ def main(): with log_time("Load state dict into model"): sd = load_file(model_int8_path) - missing_keys, unexpected_keys = model.load_state_dict(sd, strict=False, assign=True) - assert len(unexpected_keys) == 0 - assert set(missing_keys) == {"shared.weight"} - # load_model(model, model_int8_path) + load_state_dict_into_t5(model, sd) with log_time("Move model to cuda"): model = model.to("cuda") @@ -58,17 +67,19 @@ def main(): for file in files: sd = load_file(file) state_dict.update(sd) - # TODO(ryand): Cast the state_dict to the appropriate dtype? - # The state dict is expected to have some extra keys, so we use `strict=False`. - model.load_state_dict(state_dict, strict=True, assign=True) + load_state_dict_into_t5(model, state_dict) with log_time("Move model to cuda and quantize"): model = model.to("cuda") with log_time("Save quantized model"): model_int8_path.parent.mkdir(parents=True, exist_ok=True) - # save_file(model.state_dict(), model_int8_path) - save_model(model, model_int8_path) + state_dict = model.state_dict() + state_dict.pop("encoder.embed_tokens.weight") + save_file(state_dict, model_int8_path) + # This handling of shared weights could also be achieved with save_model(...), but then we'd lose control + # over which keys are kept. And, the corresponding load_model(...) function does not support assign=True. + # save_model(model, model_int8_path) print(f"Successfully quantized and saved model to '{model_int8_path}'.")