mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
final tidying before marking PR as ready for review
- Replace AnyModelLoader with ModelLoaderRegistry - Fix type check errors in multiple files - Remove apparently unneeded `get_model_config_enum()` method from model manager - Remove last vestiges of old model manager - Updated tests and documentation resolve conflict with seamless.py
This commit is contained in:
@ -1,16 +1,16 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Generator
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def _no_op(*args, **kwargs):
|
||||
def _no_op(*args: Any, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@contextmanager
|
||||
def skip_torch_weight_init():
|
||||
"""A context manager that monkey-patches several of the common torch layers (torch.nn.Linear, torch.nn.Conv1d, etc.)
|
||||
to skip weight initialization.
|
||||
def skip_torch_weight_init() -> Generator[None, None, None]:
|
||||
"""Monkey patch several of the common torch layers (torch.nn.Linear, torch.nn.Conv1d, etc.) to skip weight initialization.
|
||||
|
||||
By default, `torch.nn.Linear` and `torch.nn.ConvNd` layers initialize their weights (according to a particular
|
||||
distribution) when __init__ is called. This weight initialization step can take a significant amount of time, and is
|
||||
@ -18,13 +18,14 @@ def skip_torch_weight_init():
|
||||
monkey-patches common torch layers to skip the weight initialization step.
|
||||
"""
|
||||
torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd, torch.nn.Embedding]
|
||||
saved_functions = [m.reset_parameters for m in torch_modules]
|
||||
saved_functions = [hasattr(m, "reset_parameters") and m.reset_parameters for m in torch_modules]
|
||||
|
||||
try:
|
||||
for torch_module in torch_modules:
|
||||
assert hasattr(torch_module, "reset_parameters")
|
||||
torch_module.reset_parameters = _no_op
|
||||
|
||||
yield None
|
||||
finally:
|
||||
for torch_module, saved_function in zip(torch_modules, saved_functions, strict=True):
|
||||
assert hasattr(torch_module, "reset_parameters")
|
||||
torch_module.reset_parameters = saved_function
|
||||
|
Reference in New Issue
Block a user