From 29fe1533f290cf1c23021aa992c64cd4a7012423 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 28 Aug 2024 14:06:29 +0000 Subject: [PATCH] Fix bug in InvokeLinear8bitLt that was causing old state information to persist after loading from a state dict. This manifested as state tensors being left on the GPU even when a model had been offloaded to the CPU cache. --- invokeai/backend/quantization/bnb_llm_int8.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/invokeai/backend/quantization/bnb_llm_int8.py b/invokeai/backend/quantization/bnb_llm_int8.py index b92717cbc5..02f94936e9 100644 --- a/invokeai/backend/quantization/bnb_llm_int8.py +++ b/invokeai/backend/quantization/bnb_llm_int8.py @@ -54,8 +54,10 @@ class InvokeLinear8bitLt(bnb.nn.Linear8bitLt): # See `bnb.nn.Linear8bitLt._save_to_state_dict()` for the serialization logic of SCB and weight_format. scb = state_dict.pop(prefix + "SCB", None) - # weight_format is unused, but we pop it so we can validate that there are no unexpected keys. - _weight_format = state_dict.pop(prefix + "weight_format", None) + + # Currently, we only support weight_format=0. + weight_format = state_dict.pop(prefix + "weight_format", None) + assert weight_format == 0 # TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs` # rather than raising an exception to correctly implement this API. @@ -89,6 +91,14 @@ class InvokeLinear8bitLt(bnb.nn.Linear8bitLt): ) self.bias = bias if bias is None else torch.nn.Parameter(bias) + # Reset the state. The persisted fields are based on the initialization behaviour in + # `bnb.nn.Linear8bitLt.__init__()`. + new_state = bnb.MatmulLtState() + new_state.threshold = self.state.threshold + new_state.has_fp16_weights = False + new_state.use_pool = self.state.use_pool + self.state = new_state + def _convert_linear_layers_to_llm_8bit( module: torch.nn.Module, ignore_modules: set[str], outlier_threshold: float, prefix: str = ""