mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add a skip_torch_weight_init() context manager to improve model load times (from disk).
This commit is contained in:
parent
1f751f8c21
commit
58b56e9b1e
@ -31,6 +31,7 @@ import torch
|
|||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.backend.model_management.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
from invokeai.backend.model_management.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||||
|
from invokeai.backend.model_management.model_load_optimizations import skip_torch_weight_init
|
||||||
|
|
||||||
from ..util.devices import choose_torch_device
|
from ..util.devices import choose_torch_device
|
||||||
from .models import BaseModelType, ModelBase, ModelType, SubModelType
|
from .models import BaseModelType, ModelBase, ModelType, SubModelType
|
||||||
@ -223,6 +224,7 @@ class ModelCache(object):
|
|||||||
# Load the model from disk and capture a memory snapshot before/after.
|
# Load the model from disk and capture a memory snapshot before/after.
|
||||||
start_load_time = time.time()
|
start_load_time = time.time()
|
||||||
snapshot_before = MemorySnapshot.capture()
|
snapshot_before = MemorySnapshot.capture()
|
||||||
|
with skip_torch_weight_init():
|
||||||
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
|
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
|
||||||
snapshot_after = MemorySnapshot.capture()
|
snapshot_after = MemorySnapshot.capture()
|
||||||
end_load_time = time.time()
|
end_load_time = time.time()
|
||||||
|
@ -0,0 +1,30 @@
|
|||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def _no_op(*args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def skip_torch_weight_init():
|
||||||
|
"""A context manager that monkey-patches several of the common torch layers (torch.nn.Linear, torch.nn.Conv1d, etc.)
|
||||||
|
to skip weight initialization.
|
||||||
|
|
||||||
|
By default, `torch.nn.Linear` and `torch.nn.ConvNd` layers initialize their weights (according to a particular
|
||||||
|
distribution) when __init__ is called. This weight initialization step can take a significant amount of time, and is
|
||||||
|
completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager
|
||||||
|
monkey-patches common torch layers to skip the weight initialization step.
|
||||||
|
"""
|
||||||
|
torch_modules = [torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d]
|
||||||
|
saved_functions = [m.reset_parameters for m in torch_modules]
|
||||||
|
|
||||||
|
try:
|
||||||
|
for torch_module in torch_modules:
|
||||||
|
torch_module.reset_parameters = _no_op
|
||||||
|
|
||||||
|
yield None
|
||||||
|
finally:
|
||||||
|
for torch_module, saved_function in zip(torch_modules, saved_functions):
|
||||||
|
torch_module.reset_parameters = saved_function
|
@ -0,0 +1,44 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.backend.model_management.model_load_optimizations import _no_op, skip_torch_weight_init
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
["torch_module", "layer_args"],
|
||||||
|
[
|
||||||
|
(torch.nn.Linear, {"in_features": 10, "out_features": 20}),
|
||||||
|
(torch.nn.Conv1d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}),
|
||||||
|
(torch.nn.Conv2d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}),
|
||||||
|
(torch.nn.Conv3d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_skip_torch_weight_init_linear(torch_module, layer_args):
|
||||||
|
"""Test the interactions between `skip_torch_weight_init()` and various torch modules."""
|
||||||
|
seed = 123
|
||||||
|
|
||||||
|
# Initialize a torch layer *before* applying `skip_torch_weight_init()`.
|
||||||
|
reset_params_fn_before = torch_module.reset_parameters
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
layer_before = torch_module(**layer_args)
|
||||||
|
|
||||||
|
# Initialize a torch layer while `skip_torch_weight_init()` is applied.
|
||||||
|
with skip_torch_weight_init():
|
||||||
|
reset_params_fn_during = torch_module.reset_parameters
|
||||||
|
torch.manual_seed(123)
|
||||||
|
layer_during = torch_module(**layer_args)
|
||||||
|
|
||||||
|
# Initialize a torch layer *after* applying `skip_torch_weight_init()`.
|
||||||
|
reset_params_fn_after = torch_module.reset_parameters
|
||||||
|
torch.manual_seed(123)
|
||||||
|
layer_after = torch_module(**layer_args)
|
||||||
|
|
||||||
|
# Check that reset_parameters is skipped while `skip_torch_weight_init()` is active.
|
||||||
|
assert reset_params_fn_during == _no_op
|
||||||
|
assert not torch.allclose(layer_before.weight, layer_during.weight)
|
||||||
|
assert not torch.allclose(layer_before.bias, layer_during.bias)
|
||||||
|
|
||||||
|
# Check that the original behavior is restored after `skip_torch_weight_init()` ends.
|
||||||
|
assert reset_params_fn_before is reset_params_fn_after
|
||||||
|
assert torch.allclose(layer_before.weight, layer_after.weight)
|
||||||
|
assert torch.allclose(layer_before.bias, layer_after.bias)
|
Loading…
Reference in New Issue
Block a user