mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Skip torch.nn.Embedding.reset_parameters(...) when loading a text encoder model.
This commit is contained in:
parent
6e7a3f0546
commit
e391f3c9a8
invokeai/backend/model_management
tests/backend/model_management
@ -17,7 +17,7 @@ def skip_torch_weight_init():
|
||||
completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager
|
||||
monkey-patches common torch layers to skip the weight initialization step.
|
||||
"""
|
||||
torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd]
|
||||
torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd, torch.nn.Embedding]
|
||||
saved_functions = [m.reset_parameters for m in torch_modules]
|
||||
|
||||
try:
|
||||
|
@ -11,6 +11,7 @@ from invokeai.backend.model_management.model_load_optimizations import _no_op, s
|
||||
(torch.nn.Conv1d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}),
|
||||
(torch.nn.Conv2d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}),
|
||||
(torch.nn.Conv3d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}),
|
||||
(torch.nn.Embedding, {"num_embeddings": 10, "embedding_dim": 10}),
|
||||
],
|
||||
)
|
||||
def test_skip_torch_weight_init_linear(torch_module, layer_args):
|
||||
@ -36,12 +37,14 @@ def test_skip_torch_weight_init_linear(torch_module, layer_args):
|
||||
# Check that reset_parameters is skipped while `skip_torch_weight_init()` is active.
|
||||
assert reset_params_fn_during == _no_op
|
||||
assert not torch.allclose(layer_before.weight, layer_during.weight)
|
||||
assert not torch.allclose(layer_before.bias, layer_during.bias)
|
||||
if hasattr(layer_before, "bias"):
|
||||
assert not torch.allclose(layer_before.bias, layer_during.bias)
|
||||
|
||||
# Check that the original behavior is restored after `skip_torch_weight_init()` ends.
|
||||
assert reset_params_fn_before is reset_params_fn_after
|
||||
assert torch.allclose(layer_before.weight, layer_after.weight)
|
||||
assert torch.allclose(layer_before.bias, layer_after.bias)
|
||||
if hasattr(layer_before, "bias"):
|
||||
assert torch.allclose(layer_before.bias, layer_after.bias)
|
||||
|
||||
|
||||
def test_skip_torch_weight_init_restores_base_class_behavior():
|
||||
|
Loading…
x
Reference in New Issue
Block a user