More improvements for LLM.int8() - not fully tested.

This commit is contained in:
Ryan Dick 2024-08-15 19:59:31 +00:00 committed by Brandon
parent f01f56a98e
commit d3a5ca5247

View File

@ -11,6 +11,33 @@ import torch
# stick close to the bitsandbytes classes to make interoperability easier with other models that might use bitsandbytes.
class InvokeInt8Params(bnb.nn.Int8Params):
"""We override cuda() to avoid re-quantizing the weights in the following cases:
- We loaded quantized weights from a state_dict on the cpu, and then moved the model to the gpu.
- We are moving the model back-and-forth between the cpu and gpu.
"""
def cuda(self, device):
if self.has_fp16_weights:
return super().cuda(device)
elif self.CB is not None and self.SCB is not None:
self.data = self.data.cuda()
self.CB = self.CB.cuda()
self.SCB = self.SCB.cuda()
else:
# we store the 8-bit rows-major weight
# we convert this weight to the turning/ampere weight during the first inference pass
B = self.data.contiguous().half().cuda(device)
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
del CBt
del SCBt
self.data = CB
self.CB = CB
self.SCB = SCB
return self
class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
def _load_from_state_dict(
self,
@ -36,7 +63,7 @@ class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
if scb is not None:
# We are loading a pre-quantized state dict.
self.weight = bnb.nn.Int8Params(
self.weight = InvokeInt8Params(
data=weight,
requires_grad=self.weight.requires_grad,
has_fp16_weights=False,
@ -53,7 +80,7 @@ class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
# device requires setting `assign=True`, doing this with the default `super()._load_from_state_dict()`
# implementation causes `Params4Bit` to be replaced by a `torch.nn.Parameter`. By initializing a new
# `Params4bit` object, we work around this issue. It's a bit hacky, but it gets the job done.
self.weight = bnb.nn.Int8Params(
self.weight = InvokeInt8Params(
data=weight,
requires_grad=self.weight.requires_grad,
has_fp16_weights=False,
@ -89,10 +116,6 @@ def _convert_linear_layers_to_llm_8bit(
)
def get_parameter_device(parameter: torch.nn.Module):
return next(parameter.parameters()).device
def quantize_model_llm_int8(model: torch.nn.Module, modules_to_not_convert: set[str], outlier_threshold: float = 6.0):
"""Apply bitsandbytes LLM.8bit() quantization to the model."""
_convert_linear_layers_to_llm_8bit(