mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
NF4 inference working
This commit is contained in:
parent
96b0450b20
commit
1b80832b22
@ -1,12 +1,13 @@
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
||||
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
|
||||
from optimum.quanto import qfloat8
|
||||
from PIL import Image
|
||||
from safetensors.torch import load_file
|
||||
from transformers.models.auto import AutoModelForTextEncoding
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
@ -20,6 +21,7 @@ from invokeai.app.invocations.fields import (
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.bnb import quantize_model_nf4
|
||||
from invokeai.backend.quantization.fast_quantized_diffusion_model import FastQuantizedDiffusersModel
|
||||
from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
||||
@ -107,8 +109,9 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
transformer=transformer,
|
||||
)
|
||||
|
||||
t5_embeddings = t5_embeddings.to(dtype=transformer.dtype)
|
||||
clip_embeddings = clip_embeddings.to(dtype=transformer.dtype)
|
||||
dtype = torch.bfloat16
|
||||
t5_embeddings = t5_embeddings.to(dtype=dtype)
|
||||
clip_embeddings = clip_embeddings.to(dtype=dtype)
|
||||
|
||||
latents = flux_pipeline_with_transformer(
|
||||
height=self.height,
|
||||
@ -160,32 +163,49 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel:
|
||||
if self.use_8bit:
|
||||
model_8bit_path = path / "quantized"
|
||||
if model_8bit_path.exists():
|
||||
# The quantized model exists, load it.
|
||||
# TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
|
||||
# something that we should be able to make much faster.
|
||||
q_model = QuantizedFluxTransformer2DModel.from_pretrained(model_8bit_path)
|
||||
model_config = FluxTransformer2DModel.load_config(path, local_files_only=True)
|
||||
with accelerate.init_empty_weights():
|
||||
empty_model = FluxTransformer2DModel.from_config(model_config)
|
||||
assert isinstance(empty_model, FluxTransformer2DModel)
|
||||
|
||||
# Access the underlying wrapped model.
|
||||
# We access the wrapped model, even though it is private, because it simplifies the type checking by
|
||||
# always returning a FluxTransformer2DModel from this function.
|
||||
model = q_model._wrapped
|
||||
else:
|
||||
# The quantized model does not exist yet, quantize and save it.
|
||||
# TODO(ryand): Loading in float16 and then quantizing seems to result in NaNs. In order to run this on
|
||||
# GPUs that don't support bfloat16, we would need to host the quantized model instead of generating it
|
||||
# here.
|
||||
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
|
||||
assert isinstance(model, FluxTransformer2DModel)
|
||||
model_nf4_path = path / "bnb_nf4"
|
||||
if model_nf4_path.exists():
|
||||
with accelerate.init_empty_weights():
|
||||
model = quantize_model_nf4(empty_model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
|
||||
|
||||
q_model = QuantizedFluxTransformer2DModel.quantize(model, weights=qfloat8)
|
||||
# model.to_empty(device="cpu")
|
||||
# TODO(ryand): Right now, some of the weights are loaded in bfloat16. Think about how best to handle
|
||||
# this on GPUs without bfloat16 support.
|
||||
sd = load_file(model_nf4_path / "model.safetensors")
|
||||
model.load_state_dict(sd, strict=True, assign=True)
|
||||
# model = model.to("cuda")
|
||||
|
||||
model_8bit_path.mkdir(parents=True, exist_ok=True)
|
||||
q_model.save_pretrained(model_8bit_path)
|
||||
# model_8bit_path = path / "quantized"
|
||||
# if model_8bit_path.exists():
|
||||
# # The quantized model exists, load it.
|
||||
# # TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
|
||||
# # something that we should be able to make much faster.
|
||||
# q_model = QuantizedFluxTransformer2DModel.from_pretrained(model_8bit_path)
|
||||
|
||||
# (See earlier comment about accessing the wrapped model.)
|
||||
model = q_model._wrapped
|
||||
# # Access the underlying wrapped model.
|
||||
# # We access the wrapped model, even though it is private, because it simplifies the type checking by
|
||||
# # always returning a FluxTransformer2DModel from this function.
|
||||
# model = q_model._wrapped
|
||||
# else:
|
||||
# # The quantized model does not exist yet, quantize and save it.
|
||||
# # TODO(ryand): Loading in float16 and then quantizing seems to result in NaNs. In order to run this on
|
||||
# # GPUs that don't support bfloat16, we would need to host the quantized model instead of generating it
|
||||
# # here.
|
||||
# model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
|
||||
# assert isinstance(model, FluxTransformer2DModel)
|
||||
|
||||
# q_model = QuantizedFluxTransformer2DModel.quantize(model, weights=qfloat8)
|
||||
|
||||
# model_8bit_path.mkdir(parents=True, exist_ok=True)
|
||||
# q_model.save_pretrained(model_8bit_path)
|
||||
|
||||
# # (See earlier comment about accessing the wrapped model.)
|
||||
# model = q_model._wrapped
|
||||
else:
|
||||
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
|
||||
|
||||
|
@ -51,7 +51,7 @@ import torch
|
||||
# self.SCB = SCB
|
||||
|
||||
|
||||
class InvokeLinear4Bit(bnb.nn.Linear4bit):
|
||||
class InvokeLinearNF4(bnb.nn.LinearNF4):
|
||||
def _load_from_state_dict(
|
||||
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
):
|
||||
@ -60,31 +60,36 @@ class InvokeLinear4Bit(bnb.nn.Linear4bit):
|
||||
|
||||
I'm not sure why this was not included in the original `Linear4bit` implementation.
|
||||
"""
|
||||
# During serialization, the quant_state is stored as subkeys of "weight.". Here we extract those keys.
|
||||
quant_state_keys = [k for k in state_dict.keys() if k.startswith(prefix + "weight.")]
|
||||
|
||||
if len(quant_state_keys) > 0:
|
||||
# We are loading a quantized state dict.
|
||||
quant_state_sd = {k: state_dict.pop(k) for k in quant_state_keys}
|
||||
weight = state_dict.pop(prefix + "weight")
|
||||
bias = state_dict.pop(prefix + "bias", None)
|
||||
# During serialization, the quant_state is stored as subkeys of "weight.".
|
||||
# We expect the remaining keys to be quant_state keys. We validate that they at least have the correct prefix.
|
||||
quant_state_sd = state_dict
|
||||
assert all(k.startswith(prefix + "weight.") for k in quant_state_sd.keys())
|
||||
|
||||
if len(state_dict) != 0:
|
||||
raise RuntimeError(f"Unexpected keys in state_dict: {state_dict.keys()}")
|
||||
|
||||
if len(quant_state_sd) > 0:
|
||||
# We are loading a quantized state dict.
|
||||
self.weight = bnb.nn.Params4bit.from_prequantized(
|
||||
data=weight, quantized_stats=quant_state_sd, device=weight.device
|
||||
)
|
||||
if bias is None:
|
||||
self.bias = None
|
||||
else:
|
||||
self.bias = torch.nn.Parameter(bias)
|
||||
self.bias = bias if bias is None else torch.nn.Parameter(bias, requires_grad=False)
|
||||
|
||||
else:
|
||||
# We are loading a non-quantized state dict.
|
||||
return super()._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
|
||||
# We could simply call the `super()._load_from_state_dict` method here, but then we wouldn't be able to load
|
||||
# into from a state_dict into a model on the "meta" device. By initializing a new `Params4bit` object, we
|
||||
# work around this issue.
|
||||
self.weight = bnb.nn.Params4bit(
|
||||
data=weight,
|
||||
requires_grad=self.weight.requires_grad,
|
||||
compress_statistics=self.weight.compress_statistics,
|
||||
quant_type=self.weight.quant_type,
|
||||
quant_storage=self.weight.quant_storage,
|
||||
module=self,
|
||||
)
|
||||
self.bias = bias if bias is None else torch.nn.Parameter(bias)
|
||||
|
||||
|
||||
class Invoke2Linear8bitLt(torch.nn.Linear):
|
||||
@ -545,7 +550,7 @@ def _convert_linear_layers_to_nf4(
|
||||
fullname = f"{prefix}.{name}" if prefix else name
|
||||
if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules):
|
||||
has_bias = child.bias is not None
|
||||
replacement = InvokeLinear4Bit(
|
||||
replacement = InvokeLinearNF4(
|
||||
child.in_features,
|
||||
child.out_features,
|
||||
bias=has_bias,
|
||||
@ -553,9 +558,14 @@ def _convert_linear_layers_to_nf4(
|
||||
# TODO(ryand): Test compress_statistics=True.
|
||||
# compress_statistics=True,
|
||||
)
|
||||
replacement.weight.data = child.weight.data
|
||||
# replacement.weight.data = child.weight.data
|
||||
# if has_bias:
|
||||
# replacement.bias.data = child.bias.data
|
||||
if has_bias:
|
||||
replacement.bias.data = child.bias.data
|
||||
replacement.bias = _replace_param(replacement.bias, child.bias.data)
|
||||
replacement.weight = _replace_param(
|
||||
replacement.weight, child.weight.data, quant_state=replacement.weight.quant_state
|
||||
)
|
||||
replacement.requires_grad_(False)
|
||||
module.__setattr__(name, replacement)
|
||||
else:
|
||||
|
@ -4,7 +4,7 @@ from pathlib import Path
|
||||
import accelerate
|
||||
import torch
|
||||
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
||||
from safetensors.torch import load_file
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from invokeai.backend.bnb import quantize_model_nf4
|
||||
|
||||
@ -62,6 +62,9 @@ def load_flux_transformer(path: Path) -> FluxTransformer2DModel:
|
||||
|
||||
# ---------------------
|
||||
|
||||
with accelerate.init_empty_weights():
|
||||
model = quantize_model_nf4(empty_model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
|
||||
|
||||
# Load sharded state dict.
|
||||
files = list(path.glob("*.safetensors"))
|
||||
state_dict = dict()
|
||||
@ -69,8 +72,9 @@ def load_flux_transformer(path: Path) -> FluxTransformer2DModel:
|
||||
sd = load_file(file)
|
||||
state_dict.update(sd)
|
||||
|
||||
empty_model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
model = quantize_model_nf4(empty_model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
|
||||
# model.to_empty(device="cpu")
|
||||
# model.to(dtype=torch.float16)
|
||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
|
||||
# Load the state dict into the model. The bitsandbytes layers know how to load from both quantized and
|
||||
# non-quantized state dicts.
|
||||
@ -80,7 +84,7 @@ def load_flux_transformer(path: Path) -> FluxTransformer2DModel:
|
||||
model = model.to("cuda")
|
||||
|
||||
model_nf4_path.mkdir(parents=True, exist_ok=True)
|
||||
# save_file(model.state_dict(), model_nf4_path / "model.safetensors")
|
||||
save_file(model.state_dict(), model_nf4_path / "model.safetensors")
|
||||
|
||||
# ---------------------
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user