Add torch module autocast utilities.

This commit is contained in:
Ryan Dick
2024-12-21 14:40:27 +00:00
parent 65fcbf5f60
commit fe0ef2c27c
4 changed files with 161 additions and 0 deletions

View File

@ -0,0 +1,61 @@
from typing import TypeVar
import torch
T = TypeVar("T", torch.Tensor, None, torch.Tensor | None)
# This file contains custom torch.nn.Module classes that support streaming of weights to the target device.
# Each class sub-classes the original module type that is is replacing, so the following properties are preserved:
# - isinstance(m, torch.nn.OrginalModule) should still work.
# - Patching the weights (e.g. for LoRA) should still work if non-quantized.
def cast_to_device(t: T, to_device: torch.device) -> T:
if t is None:
return t
if t.device.type != to_device.type:
return t.to(to_device)
return t
class CustomLinear(torch.nn.Linear):
def forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
return torch.nn.functional.linear(input, weight, bias)
class CustomConv1d(torch.nn.Conv1d):
def forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
return self._conv_forward(input, weight, bias)
class CustomConv2d(torch.nn.Conv2d):
def forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
return self._conv_forward(input, weight, bias)
class CustomGroupNorm(torch.nn.GroupNorm):
def forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
class CustomEmbedding(torch.nn.Embedding):
def forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
return torch.nn.functional.embedding(
input,
weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)

View File

@ -0,0 +1,40 @@
import torch
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.autocast_modules import (
CustomConv1d,
CustomConv2d,
CustomEmbedding,
CustomGroupNorm,
CustomLinear,
)
AUTOCAST_MODULE_TYPE_MAPPING: dict[type[torch.nn.Module], type[torch.nn.Module]] = {
torch.nn.Linear: CustomLinear,
torch.nn.Conv1d: CustomConv1d,
torch.nn.Conv2d: CustomConv2d,
torch.nn.GroupNorm: CustomGroupNorm,
torch.nn.Embedding: CustomEmbedding,
}
def apply_custom_layers_to_model(model: torch.nn.Module):
def apply_custom_layers(module: torch.nn.Module):
override_type = AUTOCAST_MODULE_TYPE_MAPPING.get(type(module), None)
if override_type is not None:
module.__class__ = override_type
# model.apply(...) calls apply_custom_layers(...) on each module in the model.
model.apply(apply_custom_layers)
def remove_custom_layers_from_model(model: torch.nn.Module):
# Invert AUTOCAST_MODULE_TYPE_MAPPING.
original_module_type_mapping = {v: k for k, v in AUTOCAST_MODULE_TYPE_MAPPING.items()}
def remove_custom_layers(module: torch.nn.Module):
override_type = original_module_type_mapping.get(type(module), None)
if override_type is not None:
module.__class__ = override_type
# model.apply(...) calls remove_custom_layers(...) on each module in the model.
model.apply(remove_custom_layers)

View File

@ -0,0 +1,60 @@
import pytest
import torch
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import (
apply_custom_layers_to_model,
remove_custom_layers_from_model,
)
from tests.backend.model_manager.load.model_cache.dummy_module import DummyModule
mps_and_cuda = pytest.mark.parametrize(
"device",
[
pytest.param(
torch.device("cuda"), marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")
),
pytest.param(
torch.device("mps"),
marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="requires MPS device"),
),
],
)
@mps_and_cuda
def test_torch_module_autocast(device: torch.device):
model = DummyModule()
# Model parameters should start off on the CPU.
assert all(p.device.type == "cpu" for p in model.parameters())
# Run inference on the CPU.
x = torch.randn(10, 10, device="cpu")
expected = model(x)
assert expected.device.type == "cpu"
# Apply the custom layers to the model.
apply_custom_layers_to_model(model)
# Run the model on the device.
autocast_result = model(x.to(device))
# The model output should be on the device.
assert autocast_result.device.type == device.type
# The model parameters should still be on the CPU.
assert all(p.device.type == "cpu" for p in model.parameters())
# Remove the custom layers from the model.
remove_custom_layers_from_model(model)
# After removing the custom layers, the model should no longer be able to run inference on the device.
with pytest.raises(RuntimeError):
_ = model(x.to(device))
# Run inference again on the CPU.
after_result = model(x)
assert after_result.device.type == "cpu"
# The results from all inference runs should be the same.
assert torch.allclose(autocast_result.to("cpu"), expected)
assert torch.allclose(after_result, expected)