mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
NF4 loading working... I think.
This commit is contained in:
parent
b63df9bab9
commit
5c2f95ef50
@ -51,6 +51,42 @@ import torch
|
||||
# self.SCB = SCB
|
||||
|
||||
|
||||
class InvokeLinear4Bit(bnb.nn.Linear4bit):
|
||||
def _load_from_state_dict(
|
||||
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
):
|
||||
"""This method is based on the logic in the bitsandbytes serialization unit tests for `Linear4bit`:
|
||||
https://github.com/bitsandbytes-foundation/bitsandbytes/blob/6d714a5cce3db5bd7f577bc447becc7a92d5ccc7/tests/test_linear4bit.py#L52-L71
|
||||
|
||||
I'm not sure why this was not included in the original `Linear4bit` implementation.
|
||||
"""
|
||||
# During serialization, the quant_state is stored as subkeys of "weight.". Here we extract those keys.
|
||||
quant_state_keys = [k for k in state_dict.keys() if k.startswith(prefix + "weight.")]
|
||||
|
||||
if len(quant_state_keys) > 0:
|
||||
# We are loading a quantized state dict.
|
||||
quant_state_sd = {k: state_dict.pop(k) for k in quant_state_keys}
|
||||
weight = state_dict.pop(prefix + "weight")
|
||||
bias = state_dict.pop(prefix + "bias", None)
|
||||
|
||||
if len(state_dict) != 0:
|
||||
raise RuntimeError(f"Unexpected keys in state_dict: {state_dict.keys()}")
|
||||
|
||||
self.weight = bnb.nn.Params4bit.from_prequantized(
|
||||
data=weight, quantized_stats=quant_state_sd, device=weight.device
|
||||
)
|
||||
if bias is None:
|
||||
self.bias = None
|
||||
else:
|
||||
self.bias = torch.nn.Parameter(bias)
|
||||
|
||||
else:
|
||||
# We are loading a non-quantized state dict.
|
||||
return super()._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
)
|
||||
|
||||
|
||||
class Invoke2Linear8bitLt(torch.nn.Linear):
|
||||
"""This class is the base module for the [LLM.int8()](https://arxiv.org/abs/2208.07339) algorithm."""
|
||||
|
||||
@ -509,7 +545,7 @@ def _convert_linear_layers_to_nf4(
|
||||
fullname = f"{prefix}.{name}" if prefix else name
|
||||
if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules):
|
||||
has_bias = child.bias is not None
|
||||
replacement = bnb.nn.Linear4bit(
|
||||
replacement = InvokeLinear4Bit(
|
||||
child.in_features,
|
||||
child.out_features,
|
||||
bias=has_bias,
|
||||
|
@ -4,7 +4,7 @@ from pathlib import Path
|
||||
import accelerate
|
||||
import torch
|
||||
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
||||
from safetensors.torch import load_file, save_file
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from invokeai.backend.bnb import quantize_model_nf4
|
||||
|
||||
@ -43,6 +43,7 @@ def load_flux_transformer(path: Path) -> FluxTransformer2DModel:
|
||||
model.to_empty(device="cpu")
|
||||
sd = load_file(model_nf4_path / "model.safetensors")
|
||||
model.load_state_dict(sd, strict=True)
|
||||
model = model.to("cuda")
|
||||
|
||||
else:
|
||||
# The quantized model does not exist yet, quantize and save it.
|
||||
@ -79,7 +80,7 @@ def load_flux_transformer(path: Path) -> FluxTransformer2DModel:
|
||||
model = model.to("cuda")
|
||||
|
||||
model_nf4_path.mkdir(parents=True, exist_ok=True)
|
||||
save_file(model.state_dict(), model_nf4_path / "model.safetensors")
|
||||
# save_file(model.state_dict(), model_nf4_path / "model.safetensors")
|
||||
|
||||
# ---------------------
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user