From 61242bf86ac2e8d704699bd581326efcb45f62be Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 10 Oct 2023 10:05:50 -0400 Subject: [PATCH] Fix bug in skip_torch_weight_init() where the original behavior of torch.nn.Conv*d modules wasn't being restored correctly. --- .../model_load_optimizations.py | 2 +- .../test_model_load_optimization.py | 26 +++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/invokeai/backend/model_management/model_load_optimizations.py b/invokeai/backend/model_management/model_load_optimizations.py index fe821c41b6..f835079213 100644 --- a/invokeai/backend/model_management/model_load_optimizations.py +++ b/invokeai/backend/model_management/model_load_optimizations.py @@ -17,7 +17,7 @@ def skip_torch_weight_init(): completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager monkey-patches common torch layers to skip the weight initialization step. """ - torch_modules = [torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d] + torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd] saved_functions = [m.reset_parameters for m in torch_modules] try: diff --git a/tests/backend/model_management/test_model_load_optimization.py b/tests/backend/model_management/test_model_load_optimization.py index 385677ea3e..4273872069 100644 --- a/tests/backend/model_management/test_model_load_optimization.py +++ b/tests/backend/model_management/test_model_load_optimization.py @@ -42,3 +42,29 @@ def test_skip_torch_weight_init_linear(torch_module, layer_args): assert reset_params_fn_before is reset_params_fn_after assert torch.allclose(layer_before.weight, layer_after.weight) assert torch.allclose(layer_before.bias, layer_after.bias) + + +def test_skip_torch_weight_init_restores_base_class_behavior(): + """Test that `skip_torch_weight_init()` correctly restores the original behavior of torch.nn.Conv*d modules. This + test was created to catch a previous bug where `reset_parameters` was being copied from the base `_ConvNd` class to + its child classes (like `Conv1d`). + """ + with skip_torch_weight_init(): + # There is no need to do anything while the context manager is applied, we're just testing that the original + # behavior is restored correctly. + pass + + # Mock the behavior of another library that monkey patches `torch.nn.modules.conv._ConvNd.reset_parameters` and + # expects it to affect all of the sub-classes (e.g. `torch.nn.Conv1D`, `torch.nn.Conv2D`, etc.). + called_monkey_patched_fn = False + + def monkey_patched_fn(*args, **kwargs): + nonlocal called_monkey_patched_fn + called_monkey_patched_fn = True + + saved_fn = torch.nn.modules.conv._ConvNd.reset_parameters + torch.nn.modules.conv._ConvNd.reset_parameters = monkey_patched_fn + _ = torch.nn.Conv1d(10, 20, 3) + torch.nn.modules.conv._ConvNd.reset_parameters = saved_fn + + assert called_monkey_patched_fn == True