mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fixes to the T5XXL quantization script.
This commit is contained in:
parent
33c2fbd201
commit
b9dd354e2b
@ -1,17 +1,29 @@
|
||||
from pathlib import Path
|
||||
|
||||
import accelerate
|
||||
from safetensors.torch import load_file, save_model
|
||||
from safetensors.torch import load_file, save_file
|
||||
from transformers import AutoConfig, AutoModelForTextEncoding, T5EncoderModel
|
||||
|
||||
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
|
||||
from invokeai.backend.quantization.load_flux_model_bnb_nf4 import log_time
|
||||
|
||||
|
||||
def load_state_dict_into_t5(model: T5EncoderModel, state_dict: dict):
|
||||
# There is a shared reference to a single weight tensor in the model.
|
||||
# Both "encoder.embed_tokens.weight" and "shared.weight" refer to the same tensor, so only the latter should
|
||||
# be present in the state_dict.
|
||||
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False, assign=True)
|
||||
assert len(unexpected_keys) == 0
|
||||
assert set(missing_keys) == {"encoder.embed_tokens.weight"}
|
||||
# Assert that the layers we expect to be shared are actually shared.
|
||||
assert model.encoder.embed_tokens.weight is model.shared.weight
|
||||
|
||||
|
||||
def main():
|
||||
# Load the FLUX transformer model onto the meta device.
|
||||
model_path = Path(
|
||||
"/data/invokeai/models/.download_cache/black-forest-labs_flux.1-schnell/FLUX.1-schnell/text_encoder_2"
|
||||
# "/data/invokeai/models/.download_cache/black-forest-labs_flux.1-schnell/FLUX.1-schnell/text_encoder_2"
|
||||
"/data/misc/text_encoder_2"
|
||||
)
|
||||
|
||||
with log_time("Intialize T5 on meta device"):
|
||||
@ -34,10 +46,7 @@ def main():
|
||||
|
||||
with log_time("Load state dict into model"):
|
||||
sd = load_file(model_int8_path)
|
||||
missing_keys, unexpected_keys = model.load_state_dict(sd, strict=False, assign=True)
|
||||
assert len(unexpected_keys) == 0
|
||||
assert set(missing_keys) == {"shared.weight"}
|
||||
# load_model(model, model_int8_path)
|
||||
load_state_dict_into_t5(model, sd)
|
||||
|
||||
with log_time("Move model to cuda"):
|
||||
model = model.to("cuda")
|
||||
@ -58,17 +67,19 @@ def main():
|
||||
for file in files:
|
||||
sd = load_file(file)
|
||||
state_dict.update(sd)
|
||||
# TODO(ryand): Cast the state_dict to the appropriate dtype?
|
||||
# The state dict is expected to have some extra keys, so we use `strict=False`.
|
||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
load_state_dict_into_t5(model, state_dict)
|
||||
|
||||
with log_time("Move model to cuda and quantize"):
|
||||
model = model.to("cuda")
|
||||
|
||||
with log_time("Save quantized model"):
|
||||
model_int8_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
# save_file(model.state_dict(), model_int8_path)
|
||||
save_model(model, model_int8_path)
|
||||
state_dict = model.state_dict()
|
||||
state_dict.pop("encoder.embed_tokens.weight")
|
||||
save_file(state_dict, model_int8_path)
|
||||
# This handling of shared weights could also be achieved with save_model(...), but then we'd lose control
|
||||
# over which keys are kept. And, the corresponding load_model(...) function does not support assign=True.
|
||||
# save_model(model, model_int8_path)
|
||||
|
||||
print(f"Successfully quantized and saved model to '{model_int8_path}'.")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user