from typing import Any, Dict import torch from optimum.quanto.quantize import _quantize_submodule def requantize( model: torch.nn.Module, state_dict: Dict[str, Any], quantization_map: Dict[str, Dict[str, str]], device: torch.device | None = None, ): """This function was initially copied from: https://github.com/huggingface/optimum-quanto/blob/832f7f5c3926c91fe4f923aaaf037a780ac3e6c3/optimum/quanto/quantize.py#L101 The function was modified to remove the `freeze()` call. The `freeze()` call is very slow and unnecessary when the weights are about to be loaded from a state_dict. TODO(ryand): Unless I'm overlooking something, this should be contributed upstream to the `optimum-quanto` library. """ if device is None: device = next(model.parameters()).device if device.type == "meta": device = torch.device("cpu") # Quantize the model with parameters from the quantization map for name, m in model.named_modules(): qconfig = quantization_map.get(name, None) if qconfig is not None: weights = qconfig["weights"] if weights == "none": weights = None activations = qconfig["activations"] if activations == "none": activations = None _quantize_submodule(model, name, m, weights=weights, activations=activations) # Move model parameters and buffers to CPU before materializing quantized weights for name, m in model.named_modules(): def move_tensor(t, device): if t.device.type == "meta": return torch.empty_like(t, device=device) return t.to(device) for name, param in m.named_parameters(recurse=False): setattr(m, name, torch.nn.Parameter(move_tensor(param, "cpu"))) for name, param in m.named_buffers(recurse=False): setattr(m, name, move_tensor(param, "cpu")) # Freeze model and move to target device # freeze(model) # model.to(device) # Load the quantized model weights model.load_state_dict(state_dict, strict=False)