Fix bug in skip_torch_weight_init() where the original behavior of torch.nn.Conv*d modules wasn't being restored correctly.

This commit is contained in:
Ryan Dick 2023-10-10 10:05:50 -04:00
parent 58b56e9b1e
commit 61242bf86a
2 changed files with 27 additions and 1 deletions

View File

@ -17,7 +17,7 @@ def skip_torch_weight_init():
completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager
monkey-patches common torch layers to skip the weight initialization step. monkey-patches common torch layers to skip the weight initialization step.
""" """
torch_modules = [torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d] torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd]
saved_functions = [m.reset_parameters for m in torch_modules] saved_functions = [m.reset_parameters for m in torch_modules]
try: try:

View File

@ -42,3 +42,29 @@ def test_skip_torch_weight_init_linear(torch_module, layer_args):
assert reset_params_fn_before is reset_params_fn_after assert reset_params_fn_before is reset_params_fn_after
assert torch.allclose(layer_before.weight, layer_after.weight) assert torch.allclose(layer_before.weight, layer_after.weight)
assert torch.allclose(layer_before.bias, layer_after.bias) assert torch.allclose(layer_before.bias, layer_after.bias)
def test_skip_torch_weight_init_restores_base_class_behavior():
"""Test that `skip_torch_weight_init()` correctly restores the original behavior of torch.nn.Conv*d modules. This
test was created to catch a previous bug where `reset_parameters` was being copied from the base `_ConvNd` class to
its child classes (like `Conv1d`).
"""
with skip_torch_weight_init():
# There is no need to do anything while the context manager is applied, we're just testing that the original
# behavior is restored correctly.
pass
# Mock the behavior of another library that monkey patches `torch.nn.modules.conv._ConvNd.reset_parameters` and
# expects it to affect all of the sub-classes (e.g. `torch.nn.Conv1D`, `torch.nn.Conv2D`, etc.).
called_monkey_patched_fn = False
def monkey_patched_fn(*args, **kwargs):
nonlocal called_monkey_patched_fn
called_monkey_patched_fn = True
saved_fn = torch.nn.modules.conv._ConvNd.reset_parameters
torch.nn.modules.conv._ConvNd.reset_parameters = monkey_patched_fn
_ = torch.nn.Conv1d(10, 20, 3)
torch.nn.modules.conv._ConvNd.reset_parameters = saved_fn
assert called_monkey_patched_fn == True