mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add script for quantizing a T5 model.
This commit is contained in:
parent
5063be92bf
commit
33c2fbd201
@ -0,0 +1,80 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import accelerate
|
||||||
|
from safetensors.torch import load_file, save_model
|
||||||
|
from transformers import AutoConfig, AutoModelForTextEncoding, T5EncoderModel
|
||||||
|
|
||||||
|
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
|
||||||
|
from invokeai.backend.quantization.load_flux_model_bnb_nf4 import log_time
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Load the FLUX transformer model onto the meta device.
|
||||||
|
model_path = Path(
|
||||||
|
"/data/invokeai/models/.download_cache/black-forest-labs_flux.1-schnell/FLUX.1-schnell/text_encoder_2"
|
||||||
|
)
|
||||||
|
|
||||||
|
with log_time("Intialize T5 on meta device"):
|
||||||
|
model_config = AutoConfig.from_pretrained(model_path)
|
||||||
|
with accelerate.init_empty_weights():
|
||||||
|
model = AutoModelForTextEncoding.from_config(model_config)
|
||||||
|
|
||||||
|
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
|
||||||
|
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.
|
||||||
|
modules_to_not_convert: set[str] = set()
|
||||||
|
|
||||||
|
model_int8_path = model_path / "bnb_llm_int8.safetensors"
|
||||||
|
if model_int8_path.exists():
|
||||||
|
# The quantized model already exists, load it and return it.
|
||||||
|
print(f"A pre-quantized model already exists at '{model_int8_path}'. Attempting to load it...")
|
||||||
|
|
||||||
|
# Replace the linear layers with LLM.int8() quantized linear layers (still on the meta device).
|
||||||
|
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
|
||||||
|
model = quantize_model_llm_int8(model, modules_to_not_convert=modules_to_not_convert)
|
||||||
|
|
||||||
|
with log_time("Load state dict into model"):
|
||||||
|
sd = load_file(model_int8_path)
|
||||||
|
missing_keys, unexpected_keys = model.load_state_dict(sd, strict=False, assign=True)
|
||||||
|
assert len(unexpected_keys) == 0
|
||||||
|
assert set(missing_keys) == {"shared.weight"}
|
||||||
|
# load_model(model, model_int8_path)
|
||||||
|
|
||||||
|
with log_time("Move model to cuda"):
|
||||||
|
model = model.to("cuda")
|
||||||
|
|
||||||
|
print(f"Successfully loaded pre-quantized model from '{model_int8_path}'.")
|
||||||
|
|
||||||
|
else:
|
||||||
|
# The quantized model does not exist, quantize the model and save it.
|
||||||
|
print(f"No pre-quantized model found at '{model_int8_path}'. Quantizing the model...")
|
||||||
|
|
||||||
|
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
|
||||||
|
model = quantize_model_llm_int8(model, modules_to_not_convert=modules_to_not_convert)
|
||||||
|
|
||||||
|
with log_time("Load state dict into model"):
|
||||||
|
# Load sharded state dict.
|
||||||
|
files = list(model_path.glob("*.safetensors"))
|
||||||
|
state_dict = {}
|
||||||
|
for file in files:
|
||||||
|
sd = load_file(file)
|
||||||
|
state_dict.update(sd)
|
||||||
|
# TODO(ryand): Cast the state_dict to the appropriate dtype?
|
||||||
|
# The state dict is expected to have some extra keys, so we use `strict=False`.
|
||||||
|
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||||
|
|
||||||
|
with log_time("Move model to cuda and quantize"):
|
||||||
|
model = model.to("cuda")
|
||||||
|
|
||||||
|
with log_time("Save quantized model"):
|
||||||
|
model_int8_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
# save_file(model.state_dict(), model_int8_path)
|
||||||
|
save_model(model, model_int8_path)
|
||||||
|
|
||||||
|
print(f"Successfully quantized and saved model to '{model_int8_path}'.")
|
||||||
|
|
||||||
|
assert isinstance(model, T5EncoderModel)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
Reference in New Issue
Block a user