Fix bug in InvokeLinear8bitLt that was causing old state information to persist after loading from a state dict. This manifested as state tensors being left on the GPU even when a model had been offloaded to the CPU cache.

This commit is contained in:
Ryan Dick 2024-08-28 14:06:29 +00:00
parent 77090070bd
commit 29fe1533f2

View File

@ -54,8 +54,10 @@ class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
# See `bnb.nn.Linear8bitLt._save_to_state_dict()` for the serialization logic of SCB and weight_format. # See `bnb.nn.Linear8bitLt._save_to_state_dict()` for the serialization logic of SCB and weight_format.
scb = state_dict.pop(prefix + "SCB", None) scb = state_dict.pop(prefix + "SCB", None)
# weight_format is unused, but we pop it so we can validate that there are no unexpected keys.
_weight_format = state_dict.pop(prefix + "weight_format", None) # Currently, we only support weight_format=0.
weight_format = state_dict.pop(prefix + "weight_format", None)
assert weight_format == 0
# TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs` # TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs`
# rather than raising an exception to correctly implement this API. # rather than raising an exception to correctly implement this API.
@ -89,6 +91,14 @@ class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
) )
self.bias = bias if bias is None else torch.nn.Parameter(bias) self.bias = bias if bias is None else torch.nn.Parameter(bias)
# Reset the state. The persisted fields are based on the initialization behaviour in
# `bnb.nn.Linear8bitLt.__init__()`.
new_state = bnb.MatmulLtState()
new_state.threshold = self.state.threshold
new_state.has_fp16_weights = False
new_state.use_pool = self.state.use_pool
self.state = new_state
def _convert_linear_layers_to_llm_8bit( def _convert_linear_layers_to_llm_8bit(
module: torch.nn.Module, ignore_modules: set[str], outlier_threshold: float, prefix: str = "" module: torch.nn.Module, ignore_modules: set[str], outlier_threshold: float, prefix: str = ""