diff --git a/invokeai/app/invocations/flux_text_to_image.py b/invokeai/app/invocations/flux_text_to_image.py index 930f4c40ce..de34a6eb5e 100644 --- a/invokeai/app/invocations/flux_text_to_image.py +++ b/invokeai/app/invocations/flux_text_to_image.py @@ -21,6 +21,7 @@ from invokeai.app.invocations.fields import ( ) from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8 from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4 from invokeai.backend.quantization.fast_quantized_diffusion_model import FastQuantizedDiffusersModel from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel @@ -49,6 +50,9 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): """Text-to-image generation using a FLUX model.""" model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.") + quantization_type: Literal["raw", "NF4", "llm_int8"] = InputField( + default="raw", description="The type of quantization to use for the transformer model." + ) use_8bit: bool = InputField( default=False, description="Whether to quantize the transformer model to 8-bit precision." ) @@ -162,52 +166,37 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): return image def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel: - if self.use_8bit: + if self.quantization_type == "raw": + model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16) + elif self.quantization_type == "NF4": model_config = FluxTransformer2DModel.load_config(path, local_files_only=True) with accelerate.init_empty_weights(): empty_model = FluxTransformer2DModel.from_config(model_config) assert isinstance(empty_model, FluxTransformer2DModel) model_nf4_path = path / "bnb_nf4" - if model_nf4_path.exists(): - with accelerate.init_empty_weights(): - model = quantize_model_nf4(empty_model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16) + assert model_nf4_path.exists() + with accelerate.init_empty_weights(): + model = quantize_model_nf4(empty_model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16) - # model.to_empty(device="cpu") - # TODO(ryand): Right now, some of the weights are loaded in bfloat16. Think about how best to handle - # this on GPUs without bfloat16 support. - sd = load_file(model_nf4_path / "model.safetensors") - model.load_state_dict(sd, strict=True, assign=True) - # model = model.to("cuda") + # TODO(ryand): Right now, some of the weights are loaded in bfloat16. Think about how best to handle + # this on GPUs without bfloat16 support. + sd = load_file(model_nf4_path / "model.safetensors") + model.load_state_dict(sd, strict=True, assign=True) + elif self.quantization_type == "llm_int8": + model_config = FluxTransformer2DModel.load_config(path, local_files_only=True) + with accelerate.init_empty_weights(): + empty_model = FluxTransformer2DModel.from_config(model_config) + assert isinstance(empty_model, FluxTransformer2DModel) + model_int8_path = path / "bnb_llm_int8" + assert model_int8_path.exists() + with accelerate.init_empty_weights(): + model = quantize_model_llm_int8(empty_model, modules_to_not_convert=set()) - # model_8bit_path = path / "quantized" - # if model_8bit_path.exists(): - # # The quantized model exists, load it. - # # TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like - # # something that we should be able to make much faster. - # q_model = QuantizedFluxTransformer2DModel.from_pretrained(model_8bit_path) - - # # Access the underlying wrapped model. - # # We access the wrapped model, even though it is private, because it simplifies the type checking by - # # always returning a FluxTransformer2DModel from this function. - # model = q_model._wrapped - # else: - # # The quantized model does not exist yet, quantize and save it. - # # TODO(ryand): Loading in float16 and then quantizing seems to result in NaNs. In order to run this on - # # GPUs that don't support bfloat16, we would need to host the quantized model instead of generating it - # # here. - # model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16) - # assert isinstance(model, FluxTransformer2DModel) - - # q_model = QuantizedFluxTransformer2DModel.quantize(model, weights=qfloat8) - - # model_8bit_path.mkdir(parents=True, exist_ok=True) - # q_model.save_pretrained(model_8bit_path) - - # # (See earlier comment about accessing the wrapped model.) - # model = q_model._wrapped + sd = load_file(model_int8_path / "model.safetensors") + model.load_state_dict(sd, strict=True, assign=True) else: - model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16) + raise ValueError(f"Unsupported quantization type: {self.quantization_type}") assert isinstance(model, FluxTransformer2DModel) return model diff --git a/invokeai/backend/bnb.py b/invokeai/backend/bnb.py index 168bb1b686..1022a1d1dc 100644 --- a/invokeai/backend/bnb.py +++ b/invokeai/backend/bnb.py @@ -1,6 +1,5 @@ from typing import Any, Optional, Set, Type -import accelerate import bitsandbytes as bnb import torch @@ -460,27 +459,6 @@ def _convert_linear_layers( _convert_linear_layers(child, linear_cls, ignore_modules, prefix=fullname) -def _convert_linear_layers_to_llm_8bit(module: torch.nn.Module, ignore_modules: Set[str], prefix: str = "") -> None: - for name, child in module.named_children(): - fullname = f"{prefix}.{name}" if prefix else name - if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules): - has_bias = child.bias is not None - replacement = InvokeLinear8bitLt( - child.in_features, - child.out_features, - bias=has_bias, - has_fp16_weights=False, - # device=device, - ) - replacement.weight.data = child.weight.data - if has_bias: - replacement.bias.data = child.bias.data - replacement.requires_grad_(False) - module.__setattr__(name, replacement) - else: - _convert_linear_layers_to_llm_8bit(child, ignore_modules, prefix=fullname) - - # def _replace_linear_layers( # model: torch.nn.Module, # linear_layer_type: Literal["Linear8bitLt", "Linear4bit"], @@ -537,21 +515,3 @@ def _convert_linear_layers_to_llm_8bit(module: torch.nn.Module, ignore_modules: # # Remove the last key for recursion # current_key_name.pop(-1) # return model, has_been_replaced - - -def get_parameter_device(parameter: torch.nn.Module): - return next(parameter.parameters()).device - - -def quantize_model_llm_int8(model: torch.nn.Module, modules_to_not_convert: set[str]): - """Apply bitsandbytes LLM.8bit() quantization to the model.""" - model_device = get_parameter_device(model) - if model_device.type != "meta": - # Note: This is not strictly required, but I can't think of a good reason to quantize a model that's not on the - # meta device, so we enforce it for now. - raise RuntimeError("The model should be on the meta device to apply LLM.8bit() quantization.") - - with accelerate.init_empty_weights(): - _convert_linear_layers_to_llm_8bit(module=model, ignore_modules=modules_to_not_convert) - - return model diff --git a/invokeai/backend/load_flux_model_bnb_llm_int8.py b/invokeai/backend/load_flux_model_bnb_llm_int8_old.py similarity index 100% rename from invokeai/backend/load_flux_model_bnb_llm_int8.py rename to invokeai/backend/load_flux_model_bnb_llm_int8_old.py diff --git a/invokeai/backend/quantization/bnb_llm_int8.py b/invokeai/backend/quantization/bnb_llm_int8.py new file mode 100644 index 0000000000..900c55a085 --- /dev/null +++ b/invokeai/backend/quantization/bnb_llm_int8.py @@ -0,0 +1,102 @@ +import bitsandbytes as bnb +import torch + +# This file contains utils for working with models that use bitsandbytes LLM.int8() quantization. +# The utils in this file are partially inspired by: +# https://github.com/Lightning-AI/pytorch-lightning/blob/1551a16b94f5234a4a78801098f64d0732ef5cb5/src/lightning/fabric/plugins/precision/bitsandbytes.py + + +# NOTE(ryand): All of the custom state_dict manipulation logic in this file is pretty hacky. This could be made much +# cleaner by re-implementing bnb.nn.Linear8bitLt with proper use of buffers and less magic. But, for now, we try to +# stick close to the bitsandbytes classes to make interoperability easier with other models that might use bitsandbytes. + + +class InvokeLinear8bitLt(bnb.nn.Linear8bitLt): + def _load_from_state_dict( + self, + state_dict: dict[str, torch.Tensor], + prefix: str, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + weight = state_dict.pop(prefix + "weight") + bias = state_dict.pop(prefix + "bias", None) + + # See `bnb.nn.Linear8bitLt._save_to_state_dict()` for the serialization logic of SCB and weight_format. + scb = state_dict.pop(prefix + "SCB", None) + # weight_format is unused, but we pop it so we can validate that there are no unexpected keys. + _weight_format = state_dict.pop(prefix + "weight_format", None) + + # TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs` + # rather than raising an exception to correctly implement this API. + assert len(state_dict) == 0 + + if scb is not None: + # We are loading a pre-quantized state dict. + self.weight = bnb.nn.Int8Params( + data=weight, + requires_grad=self.weight.requires_grad, + has_fp16_weights=False, + # Note: After quantization, CB is the same as weight. + CB=weight, + SCB=scb, + ) + self.bias = bias if bias is None else torch.nn.Parameter(bias) + else: + # We are loading a non-quantized state dict. + + # We could simply call the `super()._load_from_state_dict()` method here, but then we wouldn't be able to + # load from a state_dict into a model on the "meta" device. Attempting to load into a model on the "meta" + # device requires setting `assign=True`, doing this with the default `super()._load_from_state_dict()` + # implementation causes `Params4Bit` to be replaced by a `torch.nn.Parameter`. By initializing a new + # `Params4bit` object, we work around this issue. It's a bit hacky, but it gets the job done. + self.weight = bnb.nn.Int8Params( + data=weight, + requires_grad=self.weight.requires_grad, + has_fp16_weights=False, + CB=None, + SCB=None, + ) + self.bias = bias if bias is None else torch.nn.Parameter(bias) + + +def _convert_linear_layers_to_llm_8bit( + module: torch.nn.Module, ignore_modules: set[str], outlier_threshold: float, prefix: str = "" +) -> None: + """Convert all linear layers in the module to bnb.nn.Linear8bitLt layers.""" + for name, child in module.named_children(): + fullname = f"{prefix}.{name}" if prefix else name + if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules): + has_bias = child.bias is not None + replacement = InvokeLinear8bitLt( + child.in_features, + child.out_features, + bias=has_bias, + has_fp16_weights=False, + threshold=outlier_threshold, + ) + replacement.weight.data = child.weight.data + if has_bias: + replacement.bias.data = child.bias.data + replacement.requires_grad_(False) + module.__setattr__(name, replacement) + else: + _convert_linear_layers_to_llm_8bit( + child, ignore_modules, outlier_threshold=outlier_threshold, prefix=fullname + ) + + +def get_parameter_device(parameter: torch.nn.Module): + return next(parameter.parameters()).device + + +def quantize_model_llm_int8(model: torch.nn.Module, modules_to_not_convert: set[str], outlier_threshold: float = 6.0): + """Apply bitsandbytes LLM.8bit() quantization to the model.""" + _convert_linear_layers_to_llm_8bit( + module=model, ignore_modules=modules_to_not_convert, outlier_threshold=outlier_threshold + ) + + return model diff --git a/invokeai/backend/quantization/bnb_nf4.py b/invokeai/backend/quantization/bnb_nf4.py index 02a2a732bf..28a0861449 100644 --- a/invokeai/backend/quantization/bnb_nf4.py +++ b/invokeai/backend/quantization/bnb_nf4.py @@ -5,6 +5,10 @@ import torch # The utils in this file are partially inspired by: # https://github.com/Lightning-AI/pytorch-lightning/blob/1551a16b94f5234a4a78801098f64d0732ef5cb5/src/lightning/fabric/plugins/precision/bitsandbytes.py +# NOTE(ryand): All of the custom state_dict manipulation logic in this file is pretty hacky. This could be made much +# cleaner by re-implementing bnb.nn.LinearNF4 with proper use of buffers and less magic. But, for now, we try to stick +# close to the bitsandbytes classes to make interoperability easier with other models that might use bitsandbytes. + class InvokeLinearNF4(bnb.nn.LinearNF4): """A class that extends `bnb.nn.LinearNF4` to add the following functionality: diff --git a/invokeai/backend/quantization/load_flux_model_bnb_llm_int8.py b/invokeai/backend/quantization/load_flux_model_bnb_llm_int8.py new file mode 100644 index 0000000000..fd54210cbe --- /dev/null +++ b/invokeai/backend/quantization/load_flux_model_bnb_llm_int8.py @@ -0,0 +1,89 @@ +import time +from contextlib import contextmanager +from pathlib import Path + +import accelerate +from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel +from safetensors.torch import load_file, save_file + +from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8 + + +@contextmanager +def log_time(name: str): + """Helper context manager to log the time taken by a block of code.""" + start = time.time() + try: + yield None + finally: + end = time.time() + print(f"'{name}' took {end - start:.4f} secs") + + +def main(): + # Load the FLUX transformer model onto the meta device. + model_path = Path( + "/data/invokeai/models/.download_cache/black-forest-labs_flux.1-schnell/FLUX.1-schnell/transformer/" + ) + + with log_time("Initialize FLUX transformer on meta device"): + model_config = FluxTransformer2DModel.load_config(model_path, local_files_only=True) + with accelerate.init_empty_weights(): + empty_model = FluxTransformer2DModel.from_config(model_config) + assert isinstance(empty_model, FluxTransformer2DModel) + + # TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate + # `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize. + modules_to_not_convert: set[str] = set() + + model_int8_path = model_path / "bnb_llm_int8" + if model_int8_path.exists(): + # The quantized model already exists, load it and return it. + print(f"A pre-quantized model already exists at '{model_int8_path}'. Attempting to load it...") + + # Replace the linear layers with LLM.int8() quantized linear layers (still on the meta device). + with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights(): + model = quantize_model_llm_int8(empty_model, modules_to_not_convert=modules_to_not_convert) + + with log_time("Load state dict into model"): + sd = load_file(model_int8_path / "model.safetensors") + model.load_state_dict(sd, strict=True, assign=True) + + with log_time("Move model to cuda"): + model = model.to("cuda") + + print(f"Successfully loaded pre-quantized model from '{model_int8_path}'.") + + else: + # The quantized model does not exist, quantize the model and save it. + print(f"No pre-quantized model found at '{model_int8_path}'. Quantizing the model...") + + with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights(): + model = quantize_model_llm_int8(empty_model, modules_to_not_convert=modules_to_not_convert) + + with log_time("Load state dict into model"): + # Load sharded state dict. + files = list(model_path.glob("*.safetensors")) + state_dict = dict() + for file in files: + sd = load_file(file) + state_dict.update(sd) + + model.load_state_dict(state_dict, strict=True, assign=True) + + with log_time("Move model to cuda and quantize"): + model = model.to("cuda") + + with log_time("Save quantized model"): + model_int8_path.mkdir(parents=True, exist_ok=True) + output_path = model_int8_path / "model.safetensors" + save_file(model.state_dict(), output_path) + + print(f"Successfully quantized and saved model to '{output_path}'.") + + assert isinstance(model, FluxTransformer2DModel) + return model + + +if __name__ == "__main__": + main() diff --git a/invokeai/backend/load_flux_model_bnb_nf4.py b/invokeai/backend/quantization/load_flux_model_bnb_nf4.py similarity index 100% rename from invokeai/backend/load_flux_model_bnb_nf4.py rename to invokeai/backend/quantization/load_flux_model_bnb_nf4.py