First pass at making custom layer patches work with weights streamed from the CPU to the GPU.

This commit is contained in:
Ryan Dick
2024-12-29 06:51:30 +00:00
parent 6d49ee839c
commit a8bef59699
6 changed files with 92 additions and 45 deletions

View File

@ -4,17 +4,26 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
CustomModuleMixin,
)
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.utils import (
add_nullable_tensors,
)
class CustomConv1d(torch.nn.Conv1d, CustomModuleMixin):
def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor:
aggregated_param_residuals = self._aggregate_patch_parameters(self._patches_and_weights)
weight = add_nullable_tensors(self.weight, aggregated_param_residuals.get("weight", None))
bias = add_nullable_tensors(self.bias, aggregated_param_residuals.get("bias", None))
return self._conv_forward(input, weight, bias)
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
# Prepare the original parameters for the patch aggregation.
orig_params = {"weight": weight, "bias": bias}
# Filter out None values.
orig_params = {k: v for k, v in orig_params.items() if v is not None}
aggregated_param_residuals = self._aggregate_patch_parameters(
patches_and_weights=self._patches_and_weights,
orig_params=orig_params,
device=input.device,
)
return self._conv_forward(
input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None)
)
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)

View File

@ -4,17 +4,26 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
CustomModuleMixin,
)
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.utils import (
add_nullable_tensors,
)
class CustomConv2d(torch.nn.Conv2d, CustomModuleMixin):
def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor:
aggregated_param_residuals = self._aggregate_patch_parameters(self._patches_and_weights)
weight = add_nullable_tensors(self.weight, aggregated_param_residuals.get("weight", None))
bias = add_nullable_tensors(self.bias, aggregated_param_residuals.get("bias", None))
return self._conv_forward(input, weight, bias)
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
# Prepare the original parameters for the patch aggregation.
orig_params = {"weight": weight, "bias": bias}
# Filter out None values.
orig_params = {k: v for k, v in orig_params.items() if v is not None}
aggregated_param_residuals = self._aggregate_patch_parameters(
patches_and_weights=self._patches_and_weights,
orig_params=orig_params,
device=input.device,
)
return self._conv_forward(
input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None)
)
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)

View File

@ -16,10 +16,12 @@ class CustomFluxRMSNorm(RMSNorm, CustomModuleMixin):
assert isinstance(patch, SetParameterLayer)
assert patch.param_name == "scale"
scale = cast_to_device(patch.weight, x.device)
# Apply the patch.
# NOTE(ryand): Currently, we ignore the patch weight when running as a sidecar. It's not clear how this should
# be handled.
return torch.nn.functional.rms_norm(x, patch.weight.shape, patch.weight, eps=1e-6)
return torch.nn.functional.rms_norm(x, scale.shape, scale, eps=1e-6)
def _autocast_forward(self, x: torch.Tensor) -> torch.Tensor:
scale = cast_to_device(self.scale, x.device)

View File

