mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Improve model load times from disk: skip unnecessary weight init (#4840)
## What type of PR is this? (check all applicable) - [ ] Refactor - [ ] Feature - [ ] Bug Fix - [x] Optimization - [ ] Documentation Update - [ ] Community Node Submission ## Have you updated all relevant documentation? - [x] Yes - [ ] No ## Description This PR optimizes the time to load models from disk. In my local testing, SDXL text_encoder_2 models saw the greatest improvement: - Before change, load time (disk to cpu): 14 secs - After change, load time (disk to cpu): 4 secs See the in-code documentation for an explanation of how this speedup is achieved. ## Related Tickets & Documents This change was previously proposed on the HF transformers repo, but did not get any traction: https://github.com/huggingface/transformers/issues/18505#issue-1330728188 ## QA Instructions, Screenshots, Recordings I don't expect any adverse effects, but the new context manager is applied while loading **all** models, so it would make sense to exercise everything. ## Added/updated tests? - [x] Yes - [ ] No
This commit is contained in:
commit
462c1d4c9b
@ -31,6 +31,7 @@ import torch
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.model_management.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||
from invokeai.backend.model_management.model_load_optimizations import skip_torch_weight_init
|
||||
|
||||
from ..util.devices import choose_torch_device
|
||||
from .models import BaseModelType, ModelBase, ModelType, SubModelType
|
||||
@ -223,7 +224,8 @@ class ModelCache(object):
|
||||
# Load the model from disk and capture a memory snapshot before/after.
|
||||
start_load_time = time.time()
|
||||
snapshot_before = MemorySnapshot.capture()
|
||||
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
|
||||
with skip_torch_weight_init():
|
||||
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
|
||||
snapshot_after = MemorySnapshot.capture()
|
||||
end_load_time = time.time()
|
||||
|
||||
|
@ -0,0 +1,30 @@
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def _no_op(*args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
@contextmanager
|
||||
def skip_torch_weight_init():
|
||||
"""A context manager that monkey-patches several of the common torch layers (torch.nn.Linear, torch.nn.Conv1d, etc.)
|
||||
to skip weight initialization.
|
||||
|
||||
By default, `torch.nn.Linear` and `torch.nn.ConvNd` layers initialize their weights (according to a particular
|
||||
distribution) when __init__ is called. This weight initialization step can take a significant amount of time, and is
|
||||
completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager
|
||||
monkey-patches common torch layers to skip the weight initialization step.
|
||||
"""
|
||||
torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd]
|
||||
saved_functions = [m.reset_parameters for m in torch_modules]
|
||||
|
||||
try:
|
||||
for torch_module in torch_modules:
|
||||
torch_module.reset_parameters = _no_op
|
||||
|
||||
yield None
|
||||
finally:
|
||||
for torch_module, saved_function in zip(torch_modules, saved_functions):
|
||||
torch_module.reset_parameters = saved_function
|
@ -0,0 +1,70 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_management.model_load_optimizations import _no_op, skip_torch_weight_init
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["torch_module", "layer_args"],
|
||||
[
|
||||
(torch.nn.Linear, {"in_features": 10, "out_features": 20}),
|
||||
(torch.nn.Conv1d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}),
|
||||
(torch.nn.Conv2d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}),
|
||||
(torch.nn.Conv3d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}),
|
||||
],
|
||||
)
|
||||
def test_skip_torch_weight_init_linear(torch_module, layer_args):
|
||||
"""Test the interactions between `skip_torch_weight_init()` and various torch modules."""
|
||||
seed = 123
|
||||
|
||||
# Initialize a torch layer *before* applying `skip_torch_weight_init()`.
|
||||
reset_params_fn_before = torch_module.reset_parameters
|
||||
torch.manual_seed(seed)
|
||||
layer_before = torch_module(**layer_args)
|
||||
|
||||
# Initialize a torch layer while `skip_torch_weight_init()` is applied.
|
||||
with skip_torch_weight_init():
|
||||
reset_params_fn_during = torch_module.reset_parameters
|
||||
torch.manual_seed(123)
|
||||
layer_during = torch_module(**layer_args)
|
||||
|
||||
# Initialize a torch layer *after* applying `skip_torch_weight_init()`.
|
||||
reset_params_fn_after = torch_module.reset_parameters
|
||||
torch.manual_seed(123)
|
||||
layer_after = torch_module(**layer_args)
|
||||
|
||||
# Check that reset_parameters is skipped while `skip_torch_weight_init()` is active.
|
||||
assert reset_params_fn_during == _no_op
|
||||
assert not torch.allclose(layer_before.weight, layer_during.weight)
|
||||
assert not torch.allclose(layer_before.bias, layer_during.bias)
|
||||
|
||||
# Check that the original behavior is restored after `skip_torch_weight_init()` ends.
|
||||
assert reset_params_fn_before is reset_params_fn_after
|
||||
assert torch.allclose(layer_before.weight, layer_after.weight)
|
||||
assert torch.allclose(layer_before.bias, layer_after.bias)
|
||||
|
||||
|
||||
def test_skip_torch_weight_init_restores_base_class_behavior():
|
||||
"""Test that `skip_torch_weight_init()` correctly restores the original behavior of torch.nn.Conv*d modules. This
|
||||
test was created to catch a previous bug where `reset_parameters` was being copied from the base `_ConvNd` class to
|
||||
its child classes (like `Conv1d`).
|
||||
"""
|
||||
with skip_torch_weight_init():
|
||||
# There is no need to do anything while the context manager is applied, we're just testing that the original
|
||||
# behavior is restored correctly.
|
||||
pass
|
||||
|
||||
# Mock the behavior of another library that monkey patches `torch.nn.modules.conv._ConvNd.reset_parameters` and
|
||||
# expects it to affect all of the sub-classes (e.g. `torch.nn.Conv1D`, `torch.nn.Conv2D`, etc.).
|
||||
called_monkey_patched_fn = False
|
||||
|
||||
def monkey_patched_fn(*args, **kwargs):
|
||||
nonlocal called_monkey_patched_fn
|
||||
called_monkey_patched_fn = True
|
||||
|
||||
saved_fn = torch.nn.modules.conv._ConvNd.reset_parameters
|
||||
torch.nn.modules.conv._ConvNd.reset_parameters = monkey_patched_fn
|
||||
_ = torch.nn.Conv1d(10, 20, 3)
|
||||
torch.nn.modules.conv._ConvNd.reset_parameters = saved_fn
|
||||
|
||||
assert called_monkey_patched_fn
|
Loading…
Reference in New Issue
Block a user