From fceef1b0e85ad0285b496951987095a4fb8c3c71 Mon Sep 17 00:00:00 2001 From: Liubov Talamanova Date: Fri, 7 Jun 2024 15:01:12 +0100 Subject: [PATCH] update wc_reference_data.yaml --- nncf/torch/quantization/layers.py | 7 +++--- nncf/torch/quantization/quantize_functions.py | 22 ++++++++++++++----- .../post_training/data/wc_reference_data.yaml | 2 +- 3 files changed, 22 insertions(+), 9 deletions(-) diff --git a/nncf/torch/quantization/layers.py b/nncf/torch/quantization/layers.py index 213d226926e..b84d325526f 100644 --- a/nncf/torch/quantization/layers.py +++ b/nncf/torch/quantization/layers.py @@ -46,7 +46,8 @@ from nncf.torch.quantization.quantize_functions import ExportQuantizeToONNXQuantDequant from nncf.torch.quantization.quantize_functions import TuneRange from nncf.torch.quantization.quantize_functions import asymmetric_quantize -from nncf.torch.quantization.quantize_functions import decompress +from nncf.torch.quantization.quantize_functions import decompress_asymmetric +from nncf.torch.quantization.quantize_functions import decompress_symmetric from nncf.torch.quantization.quantize_functions import get_scale_zp_from_input_low_input_high from nncf.torch.quantization.quantize_functions import symmetric_quantize from nncf.torch.return_types import maybe_get_values_from_torch_return_type @@ -1061,7 +1062,7 @@ def __init__(self, scale: torch.Tensor, zero_point: torch.Tensor, result_dtype: self.result_dtype = result_dtype def forward(self, x): - result = decompress(x, self._scale, self._zero_point) + result = decompress_asymmetric(x, self._scale, self._zero_point) result = result.type(dtype=self.result_dtype) if self.result_dtype is not None else result return result @@ -1081,6 +1082,6 @@ def __init__(self, scale: torch.Tensor, result_dtype: torch.dtype = None): self.result_dtype = result_dtype def forward(self, x): - result = decompress(x, self._scale) + result = decompress_symmetric(x, self._scale) result = result.type(dtype=self.result_dtype) if self.result_dtype is not None else result return result diff --git a/nncf/torch/quantization/quantize_functions.py b/nncf/torch/quantization/quantize_functions.py index ad7e8c03ca6..9b4055c4586 100644 --- a/nncf/torch/quantization/quantize_functions.py +++ b/nncf/torch/quantization/quantize_functions.py @@ -8,7 +8,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Any import torch @@ -249,9 +249,9 @@ def backward(ctx: Any, *grad_outputs: Any) -> Any: @register_operator() -def decompress(input: torch.Tensor, scale: torch.Tensor, zero_point: Optional[torch.Tensor] = None) -> torch.Tensor: +def decompress_asymmetric(input: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor) -> torch.Tensor: """ - Decompress the input tensor. + Decompress the asymmetrically quantized input tensor. :param input: An input tensor :param scale: A scale tensor @@ -259,7 +259,19 @@ def decompress(input: torch.Tensor, scale: torch.Tensor, zero_point: Optional[to :return: The decompressed tensor """ input = input.type(dtype=scale.dtype) - if zero_point is not None: - input -= zero_point + decompressed_input = (input - zero_point) * scale + return decompressed_input + + +@register_operator() +def decompress_symmetric(input: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + """ + Decompress the symmetrically quantized input tensor. + + :param input: An input tensor + :param scale: A scale tensor + :return: The decompressed tensor + """ + input = input.type(dtype=scale.dtype) decompressed_input = input * scale return decompressed_input diff --git a/tests/post_training/data/wc_reference_data.yaml b/tests/post_training/data/wc_reference_data.yaml index 6827c3adce1..9682783a708 100644 --- a/tests/post_training/data/wc_reference_data.yaml +++ b/tests/post_training/data/wc_reference_data.yaml @@ -24,5 +24,5 @@ tinyllama_int8_data_free_backend_TORCH: num_int8: 312 tinyllama_data_aware_gptq_backend_OV: metric_value: 0.83387 - num_int4: 188 + num_int4: 94 num_int8: 124