mirror of
https://github.com/invoke-ai/InvokeAI
synced 2025-07-26 23:58:14 +00:00
pytorch 2.7 does not implement `set.__contains__`, so make this a list instead. See https://github.com/pytorch/pytorch/issues/145761
310 lines
10 KiB
Python
310 lines
10 KiB
Python
# Largely based on https://github.com/city96/ComfyUI-GGUF
|
|
|
|
from typing import Callable, Optional, Union
|
|
|
|
import gguf
|
|
import torch
|
|
|
|
# should not be a Set until this is resolved: https://github.com/pytorch/pytorch/issues/145761
|
|
TORCH_COMPATIBLE_QTYPES = [None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16]
|
|
|
|
# K Quants #
|
|
QK_K = 256
|
|
K_SCALE_SIZE = 12
|
|
|
|
|
|
def get_scale_min(scales: torch.Tensor):
|
|
n_blocks = scales.shape[0]
|
|
scales = scales.view(torch.uint8)
|
|
scales = scales.reshape((n_blocks, 3, 4))
|
|
|
|
d, m, m_d = torch.split(scales, scales.shape[-2] // 3, dim=-2)
|
|
|
|
sc = torch.cat([d & 0x3F, (m_d & 0x0F) | ((d >> 2) & 0x30)], dim=-1)
|
|
min = torch.cat([m & 0x3F, (m_d >> 4) | ((m >> 2) & 0x30)], dim=-1)
|
|
|
|
return (sc.reshape((n_blocks, 8)), min.reshape((n_blocks, 8)))
|
|
|
|
|
|
# Legacy Quants #
|
|
def dequantize_blocks_Q8_0(
|
|
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
|
) -> torch.Tensor:
|
|
d, x = split_block_dims(blocks, 2)
|
|
d = d.view(torch.float16).to(dtype)
|
|
x = x.view(torch.int8)
|
|
return d * x
|
|
|
|
|
|
def dequantize_blocks_Q5_1(
|
|
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
|
) -> torch.Tensor:
|
|
n_blocks = blocks.shape[0]
|
|
|
|
d, m, qh, qs = split_block_dims(blocks, 2, 2, 4)
|
|
d = d.view(torch.float16).to(dtype)
|
|
m = m.view(torch.float16).to(dtype)
|
|
qh = to_uint32(qh)
|
|
|
|
qh = qh.reshape((n_blocks, 1)) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
|
|
ql = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor(
|
|
[0, 4], device=d.device, dtype=torch.uint8
|
|
).reshape(1, 1, 2, 1)
|
|
qh = (qh & 1).to(torch.uint8)
|
|
ql = (ql & 0x0F).reshape((n_blocks, -1))
|
|
|
|
qs = ql | (qh << 4)
|
|
return (d * qs) + m
|
|
|
|
|
|
def dequantize_blocks_Q5_0(
|
|
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
|
) -> torch.Tensor:
|
|
n_blocks = blocks.shape[0]
|
|
|
|
d, qh, qs = split_block_dims(blocks, 2, 4)
|
|
d = d.view(torch.float16).to(dtype)
|
|
qh = to_uint32(qh)
|
|
|
|
qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
|
|
ql = qs.reshape(n_blocks, -1, 1, block_size // 2) >> torch.tensor(
|
|
[0, 4], device=d.device, dtype=torch.uint8
|
|
).reshape(1, 1, 2, 1)
|
|
|
|
qh = (qh & 1).to(torch.uint8)
|
|
ql = (ql & 0x0F).reshape(n_blocks, -1)
|
|
|
|
qs = (ql | (qh << 4)).to(torch.int8) - 16
|
|
return d * qs
|
|
|
|
|
|
def dequantize_blocks_Q4_1(
|
|
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
|
) -> torch.Tensor:
|
|
n_blocks = blocks.shape[0]
|
|
|
|
d, m, qs = split_block_dims(blocks, 2, 2)
|
|
d = d.view(torch.float16).to(dtype)
|
|
m = m.view(torch.float16).to(dtype)
|
|
|
|
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor(
|
|
[0, 4], device=d.device, dtype=torch.uint8
|
|
).reshape(1, 1, 2, 1)
|
|
qs = (qs & 0x0F).reshape(n_blocks, -1)
|
|
|
|
return (d * qs) + m
|
|
|
|
|
|
def dequantize_blocks_Q4_0(
|
|
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
|
) -> torch.Tensor:
|
|
n_blocks = blocks.shape[0]
|
|
|
|
d, qs = split_block_dims(blocks, 2)
|
|
d = d.view(torch.float16).to(dtype)
|
|
|
|
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor(
|
|
[0, 4], device=d.device, dtype=torch.uint8
|
|
).reshape((1, 1, 2, 1))
|
|
qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8
|
|
return d * qs
|
|
|
|
|
|
def dequantize_blocks_BF16(
|
|
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
|
) -> torch.Tensor:
|
|
return (blocks.view(torch.int16).to(torch.int32) << 16).view(torch.float32)
|
|
|
|
|
|
def dequantize_blocks_Q6_K(
|
|
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
|
) -> torch.Tensor:
|
|
n_blocks = blocks.shape[0]
|
|
|
|
(
|
|
ql,
|
|
qh,
|
|
scales,
|
|
d,
|
|
) = split_block_dims(blocks, QK_K // 2, QK_K // 4, QK_K // 16)
|
|
|
|
scales = scales.view(torch.int8).to(dtype)
|
|
d = d.view(torch.float16).to(dtype)
|
|
d = (d * scales).reshape((n_blocks, QK_K // 16, 1))
|
|
|
|
ql = ql.reshape((n_blocks, -1, 1, 64)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(
|
|
(1, 1, 2, 1)
|
|
)
|
|
ql = (ql & 0x0F).reshape((n_blocks, -1, 32))
|
|
qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape(
|
|
(1, 1, 4, 1)
|
|
)
|
|
qh = (qh & 0x03).reshape((n_blocks, -1, 32))
|
|
q = (ql | (qh << 4)).to(torch.int8) - 32
|
|
q = q.reshape((n_blocks, QK_K // 16, -1))
|
|
|
|
return (d * q).reshape((n_blocks, QK_K))
|
|
|
|
|
|
def dequantize_blocks_Q5_K(
|
|
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
|
) -> torch.Tensor:
|
|
n_blocks = blocks.shape[0]
|
|
|
|
d, dmin, scales, qh, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE, QK_K // 8)
|
|
|
|
d = d.view(torch.float16).to(dtype)
|
|
dmin = dmin.view(torch.float16).to(dtype)
|
|
|
|
sc, m = get_scale_min(scales)
|
|
|
|
d = (d * sc).reshape((n_blocks, -1, 1))
|
|
dm = (dmin * m).reshape((n_blocks, -1, 1))
|
|
|
|
ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(
|
|
(1, 1, 2, 1)
|
|
)
|
|
qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.tensor(list(range(8)), device=d.device, dtype=torch.uint8).reshape(
|
|
(1, 1, 8, 1)
|
|
)
|
|
ql = (ql & 0x0F).reshape((n_blocks, -1, 32))
|
|
qh = (qh & 0x01).reshape((n_blocks, -1, 32))
|
|
q = ql | (qh << 4)
|
|
|
|
return (d * q - dm).reshape((n_blocks, QK_K))
|
|
|
|
|
|
def dequantize_blocks_Q4_K(
|
|
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
|
) -> torch.Tensor:
|
|
n_blocks = blocks.shape[0]
|
|
|
|
d, dmin, scales, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE)
|
|
d = d.view(torch.float16).to(dtype)
|
|
dmin = dmin.view(torch.float16).to(dtype)
|
|
|
|
sc, m = get_scale_min(scales)
|
|
|
|
d = (d * sc).reshape((n_blocks, -1, 1))
|
|
dm = (dmin * m).reshape((n_blocks, -1, 1))
|
|
|
|
qs = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(
|
|
(1, 1, 2, 1)
|
|
)
|
|
qs = (qs & 0x0F).reshape((n_blocks, -1, 32))
|
|
|
|
return (d * qs - dm).reshape((n_blocks, QK_K))
|
|
|
|
|
|
def dequantize_blocks_Q3_K(
|
|
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
|
) -> torch.Tensor:
|
|
n_blocks = blocks.shape[0]
|
|
|
|
hmask, qs, scales, d = split_block_dims(blocks, QK_K // 8, QK_K // 4, 12)
|
|
d = d.view(torch.float16).to(dtype)
|
|
|
|
lscales, hscales = scales[:, :8], scales[:, 8:]
|
|
lscales = lscales.reshape((n_blocks, 1, 8)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(
|
|
(1, 2, 1)
|
|
)
|
|
lscales = lscales.reshape((n_blocks, 16))
|
|
hscales = hscales.reshape((n_blocks, 1, 4)) >> torch.tensor(
|
|
[0, 2, 4, 6], device=d.device, dtype=torch.uint8
|
|
).reshape((1, 4, 1))
|
|
hscales = hscales.reshape((n_blocks, 16))
|
|
scales = (lscales & 0x0F) | ((hscales & 0x03) << 4)
|
|
scales = scales.to(torch.int8) - 32
|
|
|
|
dl = (d * scales).reshape((n_blocks, 16, 1))
|
|
|
|
ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape(
|
|
(1, 1, 4, 1)
|
|
)
|
|
qh = hmask.reshape(n_blocks, -1, 1, 32) >> torch.tensor(list(range(8)), device=d.device, dtype=torch.uint8).reshape(
|
|
(1, 1, 8, 1)
|
|
)
|
|
ql = ql.reshape((n_blocks, 16, QK_K // 16)) & 3
|
|
qh = (qh.reshape((n_blocks, 16, QK_K // 16)) & 1) ^ 1
|
|
q = ql.to(torch.int8) - (qh << 2).to(torch.int8)
|
|
|
|
return (dl * q).reshape((n_blocks, QK_K))
|
|
|
|
|
|
def dequantize_blocks_Q2_K(
|
|
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
|
) -> torch.Tensor:
|
|
n_blocks = blocks.shape[0]
|
|
|
|
scales, qs, d, dmin = split_block_dims(blocks, QK_K // 16, QK_K // 4, 2)
|
|
d = d.view(torch.float16).to(dtype)
|
|
dmin = dmin.view(torch.float16).to(dtype)
|
|
|
|
# (n_blocks, 16, 1)
|
|
dl = (d * (scales & 0xF)).reshape((n_blocks, QK_K // 16, 1))
|
|
ml = (dmin * (scales >> 4)).reshape((n_blocks, QK_K // 16, 1))
|
|
|
|
shift = torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 1, 4, 1))
|
|
|
|
qs = (qs.reshape((n_blocks, -1, 1, 32)) >> shift) & 3
|
|
qs = qs.reshape((n_blocks, QK_K // 16, 16))
|
|
qs = dl * qs - ml
|
|
|
|
return qs.reshape((n_blocks, -1))
|
|
|
|
|
|
DEQUANTIZE_FUNCTIONS: dict[
|
|
gguf.GGMLQuantizationType, Callable[[torch.Tensor, int, int, Optional[torch.dtype]], torch.Tensor]
|
|
] = {
|
|
gguf.GGMLQuantizationType.BF16: dequantize_blocks_BF16,
|
|
gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0,
|
|
gguf.GGMLQuantizationType.Q5_1: dequantize_blocks_Q5_1,
|
|
gguf.GGMLQuantizationType.Q5_0: dequantize_blocks_Q5_0,
|
|
gguf.GGMLQuantizationType.Q4_1: dequantize_blocks_Q4_1,
|
|
gguf.GGMLQuantizationType.Q4_0: dequantize_blocks_Q4_0,
|
|
gguf.GGMLQuantizationType.Q6_K: dequantize_blocks_Q6_K,
|
|
gguf.GGMLQuantizationType.Q5_K: dequantize_blocks_Q5_K,
|
|
gguf.GGMLQuantizationType.Q4_K: dequantize_blocks_Q4_K,
|
|
gguf.GGMLQuantizationType.Q3_K: dequantize_blocks_Q3_K,
|
|
gguf.GGMLQuantizationType.Q2_K: dequantize_blocks_Q2_K,
|
|
}
|
|
|
|
|
|
def is_torch_compatible(tensor: Optional[torch.Tensor]):
|
|
return getattr(tensor, "tensor_type", None) in TORCH_COMPATIBLE_QTYPES
|
|
|
|
|
|
def is_quantized(tensor: torch.Tensor):
|
|
return not is_torch_compatible(tensor)
|
|
|
|
|
|
def dequantize(
|
|
data: torch.Tensor, qtype: gguf.GGMLQuantizationType, oshape: torch.Size, dtype: Optional[torch.dtype] = None
|
|
):
|
|
"""
|
|
Dequantize tensor back to usable shape/dtype
|
|
"""
|
|
block_size, type_size = gguf.GGML_QUANT_SIZES[qtype]
|
|
dequantize_blocks = DEQUANTIZE_FUNCTIONS[qtype]
|
|
|
|
rows = data.reshape((-1, data.shape[-1])).view(torch.uint8)
|
|
|
|
n_blocks = rows.numel() // type_size
|
|
blocks = rows.reshape((n_blocks, type_size))
|
|
blocks = dequantize_blocks(blocks, block_size, type_size, dtype)
|
|
return blocks.reshape(oshape)
|
|
|
|
|
|
def to_uint32(x: torch.Tensor) -> torch.Tensor:
|
|
x = x.view(torch.uint8).to(torch.int32)
|
|
return (x[:, 0] | x[:, 1] << 8 | x[:, 2] << 16 | x[:, 3] << 24).unsqueeze(1)
|
|
|
|
|
|
def split_block_dims(blocks: torch.Tensor, *args):
|
|
n_max = blocks.shape[1]
|
|
dims = list(args) + [n_max - sum(args)]
|
|
return torch.split(blocks, dims, dim=1)
|
|
|
|
|
|
PATCH_TYPES = Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]]
|