Skip torch.nn.Embedding.reset_parameters(...) when loading a text encoder model.

This commit is contained in:
Ryan Dick 2023-11-02 11:03:16 -04:00 committed by Kent Keirsey
parent 6e7a3f0546
commit e391f3c9a8
2 changed files with 6 additions and 3 deletions

View File

@ -17,7 +17,7 @@ def skip_torch_weight_init():
completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager
monkey-patches common torch layers to skip the weight initialization step. monkey-patches common torch layers to skip the weight initialization step.
""" """
torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd] torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd, torch.nn.Embedding]
saved_functions = [m.reset_parameters for m in torch_modules] saved_functions = [m.reset_parameters for m in torch_modules]
try: try:

View File

@ -11,6 +11,7 @@ from invokeai.backend.model_management.model_load_optimizations import _no_op, s
(torch.nn.Conv1d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}), (torch.nn.Conv1d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}),
(torch.nn.Conv2d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}), (torch.nn.Conv2d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}),
(torch.nn.Conv3d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}), (torch.nn.Conv3d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}),
(torch.nn.Embedding, {"num_embeddings": 10, "embedding_dim": 10}),
], ],
) )
def test_skip_torch_weight_init_linear(torch_module, layer_args): def test_skip_torch_weight_init_linear(torch_module, layer_args):
@ -36,11 +37,13 @@ def test_skip_torch_weight_init_linear(torch_module, layer_args):
# Check that reset_parameters is skipped while `skip_torch_weight_init()` is active. # Check that reset_parameters is skipped while `skip_torch_weight_init()` is active.
assert reset_params_fn_during == _no_op assert reset_params_fn_during == _no_op
assert not torch.allclose(layer_before.weight, layer_during.weight) assert not torch.allclose(layer_before.weight, layer_during.weight)
if hasattr(layer_before, "bias"):
assert not torch.allclose(layer_before.bias, layer_during.bias) assert not torch.allclose(layer_before.bias, layer_during.bias)
# Check that the original behavior is restored after `skip_torch_weight_init()` ends. # Check that the original behavior is restored after `skip_torch_weight_init()` ends.
assert reset_params_fn_before is reset_params_fn_after assert reset_params_fn_before is reset_params_fn_after
assert torch.allclose(layer_before.weight, layer_after.weight) assert torch.allclose(layer_before.weight, layer_after.weight)
if hasattr(layer_before, "bias"):
assert torch.allclose(layer_before.bias, layer_after.bias) assert torch.allclose(layer_before.bias, layer_after.bias)