Add torch module autocast unit test for GGUF-quantized models.

This commit is contained in:
Ryan Dick
2024-12-21 15:22:06 +00:00
parent fe0ef2c27c
commit 97d56f7dc9

View File

@ -1,3 +1,4 @@
import gguf
import pytest
import torch
@ -5,9 +6,9 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch
apply_custom_layers_to_model,
remove_custom_layers_from_model,
)
from tests.backend.model_manager.load.model_cache.dummy_module import DummyModule
from tests.backend.quantization.gguf.test_ggml_tensor import quantize_tensor
mps_and_cuda = pytest.mark.parametrize(
cuda_and_mps = pytest.mark.parametrize(
"device",
[
pytest.param(
@ -21,14 +22,36 @@ mps_and_cuda = pytest.mark.parametrize(
)
@mps_and_cuda
def test_torch_module_autocast(device: torch.device):
model = DummyModule()
class ModelWithLinearLayer(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(32, 64)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x)
@pytest.fixture(params=["none", "gguf"])
def model(request: pytest.FixtureRequest) -> torch.nn.Module:
if request.param == "none":
return ModelWithLinearLayer()
elif request.param == "gguf":
# Initialize ModelWithLinearLayer and replace the linear layer weight with a GGML quantized weight.
model = ModelWithLinearLayer()
ggml_quantized_weight = quantize_tensor(model.linear.weight, gguf.GGMLQuantizationType.Q8_0)
model.linear.weight = torch.nn.Parameter(ggml_quantized_weight)
return model
else:
raise ValueError(f"Invalid quantization type: {request.param}")
@cuda_and_mps
def test_torch_module_autocast_linear_layer(device: torch.device, model: torch.nn.Module):
# Model parameters should start off on the CPU.
assert all(p.device.type == "cpu" for p in model.parameters())
# Run inference on the CPU.
x = torch.randn(10, 10, device="cpu")
x = torch.randn(10, 32, device="cpu")
expected = model(x)
assert expected.device.type == "cpu"
@ -56,5 +79,5 @@ def test_torch_module_autocast(device: torch.device):
assert after_result.device.type == "cpu"
# The results from all inference runs should be the same.
assert torch.allclose(autocast_result.to("cpu"), expected)
assert torch.allclose(after_result, expected)
assert torch.allclose(autocast_result.to("cpu"), expected, atol=1e-5)
assert torch.allclose(after_result, expected, atol=1e-5)