mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix bug in skip_torch_weight_init() where the original behavior of torch.nn.Conv*d modules wasn't being restored correctly.
This commit is contained in:
parent
58b56e9b1e
commit
61242bf86a
@ -17,7 +17,7 @@ def skip_torch_weight_init():
|
|||||||
completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager
|
completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager
|
||||||
monkey-patches common torch layers to skip the weight initialization step.
|
monkey-patches common torch layers to skip the weight initialization step.
|
||||||
"""
|
"""
|
||||||
torch_modules = [torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d]
|
torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd]
|
||||||
saved_functions = [m.reset_parameters for m in torch_modules]
|
saved_functions = [m.reset_parameters for m in torch_modules]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -42,3 +42,29 @@ def test_skip_torch_weight_init_linear(torch_module, layer_args):
|
|||||||
assert reset_params_fn_before is reset_params_fn_after
|
assert reset_params_fn_before is reset_params_fn_after
|
||||||
assert torch.allclose(layer_before.weight, layer_after.weight)
|
assert torch.allclose(layer_before.weight, layer_after.weight)
|
||||||
assert torch.allclose(layer_before.bias, layer_after.bias)
|
assert torch.allclose(layer_before.bias, layer_after.bias)
|
||||||
|
|
||||||
|
|
||||||
|
def test_skip_torch_weight_init_restores_base_class_behavior():
|
||||||
|
"""Test that `skip_torch_weight_init()` correctly restores the original behavior of torch.nn.Conv*d modules. This
|
||||||
|
test was created to catch a previous bug where `reset_parameters` was being copied from the base `_ConvNd` class to
|
||||||
|
its child classes (like `Conv1d`).
|
||||||
|
"""
|
||||||
|
with skip_torch_weight_init():
|
||||||
|
# There is no need to do anything while the context manager is applied, we're just testing that the original
|
||||||
|
# behavior is restored correctly.
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Mock the behavior of another library that monkey patches `torch.nn.modules.conv._ConvNd.reset_parameters` and
|
||||||
|
# expects it to affect all of the sub-classes (e.g. `torch.nn.Conv1D`, `torch.nn.Conv2D`, etc.).
|
||||||
|
called_monkey_patched_fn = False
|
||||||
|
|
||||||
|
def monkey_patched_fn(*args, **kwargs):
|
||||||
|
nonlocal called_monkey_patched_fn
|
||||||
|
called_monkey_patched_fn = True
|
||||||
|
|
||||||
|
saved_fn = torch.nn.modules.conv._ConvNd.reset_parameters
|
||||||
|
torch.nn.modules.conv._ConvNd.reset_parameters = monkey_patched_fn
|
||||||
|
_ = torch.nn.Conv1d(10, 20, 3)
|
||||||
|
torch.nn.modules.conv._ConvNd.reset_parameters = saved_fn
|
||||||
|
|
||||||
|
assert called_monkey_patched_fn == True
|
||||||
|
Loading…
Reference in New Issue
Block a user