from contextlib import contextmanager

import torch


def _no_op(*args, **kwargs):
    pass


@contextmanager
def skip_torch_weight_init():
    """A context manager that monkey-patches several of the common torch layers (torch.nn.Linear, torch.nn.Conv1d, etc.)
    to skip weight initialization.

    By default, `torch.nn.Linear` and `torch.nn.ConvNd` layers initialize their weights (according to a particular
    distribution) when __init__ is called. This weight initialization step can take a significant amount of time, and is
    completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager
    monkey-patches common torch layers to skip the weight initialization step.
    """
    torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd, torch.nn.Embedding]
    saved_functions = [m.reset_parameters for m in torch_modules]

    try:
        for torch_module in torch_modules:
            torch_module.reset_parameters = _no_op

        yield None
    finally:
        for torch_module, saved_function in zip(torch_modules, saved_functions, strict=True):
            torch_module.reset_parameters = saved_function