NF4 loading working... I think.

This commit is contained in:
Ryan Dick 2024-08-14 14:47:03 +00:00 committed by Brandon
parent b63df9bab9
commit 5c2f95ef50
2 changed files with 40 additions and 3 deletions

View File

@ -51,6 +51,42 @@ import torch
# self.SCB = SCB # self.SCB = SCB
class InvokeLinear4Bit(bnb.nn.Linear4bit):
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
"""This method is based on the logic in the bitsandbytes serialization unit tests for `Linear4bit`:
https://github.com/bitsandbytes-foundation/bitsandbytes/blob/6d714a5cce3db5bd7f577bc447becc7a92d5ccc7/tests/test_linear4bit.py#L52-L71
I'm not sure why this was not included in the original `Linear4bit` implementation.
"""
# During serialization, the quant_state is stored as subkeys of "weight.". Here we extract those keys.
quant_state_keys = [k for k in state_dict.keys() if k.startswith(prefix + "weight.")]
if len(quant_state_keys) > 0:
# We are loading a quantized state dict.
quant_state_sd = {k: state_dict.pop(k) for k in quant_state_keys}
weight = state_dict.pop(prefix + "weight")
bias = state_dict.pop(prefix + "bias", None)
if len(state_dict) != 0:
raise RuntimeError(f"Unexpected keys in state_dict: {state_dict.keys()}")
self.weight = bnb.nn.Params4bit.from_prequantized(
data=weight, quantized_stats=quant_state_sd, device=weight.device
)
if bias is None:
self.bias = None
else:
self.bias = torch.nn.Parameter(bias)
else:
# We are loading a non-quantized state dict.
return super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
class Invoke2Linear8bitLt(torch.nn.Linear): class Invoke2Linear8bitLt(torch.nn.Linear):
"""This class is the base module for the [LLM.int8()](https://arxiv.org/abs/2208.07339) algorithm.""" """This class is the base module for the [LLM.int8()](https://arxiv.org/abs/2208.07339) algorithm."""
@ -509,7 +545,7 @@ def _convert_linear_layers_to_nf4(
fullname = f"{prefix}.{name}" if prefix else name fullname = f"{prefix}.{name}" if prefix else name
if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules): if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules):
has_bias = child.bias is not None has_bias = child.bias is not None
replacement = bnb.nn.Linear4bit( replacement = InvokeLinear4Bit(
child.in_features, child.in_features,
child.out_features, child.out_features,
bias=has_bias, bias=has_bias,

View File

@ -4,7 +4,7 @@ from pathlib import Path
import accelerate import accelerate
import torch import torch
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from safetensors.torch import load_file, save_file from safetensors.torch import load_file
from invokeai.backend.bnb import quantize_model_nf4 from invokeai.backend.bnb import quantize_model_nf4
@ -43,6 +43,7 @@ def load_flux_transformer(path: Path) -> FluxTransformer2DModel:
model.to_empty(device="cpu") model.to_empty(device="cpu")
sd = load_file(model_nf4_path / "model.safetensors") sd = load_file(model_nf4_path / "model.safetensors")
model.load_state_dict(sd, strict=True) model.load_state_dict(sd, strict=True)
model = model.to("cuda")
else: else:
# The quantized model does not exist yet, quantize and save it. # The quantized model does not exist yet, quantize and save it.
@ -79,7 +80,7 @@ def load_flux_transformer(path: Path) -> FluxTransformer2DModel:
model = model.to("cuda") model = model.to("cuda")
model_nf4_path.mkdir(parents=True, exist_ok=True) model_nf4_path.mkdir(parents=True, exist_ok=True)
save_file(model.state_dict(), model_nf4_path / "model.safetensors") # save_file(model.state_dict(), model_nf4_path / "model.safetensors")
# --------------------- # ---------------------