InvokeAI/invokeai/backend/quantization/requantize.py

57 lines
2.1 KiB
Python
Raw Normal View History

2024-08-09 16:23:37 +00:00
from typing import Any, Dict
import torch
2024-08-09 16:39:43 +00:00
from optimum.quanto.quantize import _quantize_submodule
2024-08-09 16:23:37 +00:00
def requantize(
model: torch.nn.Module,
state_dict: Dict[str, Any],
quantization_map: Dict[str, Dict[str, str]],
device: torch.device | None = None,
2024-08-09 16:23:37 +00:00
):
"""This function was initially copied from:
https://github.com/huggingface/optimum-quanto/blob/832f7f5c3926c91fe4f923aaaf037a780ac3e6c3/optimum/quanto/quantize.py#L101
The function was modified to remove the `freeze()` call. The `freeze()` call is very slow and unnecessary when the
weights are about to be loaded from a state_dict.
TODO(ryand): Unless I'm overlooking something, this should be contributed upstream to the `optimum-quanto` library.
"""
2024-08-09 16:23:37 +00:00
if device is None:
device = next(model.parameters()).device
if device.type == "meta":
device = torch.device("cpu")
# Quantize the model with parameters from the quantization map
for name, m in model.named_modules():
qconfig = quantization_map.get(name, None)
if qconfig is not None:
weights = qconfig["weights"]
if weights == "none":
weights = None
activations = qconfig["activations"]
if activations == "none":
activations = None
_quantize_submodule(model, name, m, weights=weights, activations=activations)
# Move model parameters and buffers to CPU before materializing quantized weights
for name, m in model.named_modules():
def move_tensor(t, device):
if t.device.type == "meta":
return torch.empty_like(t, device=device)
return t.to(device)
for name, param in m.named_parameters(recurse=False):
setattr(m, name, torch.nn.Parameter(move_tensor(param, "cpu")))
for name, param in m.named_buffers(recurse=False):
setattr(m, name, move_tensor(param, "cpu"))
2024-08-09 16:23:37 +00:00
# Freeze model and move to target device
2024-08-09 16:39:43 +00:00
# freeze(model)
# model.to(device)
2024-08-09 16:23:37 +00:00
# Load the quantized model weights
model.load_state_dict(state_dict, strict=False)