Skip to content

Commit

Permalink
refactor(tensor): remove 8-bit groupwise support
Browse files Browse the repository at this point in the history
  • Loading branch information
dacorvo committed May 3, 2024
1 parent 503d256 commit 19d3e88
Show file tree
Hide file tree
Showing 8 changed files with 13 additions and 71 deletions.
7 changes: 1 addition & 6 deletions quanto/tensor/optimizers/symmetric_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import torch

from ..core import group
from .optimizer import Optimizer


Expand All @@ -25,13 +24,9 @@

class SymmetricOptimizer(Optimizer):

def __call__(
self, base: torch.Tensor, bits: int, axis: Optional[int] = None, group_size: Optional[int] = None
) -> torch.Tensor:
def __call__(self, base: torch.Tensor, bits: int, axis: Optional[int] = None) -> torch.Tensor:
if axis not in [None, 0, -1]:
raise ValueError("axis parameter must be None, 0 (first axis) or -1 (last axis)")
if group_size is not None:
base = group(base, axis=axis, group_size=group_size)
scale = self.optimize(base, bits, axis)
assert scale.dtype == base.dtype
return scale
Expand Down
2 changes: 1 addition & 1 deletion quanto/tensor/qactivation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,4 @@ def quantize_activation(t: torch.Tensor, qtype: qtype, scale: torch.Tensor):
"""
if scale.numel() != 1:
raise ValueError("Parameter scale must be a scalar because activations can only be quantized per-tensor")
return SymmetricQuantizer.apply(t, qtype, None, None, scale)
return SymmetricQuantizer.apply(t, qtype, None, scale)
2 changes: 1 addition & 1 deletion quanto/tensor/qbytes_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def transpose2d(op, input):
out_data = op(input._data)
out_scale = input._scale
out_axis = input.axis
# Manually reverse size and stride because we cannot trust the out_data shape when using group-wise quantization
# Manually reverse size and stride because we cannot trust the out_data shape
dim0, dim1 = input.size()
out_size = torch.Size([dim1, dim0])
out_stride = input.stride()[::-1]
Expand Down
13 changes: 4 additions & 9 deletions quanto/tensor/quantizers/symmetric.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch
from torch.autograd import Function

from ..core import dtype_info, group
from ..core import dtype_info
from ..qbytes import QBytesTensor
from ..qtype import qtype

Expand All @@ -27,15 +27,13 @@ class SymmetricQuantizer(Function):
"""A standard symmetric quantizer."""

@staticmethod
def forward(ctx, base: torch.Tensor, qtype: qtype, axis: int, group_size: int, scale: torch.Tensor):
def forward(ctx, base: torch.Tensor, qtype: qtype, axis: int, scale: torch.Tensor):
size = base.size()
stride = base.stride()
# Sanity checks
if axis is None:
if scale.ndim > 0:
raise ValueError("Scale must be a scalar when quantizing per-tensor")
if group_size is not None:
raise ValueError("Group size can only be specified when quantizing per-axis")
else:
if base.ndim == 1:
raise ValueError("1D Tensors cannot be quantized per-axis")
Expand All @@ -46,11 +44,8 @@ def forward(ctx, base: torch.Tensor, qtype: qtype, axis: int, group_size: int, s
raise ValueError("QBytesTensor can only be quantized along the first or last axis.")
if base.shape[axis] == 1:
raise ValueError(f"Cannot quantize Tensor of shape {base.shape} along axis {axis} of size 1")
if group_size is not None:
base = group(base, axis=axis, group_size=group_size)
else:
if torch.squeeze(scale).ndim > 1:
raise ValueError("Quantizing along multiple axis is not supported")
if torch.squeeze(scale).ndim > 1:
raise ValueError("Quantizing along multiple axis is not supported")
if scale.ndim != base.ndim:
raise ValueError(
"When quantizing per-axis, the scale must be broadcastable to the base (Tip: try to add missing dims of length zero)."
Expand Down
6 changes: 4 additions & 2 deletions quanto/tensor/qweight.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,10 @@ def quantize_weight(
else:
if not isinstance(optimizer, SymmetricOptimizer):
raise ValueError("A SymmetricOptimizer is expected")
scale = optimizer(t, qtype.bits, axis, group_size)
return SymmetricQuantizer.apply(t, qtype, axis, group_size, scale)
if group_size is not None:
raise ValueError("group_size cannot be specified for 8-bit qtypes.")
scale = optimizer(t, qtype.bits, axis)
return SymmetricQuantizer.apply(t, qtype, axis, scale)
if optimizer is None:
optimizer = default_affine_optimizer
else:
Expand Down
38 changes: 0 additions & 38 deletions test/tensor/ops/test_qweight_dispatch.py

This file was deleted.

4 changes: 2 additions & 2 deletions test/tensor/quantizers/test_symmetric.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
def test_symmetric_quantize_int(input_shape, dtype, qtype, axis, device):
a = random_tensor(input_shape, dtype=dtype).to(device)
scale = absmax_scale(a, qtype=qtype, axis=axis)
qa = SymmetricQuantizer.apply(a, qtype, axis, None, scale)
qa = SymmetricQuantizer.apply(a, qtype, axis, scale)
assert isinstance(qa, QBytesTensor)
assert qa.dtype == dtype
assert qa.qtype == qtype
Expand All @@ -62,7 +62,7 @@ def test_symmetric_quantize_int(input_shape, dtype, qtype, axis, device):
def test_symmetric_quantize_float8(input_shape, dtype, qtype, axis, device):
a = random_tensor(input_shape, dtype=dtype).to(device)
scale = absmax_scale(a, qtype=qtype, axis=axis)
qa = SymmetricQuantizer.apply(a, qtype, axis, None, scale)
qa = SymmetricQuantizer.apply(a, qtype, axis, scale)
assert isinstance(qa, QBytesTensor)
assert qa.dtype == dtype
assert qa.qtype == qtype
Expand Down
12 changes: 0 additions & 12 deletions test/tensor/test_qbytestensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import io
from math import prod

import pytest
import torch
Expand Down Expand Up @@ -123,17 +122,6 @@ def test_qbytestensor_contiguous(axis, qtype, device):
assert tqa.is_contiguous()


@pytest.mark.parametrize("input_shape, group_size", [[(4, 6), 2], [(32, 64), 4]], ids=["small", "bigger"])
def test_qbytestensor_quantize_transposed_groupwise(input_shape, group_size, device):
x = torch.tensor(range(prod(input_shape)), dtype=torch.float32).reshape(input_shape).to(device)
xt = x.t()
qx = quantize_weight(x, qtype=qint8, axis=0, group_size=group_size)
qxt = quantize_weight(xt, qtype=qint8, axis=-1, group_size=group_size)
dqx = qx.dequantize()
dqxt = qxt.dequantize()
assert torch.equal(dqx.t(), dqxt)


def test_to_device(device):
qa = random_qweight((32, 32), qtype=qint8, dtype=torch.float)
qa = qa.to(device)
Expand Down

0 comments on commit 19d3e88

Please sign in to comment.