diff --git a/invokeai/backend/requantize.py b/invokeai/backend/requantize.py index 5f506f487d..aae85bed7c 100644 --- a/invokeai/backend/requantize.py +++ b/invokeai/backend/requantize.py @@ -3,19 +3,21 @@ from typing import Any, Dict import torch from optimum.quanto.quantize import _quantize_submodule -# def custom_freeze(model: torch.nn.Module): -# for name, m in model.named_modules(): -# if isinstance(m, QModuleMixin): -# m.weight = -# m.freeze() - def requantize( model: torch.nn.Module, state_dict: Dict[str, Any], quantization_map: Dict[str, Dict[str, str]], - device: torch.device = None, + device: torch.device | None = None, ): + """This function was initially copied from: + https://github.com/huggingface/optimum-quanto/blob/832f7f5c3926c91fe4f923aaaf037a780ac3e6c3/optimum/quanto/quantize.py#L101 + + The function was modified to remove the `freeze()` call. The `freeze()` call is very slow and unnecessary when the + weights are about to be loaded from a state_dict. + + TODO(ryand): Unless I'm overlooking something, this should be contributed upstream to the `optimum-quanto` library. + """ if device is None: device = next(model.parameters()).device if device.type == "meta": @@ -45,6 +47,7 @@ def requantize( setattr(m, name, torch.nn.Parameter(move_tensor(param, "cpu"))) for name, param in m.named_buffers(recurse=False): setattr(m, name, move_tensor(param, "cpu")) + # Freeze model and move to target device # freeze(model) # model.to(device)