From 58b56e9b1ec38bf1813e3ed79b34bccf334aef71 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Mon, 9 Oct 2023 14:12:56 -0400 Subject: [PATCH 1/3] Add a skip_torch_weight_init() context manager to improve model load times (from disk). --- .../backend/model_management/model_cache.py | 4 +- .../model_load_optimizations.py | 30 +++++++++++++ .../test_model_load_optimization.py | 44 +++++++++++++++++++ 3 files changed, 77 insertions(+), 1 deletion(-) create mode 100644 invokeai/backend/model_management/model_load_optimizations.py create mode 100644 tests/backend/model_management/test_model_load_optimization.py diff --git a/invokeai/backend/model_management/model_cache.py b/invokeai/backend/model_management/model_cache.py index 8cb6b55caf..9cf3852449 100644 --- a/invokeai/backend/model_management/model_cache.py +++ b/invokeai/backend/model_management/model_cache.py @@ -31,6 +31,7 @@ import torch import invokeai.backend.util.logging as logger from invokeai.backend.model_management.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff +from invokeai.backend.model_management.model_load_optimizations import skip_torch_weight_init from ..util.devices import choose_torch_device from .models import BaseModelType, ModelBase, ModelType, SubModelType @@ -223,7 +224,8 @@ class ModelCache(object): # Load the model from disk and capture a memory snapshot before/after. start_load_time = time.time() snapshot_before = MemorySnapshot.capture() - model = model_info.get_model(child_type=submodel, torch_dtype=self.precision) + with skip_torch_weight_init(): + model = model_info.get_model(child_type=submodel, torch_dtype=self.precision) snapshot_after = MemorySnapshot.capture() end_load_time = time.time() diff --git a/invokeai/backend/model_management/model_load_optimizations.py b/invokeai/backend/model_management/model_load_optimizations.py new file mode 100644 index 0000000000..fe821c41b6 --- /dev/null +++ b/invokeai/backend/model_management/model_load_optimizations.py @@ -0,0 +1,30 @@ +from contextlib import contextmanager + +import torch + + +def _no_op(*args, **kwargs): + pass + + +@contextmanager +def skip_torch_weight_init(): + """A context manager that monkey-patches several of the common torch layers (torch.nn.Linear, torch.nn.Conv1d, etc.) + to skip weight initialization. + + By default, `torch.nn.Linear` and `torch.nn.ConvNd` layers initialize their weights (according to a particular + distribution) when __init__ is called. This weight initialization step can take a significant amount of time, and is + completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager + monkey-patches common torch layers to skip the weight initialization step. + """ + torch_modules = [torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d] + saved_functions = [m.reset_parameters for m in torch_modules] + + try: + for torch_module in torch_modules: + torch_module.reset_parameters = _no_op + + yield None + finally: + for torch_module, saved_function in zip(torch_modules, saved_functions): + torch_module.reset_parameters = saved_function diff --git a/tests/backend/model_management/test_model_load_optimization.py b/tests/backend/model_management/test_model_load_optimization.py new file mode 100644 index 0000000000..385677ea3e --- /dev/null +++ b/tests/backend/model_management/test_model_load_optimization.py @@ -0,0 +1,44 @@ +import pytest +import torch + +from invokeai.backend.model_management.model_load_optimizations import _no_op, skip_torch_weight_init + + +@pytest.mark.parametrize( + ["torch_module", "layer_args"], + [ + (torch.nn.Linear, {"in_features": 10, "out_features": 20}), + (torch.nn.Conv1d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}), + (torch.nn.Conv2d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}), + (torch.nn.Conv3d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}), + ], +) +def test_skip_torch_weight_init_linear(torch_module, layer_args): + """Test the interactions between `skip_torch_weight_init()` and various torch modules.""" + seed = 123 + + # Initialize a torch layer *before* applying `skip_torch_weight_init()`. + reset_params_fn_before = torch_module.reset_parameters + torch.manual_seed(seed) + layer_before = torch_module(**layer_args) + + # Initialize a torch layer while `skip_torch_weight_init()` is applied. + with skip_torch_weight_init(): + reset_params_fn_during = torch_module.reset_parameters + torch.manual_seed(123) + layer_during = torch_module(**layer_args) + + # Initialize a torch layer *after* applying `skip_torch_weight_init()`. + reset_params_fn_after = torch_module.reset_parameters + torch.manual_seed(123) + layer_after = torch_module(**layer_args) + + # Check that reset_parameters is skipped while `skip_torch_weight_init()` is active. + assert reset_params_fn_during == _no_op + assert not torch.allclose(layer_before.weight, layer_during.weight) + assert not torch.allclose(layer_before.bias, layer_during.bias) + + # Check that the original behavior is restored after `skip_torch_weight_init()` ends. + assert reset_params_fn_before is reset_params_fn_after + assert torch.allclose(layer_before.weight, layer_after.weight) + assert torch.allclose(layer_before.bias, layer_after.bias) From 61242bf86ac2e8d704699bd581326efcb45f62be Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 10 Oct 2023 10:05:50 -0400 Subject: [PATCH 2/3] Fix bug in skip_torch_weight_init() where the original behavior of torch.nn.Conv*d modules wasn't being restored correctly. --- .../model_load_optimizations.py | 2 +- .../test_model_load_optimization.py | 26 +++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/invokeai/backend/model_management/model_load_optimizations.py b/invokeai/backend/model_management/model_load_optimizations.py index fe821c41b6..f835079213 100644 --- a/invokeai/backend/model_management/model_load_optimizations.py +++ b/invokeai/backend/model_management/model_load_optimizations.py @@ -17,7 +17,7 @@ def skip_torch_weight_init(): completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager monkey-patches common torch layers to skip the weight initialization step. """ - torch_modules = [torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d] + torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd] saved_functions = [m.reset_parameters for m in torch_modules] try: diff --git a/tests/backend/model_management/test_model_load_optimization.py b/tests/backend/model_management/test_model_load_optimization.py index 385677ea3e..4273872069 100644 --- a/tests/backend/model_management/test_model_load_optimization.py +++ b/tests/backend/model_management/test_model_load_optimization.py @@ -42,3 +42,29 @@ def test_skip_torch_weight_init_linear(torch_module, layer_args): assert reset_params_fn_before is reset_params_fn_after assert torch.allclose(layer_before.weight, layer_after.weight) assert torch.allclose(layer_before.bias, layer_after.bias) + + +def test_skip_torch_weight_init_restores_base_class_behavior(): + """Test that `skip_torch_weight_init()` correctly restores the original behavior of torch.nn.Conv*d modules. This + test was created to catch a previous bug where `reset_parameters` was being copied from the base `_ConvNd` class to + its child classes (like `Conv1d`). + """ + with skip_torch_weight_init(): + # There is no need to do anything while the context manager is applied, we're just testing that the original + # behavior is restored correctly. + pass + + # Mock the behavior of another library that monkey patches `torch.nn.modules.conv._ConvNd.reset_parameters` and + # expects it to affect all of the sub-classes (e.g. `torch.nn.Conv1D`, `torch.nn.Conv2D`, etc.). + called_monkey_patched_fn = False + + def monkey_patched_fn(*args, **kwargs): + nonlocal called_monkey_patched_fn + called_monkey_patched_fn = True + + saved_fn = torch.nn.modules.conv._ConvNd.reset_parameters + torch.nn.modules.conv._ConvNd.reset_parameters = monkey_patched_fn + _ = torch.nn.Conv1d(10, 20, 3) + torch.nn.modules.conv._ConvNd.reset_parameters = saved_fn + + assert called_monkey_patched_fn == True From f3c138a208debc1a96ac5706c0118f41f7722047 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 10 Oct 2023 10:06:53 -0400 Subject: [PATCH 3/3] (minor) Fix Flake8. --- tests/backend/model_management/test_model_load_optimization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/backend/model_management/test_model_load_optimization.py b/tests/backend/model_management/test_model_load_optimization.py index 4273872069..43f007e972 100644 --- a/tests/backend/model_management/test_model_load_optimization.py +++ b/tests/backend/model_management/test_model_load_optimization.py @@ -67,4 +67,4 @@ def test_skip_torch_weight_init_restores_base_class_behavior(): _ = torch.nn.Conv1d(10, 20, 3) torch.nn.modules.conv._ConvNd.reset_parameters = saved_fn - assert called_monkey_patched_fn == True + assert called_monkey_patched_fn