LLM.int8() quantization is working, but still some rough edges to solve.

This commit is contained in:
Ryan Dick
2024-08-15 19:34:34 +00:00
committed by Brandon
parent 99b0f79784
commit f01f56a98e
7 changed files with 221 additions and 77 deletions

View File

@ -1,6 +1,5 @@
from typing import Any, Optional, Set, Type
import accelerate
import bitsandbytes as bnb
import torch
@ -460,27 +459,6 @@ def _convert_linear_layers(
_convert_linear_layers(child, linear_cls, ignore_modules, prefix=fullname)
def _convert_linear_layers_to_llm_8bit(module: torch.nn.Module, ignore_modules: Set[str], prefix: str = "") -> None:
for name, child in module.named_children():
fullname = f"{prefix}.{name}" if prefix else name
if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules):
has_bias = child.bias is not None
replacement = InvokeLinear8bitLt(
child.in_features,
child.out_features,
bias=has_bias,
has_fp16_weights=False,
# device=device,
)
replacement.weight.data = child.weight.data
if has_bias:
replacement.bias.data = child.bias.data
replacement.requires_grad_(False)
module.__setattr__(name, replacement)
else:
_convert_linear_layers_to_llm_8bit(child, ignore_modules, prefix=fullname)
# def _replace_linear_layers(
# model: torch.nn.Module,
# linear_layer_type: Literal["Linear8bitLt", "Linear4bit"],
@ -537,21 +515,3 @@ def _convert_linear_layers_to_llm_8bit(module: torch.nn.Module, ignore_modules:
# # Remove the last key for recursion
# current_key_name.pop(-1)
# return model, has_been_replaced
def get_parameter_device(parameter: torch.nn.Module):
return next(parameter.parameters()).device
def quantize_model_llm_int8(model: torch.nn.Module, modules_to_not_convert: set[str]):
"""Apply bitsandbytes LLM.8bit() quantization to the model."""
model_device = get_parameter_device(model)
if model_device.type != "meta":
# Note: This is not strictly required, but I can't think of a good reason to quantize a model that's not on the
# meta device, so we enforce it for now.
raise RuntimeError("The model should be on the meta device to apply LLM.8bit() quantization.")
with accelerate.init_empty_weights():
_convert_linear_layers_to_llm_8bit(module=model, ignore_modules=modules_to_not_convert)
return model

View File

@ -0,0 +1,102 @@
import bitsandbytes as bnb
import torch
# This file contains utils for working with models that use bitsandbytes LLM.int8() quantization.
# The utils in this file are partially inspired by:
# https://github.com/Lightning-AI/pytorch-lightning/blob/1551a16b94f5234a4a78801098f64d0732ef5cb5/src/lightning/fabric/plugins/precision/bitsandbytes.py
# NOTE(ryand): All of the custom state_dict manipulation logic in this file is pretty hacky. This could be made much
# cleaner by re-implementing bnb.nn.Linear8bitLt with proper use of buffers and less magic. But, for now, we try to
# stick close to the bitsandbytes classes to make interoperability easier with other models that might use bitsandbytes.
class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
def _load_from_state_dict(
self,
state_dict: dict[str, torch.Tensor],
prefix: str,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
weight = state_dict.pop(prefix + "weight")
bias = state_dict.pop(prefix + "bias", None)
# See `bnb.nn.Linear8bitLt._save_to_state_dict()` for the serialization logic of SCB and weight_format.
scb = state_dict.pop(prefix + "SCB", None)
# weight_format is unused, but we pop it so we can validate that there are no unexpected keys.
_weight_format = state_dict.pop(prefix + "weight_format", None)
# TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs`
# rather than raising an exception to correctly implement this API.
assert len(state_dict) == 0
if scb is not None:
# We are loading a pre-quantized state dict.
self.weight = bnb.nn.Int8Params(
data=weight,
requires_grad=self.weight.requires_grad,
has_fp16_weights=False,
# Note: After quantization, CB is the same as weight.
CB=weight,
SCB=scb,
)
self.bias = bias if bias is None else torch.nn.Parameter(bias)
else:
# We are loading a non-quantized state dict.
# We could simply call the `super()._load_from_state_dict()` method here, but then we wouldn't be able to
# load from a state_dict into a model on the "meta" device. Attempting to load into a model on the "meta"
# device requires setting `assign=True`, doing this with the default `super()._load_from_state_dict()`
# implementation causes `Params4Bit` to be replaced by a `torch.nn.Parameter`. By initializing a new
# `Params4bit` object, we work around this issue. It's a bit hacky, but it gets the job done.
self.weight = bnb.nn.Int8Params(
data=weight,
requires_grad=self.weight.requires_grad,
has_fp16_weights=False,
CB=None,
SCB=None,
)
self.bias = bias if bias is None else torch.nn.Parameter(bias)
def _convert_linear_layers_to_llm_8bit(
module: torch.nn.Module, ignore_modules: set[str], outlier_threshold: float, prefix: str = ""
) -> None:
"""Convert all linear layers in the module to bnb.nn.Linear8bitLt layers."""
for name, child in module.named_children():
fullname = f"{prefix}.{name}" if prefix else name
if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules):
has_bias = child.bias is not None
replacement = InvokeLinear8bitLt(
child.in_features,
child.out_features,
bias=has_bias,
has_fp16_weights=False,
threshold=outlier_threshold,
)
replacement.weight.data = child.weight.data
if has_bias:
replacement.bias.data = child.bias.data
replacement.requires_grad_(False)
module.__setattr__(name, replacement)
else:
_convert_linear_layers_to_llm_8bit(
child, ignore_modules, outlier_threshold=outlier_threshold, prefix=fullname
)
def get_parameter_device(parameter: torch.nn.Module):
return next(parameter.parameters()).device
def quantize_model_llm_int8(model: torch.nn.Module, modules_to_not_convert: set[str], outlier_threshold: float = 6.0):
"""Apply bitsandbytes LLM.8bit() quantization to the model."""
_convert_linear_layers_to_llm_8bit(
module=model, ignore_modules=modules_to_not_convert, outlier_threshold=outlier_threshold
)
return model

View File

@ -5,6 +5,10 @@ import torch
# The utils in this file are partially inspired by:
# https://github.com/Lightning-AI/pytorch-lightning/blob/1551a16b94f5234a4a78801098f64d0732ef5cb5/src/lightning/fabric/plugins/precision/bitsandbytes.py
# NOTE(ryand): All of the custom state_dict manipulation logic in this file is pretty hacky. This could be made much
# cleaner by re-implementing bnb.nn.LinearNF4 with proper use of buffers and less magic. But, for now, we try to stick
# close to the bitsandbytes classes to make interoperability easier with other models that might use bitsandbytes.
class InvokeLinearNF4(bnb.nn.LinearNF4):
"""A class that extends `bnb.nn.LinearNF4` to add the following functionality:

View File

@ -0,0 +1,89 @@
import time
from contextlib import contextmanager
from pathlib import Path
import accelerate
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from safetensors.torch import load_file, save_file
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
@contextmanager
def log_time(name: str):
"""Helper context manager to log the time taken by a block of code."""
start = time.time()
try:
yield None
finally:
end = time.time()
print(f"'{name}' took {end - start:.4f} secs")
def main():
# Load the FLUX transformer model onto the meta device.
model_path = Path(
"/data/invokeai/models/.download_cache/black-forest-labs_flux.1-schnell/FLUX.1-schnell/transformer/"
)
with log_time("Initialize FLUX transformer on meta device"):
model_config = FluxTransformer2DModel.load_config(model_path, local_files_only=True)
with accelerate.init_empty_weights():
empty_model = FluxTransformer2DModel.from_config(model_config)
assert isinstance(empty_model, FluxTransformer2DModel)
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.
modules_to_not_convert: set[str] = set()
model_int8_path = model_path / "bnb_llm_int8"
if model_int8_path.exists():
# The quantized model already exists, load it and return it.
print(f"A pre-quantized model already exists at '{model_int8_path}'. Attempting to load it...")
# Replace the linear layers with LLM.int8() quantized linear layers (still on the meta device).
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
model = quantize_model_llm_int8(empty_model, modules_to_not_convert=modules_to_not_convert)
with log_time("Load state dict into model"):
sd = load_file(model_int8_path / "model.safetensors")
model.load_state_dict(sd, strict=True, assign=True)
with log_time("Move model to cuda"):
model = model.to("cuda")
print(f"Successfully loaded pre-quantized model from '{model_int8_path}'.")
else:
# The quantized model does not exist, quantize the model and save it.
print(f"No pre-quantized model found at '{model_int8_path}'. Quantizing the model...")
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
model = quantize_model_llm_int8(empty_model, modules_to_not_convert=modules_to_not_convert)
with log_time("Load state dict into model"):
# Load sharded state dict.
files = list(model_path.glob("*.safetensors"))
state_dict = dict()
for file in files:
sd = load_file(file)
state_dict.update(sd)
model.load_state_dict(state_dict, strict=True, assign=True)
with log_time("Move model to cuda and quantize"):
model = model.to("cuda")
with log_time("Save quantized model"):
model_int8_path.mkdir(parents=True, exist_ok=True)
output_path = model_int8_path / "model.safetensors"
save_file(model.state_dict(), output_path)
print(f"Successfully quantized and saved model to '{output_path}'.")
assert isinstance(model, FluxTransformer2DModel)
return model
if __name__ == "__main__":
main()