LLM.int8() quantization is working, but still some rough edges to solve.

This commit is contained in:
Ryan Dick 2024-08-15 19:34:34 +00:00 committed by Brandon
parent 99b0f79784
commit f01f56a98e
7 changed files with 221 additions and 77 deletions

View File

@ -21,6 +21,7 @@ from invokeai.app.invocations.fields import (
) )
from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4 from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
from invokeai.backend.quantization.fast_quantized_diffusion_model import FastQuantizedDiffusersModel from invokeai.backend.quantization.fast_quantized_diffusion_model import FastQuantizedDiffusersModel
from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel
@ -49,6 +50,9 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Text-to-image generation using a FLUX model.""" """Text-to-image generation using a FLUX model."""
model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.") model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.")
quantization_type: Literal["raw", "NF4", "llm_int8"] = InputField(
default="raw", description="The type of quantization to use for the transformer model."
)
use_8bit: bool = InputField( use_8bit: bool = InputField(
default=False, description="Whether to quantize the transformer model to 8-bit precision." default=False, description="Whether to quantize the transformer model to 8-bit precision."
) )
@ -162,52 +166,37 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
return image return image
def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel: def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel:
if self.use_8bit: if self.quantization_type == "raw":
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
elif self.quantization_type == "NF4":
model_config = FluxTransformer2DModel.load_config(path, local_files_only=True) model_config = FluxTransformer2DModel.load_config(path, local_files_only=True)
with accelerate.init_empty_weights(): with accelerate.init_empty_weights():
empty_model = FluxTransformer2DModel.from_config(model_config) empty_model = FluxTransformer2DModel.from_config(model_config)
assert isinstance(empty_model, FluxTransformer2DModel) assert isinstance(empty_model, FluxTransformer2DModel)
model_nf4_path = path / "bnb_nf4" model_nf4_path = path / "bnb_nf4"
if model_nf4_path.exists(): assert model_nf4_path.exists()
with accelerate.init_empty_weights(): with accelerate.init_empty_weights():
model = quantize_model_nf4(empty_model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16) model = quantize_model_nf4(empty_model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
# model.to_empty(device="cpu") # TODO(ryand): Right now, some of the weights are loaded in bfloat16. Think about how best to handle
# TODO(ryand): Right now, some of the weights are loaded in bfloat16. Think about how best to handle # this on GPUs without bfloat16 support.
# this on GPUs without bfloat16 support. sd = load_file(model_nf4_path / "model.safetensors")
sd = load_file(model_nf4_path / "model.safetensors") model.load_state_dict(sd, strict=True, assign=True)
model.load_state_dict(sd, strict=True, assign=True) elif self.quantization_type == "llm_int8":
# model = model.to("cuda") model_config = FluxTransformer2DModel.load_config(path, local_files_only=True)
with accelerate.init_empty_weights():
empty_model = FluxTransformer2DModel.from_config(model_config)
assert isinstance(empty_model, FluxTransformer2DModel)
model_int8_path = path / "bnb_llm_int8"
assert model_int8_path.exists()
with accelerate.init_empty_weights():
model = quantize_model_llm_int8(empty_model, modules_to_not_convert=set())
# model_8bit_path = path / "quantized" sd = load_file(model_int8_path / "model.safetensors")
# if model_8bit_path.exists(): model.load_state_dict(sd, strict=True, assign=True)
# # The quantized model exists, load it.
# # TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
# # something that we should be able to make much faster.
# q_model = QuantizedFluxTransformer2DModel.from_pretrained(model_8bit_path)
# # Access the underlying wrapped model.
# # We access the wrapped model, even though it is private, because it simplifies the type checking by
# # always returning a FluxTransformer2DModel from this function.
# model = q_model._wrapped
# else:
# # The quantized model does not exist yet, quantize and save it.
# # TODO(ryand): Loading in float16 and then quantizing seems to result in NaNs. In order to run this on
# # GPUs that don't support bfloat16, we would need to host the quantized model instead of generating it
# # here.
# model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
# assert isinstance(model, FluxTransformer2DModel)
# q_model = QuantizedFluxTransformer2DModel.quantize(model, weights=qfloat8)
# model_8bit_path.mkdir(parents=True, exist_ok=True)
# q_model.save_pretrained(model_8bit_path)
# # (See earlier comment about accessing the wrapped model.)
# model = q_model._wrapped
else: else:
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16) raise ValueError(f"Unsupported quantization type: {self.quantization_type}")
assert isinstance(model, FluxTransformer2DModel) assert isinstance(model, FluxTransformer2DModel)
return model return model

View File

@ -1,6 +1,5 @@
from typing import Any, Optional, Set, Type from typing import Any, Optional, Set, Type
import accelerate
import bitsandbytes as bnb import bitsandbytes as bnb
import torch import torch
@ -460,27 +459,6 @@ def _convert_linear_layers(
_convert_linear_layers(child, linear_cls, ignore_modules, prefix=fullname) _convert_linear_layers(child, linear_cls, ignore_modules, prefix=fullname)
def _convert_linear_layers_to_llm_8bit(module: torch.nn.Module, ignore_modules: Set[str], prefix: str = "") -> None:
for name, child in module.named_children():
fullname = f"{prefix}.{name}" if prefix else name
if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules):
has_bias = child.bias is not None
replacement = InvokeLinear8bitLt(
child.in_features,
child.out_features,
bias=has_bias,
has_fp16_weights=False,
# device=device,
)
replacement.weight.data = child.weight.data
if has_bias:
replacement.bias.data = child.bias.data
replacement.requires_grad_(False)
module.__setattr__(name, replacement)
else:
_convert_linear_layers_to_llm_8bit(child, ignore_modules, prefix=fullname)
# def _replace_linear_layers( # def _replace_linear_layers(
# model: torch.nn.Module, # model: torch.nn.Module,
# linear_layer_type: Literal["Linear8bitLt", "Linear4bit"], # linear_layer_type: Literal["Linear8bitLt", "Linear4bit"],
@ -537,21 +515,3 @@ def _convert_linear_layers_to_llm_8bit(module: torch.nn.Module, ignore_modules:
# # Remove the last key for recursion # # Remove the last key for recursion
# current_key_name.pop(-1) # current_key_name.pop(-1)
# return model, has_been_replaced # return model, has_been_replaced
def get_parameter_device(parameter: torch.nn.Module):
return next(parameter.parameters()).device
def quantize_model_llm_int8(model: torch.nn.Module, modules_to_not_convert: set[str]):
"""Apply bitsandbytes LLM.8bit() quantization to the model."""
model_device = get_parameter_device(model)
if model_device.type != "meta":
# Note: This is not strictly required, but I can't think of a good reason to quantize a model that's not on the
# meta device, so we enforce it for now.
raise RuntimeError("The model should be on the meta device to apply LLM.8bit() quantization.")
with accelerate.init_empty_weights():
_convert_linear_layers_to_llm_8bit(module=model, ignore_modules=modules_to_not_convert)
return model

View File

@ -0,0 +1,102 @@
import bitsandbytes as bnb
import torch
# This file contains utils for working with models that use bitsandbytes LLM.int8() quantization.
# The utils in this file are partially inspired by:
# https://github.com/Lightning-AI/pytorch-lightning/blob/1551a16b94f5234a4a78801098f64d0732ef5cb5/src/lightning/fabric/plugins/precision/bitsandbytes.py
# NOTE(ryand): All of the custom state_dict manipulation logic in this file is pretty hacky. This could be made much
# cleaner by re-implementing bnb.nn.Linear8bitLt with proper use of buffers and less magic. But, for now, we try to
# stick close to the bitsandbytes classes to make interoperability easier with other models that might use bitsandbytes.
class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
def _load_from_state_dict(
self,
state_dict: dict[str, torch.Tensor],
prefix: str,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
weight = state_dict.pop(prefix + "weight")
bias = state_dict.pop(prefix + "bias", None)
# See `bnb.nn.Linear8bitLt._save_to_state_dict()` for the serialization logic of SCB and weight_format.
scb = state_dict.pop(prefix + "SCB", None)
# weight_format is unused, but we pop it so we can validate that there are no unexpected keys.
_weight_format = state_dict.pop(prefix + "weight_format", None)
# TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs`
# rather than raising an exception to correctly implement this API.
assert len(state_dict) == 0
if scb is not None:
# We are loading a pre-quantized state dict.
self.weight = bnb.nn.Int8Params(
data=weight,
requires_grad=self.weight.requires_grad,
has_fp16_weights=False,
# Note: After quantization, CB is the same as weight.
CB=weight,
SCB=scb,
)
self.bias = bias if bias is None else torch.nn.Parameter(bias)
else:
# We are loading a non-quantized state dict.
# We could simply call the `super()._load_from_state_dict()` method here, but then we wouldn't be able to
# load from a state_dict into a model on the "meta" device. Attempting to load into a model on the "meta"
# device requires setting `assign=True`, doing this with the default `super()._load_from_state_dict()`
# implementation causes `Params4Bit` to be replaced by a `torch.nn.Parameter`. By initializing a new
# `Params4bit` object, we work around this issue. It's a bit hacky, but it gets the job done.
self.weight = bnb.nn.Int8Params(
data=weight,
requires_grad=self.weight.requires_grad,
has_fp16_weights=False,
CB=None,
SCB=None,
)
self.bias = bias if bias is None else torch.nn.Parameter(bias)
def _convert_linear_layers_to_llm_8bit(
module: torch.nn.Module, ignore_modules: set[str], outlier_threshold: float, prefix: str = ""
) -> None:
"""Convert all linear layers in the module to bnb.nn.Linear8bitLt layers."""
for name, child in module.named_children():
fullname = f"{prefix}.{name}" if prefix else name
if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules):
has_bias = child.bias is not None
replacement = InvokeLinear8bitLt(
child.in_features,
child.out_features,
bias=has_bias,
has_fp16_weights=False,
threshold=outlier_threshold,
)
replacement.weight.data = child.weight.data
if has_bias:
replacement.bias.data = child.bias.data
replacement.requires_grad_(False)
module.__setattr__(name, replacement)
else:
_convert_linear_layers_to_llm_8bit(
child, ignore_modules, outlier_threshold=outlier_threshold, prefix=fullname
)
def get_parameter_device(parameter: torch.nn.Module):
return next(parameter.parameters()).device
def quantize_model_llm_int8(model: torch.nn.Module, modules_to_not_convert: set[str], outlier_threshold: float = 6.0):
"""Apply bitsandbytes LLM.8bit() quantization to the model."""
_convert_linear_layers_to_llm_8bit(
module=model, ignore_modules=modules_to_not_convert, outlier_threshold=outlier_threshold
)
return model

View File

@ -5,6 +5,10 @@ import torch
# The utils in this file are partially inspired by: # The utils in this file are partially inspired by:
# https://github.com/Lightning-AI/pytorch-lightning/blob/1551a16b94f5234a4a78801098f64d0732ef5cb5/src/lightning/fabric/plugins/precision/bitsandbytes.py # https://github.com/Lightning-AI/pytorch-lightning/blob/1551a16b94f5234a4a78801098f64d0732ef5cb5/src/lightning/fabric/plugins/precision/bitsandbytes.py
# NOTE(ryand): All of the custom state_dict manipulation logic in this file is pretty hacky. This could be made much
# cleaner by re-implementing bnb.nn.LinearNF4 with proper use of buffers and less magic. But, for now, we try to stick
# close to the bitsandbytes classes to make interoperability easier with other models that might use bitsandbytes.
class InvokeLinearNF4(bnb.nn.LinearNF4): class InvokeLinearNF4(bnb.nn.LinearNF4):
"""A class that extends `bnb.nn.LinearNF4` to add the following functionality: """A class that extends `bnb.nn.LinearNF4` to add the following functionality:

View File

@ -0,0 +1,89 @@
import time
from contextlib import contextmanager
from pathlib import Path
import accelerate
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from safetensors.torch import load_file, save_file
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
@contextmanager
def log_time(name: str):
"""Helper context manager to log the time taken by a block of code."""
start = time.time()
try:
yield None
finally:
end = time.time()
print(f"'{name}' took {end - start:.4f} secs")
def main():
# Load the FLUX transformer model onto the meta device.
model_path = Path(
"/data/invokeai/models/.download_cache/black-forest-labs_flux.1-schnell/FLUX.1-schnell/transformer/"
)
with log_time("Initialize FLUX transformer on meta device"):
model_config = FluxTransformer2DModel.load_config(model_path, local_files_only=True)
with accelerate.init_empty_weights():
empty_model = FluxTransformer2DModel.from_config(model_config)
assert isinstance(empty_model, FluxTransformer2DModel)
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.
modules_to_not_convert: set[str] = set()
model_int8_path = model_path / "bnb_llm_int8"
if model_int8_path.exists():
# The quantized model already exists, load it and return it.
print(f"A pre-quantized model already exists at '{model_int8_path}'. Attempting to load it...")
# Replace the linear layers with LLM.int8() quantized linear layers (still on the meta device).
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
model = quantize_model_llm_int8(empty_model, modules_to_not_convert=modules_to_not_convert)
with log_time("Load state dict into model"):
sd = load_file(model_int8_path / "model.safetensors")
model.load_state_dict(sd, strict=True, assign=True)
with log_time("Move model to cuda"):
model = model.to("cuda")
print(f"Successfully loaded pre-quantized model from '{model_int8_path}'.")
else:
# The quantized model does not exist, quantize the model and save it.
print(f"No pre-quantized model found at '{model_int8_path}'. Quantizing the model...")
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
model = quantize_model_llm_int8(empty_model, modules_to_not_convert=modules_to_not_convert)
with log_time("Load state dict into model"):
# Load sharded state dict.
files = list(model_path.glob("*.safetensors"))
state_dict = dict()
for file in files:
sd = load_file(file)
state_dict.update(sd)
model.load_state_dict(state_dict, strict=True, assign=True)
with log_time("Move model to cuda and quantize"):
model = model.to("cuda")
with log_time("Save quantized model"):
model_int8_path.mkdir(parents=True, exist_ok=True)
output_path = model_int8_path / "model.safetensors"
save_file(model.state_dict(), output_path)
print(f"Successfully quantized and saved model to '{output_path}'.")
assert isinstance(model, FluxTransformer2DModel)
return model
if __name__ == "__main__":
main()