Skip to content

Commit

Permalink
add slicing
Browse files Browse the repository at this point in the history
  • Loading branch information
SunMarc committed Mar 7, 2024
1 parent ee341dd commit b1f0de5
Showing 8 changed files with 86 additions and 49 deletions.
30 changes: 21 additions & 9 deletions bench/library/benchmark.py
Original file line number Diff line number Diff line change
@@ -15,14 +15,26 @@ def get_udqmm_bench(input_dtype, device, bits):
input = torch.rand([128, 128], dtype=input_dtype).to(device)
weight = torch.randint(-127, 127, [128, 128], dtype=torch.int8).to(device)

input_shape = weight.shape
grouped_weights = group(weight, axis=0, group_size=int(input_shape[-1] / 4))
orig_shape = weight.shape
grouped_weights = group(weight, axis=0, group_size=int(orig_shape[-1] / 4))
scale = torch.ones((1, grouped_weights.shape[1]), dtype=input_dtype, device=device) * 0.5
zeropoint = torch.randint(
torch.iinfo(torch.int8).min, torch.iinfo(torch.int8).max, (1, grouped_weights.shape[1]), dtype=torch.int8
).to(device)

packed_weights = pack_weights(grouped_weights, bits)

def bench_fn():
return torch.ops.quanto.udqmm(input, packed_weights, scale, bits)
return torch.ops.quanto.udqmm(
input,
packed_weights,
scale,
zeropoint,
axis=0,
bits=bits,
orig_shape=orig_shape,
unpacked_shape=grouped_weights.shape,
)

return bench_fn

@@ -109,12 +121,12 @@ def elapsed_time(self, other):


GET_BENCH_FUNCTIONS = {
"dqmm_w8a16": lambda device: get_dqmm_bench(torch.float16, device),
"quantize_symmetric_fp32_int8_per_tensor": lambda device: get_quantize_symmetric_bench(
torch.float32, torch.int8, False, device
),
"unpack_2bit": lambda device: get_unpack_bench(2, device),
"unpack_4bit": lambda device: get_unpack_bench(4, device),
# "dqmm_w8a16": lambda device: get_dqmm_bench(torch.float16, device),
# "quantize_symmetric_fp32_int8_per_tensor": lambda device: get_quantize_symmetric_bench(
# torch.float32, torch.int8, False, device
# ),
# "unpack_2bit": lambda device: get_unpack_bench(2, device),
# "unpack_4bit": lambda device: get_unpack_bench(4, device),
"udqmm_4bit": lambda device: get_udqmm_bench(torch.float16, device, 4),
}

3 changes: 2 additions & 1 deletion quanto/library/ext/cpp/__init__.py
Original file line number Diff line number Diff line change
@@ -53,5 +53,6 @@ def udqmm_cpp(
axis: int,
bits: int,
orig_shape: torch.Size,
unpacked_shape: torch.Size,
):
return ext().udqmm(input, weights, scale, zeropoint, axis, bits, orig_shape)
return ext().udqmm(input, weights, scale, zeropoint, axis, bits, orig_shape, unpacked_shape)
9 changes: 6 additions & 3 deletions quanto/library/ext/cpp/udqmm.cpp
Original file line number Diff line number Diff line change
@@ -6,12 +6,15 @@

using namespace std;

