Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: adjust _convert_weight_to_int4pack_cpu input weights for pytorch>=2.5 #286

Merged
merged 1 commit into from
Aug 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
fix: adjust _convert_weight_to_int4pack_cpu input weights for pytorch…
…>=2.5

Fixes: #274

PyTorch 2.5 adjusted input weights of _convert_weight_to_int4pack_cpu
from [n][k] int32 to [n][k / 2] uint8. Changing quanto code accordingly.

See: pytorch/pytorch#129940
See: pytorch/pytorch@6f662e9
Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
dvrogozh committed Aug 16, 2024
commit 37c83b578dff5f244956f21351afcd02438601ca
7 changes: 6 additions & 1 deletion bench/torch_kernels/test_weight_int4pack_mm.py
Original file line number Diff line number Diff line change
@@ -16,6 +16,7 @@
import timeit

import torch
from packaging import version


def _group_quantize_tensor(w, n_bit=4, q_group_size=16):
@@ -90,7 +91,11 @@ def avg_time(f, it):
B = torch.rand([3200, 4800], dtype=dtype, device=device)
group_size = 128
B_int32, B_scale_and_zeros = _group_quantize_tensor(B, n_bit=4, q_group_size=group_size)
B_packed = torch._convert_weight_to_int4pack(B_int32, innerKTiles=2)
if version.parse(torch.__version__).release >= version.parse("2.5.0").release:
B_uint8 = (B_int32[::, ::2] << 4 | B_int32[::, 1::2]).to(torch.uint8)
B_packed = torch._convert_weight_to_int4pack(B_uint8, innerKTiles=2)
else:
B_packed = torch._convert_weight_to_int4pack(B_int32, innerKTiles=2)

# Check quantized mm is close to float mm
qout = torch._weight_int4pack_mm(A, B_packed, group_size, B_scale_and_zeros)
7 changes: 6 additions & 1 deletion optimum/quanto/tensor/qbits/tinygemm/packed.py
Original file line number Diff line number Diff line change
@@ -16,6 +16,7 @@
from copy import copy

import torch
from packaging import version
from torch.utils import _pytree as pytree


@@ -53,7 +54,11 @@ def pack(cls, t):
"""
inner_ktiles = 2
t = t.to(torch.int32).contiguous()
data = torch._convert_weight_to_int4pack(t, innerKTiles=inner_ktiles)
if version.parse(torch.__version__).release >= version.parse("2.5.0").release:
t_uint8 = (t[::, ::2] << 4 | t[::, 1::2]).to(torch.uint8)
data = torch._convert_weight_to_int4pack(t_uint8, innerKTiles=inner_ktiles)
else:
data = torch._convert_weight_to_int4pack(t, innerKTiles=inner_ktiles)
# We need to store size and stride to make sure the unpacked data has the correct shape
return TinyGemmPackedTensor(data, t.size(), t.stride())