NF4 inference working

This commit is contained in:
Ryan Dick 2024-08-14 23:30:53 +00:00 committed by Brandon
parent 5c2f95ef50
commit e1eb104345
3 changed files with 83 additions and 49 deletions

View File

@ -1,12 +1,13 @@
from pathlib import Path from pathlib import Path
from typing import Literal from typing import Literal
import accelerate
import torch import torch
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from optimum.quanto import qfloat8
from PIL import Image from PIL import Image
from safetensors.torch import load_file
from transformers.models.auto import AutoModelForTextEncoding from transformers.models.auto import AutoModelForTextEncoding
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
@ -20,6 +21,7 @@ from invokeai.app.invocations.fields import (
) )
from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.bnb import quantize_model_nf4
from invokeai.backend.quantization.fast_quantized_diffusion_model import FastQuantizedDiffusersModel from invokeai.backend.quantization.fast_quantized_diffusion_model import FastQuantizedDiffusersModel
from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
@ -107,8 +109,9 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
transformer=transformer, transformer=transformer,
) )
t5_embeddings = t5_embeddings.to(dtype=transformer.dtype) dtype = torch.bfloat16
clip_embeddings = clip_embeddings.to(dtype=transformer.dtype) t5_embeddings = t5_embeddings.to(dtype=dtype)
clip_embeddings = clip_embeddings.to(dtype=dtype)
latents = flux_pipeline_with_transformer( latents = flux_pipeline_with_transformer(
height=self.height, height=self.height,
@ -160,32 +163,49 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel: def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel:
if self.use_8bit: if self.use_8bit:
model_8bit_path = path / "quantized" model_config = FluxTransformer2DModel.load_config(path, local_files_only=True)
if model_8bit_path.exists(): with accelerate.init_empty_weights():
# The quantized model exists, load it. empty_model = FluxTransformer2DModel.from_config(model_config)
# TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like assert isinstance(empty_model, FluxTransformer2DModel)
# something that we should be able to make much faster.
q_model = QuantizedFluxTransformer2DModel.from_pretrained(model_8bit_path)
# Access the underlying wrapped model. model_nf4_path = path / "bnb_nf4"
# We access the wrapped model, even though it is private, because it simplifies the type checking by if model_nf4_path.exists():
# always returning a FluxTransformer2DModel from this function. with accelerate.init_empty_weights():
model = q_model._wrapped model = quantize_model_nf4(empty_model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
else:
# The quantized model does not exist yet, quantize and save it.
# TODO(ryand): Loading in float16 and then quantizing seems to result in NaNs. In order to run this on
# GPUs that don't support bfloat16, we would need to host the quantized model instead of generating it
# here.
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
assert isinstance(model, FluxTransformer2DModel)
q_model = QuantizedFluxTransformer2DModel.quantize(model, weights=qfloat8) # model.to_empty(device="cpu")
# TODO(ryand): Right now, some of the weights are loaded in bfloat16. Think about how best to handle
# this on GPUs without bfloat16 support.
sd = load_file(model_nf4_path / "model.safetensors")
model.load_state_dict(sd, strict=True, assign=True)
# model = model.to("cuda")
model_8bit_path.mkdir(parents=True, exist_ok=True) # model_8bit_path = path / "quantized"
q_model.save_pretrained(model_8bit_path) # if model_8bit_path.exists():
# # The quantized model exists, load it.
# # TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
# # something that we should be able to make much faster.
# q_model = QuantizedFluxTransformer2DModel.from_pretrained(model_8bit_path)
# (See earlier comment about accessing the wrapped model.) # # Access the underlying wrapped model.
model = q_model._wrapped # # We access the wrapped model, even though it is private, because it simplifies the type checking by
# # always returning a FluxTransformer2DModel from this function.
# model = q_model._wrapped
# else:
# # The quantized model does not exist yet, quantize and save it.
# # TODO(ryand): Loading in float16 and then quantizing seems to result in NaNs. In order to run this on
# # GPUs that don't support bfloat16, we would need to host the quantized model instead of generating it
# # here.
# model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
# assert isinstance(model, FluxTransformer2DModel)
# q_model = QuantizedFluxTransformer2DModel.quantize(model, weights=qfloat8)
# model_8bit_path.mkdir(parents=True, exist_ok=True)
# q_model.save_pretrained(model_8bit_path)
# # (See earlier comment about accessing the wrapped model.)
# model = q_model._wrapped
else: else:
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16) model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)

View File

@ -51,7 +51,7 @@ import torch
# self.SCB = SCB # self.SCB = SCB
class InvokeLinear4Bit(bnb.nn.Linear4bit): class InvokeLinearNF4(bnb.nn.LinearNF4):
def _load_from_state_dict( def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
): ):
@ -60,31 +60,36 @@ class InvokeLinear4Bit(bnb.nn.Linear4bit):
I'm not sure why this was not included in the original `Linear4bit` implementation. I'm not sure why this was not included in the original `Linear4bit` implementation.
""" """
# During serialization, the quant_state is stored as subkeys of "weight.". Here we extract those keys.
quant_state_keys = [k for k in state_dict.keys() if k.startswith(prefix + "weight.")]
if len(quant_state_keys) > 0: weight = state_dict.pop(prefix + "weight")
bias = state_dict.pop(prefix + "bias", None)
# During serialization, the quant_state is stored as subkeys of "weight.".
# We expect the remaining keys to be quant_state keys. We validate that they at least have the correct prefix.
quant_state_sd = state_dict
assert all(k.startswith(prefix + "weight.") for k in quant_state_sd.keys())
if len(quant_state_sd) > 0:
# We are loading a quantized state dict. # We are loading a quantized state dict.
quant_state_sd = {k: state_dict.pop(k) for k in quant_state_keys}
weight = state_dict.pop(prefix + "weight")
bias = state_dict.pop(prefix + "bias", None)
if len(state_dict) != 0:
raise RuntimeError(f"Unexpected keys in state_dict: {state_dict.keys()}")
self.weight = bnb.nn.Params4bit.from_prequantized( self.weight = bnb.nn.Params4bit.from_prequantized(
data=weight, quantized_stats=quant_state_sd, device=weight.device data=weight, quantized_stats=quant_state_sd, device=weight.device
) )
if bias is None: self.bias = bias if bias is None else torch.nn.Parameter(bias, requires_grad=False)
self.bias = None
else:
self.bias = torch.nn.Parameter(bias)
else: else:
# We are loading a non-quantized state dict. # We are loading a non-quantized state dict.
return super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs # We could simply call the `super()._load_from_state_dict` method here, but then we wouldn't be able to load
# into from a state_dict into a model on the "meta" device. By initializing a new `Params4bit` object, we
# work around this issue.
self.weight = bnb.nn.Params4bit(
data=weight,
requires_grad=self.weight.requires_grad,
compress_statistics=self.weight.compress_statistics,
quant_type=self.weight.quant_type,
quant_storage=self.weight.quant_storage,
module=self,
) )
self.bias = bias if bias is None else torch.nn.Parameter(bias)
class Invoke2Linear8bitLt(torch.nn.Linear): class Invoke2Linear8bitLt(torch.nn.Linear):
@ -545,7 +550,7 @@ def _convert_linear_layers_to_nf4(
fullname = f"{prefix}.{name}" if prefix else name fullname = f"{prefix}.{name}" if prefix else name
if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules): if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules):
has_bias = child.bias is not None has_bias = child.bias is not None
replacement = InvokeLinear4Bit( replacement = InvokeLinearNF4(
child.in_features, child.in_features,
child.out_features, child.out_features,
bias=has_bias, bias=has_bias,
@ -553,9 +558,14 @@ def _convert_linear_layers_to_nf4(
# TODO(ryand): Test compress_statistics=True. # TODO(ryand): Test compress_statistics=True.
# compress_statistics=True, # compress_statistics=True,
) )
replacement.weight.data = child.weight.data # replacement.weight.data = child.weight.data
# if has_bias:
# replacement.bias.data = child.bias.data
if has_bias: if has_bias:
replacement.bias.data = child.bias.data replacement.bias = _replace_param(replacement.bias, child.bias.data)
replacement.weight = _replace_param(
replacement.weight, child.weight.data, quant_state=replacement.weight.quant_state
)
replacement.requires_grad_(False) replacement.requires_grad_(False)
module.__setattr__(name, replacement) module.__setattr__(name, replacement)
else: else:

View File

@ -4,7 +4,7 @@ from pathlib import Path
import accelerate import accelerate
import torch import torch
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from safetensors.torch import load_file from safetensors.torch import load_file, save_file
from invokeai.backend.bnb import quantize_model_nf4 from invokeai.backend.bnb import quantize_model_nf4
@ -62,6 +62,9 @@ def load_flux_transformer(path: Path) -> FluxTransformer2DModel:
# --------------------- # ---------------------
with accelerate.init_empty_weights():
model = quantize_model_nf4(empty_model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
# Load sharded state dict. # Load sharded state dict.
files = list(path.glob("*.safetensors")) files = list(path.glob("*.safetensors"))
state_dict = dict() state_dict = dict()
@ -69,8 +72,9 @@ def load_flux_transformer(path: Path) -> FluxTransformer2DModel:
sd = load_file(file) sd = load_file(file)
state_dict.update(sd) state_dict.update(sd)
empty_model.load_state_dict(state_dict, strict=True, assign=True) # model.to_empty(device="cpu")
model = quantize_model_nf4(empty_model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16) # model.to(dtype=torch.float16)
model.load_state_dict(state_dict, strict=True, assign=True)
# Load the state dict into the model. The bitsandbytes layers know how to load from both quantized and # Load the state dict into the model. The bitsandbytes layers know how to load from both quantized and
# non-quantized state dicts. # non-quantized state dicts.
@ -80,7 +84,7 @@ def load_flux_transformer(path: Path) -> FluxTransformer2DModel:
model = model.to("cuda") model = model.to("cuda")
model_nf4_path.mkdir(parents=True, exist_ok=True) model_nf4_path.mkdir(parents=True, exist_ok=True)
# save_file(model.state_dict(), model_nf4_path / "model.safetensors") save_file(model.state_dict(), model_nf4_path / "model.safetensors")
# --------------------- # ---------------------