mirror of
https://github.com/invoke-ai/InvokeAI
synced 2025-07-26 05:17:55 +00:00
Add torch module autocast utilities.
This commit is contained in:
@ -0,0 +1,61 @@
|
||||
from typing import TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
T = TypeVar("T", torch.Tensor, None, torch.Tensor | None)
|
||||
|
||||
# This file contains custom torch.nn.Module classes that support streaming of weights to the target device.
|
||||
# Each class sub-classes the original module type that is is replacing, so the following properties are preserved:
|
||||
# - isinstance(m, torch.nn.OrginalModule) should still work.
|
||||
# - Patching the weights (e.g. for LoRA) should still work if non-quantized.
|
||||
|
||||
|
||||
def cast_to_device(t: T, to_device: torch.device) -> T:
|
||||
if t is None:
|
||||
return t
|
||||
|
||||
if t.device.type != to_device.type:
|
||||
return t.to(to_device)
|
||||
return t
|
||||
|
||||
|
||||
class CustomLinear(torch.nn.Linear):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
bias = cast_to_device(self.bias, input.device)
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
|
||||
class CustomConv1d(torch.nn.Conv1d):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
bias = cast_to_device(self.bias, input.device)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
|
||||
class CustomConv2d(torch.nn.Conv2d):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
bias = cast_to_device(self.bias, input.device)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
|
||||
class CustomGroupNorm(torch.nn.GroupNorm):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
bias = cast_to_device(self.bias, input.device)
|
||||
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
||||
|
||||
|
||||
class CustomEmbedding(torch.nn.Embedding):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
return torch.nn.functional.embedding(
|
||||
input,
|
||||
weight,
|
||||
self.padding_idx,
|
||||
self.max_norm,
|
||||
self.norm_type,
|
||||
self.scale_grad_by_freq,
|
||||
self.sparse,
|
||||
)
|
@ -0,0 +1,40 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.autocast_modules import (
|
||||
CustomConv1d,
|
||||
CustomConv2d,
|
||||
CustomEmbedding,
|
||||
CustomGroupNorm,
|
||||
CustomLinear,
|
||||
)
|
||||
|
||||
AUTOCAST_MODULE_TYPE_MAPPING: dict[type[torch.nn.Module], type[torch.nn.Module]] = {
|
||||
torch.nn.Linear: CustomLinear,
|
||||
torch.nn.Conv1d: CustomConv1d,
|
||||
torch.nn.Conv2d: CustomConv2d,
|
||||
torch.nn.GroupNorm: CustomGroupNorm,
|
||||
torch.nn.Embedding: CustomEmbedding,
|
||||
}
|
||||
|
||||
|
||||
def apply_custom_layers_to_model(model: torch.nn.Module):
|
||||
def apply_custom_layers(module: torch.nn.Module):
|
||||
override_type = AUTOCAST_MODULE_TYPE_MAPPING.get(type(module), None)
|
||||
if override_type is not None:
|
||||
module.__class__ = override_type
|
||||
|
||||
# model.apply(...) calls apply_custom_layers(...) on each module in the model.
|
||||
model.apply(apply_custom_layers)
|
||||
|
||||
|
||||
def remove_custom_layers_from_model(model: torch.nn.Module):
|
||||
# Invert AUTOCAST_MODULE_TYPE_MAPPING.
|
||||
original_module_type_mapping = {v: k for k, v in AUTOCAST_MODULE_TYPE_MAPPING.items()}
|
||||
|
||||
def remove_custom_layers(module: torch.nn.Module):
|
||||
override_type = original_module_type_mapping.get(type(module), None)
|
||||
if override_type is not None:
|
||||
module.__class__ = override_type
|
||||
|
||||
# model.apply(...) calls remove_custom_layers(...) on each module in the model.
|
||||
model.apply(remove_custom_layers)
|
@ -0,0 +1,60 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import (
|
||||
apply_custom_layers_to_model,
|
||||
remove_custom_layers_from_model,
|
||||
)
|
||||
from tests.backend.model_manager.load.model_cache.dummy_module import DummyModule
|
||||
|
||||
mps_and_cuda = pytest.mark.parametrize(
|
||||
"device",
|
||||
[
|
||||
pytest.param(
|
||||
torch.device("cuda"), marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")
|
||||
),
|
||||
pytest.param(
|
||||
torch.device("mps"),
|
||||
marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="requires MPS device"),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@mps_and_cuda
|
||||
def test_torch_module_autocast(device: torch.device):
|
||||
model = DummyModule()
|
||||
# Model parameters should start off on the CPU.
|
||||
assert all(p.device.type == "cpu" for p in model.parameters())
|
||||
|
||||
# Run inference on the CPU.
|
||||
x = torch.randn(10, 10, device="cpu")
|
||||
expected = model(x)
|
||||
assert expected.device.type == "cpu"
|
||||
|
||||
# Apply the custom layers to the model.
|
||||
apply_custom_layers_to_model(model)
|
||||
|
||||
# Run the model on the device.
|
||||
autocast_result = model(x.to(device))
|
||||
|
||||
# The model output should be on the device.
|
||||
assert autocast_result.device.type == device.type
|
||||
# The model parameters should still be on the CPU.
|
||||
assert all(p.device.type == "cpu" for p in model.parameters())
|
||||
|
||||
# Remove the custom layers from the model.
|
||||
remove_custom_layers_from_model(model)
|
||||
|
||||
# After removing the custom layers, the model should no longer be able to run inference on the device.
|
||||
with pytest.raises(RuntimeError):
|
||||
_ = model(x.to(device))
|
||||
|
||||
# Run inference again on the CPU.
|
||||
after_result = model(x)
|
||||
|
||||
assert after_result.device.type == "cpu"
|
||||
|
||||
# The results from all inference runs should be the same.
|
||||
assert torch.allclose(autocast_result.to("cpu"), expected)
|
||||
assert torch.allclose(after_result, expected)
|
Reference in New Issue
Block a user