diff --git a/invokeai/app/invocations/flux_text_to_image.py b/invokeai/app/invocations/flux_text_to_image.py index 542fa6d6b5..930f4c40ce 100644 --- a/invokeai/app/invocations/flux_text_to_image.py +++ b/invokeai/app/invocations/flux_text_to_image.py @@ -21,7 +21,7 @@ from invokeai.app.invocations.fields import ( ) from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.backend.bnb import quantize_model_nf4 +from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4 from invokeai.backend.quantization.fast_quantized_diffusion_model import FastQuantizedDiffusersModel from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo diff --git a/invokeai/backend/bnb.py b/invokeai/backend/bnb.py index 8c1f080e98..168bb1b686 100644 --- a/invokeai/backend/bnb.py +++ b/invokeai/backend/bnb.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Set, Tuple, Type +from typing import Any, Optional, Set, Type import accelerate import bitsandbytes as bnb @@ -51,47 +51,6 @@ import torch # self.SCB = SCB -class InvokeLinearNF4(bnb.nn.LinearNF4): - def _load_from_state_dict( - self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs - ): - """This method is based on the logic in the bitsandbytes serialization unit tests for `Linear4bit`: - https://github.com/bitsandbytes-foundation/bitsandbytes/blob/6d714a5cce3db5bd7f577bc447becc7a92d5ccc7/tests/test_linear4bit.py#L52-L71 - - I'm not sure why this was not included in the original `Linear4bit` implementation. - """ - - weight = state_dict.pop(prefix + "weight") - bias = state_dict.pop(prefix + "bias", None) - # During serialization, the quant_state is stored as subkeys of "weight.". - # We expect the remaining keys to be quant_state keys. We validate that they at least have the correct prefix. - quant_state_sd = state_dict - assert all(k.startswith(prefix + "weight.") for k in quant_state_sd.keys()) - - if len(quant_state_sd) > 0: - # We are loading a quantized state dict. - self.weight = bnb.nn.Params4bit.from_prequantized( - data=weight, quantized_stats=quant_state_sd, device=weight.device - ) - self.bias = bias if bias is None else torch.nn.Parameter(bias, requires_grad=False) - - else: - # We are loading a non-quantized state dict. - - # We could simply call the `super()._load_from_state_dict` method here, but then we wouldn't be able to load - # into from a state_dict into a model on the "meta" device. By initializing a new `Params4bit` object, we - # work around this issue. - self.weight = bnb.nn.Params4bit( - data=weight, - requires_grad=self.weight.requires_grad, - compress_statistics=self.weight.compress_statistics, - quant_type=self.weight.quant_type, - quant_storage=self.weight.quant_storage, - module=self, - ) - self.bias = bias if bias is None else torch.nn.Parameter(bias) - - class Invoke2Linear8bitLt(torch.nn.Linear): """This class is the base module for the [LLM.int8()](https://arxiv.org/abs/2208.07339) algorithm.""" @@ -474,27 +433,6 @@ def convert_model_to_bnb_llm_int8(model: torch.nn.Module, ignore_modules: set[st # incompatible_keys.missing_keys.remove(key) -def _replace_param( - param: torch.nn.Parameter, data: torch.Tensor, quant_state: Optional[Tuple] = None -) -> torch.nn.Parameter: - # doing `param.data = weight` raises a RuntimeError if param.data was on meta-device, so - # we need to re-create the parameters instead of overwriting the data - if param.device.type == "meta": - if isinstance(param, bnb.nn.Params4bit): - return bnb.nn.Params4bit( - data, - requires_grad=data.requires_grad, - quant_state=quant_state, - compress_statistics=param.compress_statistics, - quant_type=param.quant_type, - ) - return torch.nn.Parameter(data, requires_grad=data.requires_grad) - param.data = data - if isinstance(param, bnb.nn.Params4bit): - param.quant_state = quant_state - return param - - def _convert_linear_layers( module: torch.nn.Module, linear_cls: Type, ignore_modules: Set[str], prefix: str = "" ) -> None: @@ -543,35 +481,6 @@ def _convert_linear_layers_to_llm_8bit(module: torch.nn.Module, ignore_modules: _convert_linear_layers_to_llm_8bit(child, ignore_modules, prefix=fullname) -def _convert_linear_layers_to_nf4( - module: torch.nn.Module, ignore_modules: Set[str], compute_dtype: torch.dtype, prefix: str = "" -) -> None: - for name, child in module.named_children(): - fullname = f"{prefix}.{name}" if prefix else name - if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules): - has_bias = child.bias is not None - replacement = InvokeLinearNF4( - child.in_features, - child.out_features, - bias=has_bias, - compute_dtype=torch.float16, - # TODO(ryand): Test compress_statistics=True. - # compress_statistics=True, - ) - # replacement.weight.data = child.weight.data - # if has_bias: - # replacement.bias.data = child.bias.data - if has_bias: - replacement.bias = _replace_param(replacement.bias, child.bias.data) - replacement.weight = _replace_param( - replacement.weight, child.weight.data, quant_state=replacement.weight.quant_state - ) - replacement.requires_grad_(False) - module.__setattr__(name, replacement) - else: - _convert_linear_layers_to_nf4(child, ignore_modules, compute_dtype=compute_dtype, prefix=fullname) - - # def _replace_linear_layers( # model: torch.nn.Module, # linear_layer_type: Literal["Linear8bitLt", "Linear4bit"], @@ -646,17 +555,3 @@ def quantize_model_llm_int8(model: torch.nn.Module, modules_to_not_convert: set[ _convert_linear_layers_to_llm_8bit(module=model, ignore_modules=modules_to_not_convert) return model - - -def quantize_model_nf4(model: torch.nn.Module, modules_to_not_convert: set[str], compute_dtype: torch.dtype): - """Apply bitsandbytes nf4 quantization to the model.""" - # model_device = get_parameter_device(model) - # if model_device.type != "meta": - # # Note: This is not strictly required, but I can't think of a good reason to quantize a model that's not on the - # # meta device, so we enforce it for now. - # raise RuntimeError("The model should be on the meta device to apply LLM.8bit() quantization.") - - # with accelerate.init_empty_weights(): - _convert_linear_layers_to_nf4(module=model, ignore_modules=modules_to_not_convert, compute_dtype=compute_dtype) - - return model diff --git a/invokeai/backend/load_flux_model_bnb_nf4.py b/invokeai/backend/load_flux_model_bnb_nf4.py index 5cff6f07d4..b55c56a032 100644 --- a/invokeai/backend/load_flux_model_bnb_nf4.py +++ b/invokeai/backend/load_flux_model_bnb_nf4.py @@ -1,4 +1,5 @@ import time +from contextlib import contextmanager from pathlib import Path import accelerate @@ -6,100 +7,88 @@ import torch from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel from safetensors.torch import load_file, save_file -from invokeai.backend.bnb import quantize_model_nf4 - -# Docs: -# https://huggingface.co/docs/accelerate/usage_guides/quantization -# https://huggingface.co/docs/bitsandbytes/v0.43.3/en/integrations#accelerate +from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4 -def get_parameter_device(parameter: torch.nn.Module): - return next(parameter.parameters()).device +@contextmanager +def log_time(name: str): + """Helper context manager to log the time taken by a block of code.""" + start = time.time() + try: + yield None + finally: + end = time.time() + print(f"'{name}' took {end - start:.4f} secs") -def load_flux_transformer(path: Path) -> FluxTransformer2DModel: - model_config = FluxTransformer2DModel.load_config(path, local_files_only=True) - with accelerate.init_empty_weights(): - empty_model = FluxTransformer2DModel.from_config(model_config) - assert isinstance(empty_model, FluxTransformer2DModel) +def main(): + # Load the FLUX transformer model onto the meta device. + model_path = Path( + "/data/invokeai/models/.download_cache/black-forest-labs_flux.1-schnell/FLUX.1-schnell/transformer/" + ) - model_nf4_path = path / "bnb_nf4" + with log_time("Intialize FLUX transformer on meta device"): + model_config = FluxTransformer2DModel.load_config(model_path, local_files_only=True) + with accelerate.init_empty_weights(): + empty_model = FluxTransformer2DModel.from_config(model_config) + assert isinstance(empty_model, FluxTransformer2DModel) + + # TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate + # `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize. + modules_to_not_convert: set[str] = set() + + model_nf4_path = model_path / "bnb_nf4" if model_nf4_path.exists(): # The quantized model already exists, load it and return it. - # Note that the model loading code is the same when loading from quantized vs original weights. The only - # difference is the weights_location. - # model = load_and_quantize_model( - # empty_model, - # weights_location=model_8bit_path, - # bnb_quantization_config=bnb_quantization_config, - # # device_map="auto", - # device_map={"": "cpu"}, - # ) + print(f"A pre-quantized model already exists at '{model_nf4_path}'. Attempting to load it...") - # TODO: Handle the keys that were not quantized (get_keys_to_not_convert). - with accelerate.init_empty_weights(): - model = quantize_model_nf4(empty_model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16) + # Replace the linear layers with NF4 quantized linear layers (still on the meta device). + with log_time("Replace linear layers with NF4 layers"), accelerate.init_empty_weights(): + model = quantize_model_nf4( + empty_model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16 + ) - model.to_empty(device="cpu") - sd = load_file(model_nf4_path / "model.safetensors") - model.load_state_dict(sd, strict=True) - model = model.to("cuda") + with log_time("Load state dict into model"): + sd = load_file(model_nf4_path / "model.safetensors") + model.load_state_dict(sd, strict=True, assign=True) + + with log_time("Move model to cuda"): + model = model.to("cuda") + + print(f"Successfully loaded pre-quantized model from '{model_nf4_path}'.") else: - # The quantized model does not exist yet, quantize and save it. - # model = load_and_quantize_model( - # empty_model, - # weights_location=path, - # bnb_quantization_config=bnb_quantization_config, - # device_map="auto", - # ) + # The quantized model does not exist, quantize the model and save it. + print(f"No pre-quantized model found at '{model_nf4_path}'. Quantizing the model...") - # keys_to_not_convert = get_keys_to_not_convert(empty_model) # TODO + with log_time("Replace linear layers with NF4 layers"), accelerate.init_empty_weights(): + model = quantize_model_nf4( + empty_model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16 + ) - # model_8bit_path.mkdir(parents=True, exist_ok=True) - # accl = accelerate.Accelerator() - # accl.save_model(model, model_8bit_path) + with log_time("Load state dict into model"): + # Load sharded state dict. + files = list(model_path.glob("*.safetensors")) + state_dict = dict() + for file in files: + sd = load_file(file) + state_dict.update(sd) - # --------------------- + model.load_state_dict(state_dict, strict=True, assign=True) - with accelerate.init_empty_weights(): - model = quantize_model_nf4(empty_model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16) + with log_time("Move model to cuda and quantize"): + model = model.to("cuda") - # Load sharded state dict. - files = list(path.glob("*.safetensors")) - state_dict = dict() - for file in files: - sd = load_file(file) - state_dict.update(sd) + with log_time("Save quantized model"): + model_nf4_path.mkdir(parents=True, exist_ok=True) + output_path = model_nf4_path / "model.safetensors" + save_file(model.state_dict(), output_path) - # model.to_empty(device="cpu") - # model.to(dtype=torch.float16) - model.load_state_dict(state_dict, strict=True, assign=True) - - # Load the state dict into the model. The bitsandbytes layers know how to load from both quantized and - # non-quantized state dicts. - # model.to_empty(device="cpu") - # model.to(dtype=torch.float16) - # result = model.load_state_dict(state_dict, strict=True) - model = model.to("cuda") - - model_nf4_path.mkdir(parents=True, exist_ok=True) - save_file(model.state_dict(), model_nf4_path / "model.safetensors") - - # --------------------- + print(f"Successfully quantized and saved model to '{output_path}'.") assert isinstance(model, FluxTransformer2DModel) return model -def main(): - start = time.time() - model = load_flux_transformer( - Path("/data/invokeai/models/.download_cache/black-forest-labs_flux.1-schnell/FLUX.1-schnell/transformer/") - ) - print(f"Time to load: {time.time() - start}s") - print("hi") - - if __name__ == "__main__": main() diff --git a/invokeai/backend/quantization/bnb_nf4.py b/invokeai/backend/quantization/bnb_nf4.py new file mode 100644 index 0000000000..02a2a732bf --- /dev/null +++ b/invokeai/backend/quantization/bnb_nf4.py @@ -0,0 +1,152 @@ +import bitsandbytes as bnb +import torch + +# This file contains utils for working with models that use bitsandbytes NF4 quantization. +# The utils in this file are partially inspired by: +# https://github.com/Lightning-AI/pytorch-lightning/blob/1551a16b94f5234a4a78801098f64d0732ef5cb5/src/lightning/fabric/plugins/precision/bitsandbytes.py + + +class InvokeLinearNF4(bnb.nn.LinearNF4): + """A class that extends `bnb.nn.LinearNF4` to add the following functionality: + - Ability to load Linear NF4 layers from a pre-quantized state_dict. + - Ability to load Linear NF4 layers from a state_dict when the model is on the "meta" device. + """ + + def _load_from_state_dict( + self, + state_dict: dict[str, torch.Tensor], + prefix: str, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + """This method is based on the logic in the bitsandbytes serialization unit tests for `Linear4bit`: + https://github.com/bitsandbytes-foundation/bitsandbytes/blob/6d714a5cce3db5bd7f577bc447becc7a92d5ccc7/tests/test_linear4bit.py#L52-L71 + """ + weight = state_dict.pop(prefix + "weight") + bias = state_dict.pop(prefix + "bias", None) + # We expect the remaining keys to be quant_state keys. + quant_state_sd = state_dict + + # During serialization, the quant_state is stored as subkeys of "weight." (See + # `bnb.nn.LinearNF4._save_to_state_dict()`). We validate that they at least have the correct prefix. + # TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs` + # rather than raising an exception to correctly implement this API. + assert all(k.startswith(prefix + "weight.") for k in quant_state_sd.keys()) + + if len(quant_state_sd) > 0: + # We are loading a pre-quantized state dict. + self.weight = bnb.nn.Params4bit.from_prequantized( + data=weight, quantized_stats=quant_state_sd, device=weight.device + ) + self.bias = bias if bias is None else torch.nn.Parameter(bias, requires_grad=False) + else: + # We are loading a non-quantized state dict. + + # We could simply call the `super()._load_from_state_dict()` method here, but then we wouldn't be able to + # load from a state_dict into a model on the "meta" device. Attempting to load into a model on the "meta" + # device requires setting `assign=True`, doing this with the default `super()._load_from_state_dict()` + # implementation causes `Params4Bit` to be replaced by a `torch.nn.Parameter`. By initializing a new + # `Params4bit` object, we work around this issue. It's a bit hacky, but it gets the job done. + self.weight = bnb.nn.Params4bit( + data=weight, + requires_grad=self.weight.requires_grad, + compress_statistics=self.weight.compress_statistics, + quant_type=self.weight.quant_type, + quant_storage=self.weight.quant_storage, + module=self, + ) + self.bias = bias if bias is None else torch.nn.Parameter(bias) + + +def _replace_param( + param: torch.nn.Parameter | bnb.nn.Params4bit, + data: torch.Tensor, +) -> torch.nn.Parameter: + """A helper function to replace the data of a model parameter with new data in a way that allows replacing params on + the "meta" device. + + Supports both `torch.nn.Parameter` and `bnb.nn.Params4bit` parameters. + """ + if param.device.type == "meta": + # Doing `param.data = data` raises a RuntimeError if param.data was on the "meta" device, so we need to + # re-create the param instead of overwriting the data. + if isinstance(param, bnb.nn.Params4bit): + return bnb.nn.Params4bit( + data, + requires_grad=data.requires_grad, + quant_state=param.quant_state, + compress_statistics=param.compress_statistics, + quant_type=param.quant_type, + ) + return torch.nn.Parameter(data, requires_grad=data.requires_grad) + + param.data = data + return param + + +def _convert_linear_layers_to_nf4( + module: torch.nn.Module, + ignore_modules: set[str], + compute_dtype: torch.dtype, + compress_statistics: bool = False, + prefix: str = "", +) -> None: + """Convert all linear layers in the model to NF4 quantized linear layers. + + Args: + module: All linear layers in this module will be converted. + ignore_modules: A set of module prefixes to ignore when converting linear layers. + compute_dtype: The dtype to use for computation in the quantized linear layers. + compress_statistics: Whether to enable nested quantization (aka double quantization) where the quantization + constants from the first quantization are quantized again. + prefix: The prefix of the current module in the model. Used to call this function recursively. + """ + for name, child in module.named_children(): + fullname = f"{prefix}.{name}" if prefix else name + if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules): + has_bias = child.bias is not None + replacement = InvokeLinearNF4( + child.in_features, + child.out_features, + bias=has_bias, + compute_dtype=torch.float16, + compress_statistics=compress_statistics, + ) + if has_bias: + replacement.bias = _replace_param(replacement.bias, child.bias.data) + replacement.weight = _replace_param(replacement.weight, child.weight.data) + replacement.requires_grad_(False) + module.__setattr__(name, replacement) + else: + _convert_linear_layers_to_nf4(child, ignore_modules, compute_dtype=compute_dtype, prefix=fullname) + + +def quantize_model_nf4(model: torch.nn.Module, modules_to_not_convert: set[str], compute_dtype: torch.dtype): + """Apply bitsandbytes nf4 quantization to the model. + + You likely want to call this function inside a `accelerate.init_empty_weights()` context. + + Example usage: + ``` + # Initialize the model from a config on the meta device. + with accelerate.init_empty_weights(): + model = ModelClass.from_config(...) + + # Add NF4 quantization linear layers to the model - still on the meta device. + with accelerate.init_empty_weights(): + model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.float16) + + # Load a state_dict into the model. (Could be either a prequantized or non-quantized state_dict.) + model.load_state_dict(state_dict, strict=True, assign=True) + + # Move the model to the "cuda" device. If the model was non-quantized, this is where the weight quantization takes + # place. + model.to("cuda") + ``` + """ + _convert_linear_layers_to_nf4(module=model, ignore_modules=modules_to_not_convert, compute_dtype=compute_dtype) + + return model