Skip torch.nn.Embedding.reset_parameters(...) when loading a text encoder model.

This commit is contained in:
Ryan Dick 2023-11-02 11:03:16 -04:00 committed by Kent Keirsey
parent 6e7a3f0546
commit e391f3c9a8
2 changed files with 6 additions and 3 deletions
invokeai/backend/model_management
tests/backend/model_management

@ -17,7 +17,7 @@ def skip_torch_weight_init():
completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager
monkey-patches common torch layers to skip the weight initialization step.
"""
torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd]
torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd, torch.nn.Embedding]
saved_functions = [m.reset_parameters for m in torch_modules]
try:

@ -11,6 +11,7 @@ from invokeai.backend.model_management.model_load_optimizations import _no_op, s
(torch.nn.Conv1d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}),
(torch.nn.Conv2d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}),
(torch.nn.Conv3d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}),
(torch.nn.Embedding, {"num_embeddings": 10, "embedding_dim": 10}),
],
)
def test_skip_torch_weight_init_linear(torch_module, layer_args):
@ -36,12 +37,14 @@ def test_skip_torch_weight_init_linear(torch_module, layer_args):
# Check that reset_parameters is skipped while `skip_torch_weight_init()` is active.
assert reset_params_fn_during == _no_op
assert not torch.allclose(layer_before.weight, layer_during.weight)
assert not torch.allclose(layer_before.bias, layer_during.bias)
if hasattr(layer_before, "bias"):
assert not torch.allclose(layer_before.bias, layer_during.bias)
# Check that the original behavior is restored after `skip_torch_weight_init()` ends.
assert reset_params_fn_before is reset_params_fn_after
assert torch.allclose(layer_before.weight, layer_after.weight)
assert torch.allclose(layer_before.bias, layer_after.bias)
if hasattr(layer_before, "bias"):
assert torch.allclose(layer_before.bias, layer_after.bias)
def test_skip_torch_weight_init_restores_base_class_behavior():