Make quantized loading fast.

This commit is contained in:
Ryan Dick 2024-08-09 16:39:43 +00:00
parent 4181ab654b
commit d23ad1818d

View File

@ -1,14 +1,13 @@
from typing import Any, Dict from typing import Any, Dict
import torch import torch
from optimum.quanto.nn import QModuleMixin from optimum.quanto.quantize import _quantize_submodule
from optimum.quanto.quantize import _quantize_submodule, freeze
# def custom_freeze(model: torch.nn.Module):
def custom_freeze(model: torch.nn.Module): # for name, m in model.named_modules():
for name, m in model.named_modules(): # if isinstance(m, QModuleMixin):
if isinstance(m, QModuleMixin): # m.weight =
m.freeze() # m.freeze()
def requantize( def requantize(
@ -47,8 +46,8 @@ def requantize(
for name, param in m.named_buffers(recurse=False): for name, param in m.named_buffers(recurse=False):
setattr(m, name, move_tensor(param, "cpu")) setattr(m, name, move_tensor(param, "cpu"))
# Freeze model and move to target device # Freeze model and move to target device
freeze(model) # freeze(model)
model.to(device) # model.to(device)
# Load the quantized model weights # Load the quantized model weights
model.load_state_dict(state_dict, strict=False) model.load_state_dict(state_dict, strict=False)