2023-10-09 18:12:56 +00:00
|
|
|
from contextlib import contextmanager
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
|
|
def _no_op(*args, **kwargs):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
def skip_torch_weight_init():
|
|
|
|
"""A context manager that monkey-patches several of the common torch layers (torch.nn.Linear, torch.nn.Conv1d, etc.)
|
|
|
|
to skip weight initialization.
|
|
|
|
|
|
|
|
By default, `torch.nn.Linear` and `torch.nn.ConvNd` layers initialize their weights (according to a particular
|
|
|
|
distribution) when __init__ is called. This weight initialization step can take a significant amount of time, and is
|
|
|
|
completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager
|
|
|
|
monkey-patches common torch layers to skip the weight initialization step.
|
|
|
|
"""
|
2023-11-02 15:03:16 +00:00
|
|
|
torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd, torch.nn.Embedding]
|
2023-10-09 18:12:56 +00:00
|
|
|
saved_functions = [m.reset_parameters for m in torch_modules]
|
|
|
|
|
|
|
|
try:
|
|
|
|
for torch_module in torch_modules:
|
|
|
|
torch_module.reset_parameters = _no_op
|
|
|
|
|
|
|
|
yield None
|
|
|
|
finally:
|
2023-11-10 23:51:21 +00:00
|
|
|
for torch_module, saved_function in zip(torch_modules, saved_functions, strict=True):
|
2023-10-09 18:12:56 +00:00
|
|
|
torch_module.reset_parameters = saved_function
|