diff --git a/invokeai/app/invocations/flux_text_to_image.py b/invokeai/app/invocations/flux_text_to_image.py index 1b90048417..542fa6d6b5 100644 --- a/invokeai/app/invocations/flux_text_to_image.py +++ b/invokeai/app/invocations/flux_text_to_image.py @@ -1,12 +1,13 @@ from pathlib import Path from typing import Literal +import accelerate import torch from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel from diffusers.pipelines.flux.pipeline_flux import FluxPipeline -from optimum.quanto import qfloat8 from PIL import Image +from safetensors.torch import load_file from transformers.models.auto import AutoModelForTextEncoding from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation @@ -20,6 +21,7 @@ from invokeai.app.invocations.fields import ( ) from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.bnb import quantize_model_nf4 from invokeai.backend.quantization.fast_quantized_diffusion_model import FastQuantizedDiffusersModel from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo @@ -107,8 +109,9 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): transformer=transformer, ) - t5_embeddings = t5_embeddings.to(dtype=transformer.dtype) - clip_embeddings = clip_embeddings.to(dtype=transformer.dtype) + dtype = torch.bfloat16 + t5_embeddings = t5_embeddings.to(dtype=dtype) + clip_embeddings = clip_embeddings.to(dtype=dtype) latents = flux_pipeline_with_transformer( height=self.height, @@ -160,32 +163,49 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel: if self.use_8bit: - model_8bit_path = path / "quantized" - if model_8bit_path.exists(): - # The quantized model exists, load it. - # TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like - # something that we should be able to make much faster. - q_model = QuantizedFluxTransformer2DModel.from_pretrained(model_8bit_path) + model_config = FluxTransformer2DModel.load_config(path, local_files_only=True) + with accelerate.init_empty_weights(): + empty_model = FluxTransformer2DModel.from_config(model_config) + assert isinstance(empty_model, FluxTransformer2DModel) - # Access the underlying wrapped model. - # We access the wrapped model, even though it is private, because it simplifies the type checking by - # always returning a FluxTransformer2DModel from this function. - model = q_model._wrapped - else: - # The quantized model does not exist yet, quantize and save it. - # TODO(ryand): Loading in float16 and then quantizing seems to result in NaNs. In order to run this on - # GPUs that don't support bfloat16, we would need to host the quantized model instead of generating it - # here. - model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16) - assert isinstance(model, FluxTransformer2DModel) + model_nf4_path = path / "bnb_nf4" + if model_nf4_path.exists(): + with accelerate.init_empty_weights(): + model = quantize_model_nf4(empty_model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16) - q_model = QuantizedFluxTransformer2DModel.quantize(model, weights=qfloat8) + # model.to_empty(device="cpu") + # TODO(ryand): Right now, some of the weights are loaded in bfloat16. Think about how best to handle + # this on GPUs without bfloat16 support. + sd = load_file(model_nf4_path / "model.safetensors") + model.load_state_dict(sd, strict=True, assign=True) + # model = model.to("cuda") - model_8bit_path.mkdir(parents=True, exist_ok=True) - q_model.save_pretrained(model_8bit_path) + # model_8bit_path = path / "quantized" + # if model_8bit_path.exists(): + # # The quantized model exists, load it. + # # TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like + # # something that we should be able to make much faster. + # q_model = QuantizedFluxTransformer2DModel.from_pretrained(model_8bit_path) - # (See earlier comment about accessing the wrapped model.) - model = q_model._wrapped + # # Access the underlying wrapped model. + # # We access the wrapped model, even though it is private, because it simplifies the type checking by + # # always returning a FluxTransformer2DModel from this function. + # model = q_model._wrapped + # else: + # # The quantized model does not exist yet, quantize and save it. + # # TODO(ryand): Loading in float16 and then quantizing seems to result in NaNs. In order to run this on + # # GPUs that don't support bfloat16, we would need to host the quantized model instead of generating it + # # here. + # model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16) + # assert isinstance(model, FluxTransformer2DModel) + + # q_model = QuantizedFluxTransformer2DModel.quantize(model, weights=qfloat8) + + # model_8bit_path.mkdir(parents=True, exist_ok=True) + # q_model.save_pretrained(model_8bit_path) + + # # (See earlier comment about accessing the wrapped model.) + # model = q_model._wrapped else: model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16) diff --git a/invokeai/backend/bnb.py b/invokeai/backend/bnb.py index d0cb6f7c99..8c1f080e98 100644 --- a/invokeai/backend/bnb.py +++ b/invokeai/backend/bnb.py @@ -51,7 +51,7 @@ import torch # self.SCB = SCB -class InvokeLinear4Bit(bnb.nn.Linear4bit): +class InvokeLinearNF4(bnb.nn.LinearNF4): def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ): @@ -60,31 +60,36 @@ class InvokeLinear4Bit(bnb.nn.Linear4bit): I'm not sure why this was not included in the original `Linear4bit` implementation. """ - # During serialization, the quant_state is stored as subkeys of "weight.". Here we extract those keys. - quant_state_keys = [k for k in state_dict.keys() if k.startswith(prefix + "weight.")] - if len(quant_state_keys) > 0: + weight = state_dict.pop(prefix + "weight") + bias = state_dict.pop(prefix + "bias", None) + # During serialization, the quant_state is stored as subkeys of "weight.". + # We expect the remaining keys to be quant_state keys. We validate that they at least have the correct prefix. + quant_state_sd = state_dict + assert all(k.startswith(prefix + "weight.") for k in quant_state_sd.keys()) + + if len(quant_state_sd) > 0: # We are loading a quantized state dict. - quant_state_sd = {k: state_dict.pop(k) for k in quant_state_keys} - weight = state_dict.pop(prefix + "weight") - bias = state_dict.pop(prefix + "bias", None) - - if len(state_dict) != 0: - raise RuntimeError(f"Unexpected keys in state_dict: {state_dict.keys()}") - self.weight = bnb.nn.Params4bit.from_prequantized( data=weight, quantized_stats=quant_state_sd, device=weight.device ) - if bias is None: - self.bias = None - else: - self.bias = torch.nn.Parameter(bias) + self.bias = bias if bias is None else torch.nn.Parameter(bias, requires_grad=False) else: # We are loading a non-quantized state dict. - return super()._load_from_state_dict( - state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + + # We could simply call the `super()._load_from_state_dict` method here, but then we wouldn't be able to load + # into from a state_dict into a model on the "meta" device. By initializing a new `Params4bit` object, we + # work around this issue. + self.weight = bnb.nn.Params4bit( + data=weight, + requires_grad=self.weight.requires_grad, + compress_statistics=self.weight.compress_statistics, + quant_type=self.weight.quant_type, + quant_storage=self.weight.quant_storage, + module=self, ) + self.bias = bias if bias is None else torch.nn.Parameter(bias) class Invoke2Linear8bitLt(torch.nn.Linear): @@ -545,7 +550,7 @@ def _convert_linear_layers_to_nf4( fullname = f"{prefix}.{name}" if prefix else name if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules): has_bias = child.bias is not None - replacement = InvokeLinear4Bit( + replacement = InvokeLinearNF4( child.in_features, child.out_features, bias=has_bias, @@ -553,9 +558,14 @@ def _convert_linear_layers_to_nf4( # TODO(ryand): Test compress_statistics=True. # compress_statistics=True, ) - replacement.weight.data = child.weight.data + # replacement.weight.data = child.weight.data + # if has_bias: + # replacement.bias.data = child.bias.data if has_bias: - replacement.bias.data = child.bias.data + replacement.bias = _replace_param(replacement.bias, child.bias.data) + replacement.weight = _replace_param( + replacement.weight, child.weight.data, quant_state=replacement.weight.quant_state + ) replacement.requires_grad_(False) module.__setattr__(name, replacement) else: diff --git a/invokeai/backend/load_flux_model_bnb_nf4.py b/invokeai/backend/load_flux_model_bnb_nf4.py index 1a4e67c1c7..5cff6f07d4 100644 --- a/invokeai/backend/load_flux_model_bnb_nf4.py +++ b/invokeai/backend/load_flux_model_bnb_nf4.py @@ -4,7 +4,7 @@ from pathlib import Path import accelerate import torch from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel -from safetensors.torch import load_file +from safetensors.torch import load_file, save_file from invokeai.backend.bnb import quantize_model_nf4 @@ -62,6 +62,9 @@ def load_flux_transformer(path: Path) -> FluxTransformer2DModel: # --------------------- + with accelerate.init_empty_weights(): + model = quantize_model_nf4(empty_model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16) + # Load sharded state dict. files = list(path.glob("*.safetensors")) state_dict = dict() @@ -69,8 +72,9 @@ def load_flux_transformer(path: Path) -> FluxTransformer2DModel: sd = load_file(file) state_dict.update(sd) - empty_model.load_state_dict(state_dict, strict=True, assign=True) - model = quantize_model_nf4(empty_model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16) + # model.to_empty(device="cpu") + # model.to(dtype=torch.float16) + model.load_state_dict(state_dict, strict=True, assign=True) # Load the state dict into the model. The bitsandbytes layers know how to load from both quantized and # non-quantized state dicts. @@ -80,7 +84,7 @@ def load_flux_transformer(path: Path) -> FluxTransformer2DModel: model = model.to("cuda") model_nf4_path.mkdir(parents=True, exist_ok=True) - # save_file(model.state_dict(), model_nf4_path / "model.safetensors") + save_file(model.state_dict(), model_nf4_path / "model.safetensors") # ---------------------