mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
More improvements for LLM.int8() - not fully tested.
This commit is contained in:
parent
f01f56a98e
commit
d3a5ca5247
@ -11,6 +11,33 @@ import torch
|
|||||||
# stick close to the bitsandbytes classes to make interoperability easier with other models that might use bitsandbytes.
|
# stick close to the bitsandbytes classes to make interoperability easier with other models that might use bitsandbytes.
|
||||||
|
|
||||||
|
|
||||||
|
class InvokeInt8Params(bnb.nn.Int8Params):
|
||||||
|
"""We override cuda() to avoid re-quantizing the weights in the following cases:
|
||||||
|
- We loaded quantized weights from a state_dict on the cpu, and then moved the model to the gpu.
|
||||||
|
- We are moving the model back-and-forth between the cpu and gpu.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def cuda(self, device):
|
||||||
|
if self.has_fp16_weights:
|
||||||
|
return super().cuda(device)
|
||||||
|
elif self.CB is not None and self.SCB is not None:
|
||||||
|
self.data = self.data.cuda()
|
||||||
|
self.CB = self.CB.cuda()
|
||||||
|
self.SCB = self.SCB.cuda()
|
||||||
|
else:
|
||||||
|
# we store the 8-bit rows-major weight
|
||||||
|
# we convert this weight to the turning/ampere weight during the first inference pass
|
||||||
|
B = self.data.contiguous().half().cuda(device)
|
||||||
|
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
|
||||||
|
del CBt
|
||||||
|
del SCBt
|
||||||
|
self.data = CB
|
||||||
|
self.CB = CB
|
||||||
|
self.SCB = SCB
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
|
class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
|
||||||
def _load_from_state_dict(
|
def _load_from_state_dict(
|
||||||
self,
|
self,
|
||||||
@ -36,7 +63,7 @@ class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
|
|||||||
|
|
||||||
if scb is not None:
|
if scb is not None:
|
||||||
# We are loading a pre-quantized state dict.
|
# We are loading a pre-quantized state dict.
|
||||||
self.weight = bnb.nn.Int8Params(
|
self.weight = InvokeInt8Params(
|
||||||
data=weight,
|
data=weight,
|
||||||
requires_grad=self.weight.requires_grad,
|
requires_grad=self.weight.requires_grad,
|
||||||
has_fp16_weights=False,
|
has_fp16_weights=False,
|
||||||
@ -53,7 +80,7 @@ class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
|
|||||||
# device requires setting `assign=True`, doing this with the default `super()._load_from_state_dict()`
|
# device requires setting `assign=True`, doing this with the default `super()._load_from_state_dict()`
|
||||||
# implementation causes `Params4Bit` to be replaced by a `torch.nn.Parameter`. By initializing a new
|
# implementation causes `Params4Bit` to be replaced by a `torch.nn.Parameter`. By initializing a new
|
||||||
# `Params4bit` object, we work around this issue. It's a bit hacky, but it gets the job done.
|
# `Params4bit` object, we work around this issue. It's a bit hacky, but it gets the job done.
|
||||||
self.weight = bnb.nn.Int8Params(
|
self.weight = InvokeInt8Params(
|
||||||
data=weight,
|
data=weight,
|
||||||
requires_grad=self.weight.requires_grad,
|
requires_grad=self.weight.requires_grad,
|
||||||
has_fp16_weights=False,
|
has_fp16_weights=False,
|
||||||
@ -89,10 +116,6 @@ def _convert_linear_layers_to_llm_8bit(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_parameter_device(parameter: torch.nn.Module):
|
|
||||||
return next(parameter.parameters()).device
|
|
||||||
|
|
||||||
|
|
||||||
def quantize_model_llm_int8(model: torch.nn.Module, modules_to_not_convert: set[str], outlier_threshold: float = 6.0):
|
def quantize_model_llm_int8(model: torch.nn.Module, modules_to_not_convert: set[str], outlier_threshold: float = 6.0):
|
||||||
"""Apply bitsandbytes LLM.8bit() quantization to the model."""
|
"""Apply bitsandbytes LLM.8bit() quantization to the model."""
|
||||||
_convert_linear_layers_to_llm_8bit(
|
_convert_linear_layers_to_llm_8bit(
|
||||||
|
Loading…
Reference in New Issue
Block a user