Fixes to the T5XXL quantization script.

This commit is contained in:
Ryan Dick 2024-08-23 14:06:08 +00:00 committed by Brandon
parent 33c2fbd201
commit b9dd354e2b

View File

@ -1,17 +1,29 @@
from pathlib import Path
import accelerate
from safetensors.torch import load_file, save_model
from safetensors.torch import load_file, save_file
from transformers import AutoConfig, AutoModelForTextEncoding, T5EncoderModel
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
from invokeai.backend.quantization.load_flux_model_bnb_nf4 import log_time
def load_state_dict_into_t5(model: T5EncoderModel, state_dict: dict):
# There is a shared reference to a single weight tensor in the model.
# Both "encoder.embed_tokens.weight" and "shared.weight" refer to the same tensor, so only the latter should
# be present in the state_dict.
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False, assign=True)
assert len(unexpected_keys) == 0
assert set(missing_keys) == {"encoder.embed_tokens.weight"}
# Assert that the layers we expect to be shared are actually shared.
assert model.encoder.embed_tokens.weight is model.shared.weight
def main():
# Load the FLUX transformer model onto the meta device.
model_path = Path(
"/data/invokeai/models/.download_cache/black-forest-labs_flux.1-schnell/FLUX.1-schnell/text_encoder_2"
# "/data/invokeai/models/.download_cache/black-forest-labs_flux.1-schnell/FLUX.1-schnell/text_encoder_2"
"/data/misc/text_encoder_2"
)
with log_time("Intialize T5 on meta device"):
@ -34,10 +46,7 @@ def main():
with log_time("Load state dict into model"):
sd = load_file(model_int8_path)
missing_keys, unexpected_keys = model.load_state_dict(sd, strict=False, assign=True)
assert len(unexpected_keys) == 0
assert set(missing_keys) == {"shared.weight"}
# load_model(model, model_int8_path)
load_state_dict_into_t5(model, sd)
with log_time("Move model to cuda"):
model = model.to("cuda")
@ -58,17 +67,19 @@ def main():
for file in files:
sd = load_file(file)
state_dict.update(sd)
# TODO(ryand): Cast the state_dict to the appropriate dtype?
# The state dict is expected to have some extra keys, so we use `strict=False`.
model.load_state_dict(state_dict, strict=True, assign=True)
load_state_dict_into_t5(model, state_dict)
with log_time("Move model to cuda and quantize"):
model = model.to("cuda")
with log_time("Save quantized model"):
model_int8_path.parent.mkdir(parents=True, exist_ok=True)
# save_file(model.state_dict(), model_int8_path)
save_model(model, model_int8_path)
state_dict = model.state_dict()
state_dict.pop("encoder.embed_tokens.weight")
save_file(state_dict, model_int8_path)
# This handling of shared weights could also be achieved with save_model(...), but then we'd lose control
# over which keys are kept. And, the corresponding load_model(...) function does not support assign=True.
# save_model(model, model_int8_path)
print(f"Successfully quantized and saved model to '{model_int8_path}'.")