mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix bug in skip_torch_weight_init() where the original behavior of torch.nn.Conv*d modules wasn't being restored correctly.
This commit is contained in:
parent
58b56e9b1e
commit
61242bf86a
@ -17,7 +17,7 @@ def skip_torch_weight_init():
|
||||
completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager
|
||||
monkey-patches common torch layers to skip the weight initialization step.
|
||||
"""
|
||||
torch_modules = [torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d]
|
||||
torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd]
|
||||
saved_functions = [m.reset_parameters for m in torch_modules]
|
||||
|
||||
try:
|
||||
|
@ -42,3 +42,29 @@ def test_skip_torch_weight_init_linear(torch_module, layer_args):
|
||||
assert reset_params_fn_before is reset_params_fn_after
|
||||
assert torch.allclose(layer_before.weight, layer_after.weight)
|
||||
assert torch.allclose(layer_before.bias, layer_after.bias)
|
||||
|
||||
|
||||
def test_skip_torch_weight_init_restores_base_class_behavior():
|
||||
"""Test that `skip_torch_weight_init()` correctly restores the original behavior of torch.nn.Conv*d modules. This
|
||||
test was created to catch a previous bug where `reset_parameters` was being copied from the base `_ConvNd` class to
|
||||
its child classes (like `Conv1d`).
|
||||
"""
|
||||
with skip_torch_weight_init():
|
||||
# There is no need to do anything while the context manager is applied, we're just testing that the original
|
||||
# behavior is restored correctly.
|
||||
pass
|
||||
|
||||
# Mock the behavior of another library that monkey patches `torch.nn.modules.conv._ConvNd.reset_parameters` and
|
||||
# expects it to affect all of the sub-classes (e.g. `torch.nn.Conv1D`, `torch.nn.Conv2D`, etc.).
|
||||
called_monkey_patched_fn = False
|
||||
|
||||
def monkey_patched_fn(*args, **kwargs):
|
||||
nonlocal called_monkey_patched_fn
|
||||
called_monkey_patched_fn = True
|
||||
|
||||
saved_fn = torch.nn.modules.conv._ConvNd.reset_parameters
|
||||
torch.nn.modules.conv._ConvNd.reset_parameters = monkey_patched_fn
|
||||
_ = torch.nn.Conv1d(10, 20, 3)
|
||||
torch.nn.modules.conv._ConvNd.reset_parameters = saved_fn
|
||||
|
||||
assert called_monkey_patched_fn == True
|
||||
|
Loading…
Reference in New Issue
Block a user