Fix bug in skip_torch_weight_init() where the original behavior of torch.nn.Conv*d modules wasn't being restored correctly.

This commit is contained in:
Ryan Dick
2023-10-10 10:05:50 -04:00
parent 58b56e9b1e
commit 61242bf86a
2 changed files with 27 additions and 1 deletions

View File

@ -17,7 +17,7 @@ def skip_torch_weight_init():
completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager
monkey-patches common torch layers to skip the weight initialization step.
"""
torch_modules = [torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d]
torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd]
saved_functions = [m.reset_parameters for m in torch_modules]
try: