mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Clean up NF4 implementation.
This commit is contained in:
parent
e1eb104345
commit
99b0f79784
@ -21,7 +21,7 @@ from invokeai.app.invocations.fields import (
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.bnb import quantize_model_nf4
|
||||
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
||||
from invokeai.backend.quantization.fast_quantized_diffusion_model import FastQuantizedDiffusersModel
|
||||
from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, Optional, Set, Tuple, Type
|
||||
from typing import Any, Optional, Set, Type
|
||||
|
||||
import accelerate
|
||||
import bitsandbytes as bnb
|
||||
@ -51,47 +51,6 @@ import torch
|
||||
# self.SCB = SCB
|
||||
|
||||
|
||||
class InvokeLinearNF4(bnb.nn.LinearNF4):
|
||||
def _load_from_state_dict(
|
||||
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
):
|
||||
"""This method is based on the logic in the bitsandbytes serialization unit tests for `Linear4bit`:
|
||||
https://github.com/bitsandbytes-foundation/bitsandbytes/blob/6d714a5cce3db5bd7f577bc447becc7a92d5ccc7/tests/test_linear4bit.py#L52-L71
|
||||
|
||||
I'm not sure why this was not included in the original `Linear4bit` implementation.
|
||||
"""
|
||||
|
||||
weight = state_dict.pop(prefix + "weight")
|
||||
bias = state_dict.pop(prefix + "bias", None)
|
||||
# During serialization, the quant_state is stored as subkeys of "weight.".
|
||||
# We expect the remaining keys to be quant_state keys. We validate that they at least have the correct prefix.
|
||||
quant_state_sd = state_dict
|
||||
assert all(k.startswith(prefix + "weight.") for k in quant_state_sd.keys())
|
||||
|
||||
if len(quant_state_sd) > 0:
|
||||
# We are loading a quantized state dict.
|
||||
self.weight = bnb.nn.Params4bit.from_prequantized(
|
||||
data=weight, quantized_stats=quant_state_sd, device=weight.device
|
||||
)
|
||||
self.bias = bias if bias is None else torch.nn.Parameter(bias, requires_grad=False)
|
||||
|
||||
else:
|
||||
# We are loading a non-quantized state dict.
|
||||
|
||||
# We could simply call the `super()._load_from_state_dict` method here, but then we wouldn't be able to load
|
||||
# into from a state_dict into a model on the "meta" device. By initializing a new `Params4bit` object, we
|
||||
# work around this issue.
|
||||
self.weight = bnb.nn.Params4bit(
|
||||
data=weight,
|
||||
requires_grad=self.weight.requires_grad,
|
||||
compress_statistics=self.weight.compress_statistics,
|
||||
quant_type=self.weight.quant_type,
|
||||
quant_storage=self.weight.quant_storage,
|
||||
module=self,
|
||||
)
|
||||
self.bias = bias if bias is None else torch.nn.Parameter(bias)
|
||||
|
||||
|
||||
class Invoke2Linear8bitLt(torch.nn.Linear):
|
||||
"""This class is the base module for the [LLM.int8()](https://arxiv.org/abs/2208.07339) algorithm."""
|
||||
|
||||
@ -474,27 +433,6 @@ def convert_model_to_bnb_llm_int8(model: torch.nn.Module, ignore_modules: set[st
|
||||
# incompatible_keys.missing_keys.remove(key)
|
||||
|
||||
|
||||
def _replace_param(
|
||||
param: torch.nn.Parameter, data: torch.Tensor, quant_state: Optional[Tuple] = None
|
||||
) -> torch.nn.Parameter:
|
||||
# doing `param.data = weight` raises a RuntimeError if param.data was on meta-device, so
|
||||
# we need to re-create the parameters instead of overwriting the data
|
||||
if param.device.type == "meta":
|
||||
if isinstance(param, bnb.nn.Params4bit):
|
||||
return bnb.nn.Params4bit(
|
||||
data,
|
||||
requires_grad=data.requires_grad,
|
||||
quant_state=quant_state,
|
||||
compress_statistics=param.compress_statistics,
|
||||
quant_type=param.quant_type,
|
||||
)
|
||||
return torch.nn.Parameter(data, requires_grad=data.requires_grad)
|
||||
param.data = data
|
||||
if isinstance(param, bnb.nn.Params4bit):
|
||||
param.quant_state = quant_state
|
||||
return param
|
||||
|
||||
|
||||
def _convert_linear_layers(
|
||||
module: torch.nn.Module, linear_cls: Type, ignore_modules: Set[str], prefix: str = ""
|
||||
) -> None:
|
||||
@ -543,35 +481,6 @@ def _convert_linear_layers_to_llm_8bit(module: torch.nn.Module, ignore_modules:
|
||||
_convert_linear_layers_to_llm_8bit(child, ignore_modules, prefix=fullname)
|
||||
|
||||
|
||||
def _convert_linear_layers_to_nf4(
|
||||
module: torch.nn.Module, ignore_modules: Set[str], compute_dtype: torch.dtype, prefix: str = ""
|
||||
) -> None:
|
||||
for name, child in module.named_children():
|
||||
fullname = f"{prefix}.{name}" if prefix else name
|
||||
if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules):
|
||||
has_bias = child.bias is not None
|
||||
replacement = InvokeLinearNF4(
|
||||
child.in_features,
|
||||
child.out_features,
|
||||
bias=has_bias,
|
||||
compute_dtype=torch.float16,
|
||||
# TODO(ryand): Test compress_statistics=True.
|
||||
# compress_statistics=True,
|
||||
)
|
||||
# replacement.weight.data = child.weight.data
|
||||
# if has_bias:
|
||||
# replacement.bias.data = child.bias.data
|
||||
if has_bias:
|
||||
replacement.bias = _replace_param(replacement.bias, child.bias.data)
|
||||
replacement.weight = _replace_param(
|
||||
replacement.weight, child.weight.data, quant_state=replacement.weight.quant_state
|
||||
)
|
||||
replacement.requires_grad_(False)
|
||||
module.__setattr__(name, replacement)
|
||||
else:
|
||||
_convert_linear_layers_to_nf4(child, ignore_modules, compute_dtype=compute_dtype, prefix=fullname)
|
||||
|
||||
|
||||
# def _replace_linear_layers(
|
||||
# model: torch.nn.Module,
|
||||
# linear_layer_type: Literal["Linear8bitLt", "Linear4bit"],
|
||||
@ -646,17 +555,3 @@ def quantize_model_llm_int8(model: torch.nn.Module, modules_to_not_convert: set[
|
||||
_convert_linear_layers_to_llm_8bit(module=model, ignore_modules=modules_to_not_convert)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def quantize_model_nf4(model: torch.nn.Module, modules_to_not_convert: set[str], compute_dtype: torch.dtype):
|
||||
"""Apply bitsandbytes nf4 quantization to the model."""
|
||||
# model_device = get_parameter_device(model)
|
||||
# if model_device.type != "meta":
|
||||
# # Note: This is not strictly required, but I can't think of a good reason to quantize a model that's not on the
|
||||
# # meta device, so we enforce it for now.
|
||||
# raise RuntimeError("The model should be on the meta device to apply LLM.8bit() quantization.")
|
||||
|
||||
# with accelerate.init_empty_weights():
|
||||
_convert_linear_layers_to_nf4(module=model, ignore_modules=modules_to_not_convert, compute_dtype=compute_dtype)
|
||||
|
||||
return model
|
||||
|
@ -1,4 +1,5 @@
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
|
||||
import accelerate
|
||||
@ -6,100 +7,88 @@ import torch
|
||||
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from invokeai.backend.bnb import quantize_model_nf4
|
||||
|
||||
# Docs:
|
||||
# https://huggingface.co/docs/accelerate/usage_guides/quantization
|
||||
# https://huggingface.co/docs/bitsandbytes/v0.43.3/en/integrations#accelerate
|
||||
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
||||
|
||||
|
||||
def get_parameter_device(parameter: torch.nn.Module):
|
||||
return next(parameter.parameters()).device
|
||||
@contextmanager
|
||||
def log_time(name: str):
|
||||
"""Helper context manager to log the time taken by a block of code."""
|
||||
start = time.time()
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
end = time.time()
|
||||
print(f"'{name}' took {end - start:.4f} secs")
|
||||
|
||||
|
||||
def load_flux_transformer(path: Path) -> FluxTransformer2DModel:
|
||||
model_config = FluxTransformer2DModel.load_config(path, local_files_only=True)
|
||||
with accelerate.init_empty_weights():
|
||||
empty_model = FluxTransformer2DModel.from_config(model_config)
|
||||
assert isinstance(empty_model, FluxTransformer2DModel)
|
||||
def main():
|
||||
# Load the FLUX transformer model onto the meta device.
|
||||
model_path = Path(
|
||||
"/data/invokeai/models/.download_cache/black-forest-labs_flux.1-schnell/FLUX.1-schnell/transformer/"
|
||||
)
|
||||
|
||||
model_nf4_path = path / "bnb_nf4"
|
||||
with log_time("Intialize FLUX transformer on meta device"):
|
||||
model_config = FluxTransformer2DModel.load_config(model_path, local_files_only=True)
|
||||
with accelerate.init_empty_weights():
|
||||
empty_model = FluxTransformer2DModel.from_config(model_config)
|
||||
assert isinstance(empty_model, FluxTransformer2DModel)
|
||||
|
||||
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
|
||||
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.
|
||||
modules_to_not_convert: set[str] = set()
|
||||
|
||||
model_nf4_path = model_path / "bnb_nf4"
|
||||
if model_nf4_path.exists():
|
||||
# The quantized model already exists, load it and return it.
|
||||
# Note that the model loading code is the same when loading from quantized vs original weights. The only
|
||||
# difference is the weights_location.
|
||||
# model = load_and_quantize_model(
|
||||
# empty_model,
|
||||
# weights_location=model_8bit_path,
|
||||
# bnb_quantization_config=bnb_quantization_config,
|
||||
# # device_map="auto",
|
||||
# device_map={"": "cpu"},
|
||||
# )
|
||||
print(f"A pre-quantized model already exists at '{model_nf4_path}'. Attempting to load it...")
|
||||
|
||||
# TODO: Handle the keys that were not quantized (get_keys_to_not_convert).
|
||||
with accelerate.init_empty_weights():
|
||||
model = quantize_model_nf4(empty_model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
|
||||
# Replace the linear layers with NF4 quantized linear layers (still on the meta device).
|
||||
with log_time("Replace linear layers with NF4 layers"), accelerate.init_empty_weights():
|
||||
model = quantize_model_nf4(
|
||||
empty_model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
model.to_empty(device="cpu")
|
||||
sd = load_file(model_nf4_path / "model.safetensors")
|
||||
model.load_state_dict(sd, strict=True)
|
||||
model = model.to("cuda")
|
||||
with log_time("Load state dict into model"):
|
||||
sd = load_file(model_nf4_path / "model.safetensors")
|
||||
model.load_state_dict(sd, strict=True, assign=True)
|
||||
|
||||
with log_time("Move model to cuda"):
|
||||
model = model.to("cuda")
|
||||
|
||||
print(f"Successfully loaded pre-quantized model from '{model_nf4_path}'.")
|
||||
|
||||
else:
|
||||
# The quantized model does not exist yet, quantize and save it.
|
||||
# model = load_and_quantize_model(
|
||||
# empty_model,
|
||||
# weights_location=path,
|
||||
# bnb_quantization_config=bnb_quantization_config,
|
||||
# device_map="auto",
|
||||
# )
|
||||
# The quantized model does not exist, quantize the model and save it.
|
||||
print(f"No pre-quantized model found at '{model_nf4_path}'. Quantizing the model...")
|
||||
|
||||
# keys_to_not_convert = get_keys_to_not_convert(empty_model) # TODO
|
||||
with log_time("Replace linear layers with NF4 layers"), accelerate.init_empty_weights():
|
||||
model = quantize_model_nf4(
|
||||
empty_model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
# model_8bit_path.mkdir(parents=True, exist_ok=True)
|
||||
# accl = accelerate.Accelerator()
|
||||
# accl.save_model(model, model_8bit_path)
|
||||
with log_time("Load state dict into model"):
|
||||
# Load sharded state dict.
|
||||
files = list(model_path.glob("*.safetensors"))
|
||||
state_dict = dict()
|
||||
for file in files:
|
||||
sd = load_file(file)
|
||||
state_dict.update(sd)
|
||||
|
||||
# ---------------------
|
||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
|
||||
with accelerate.init_empty_weights():
|
||||
model = quantize_model_nf4(empty_model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
|
||||
with log_time("Move model to cuda and quantize"):
|
||||
model = model.to("cuda")
|
||||
|
||||
# Load sharded state dict.
|
||||
files = list(path.glob("*.safetensors"))
|
||||
state_dict = dict()
|
||||
for file in files:
|
||||
sd = load_file(file)
|
||||
state_dict.update(sd)
|
||||
with log_time("Save quantized model"):
|
||||
model_nf4_path.mkdir(parents=True, exist_ok=True)
|
||||
output_path = model_nf4_path / "model.safetensors"
|
||||
save_file(model.state_dict(), output_path)
|
||||
|
||||
# model.to_empty(device="cpu")
|
||||
# model.to(dtype=torch.float16)
|
||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
|
||||
# Load the state dict into the model. The bitsandbytes layers know how to load from both quantized and
|
||||
# non-quantized state dicts.
|
||||
# model.to_empty(device="cpu")
|
||||
# model.to(dtype=torch.float16)
|
||||
# result = model.load_state_dict(state_dict, strict=True)
|
||||
model = model.to("cuda")
|
||||
|
||||
model_nf4_path.mkdir(parents=True, exist_ok=True)
|
||||
save_file(model.state_dict(), model_nf4_path / "model.safetensors")
|
||||
|
||||
# ---------------------
|
||||
print(f"Successfully quantized and saved model to '{output_path}'.")
|
||||
|
||||
assert isinstance(model, FluxTransformer2DModel)
|
||||
return model
|
||||
|
||||
|
||||
def main():
|
||||
start = time.time()
|
||||
model = load_flux_transformer(
|
||||
Path("/data/invokeai/models/.download_cache/black-forest-labs_flux.1-schnell/FLUX.1-schnell/transformer/")
|
||||
)
|
||||
print(f"Time to load: {time.time() - start}s")
|
||||
print("hi")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
152
invokeai/backend/quantization/bnb_nf4.py
Normal file
152
invokeai/backend/quantization/bnb_nf4.py
Normal file
@ -0,0 +1,152 @@
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
|
||||
# This file contains utils for working with models that use bitsandbytes NF4 quantization.
|
||||
# The utils in this file are partially inspired by:
|
||||
# https://github.com/Lightning-AI/pytorch-lightning/blob/1551a16b94f5234a4a78801098f64d0732ef5cb5/src/lightning/fabric/plugins/precision/bitsandbytes.py
|
||||
|
||||
|
||||
class InvokeLinearNF4(bnb.nn.LinearNF4):
|
||||
"""A class that extends `bnb.nn.LinearNF4` to add the following functionality:
|
||||
- Ability to load Linear NF4 layers from a pre-quantized state_dict.
|
||||
- Ability to load Linear NF4 layers from a state_dict when the model is on the "meta" device.
|
||||
"""
|
||||
|
||||
def _load_from_state_dict(
|
||||
self,
|
||||
state_dict: dict[str, torch.Tensor],
|
||||
prefix: str,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
):
|
||||
"""This method is based on the logic in the bitsandbytes serialization unit tests for `Linear4bit`:
|
||||
https://github.com/bitsandbytes-foundation/bitsandbytes/blob/6d714a5cce3db5bd7f577bc447becc7a92d5ccc7/tests/test_linear4bit.py#L52-L71
|
||||
"""
|
||||
weight = state_dict.pop(prefix + "weight")
|
||||
bias = state_dict.pop(prefix + "bias", None)
|
||||
# We expect the remaining keys to be quant_state keys.
|
||||
quant_state_sd = state_dict
|
||||
|
||||
# During serialization, the quant_state is stored as subkeys of "weight." (See
|
||||
# `bnb.nn.LinearNF4._save_to_state_dict()`). We validate that they at least have the correct prefix.
|
||||
# TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs`
|
||||
# rather than raising an exception to correctly implement this API.
|
||||
assert all(k.startswith(prefix + "weight.") for k in quant_state_sd.keys())
|
||||
|
||||
if len(quant_state_sd) > 0:
|
||||
# We are loading a pre-quantized state dict.
|
||||
self.weight = bnb.nn.Params4bit.from_prequantized(
|
||||
data=weight, quantized_stats=quant_state_sd, device=weight.device
|
||||
)
|
||||
self.bias = bias if bias is None else torch.nn.Parameter(bias, requires_grad=False)
|
||||
else:
|
||||
# We are loading a non-quantized state dict.
|
||||
|
||||
# We could simply call the `super()._load_from_state_dict()` method here, but then we wouldn't be able to
|
||||
# load from a state_dict into a model on the "meta" device. Attempting to load into a model on the "meta"
|
||||
# device requires setting `assign=True`, doing this with the default `super()._load_from_state_dict()`
|
||||
# implementation causes `Params4Bit` to be replaced by a `torch.nn.Parameter`. By initializing a new
|
||||
# `Params4bit` object, we work around this issue. It's a bit hacky, but it gets the job done.
|
||||
self.weight = bnb.nn.Params4bit(
|
||||
data=weight,
|
||||
requires_grad=self.weight.requires_grad,
|
||||
compress_statistics=self.weight.compress_statistics,
|
||||
quant_type=self.weight.quant_type,
|
||||
quant_storage=self.weight.quant_storage,
|
||||
module=self,
|
||||
)
|
||||
self.bias = bias if bias is None else torch.nn.Parameter(bias)
|
||||
|
||||
|
||||
def _replace_param(
|
||||
param: torch.nn.Parameter | bnb.nn.Params4bit,
|
||||
data: torch.Tensor,
|
||||
) -> torch.nn.Parameter:
|
||||
"""A helper function to replace the data of a model parameter with new data in a way that allows replacing params on
|
||||
the "meta" device.
|
||||
|
||||
Supports both `torch.nn.Parameter` and `bnb.nn.Params4bit` parameters.
|
||||
"""
|
||||
if param.device.type == "meta":
|
||||
# Doing `param.data = data` raises a RuntimeError if param.data was on the "meta" device, so we need to
|
||||
# re-create the param instead of overwriting the data.
|
||||
if isinstance(param, bnb.nn.Params4bit):
|
||||
return bnb.nn.Params4bit(
|
||||
data,
|
||||
requires_grad=data.requires_grad,
|
||||
quant_state=param.quant_state,
|
||||
compress_statistics=param.compress_statistics,
|
||||
quant_type=param.quant_type,
|
||||
)
|
||||
return torch.nn.Parameter(data, requires_grad=data.requires_grad)
|
||||
|
||||
param.data = data
|
||||
return param
|
||||
|
||||
|
||||
def _convert_linear_layers_to_nf4(
|
||||
module: torch.nn.Module,
|
||||
ignore_modules: set[str],
|
||||
compute_dtype: torch.dtype,
|
||||
compress_statistics: bool = False,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
"""Convert all linear layers in the model to NF4 quantized linear layers.
|
||||
|
||||
Args:
|
||||
module: All linear layers in this module will be converted.
|
||||
ignore_modules: A set of module prefixes to ignore when converting linear layers.
|
||||
compute_dtype: The dtype to use for computation in the quantized linear layers.
|
||||
compress_statistics: Whether to enable nested quantization (aka double quantization) where the quantization
|
||||
constants from the first quantization are quantized again.
|
||||
prefix: The prefix of the current module in the model. Used to call this function recursively.
|
||||
"""
|
||||
for name, child in module.named_children():
|
||||
fullname = f"{prefix}.{name}" if prefix else name
|
||||
if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules):
|
||||
has_bias = child.bias is not None
|
||||
replacement = InvokeLinearNF4(
|
||||
child.in_features,
|
||||
child.out_features,
|
||||
bias=has_bias,
|
||||
compute_dtype=torch.float16,
|
||||
compress_statistics=compress_statistics,
|
||||
)
|
||||
if has_bias:
|
||||
replacement.bias = _replace_param(replacement.bias, child.bias.data)
|
||||
replacement.weight = _replace_param(replacement.weight, child.weight.data)
|
||||
replacement.requires_grad_(False)
|
||||
module.__setattr__(name, replacement)
|
||||
else:
|
||||
_convert_linear_layers_to_nf4(child, ignore_modules, compute_dtype=compute_dtype, prefix=fullname)
|
||||
|
||||
|
||||
def quantize_model_nf4(model: torch.nn.Module, modules_to_not_convert: set[str], compute_dtype: torch.dtype):
|
||||
"""Apply bitsandbytes nf4 quantization to the model.
|
||||
|
||||
You likely want to call this function inside a `accelerate.init_empty_weights()` context.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
# Initialize the model from a config on the meta device.
|
||||
with accelerate.init_empty_weights():
|
||||
model = ModelClass.from_config(...)
|
||||
|
||||
# Add NF4 quantization linear layers to the model - still on the meta device.
|
||||
with accelerate.init_empty_weights():
|
||||
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.float16)
|
||||
|
||||
# Load a state_dict into the model. (Could be either a prequantized or non-quantized state_dict.)
|
||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
|
||||
# Move the model to the "cuda" device. If the model was non-quantized, this is where the weight quantization takes
|
||||
# place.
|
||||
model.to("cuda")
|
||||
```
|
||||
"""
|
||||
_convert_linear_layers_to_nf4(module=model, ignore_modules=modules_to_not_convert, compute_dtype=compute_dtype)
|
||||
|
||||
return model
|
Loading…
Reference in New Issue
Block a user