InvokeAI/invokeai/backend/model_manager/load/optimizations.py
psychedelicious 4191ca1a46 final tidying before marking PR as ready for review
- Replace AnyModelLoader with ModelLoaderRegistry
- Fix type check errors in multiple files
- Remove apparently unneeded `get_model_config_enum()` method from model manager
- Remove last vestiges of old model manager
- Updated tests and documentation

resolve conflict with seamless.py
2024-02-29 13:28:20 -05:00

32 lines
1.4 KiB
Python

from contextlib import contextmanager
from typing import Any, Generator
import torch
def _no_op(*args: Any, **kwargs: Any) -> None:
pass
@contextmanager
def skip_torch_weight_init() -> Generator[None, None, None]:
"""Monkey patch several of the common torch layers (torch.nn.Linear, torch.nn.Conv1d, etc.) to skip weight initialization.
By default, `torch.nn.Linear` and `torch.nn.ConvNd` layers initialize their weights (according to a particular
distribution) when __init__ is called. This weight initialization step can take a significant amount of time, and is
completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager
monkey-patches common torch layers to skip the weight initialization step.
"""
torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd, torch.nn.Embedding]
saved_functions = [hasattr(m, "reset_parameters") and m.reset_parameters for m in torch_modules]
try:
for torch_module in torch_modules:
assert hasattr(torch_module, "reset_parameters")
torch_module.reset_parameters = _no_op
yield None
finally:
for torch_module, saved_function in zip(torch_modules, saved_functions, strict=True):
assert hasattr(torch_module, "reset_parameters")
torch_module.reset_parameters = saved_function