Skip torch.nn.Embedding.reset_parameters(...) when loading a text encoder model.

This commit is contained in:
Ryan Dick
2023-11-02 11:03:16 -04:00
committed by Kent Keirsey
parent 6e7a3f0546
commit e391f3c9a8
2 changed files with 6 additions and 3 deletions

View File

@ -17,7 +17,7 @@ def skip_torch_weight_init():
completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager
monkey-patches common torch layers to skip the weight initialization step.
"""
torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd]
torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd, torch.nn.Embedding]
saved_functions = [m.reset_parameters for m in torch_modules]
try: