diff --git a/invokeai/backend/model_management/model_load_optimizations.py b/invokeai/backend/model_management/model_load_optimizations.py index f835079213..8dc8a8793e 100644 --- a/invokeai/backend/model_management/model_load_optimizations.py +++ b/invokeai/backend/model_management/model_load_optimizations.py @@ -17,7 +17,7 @@ def skip_torch_weight_init(): completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager monkey-patches common torch layers to skip the weight initialization step. """ - torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd] + torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd, torch.nn.Embedding] saved_functions = [m.reset_parameters for m in torch_modules] try: diff --git a/tests/backend/model_management/test_model_load_optimization.py b/tests/backend/model_management/test_model_load_optimization.py index 43f007e972..a4fe1dd597 100644 --- a/tests/backend/model_management/test_model_load_optimization.py +++ b/tests/backend/model_management/test_model_load_optimization.py @@ -11,6 +11,7 @@ from invokeai.backend.model_management.model_load_optimizations import _no_op, s (torch.nn.Conv1d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}), (torch.nn.Conv2d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}), (torch.nn.Conv3d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}), + (torch.nn.Embedding, {"num_embeddings": 10, "embedding_dim": 10}), ], ) def test_skip_torch_weight_init_linear(torch_module, layer_args): @@ -36,12 +37,14 @@ def test_skip_torch_weight_init_linear(torch_module, layer_args): # Check that reset_parameters is skipped while `skip_torch_weight_init()` is active. assert reset_params_fn_during == _no_op assert not torch.allclose(layer_before.weight, layer_during.weight) - assert not torch.allclose(layer_before.bias, layer_during.bias) + if hasattr(layer_before, "bias"): + assert not torch.allclose(layer_before.bias, layer_during.bias) # Check that the original behavior is restored after `skip_torch_weight_init()` ends. assert reset_params_fn_before is reset_params_fn_after assert torch.allclose(layer_before.weight, layer_after.weight) - assert torch.allclose(layer_before.bias, layer_after.bias) + if hasattr(layer_before, "bias"): + assert torch.allclose(layer_before.bias, layer_after.bias) def test_skip_torch_weight_init_restores_base_class_behavior():