diff --git a/invokeai/backend/quantization/bnb_llm_int8.py b/invokeai/backend/quantization/bnb_llm_int8.py index 900c55a085..f196ebc43e 100644 --- a/invokeai/backend/quantization/bnb_llm_int8.py +++ b/invokeai/backend/quantization/bnb_llm_int8.py @@ -11,6 +11,33 @@ import torch # stick close to the bitsandbytes classes to make interoperability easier with other models that might use bitsandbytes. +class InvokeInt8Params(bnb.nn.Int8Params): + """We override cuda() to avoid re-quantizing the weights in the following cases: + - We loaded quantized weights from a state_dict on the cpu, and then moved the model to the gpu. + - We are moving the model back-and-forth between the cpu and gpu. + """ + + def cuda(self, device): + if self.has_fp16_weights: + return super().cuda(device) + elif self.CB is not None and self.SCB is not None: + self.data = self.data.cuda() + self.CB = self.CB.cuda() + self.SCB = self.SCB.cuda() + else: + # we store the 8-bit rows-major weight + # we convert this weight to the turning/ampere weight during the first inference pass + B = self.data.contiguous().half().cuda(device) + CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) + del CBt + del SCBt + self.data = CB + self.CB = CB + self.SCB = SCB + + return self + + class InvokeLinear8bitLt(bnb.nn.Linear8bitLt): def _load_from_state_dict( self, @@ -36,7 +63,7 @@ class InvokeLinear8bitLt(bnb.nn.Linear8bitLt): if scb is not None: # We are loading a pre-quantized state dict. - self.weight = bnb.nn.Int8Params( + self.weight = InvokeInt8Params( data=weight, requires_grad=self.weight.requires_grad, has_fp16_weights=False, @@ -53,7 +80,7 @@ class InvokeLinear8bitLt(bnb.nn.Linear8bitLt): # device requires setting `assign=True`, doing this with the default `super()._load_from_state_dict()` # implementation causes `Params4Bit` to be replaced by a `torch.nn.Parameter`. By initializing a new # `Params4bit` object, we work around this issue. It's a bit hacky, but it gets the job done. - self.weight = bnb.nn.Int8Params( + self.weight = InvokeInt8Params( data=weight, requires_grad=self.weight.requires_grad, has_fp16_weights=False, @@ -89,10 +116,6 @@ def _convert_linear_layers_to_llm_8bit( ) -def get_parameter_device(parameter: torch.nn.Module): - return next(parameter.parameters()).device - - def quantize_model_llm_int8(model: torch.nn.Module, modules_to_not_convert: set[str], outlier_threshold: float = 6.0): """Apply bitsandbytes LLM.8bit() quantization to the model.""" _convert_linear_layers_to_llm_8bit(