mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fixes to the T5XXL quantization script.
This commit is contained in:
parent
6d838fa997
commit
86e49c423c
@ -1,17 +1,29 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import accelerate
|
import accelerate
|
||||||
from safetensors.torch import load_file, save_model
|
from safetensors.torch import load_file, save_file
|
||||||
from transformers import AutoConfig, AutoModelForTextEncoding, T5EncoderModel
|
from transformers import AutoConfig, AutoModelForTextEncoding, T5EncoderModel
|
||||||
|
|
||||||
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
|
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
|
||||||
from invokeai.backend.quantization.load_flux_model_bnb_nf4 import log_time
|
from invokeai.backend.quantization.load_flux_model_bnb_nf4 import log_time
|
||||||
|
|
||||||
|
|
||||||
|
def load_state_dict_into_t5(model: T5EncoderModel, state_dict: dict):
|
||||||
|
# There is a shared reference to a single weight tensor in the model.
|
||||||
|
# Both "encoder.embed_tokens.weight" and "shared.weight" refer to the same tensor, so only the latter should
|
||||||
|
# be present in the state_dict.
|
||||||
|
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False, assign=True)
|
||||||
|
assert len(unexpected_keys) == 0
|
||||||
|
assert set(missing_keys) == {"encoder.embed_tokens.weight"}
|
||||||
|
# Assert that the layers we expect to be shared are actually shared.
|
||||||
|
assert model.encoder.embed_tokens.weight is model.shared.weight
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# Load the FLUX transformer model onto the meta device.
|
# Load the FLUX transformer model onto the meta device.
|
||||||
model_path = Path(
|
model_path = Path(
|
||||||
"/data/invokeai/models/.download_cache/black-forest-labs_flux.1-schnell/FLUX.1-schnell/text_encoder_2"
|
# "/data/invokeai/models/.download_cache/black-forest-labs_flux.1-schnell/FLUX.1-schnell/text_encoder_2"
|
||||||
|
"/data/misc/text_encoder_2"
|
||||||
)
|
)
|
||||||
|
|
||||||
with log_time("Intialize T5 on meta device"):
|
with log_time("Intialize T5 on meta device"):
|
||||||
@ -34,10 +46,7 @@ def main():
|
|||||||
|
|
||||||
with log_time("Load state dict into model"):
|
with log_time("Load state dict into model"):
|
||||||
sd = load_file(model_int8_path)
|
sd = load_file(model_int8_path)
|
||||||
missing_keys, unexpected_keys = model.load_state_dict(sd, strict=False, assign=True)
|
load_state_dict_into_t5(model, sd)
|
||||||
assert len(unexpected_keys) == 0
|
|
||||||
assert set(missing_keys) == {"shared.weight"}
|
|
||||||
# load_model(model, model_int8_path)
|
|
||||||
|
|
||||||
with log_time("Move model to cuda"):
|
with log_time("Move model to cuda"):
|
||||||
model = model.to("cuda")
|
model = model.to("cuda")
|
||||||
@ -58,17 +67,19 @@ def main():
|
|||||||
for file in files:
|
for file in files:
|
||||||
sd = load_file(file)
|
sd = load_file(file)
|
||||||
state_dict.update(sd)
|
state_dict.update(sd)
|
||||||
# TODO(ryand): Cast the state_dict to the appropriate dtype?
|
load_state_dict_into_t5(model, state_dict)
|
||||||
# The state dict is expected to have some extra keys, so we use `strict=False`.
|
|
||||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
|
||||||
|
|
||||||
with log_time("Move model to cuda and quantize"):
|
with log_time("Move model to cuda and quantize"):
|
||||||
model = model.to("cuda")
|
model = model.to("cuda")
|
||||||
|
|
||||||
with log_time("Save quantized model"):
|
with log_time("Save quantized model"):
|
||||||
model_int8_path.parent.mkdir(parents=True, exist_ok=True)
|
model_int8_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
# save_file(model.state_dict(), model_int8_path)
|
state_dict = model.state_dict()
|
||||||
save_model(model, model_int8_path)
|
state_dict.pop("encoder.embed_tokens.weight")
|
||||||
|
save_file(state_dict, model_int8_path)
|
||||||
|
# This handling of shared weights could also be achieved with save_model(...), but then we'd lose control
|
||||||
|
# over which keys are kept. And, the corresponding load_model(...) function does not support assign=True.
|
||||||
|
# save_model(model, model_int8_path)
|
||||||
|
|
||||||
print(f"Successfully quantized and saved model to '{model_int8_path}'.")
|
print(f"Successfully quantized and saved model to '{model_int8_path}'.")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user