From 96b0450b20865f0a0651171b4aa68795ac05aca3 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 14 Aug 2024 14:47:03 +0000 Subject: [PATCH] NF4 loading working... I think. --- invokeai/backend/bnb.py | 38 ++++++++++++++++++++- invokeai/backend/load_flux_model_bnb_nf4.py | 5 +-- 2 files changed, 40 insertions(+), 3 deletions(-) diff --git a/invokeai/backend/bnb.py b/invokeai/backend/bnb.py index 766de08f6d..d0cb6f7c99 100644 --- a/invokeai/backend/bnb.py +++ b/invokeai/backend/bnb.py @@ -51,6 +51,42 @@ import torch # self.SCB = SCB +class InvokeLinear4Bit(bnb.nn.Linear4bit): + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + """This method is based on the logic in the bitsandbytes serialization unit tests for `Linear4bit`: + https://github.com/bitsandbytes-foundation/bitsandbytes/blob/6d714a5cce3db5bd7f577bc447becc7a92d5ccc7/tests/test_linear4bit.py#L52-L71 + + I'm not sure why this was not included in the original `Linear4bit` implementation. + """ + # During serialization, the quant_state is stored as subkeys of "weight.". Here we extract those keys. + quant_state_keys = [k for k in state_dict.keys() if k.startswith(prefix + "weight.")] + + if len(quant_state_keys) > 0: + # We are loading a quantized state dict. + quant_state_sd = {k: state_dict.pop(k) for k in quant_state_keys} + weight = state_dict.pop(prefix + "weight") + bias = state_dict.pop(prefix + "bias", None) + + if len(state_dict) != 0: + raise RuntimeError(f"Unexpected keys in state_dict: {state_dict.keys()}") + + self.weight = bnb.nn.Params4bit.from_prequantized( + data=weight, quantized_stats=quant_state_sd, device=weight.device + ) + if bias is None: + self.bias = None + else: + self.bias = torch.nn.Parameter(bias) + + else: + # We are loading a non-quantized state dict. + return super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + + class Invoke2Linear8bitLt(torch.nn.Linear): """This class is the base module for the [LLM.int8()](https://arxiv.org/abs/2208.07339) algorithm.""" @@ -509,7 +545,7 @@ def _convert_linear_layers_to_nf4( fullname = f"{prefix}.{name}" if prefix else name if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules): has_bias = child.bias is not None - replacement = bnb.nn.Linear4bit( + replacement = InvokeLinear4Bit( child.in_features, child.out_features, bias=has_bias, diff --git a/invokeai/backend/load_flux_model_bnb_nf4.py b/invokeai/backend/load_flux_model_bnb_nf4.py index 1629c5a01c..1a4e67c1c7 100644 --- a/invokeai/backend/load_flux_model_bnb_nf4.py +++ b/invokeai/backend/load_flux_model_bnb_nf4.py @@ -4,7 +4,7 @@ from pathlib import Path import accelerate import torch from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel -from safetensors.torch import load_file, save_file +from safetensors.torch import load_file from invokeai.backend.bnb import quantize_model_nf4 @@ -43,6 +43,7 @@ def load_flux_transformer(path: Path) -> FluxTransformer2DModel: model.to_empty(device="cpu") sd = load_file(model_nf4_path / "model.safetensors") model.load_state_dict(sd, strict=True) + model = model.to("cuda") else: # The quantized model does not exist yet, quantize and save it. @@ -79,7 +80,7 @@ def load_flux_transformer(path: Path) -> FluxTransformer2DModel: model = model.to("cuda") model_nf4_path.mkdir(parents=True, exist_ok=True) - save_file(model.state_dict(), model_nf4_path / "model.safetensors") + # save_file(model.state_dict(), model_nf4_path / "model.safetensors") # ---------------------