2024-08-15 16:30:47 +00:00
from typing import Any , Optional , Set , Type
2024-08-14 04:06:16 +00:00
import bitsandbytes as bnb
import torch
# The utils in this file take ideas from
# https://github.com/Lightning-AI/pytorch-lightning/blob/1551a16b94f5234a4a78801098f64d0732ef5cb5/src/lightning/fabric/plugins/precision/bitsandbytes.py
# Patterns:
# - Quantize:
# - Initialize model on meta device
# - Replace layers
# - Load state_dict to cpu
# - Load state_dict into model
# - Quantize on GPU
# - Extract state_dict
# - Save
# - Load:
# - Initialize model on meta device
# - Replace layers
# - Load state_dict to cpu
# - Load state_dict into model on cpu
# - Move to GPU
# class InvokeInt8Params(bnb.nn.Int8Params):
# """Overrides `bnb.nn.Int8Params` to add the following functionality:
# - Make it possible to load a quantized state dict without putting the weight on a "cuda" device.
# """
# def quantize(self, device: Optional[torch.device] = None):
# device = device or torch.device("cuda")
# if device.type != "cuda":
# raise RuntimeError(f"Int8Params quantization is only supported on CUDA devices ({device=}).")
# # https://github.com/TimDettmers/bitsandbytes/blob/0.41.0/bitsandbytes/nn/modules.py#L291-L302
# B = self.data.contiguous().half().cuda(device)
# if self.has_fp16_weights:
# self.data = B
# else:
# # we store the 8-bit rows-major weight
# # we convert this weight to the turning/ampere weight during the first inference pass
# CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
# del CBt
# del SCBt
# self.data = CB
# self.CB = CB
# self.SCB = SCB
class Invoke2Linear8bitLt ( torch . nn . Linear ) :
""" This class is the base module for the [LLM.int8()](https://arxiv.org/abs/2208.07339) algorithm. """
def __init__ (
self ,
input_features : int ,
output_features : int ,
bias = True ,
has_fp16_weights = True ,
memory_efficient_backward = False ,
threshold = 0.0 ,
index = None ,
device = None ,
) :
"""
Initialize Linear8bitLt class .
Args :
input_features ( ` int ` ) :
Number of input features of the linear layer .
output_features ( ` int ` ) :
Number of output features of the linear layer .
bias ( ` bool ` , defaults to ` True ` ) :
Whether the linear class uses the bias term as well .
"""
super ( ) . __init__ ( input_features , output_features , bias , device )
assert not memory_efficient_backward , " memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0 "
self . state = bnb . MatmulLtState ( )
self . index = index
self . state . threshold = threshold
self . state . has_fp16_weights = has_fp16_weights
self . state . memory_efficient_backward = memory_efficient_backward
if threshold > 0.0 and not has_fp16_weights :
self . state . use_pool = True
self . weight = Int8Params ( self . weight . data , has_fp16_weights = has_fp16_weights , requires_grad = has_fp16_weights )
self . _register_load_state_dict_pre_hook ( maybe_rearrange_weight )
def _save_to_state_dict ( self , destination , prefix , keep_vars ) :
super ( ) . _save_to_state_dict ( destination , prefix , keep_vars )
# we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data
scb_name = " SCB "
# case 1: .cuda was called, SCB is in self.weight
param_from_weight = getattr ( self . weight , scb_name )
# case 2: self.init_8bit_state was called, SCB is in self.state
param_from_state = getattr ( self . state , scb_name )
# case 3: SCB is in self.state, weight layout reordered after first forward()
layout_reordered = self . state . CxB is not None
key_name = prefix + f " { scb_name } "
format_name = prefix + " weight_format "
if not self . state . has_fp16_weights :
if param_from_weight is not None :
destination [ key_name ] = param_from_weight if keep_vars else param_from_weight . detach ( )
destination [ format_name ] = torch . tensor ( 0 , dtype = torch . uint8 )
elif param_from_state is not None and not layout_reordered :
destination [ key_name ] = param_from_state if keep_vars else param_from_state . detach ( )
destination [ format_name ] = torch . tensor ( 0 , dtype = torch . uint8 )
elif param_from_state is not None :
destination [ key_name ] = param_from_state if keep_vars else param_from_state . detach ( )
weights_format = self . state . formatB
# At this point `weights_format` is an str
if weights_format not in LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING :
raise ValueError ( f " Unrecognized weights format { weights_format } " )
weights_format = LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING [ weights_format ]
destination [ format_name ] = torch . tensor ( weights_format , dtype = torch . uint8 )
def _load_from_state_dict (
self ,
state_dict ,
prefix ,
local_metadata ,
strict ,
missing_keys ,
unexpected_keys ,
error_msgs ,
) :
super ( ) . _load_from_state_dict (
state_dict ,
prefix ,
local_metadata ,
strict ,
missing_keys ,
unexpected_keys ,
error_msgs ,
)
unexpected_copy = list ( unexpected_keys )
for key in unexpected_copy :
input_name = key [ len ( prefix ) : ]
if input_name == " SCB " :
if self . weight . SCB is None :
# buffers not yet initialized, can't access them directly without quantizing first
raise RuntimeError (
" Loading a quantized checkpoint into non-quantized Linear8bitLt is "
" not supported. Please call module.cuda() before module.load_state_dict() " ,
)
input_param = state_dict [ key ]
self . weight . SCB . copy_ ( input_param )
if self . state . SCB is not None :
self . state . SCB = self . weight . SCB
unexpected_keys . remove ( key )
def init_8bit_state ( self ) :
self . state . CB = self . weight . CB
self . state . SCB = self . weight . SCB
self . weight . CB = None
self . weight . SCB = None
def forward ( self , x : torch . Tensor ) :
self . state . is_training = self . training
if self . weight . CB is not None :
self . init_8bit_state ( )
# weights are cast automatically as Int8Params, but the bias has to be cast manually
if self . bias is not None and self . bias . dtype != x . dtype :
self . bias . data = self . bias . data . to ( x . dtype )
out = bnb . matmul ( x , self . weight , bias = self . bias , state = self . state )
if not self . state . has_fp16_weights :
if self . state . CB is not None and self . state . CxB is not None :
# we converted 8-bit row major to turing/ampere format in the first inference pass
# we no longer need the row-major weight
del self . state . CB
self . weight . data = self . state . CxB
return out
class InvokeLinear8bitLt ( bnb . nn . Linear8bitLt ) :
""" Wraps `bnb.nn.Linear8bitLt` and adds the following functionality:
- enables instantiation directly on the device
- re - quantizaton when loading the state dict
"""
def __init__ (
self , * args : Any , device : Optional [ torch . device ] = None , threshold : float = 6.0 , * * kwargs : Any
) - > None :
super ( ) . __init__ ( * args , device = device , threshold = threshold , * * kwargs )
# If the device is CUDA or we are under a CUDA context manager, quantize the weight here, so we don't end up
# filling the device memory with float32 weights which could lead to OOM
# if torch.tensor(0, device=device).device.type == "cuda":
# self.quantize_()
# self._register_load_state_dict_pre_hook(partial(_quantize_on_load_hook, self.quantize_))
# self.register_load_state_dict_post_hook(_ignore_missing_weights_hook)
def _load_from_state_dict (
self ,
state_dict ,
prefix ,
local_metadata ,
strict ,
missing_keys ,
unexpected_keys ,
error_msgs ,
) :
super ( ) . _load_from_state_dict (
state_dict ,
prefix ,
local_metadata ,
strict ,
missing_keys ,
unexpected_keys ,
error_msgs ,
)
unexpected_copy = list ( unexpected_keys )
for key in unexpected_copy :
input_name = key [ len ( prefix ) : ]
if input_name == " SCB " :
if self . weight . SCB is None :
# buffers not yet initialized, can't access them directly without quantizing first
raise RuntimeError (
" Loading a quantized checkpoint into non-quantized Linear8bitLt is "
" not supported. Please call module.cuda() before module.load_state_dict() " ,
)
input_param = state_dict [ key ]
self . weight . SCB . copy_ ( input_param )
if self . state . SCB is not None :
self . state . SCB = self . weight . SCB
unexpected_keys . remove ( key )
def quantize_ ( self , weight : Optional [ torch . Tensor ] = None , device : Optional [ torch . device ] = None ) - > None :
""" Inplace quantize. """
if weight is None :
weight = self . weight . data
if weight . data . dtype == torch . int8 :
# already quantized
return
assert isinstance ( self . weight , bnb . nn . Int8Params )
self . weight = self . quantize ( self . weight , weight , device )
@staticmethod
def quantize (
int8params : bnb . nn . Int8Params , weight : torch . Tensor , device : Optional [ torch . device ]
) - > bnb . nn . Int8Params :
device = device or torch . device ( " cuda " )
if device . type != " cuda " :
raise RuntimeError ( f " Unexpected device type: { device . type } " )
# https://github.com/TimDettmers/bitsandbytes/blob/0.41.0/bitsandbytes/nn/modules.py#L291-L302
B = weight . contiguous ( ) . to ( device = device , dtype = torch . float16 )
if int8params . has_fp16_weights :
int8params . data = B
else :
CB , CBt , SCB , SCBt , _ = bnb . functional . double_quant ( B )
del CBt
del SCBt
int8params . data = CB
int8params . CB = CB
int8params . SCB = SCB
return int8params
# class _Linear4bit(bnb.nn.Linear4bit):
# """Wraps `bnb.nn.Linear4bit` to enable: instantiation directly on the device, re-quantizaton when loading the
# state dict, meta-device initialization, and materialization."""
# def __init__(self, *args: Any, device: Optional[torch.device] = None, **kwargs: Any) -> None:
# super().__init__(*args, device=device, **kwargs)
# self.weight = cast(bnb.nn.Params4bit, self.weight) # type: ignore[has-type]
# self.bias = cast(Optional[torch.nn.Parameter], self.bias) # type: ignore[has-type]
# # if the device is CUDA or we are under a CUDA context manager, quantize the weight here, so we don't end up
# # filling the device memory with float32 weights which could lead to OOM
# if torch.tensor(0, device=device).device.type == "cuda":
# self.quantize_()
# self._register_load_state_dict_pre_hook(partial(_quantize_on_load_hook, self.quantize_))
# self.register_load_state_dict_post_hook(_ignore_missing_weights_hook)
# def quantize_(self, weight: Optional[torch.Tensor] = None, device: Optional[torch.device] = None) -> None:
# """Inplace quantize."""
# if weight is None:
# weight = self.weight.data
# if weight.data.dtype == torch.uint8:
# # already quantized
# return
# assert isinstance(self.weight, bnb.nn.Params4bit)
# self.weight = self.quantize(self.weight, weight, device)
# @staticmethod
# def quantize(
# params4bit: bnb.nn.Params4bit, weight: torch.Tensor, device: Optional[torch.device]
# ) -> bnb.nn.Params4bit:
# device = device or torch.device("cuda")
# if device.type != "cuda":
# raise RuntimeError(f"Unexpected device type: {device.type}")
# # https://github.com/TimDettmers/bitsandbytes/blob/0.41.0/bitsandbytes/nn/modules.py#L156-L159
# w = weight.contiguous().to(device=device, dtype=torch.half)
# w_4bit, quant_state = bnb.functional.quantize_4bit(
# w,
# blocksize=params4bit.blocksize,
# compress_statistics=params4bit.compress_statistics,
# quant_type=params4bit.quant_type,
# )
# return _replace_param(params4bit, w_4bit, quant_state)
# def to_empty(self, *, device: _DEVICE, recurse: bool = True) -> Self:
# if self.weight.dtype == torch.uint8: # was quantized
# # cannot init the quantized params directly
# weight = torch.empty(self.weight.quant_state.shape, device=device, dtype=torch.half)
# else:
# weight = torch.empty_like(self.weight.data, device=device)
# device = torch.device(device)
# if device.type == "cuda": # re-quantize
# self.quantize_(weight, device)
# else:
# self.weight = _replace_param(self.weight, weight)
# if self.bias is not None:
# self.bias = _replace_param(self.bias, torch.empty_like(self.bias, device=device))
# return self
def convert_model_to_bnb_llm_int8 ( model : torch . nn . Module , ignore_modules : set [ str ] ) :
linear_cls = InvokeLinear8bitLt
_convert_linear_layers ( model , linear_cls , ignore_modules )
# TODO(ryand): Is this necessary?
# set the compute dtype if necessary
# for m in model.modules():
# if isinstance(m, bnb.nn.Linear4bit):
# m.compute_dtype = self.dtype
# m.compute_type_is_set = False
# class BitsandbytesPrecision(Precision):
# """Plugin for quantizing weights with `bitsandbytes <https://github.com/TimDettmers/bitsandbytes>`__.
# .. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.
# .. note::
# The optimizer is not automatically replaced with ``bitsandbytes.optim.Adam8bit`` or equivalent 8-bit optimizers.
# Args:
# mode: The quantization mode to use.
# dtype: The compute dtype to use.
# ignore_modules: The submodules whose Linear layers should not be replaced, for example. ``{"lm_head"}``.
# This might be desirable for numerical stability. The string will be checked in as a prefix, so a value like
# "transformer.blocks" will ignore all linear layers in all of the transformer blocks.
# """
# def __init__(
# self,
# mode: Literal["nf4", "nf4-dq", "fp4", "fp4-dq", "int8", "int8-training"],
# dtype: Optional[torch.dtype] = None,
# ignore_modules: Optional[Set[str]] = None,
# ) -> None:
# if dtype is None:
# # try to be smart about the default selection
# if mode.startswith("int8"):
# dtype = torch.float16
# else:
# dtype = (
# torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
# )
# if mode.startswith("int8") and dtype is not torch.float16:
# # this limitation is mentioned in https://huggingface.co/blog/hf-bitsandbytes-integration#usage
# raise ValueError(f"{mode!r} only works with `dtype=torch.float16`, but you chose `{dtype}`")
# globals_ = globals()
# mode_to_cls = {
# "nf4": globals_["_NF4Linear"],
# "nf4-dq": globals_["_NF4DQLinear"],
# "fp4": globals_["_FP4Linear"],
# "fp4-dq": globals_["_FP4DQLinear"],
# "int8-training": globals_["_Linear8bitLt"],
# "int8": globals_["_Int8LinearInference"],
# }
# self._linear_cls = mode_to_cls[mode]
# self.dtype = dtype
# self.ignore_modules = ignore_modules or set()
# @override
# def convert_module(self, module: torch.nn.Module) -> torch.nn.Module:
# # avoid naive users thinking they quantized their model
# if not any(isinstance(m, torch.nn.Linear) for m in module.modules()):
# raise TypeError(
# "You are using the bitsandbytes precision plugin, but your model has no Linear layers. This plugin"
# " won't work for your model."
# )
# # convert modules if they haven't been converted already
# if not any(isinstance(m, (bnb.nn.Linear8bitLt, bnb.nn.Linear4bit)) for m in module.modules()):
# # this will not quantize the model but only replace the layer classes
# _convert_layers(module, self._linear_cls, self.ignore_modules)
# # set the compute dtype if necessary
# for m in module.modules():
# if isinstance(m, bnb.nn.Linear4bit):
# m.compute_dtype = self.dtype
# m.compute_type_is_set = False
# return module
# def _quantize_on_load_hook(quantize_fn: Callable[[torch.Tensor], None], state_dict: OrderedDict, *_: Any) -> None:
# # There is only one key that ends with `*.weight`, the other one is the bias
# weight_key = next((name for name in state_dict if name.endswith("weight")), None)
# if weight_key is None:
# return
# # Load the weight from the state dict and re-quantize it
# weight = state_dict.pop(weight_key)
# quantize_fn(weight)
# def _ignore_missing_weights_hook(module: torch.nn.Module, incompatible_keys: _IncompatibleKeys) -> None:
# # since we manually loaded the weight in the `_quantize_on_load_hook` hook, we need to avoid this missing key false
# # positive
# for key in reversed(incompatible_keys.missing_keys):
# if key.endswith("weight"):
# incompatible_keys.missing_keys.remove(key)
def _convert_linear_layers (
module : torch . nn . Module , linear_cls : Type , ignore_modules : Set [ str ] , prefix : str = " "
) - > None :
for name , child in module . named_children ( ) :
fullname = f " { prefix } . { name } " if prefix else name
if isinstance ( child , torch . nn . Linear ) and not any ( fullname . startswith ( s ) for s in ignore_modules ) :
has_bias = child . bias is not None
# since we are going to copy over the child's data, the device doesn't matter. I chose CPU
# to avoid spiking CUDA memory even though initialization is slower
# 4bit layers support quantizing from meta-device params so this is only relevant for 8-bit
_Linear4bit = globals ( ) [ " _Linear4bit " ]
device = torch . device ( " meta " if issubclass ( linear_cls , _Linear4bit ) else " cpu " )
replacement = linear_cls (
child . in_features ,
child . out_features ,
bias = has_bias ,
device = device ,
)
if has_bias :
replacement . bias = _replace_param ( replacement . bias , child . bias . data . clone ( ) )
state = { " quant_state " : replacement . weight . quant_state if issubclass ( linear_cls , _Linear4bit ) else None }
replacement . weight = _replace_param ( replacement . weight , child . weight . data . clone ( ) , * * state )
module . __setattr__ ( name , replacement )
else :
_convert_linear_layers ( child , linear_cls , ignore_modules , prefix = fullname )
# def _replace_linear_layers(
# model: torch.nn.Module,
# linear_layer_type: Literal["Linear8bitLt", "Linear4bit"],
# modules_to_not_convert: set[str],
# current_key_name: str | None = None,
# ):
# has_been_replaced = False
# for name, module in model.named_children():
# if current_key_name is None:
# current_key_name = []
# current_key_name.append(name)
# if isinstance(module, torch.nn.Linear) and name not in modules_to_not_convert:
# # Check if the current key is not in the `modules_to_not_convert`
# current_key_name_str = ".".join(current_key_name)
# proceed = True
# for key in modules_to_not_convert:
# if (
# (key in current_key_name_str) and (key + "." in current_key_name_str)
# ) or key == current_key_name_str:
# proceed = False
# break
# if proceed:
# # Load bnb module with empty weight and replace ``nn.Linear` module
# if bnb_quantization_config.load_in_8bit:
# bnb_module = bnb.nn.Linear8bitLt(
# module.in_features,
# module.out_features,
# module.bias is not None,
# has_fp16_weights=False,
# threshold=bnb_quantization_config.llm_int8_threshold,
# )
# elif bnb_quantization_config.load_in_4bit:
# bnb_module = bnb.nn.Linear4bit(
# module.in_features,
# module.out_features,
# module.bias is not None,
# bnb_quantization_config.bnb_4bit_compute_dtype,
# compress_statistics=bnb_quantization_config.bnb_4bit_use_double_quant,
# quant_type=bnb_quantization_config.bnb_4bit_quant_type,
# )
# else:
# raise ValueError("load_in_8bit and load_in_4bit can't be both False")
# bnb_module.weight.data = module.weight.data
# if module.bias is not None:
# bnb_module.bias.data = module.bias.data
# bnb_module.requires_grad_(False)
# setattr(model, name, bnb_module)
# has_been_replaced = True
# if len(list(module.children())) > 0:
# _, _has_been_replaced = _replace_with_bnb_layers(
# module, bnb_quantization_config, modules_to_not_convert, current_key_name
# )
# has_been_replaced = has_been_replaced | _has_been_replaced
# # Remove the last key for recursion
# current_key_name.pop(-1)
# return model, has_been_replaced