mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
LLM.int8() quantization is working, but still some rough edges to solve.
This commit is contained in:
parent
99b0f79784
commit
f01f56a98e
@ -21,6 +21,7 @@ from invokeai.app.invocations.fields import (
|
|||||||
)
|
)
|
||||||
from invokeai.app.invocations.primitives import ImageOutput
|
from invokeai.app.invocations.primitives import ImageOutput
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
|
||||||
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
||||||
from invokeai.backend.quantization.fast_quantized_diffusion_model import FastQuantizedDiffusersModel
|
from invokeai.backend.quantization.fast_quantized_diffusion_model import FastQuantizedDiffusersModel
|
||||||
from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel
|
from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel
|
||||||
@ -49,6 +50,9 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
"""Text-to-image generation using a FLUX model."""
|
"""Text-to-image generation using a FLUX model."""
|
||||||
|
|
||||||
model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.")
|
model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.")
|
||||||
|
quantization_type: Literal["raw", "NF4", "llm_int8"] = InputField(
|
||||||
|
default="raw", description="The type of quantization to use for the transformer model."
|
||||||
|
)
|
||||||
use_8bit: bool = InputField(
|
use_8bit: bool = InputField(
|
||||||
default=False, description="Whether to quantize the transformer model to 8-bit precision."
|
default=False, description="Whether to quantize the transformer model to 8-bit precision."
|
||||||
)
|
)
|
||||||
@ -162,52 +166,37 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
return image
|
return image
|
||||||
|
|
||||||
def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel:
|
def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel:
|
||||||
if self.use_8bit:
|
if self.quantization_type == "raw":
|
||||||
|
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
|
||||||
|
elif self.quantization_type == "NF4":
|
||||||
model_config = FluxTransformer2DModel.load_config(path, local_files_only=True)
|
model_config = FluxTransformer2DModel.load_config(path, local_files_only=True)
|
||||||
with accelerate.init_empty_weights():
|
with accelerate.init_empty_weights():
|
||||||
empty_model = FluxTransformer2DModel.from_config(model_config)
|
empty_model = FluxTransformer2DModel.from_config(model_config)
|
||||||
assert isinstance(empty_model, FluxTransformer2DModel)
|
assert isinstance(empty_model, FluxTransformer2DModel)
|
||||||
|
|
||||||
model_nf4_path = path / "bnb_nf4"
|
model_nf4_path = path / "bnb_nf4"
|
||||||
if model_nf4_path.exists():
|
assert model_nf4_path.exists()
|
||||||
with accelerate.init_empty_weights():
|
with accelerate.init_empty_weights():
|
||||||
model = quantize_model_nf4(empty_model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
|
model = quantize_model_nf4(empty_model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
|
||||||
|
|
||||||
# model.to_empty(device="cpu")
|
# TODO(ryand): Right now, some of the weights are loaded in bfloat16. Think about how best to handle
|
||||||
# TODO(ryand): Right now, some of the weights are loaded in bfloat16. Think about how best to handle
|
# this on GPUs without bfloat16 support.
|
||||||
# this on GPUs without bfloat16 support.
|
sd = load_file(model_nf4_path / "model.safetensors")
|
||||||
sd = load_file(model_nf4_path / "model.safetensors")
|
model.load_state_dict(sd, strict=True, assign=True)
|
||||||
model.load_state_dict(sd, strict=True, assign=True)
|
elif self.quantization_type == "llm_int8":
|
||||||
# model = model.to("cuda")
|
model_config = FluxTransformer2DModel.load_config(path, local_files_only=True)
|
||||||
|
with accelerate.init_empty_weights():
|
||||||
|
empty_model = FluxTransformer2DModel.from_config(model_config)
|
||||||
|
assert isinstance(empty_model, FluxTransformer2DModel)
|
||||||
|
model_int8_path = path / "bnb_llm_int8"
|
||||||
|
assert model_int8_path.exists()
|
||||||
|
with accelerate.init_empty_weights():
|
||||||
|
model = quantize_model_llm_int8(empty_model, modules_to_not_convert=set())
|
||||||
|
|
||||||
# model_8bit_path = path / "quantized"
|
sd = load_file(model_int8_path / "model.safetensors")
|
||||||
# if model_8bit_path.exists():
|
model.load_state_dict(sd, strict=True, assign=True)
|
||||||
# # The quantized model exists, load it.
|
|
||||||
# # TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
|
|
||||||
# # something that we should be able to make much faster.
|
|
||||||
# q_model = QuantizedFluxTransformer2DModel.from_pretrained(model_8bit_path)
|
|
||||||
|
|
||||||
# # Access the underlying wrapped model.
|
|
||||||
# # We access the wrapped model, even though it is private, because it simplifies the type checking by
|
|
||||||
# # always returning a FluxTransformer2DModel from this function.
|
|
||||||
# model = q_model._wrapped
|
|
||||||
# else:
|
|
||||||
# # The quantized model does not exist yet, quantize and save it.
|
|
||||||
# # TODO(ryand): Loading in float16 and then quantizing seems to result in NaNs. In order to run this on
|
|
||||||
# # GPUs that don't support bfloat16, we would need to host the quantized model instead of generating it
|
|
||||||
# # here.
|
|
||||||
# model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
|
|
||||||
# assert isinstance(model, FluxTransformer2DModel)
|
|
||||||
|
|
||||||
# q_model = QuantizedFluxTransformer2DModel.quantize(model, weights=qfloat8)
|
|
||||||
|
|
||||||
# model_8bit_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
# q_model.save_pretrained(model_8bit_path)
|
|
||||||
|
|
||||||
# # (See earlier comment about accessing the wrapped model.)
|
|
||||||
# model = q_model._wrapped
|
|
||||||
else:
|
else:
|
||||||
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
|
raise ValueError(f"Unsupported quantization type: {self.quantization_type}")
|
||||||
|
|
||||||
assert isinstance(model, FluxTransformer2DModel)
|
assert isinstance(model, FluxTransformer2DModel)
|
||||||
return model
|
return model
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
from typing import Any, Optional, Set, Type
|
from typing import Any, Optional, Set, Type
|
||||||
|
|
||||||
import accelerate
|
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -460,27 +459,6 @@ def _convert_linear_layers(
|
|||||||
_convert_linear_layers(child, linear_cls, ignore_modules, prefix=fullname)
|
_convert_linear_layers(child, linear_cls, ignore_modules, prefix=fullname)
|
||||||
|
|
||||||
|
|
||||||
def _convert_linear_layers_to_llm_8bit(module: torch.nn.Module, ignore_modules: Set[str], prefix: str = "") -> None:
|
|
||||||
for name, child in module.named_children():
|
|
||||||
fullname = f"{prefix}.{name}" if prefix else name
|
|
||||||
if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules):
|
|
||||||
has_bias = child.bias is not None
|
|
||||||
replacement = InvokeLinear8bitLt(
|
|
||||||
child.in_features,
|
|
||||||
child.out_features,
|
|
||||||
bias=has_bias,
|
|
||||||
has_fp16_weights=False,
|
|
||||||
# device=device,
|
|
||||||
)
|
|
||||||
replacement.weight.data = child.weight.data
|
|
||||||
if has_bias:
|
|
||||||
replacement.bias.data = child.bias.data
|
|
||||||
replacement.requires_grad_(False)
|
|
||||||
module.__setattr__(name, replacement)
|
|
||||||
else:
|
|
||||||
_convert_linear_layers_to_llm_8bit(child, ignore_modules, prefix=fullname)
|
|
||||||
|
|
||||||
|
|
||||||
# def _replace_linear_layers(
|
# def _replace_linear_layers(
|
||||||
# model: torch.nn.Module,
|
# model: torch.nn.Module,
|
||||||
# linear_layer_type: Literal["Linear8bitLt", "Linear4bit"],
|
# linear_layer_type: Literal["Linear8bitLt", "Linear4bit"],
|
||||||
@ -537,21 +515,3 @@ def _convert_linear_layers_to_llm_8bit(module: torch.nn.Module, ignore_modules:
|
|||||||
# # Remove the last key for recursion
|
# # Remove the last key for recursion
|
||||||
# current_key_name.pop(-1)
|
# current_key_name.pop(-1)
|
||||||
# return model, has_been_replaced
|
# return model, has_been_replaced
|
||||||
|
|
||||||
|
|
||||||
def get_parameter_device(parameter: torch.nn.Module):
|
|
||||||
return next(parameter.parameters()).device
|
|
||||||
|
|
||||||
|
|
||||||
def quantize_model_llm_int8(model: torch.nn.Module, modules_to_not_convert: set[str]):
|
|
||||||
"""Apply bitsandbytes LLM.8bit() quantization to the model."""
|
|
||||||
model_device = get_parameter_device(model)
|
|
||||||
if model_device.type != "meta":
|
|
||||||
# Note: This is not strictly required, but I can't think of a good reason to quantize a model that's not on the
|
|
||||||
# meta device, so we enforce it for now.
|
|
||||||
raise RuntimeError("The model should be on the meta device to apply LLM.8bit() quantization.")
|
|
||||||
|
|
||||||
with accelerate.init_empty_weights():
|
|
||||||
_convert_linear_layers_to_llm_8bit(module=model, ignore_modules=modules_to_not_convert)
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
102
invokeai/backend/quantization/bnb_llm_int8.py
Normal file
102
invokeai/backend/quantization/bnb_llm_int8.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
import bitsandbytes as bnb
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# This file contains utils for working with models that use bitsandbytes LLM.int8() quantization.
|
||||||
|
# The utils in this file are partially inspired by:
|
||||||
|
# https://github.com/Lightning-AI/pytorch-lightning/blob/1551a16b94f5234a4a78801098f64d0732ef5cb5/src/lightning/fabric/plugins/precision/bitsandbytes.py
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE(ryand): All of the custom state_dict manipulation logic in this file is pretty hacky. This could be made much
|
||||||
|
# cleaner by re-implementing bnb.nn.Linear8bitLt with proper use of buffers and less magic. But, for now, we try to
|
||||||
|
# stick close to the bitsandbytes classes to make interoperability easier with other models that might use bitsandbytes.
|
||||||
|
|
||||||
|
|
||||||
|
class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
|
||||||
|
def _load_from_state_dict(
|
||||||
|
self,
|
||||||
|
state_dict: dict[str, torch.Tensor],
|
||||||
|
prefix: str,
|
||||||
|
local_metadata,
|
||||||
|
strict,
|
||||||
|
missing_keys,
|
||||||
|
unexpected_keys,
|
||||||
|
error_msgs,
|
||||||
|
):
|
||||||
|
weight = state_dict.pop(prefix + "weight")
|
||||||
|
bias = state_dict.pop(prefix + "bias", None)
|
||||||
|
|
||||||
|
# See `bnb.nn.Linear8bitLt._save_to_state_dict()` for the serialization logic of SCB and weight_format.
|
||||||
|
scb = state_dict.pop(prefix + "SCB", None)
|
||||||
|
# weight_format is unused, but we pop it so we can validate that there are no unexpected keys.
|
||||||
|
_weight_format = state_dict.pop(prefix + "weight_format", None)
|
||||||
|
|
||||||
|
# TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs`
|
||||||
|
# rather than raising an exception to correctly implement this API.
|
||||||
|
assert len(state_dict) == 0
|
||||||
|
|
||||||
|
if scb is not None:
|
||||||
|
# We are loading a pre-quantized state dict.
|
||||||
|
self.weight = bnb.nn.Int8Params(
|
||||||
|
data=weight,
|
||||||
|
requires_grad=self.weight.requires_grad,
|
||||||
|
has_fp16_weights=False,
|
||||||
|
# Note: After quantization, CB is the same as weight.
|
||||||
|
CB=weight,
|
||||||
|
SCB=scb,
|
||||||
|
)
|
||||||
|
self.bias = bias if bias is None else torch.nn.Parameter(bias)
|
||||||
|
else:
|
||||||
|
# We are loading a non-quantized state dict.
|
||||||
|
|
||||||
|
# We could simply call the `super()._load_from_state_dict()` method here, but then we wouldn't be able to
|
||||||
|
# load from a state_dict into a model on the "meta" device. Attempting to load into a model on the "meta"
|
||||||
|
# device requires setting `assign=True`, doing this with the default `super()._load_from_state_dict()`
|
||||||
|
# implementation causes `Params4Bit` to be replaced by a `torch.nn.Parameter`. By initializing a new
|
||||||
|
# `Params4bit` object, we work around this issue. It's a bit hacky, but it gets the job done.
|
||||||
|
self.weight = bnb.nn.Int8Params(
|
||||||
|
data=weight,
|
||||||
|
requires_grad=self.weight.requires_grad,
|
||||||
|
has_fp16_weights=False,
|
||||||
|
CB=None,
|
||||||
|
SCB=None,
|
||||||
|
)
|
||||||
|
self.bias = bias if bias is None else torch.nn.Parameter(bias)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_linear_layers_to_llm_8bit(
|
||||||
|
module: torch.nn.Module, ignore_modules: set[str], outlier_threshold: float, prefix: str = ""
|
||||||
|
) -> None:
|
||||||
|
"""Convert all linear layers in the module to bnb.nn.Linear8bitLt layers."""
|
||||||
|
for name, child in module.named_children():
|
||||||
|
fullname = f"{prefix}.{name}" if prefix else name
|
||||||
|
if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules):
|
||||||
|
has_bias = child.bias is not None
|
||||||
|
replacement = InvokeLinear8bitLt(
|
||||||
|
child.in_features,
|
||||||
|
child.out_features,
|
||||||
|
bias=has_bias,
|
||||||
|
has_fp16_weights=False,
|
||||||
|
threshold=outlier_threshold,
|
||||||
|
)
|
||||||
|
replacement.weight.data = child.weight.data
|
||||||
|
if has_bias:
|
||||||
|
replacement.bias.data = child.bias.data
|
||||||
|
replacement.requires_grad_(False)
|
||||||
|
module.__setattr__(name, replacement)
|
||||||
|
else:
|
||||||
|
_convert_linear_layers_to_llm_8bit(
|
||||||
|
child, ignore_modules, outlier_threshold=outlier_threshold, prefix=fullname
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_parameter_device(parameter: torch.nn.Module):
|
||||||
|
return next(parameter.parameters()).device
|
||||||
|
|
||||||
|
|
||||||
|
def quantize_model_llm_int8(model: torch.nn.Module, modules_to_not_convert: set[str], outlier_threshold: float = 6.0):
|
||||||
|
"""Apply bitsandbytes LLM.8bit() quantization to the model."""
|
||||||
|
_convert_linear_layers_to_llm_8bit(
|
||||||
|
module=model, ignore_modules=modules_to_not_convert, outlier_threshold=outlier_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
return model
|
@ -5,6 +5,10 @@ import torch
|
|||||||
# The utils in this file are partially inspired by:
|
# The utils in this file are partially inspired by:
|
||||||
# https://github.com/Lightning-AI/pytorch-lightning/blob/1551a16b94f5234a4a78801098f64d0732ef5cb5/src/lightning/fabric/plugins/precision/bitsandbytes.py
|
# https://github.com/Lightning-AI/pytorch-lightning/blob/1551a16b94f5234a4a78801098f64d0732ef5cb5/src/lightning/fabric/plugins/precision/bitsandbytes.py
|
||||||
|
|
||||||
|
# NOTE(ryand): All of the custom state_dict manipulation logic in this file is pretty hacky. This could be made much
|
||||||
|
# cleaner by re-implementing bnb.nn.LinearNF4 with proper use of buffers and less magic. But, for now, we try to stick
|
||||||
|
# close to the bitsandbytes classes to make interoperability easier with other models that might use bitsandbytes.
|
||||||
|
|
||||||
|
|
||||||
class InvokeLinearNF4(bnb.nn.LinearNF4):
|
class InvokeLinearNF4(bnb.nn.LinearNF4):
|
||||||
"""A class that extends `bnb.nn.LinearNF4` to add the following functionality:
|
"""A class that extends `bnb.nn.LinearNF4` to add the following functionality:
|
||||||
|
@ -0,0 +1,89 @@
|
|||||||
|
import time
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import accelerate
|
||||||
|
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
||||||
|
from safetensors.torch import load_file, save_file
|
||||||
|
|
||||||
|
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def log_time(name: str):
|
||||||
|
"""Helper context manager to log the time taken by a block of code."""
|
||||||
|
start = time.time()
|
||||||
|
try:
|
||||||
|
yield None
|
||||||
|
finally:
|
||||||
|
end = time.time()
|
||||||
|
print(f"'{name}' took {end - start:.4f} secs")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Load the FLUX transformer model onto the meta device.
|
||||||
|
model_path = Path(
|
||||||
|
"/data/invokeai/models/.download_cache/black-forest-labs_flux.1-schnell/FLUX.1-schnell/transformer/"
|
||||||
|
)
|
||||||
|
|
||||||
|
with log_time("Initialize FLUX transformer on meta device"):
|
||||||
|
model_config = FluxTransformer2DModel.load_config(model_path, local_files_only=True)
|
||||||
|
with accelerate.init_empty_weights():
|
||||||
|
empty_model = FluxTransformer2DModel.from_config(model_config)
|
||||||
|
assert isinstance(empty_model, FluxTransformer2DModel)
|
||||||
|
|
||||||
|
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
|
||||||
|
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.
|
||||||
|
modules_to_not_convert: set[str] = set()
|
||||||
|
|
||||||
|
model_int8_path = model_path / "bnb_llm_int8"
|
||||||
|
if model_int8_path.exists():
|
||||||
|
# The quantized model already exists, load it and return it.
|
||||||
|
print(f"A pre-quantized model already exists at '{model_int8_path}'. Attempting to load it...")
|
||||||
|
|
||||||
|
# Replace the linear layers with LLM.int8() quantized linear layers (still on the meta device).
|
||||||
|
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
|
||||||
|
model = quantize_model_llm_int8(empty_model, modules_to_not_convert=modules_to_not_convert)
|
||||||
|
|
||||||
|
with log_time("Load state dict into model"):
|
||||||
|
sd = load_file(model_int8_path / "model.safetensors")
|
||||||
|
model.load_state_dict(sd, strict=True, assign=True)
|
||||||
|
|
||||||
|
with log_time("Move model to cuda"):
|
||||||
|
model = model.to("cuda")
|
||||||
|
|
||||||
|
print(f"Successfully loaded pre-quantized model from '{model_int8_path}'.")
|
||||||
|
|
||||||
|
else:
|
||||||
|
# The quantized model does not exist, quantize the model and save it.
|
||||||
|
print(f"No pre-quantized model found at '{model_int8_path}'. Quantizing the model...")
|
||||||
|
|
||||||
|
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
|
||||||
|
model = quantize_model_llm_int8(empty_model, modules_to_not_convert=modules_to_not_convert)
|
||||||
|
|
||||||
|
with log_time("Load state dict into model"):
|
||||||
|
# Load sharded state dict.
|
||||||
|
files = list(model_path.glob("*.safetensors"))
|
||||||
|
state_dict = dict()
|
||||||
|
for file in files:
|
||||||
|
sd = load_file(file)
|
||||||
|
state_dict.update(sd)
|
||||||
|
|
||||||
|
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||||
|
|
||||||
|
with log_time("Move model to cuda and quantize"):
|
||||||
|
model = model.to("cuda")
|
||||||
|
|
||||||
|
with log_time("Save quantized model"):
|
||||||
|
model_int8_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
output_path = model_int8_path / "model.safetensors"
|
||||||
|
save_file(model.state_dict(), output_path)
|
||||||
|
|
||||||
|
print(f"Successfully quantized and saved model to '{output_path}'.")
|
||||||
|
|
||||||
|
assert isinstance(model, FluxTransformer2DModel)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
Reference in New Issue
Block a user