mirror of
https://github.com/invoke-ai/InvokeAI
synced 2025-07-26 05:17:55 +00:00
Add CustomInvokeLinearNF4 to enable CPU -> GPU streaming for InvokeLinearNF4 layers.
This commit is contained in:
@ -1,9 +1,11 @@
|
||||
import copy
|
||||
from typing import TypeVar
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
|
||||
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt
|
||||
from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4
|
||||
|
||||
T = TypeVar("T", torch.Tensor, None, torch.Tensor | None)
|
||||
|
||||
@ -84,3 +86,37 @@ class CustomInvokeLinear8bitLt(InvokeLinear8bitLt):
|
||||
# it's dtype field must be accessible, even though it's not used. We pass in self.weight even though it could be
|
||||
# on the wrong device.
|
||||
return bnb.matmul(x, self.weight, bias=cast_to_device(self.bias, x.device), state=matmul_state)
|
||||
|
||||
|
||||
class CustomInvokeLinearNF4(InvokeLinearNF4):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
bnb.nn.modules.fix_4bit_weight_quant_state_from_module(self)
|
||||
|
||||
# weights are cast automatically as Int8Params, but the bias has to be cast manually
|
||||
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||
self.bias.data = self.bias.data.to(x.dtype)
|
||||
|
||||
if not self.compute_type_is_set:
|
||||
self.set_compute_type(x)
|
||||
self.compute_type_is_set = True
|
||||
|
||||
inp_dtype = x.dtype
|
||||
if self.compute_dtype is not None:
|
||||
x = x.to(self.compute_dtype)
|
||||
|
||||
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
|
||||
|
||||
# HACK(ryand): Casting self.weight to the device also casts the self.weight.quant_state in-place (i.e. it
|
||||
# does not follow the tensor semantics of returning a new copy when converting to a different device). This
|
||||
# means that quant_state elements that started on the CPU would be left on the GPU, which we don't want. To
|
||||
# avoid this side effect we make a shallow copy of the original quant_state so that we can restore it. Fixing
|
||||
# this properly would require more invasive changes to the bitsandbytes library.
|
||||
|
||||
# Make a shallow copy of the quant_state so that we can undo the in-place modification that occurs when casting
|
||||
# to a new device.
|
||||
old_quant_state = copy.copy(self.weight.quant_state)
|
||||
weight = cast_to_device(self.weight, x.device)
|
||||
self.weight.quant_state = old_quant_state
|
||||
|
||||
bias = cast_to_device(self.bias, x.device)
|
||||
return bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state).to(inp_dtype)
|
||||
|
@ -3,8 +3,10 @@ import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.autocast_modules import (
|
||||
CustomInvokeLinear8bitLt,
|
||||
CustomInvokeLinearNF4,
|
||||
)
|
||||
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt
|
||||
from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -65,3 +67,70 @@ def test_custom_invoke_linear_8bit_lt_all_weights_on_cpu(linear_8bit_lt_layer: I
|
||||
|
||||
# Assert that the quantized and custom layers produce the same output.
|
||||
assert torch.allclose(y_quantized, y_custom, atol=1e-5)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def linear_nf4_layer():
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is not available")
|
||||
|
||||
torch.manual_seed(1)
|
||||
|
||||
orig_layer = torch.nn.Linear(32, 64)
|
||||
orig_layer_state_dict = orig_layer.state_dict()
|
||||
|
||||
# Prepare a quantized InvokeLinearNF4 layer.
|
||||
quantized_layer = InvokeLinearNF4(input_features=32, output_features=64)
|
||||
quantized_layer.load_state_dict(orig_layer_state_dict)
|
||||
quantized_layer.to("cuda")
|
||||
|
||||
# Assert that the InvokeLinearNF4 layer is quantized.
|
||||
assert quantized_layer.weight.bnb_quantized
|
||||
|
||||
return quantized_layer
|
||||
|
||||
|
||||
def test_custom_invoke_linear_nf4_all_weights_on_cuda(linear_nf4_layer: InvokeLinearNF4):
|
||||
"""Test CustomInvokeLinearNF4 inference with all weights on the GPU."""
|
||||
# Run inference on the original layer.
|
||||
x = torch.randn(10, 32).to("cuda")
|
||||
y_quantized = linear_nf4_layer(x)
|
||||
|
||||
# Wrap the InvokeLinearNF4 layer in a CustomInvokeLinearNF4 layer, and run inference on it.
|
||||
linear_nf4_layer.__class__ = CustomInvokeLinearNF4
|
||||
y_custom = linear_nf4_layer(x)
|
||||
|
||||
# Assert that the quantized and custom layers produce the same output.
|
||||
assert torch.allclose(y_quantized, y_custom, atol=1e-5)
|
||||
|
||||
|
||||
def test_custom_invoke_linear_nf4_all_weights_on_cpu(linear_nf4_layer: InvokeLinearNF4):
|
||||
"""Test CustomInvokeLinearNF4 inference with all weights on the CPU (streaming to the GPU)."""
|
||||
# Run inference on the original layer.
|
||||
x = torch.randn(10, 32).to(device="cuda")
|
||||
y_quantized = linear_nf4_layer(x)
|
||||
|
||||
# Copy the state dict to the CPU and reload it.
|
||||
state_dict = linear_nf4_layer.state_dict()
|
||||
state_dict = {k: v.to("cpu") for k, v in state_dict.items()}
|
||||
linear_nf4_layer.load_state_dict(state_dict)
|
||||
|
||||
# Inference of the original layer should fail.
|
||||
with pytest.raises(RuntimeError):
|
||||
linear_nf4_layer(x)
|
||||
|
||||
# Wrap the InvokeLinearNF4 layer in a CustomInvokeLinearNF4 layer, and run inference on it.
|
||||
linear_nf4_layer.__class__ = CustomInvokeLinearNF4
|
||||
y_custom = linear_nf4_layer(x)
|
||||
|
||||
# Assert that the state dict (and the tensors that it references) are still on the CPU.
|
||||
assert all(v.device == torch.device("cpu") for v in state_dict.values())
|
||||
|
||||
# Assert that the weight, bias, and quant_state are all on the CPU.
|
||||
assert linear_nf4_layer.weight.device == torch.device("cpu")
|
||||
assert linear_nf4_layer.bias.device == torch.device("cpu")
|
||||
assert linear_nf4_layer.weight.quant_state.absmax.device == torch.device("cpu")
|
||||
assert linear_nf4_layer.weight.quant_state.code.device == torch.device("cpu")
|
||||
|
||||
# Assert that the quantized and custom layers produce the same output.
|
||||
assert torch.allclose(y_quantized, y_custom, atol=1e-5)
|
||||
|
Reference in New Issue
Block a user