torch::Tensor udqmm(torch::Tensor &input, torch::Tensor &weights, torch::Tensor& scale, torch::Tensor& zeropoint, int axis, int bits, torch::IntArrayRef orig_shape) {
torch::Tensor udqmm(torch::Tensor &input, torch::Tensor &weights, torch::Tensor &scale, torch::Tensor &zeropoint, int axis, int bits, torch::IntArrayRef orig_shape, torch::IntArrayRef unpacked_shape) {
TORCH_CHECK(zeropoint.dtype() == torch::kInt8, "zeropoint must have dtype: torch.int8");

torch::Tensor unpacked_weights = unpack(weights, bits);
torch::Tensor dq_output = (unpacked_weights.to(torch::kInt8) - zeropoint.to(torch::kInt8)) * scale;
// slice along the first dim from index 0 to unpacked_shape[0]
unpacked_weights = unpacked_weights.slice(0, 0, unpacked_shape[0]);
torch::Tensor dq_output = (unpacked_weights.to(torch::kInt8) - zeropoint).to(scale.dtype()) * scale;

torch::Tensor ungrouped_output;

// Ungroup TODO : put on its own function
if (dq_output.sizes() == orig_shape){
ungrouped_output = dq_output;
2 changes: 1 addition & 1 deletion quanto/library/ext/cpp/udqmm.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#include <torch/extension.h>

torch::Tensor udqmm(torch::Tensor &input, torch::Tensor &weights, torch::Tensor& scale, torch::Tensor& zeropoint, int axis, int bits, torch::IntArrayRef orig_shape);
torch::Tensor udqmm(torch::Tensor &input, torch::Tensor &weights, torch::Tensor &scale, torch::Tensor &zeropoint, int axis, int bits, torch::IntArrayRef orig_shape, torch::IntArrayRef unpacked_shape);
2 changes: 1 addition & 1 deletion quanto/library/ops.py
Original file line number Diff line number Diff line change
@@ -58,5 +58,5 @@ def impl(*args, **kwargs):
define("unpack", "(Tensor self, int bits) -> Tensor")
define(
"udqmm",
"(Tensor input, Tensor weight, Tensor scales, Tensor zeropoint, int axis, int bits, Any orig_shape) -> Tensor",
"(Tensor input, Tensor weight, Tensor scales, Tensor zeropoint, int axis, int bits, Any orig_shape, Any unpacked_shape) -> Tensor",
)
7 changes: 5 additions & 2 deletions quanto/library/python/udqmm.py
Original file line number Diff line number Diff line change
@@ -12,9 +12,12 @@ def udqmm(
axis: int,
bits: int,
orig_shape: torch.Size,
unpacked_shape: torch.Size,
):
unpacked_weights = torch.ops.quanto.unpack(weights, bits)
shifted_weights = unpacked_weights.to(torch.int8) - zeropoint.to(torch.int8)
scaled_weights = shifted_weights * scale
# TODO : we should proably add that in unpack with with arg unpacked_shape.
unpacked_weights_resized = unpacked_weights[: unpacked_shape[0]]
shifted_weights = unpacked_weights_resized.to(torch.int8) - zeropoint.to(torch.int8)
scaled_weights = shifted_weights.to(scale.dtype) * scale
ungrouped_weights = ungroup(scaled_weights, axis, orig_shape)
return torch.ops.aten.mm(input, ungrouped_weights)
4 changes: 3 additions & 1 deletion quanto/tensor/qbitstensor.py
Original file line number Diff line number Diff line change
@@ -134,6 +134,8 @@ def __torch_dispatch__(cls, op, types, args, kwargs=None):
elif op.overloadpacket is torch.ops.aten.mm:
input = args[0]
t = args[1]
return torch.ops.quanto.udqmm(input, t._data, t._scale, t._zeropoint, t._axis, t._bits, t.shape)
return torch.ops.quanto.udqmm(
input, t._data, t._scale, t._zeropoint, t._axis, t._bits, t.shape, t._data.shape
)
args, kwargs = pytree.tree_map_only(QBitsTensor, lambda x: x.qtensor(), (args, kwargs or {}))
return op(*args, **kwargs)
78 changes: 47 additions & 31 deletions test/library/test_mm.py
Original file line number Diff line number Diff line change
@@ -6,18 +6,21 @@

from quanto.library import disable_extensions
from quanto.tensor.core import group, ungroup
from quanto.tensor.packed import pack_weights, unpack_weights
from quanto.tensor.packed import pack_weights


@pytest.mark.parametrize("input_shape", [[10, 32], [32, 32]])
@pytest.mark.parametrize("output_features", [48, 64])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_dqmm(input_shape, output_features, dtype, device):
@pytest.mark.parametrize("use_ext", [True, False], ids=["ext", "no-ext"])
def test_dqmm(input_shape, output_features, dtype, device, use_ext):
input = random_tensor(input_shape, dtype=dtype).to(device)
other = torch.randint(-127, 127, (input_shape[-1], output_features), dtype=torch.int8).to(device)
other_scale = random_tensor((output_features,), dtype=dtype).to(device)
output = torch.ops.quanto.dqmm(input, other, other_scale)
expected = torch.ops.aten.mm(input, other * other_scale)
context = nullcontext() if use_ext else disable_extensions()
with context:
output = torch.ops.quanto.dqmm(input, other, other_scale)
expected = torch.ops.aten.mm(input, other * other_scale)
assert torch.equal(expected, output)


@@ -27,25 +30,33 @@ def test_dqmm(input_shape, output_features, dtype, device):
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
@pytest.mark.parametrize("use_ext", [True, False], ids=["ext", "no-ext"])
def test_packed_udqmm(input_shape, output_features, dtype, device, bits, use_ext):
input = random_tensor(input_shape, dtype=dtype).to(device)

qmax = 2**bits
a = torch.randint(0, qmax, (input_shape[-1], output_features), dtype=torch.uint8).to(device)
input = random_tensor(input_shape, dtype=dtype).to(device)
weights = torch.randint(0, qmax, (input_shape[-1], output_features), dtype=torch.uint8).to(device)
packed_weights = pack_weights(weights, bits)

packed_a = pack_weights(a, bits)
unpacked_weights = unpack_weights(packed_a, bits)
other_scale = random_tensor((output_features,), dtype=dtype).to(device)
other_zeropoint = torch.randint(
scale = random_tensor((output_features,), dtype=dtype).to(device)
zeropoint = torch.randint(
torch.iinfo(torch.int8).min, torch.iinfo(torch.int8).max, (input_shape[-1], output_features), dtype=torch.int8
).to(device)

context = nullcontext() if use_ext else disable_extensions()
with context:
output = torch.ops.quanto.udqmm(
input, packed_a, other_scale, other_zeropoint, axis=0, bits=bits, orig_shape=a.shape
)
expected = torch.ops.aten.mm(
input, (unpacked_weights.to(torch.int8) - other_zeropoint.to(torch.int8)) * other_scale
input,
packed_weights,
scale,
zeropoint,
axis=0,
bits=bits,
orig_shape=weights.shape,
unpacked_shape=weights.shape,
)

unpacked_weights = torch.ops.quanto.unpack(packed_weights, bits)
# TODO: We should probably combine it with unpack
unpacked_weights = unpacked_weights[: weights.shape[0]]
expected = torch.ops.aten.mm(input, (unpacked_weights.to(torch.int8) - zeropoint.to(torch.int8)) * scale)
assert torch.equal(expected, output)


@@ -57,30 +68,35 @@ def test_packed_udqmm(input_shape, output_features, dtype, device, bits, use_ext
def test_grouped_udqmm(input_shape, output_features, dtype, device, bits, use_ext):
input = random_tensor(input_shape, dtype=dtype).to(device)
qmax = 2**bits

weights = torch.randint(0, qmax, (input_shape[-1], output_features), dtype=torch.uint8).to(device)
grouped_weights = group(weights, axis=0, group_size=int(input_shape[-1] / 4))
output_shape = grouped_weights.shape

packed_weights = pack_weights(grouped_weights, bits)
unpacked_weights = unpack_weights(packed_weights, bits)

other_scale = random_tensor((1, output_shape[1]), dtype=dtype).to(device)
other_zeropoint = torch.randint(
grouped_weights = group(weights, axis=0, group_size=int(input_shape[-1] / 4))
scale = random_tensor((1, grouped_weights.shape[1]), dtype=dtype).to(device)
zeropoint = torch.randint(
torch.iinfo(torch.int8).min, torch.iinfo(torch.int8).max, grouped_weights.shape, dtype=torch.int8
).to(device)

packed_weights = pack_weights(grouped_weights, bits)

context = nullcontext() if use_ext else disable_extensions()
with context:
output = torch.ops.quanto.udqmm(
input, packed_weights, other_scale, other_zeropoint, axis=0, bits=bits, orig_shape=weights.shape
)
expected = torch.ops.aten.mm(
input,
ungroup(
(unpacked_weights.to(torch.int8) - other_zeropoint.to(torch.int8)) * other_scale,
axis=0,
orig_shape=weights.shape,
),
packed_weights,
scale,
zeropoint,
axis=0,
bits=bits,
orig_shape=weights.shape,
unpacked_shape=grouped_weights.shape,
)
unpacked_weights = torch.ops.quanto.unpack(packed_weights, bits)
# TODO: We should probably combine it with unpack
unpacked_weights = unpacked_weights[: grouped_weights.shape[0]]
ungrouped_weights = ungroup(
(unpacked_weights.to(torch.int8) - zeropoint.to(torch.int8)) * scale,
axis=0,
orig_shape=weights.shape,
)
expected = torch.ops.aten.mm(input, ungrouped_weights)
assert torch.equal(expected, output)

0 comments on commit b1f0de5

Please sign in to comment.