Add LayerNorm to list of modules optimized by skip_torch_weight_init()

This commit is contained in:
Ryan Dick 2024-04-04 17:15:05 -04:00
parent 8db4ba252a
commit 5d41157404

View File

@ -17,7 +17,7 @@ def skip_torch_weight_init() -> Generator[None, None, None]:
completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager
monkey-patches common torch layers to skip the weight initialization step. monkey-patches common torch layers to skip the weight initialization step.
""" """
torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd, torch.nn.Embedding] torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd, torch.nn.Embedding, torch.nn.LayerNorm]
saved_functions = [hasattr(m, "reset_parameters") and m.reset_parameters for m in torch_modules] saved_functions = [hasattr(m, "reset_parameters") and m.reset_parameters for m in torch_modules]
try: try: