Add docs to the requantize(...) function explaining why it was copied from optimum-quanto.

This commit is contained in:
Ryan Dick 2024-08-21 18:19:47 +00:00
parent d11dc6ddd0
commit 38c2e7801f

View File

@ -3,19 +3,21 @@ from typing import Any, Dict
import torch import torch
from optimum.quanto.quantize import _quantize_submodule from optimum.quanto.quantize import _quantize_submodule
# def custom_freeze(model: torch.nn.Module):
# for name, m in model.named_modules():
# if isinstance(m, QModuleMixin):
# m.weight =
# m.freeze()
def requantize( def requantize(
model: torch.nn.Module, model: torch.nn.Module,
state_dict: Dict[str, Any], state_dict: Dict[str, Any],
quantization_map: Dict[str, Dict[str, str]], quantization_map: Dict[str, Dict[str, str]],
device: torch.device = None, device: torch.device | None = None,
): ):
"""This function was initially copied from:
https://github.com/huggingface/optimum-quanto/blob/832f7f5c3926c91fe4f923aaaf037a780ac3e6c3/optimum/quanto/quantize.py#L101
The function was modified to remove the `freeze()` call. The `freeze()` call is very slow and unnecessary when the
weights are about to be loaded from a state_dict.
TODO(ryand): Unless I'm overlooking something, this should be contributed upstream to the `optimum-quanto` library.
"""
if device is None: if device is None:
device = next(model.parameters()).device device = next(model.parameters()).device
if device.type == "meta": if device.type == "meta":
@ -45,6 +47,7 @@ def requantize(
setattr(m, name, torch.nn.Parameter(move_tensor(param, "cpu"))) setattr(m, name, torch.nn.Parameter(move_tensor(param, "cpu")))
for name, param in m.named_buffers(recurse=False): for name, param in m.named_buffers(recurse=False):
setattr(m, name, move_tensor(param, "cpu")) setattr(m, name, move_tensor(param, "cpu"))
# Freeze model and move to target device # Freeze model and move to target device
# freeze(model) # freeze(model)
# model.to(device) # model.to(device)