mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
NF4 loading working... I think.
This commit is contained in:
parent
45792cc152
commit
96b0450b20
@ -51,6 +51,42 @@ import torch
|
|||||||
# self.SCB = SCB
|
# self.SCB = SCB
|
||||||
|
|
||||||
|
|
||||||
|
class InvokeLinear4Bit(bnb.nn.Linear4bit):
|
||||||
|
def _load_from_state_dict(
|
||||||
|
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||||
|
):
|
||||||
|
"""This method is based on the logic in the bitsandbytes serialization unit tests for `Linear4bit`:
|
||||||
|
https://github.com/bitsandbytes-foundation/bitsandbytes/blob/6d714a5cce3db5bd7f577bc447becc7a92d5ccc7/tests/test_linear4bit.py#L52-L71
|
||||||
|
|
||||||
|
I'm not sure why this was not included in the original `Linear4bit` implementation.
|
||||||
|
"""
|
||||||
|
# During serialization, the quant_state is stored as subkeys of "weight.". Here we extract those keys.
|
||||||
|
quant_state_keys = [k for k in state_dict.keys() if k.startswith(prefix + "weight.")]
|
||||||
|
|
||||||
|
if len(quant_state_keys) > 0:
|
||||||
|
# We are loading a quantized state dict.
|
||||||
|
quant_state_sd = {k: state_dict.pop(k) for k in quant_state_keys}
|
||||||
|
weight = state_dict.pop(prefix + "weight")
|
||||||
|
bias = state_dict.pop(prefix + "bias", None)
|
||||||
|
|
||||||
|
if len(state_dict) != 0:
|
||||||
|
raise RuntimeError(f"Unexpected keys in state_dict: {state_dict.keys()}")
|
||||||
|
|
||||||
|
self.weight = bnb.nn.Params4bit.from_prequantized(
|
||||||
|
data=weight, quantized_stats=quant_state_sd, device=weight.device
|
||||||
|
)
|
||||||
|
if bias is None:
|
||||||
|
self.bias = None
|
||||||
|
else:
|
||||||
|
self.bias = torch.nn.Parameter(bias)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# We are loading a non-quantized state dict.
|
||||||
|
return super()._load_from_state_dict(
|
||||||
|
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Invoke2Linear8bitLt(torch.nn.Linear):
|
class Invoke2Linear8bitLt(torch.nn.Linear):
|
||||||
"""This class is the base module for the [LLM.int8()](https://arxiv.org/abs/2208.07339) algorithm."""
|
"""This class is the base module for the [LLM.int8()](https://arxiv.org/abs/2208.07339) algorithm."""
|
||||||
|
|
||||||
@ -509,7 +545,7 @@ def _convert_linear_layers_to_nf4(
|
|||||||
fullname = f"{prefix}.{name}" if prefix else name
|
fullname = f"{prefix}.{name}" if prefix else name
|
||||||
if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules):
|
if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules):
|
||||||
has_bias = child.bias is not None
|
has_bias = child.bias is not None
|
||||||
replacement = bnb.nn.Linear4bit(
|
replacement = InvokeLinear4Bit(
|
||||||
child.in_features,
|
child.in_features,
|
||||||
child.out_features,
|
child.out_features,
|
||||||
bias=has_bias,
|
bias=has_bias,
|
||||||
|
@ -4,7 +4,7 @@ from pathlib import Path
|
|||||||
import accelerate
|
import accelerate
|
||||||
import torch
|
import torch
|
||||||
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
||||||
from safetensors.torch import load_file, save_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
from invokeai.backend.bnb import quantize_model_nf4
|
from invokeai.backend.bnb import quantize_model_nf4
|
||||||
|
|
||||||
@ -43,6 +43,7 @@ def load_flux_transformer(path: Path) -> FluxTransformer2DModel:
|
|||||||
model.to_empty(device="cpu")
|
model.to_empty(device="cpu")
|
||||||
sd = load_file(model_nf4_path / "model.safetensors")
|
sd = load_file(model_nf4_path / "model.safetensors")
|
||||||
model.load_state_dict(sd, strict=True)
|
model.load_state_dict(sd, strict=True)
|
||||||
|
model = model.to("cuda")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# The quantized model does not exist yet, quantize and save it.
|
# The quantized model does not exist yet, quantize and save it.
|
||||||
@ -79,7 +80,7 @@ def load_flux_transformer(path: Path) -> FluxTransformer2DModel:
|
|||||||
model = model.to("cuda")
|
model = model.to("cuda")
|
||||||
|
|
||||||
model_nf4_path.mkdir(parents=True, exist_ok=True)
|
model_nf4_path.mkdir(parents=True, exist_ok=True)
|
||||||
save_file(model.state_dict(), model_nf4_path / "model.safetensors")
|
# save_file(model.state_dict(), model_nf4_path / "model.safetensors")
|
||||||
|
|
||||||
# ---------------------
|
# ---------------------
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user