mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix bug in InvokeLinear8bitLt that was causing old state information to persist after loading from a state dict. This manifested as state tensors being left on the GPU even when a model had been offloaded to the CPU cache.
This commit is contained in:
parent
77090070bd
commit
29fe1533f2
@ -54,8 +54,10 @@ class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
|
|||||||
|
|
||||||
# See `bnb.nn.Linear8bitLt._save_to_state_dict()` for the serialization logic of SCB and weight_format.
|
# See `bnb.nn.Linear8bitLt._save_to_state_dict()` for the serialization logic of SCB and weight_format.
|
||||||
scb = state_dict.pop(prefix + "SCB", None)
|
scb = state_dict.pop(prefix + "SCB", None)
|
||||||
# weight_format is unused, but we pop it so we can validate that there are no unexpected keys.
|
|
||||||
_weight_format = state_dict.pop(prefix + "weight_format", None)
|
# Currently, we only support weight_format=0.
|
||||||
|
weight_format = state_dict.pop(prefix + "weight_format", None)
|
||||||
|
assert weight_format == 0
|
||||||
|
|
||||||
# TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs`
|
# TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs`
|
||||||
# rather than raising an exception to correctly implement this API.
|
# rather than raising an exception to correctly implement this API.
|
||||||
@ -89,6 +91,14 @@ class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
|
|||||||
)
|
)
|
||||||
self.bias = bias if bias is None else torch.nn.Parameter(bias)
|
self.bias = bias if bias is None else torch.nn.Parameter(bias)
|
||||||
|
|
||||||
|
# Reset the state. The persisted fields are based on the initialization behaviour in
|
||||||
|
# `bnb.nn.Linear8bitLt.__init__()`.
|
||||||
|
new_state = bnb.MatmulLtState()
|
||||||
|
new_state.threshold = self.state.threshold
|
||||||
|
new_state.has_fp16_weights = False
|
||||||
|
new_state.use_pool = self.state.use_pool
|
||||||
|
self.state = new_state
|
||||||
|
|
||||||
|
|
||||||
def _convert_linear_layers_to_llm_8bit(
|
def _convert_linear_layers_to_llm_8bit(
|
||||||
module: torch.nn.Module, ignore_modules: set[str], outlier_threshold: float, prefix: str = ""
|
module: torch.nn.Module, ignore_modules: set[str], outlier_threshold: float, prefix: str = ""
|
||||||
|
Loading…
Reference in New Issue
Block a user