mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
LLM.int8() quantization is working, but still some rough edges to solve.
This commit is contained in:
@ -1,6 +1,5 @@
|
||||
from typing import Any, Optional, Set, Type
|
||||
|
||||
import accelerate
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
|
||||
@ -460,27 +459,6 @@ def _convert_linear_layers(
|
||||
_convert_linear_layers(child, linear_cls, ignore_modules, prefix=fullname)
|
||||
|
||||
|
||||
def _convert_linear_layers_to_llm_8bit(module: torch.nn.Module, ignore_modules: Set[str], prefix: str = "") -> None:
|
||||
for name, child in module.named_children():
|
||||
fullname = f"{prefix}.{name}" if prefix else name
|
||||
if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules):
|
||||
has_bias = child.bias is not None
|
||||
replacement = InvokeLinear8bitLt(
|
||||
child.in_features,
|
||||
child.out_features,
|
||||
bias=has_bias,
|
||||
has_fp16_weights=False,
|
||||
# device=device,
|
||||
)
|
||||
replacement.weight.data = child.weight.data
|
||||
if has_bias:
|
||||
replacement.bias.data = child.bias.data
|
||||
replacement.requires_grad_(False)
|
||||
module.__setattr__(name, replacement)
|
||||
else:
|
||||
_convert_linear_layers_to_llm_8bit(child, ignore_modules, prefix=fullname)
|
||||
|
||||
|
||||
# def _replace_linear_layers(
|
||||
# model: torch.nn.Module,
|
||||
# linear_layer_type: Literal["Linear8bitLt", "Linear4bit"],
|
||||
@ -537,21 +515,3 @@ def _convert_linear_layers_to_llm_8bit(module: torch.nn.Module, ignore_modules:
|
||||
# # Remove the last key for recursion
|
||||
# current_key_name.pop(-1)
|
||||
# return model, has_been_replaced
|
||||
|
||||
|
||||
def get_parameter_device(parameter: torch.nn.Module):
|
||||
return next(parameter.parameters()).device
|
||||
|
||||
|
||||
def quantize_model_llm_int8(model: torch.nn.Module, modules_to_not_convert: set[str]):
|
||||
"""Apply bitsandbytes LLM.8bit() quantization to the model."""
|
||||
model_device = get_parameter_device(model)
|
||||
if model_device.type != "meta":
|
||||
# Note: This is not strictly required, but I can't think of a good reason to quantize a model that's not on the
|
||||
# meta device, so we enforce it for now.
|
||||
raise RuntimeError("The model should be on the meta device to apply LLM.8bit() quantization.")
|
||||
|
||||
with accelerate.init_empty_weights():
|
||||
_convert_linear_layers_to_llm_8bit(module=model, ignore_modules=modules_to_not_convert)
|
||||
|
||||
return model
|
||||
|
102
invokeai/backend/quantization/bnb_llm_int8.py
Normal file
102
invokeai/backend/quantization/bnb_llm_int8.py
Normal file
@ -0,0 +1,102 @@
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
|
||||
# This file contains utils for working with models that use bitsandbytes LLM.int8() quantization.
|
||||
# The utils in this file are partially inspired by:
|
||||
# https://github.com/Lightning-AI/pytorch-lightning/blob/1551a16b94f5234a4a78801098f64d0732ef5cb5/src/lightning/fabric/plugins/precision/bitsandbytes.py
|
||||
|
||||
|
||||
# NOTE(ryand): All of the custom state_dict manipulation logic in this file is pretty hacky. This could be made much
|
||||
# cleaner by re-implementing bnb.nn.Linear8bitLt with proper use of buffers and less magic. But, for now, we try to
|
||||
# stick close to the bitsandbytes classes to make interoperability easier with other models that might use bitsandbytes.
|
||||
|
||||
|
||||
class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
|
||||
def _load_from_state_dict(
|
||||
self,
|
||||
state_dict: dict[str, torch.Tensor],
|
||||
prefix: str,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
):
|
||||
weight = state_dict.pop(prefix + "weight")
|
||||
bias = state_dict.pop(prefix + "bias", None)
|
||||
|
||||
# See `bnb.nn.Linear8bitLt._save_to_state_dict()` for the serialization logic of SCB and weight_format.
|
||||
scb = state_dict.pop(prefix + "SCB", None)
|
||||
# weight_format is unused, but we pop it so we can validate that there are no unexpected keys.
|
||||
_weight_format = state_dict.pop(prefix + "weight_format", None)
|
||||
|
||||
# TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs`
|
||||
# rather than raising an exception to correctly implement this API.
|
||||
assert len(state_dict) == 0
|
||||
|
||||
if scb is not None:
|
||||
# We are loading a pre-quantized state dict.
|
||||
self.weight = bnb.nn.Int8Params(
|
||||
data=weight,
|
||||
requires_grad=self.weight.requires_grad,
|
||||
has_fp16_weights=False,
|
||||
# Note: After quantization, CB is the same as weight.
|
||||
CB=weight,
|
||||
SCB=scb,
|
||||
)
|
||||
self.bias = bias if bias is None else torch.nn.Parameter(bias)
|
||||
else:
|
||||
# We are loading a non-quantized state dict.
|
||||
|
||||
# We could simply call the `super()._load_from_state_dict()` method here, but then we wouldn't be able to
|
||||
# load from a state_dict into a model on the "meta" device. Attempting to load into a model on the "meta"
|
||||
# device requires setting `assign=True`, doing this with the default `super()._load_from_state_dict()`
|
||||
# implementation causes `Params4Bit` to be replaced by a `torch.nn.Parameter`. By initializing a new
|
||||
# `Params4bit` object, we work around this issue. It's a bit hacky, but it gets the job done.
|
||||
self.weight = bnb.nn.Int8Params(
|
||||
data=weight,
|
||||
requires_grad=self.weight.requires_grad,
|
||||
has_fp16_weights=False,
|
||||
CB=None,
|
||||
SCB=None,
|
||||
)
|
||||
self.bias = bias if bias is None else torch.nn.Parameter(bias)
|
||||
|
||||
|
||||
def _convert_linear_layers_to_llm_8bit(
|
||||
module: torch.nn.Module, ignore_modules: set[str], outlier_threshold: float, prefix: str = ""
|
||||
) -> None:
|
||||
"""Convert all linear layers in the module to bnb.nn.Linear8bitLt layers."""
|
||||
for name, child in module.named_children():
|
||||
fullname = f"{prefix}.{name}" if prefix else name
|
||||
if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules):
|
||||
has_bias = child.bias is not None
|
||||
replacement = InvokeLinear8bitLt(
|
||||
child.in_features,
|
||||
child.out_features,
|
||||
bias=has_bias,
|
||||
has_fp16_weights=False,
|
||||
threshold=outlier_threshold,
|
||||
)
|
||||
replacement.weight.data = child.weight.data
|
||||
if has_bias:
|
||||
replacement.bias.data = child.bias.data
|
||||
replacement.requires_grad_(False)
|
||||
module.__setattr__(name, replacement)
|
||||
else:
|
||||
_convert_linear_layers_to_llm_8bit(
|
||||
child, ignore_modules, outlier_threshold=outlier_threshold, prefix=fullname
|
||||
)
|
||||
|
||||
|
||||
def get_parameter_device(parameter: torch.nn.Module):
|
||||
return next(parameter.parameters()).device
|
||||
|
||||
|
||||
def quantize_model_llm_int8(model: torch.nn.Module, modules_to_not_convert: set[str], outlier_threshold: float = 6.0):
|
||||
"""Apply bitsandbytes LLM.8bit() quantization to the model."""
|
||||
_convert_linear_layers_to_llm_8bit(
|
||||
module=model, ignore_modules=modules_to_not_convert, outlier_threshold=outlier_threshold
|
||||
)
|
||||
|
||||
return model
|
@ -5,6 +5,10 @@ import torch
|
||||
# The utils in this file are partially inspired by:
|
||||
# https://github.com/Lightning-AI/pytorch-lightning/blob/1551a16b94f5234a4a78801098f64d0732ef5cb5/src/lightning/fabric/plugins/precision/bitsandbytes.py
|
||||
|
||||
# NOTE(ryand): All of the custom state_dict manipulation logic in this file is pretty hacky. This could be made much
|
||||
# cleaner by re-implementing bnb.nn.LinearNF4 with proper use of buffers and less magic. But, for now, we try to stick
|
||||
# close to the bitsandbytes classes to make interoperability easier with other models that might use bitsandbytes.
|
||||
|
||||
|
||||
class InvokeLinearNF4(bnb.nn.LinearNF4):
|
||||
"""A class that extends `bnb.nn.LinearNF4` to add the following functionality:
|
||||
|
@ -0,0 +1,89 @@
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
|
||||
import accelerate
|
||||
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
|
||||
|
||||
|
||||
@contextmanager
|
||||
def log_time(name: str):
|
||||
"""Helper context manager to log the time taken by a block of code."""
|
||||
start = time.time()
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
end = time.time()
|
||||
print(f"'{name}' took {end - start:.4f} secs")
|
||||
|
||||
|
||||
def main():
|
||||
# Load the FLUX transformer model onto the meta device.
|
||||
model_path = Path(
|
||||
"/data/invokeai/models/.download_cache/black-forest-labs_flux.1-schnell/FLUX.1-schnell/transformer/"
|
||||
)
|
||||
|
||||
with log_time("Initialize FLUX transformer on meta device"):
|
||||
model_config = FluxTransformer2DModel.load_config(model_path, local_files_only=True)
|
||||
with accelerate.init_empty_weights():
|
||||
empty_model = FluxTransformer2DModel.from_config(model_config)
|
||||
assert isinstance(empty_model, FluxTransformer2DModel)
|
||||
|
||||
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
|
||||
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.
|
||||
modules_to_not_convert: set[str] = set()
|
||||
|
||||
model_int8_path = model_path / "bnb_llm_int8"
|
||||
if model_int8_path.exists():
|
||||
# The quantized model already exists, load it and return it.
|
||||
print(f"A pre-quantized model already exists at '{model_int8_path}'. Attempting to load it...")
|
||||
|
||||
# Replace the linear layers with LLM.int8() quantized linear layers (still on the meta device).
|
||||
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
|
||||
model = quantize_model_llm_int8(empty_model, modules_to_not_convert=modules_to_not_convert)
|
||||
|
||||
with log_time("Load state dict into model"):
|
||||
sd = load_file(model_int8_path / "model.safetensors")
|
||||
model.load_state_dict(sd, strict=True, assign=True)
|
||||
|
||||
with log_time("Move model to cuda"):
|
||||
model = model.to("cuda")
|
||||
|
||||
print(f"Successfully loaded pre-quantized model from '{model_int8_path}'.")
|
||||
|
||||
else:
|
||||
# The quantized model does not exist, quantize the model and save it.
|
||||
print(f"No pre-quantized model found at '{model_int8_path}'. Quantizing the model...")
|
||||
|
||||
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
|
||||
model = quantize_model_llm_int8(empty_model, modules_to_not_convert=modules_to_not_convert)
|
||||
|
||||
with log_time("Load state dict into model"):
|
||||
# Load sharded state dict.
|
||||
files = list(model_path.glob("*.safetensors"))
|
||||
state_dict = dict()
|
||||
for file in files:
|
||||
sd = load_file(file)
|
||||
state_dict.update(sd)
|
||||
|
||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
|
||||
with log_time("Move model to cuda and quantize"):
|
||||
model = model.to("cuda")
|
||||
|
||||
with log_time("Save quantized model"):
|
||||
model_int8_path.mkdir(parents=True, exist_ok=True)
|
||||
output_path = model_int8_path / "model.safetensors"
|
||||
save_file(model.state_dict(), output_path)
|
||||
|
||||
print(f"Successfully quantized and saved model to '{output_path}'.")
|
||||
|
||||
assert isinstance(model, FluxTransformer2DModel)
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user