Add a skip_torch_weight_init() context manager to improve model load times (from disk).

This commit is contained in:
Ryan Dick
2023-10-09 14:12:56 -04:00
parent 1f751f8c21
commit 58b56e9b1e
3 changed files with 77 additions and 1 deletions

View File

@ -31,6 +31,7 @@ import torch
import invokeai.backend.util.logging as logger
from invokeai.backend.model_management.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.model_management.model_load_optimizations import skip_torch_weight_init
from ..util.devices import choose_torch_device
from .models import BaseModelType, ModelBase, ModelType, SubModelType
@ -223,7 +224,8 @@ class ModelCache(object):
# Load the model from disk and capture a memory snapshot before/after.
start_load_time = time.time()
snapshot_before = MemorySnapshot.capture()
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
with skip_torch_weight_init():
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
snapshot_after = MemorySnapshot.capture()
end_load_time = time.time()

View File

@ -0,0 +1,30 @@
from contextlib import contextmanager
import torch
def _no_op(*args, **kwargs):
pass
@contextmanager
def skip_torch_weight_init():
"""A context manager that monkey-patches several of the common torch layers (torch.nn.Linear, torch.nn.Conv1d, etc.)
to skip weight initialization.
By default, `torch.nn.Linear` and `torch.nn.ConvNd` layers initialize their weights (according to a particular
distribution) when __init__ is called. This weight initialization step can take a significant amount of time, and is
completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager
monkey-patches common torch layers to skip the weight initialization step.
"""
torch_modules = [torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d]
saved_functions = [m.reset_parameters for m in torch_modules]
try:
for torch_module in torch_modules:
torch_module.reset_parameters = _no_op
yield None
finally:
for torch_module, saved_function in zip(torch_modules, saved_functions):
torch_module.reset_parameters = saved_function