@ -1,3 +1,5 @@
import copy
import torch
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
@ -55,6 +57,10 @@ def autocast_linear_forward_sidecar_patches(
# Then, apply layers for which we have optimized implementations.
unprocessed_patches_and_weights: list[tuple[BaseLayerPatch, float]] = []
for patch, patch_weight in patches_and_weights:
# Shallow copy the patch so that we can cast it to the target device without modifying the original patch.
patch = copy.copy(patch)
patch.to(input.device)
if isinstance(patch, FluxControlLoRALayer):
# Note that we use the original input here, not the sliced input.
output += linear_lora_forward(orig_input, patch, patch_weight)
@ -67,7 +73,14 @@ def autocast_linear_forward_sidecar_patches(
# Finally, apply any remaining patches.
if len(unprocessed_patches_and_weights) > 0:
aggregated_param_residuals = orig_module._aggregate_patch_parameters(unprocessed_patches_and_weights)
# Prepare the original parameters for the patch aggregation.
orig_params = {"weight": orig_module.weight, "bias": orig_module.bias}
# Filter out None values.
orig_params = {k: v for k, v in orig_params.items() if v is not None}
aggregated_param_residuals = orig_module._aggregate_patch_parameters(
unprocessed_patches_and_weights, orig_params=orig_params, device=input.device
)
output += torch.nn.functional.linear(
input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None)
)

View File

@ -1,3 +1,5 @@
import copy
import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
@ -34,15 +36,23 @@ class CustomModuleMixin:
return len(self._patches_and_weights)
def _aggregate_patch_parameters(
self, patches_and_weights: list[tuple[BaseLayerPatch, float]]
) -> dict[str, torch.Tensor]:
self,
patches_and_weights: list[tuple[BaseLayerPatch, float]],
orig_params: dict[str, torch.Tensor],
device: torch.device | None = None,
):
"""Helper function that aggregates the parameters from all patches into a single dict."""
params: dict[str, torch.Tensor] = {}
for patch, patch_weight in patches_and_weights:
if device is not None:
# Shallow copy the patch so that we can cast it to the target device without modifying the original patch.
patch = copy.copy(patch)
patch.to(device)
# TODO(ryand): `self` could be a quantized module. Depending on what the patch is doing with the original
# parameters, this might fail or return incorrect results.
layer_params = patch.get_parameters(dict(self.named_parameters(recurse=False)), weight=patch_weight) # type: ignore
layer_params = patch.get_parameters(orig_params, weight=patch_weight)
for param_name, param_weight in layer_params.items():
if param_name not in params:

View File

@ -389,40 +389,44 @@ def test_linear_sidecar_patches(device: str, patch_under_test: PatchUnderTest):
@parameterize_cuda_and_mps
# def test_linear_sidecar_patches_with_autocast_from_cpu_to_device(device: str, patch_under_test: PatchUnderTest):
# patches, input = patch_under_test
def test_linear_sidecar_patches_with_autocast_from_cpu_to_device(device: str, patch_under_test: PatchUnderTest):
"""Test that the output of a linear layer with sidecar patches is the same when the layer is on the device and
when the layer is on the CPU and the patches are autocasted to the device.
"""
patches, input = patch_under_test
# # Build the base layer under test.
# layer = torch.nn.Linear(32, 64)
# Build the base layer under test.
layer = torch.nn.Linear(32, 64)
# # Move the layer and input to the device.
# layer_to_device_via_state_dict(layer, device)
# input = input.to(torch.device(device))
# Move the layer and input to the device.
layer_to_device_via_state_dict(layer, device)
input = input.to(torch.device(device))
# # Wrap the original layer in a custom layer and add the patch to it.
# custom_layer = wrap_single_custom_layer(layer)
# for patch, weight in patches:
# patch.to(torch.device(device))
# custom_layer.add_patch(patch, weight)
# Wrap the original layer in a custom layer and add the patch to it.
custom_layer = wrap_single_custom_layer(layer)
for patch, weight in patches:
patch.to(torch.device(device))
custom_layer.add_patch(patch, weight)
# # Run inference with the custom layer on the device.
# expected_output = custom_layer(input)
# Run inference with the custom layer on the device.
expected_output = custom_layer(input)
# # Move the custom layer to the CPU.
# layer_to_device_via_state_dict(custom_layer, "cpu")
# Move the custom layer to the CPU.
layer_to_device_via_state_dict(custom_layer, "cpu")
# # Move the patches to the CPU.
# custom_layer.clear_patches()
# for patch, weight in patches:
# patch.to(torch.device("cpu"))
# custom_layer.add_patch(patch, weight)
# Move the patches to the CPU.
custom_layer.clear_patches()
for patch, weight in patches:
patch.to(torch.device("cpu"))
custom_layer.add_patch(patch, weight)
# # Run inference with an input on the device, and all layer weights on the CPU. The weights should be autocasted to
# # the device.
# autocast_output = custom_layer(input)
# assert autocast_output.device.type == device
# Run inference with an input on the device, and all layer weights on the CPU. The weights should be autocasted to
# the device.
autocast_output = custom_layer(input)
assert autocast_output.device.type == device
# assert torch.allclose(expected_output, autocast_output, atol=1e-6)
# Assert that the outputs with and without autocasting are the same.
assert torch.allclose(expected_output, autocast_output, atol=1e-6)
@pytest.fixture(