Add LayerNorm to list of modules optimized by skip_torch_weight_init()

This commit is contained in:
Ryan Dick 2024-04-04 17:15:05 -04:00
parent 8db4ba252a
commit 5d41157404

View File

@ -17,7 +17,7 @@ def skip_torch_weight_init() -> Generator[None, None, None]:
completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager
monkey-patches common torch layers to skip the weight initialization step.
"""
torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd, torch.nn.Embedding]
torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd, torch.nn.Embedding, torch.nn.LayerNorm]
saved_functions = [hasattr(m, "reset_parameters") and m.reset_parameters for m in torch_modules]
try: