mirror of
https://github.com/invoke-ai/InvokeAI
synced 2025-07-26 05:17:55 +00:00
First pass at making custom layer patches work with weights streamed from the CPU to the GPU.
This commit is contained in:
@ -4,17 +4,26 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
|
||||
CustomModuleMixin,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.utils import (
|
||||
add_nullable_tensors,
|
||||
)
|
||||
|
||||
|
||||
class CustomConv1d(torch.nn.Conv1d, CustomModuleMixin):
|
||||
def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor:
|
||||
aggregated_param_residuals = self._aggregate_patch_parameters(self._patches_and_weights)
|
||||
weight = add_nullable_tensors(self.weight, aggregated_param_residuals.get("weight", None))
|
||||
bias = add_nullable_tensors(self.bias, aggregated_param_residuals.get("bias", None))
|
||||
return self._conv_forward(input, weight, bias)
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
bias = cast_to_device(self.bias, input.device)
|
||||
|
||||
# Prepare the original parameters for the patch aggregation.
|
||||
orig_params = {"weight": weight, "bias": bias}
|
||||
# Filter out None values.
|
||||
orig_params = {k: v for k, v in orig_params.items() if v is not None}
|
||||
|
||||
aggregated_param_residuals = self._aggregate_patch_parameters(
|
||||
patches_and_weights=self._patches_and_weights,
|
||||
orig_params=orig_params,
|
||||
device=input.device,
|
||||
)
|
||||
return self._conv_forward(
|
||||
input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None)
|
||||
)
|
||||
|
||||
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
|
@ -4,17 +4,26 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
|
||||
CustomModuleMixin,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.utils import (
|
||||
add_nullable_tensors,
|
||||
)
|
||||
|
||||
|
||||
class CustomConv2d(torch.nn.Conv2d, CustomModuleMixin):
|
||||
def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor:
|
||||
aggregated_param_residuals = self._aggregate_patch_parameters(self._patches_and_weights)
|
||||
weight = add_nullable_tensors(self.weight, aggregated_param_residuals.get("weight", None))
|
||||
bias = add_nullable_tensors(self.bias, aggregated_param_residuals.get("bias", None))
|
||||
return self._conv_forward(input, weight, bias)
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
bias = cast_to_device(self.bias, input.device)
|
||||
|
||||
# Prepare the original parameters for the patch aggregation.
|
||||
orig_params = {"weight": weight, "bias": bias}
|
||||
# Filter out None values.
|
||||
orig_params = {k: v for k, v in orig_params.items() if v is not None}
|
||||
|
||||
aggregated_param_residuals = self._aggregate_patch_parameters(
|
||||
patches_and_weights=self._patches_and_weights,
|
||||
orig_params=orig_params,
|
||||
device=input.device,
|
||||
)
|
||||
return self._conv_forward(
|
||||
input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None)
|
||||
)
|
||||
|
||||
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
|
@ -16,10 +16,12 @@ class CustomFluxRMSNorm(RMSNorm, CustomModuleMixin):
|
||||
assert isinstance(patch, SetParameterLayer)
|
||||
assert patch.param_name == "scale"
|
||||
|
||||
scale = cast_to_device(patch.weight, x.device)
|
||||
|
||||
# Apply the patch.
|
||||
# NOTE(ryand): Currently, we ignore the patch weight when running as a sidecar. It's not clear how this should
|
||||
# be handled.
|
||||
return torch.nn.functional.rms_norm(x, patch.weight.shape, patch.weight, eps=1e-6)
|
||||
return torch.nn.functional.rms_norm(x, scale.shape, scale, eps=1e-6)
|
||||
|
||||
def _autocast_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
scale = cast_to_device(self.scale, x.device)
|
||||
|
@ -1,3 +1,5 @@
|
||||
import copy
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
|
||||
@ -55,6 +57,10 @@ def autocast_linear_forward_sidecar_patches(
|
||||
# Then, apply layers for which we have optimized implementations.
|
||||
unprocessed_patches_and_weights: list[tuple[BaseLayerPatch, float]] = []
|
||||
for patch, patch_weight in patches_and_weights:
|
||||
# Shallow copy the patch so that we can cast it to the target device without modifying the original patch.
|
||||
patch = copy.copy(patch)
|
||||
patch.to(input.device)
|
||||
|
||||
if isinstance(patch, FluxControlLoRALayer):
|
||||
# Note that we use the original input here, not the sliced input.
|
||||
output += linear_lora_forward(orig_input, patch, patch_weight)
|
||||
@ -67,7 +73,14 @@ def autocast_linear_forward_sidecar_patches(
|
||||
|
||||
# Finally, apply any remaining patches.
|
||||
if len(unprocessed_patches_and_weights) > 0:
|
||||
aggregated_param_residuals = orig_module._aggregate_patch_parameters(unprocessed_patches_and_weights)
|
||||
# Prepare the original parameters for the patch aggregation.
|
||||
orig_params = {"weight": orig_module.weight, "bias": orig_module.bias}
|
||||
# Filter out None values.
|
||||
orig_params = {k: v for k, v in orig_params.items() if v is not None}
|
||||
|
||||
aggregated_param_residuals = orig_module._aggregate_patch_parameters(
|
||||
unprocessed_patches_and_weights, orig_params=orig_params, device=input.device
|
||||
)
|
||||
output += torch.nn.functional.linear(
|
||||
input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None)
|
||||
)
|
||||
|
@ -1,3 +1,5 @@
|
||||
import copy
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
@ -34,15 +36,23 @@ class CustomModuleMixin:
|
||||
return len(self._patches_and_weights)
|
||||
|
||||
def _aggregate_patch_parameters(
|
||||
self, patches_and_weights: list[tuple[BaseLayerPatch, float]]
|
||||
) -> dict[str, torch.Tensor]:
|
||||
self,
|
||||
patches_and_weights: list[tuple[BaseLayerPatch, float]],
|
||||
orig_params: dict[str, torch.Tensor],
|
||||
device: torch.device | None = None,
|
||||
):
|
||||
"""Helper function that aggregates the parameters from all patches into a single dict."""
|
||||
params: dict[str, torch.Tensor] = {}
|
||||
|
||||
for patch, patch_weight in patches_and_weights:
|
||||
if device is not None:
|
||||
# Shallow copy the patch so that we can cast it to the target device without modifying the original patch.
|
||||
patch = copy.copy(patch)
|
||||
patch.to(device)
|
||||
|
||||
# TODO(ryand): `self` could be a quantized module. Depending on what the patch is doing with the original
|
||||
# parameters, this might fail or return incorrect results.
|
||||
layer_params = patch.get_parameters(dict(self.named_parameters(recurse=False)), weight=patch_weight) # type: ignore
|
||||
layer_params = patch.get_parameters(orig_params, weight=patch_weight)
|
||||
|
||||
for param_name, param_weight in layer_params.items():
|
||||
if param_name not in params:
|
||||
|
@ -389,40 +389,44 @@ def test_linear_sidecar_patches(device: str, patch_under_test: PatchUnderTest):
|
||||
|
||||
|
||||
@parameterize_cuda_and_mps
|
||||
# def test_linear_sidecar_patches_with_autocast_from_cpu_to_device(device: str, patch_under_test: PatchUnderTest):
|
||||
# patches, input = patch_under_test
|
||||
def test_linear_sidecar_patches_with_autocast_from_cpu_to_device(device: str, patch_under_test: PatchUnderTest):
|
||||
"""Test that the output of a linear layer with sidecar patches is the same when the layer is on the device and
|
||||
when the layer is on the CPU and the patches are autocasted to the device.
|
||||
"""
|
||||
patches, input = patch_under_test
|
||||
|
||||
# # Build the base layer under test.
|
||||
# layer = torch.nn.Linear(32, 64)
|
||||
# Build the base layer under test.
|
||||
layer = torch.nn.Linear(32, 64)
|
||||
|
||||
# # Move the layer and input to the device.
|
||||
# layer_to_device_via_state_dict(layer, device)
|
||||
# input = input.to(torch.device(device))
|
||||
# Move the layer and input to the device.
|
||||
layer_to_device_via_state_dict(layer, device)
|
||||
input = input.to(torch.device(device))
|
||||
|
||||
# # Wrap the original layer in a custom layer and add the patch to it.
|
||||
# custom_layer = wrap_single_custom_layer(layer)
|
||||
# for patch, weight in patches:
|
||||
# patch.to(torch.device(device))
|
||||
# custom_layer.add_patch(patch, weight)
|
||||
# Wrap the original layer in a custom layer and add the patch to it.
|
||||
custom_layer = wrap_single_custom_layer(layer)
|
||||
for patch, weight in patches:
|
||||
patch.to(torch.device(device))
|
||||
custom_layer.add_patch(patch, weight)
|
||||
|
||||
# # Run inference with the custom layer on the device.
|
||||
# expected_output = custom_layer(input)
|
||||
# Run inference with the custom layer on the device.
|
||||
expected_output = custom_layer(input)
|
||||
|
||||
# # Move the custom layer to the CPU.
|
||||
# layer_to_device_via_state_dict(custom_layer, "cpu")
|
||||
# Move the custom layer to the CPU.
|
||||
layer_to_device_via_state_dict(custom_layer, "cpu")
|
||||
|
||||
# # Move the patches to the CPU.
|
||||
# custom_layer.clear_patches()
|
||||
# for patch, weight in patches:
|
||||
# patch.to(torch.device("cpu"))
|
||||
# custom_layer.add_patch(patch, weight)
|
||||
# Move the patches to the CPU.
|
||||
custom_layer.clear_patches()
|
||||
for patch, weight in patches:
|
||||
patch.to(torch.device("cpu"))
|
||||
custom_layer.add_patch(patch, weight)
|
||||
|
||||
# # Run inference with an input on the device, and all layer weights on the CPU. The weights should be autocasted to
|
||||
# # the device.
|
||||
# autocast_output = custom_layer(input)
|
||||
# assert autocast_output.device.type == device
|
||||
# Run inference with an input on the device, and all layer weights on the CPU. The weights should be autocasted to
|
||||
# the device.
|
||||
autocast_output = custom_layer(input)
|
||||
assert autocast_output.device.type == device
|
||||
|
||||
# assert torch.allclose(expected_output, autocast_output, atol=1e-6)
|
||||
# Assert that the outputs with and without autocasting are the same.
|
||||
assert torch.allclose(expected_output, autocast_output, atol=1e-6)
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
|
Reference in New Issue
Block a user