NF4 inference working

This commit is contained in:
Ryan Dick
2024-08-14 23:30:53 +00:00
committed by Brandon
parent 5c2f95ef50
commit e1eb104345
3 changed files with 83 additions and 49 deletions

View File

@ -51,7 +51,7 @@ import torch
# self.SCB = SCB
class InvokeLinear4Bit(bnb.nn.Linear4bit):
class InvokeLinearNF4(bnb.nn.LinearNF4):
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
@ -60,31 +60,36 @@ class InvokeLinear4Bit(bnb.nn.Linear4bit):
I'm not sure why this was not included in the original `Linear4bit` implementation.
"""
# During serialization, the quant_state is stored as subkeys of "weight.". Here we extract those keys.
quant_state_keys = [k for k in state_dict.keys() if k.startswith(prefix + "weight.")]
if len(quant_state_keys) > 0:
weight = state_dict.pop(prefix + "weight")
bias = state_dict.pop(prefix + "bias", None)
# During serialization, the quant_state is stored as subkeys of "weight.".
# We expect the remaining keys to be quant_state keys. We validate that they at least have the correct prefix.
quant_state_sd = state_dict
assert all(k.startswith(prefix + "weight.") for k in quant_state_sd.keys())
if len(quant_state_sd) > 0:
# We are loading a quantized state dict.
quant_state_sd = {k: state_dict.pop(k) for k in quant_state_keys}
weight = state_dict.pop(prefix + "weight")
bias = state_dict.pop(prefix + "bias", None)
if len(state_dict) != 0:
raise RuntimeError(f"Unexpected keys in state_dict: {state_dict.keys()}")
self.weight = bnb.nn.Params4bit.from_prequantized(
data=weight, quantized_stats=quant_state_sd, device=weight.device
)
if bias is None:
self.bias = None
else:
self.bias = torch.nn.Parameter(bias)
self.bias = bias if bias is None else torch.nn.Parameter(bias, requires_grad=False)
else:
# We are loading a non-quantized state dict.
return super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
# We could simply call the `super()._load_from_state_dict` method here, but then we wouldn't be able to load
# into from a state_dict into a model on the "meta" device. By initializing a new `Params4bit` object, we
# work around this issue.
self.weight = bnb.nn.Params4bit(
data=weight,
requires_grad=self.weight.requires_grad,
compress_statistics=self.weight.compress_statistics,
quant_type=self.weight.quant_type,
quant_storage=self.weight.quant_storage,
module=self,
)
self.bias = bias if bias is None else torch.nn.Parameter(bias)
class Invoke2Linear8bitLt(torch.nn.Linear):
@ -545,7 +550,7 @@ def _convert_linear_layers_to_nf4(
fullname = f"{prefix}.{name}" if prefix else name
if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules):
has_bias = child.bias is not None
replacement = InvokeLinear4Bit(
replacement = InvokeLinearNF4(
child.in_features,
child.out_features,
bias=has_bias,
@ -553,9 +558,14 @@ def _convert_linear_layers_to_nf4(
# TODO(ryand): Test compress_statistics=True.
# compress_statistics=True,
)
replacement.weight.data = child.weight.data
# replacement.weight.data = child.weight.data
# if has_bias:
# replacement.bias.data = child.bias.data
if has_bias:
replacement.bias.data = child.bias.data
replacement.bias = _replace_param(replacement.bias, child.bias.data)
replacement.weight = _replace_param(
replacement.weight, child.weight.data, quant_state=replacement.weight.quant_state
)
replacement.requires_grad_(False)
module.__setattr__(name, replacement)
else:

View File

@ -4,7 +4,7 @@ from pathlib import Path
import accelerate
import torch
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from safetensors.torch import load_file
from safetensors.torch import load_file, save_file
from invokeai.backend.bnb import quantize_model_nf4
@ -62,6 +62,9 @@ def load_flux_transformer(path: Path) -> FluxTransformer2DModel:
# ---------------------
with accelerate.init_empty_weights():
model = quantize_model_nf4(empty_model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
# Load sharded state dict.
files = list(path.glob("*.safetensors"))
state_dict = dict()
@ -69,8 +72,9 @@ def load_flux_transformer(path: Path) -> FluxTransformer2DModel:
sd = load_file(file)
state_dict.update(sd)
empty_model.load_state_dict(state_dict, strict=True, assign=True)
model = quantize_model_nf4(empty_model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
# model.to_empty(device="cpu")
# model.to(dtype=torch.float16)
model.load_state_dict(state_dict, strict=True, assign=True)
# Load the state dict into the model. The bitsandbytes layers know how to load from both quantized and
# non-quantized state dicts.
@ -80,7 +84,7 @@ def load_flux_transformer(path: Path) -> FluxTransformer2DModel:
model = model.to("cuda")
model_nf4_path.mkdir(parents=True, exist_ok=True)
# save_file(model.state_dict(), model_nf4_path / "model.safetensors")
save_file(model.state_dict(), model_nf4_path / "model.safetensors")
# ---------------------