mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add docs to the requantize(...) function explaining why it was copied from optimum-quanto.
This commit is contained in:
parent
d11dc6ddd0
commit
38c2e7801f
@ -3,19 +3,21 @@ from typing import Any, Dict
|
||||
import torch
|
||||
from optimum.quanto.quantize import _quantize_submodule
|
||||
|
||||
# def custom_freeze(model: torch.nn.Module):
|
||||
# for name, m in model.named_modules():
|
||||
# if isinstance(m, QModuleMixin):
|
||||
# m.weight =
|
||||
# m.freeze()
|
||||
|
||||
|
||||
def requantize(
|
||||
model: torch.nn.Module,
|
||||
state_dict: Dict[str, Any],
|
||||
quantization_map: Dict[str, Dict[str, str]],
|
||||
device: torch.device = None,
|
||||
device: torch.device | None = None,
|
||||
):
|
||||
"""This function was initially copied from:
|
||||
https://github.com/huggingface/optimum-quanto/blob/832f7f5c3926c91fe4f923aaaf037a780ac3e6c3/optimum/quanto/quantize.py#L101
|
||||
|
||||
The function was modified to remove the `freeze()` call. The `freeze()` call is very slow and unnecessary when the
|
||||
weights are about to be loaded from a state_dict.
|
||||
|
||||
TODO(ryand): Unless I'm overlooking something, this should be contributed upstream to the `optimum-quanto` library.
|
||||
"""
|
||||
if device is None:
|
||||
device = next(model.parameters()).device
|
||||
if device.type == "meta":
|
||||
@ -45,6 +47,7 @@ def requantize(
|
||||
setattr(m, name, torch.nn.Parameter(move_tensor(param, "cpu")))
|
||||
for name, param in m.named_buffers(recurse=False):
|
||||
setattr(m, name, move_tensor(param, "cpu"))
|
||||
|
||||
# Freeze model and move to target device
|
||||
# freeze(model)
|
||||
# model.to(device)
|
||||
|
Loading…
Reference in New Issue
Block a user