From 19d3e887ee686040c7348288fe4a99fe5dfda059 Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Fri, 3 May 2024 15:50:15 +0200 Subject: [PATCH] refactor(tensor): remove 8-bit groupwise support --- .../tensor/optimizers/symmetric_optimizer.py | 7 +--- quanto/tensor/qactivation.py | 2 +- quanto/tensor/qbytes_ops.py | 2 +- quanto/tensor/quantizers/symmetric.py | 13 ++----- quanto/tensor/qweight.py | 6 ++- test/tensor/ops/test_qweight_dispatch.py | 38 ------------------- test/tensor/quantizers/test_symmetric.py | 4 +- test/tensor/test_qbytestensor.py | 12 ------ 8 files changed, 13 insertions(+), 71 deletions(-) delete mode 100644 test/tensor/ops/test_qweight_dispatch.py diff --git a/quanto/tensor/optimizers/symmetric_optimizer.py b/quanto/tensor/optimizers/symmetric_optimizer.py index 33e5513e..3c990fbf 100644 --- a/quanto/tensor/optimizers/symmetric_optimizer.py +++ b/quanto/tensor/optimizers/symmetric_optimizer.py @@ -16,7 +16,6 @@ import torch -from ..core import group from .optimizer import Optimizer @@ -25,13 +24,9 @@ class SymmetricOptimizer(Optimizer): - def __call__( - self, base: torch.Tensor, bits: int, axis: Optional[int] = None, group_size: Optional[int] = None - ) -> torch.Tensor: + def __call__(self, base: torch.Tensor, bits: int, axis: Optional[int] = None) -> torch.Tensor: if axis not in [None, 0, -1]: raise ValueError("axis parameter must be None, 0 (first axis) or -1 (last axis)") - if group_size is not None: - base = group(base, axis=axis, group_size=group_size) scale = self.optimize(base, bits, axis) assert scale.dtype == base.dtype return scale diff --git a/quanto/tensor/qactivation.py b/quanto/tensor/qactivation.py index a7bf596c..a55b11d0 100644 --- a/quanto/tensor/qactivation.py +++ b/quanto/tensor/qactivation.py @@ -36,4 +36,4 @@ def quantize_activation(t: torch.Tensor, qtype: qtype, scale: torch.Tensor): """ if scale.numel() != 1: raise ValueError("Parameter scale must be a scalar because activations can only be quantized per-tensor") - return SymmetricQuantizer.apply(t, qtype, None, None, scale) + return SymmetricQuantizer.apply(t, qtype, None, scale) diff --git a/quanto/tensor/qbytes_ops.py b/quanto/tensor/qbytes_ops.py index 459a348a..47377ae7 100644 --- a/quanto/tensor/qbytes_ops.py +++ b/quanto/tensor/qbytes_ops.py @@ -295,7 +295,7 @@ def transpose2d(op, input): out_data = op(input._data) out_scale = input._scale out_axis = input.axis - # Manually reverse size and stride because we cannot trust the out_data shape when using group-wise quantization + # Manually reverse size and stride because we cannot trust the out_data shape dim0, dim1 = input.size() out_size = torch.Size([dim1, dim0]) out_stride = input.stride()[::-1] diff --git a/quanto/tensor/quantizers/symmetric.py b/quanto/tensor/quantizers/symmetric.py index db59dc2e..47907601 100644 --- a/quanto/tensor/quantizers/symmetric.py +++ b/quanto/tensor/quantizers/symmetric.py @@ -15,7 +15,7 @@ import torch from torch.autograd import Function -from ..core import dtype_info, group +from ..core import dtype_info from ..qbytes import QBytesTensor from ..qtype import qtype @@ -27,15 +27,13 @@ class SymmetricQuantizer(Function): """A standard symmetric quantizer.""" @staticmethod - def forward(ctx, base: torch.Tensor, qtype: qtype, axis: int, group_size: int, scale: torch.Tensor): + def forward(ctx, base: torch.Tensor, qtype: qtype, axis: int, scale: torch.Tensor): size = base.size() stride = base.stride() # Sanity checks if axis is None: if scale.ndim > 0: raise ValueError("Scale must be a scalar when quantizing per-tensor") - if group_size is not None: - raise ValueError("Group size can only be specified when quantizing per-axis") else: if base.ndim == 1: raise ValueError("1D Tensors cannot be quantized per-axis") @@ -46,11 +44,8 @@ def forward(ctx, base: torch.Tensor, qtype: qtype, axis: int, group_size: int, s raise ValueError("QBytesTensor can only be quantized along the first or last axis.") if base.shape[axis] == 1: raise ValueError(f"Cannot quantize Tensor of shape {base.shape} along axis {axis} of size 1") - if group_size is not None: - base = group(base, axis=axis, group_size=group_size) - else: - if torch.squeeze(scale).ndim > 1: - raise ValueError("Quantizing along multiple axis is not supported") + if torch.squeeze(scale).ndim > 1: + raise ValueError("Quantizing along multiple axis is not supported") if scale.ndim != base.ndim: raise ValueError( "When quantizing per-axis, the scale must be broadcastable to the base (Tip: try to add missing dims of length zero)." diff --git a/quanto/tensor/qweight.py b/quanto/tensor/qweight.py index c28f5e75..381d3335 100644 --- a/quanto/tensor/qweight.py +++ b/quanto/tensor/qweight.py @@ -54,8 +54,10 @@ def quantize_weight( else: if not isinstance(optimizer, SymmetricOptimizer): raise ValueError("A SymmetricOptimizer is expected") - scale = optimizer(t, qtype.bits, axis, group_size) - return SymmetricQuantizer.apply(t, qtype, axis, group_size, scale) + if group_size is not None: + raise ValueError("group_size cannot be specified for 8-bit qtypes.") + scale = optimizer(t, qtype.bits, axis) + return SymmetricQuantizer.apply(t, qtype, axis, scale) if optimizer is None: optimizer = default_affine_optimizer else: diff --git a/test/tensor/ops/test_qweight_dispatch.py b/test/tensor/ops/test_qweight_dispatch.py deleted file mode 100644 index 73ae8eca..00000000 --- a/test/tensor/ops/test_qweight_dispatch.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest -import torch -from helpers import random_qweight - -from quanto import qint8 - - -@pytest.mark.parametrize("axis", [0, -1], ids=["first-axis", "last-axis"]) -@pytest.mark.parametrize( - "group_size", - [None, 2], - ids=["channel-wise", "group-wise"], -) -def test_qweight_transpose_2d(axis, group_size, device): - input_shape = (4, 6) - qinputs = random_qweight(input_shape, qint8, axis=axis, group_size=group_size).to(device) - qtransposed = qinputs.t() - assert qtransposed.qtype == qinputs.qtype - if axis == -1: - assert qtransposed.axis == 0 - elif axis == 0: - assert qtransposed.axis == -1 - assert qtransposed.shape == input_shape[::-1] - assert torch.equal(qtransposed.dequantize(), qinputs.dequantize().t()) diff --git a/test/tensor/quantizers/test_symmetric.py b/test/tensor/quantizers/test_symmetric.py index 1341f041..56ce580c 100644 --- a/test/tensor/quantizers/test_symmetric.py +++ b/test/tensor/quantizers/test_symmetric.py @@ -40,7 +40,7 @@ def test_symmetric_quantize_int(input_shape, dtype, qtype, axis, device): a = random_tensor(input_shape, dtype=dtype).to(device) scale = absmax_scale(a, qtype=qtype, axis=axis) - qa = SymmetricQuantizer.apply(a, qtype, axis, None, scale) + qa = SymmetricQuantizer.apply(a, qtype, axis, scale) assert isinstance(qa, QBytesTensor) assert qa.dtype == dtype assert qa.qtype == qtype @@ -62,7 +62,7 @@ def test_symmetric_quantize_int(input_shape, dtype, qtype, axis, device): def test_symmetric_quantize_float8(input_shape, dtype, qtype, axis, device): a = random_tensor(input_shape, dtype=dtype).to(device) scale = absmax_scale(a, qtype=qtype, axis=axis) - qa = SymmetricQuantizer.apply(a, qtype, axis, None, scale) + qa = SymmetricQuantizer.apply(a, qtype, axis, scale) assert isinstance(qa, QBytesTensor) assert qa.dtype == dtype assert qa.qtype == qtype diff --git a/test/tensor/test_qbytestensor.py b/test/tensor/test_qbytestensor.py index c491d32e..1e24cb37 100644 --- a/test/tensor/test_qbytestensor.py +++ b/test/tensor/test_qbytestensor.py @@ -13,7 +13,6 @@ # limitations under the License. import io -from math import prod import pytest import torch @@ -123,17 +122,6 @@ def test_qbytestensor_contiguous(axis, qtype, device): assert tqa.is_contiguous() -@pytest.mark.parametrize("input_shape, group_size", [[(4, 6), 2], [(32, 64), 4]], ids=["small", "bigger"]) -def test_qbytestensor_quantize_transposed_groupwise(input_shape, group_size, device): - x = torch.tensor(range(prod(input_shape)), dtype=torch.float32).reshape(input_shape).to(device) - xt = x.t() - qx = quantize_weight(x, qtype=qint8, axis=0, group_size=group_size) - qxt = quantize_weight(xt, qtype=qint8, axis=-1, group_size=group_size) - dqx = qx.dequantize() - dqxt = qxt.dequantize() - assert torch.equal(dqx.t(), dqxt) - - def test_to_device(device): qa = random_qweight((32, 32), qtype=qint8, dtype=torch.float) qa = qa.to(device)