From 70908987c729f5e3730ed067d6b7cc7d7fa1ac4c Mon Sep 17 00:00:00 2001 From: Liubov Talamanova Date: Mon, 29 Jan 2024 13:44:06 +0000 Subject: [PATCH 01/10] Represent symmetrically quantized weights in signed data type --- .../weights_compression/Usage.md | 4 +- .../weight_compression/openvino_backend.py | 42 ++-- .../weight_compression/torch_backend.py | 24 +- .../weight_compression/weight_lowering.py | 36 ++- nncf/torch/quantization/layers.py | 24 +- ...egerModel_compressed_weights_int8_sym.json | 221 +++++++++--------- .../quantization/test_weights_compression.py | 70 +++--- tests/torch/ptq/test_weights_compression.py | 20 +- 8 files changed, 232 insertions(+), 209 deletions(-) diff --git a/docs/usage/post_training_compression/weights_compression/Usage.md b/docs/usage/post_training_compression/weights_compression/Usage.md index a9e5289e0db..a27bb2bb50a 100644 --- a/docs/usage/post_training_compression/weights_compression/Usage.md +++ b/docs/usage/post_training_compression/weights_compression/Usage.md @@ -9,7 +9,7 @@ The Weights Compression algorithm is aimed at compressing the weights of the mod #### Supported modes By default, weights are compressed asymmetrically to 8-bit integer data type - "INT8_ASYM" mode. -OpenVINO backend also supports 3 modes of mixed precision weight quantization with a 4-bit data type as a primary precision - INT4_SYM, INT4_ASYM and NF4. The primary precision in case of INT4_SYM mode is unsigned 4-bit integer and weights are quantized to it [symmetrically](/docs/usage/training_time_compression/other_algorithms/LegacyQuantization.md#symmetric-quantization) with a fixed zero point equals to 8. In case of INT4_ASYM mode - also unsigned 4-bit integer, but weight are quantized to it [asymmetrically](/docs/usage/training_time_compression/other_algorithms/LegacyQuantization.md#asymmetric-quantization) with a typical non-fixed zero point. In case of NF4 mode - [nf4](https://arxiv.org/pdf/2305.14314v1.pdf) data type without zero point. +OpenVINO backend also supports 3 modes of mixed precision weight quantization with a 4-bit data type as a primary precision - INT4_SYM, INT4_ASYM and NF4. The primary precision in case of INT4_SYM mode is signed 4-bit integer and weights are quantized to it [symmetrically](/docs/usage/training_time_compression/other_algorithms/LegacyQuantization.md#symmetric-quantization) without zero point. In case of INT4_ASYM mode - unsigned 4-bit integer, but weight are quantized to it [asymmetrically](/docs/usage/training_time_compression/other_algorithms/LegacyQuantization.md#asymmetric-quantization) with a typical non-fixed zero point. In case of NF4 mode - [nf4](https://arxiv.org/pdf/2305.14314v1.pdf) data type without zero point. All 4-bit modes have a grouped quantization support, when small group of weights (e.g. 128) in the channel dimension share quantization parameters (scale). All embeddings, convolutions and last linear layers are always compressed to 8-bit integer data type. To quantize embeddings and last linear layers to 4-bit, use `all_layers=True`. Percent of the rest layers compressed to 4-bit can be configured by "ratio" parameter. E.g. ratio=0.9 means 90% of layers compressed to the corresponding 4-bit data type and the rest to 8-bit asymmetric integer data type. @@ -484,7 +484,7 @@ Here is the perplexity and accuracy with data-free and data-aware mixed-precisio - The algorithm is supported for OpenVINO and PyTorch models. - The compression applies in-place. - The compressed model is not trainable. -- INT8_SYM, INT4_SYM, INT4_ASYM and NF4 modes, grouped quantization and mixed precision selection is available for OpenVINO backend only. +- INT4_SYM, INT4_ASYM and NF4 modes, grouped quantization and mixed precision selection is available for OpenVINO backend only. - NF4 support is experimental - models quantized to nf4 should not be faster models quantized to 8-bit integer. #### Additional resources diff --git a/nncf/quantization/algorithms/weight_compression/openvino_backend.py b/nncf/quantization/algorithms/weight_compression/openvino_backend.py index c565c8a603d..859112a395e 100644 --- a/nncf/quantization/algorithms/weight_compression/openvino_backend.py +++ b/nncf/quantization/algorithms/weight_compression/openvino_backend.py @@ -20,6 +20,8 @@ from nncf.common.graph.transformations.commands import TargetType from nncf.common.graph.utils import get_reduction_axes from nncf.experimental.common.tensor_statistics.collectors import TensorCollector +from nncf.experimental.tensor.definitions import TensorDataType +from nncf.experimental.tensor.functions import count_nonzero from nncf.experimental.tensor.tensor import Tensor from nncf.openvino.graph.metatypes import openvino_metatypes as om from nncf.openvino.graph.model_transformer import OVModelTransformer @@ -134,17 +136,14 @@ def transform_model( compression_config = wc_params.compression_config if compression_config.mode == CompressWeightsMode.NF4: compression_dtype = ov.Type.nf4 - elif compression_config.mode in [ - CompressWeightsMode.INT8_ASYM, - CompressWeightsMode.INT8_SYM, - CompressWeightsMode.INT8, - CompressWeightsMode.INT4_ASYM, - CompressWeightsMode.INT4_SYM, - ]: - if compression_config.mode in [CompressWeightsMode.INT4_ASYM, CompressWeightsMode.INT4_SYM]: - compression_dtype = ov.Type.u4 - else: - compression_dtype = ov.Type.u8 + elif compression_config.mode == CompressWeightsMode.INT4_SYM: + compression_dtype = ov.Type.i4 + elif compression_config.mode == CompressWeightsMode.INT4_ASYM: + compression_dtype = ov.Type.u4 + elif compression_config.mode == CompressWeightsMode.INT8_SYM: + compression_dtype = ov.Type.i8 + elif compression_config.mode == CompressWeightsMode.INT8_ASYM: + compression_dtype = ov.Type.u8 else: raise ValueError(f"{compression_config.mode.value} is not supported.") @@ -176,15 +175,18 @@ def transform_model( ) converted_const = opset.convert(compressed_const, ov.Type.f16) if compressed_weight.zero_point is not None: - zero_point_const = opset.constant( - compressed_weight.zero_point.data, - dtype=compression_dtype, - name=f"{const_node_name}/zero_point", - ) - converted_zero_point = opset.convert(zero_point_const, ov.Type.f16) - converted_const = opset.subtract( - converted_const, converted_zero_point, name=f"{const_node_name}/zero_point/subtract" - ) + if compressed_weight.tensor.dtype == TensorDataType.int8: + assert count_nonzero(compressed_weight.zero_point.data) == 0 + else: + zero_point_const = opset.constant( + compressed_weight.zero_point.data, + dtype=compression_dtype, + name=f"{const_node_name}/zero_point", + ) + converted_zero_point = opset.convert(zero_point_const, ov.Type.f16) + converted_const = opset.subtract( + converted_const, converted_zero_point, name=f"{const_node_name}/zero_point/subtract" + ) scale_const = opset.constant( compressed_weight.scale.data, dtype=ov.Type.f16, name=f"{const_node_name}/scale" diff --git a/nncf/quantization/algorithms/weight_compression/torch_backend.py b/nncf/quantization/algorithms/weight_compression/torch_backend.py index 4aa139177df..f80fd3a7864 100644 --- a/nncf/quantization/algorithms/weight_compression/torch_backend.py +++ b/nncf/quantization/algorithms/weight_compression/torch_backend.py @@ -23,6 +23,7 @@ from nncf.common.graph.transformations.layout import TransformationLayout from nncf.experimental.common.tensor_statistics.collectors import TensorCollector from nncf.experimental.tensor.definitions import TensorDataType +from nncf.experimental.tensor.functions import count_nonzero from nncf.experimental.tensor.tensor import Tensor from nncf.parameters import CompressWeightsMode from nncf.quantization.algorithms.weight_compression.backend import WeightCompressionAlgoBackend @@ -38,7 +39,8 @@ from nncf.torch.model_graph_manager import split_const_name from nncf.torch.model_transformer import PTModelTransformer from nncf.torch.nncf_network import NNCFNetwork -from nncf.torch.quantization.layers import WeightsDecompressor +from nncf.torch.quantization.layers import AsymmetricWeightsDecompressor +from nncf.torch.quantization.layers import SymmetricWeightsDecompressor from nncf.torch.tensor_statistics.collectors import get_raw_stat_collector @@ -211,7 +213,11 @@ def transform_model( compressed_weight.scale = compressed_weight.scale.astype(dtype=TensorDataType.float16) # pack compressed tensor - packed_tensor = compressed_weight.tensor.astype(TensorDataType.uint8) + if compression_config.mode == CompressWeightsMode.INT8_SYM: + dtype = TensorDataType.int8 + else: + dtype = TensorDataType.uint8 + packed_tensor = compressed_weight.tensor.astype(dtype) # sets compressed tensor compressed_parameter = torch.nn.Parameter(packed_tensor.data, requires_grad=False) @@ -225,13 +231,15 @@ def transform_model( if id(param) == id(weight): setattr(c_module, name, compressed_parameter) - # pack zero point tensor - packed_zero_point = compressed_weight.zero_point.astype(TensorDataType.uint8) - # creates weight decompressor - decompressor = WeightsDecompressor( - compressed_weight.scale.data, packed_zero_point.data, result_dtype=weight.dtype - ) + if compression_config.mode == CompressWeightsMode.INT8_SYM: + assert count_nonzero(compressed_weight.zero_point) == 0 + decompressor = SymmetricWeightsDecompressor(compressed_weight.scale.data, result_dtype=weight.dtype) + else: + packed_zero_point = compressed_weight.zero_point.astype(dtype) + decompressor = AsymmetricWeightsDecompressor( + compressed_weight.scale.data, packed_zero_point.data, result_dtype=weight.dtype + ) # registry weight decompression module in the model decompressor_name = f"weights_decompressor_{weight_node.node_name.replace('.', '_')}" diff --git a/nncf/quantization/algorithms/weight_compression/weight_lowering.py b/nncf/quantization/algorithms/weight_compression/weight_lowering.py index c3b7ed40c84..050467269df 100644 --- a/nncf/quantization/algorithms/weight_compression/weight_lowering.py +++ b/nncf/quantization/algorithms/weight_compression/weight_lowering.py @@ -222,25 +222,22 @@ def calculate_integer_quantization_params( assert mode != CompressWeightsMode.NF4, "The function supports integer quantization only" num_bits = config.num_bits - level_low = 0 - level_high = 2**num_bits - 1 - if weight.dtype != TensorDataType.float32: weight = weight.astype(TensorDataType.float32) if mode in [CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT4_ASYM]: + level_low = 0 + level_high = 2**num_bits - 1 min_values = fns.min(weight, axis=reduction_axes, keepdims=True) # [a1, r, a2] -> [a1, 1, a2] max_values = fns.max(weight, axis=reduction_axes, keepdims=True) # [a1, r, a2] -> [a1, 1, a2] scale, zero_point = calculate_scale_zero_point( min_values, max_values, level_low, level_high, narrow_range=False ) else: - level_low_sym = -(2 ** (num_bits - 1)) - level_high_sym = 2 ** (num_bits - 1) - 1 - + level_high = 2 ** (num_bits - 1) - 1 scale = fns.max(fns.abs(weight), axis=reduction_axes, keepdims=True) # [a1, r//gs, 1, a2] - scale = scale / level_high_sym - zero_point = fns.as_tensor_like(scale, [-level_low_sym]).astype(TensorDataType.int32) + scale = scale / level_high + zero_point = fns.zeros_like(scale).astype(TensorDataType.int32) eps = fns.finfo(scale).eps # NOTE: adding machine epsilon to avoid division by zero scale = fns.where(fns.abs(scale) < eps, eps, scale) @@ -258,7 +255,7 @@ def calculate_quantized_weight( :param scale: Scale tensor used for quantization. :param zero_point: Zero point tensor used for quantization. :param config: Weight compression configuration. - :return: Quantized weight tensor of uint8 type. + :return: Quantized weight tensor of uint8 or int8 type. """ if weight.dtype != TensorDataType.float32: weight = weight.astype(TensorDataType.float32) @@ -266,12 +263,13 @@ def calculate_quantized_weight( scale = scale.astype(TensorDataType.float32) num_bits = config.num_bits - - level_low = 0 - level_high = 2**num_bits - 1 + asym_quant = config.mode in [CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT4_ASYM] + dtype = TensorDataType.uint8 if asym_quant else TensorDataType.int8 + level_low = 0 if asym_quant else -(2 ** (num_bits - 1)) + level_high = 2**num_bits - 1 if asym_quant else 2 ** (num_bits - 1) - 1 compressed_weights = fns.round(weight / scale + zero_point.astype(weight.dtype)) - compressed_weights = fns.clip(compressed_weights, level_low, level_high).astype(TensorDataType.uint8) + compressed_weights = fns.clip(compressed_weights, level_low, level_high).astype(dtype) return compressed_weights @@ -285,14 +283,14 @@ def do_integer_quantization( """ The method quantizes the given weights to integer data type in accordance with the compression config. The config defines a quantization mode: - INT8_SYM mode refers to unsigned int8 symmetric weight compression with a fixed zero point equals to 128 - - quantization to [0, 255] range. + INT8_SYM mode refers to signed int8 symmetric weight compression without zero point - + quantization to [-128, 127] range. INT8_ASYM mode refers to unsigned int8 asymmetric weight compression with a typical non-fixed zero-point - quantization to [0, 255] range. INT4_ASYM mode refers to unsigned int4 asymmetric weight compression with a typical non-fixed zero-point - quantization to [0, 15] range. - INT4_SYM mode refers to unsigned int4 symmetric weight compression with a fixed zero point equals to 8 - - quantization to [0, 15] range. + INT4_SYM mode refers to signed int4 symmetric weight compression without zero point - + quantization to [-8, 7] range. NF4 mode requires a dedicated procedure and it is not supported in this method. One of the parameter of compression config is a group size. Quantization is per-channel, if group size equals to -1, otherwise it's per-group, i.e. group size number of weights in the channel dimension share quantization parameters @@ -303,8 +301,8 @@ def do_integer_quantization( :param config: Information on how to compress (quantize) a specific weight. :param precomputed_scale: Precomputed scale. :param precomputed_zero_point: Precomputed zero point. - :return: The compressed weights tensor of uint8 type, scale tensor of float32 type and - zero point tensor of int32 type that was used for its quantization. + :return: The compressed weights tensor of uint8 (asymmetric mode) or int8 (symmetric mode) type, + scale tensor of float32 type and zero point tensor of int32 type that was used for its quantization. """ mode = config.mode assert mode != CompressWeightsMode.NF4, "The function supports integer quantization only" diff --git a/nncf/torch/quantization/layers.py b/nncf/torch/quantization/layers.py index 4b463600bc5..ec5e08d73d7 100644 --- a/nncf/torch/quantization/layers.py +++ b/nncf/torch/quantization/layers.py @@ -1044,9 +1044,9 @@ def get_scale_shape(input_shape: List[int], is_weights: bool, per_channel: bool, return get_per_channel_scale_shape(input_shape, is_weights, channel_idx) -class WeightsDecompressor(nn.Module): +class AsymmetricWeightsDecompressor(nn.Module): """ - Applies decompression of compressed weights in the forward pass + Applies asymmetric decompression of compressed weights in the forward pass """ def __init__(self, scale: torch.Tensor, zero_point: torch.Tensor, result_dtype: torch.dtype = None): @@ -1064,3 +1064,23 @@ def forward(self, x): result = decompress(x, self._scale, self._zero_point) result = result.type(dtype=self.result_dtype) if self.result_dtype is not None else result return result + + +class SymmetricWeightsDecompressor(nn.Module): + """ + Applies symmetric decompression of compressed weights in the forward pass + """ + + def __init__(self, scale: torch.Tensor, result_dtype: torch.dtype = None): + """ + :param scale: A scale in quantization scheme + :param result_dtype: (Optional) A data type that result should be cast to + """ + super().__init__() + self.register_buffer("_scale", scale) + + def forward(self, x): + zero_point = torch.zeros_like(self._scale) + result = decompress(x, self._scale, zero_point) + result = result.type(dtype=self.result_dtype) if self.result_dtype is not None else result + return result diff --git a/tests/openvino/native/data/2024.1/reference_scales/IntegerModel_compressed_weights_int8_sym.json b/tests/openvino/native/data/2024.1/reference_scales/IntegerModel_compressed_weights_int8_sym.json index 2e58bf43fc2..777d418696d 100644 --- a/tests/openvino/native/data/2024.1/reference_scales/IntegerModel_compressed_weights_int8_sym.json +++ b/tests/openvino/native/data/2024.1/reference_scales/IntegerModel_compressed_weights_int8_sym.json @@ -2,63 +2,60 @@ "matmul_2_data": { "compressed_weight": [ [ - 182, - 152, - 200, - 255, - 165, - 136, - 193 - ], - [ - 155, - 140, - 206, - 168, - 219, - 155, - 255 - ], - [ - 177, - 142, - 212, - 251, - 187, - 255, - 195 - ], - [ - 182, - 207, - 255, - 249, - 187, - 225, - 191 - ], - [ - 200, - 235, - 184, - 228, - 225, - 255, - 144 - ], - [ - 222, - 248, - 253, - 130, - 240, - 255, - 252 + 54, + 24, + 72, + 127, + 37, + 8, + 65 + ], + [ + 27, + 12, + 78, + 40, + 91, + 27, + 127 + ], + [ + 49, + 14, + 84, + 123, + 59, + 127, + 67 + ], + [ + 54, + 79, + 127, + 121, + 59, + 97, + 63 + ], + [ + 72, + 107, + 56, + 100, + 97, + 127, + 16 + ], + [ + 94, + 120, + 125, + 2, + 112, + 127, + 124 ] ], - "zero_point": [ - 128 - ], "scale": [ [ 0.0062713623046875 @@ -83,57 +80,54 @@ "matmul_1_data": { "compressed_weight": [ [ - 185, - 208, - 133, - 152, - 255, - 251 + 57, + 80, + 5, + 24, + 127, + 123 ], [ - 206, - 177, - 255, - 253, - 215, - 211 + 78, + 49, + 127, + 125, + 87, + 83 ], [ - 249, - 196, - 152, - 255, - 220, - 183 + 121, + 68, + 24, + 127, + 92, + 55 ], [ - 194, - 249, - 255, - 177, - 206, - 172 + 66, + 121, + 127, + 49, + 78, + 44 ], [ - 213, - 176, - 184, - 255, - 160, - 217 + 85, + 48, + 56, + 127, + 32, + 89 ], [ - 140, - 249, - 242, - 163, - 255, - 136 + 12, + 121, + 114, + 35, + 127, + 8 ] ], - "zero_point": [ - 128 - ], "scale": [ [ 0.005279541015625 @@ -158,33 +152,30 @@ "gather_2_data": { "compressed_weight": [ [ - 217, - 166, - 134, - 130, - 241, - 255 + 89, + 38, + 6, + 2, + 113, + 127 ], [ - 210, - 227, - 202, - 255, - 239, - 128 + 82, + 99, + 74, + 127, + 111, + 0 ], [ - 254, - 133, - 235, - 154, - 255, - 208 + 126, + 5, + 107, + 26, + 127, + 80 ] ], - "zero_point": [ - 128 - ], "scale": [ [ 0.0071868896484375 diff --git a/tests/openvino/native/quantization/test_weights_compression.py b/tests/openvino/native/quantization/test_weights_compression.py index dbc82013df5..804baf608a4 100644 --- a/tests/openvino/native/quantization/test_weights_compression.py +++ b/tests/openvino/native/quantization/test_weights_compression.py @@ -74,41 +74,41 @@ def get_next_node(node): def check_int8_node(op: ov.Node, mode: CompressWeightsMode = CompressWeightsMode.INT8_ASYM): - assert op.get_element_type() == ov.Type(np.uint8) + dtype = ov.Type.u8 if mode == CompressWeightsMode.INT8_ASYM else ov.Type.i8 + assert op.get_element_type() == dtype compressed_weight = get_const_value(op) + stats = {"compressed_weight": compressed_weight} convert_node = get_next_node(op) assert convert_node.get_type_name() == "Convert" - sub_node = get_next_node(convert_node) - assert sub_node.get_type_name() == "Subtract" + if mode == CompressWeightsMode.INT8_ASYM: + sub_node = get_next_node(convert_node) + assert sub_node.get_type_name() == "Subtract" - convert_node = sub_node.input_value(1).get_node() - assert convert_node.get_type_name() == "Convert" + convert_node = sub_node.input_value(1).get_node() + assert convert_node.get_type_name() == "Convert" - zero_point_node = convert_node.input_value(0).get_node() - zero_point = get_const_value(zero_point_node) - if mode == CompressWeightsMode.INT8_SYM: - assert list(zero_point_node.shape) == [1] - else: + zero_point_node = convert_node.input_value(0).get_node() + zero_point = get_const_value(zero_point_node) + stats["zero_point"] = zero_point reduced_weight_shape = list(op.shape) reduced_weight_shape[-1] = 1 assert list(zero_point_node.shape) == reduced_weight_shape + mul_node = get_next_node(sub_node) + else: + mul_node = get_next_node(convert_node) - mul_node = get_next_node(sub_node) assert mul_node.get_type_name() == "Multiply" scale_node = mul_node.input_value(1).get_node() scale = get_const_value(scale_node) - - return { - "compressed_weight": compressed_weight, - "zero_point": zero_point, - "scale": scale, - } + stats["scale"] = scale + return stats def check_int4_grouped(op: ov.Node, mode: CompressWeightsMode, group_size: int = 7): - assert op.get_element_type() == ov.Type.u4 + dtype = ov.Type.u4 if mode == CompressWeightsMode.INT4_ASYM else ov.Type.i4 + assert op.get_element_type() == dtype weight_shape = op.shape # NOTE: get_const_value doesn't work for 4-bit types assert list(weight_shape)[-1] == group_size @@ -118,20 +118,20 @@ def check_int4_grouped(op: ov.Node, mode: CompressWeightsMode, group_size: int = convert_node = get_next_node(op) assert convert_node.get_type_name() == "Convert" - sub_node = get_next_node(convert_node) - assert sub_node.get_type_name() == "Subtract" + if mode == CompressWeightsMode.INT4_ASYM: + sub_node = get_next_node(convert_node) + assert sub_node.get_type_name() == "Subtract" - convert_node = sub_node.input_value(1).get_node() - assert convert_node.get_type_name() == "Convert" + convert_node = sub_node.input_value(1).get_node() + assert convert_node.get_type_name() == "Convert" - zero_point_node = convert_node.input_value(0).get_node() - assert zero_point_node.get_element_type() == ov.Type.u4 - if mode == CompressWeightsMode.INT4_SYM: - assert list(zero_point_node.shape) == [1] - else: + zero_point_node = convert_node.input_value(0).get_node() + assert zero_point_node.get_element_type() == dtype assert list(zero_point_node.shape) == reduced_weight_shape + mul_node = get_next_node(sub_node) + else: + mul_node = get_next_node(convert_node) - mul_node = get_next_node(sub_node) assert mul_node.get_type_name() == "Multiply" scale_node = mul_node.input_value(1).get_node() assert list(scale_node.shape) == reduced_weight_shape @@ -282,7 +282,7 @@ def test_gather_in_4_bit_if_all_layers_with_data(metric): for node_name in int4_reference_node_names: node = nodes_map[node_name] assert node.get_type_name() == "Constant" - assert node.get_element_type() == ov.Type.u4 + assert node.get_element_type() == ov.Type.i4 def test_gather_can_be_8_bit_if_all_layers_without_data(): @@ -359,7 +359,7 @@ def test_gather_can_be_4_bit_if_all_layers_without_data(): for node_name in int4_reference_node_names: node = nodes_map[node_name] assert node.get_type_name() == "Constant" - assert node.get_element_type() == ov.Type.u4 + assert node.get_element_type() == ov.Type.i4 @pytest.mark.parametrize("metric", ALL_SENSITIVITY_METRICS) @@ -380,7 +380,7 @@ def test_gather_in_8_bit_if_not_all_layers(metric): for node_name in int8_reference_node_names: node = nodes_map[node_name] assert node.get_type_name() == "Constant" - assert node.get_element_type() == ov.Type.u8 + assert node.get_element_type() == ov.Type.i8 MAX_BASELINE_SCORE = 1 / np.finfo(np.float32).eps @@ -445,9 +445,9 @@ def test_quantize_Gather_with_multiple_reduction_axes_if_mode_4bit(mode, all_lay @pytest.mark.parametrize("mode", (CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM)) def test_shared_gather(mode): weight_name_vs_type = { - "gather_2_data": ov.Type(np.uint8), - "shared_data": ov.Type(np.uint8), - "matmul_1_data": ov.Type.u4, + "gather_2_data": ov.Type.u8, + "shared_data": ov.Type.u8, + "matmul_1_data": ov.Type.i4 if mode == CompressWeightsMode.INT4_SYM else ov.Type.u4, } model = GatherAndMatmulShareData().ov_model compressed_model = compress_weights(model, mode, group_size=3) @@ -638,7 +638,7 @@ def test_weight_compress_with_ignored_scope(ignored_scope, num_compressed): if ( op.get_type_name() == "Constant" and op.get_friendly_name() in ref_compressed_weights - and op.get_element_type() == ov.Type(np.uint8) + and op.get_element_type() == ov.Type.u8 ): act_num += 1 assert act_num == num_compressed diff --git a/tests/torch/ptq/test_weights_compression.py b/tests/torch/ptq/test_weights_compression.py index b7de194e6c7..569e6cc1f13 100644 --- a/tests/torch/ptq/test_weights_compression.py +++ b/tests/torch/ptq/test_weights_compression.py @@ -28,7 +28,7 @@ ALL_SENSITIVITY_METRICS = DATA_BASED_SENSITIVITY_METRICS + (SensitivityMetric.WEIGHT_QUANTIZATION_ERROR,) -SUPPORTED_MODES = (CompressWeightsMode.INT8, CompressWeightsMode.INT8_ASYM) +SUPPORTED_MODES = (CompressWeightsMode.INT8, CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT8_SYM) UNSUPPORTED_MODES = ( CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM, @@ -106,12 +106,14 @@ def forward(self, input_): return x -def test_compress_weights(): +@pytest.mark.parametrize("mode", (CompressWeightsMode.INT8_SYM, CompressWeightsMode.INT8_ASYM)) +def test_compress_weights(mode): model = ShortTransformer(5, 10) + dtype = torch.int8 if mode == CompressWeightsMode.INT8_SYM else torch.uint8 input_ids = torch.randint(0, 10, (5,)) wrapped_model = wrap_model(model, example_input=input_ids, trace_parameters=True) - compressed_model = compress_weights(wrapped_model) + compressed_model = compress_weights(wrapped_model, mode=mode) n_compressed_weights = 0 n_target_modules = 0 @@ -119,7 +121,7 @@ def test_compress_weights(): for _, module in compressed_model.named_children(): if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)): n_target_modules += 1 - if module.weight.dtype in [torch.uint8, torch.int8]: + if module.weight.dtype == dtype: n_compressed_weights += 1 assert n_compressed_weights == n_target_modules @@ -158,12 +160,14 @@ def test_compress_weights_conv(): assert n_compressed_weights == n_target_modules -def test_compress_shared_weights(mocker): +@pytest.mark.parametrize("mode", (CompressWeightsMode.INT8_SYM, CompressWeightsMode.INT8_ASYM)) +def test_compress_shared_weights(mocker, mode): model = ShortTransformer(5, 10, share_weights=True) + dtype = torch.int8 if mode == CompressWeightsMode.INT8_SYM else torch.uint8 input_ids = torch.randint(0, 10, (5,)) wrapped_model = wrap_model(model, example_input=input_ids, trace_parameters=True) - compressed_model = compress_weights(wrapped_model) + compressed_model = compress_weights(wrapped_model, mode=mode) n_compressed_weights = 0 n_target_modules = 0 @@ -171,7 +175,7 @@ def test_compress_shared_weights(mocker): for _, module in compressed_model.named_children(): if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)): n_target_modules += 1 - if module.weight.dtype in [torch.uint8, torch.int8]: + if module.weight.dtype == dtype: n_compressed_weights += 1 assert n_compressed_weights == n_target_modules @@ -215,7 +219,7 @@ def test_raise_error_with_unsupported_params_for_int8(mode, params): @pytest.mark.parametrize("mode", UNSUPPORTED_MODES) -def test_raise_error_with_not_int8_asym(mode): +def test_raise_error_with_not_int8(mode): dummy_torch_model = EmptyModel() dummy_input = torch.Tensor() wrapped_model = wrap_model(dummy_torch_model, example_input=dummy_input, trace_parameters=True) From 1d4be3713f175e828c46e304ddd4418654cb8811 Mon Sep 17 00:00:00 2001 From: Liubov Talamanova Date: Wed, 24 Apr 2024 15:30:08 +0100 Subject: [PATCH 02/10] Fix tests --- nncf/torch/quantization/layers.py | 1 + .../quantization/test_weights_compression.py | 23 ++++++++----------- tests/torch/ptq/test_weights_compression.py | 6 ++--- 3 files changed, 14 insertions(+), 16 deletions(-) diff --git a/nncf/torch/quantization/layers.py b/nncf/torch/quantization/layers.py index ec5e08d73d7..ab16a48cb3d 100644 --- a/nncf/torch/quantization/layers.py +++ b/nncf/torch/quantization/layers.py @@ -1078,6 +1078,7 @@ def __init__(self, scale: torch.Tensor, result_dtype: torch.dtype = None): """ super().__init__() self.register_buffer("_scale", scale) + self.result_dtype = result_dtype def forward(self, x): zero_point = torch.zeros_like(self._scale) diff --git a/tests/openvino/native/quantization/test_weights_compression.py b/tests/openvino/native/quantization/test_weights_compression.py index 804baf608a4..3ca449bba9b 100644 --- a/tests/openvino/native/quantization/test_weights_compression.py +++ b/tests/openvino/native/quantization/test_weights_compression.py @@ -277,7 +277,7 @@ def test_gather_in_4_bit_if_all_layers_with_data(metric): sensitivity_metric=metric, dataset=dataset, ) - int4_reference_node_names = ["gather_2_data", "gather_2_data/zero_point"] + int4_reference_node_names = ["gather_2_data"] nodes_map = {op.get_friendly_name(): op for op in compressed_model.get_ordered_ops()} for node_name in int4_reference_node_names: node = nodes_map[node_name] @@ -306,17 +306,13 @@ def test_gather_can_be_8_bit_if_all_layers_without_data(): def test_conv_in_8_bit_if_mode_8bit(mode): model = WeightsModel().ov_model compressed_model = compress_weights(model, mode=mode) - int8_reference_node_names = [ - "conv_weights_0", - "conv_weights_0/zero_point", - "conv_weights_1", - "conv_weights_1/zero_point", - ] + int8_reference_node_names = ["conv_weights_0", "conv_weights_1"] nodes_map = {op.get_friendly_name(): op for op in compressed_model.get_ordered_ops()} + dtype = ov.Type.u8 if mode == CompressWeightsMode.INT8_ASYM else ov.Type.i8 for node_name in int8_reference_node_names: node = nodes_map[node_name] assert node.get_type_name() == "Constant" - assert node.get_element_type() == ov.Type.u8 + assert node.get_element_type() == dtype @pytest.mark.parametrize("all_layers", (True, False)) @@ -339,9 +335,9 @@ def test_conv_in_8_bit_if_mode_4bit(all_layers): ]: assert ov.Type.u8 == op.get_element_type() elif op.get_friendly_name() in ["weights_1", "weights_1/zero_point"]: - assert ov.Type.u4 == op.get_element_type() + assert ov.Type.i4 == op.get_element_type() elif op.get_friendly_name() in ["weights_0", "weights_0/zero_point"]: - dtype = ov.Type.u4 if all_layers else ov.Type.u8 + dtype = ov.Type.i4 if all_layers else ov.Type.u8 assert dtype == op.get_element_type() @@ -354,7 +350,7 @@ def test_gather_can_be_4_bit_if_all_layers_without_data(): group_size=1, all_layers=True, ) - int4_reference_node_names = ["gather_2_data", "gather_2_data/zero_point"] + int4_reference_node_names = ["gather_2_data"] nodes_map = {op.get_friendly_name(): op for op in compressed_model.get_ordered_ops()} for node_name in int4_reference_node_names: node = nodes_map[node_name] @@ -380,7 +376,7 @@ def test_gather_in_8_bit_if_not_all_layers(metric): for node_name in int8_reference_node_names: node = nodes_map[node_name] assert node.get_type_name() == "Constant" - assert node.get_element_type() == ov.Type.i8 + assert node.get_element_type() == ov.Type.u8 MAX_BASELINE_SCORE = 1 / np.finfo(np.float32).eps @@ -427,9 +423,10 @@ def test_data_based_criterion(mode, ref_scores, ref_act_scores, mocker): def test_quantize_Gather_with_multiple_reduction_axes_in_8bit(mode): model = GatherWithTwoReductionAxes().ov_model compressed_model = compress_weights(model, mode=mode) + dtype = ov.Type.u8 if mode == CompressWeightsMode.INT8_ASYM else ov.Type.i8 for op in compressed_model.get_ordered_ops(): if op.get_type_name() == "Constant" and op.get_friendly_name() == "gather_1_data": - assert op.get_element_type() == ov.Type.u8 + assert op.get_element_type() == dtype @pytest.mark.parametrize("mode", (CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM)) diff --git a/tests/torch/ptq/test_weights_compression.py b/tests/torch/ptq/test_weights_compression.py index 569e6cc1f13..a4d165c9077 100644 --- a/tests/torch/ptq/test_weights_compression.py +++ b/tests/torch/ptq/test_weights_compression.py @@ -17,7 +17,7 @@ from nncf import SensitivityMetric from nncf.quantization import compress_weights from nncf.torch import wrap_model -from nncf.torch.quantization.layers import WeightsDecompressor +from nncf.torch.quantization.layers import SymmetricWeightsDecompressor DATA_BASED_SENSITIVITY_METRICS = ( SensitivityMetric.HESSIAN_INPUT_ACTIVATION, @@ -132,11 +132,11 @@ def test_compress_weights_functional_model(): input_ids = torch.randint(0, 10, [1, 3, 300, 300]) wrapped_model = wrap_model(model, example_input=input_ids, trace_parameters=True) - compressed_model = compress_weights(wrapped_model) + compressed_model = compress_weights(wrapped_model, mode=CompressWeightsMode.INT8_SYM) n_compressed_weights = 0 for layer in compressed_model.nncf.external_op.values(): - if isinstance(layer, WeightsDecompressor): + if isinstance(layer, SymmetricWeightsDecompressor): n_compressed_weights += 1 assert n_compressed_weights == 4 From 9fb58623aea07b736bd5c05fcdd826d5667e184b Mon Sep 17 00:00:00 2001 From: Liubov Talamanova Date: Thu, 16 May 2024 12:08:31 +0100 Subject: [PATCH 03/10] minor fixes --- .../weights_compression/Usage.md | 2 +- nncf/parameters.py | 4 ++-- .../algorithms/weight_compression/algorithm.py | 4 ++-- nncf/quantization/quantize_model.py | 6 +++--- tests/torch/ptq/test_weights_compression.py | 11 ++++++++--- 5 files changed, 16 insertions(+), 11 deletions(-) diff --git a/docs/usage/post_training_compression/weights_compression/Usage.md b/docs/usage/post_training_compression/weights_compression/Usage.md index a27bb2bb50a..8d2f56143f5 100644 --- a/docs/usage/post_training_compression/weights_compression/Usage.md +++ b/docs/usage/post_training_compression/weights_compression/Usage.md @@ -9,7 +9,7 @@ The Weights Compression algorithm is aimed at compressing the weights of the mod #### Supported modes By default, weights are compressed asymmetrically to 8-bit integer data type - "INT8_ASYM" mode. -OpenVINO backend also supports 3 modes of mixed precision weight quantization with a 4-bit data type as a primary precision - INT4_SYM, INT4_ASYM and NF4. The primary precision in case of INT4_SYM mode is signed 4-bit integer and weights are quantized to it [symmetrically](/docs/usage/training_time_compression/other_algorithms/LegacyQuantization.md#symmetric-quantization) without zero point. In case of INT4_ASYM mode - unsigned 4-bit integer, but weight are quantized to it [asymmetrically](/docs/usage/training_time_compression/other_algorithms/LegacyQuantization.md#asymmetric-quantization) with a typical non-fixed zero point. In case of NF4 mode - [nf4](https://arxiv.org/pdf/2305.14314v1.pdf) data type without zero point. +OpenVINO backend also supports 3 modes of mixed precision weight quantization with a 4-bit data type as a primary precision - INT4_SYM, INT4_ASYM and NF4. The primary precision in case of INT4_SYM mode is signed 4-bit integer and weights are quantized to it [symmetrically](/docs/usage/training_time_compression/other_algorithms/LegacyQuantization.md#symmetric-quantization) without zero point. In case of INT4_ASYM mode - unsigned 4-bit integer and weight are quantized to it [asymmetrically](/docs/usage/training_time_compression/other_algorithms/LegacyQuantization.md#asymmetric-quantization) with a typical non-fixed zero point. In case of NF4 mode - [nf4](https://arxiv.org/pdf/2305.14314v1.pdf) data type without zero point. All 4-bit modes have a grouped quantization support, when small group of weights (e.g. 128) in the channel dimension share quantization parameters (scale). All embeddings, convolutions and last linear layers are always compressed to 8-bit integer data type. To quantize embeddings and last linear layers to 4-bit, use `all_layers=True`. Percent of the rest layers compressed to 4-bit can be configured by "ratio" parameter. E.g. ratio=0.9 means 90% of layers compressed to the corresponding 4-bit data type and the rest to 8-bit asymmetric integer data type. diff --git a/nncf/parameters.py b/nncf/parameters.py index 3cfeb246668..00cf487b9d4 100644 --- a/nncf/parameters.py +++ b/nncf/parameters.py @@ -68,13 +68,13 @@ class CompressWeightsMode(StrEnum): """ Defines a mode for weight compression. :param INT8_SYM: Stands for 8-bit integer symmetric quantization of all weights. - Weights are quantized symmetrically with a fixed zero point equals to 128. + Weights are quantized symmetrically without zero point. https://github.com/openvinotoolkit/nncf/blob/develop/docs/usage/training_time_compression/other_algorithms/LegacyQuantization.md#symmetric-quantization :param INT8_ASYM: The same as INT8_SYM mode, but weights are quantized to a primary precision asymmetrically with a typical non-fixed zero point. https://github.com/openvinotoolkit/nncf/blob/develop/docs/compression_algorithms/Quantization.md#asymmetric-quantization :param INT4_SYM: Stands for a mixed-precision weights quantization with 4-bit integer as a primary precision. - Weights are quantized to a primary precision symmetrically with a fixed zero point equals to 8. + Weights are quantized to a primary precision symmetrically without zero point. All embeddings and the last layer are always compressed to a backup precision, which is INT8_ASYM, by default. All others are quantized whether to 4-bit integer or to a backup precision depending on criteria and the given ratio. diff --git a/nncf/quantization/algorithms/weight_compression/algorithm.py b/nncf/quantization/algorithms/weight_compression/algorithm.py index 4609df0a463..8fb6e2c1b9b 100644 --- a/nncf/quantization/algorithms/weight_compression/algorithm.py +++ b/nncf/quantization/algorithms/weight_compression/algorithm.py @@ -70,11 +70,11 @@ def __init__( """ :param mode: Defines a mode for weight compression. INT8_SYM stands for 8-bit integer symmetric quantization of all weights. - Weights are quantized symmetrically with a fixed zero point equals to 128. + Weights are quantized symmetrically without zero point. INT8_ASYM is the same as INT8_SYM mode, but weights are quantized to a primary precision asymmetrically with a typical non-fixed zero point. INT4_SYM stands for a mixed-precision weights quantization with 4-bit integer as a primary precision. - Weights are quantized to a primary precision symmetrically with a fixed zero point equals to 8. + Weights are quantized to a primary precision symmetrically without zero point. All embeddings and the last layer are always compressed to a backup precision, which is INT8_ASYM, by default. All others are quantized whether to 4-bit integer or to a backup precision depending on criteria and the given ratio. diff --git a/nncf/quantization/quantize_model.py b/nncf/quantization/quantize_model.py index fc6f5c07cde..46426d77c21 100644 --- a/nncf/quantization/quantize_model.py +++ b/nncf/quantization/quantize_model.py @@ -348,11 +348,11 @@ def compress_weights( :param model: A model to be compressed. :type model: TModel :param mode: Defines a mode for weight compression. - INT8_SYM stands for 8-bit integer symmetric quantization of all weights. + INT8_SYM stands for 8-bit integer symmetric quantization of all weights without zero point. INT8_ASYM is the same as INT8_SYM mode, but weights are quantized to a primary precision asymmetrically with a typical non-fixed zero point. INT4_SYM stands for a mixed-precision weights quantization with 4-bit integer as a primary precision. - Weights are quantized to a primary precision symmetrically with a fixed zero point equals to 8. + Weights are quantized to a primary precision symmetrically without zero point. All embeddings and the last layer are always compressed to a backup precision, which is INT8_ASYM, by default. All others are quantized whether to 4-bit integer or to a backup precision depending on criteria and the given ratio. @@ -393,7 +393,7 @@ def compress_weights( """ if mode == CompressWeightsMode.INT8: warning_deprecated( - "`CompressWeightsMode.INT8` is deprecated." "Please, use `CompressWeightsMode.INT8_ASYM` as value instead." + "`CompressWeightsMode.INT8` is deprecated. Please, use `CompressWeightsMode.INT8_ASYM` as value instead." ) mode = CompressWeightsMode.INT8_ASYM diff --git a/tests/torch/ptq/test_weights_compression.py b/tests/torch/ptq/test_weights_compression.py index a4d165c9077..30c704e5435 100644 --- a/tests/torch/ptq/test_weights_compression.py +++ b/tests/torch/ptq/test_weights_compression.py @@ -17,6 +17,7 @@ from nncf import SensitivityMetric from nncf.quantization import compress_weights from nncf.torch import wrap_model +from nncf.torch.quantization.layers import AsymmetricWeightsDecompressor from nncf.torch.quantization.layers import SymmetricWeightsDecompressor DATA_BASED_SENSITIVITY_METRICS = ( @@ -127,16 +128,20 @@ def test_compress_weights(mode): assert n_compressed_weights == n_target_modules -def test_compress_weights_functional_model(): +@pytest.mark.parametrize("mode", (CompressWeightsMode.INT8_SYM, CompressWeightsMode.INT8_ASYM)) +def test_compress_weights_functional_model(mode): model = FunctionalModel() + decompressor_type = ( + SymmetricWeightsDecompressor if mode == CompressWeightsMode.INT8_SYM else AsymmetricWeightsDecompressor + ) input_ids = torch.randint(0, 10, [1, 3, 300, 300]) wrapped_model = wrap_model(model, example_input=input_ids, trace_parameters=True) - compressed_model = compress_weights(wrapped_model, mode=CompressWeightsMode.INT8_SYM) + compressed_model = compress_weights(wrapped_model, mode=mode) n_compressed_weights = 0 for layer in compressed_model.nncf.external_op.values(): - if isinstance(layer, SymmetricWeightsDecompressor): + if isinstance(layer, decompressor_type): n_compressed_weights += 1 assert n_compressed_weights == 4 From 2e00a3ae35bb4a2c7f45af83be3cca7aa801e0ce Mon Sep 17 00:00:00 2001 From: Liubov Talamanova Date: Sun, 19 May 2024 23:27:12 +0100 Subject: [PATCH 04/10] Update wc conformance tests --- tests/post_training/data/wc_reference_data.yaml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/post_training/data/wc_reference_data.yaml b/tests/post_training/data/wc_reference_data.yaml index ae59fdfb3fb..9f1e74e62b7 100644 --- a/tests/post_training/data/wc_reference_data.yaml +++ b/tests/post_training/data/wc_reference_data.yaml @@ -1,22 +1,22 @@ tinyllama_data_free_backend_OV: metric_value: 0.72057 - num_int4: 228 + num_int4: 114 num_int8: 84 tinyllama_data_aware_backend_OV: metric_value: 0.83853 - num_int4: 188 + num_int4: 94 num_int8: 124 tinyllama_data_aware_awq_stateful_backend_OV: metric_value: 0.85259 - num_int4: 188 + num_int4: 94 num_int8: 124 tinyllama_data_aware_awq_scale_estimation_backend_OV: metric_value: 0.8404 - num_int4: 188 + num_int4: 94 num_int8: 124 tinyllama_data_aware_awq_scale_estimation_stateful_backend_OV: metric_value: 0.8404 - num_int4: 188 + num_int4: 94 num_int8: 124 tinyllama_int8_data_free_backend_TORCH: metric_value: 0.95624 From 38fd9732e84686ea4db156a5de6c4926e33066a2 Mon Sep 17 00:00:00 2001 From: Liubov Talamanova Date: Wed, 5 Jun 2024 13:59:46 +0100 Subject: [PATCH 05/10] fix test --- .../algorithms/weight_compression/weight_lowering.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nncf/quantization/algorithms/weight_compression/weight_lowering.py b/nncf/quantization/algorithms/weight_compression/weight_lowering.py index 050467269df..e519a2727cc 100644 --- a/nncf/quantization/algorithms/weight_compression/weight_lowering.py +++ b/nncf/quantization/algorithms/weight_compression/weight_lowering.py @@ -237,7 +237,7 @@ def calculate_integer_quantization_params( level_high = 2 ** (num_bits - 1) - 1 scale = fns.max(fns.abs(weight), axis=reduction_axes, keepdims=True) # [a1, r//gs, 1, a2] scale = scale / level_high - zero_point = fns.zeros_like(scale).astype(TensorDataType.int32) + zero_point = fns.as_tensor_like(scale, [0]).astype(TensorDataType.int32) eps = fns.finfo(scale).eps # NOTE: adding machine epsilon to avoid division by zero scale = fns.where(fns.abs(scale) < eps, eps, scale) From 6cc5a06ee7a815d5d642740fb2c2668b8506d514 Mon Sep 17 00:00:00 2001 From: Liubov Talamanova Date: Thu, 6 Jun 2024 12:42:03 +0100 Subject: [PATCH 06/10] Apply comments --- .../algorithms/weight_compression/gptq.py | 14 ++-- .../weight_compression/mixed_precision.py | 3 +- .../weight_compression/openvino_backend.py | 62 ++++++++-------- .../weight_compression/scale_estimation.py | 74 +++++++++++-------- .../weight_compression/torch_backend.py | 2 - .../weight_compression/weight_lowering.py | 37 +++++----- nncf/torch/quantization/layers.py | 3 +- nncf/torch/quantization/quantize_functions.py | 8 +- 8 files changed, 109 insertions(+), 94 deletions(-) diff --git a/nncf/quantization/algorithms/weight_compression/gptq.py b/nncf/quantization/algorithms/weight_compression/gptq.py index 282e2f6910a..ee2940a86aa 100644 --- a/nncf/quantization/algorithms/weight_compression/gptq.py +++ b/nncf/quantization/algorithms/weight_compression/gptq.py @@ -263,7 +263,7 @@ def _quantize_weights( quantized_col = decompress_nf4_weight(compressed_weights, scales[-1]) else: compressed_weights = calculate_quantized_weight( - fns.unsqueeze(weight_col, 1), scales[-1], zero_points[-1], block_compression_config + fns.unsqueeze(weight_col, 1), block_compression_config, scales[-1], zero_points[-1] ) quantized_col = do_dequantization(compressed_weights, scales[-1], zero_points[-1]) quantized_col = fns.flatten(quantized_col) @@ -287,13 +287,11 @@ def _quantize_weights( ) scales = fns.stack(scales, axis=1) - if wc_params.compression_config.mode == CompressWeightsMode.NF4: - zero_points = None - elif wc_params.compression_config.mode in [ - CompressWeightsMode.INT8_SYM, - CompressWeightsMode.INT4_SYM, + if wc_params.compression_config.mode in [ + CompressWeightsMode.INT8_ASYM, + CompressWeightsMode.INT4_ASYM, ]: - zero_points = fns.squeeze(zero_points[0]) - else: zero_points = fns.stack(zero_points, axis=1) + else: + zero_points = None return scales, zero_points diff --git a/nncf/quantization/algorithms/weight_compression/mixed_precision.py b/nncf/quantization/algorithms/weight_compression/mixed_precision.py index 29a5fc4bb3a..c4e4c73fa92 100644 --- a/nncf/quantization/algorithms/weight_compression/mixed_precision.py +++ b/nncf/quantization/algorithms/weight_compression/mixed_precision.py @@ -22,6 +22,7 @@ from nncf.quantization.algorithms.weight_compression.backend import WeightCompressionAlgoBackend from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters +from nncf.quantization.algorithms.weight_compression.weight_lowering import do_dequantization from nncf.quantization.algorithms.weight_compression.weight_lowering import do_integer_quantization from nncf.quantization.algorithms.weight_compression.weight_lowering import get_integer_quantization_error @@ -176,7 +177,7 @@ def _calc_weight_sensitivity(self, weight_param: WeightCompressionParameters) -> weight = weight.astype(TensorDataType.float32) compressed_weights, scale, zero_point = do_integer_quantization(weight, reduction_axes, backup_config) - decompressed_weight = (compressed_weights - zero_point).astype(weight.dtype) * scale + decompressed_weight = do_dequantization(compressed_weights, scale, zero_point) decompressed_weight = decompressed_weight.reshape(orig_shape) return fns.linalg.norm(decompressed_weight - weight, ord="fro").item() diff --git a/nncf/quantization/algorithms/weight_compression/openvino_backend.py b/nncf/quantization/algorithms/weight_compression/openvino_backend.py index 859112a395e..33ecb91a56f 100644 --- a/nncf/quantization/algorithms/weight_compression/openvino_backend.py +++ b/nncf/quantization/algorithms/weight_compression/openvino_backend.py @@ -21,7 +21,6 @@ from nncf.common.graph.utils import get_reduction_axes from nncf.experimental.common.tensor_statistics.collectors import TensorCollector from nncf.experimental.tensor.definitions import TensorDataType -from nncf.experimental.tensor.functions import count_nonzero from nncf.experimental.tensor.tensor import Tensor from nncf.openvino.graph.metatypes import openvino_metatypes as om from nncf.openvino.graph.model_transformer import OVModelTransformer @@ -174,19 +173,16 @@ def transform_model( compressed_weight.tensor.data, dtype=compression_dtype, name=const_node_name ) converted_const = opset.convert(compressed_const, ov.Type.f16) - if compressed_weight.zero_point is not None: - if compressed_weight.tensor.dtype == TensorDataType.int8: - assert count_nonzero(compressed_weight.zero_point.data) == 0 - else: - zero_point_const = opset.constant( - compressed_weight.zero_point.data, - dtype=compression_dtype, - name=f"{const_node_name}/zero_point", - ) - converted_zero_point = opset.convert(zero_point_const, ov.Type.f16) - converted_const = opset.subtract( - converted_const, converted_zero_point, name=f"{const_node_name}/zero_point/subtract" - ) + if compressed_weight.zero_point is not None and compressed_weight.tensor.dtype == TensorDataType.uint8: + zero_point_const = opset.constant( + compressed_weight.zero_point.data, + dtype=compression_dtype, + name=f"{const_node_name}/zero_point", + ) + converted_zero_point = opset.convert(zero_point_const, ov.Type.f16) + converted_const = opset.subtract( + converted_const, converted_zero_point, name=f"{const_node_name}/zero_point/subtract" + ) scale_const = opset.constant( compressed_weight.scale.data, dtype=ov.Type.f16, name=f"{const_node_name}/scale" @@ -222,27 +218,28 @@ def dump_parameters( @staticmethod def get_compress_decompress_pipeline( - weight_compression_parameter: WeightCompressionParameters, w_shape, s_shape, z_p_shape + weight_compression_parameter: WeightCompressionParameters, w_shape, s_shape, z_p_shape=None ): - ( - w, - s, - zp, - clamp, - ) = OVWeightCompressionAlgoBackend.get_compress_pipeline( + parameters, clamp = OVWeightCompressionAlgoBackend.get_compress_pipeline( weight_compression_parameter, w_shape, s_shape, z_p_shape, True ) - result = (clamp - zp) * s - model = ov.Model([result], [w, s, zp]) + if len(parameters) == 3: + _, s, zp = parameters + result = (clamp - zp) * s + else: + s = parameters[1] + result = clamp * s + + model = ov.Model([result], parameters) compiled_model = ov.compile_model(model) - return lambda w, s, zp: compiled_model([w, s, zp])[0] + return lambda parameters: compiled_model(parameters)[0] @staticmethod def get_compress_pipeline( - weight_compression_parameter: WeightCompressionParameters, w_shape, s_shape, z_p_shape, return_nodes=False + weight_compression_parameter: WeightCompressionParameters, w_shape, s_shape, z_p_shape=None, return_nodes=False ): config = weight_compression_parameter.compression_config mode = config.mode @@ -254,18 +251,23 @@ def get_compress_pipeline( w = opset.parameter(w_shape, name="w") s = opset.parameter(s_shape, name="s") - zp = opset.parameter(z_p_shape, name="zp") + parameters = [w, s] + compressed_w = w / s + if z_p_shape is not None: + zp = opset.parameter(z_p_shape, name="zp") + parameters.append(zp) + compressed_w += zp - result = opset.clamp(opset.round(w / s + zp), level_low, level_high, name="compressed_weights") + result = opset.clamp(opset.round(compressed_w), level_low, level_high, name="compressed_weights") if return_nodes: - return w, s, zp, result + return parameters, result - model = ov.Model([result], [w, s, zp]) + model = ov.Model([result], parameters) compiled_model = ov.compile_model(model) - return lambda w, s, zp: compiled_model([w, s, zp])[0] + return lambda parameters: compiled_model(parameters)[0] class OVAWQAlgoAlgoBackend(OVWeightCompressionAlgoBackend): diff --git a/nncf/quantization/algorithms/weight_compression/scale_estimation.py b/nncf/quantization/algorithms/weight_compression/scale_estimation.py index 3cc9d7f4180..ea414d8fe9c 100644 --- a/nncf/quantization/algorithms/weight_compression/scale_estimation.py +++ b/nncf/quantization/algorithms/weight_compression/scale_estimation.py @@ -117,7 +117,7 @@ def apply( :return: Dict with pairs (weight name, estimated scale). """ - compress_decompress_cashe = {} + compress_decompress_cache = {} res = dict() for wp in track(self._all_weight_params, description="Applying Scale Estimation"): @@ -165,8 +165,6 @@ def apply( original_weight = fns.zeros_like(weight) + weight compressed_weights, scale, zp = do_integer_quantization(original_weight, reduction_axis, config) - zp = zp.astype(scale.dtype) - q_weights = do_dequantization(compressed_weights, scale, zp, reduction_axis) s = fns.unsqueeze(s, 0) @@ -180,9 +178,7 @@ def apply( importance = fns.ones_like(original_weight) importance = importance * s - target = compressed_weights.astype(dtype=zp.dtype) - zp - zero_mask = compressed_weights == zp - + target, zero_mask = get_target_zero_mask(compressed_weights, zp) importance = fns.where(zero_mask, 0.0, importance) # normalize importances for every group of weights to make sum of them equal to 1.0 @@ -203,18 +199,20 @@ def apply( if self._weight_penalty > 0.0: min_max_scale_diffs += self._weight_penalty * fns.mean((q_weights - original_weight) ** 2, axis=-1) - key = ( - (wp.compression_config.mode, wp.compression_config.num_bits) + q_weights.shape + scale.shape + zp.shape - ) - if key in compress_decompress_cashe: - compress_decompress_model = compress_decompress_cashe[key]["compress_decompress_model"] - compress_model = compress_decompress_cashe[key]["compress_model"] + zp_shape = zp.shape if zp is not None else None + key = [(wp.compression_config.mode, wp.compression_config.num_bits) + q_weights.shape + scale.shape] + if zp is not None: + key += zp_shape + key = tuple(key) + if key in compress_decompress_cache: + compress_decompress_model = compress_decompress_cache[key]["compress_decompress_model"] + compress_model = compress_decompress_cache[key]["compress_model"] else: compress_decompress_model = self._backend_entity.get_compress_decompress_pipeline( - wp, q_weights.shape, scale.shape, zp.shape + wp, q_weights.shape, scale.shape, zp_shape ) - compress_model = self._backend_entity.get_compress_pipeline(wp, q_weights.shape, scale.shape, zp.shape) - compress_decompress_cashe[key] = { + compress_model = self._backend_entity.get_compress_pipeline(wp, q_weights.shape, scale.shape, zp_shape) + compress_decompress_cache[key] = { "compress_decompress_model": compress_decompress_model, "compress_model": compress_model, } @@ -222,14 +220,15 @@ def apply( zero_scale = 0.001 zero_mask = zero_scale * zero_mask.astype(original_weight.dtype) + input_tensors = [original_weight.data, None] + if zp is not None: + input_tensors.append(zp.data) # iterative rectification of initial scale for i in range(self._initial_steps): - ideal_scale = fns.abs(original_weight) / (fns.abs(target) + zero_mask) - weighted_scale = ideal_scale * importance + near_to_ideal_scale = get_near_to_ideal_scale(original_weight, target, zero_mask, importance) + input_tensors[1] = near_to_ideal_scale.data - near_to_ideal_scale = fns.sum(weighted_scale, axis=2, keepdims=True) - - out = compress_decompress_model(original_weight.data, near_to_ideal_scale.data, zp.data) + out = compress_decompress_model(input_tensors) q_weights_ = fns.zeros_like(original_weight) + out q_outs = fns.matmul(fns.transpose(q_weights_, (1, 0, 2)), X) @@ -252,12 +251,12 @@ def apply( else: near_to_ideal_scale = mask * result_scale + (1.0 - mask) * near_to_ideal_scale result_scale = near_to_ideal_scale + input_tensors[1] = near_to_ideal_scale.data if i < self._initial_steps - 1: - out = compress_model(original_weight.data, near_to_ideal_scale.data, zp.data) + out = compress_model(input_tensors) compressed_weights = fns.zeros_like(original_weight) + out - target = compressed_weights - zp - zero_mask = compressed_weights == zp + target, zero_mask = get_target_zero_mask(compressed_weights, zp) zero_mask = zero_scale * zero_mask.astype(original_weight.dtype) # iterative rectification of scale based on grid search @@ -265,18 +264,16 @@ def apply( factor = 1.0 - 0.05 * scale_steps scaled_scale = factor * scale - out = compress_model(original_weight.data, scaled_scale.data, zp.data) + input_tensors[1] = scaled_scale.data + out = compress_model(input_tensors) compressed_weights = fns.zeros_like(original_weight) + out - target = compressed_weights - zp - zero_mask = compressed_weights == zp + target, zero_mask = get_target_zero_mask(compressed_weights, zp) zero_mask = zero_scale * zero_mask.astype(original_weight.dtype) + near_to_ideal_scale = get_near_to_ideal_scale(original_weight, target, zero_mask, importance) - ideal_scale = fns.abs(original_weight) / (fns.abs(target) + zero_mask) - weighted_scale = ideal_scale * importance - near_to_ideal_scale = fns.sum(weighted_scale, axis=2, keepdims=True) - - out = compress_decompress_model(original_weight.data, near_to_ideal_scale.data, zp.data) + input_tensors[1] = near_to_ideal_scale.data + out = compress_decompress_model(input_tensors) q_weights_ = fns.zeros_like(original_weight) + out q_outs = fns.matmul(fns.transpose(q_weights_, (1, 0, 2)), X) @@ -300,3 +297,18 @@ def apply( res[weight_name] = result_scale return res + + +def get_target_zero_mask(compressed_weights, zp=None): + target = compressed_weights + if zp is not None: + target = target.astype(dtype=zp.dtype) - zp + zero_mask = fns.isclose(target, 0) + return target, zero_mask + + +def get_near_to_ideal_scale(weight, target, zero_mask, importance): + ideal_scale = fns.abs(weight) / (fns.abs(target) + zero_mask) + weighted_scale = ideal_scale * importance + near_to_ideal_scale = fns.sum(weighted_scale, axis=2, keepdims=True) + return near_to_ideal_scale diff --git a/nncf/quantization/algorithms/weight_compression/torch_backend.py b/nncf/quantization/algorithms/weight_compression/torch_backend.py index f80fd3a7864..6e1b4a7a6c1 100644 --- a/nncf/quantization/algorithms/weight_compression/torch_backend.py +++ b/nncf/quantization/algorithms/weight_compression/torch_backend.py @@ -23,7 +23,6 @@ from nncf.common.graph.transformations.layout import TransformationLayout from nncf.experimental.common.tensor_statistics.collectors import TensorCollector from nncf.experimental.tensor.definitions import TensorDataType -from nncf.experimental.tensor.functions import count_nonzero from nncf.experimental.tensor.tensor import Tensor from nncf.parameters import CompressWeightsMode from nncf.quantization.algorithms.weight_compression.backend import WeightCompressionAlgoBackend @@ -233,7 +232,6 @@ def transform_model( # creates weight decompressor if compression_config.mode == CompressWeightsMode.INT8_SYM: - assert count_nonzero(compressed_weight.zero_point) == 0 decompressor = SymmetricWeightsDecompressor(compressed_weight.scale.data, result_dtype=weight.dtype) else: packed_zero_point = compressed_weight.zero_point.astype(dtype) diff --git a/nncf/quantization/algorithms/weight_compression/weight_lowering.py b/nncf/quantization/algorithms/weight_compression/weight_lowering.py index e519a2727cc..54a2ef90754 100644 --- a/nncf/quantization/algorithms/weight_compression/weight_lowering.py +++ b/nncf/quantization/algorithms/weight_compression/weight_lowering.py @@ -233,28 +233,27 @@ def calculate_integer_quantization_params( scale, zero_point = calculate_scale_zero_point( min_values, max_values, level_low, level_high, narrow_range=False ) - else: - level_high = 2 ** (num_bits - 1) - 1 - scale = fns.max(fns.abs(weight), axis=reduction_axes, keepdims=True) # [a1, r//gs, 1, a2] - scale = scale / level_high - zero_point = fns.as_tensor_like(scale, [0]).astype(TensorDataType.int32) - eps = fns.finfo(scale).eps - # NOTE: adding machine epsilon to avoid division by zero - scale = fns.where(fns.abs(scale) < eps, eps, scale) + return scale, zero_point - return scale, zero_point + level_high = 2 ** (num_bits - 1) - 1 + scale = fns.max(fns.abs(weight), axis=reduction_axes, keepdims=True) # [a1, r//gs, 1, a2] + scale /= level_high + eps = fns.finfo(scale).eps + # NOTE: adding machine epsilon to avoid division by zero + scale = fns.where(fns.abs(scale) < eps, eps, scale) + return scale, None def calculate_quantized_weight( - weight: Tensor, scale: Tensor, zero_point: Tensor, config: WeightCompressionConfig + weight: Tensor, config: WeightCompressionConfig, scale: Tensor, zero_point: Optional[Tensor] = None ) -> Tensor: """ Quantizes the weight tensor using the provided scale and zero point. :param weight: Weight tensor to quantize. + :param config: Weight compression configuration. :param scale: Scale tensor used for quantization. :param zero_point: Zero point tensor used for quantization. - :param config: Weight compression configuration. :return: Quantized weight tensor of uint8 or int8 type. """ if weight.dtype != TensorDataType.float32: @@ -268,7 +267,10 @@ def calculate_quantized_weight( level_low = 0 if asym_quant else -(2 ** (num_bits - 1)) level_high = 2**num_bits - 1 if asym_quant else 2 ** (num_bits - 1) - 1 - compressed_weights = fns.round(weight / scale + zero_point.astype(weight.dtype)) + compressed_weights = weight / scale + if zero_point is not None: + compressed_weights += zero_point.astype(weight.dtype) + compressed_weights = fns.round(compressed_weights) compressed_weights = fns.clip(compressed_weights, level_low, level_high).astype(dtype) return compressed_weights @@ -322,7 +324,7 @@ def do_integer_quantization( if precomputed_zero_point is not None: zero_point = precomputed_zero_point - compressed_weights = calculate_quantized_weight(weight, scale, zero_point, config) + compressed_weights = calculate_quantized_weight(weight, config, scale, zero_point) return compressed_weights, scale, zero_point @@ -344,8 +346,7 @@ def get_integer_quantization_error( weight = weight.astype(TensorDataType.float32) compressed_weights, scale, zero_point = do_integer_quantization(weight, reduction_axes, config) - - decompressed_weight = (compressed_weights - zero_point).astype(weight.dtype) * scale + decompressed_weight = do_dequantization(compressed_weights, scale, zero_point) decompressed_weight = decompressed_weight.reshape(orig_shape) diff = (decompressed_weight - weight) ** 2 @@ -384,7 +385,7 @@ def compress_weight( def do_dequantization( - compressed_weights: Tensor, scale: Tensor, zero_point: Tensor, reduction_axis: int = -1 + compressed_weights: Tensor, scale: Tensor, zero_point: Optional[Tensor] = None, reduction_axis: int = -1 ) -> Tensor: """ The method dequantizes the given weights to float point data type in accordance with the scale and @@ -397,7 +398,9 @@ def do_dequantization( :return: dequantized/decompressed weights. """ decompressed_weight = compressed_weights.astype(dtype=scale.dtype) - decompressed_weight = (decompressed_weight - zero_point) * scale + if zero_point is not None: + decompressed_weight -= zero_point + decompressed_weight *= scale if reduction_axis > -1: shape = list(decompressed_weight.shape) # [a1, r, a2] - "r" refers to number of channels along reduction axis diff --git a/nncf/torch/quantization/layers.py b/nncf/torch/quantization/layers.py index ab16a48cb3d..213d226926e 100644 --- a/nncf/torch/quantization/layers.py +++ b/nncf/torch/quantization/layers.py @@ -1081,7 +1081,6 @@ def __init__(self, scale: torch.Tensor, result_dtype: torch.dtype = None): self.result_dtype = result_dtype def forward(self, x): - zero_point = torch.zeros_like(self._scale) - result = decompress(x, self._scale, zero_point) + result = decompress(x, self._scale) result = result.type(dtype=self.result_dtype) if self.result_dtype is not None else result return result diff --git a/nncf/torch/quantization/quantize_functions.py b/nncf/torch/quantization/quantize_functions.py index 6771d6f38fa..ad7e8c03ca6 100644 --- a/nncf/torch/quantization/quantize_functions.py +++ b/nncf/torch/quantization/quantize_functions.py @@ -8,7 +8,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Optional import torch @@ -249,7 +249,7 @@ def backward(ctx: Any, *grad_outputs: Any) -> Any: @register_operator() -def decompress(input: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor) -> torch.Tensor: +def decompress(input: torch.Tensor, scale: torch.Tensor, zero_point: Optional[torch.Tensor] = None) -> torch.Tensor: """ Decompress the input tensor. @@ -259,5 +259,7 @@ def decompress(input: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tenso :return: The decompressed tensor """ input = input.type(dtype=scale.dtype) - decompressed_input = (input - zero_point) * scale + if zero_point is not None: + input -= zero_point + decompressed_input = input * scale return decompressed_input From d15e460ef01e13dc56c2f6764ab742bfc06e408d Mon Sep 17 00:00:00 2001 From: Liubov Talamanova Date: Fri, 7 Jun 2024 15:01:12 +0100 Subject: [PATCH 07/10] update wc_reference_data.yaml --- nncf/torch/quantization/layers.py | 7 +++--- nncf/torch/quantization/quantize_functions.py | 22 ++++++++++++++----- .../post_training/data/wc_reference_data.yaml | 2 +- 3 files changed, 22 insertions(+), 9 deletions(-) diff --git a/nncf/torch/quantization/layers.py b/nncf/torch/quantization/layers.py index 213d226926e..b84d325526f 100644 --- a/nncf/torch/quantization/layers.py +++ b/nncf/torch/quantization/layers.py @@ -46,7 +46,8 @@ from nncf.torch.quantization.quantize_functions import ExportQuantizeToONNXQuantDequant from nncf.torch.quantization.quantize_functions import TuneRange from nncf.torch.quantization.quantize_functions import asymmetric_quantize -from nncf.torch.quantization.quantize_functions import decompress +from nncf.torch.quantization.quantize_functions import decompress_asymmetric +from nncf.torch.quantization.quantize_functions import decompress_symmetric from nncf.torch.quantization.quantize_functions import get_scale_zp_from_input_low_input_high from nncf.torch.quantization.quantize_functions import symmetric_quantize from nncf.torch.return_types import maybe_get_values_from_torch_return_type @@ -1061,7 +1062,7 @@ def __init__(self, scale: torch.Tensor, zero_point: torch.Tensor, result_dtype: self.result_dtype = result_dtype def forward(self, x): - result = decompress(x, self._scale, self._zero_point) + result = decompress_asymmetric(x, self._scale, self._zero_point) result = result.type(dtype=self.result_dtype) if self.result_dtype is not None else result return result @@ -1081,6 +1082,6 @@ def __init__(self, scale: torch.Tensor, result_dtype: torch.dtype = None): self.result_dtype = result_dtype def forward(self, x): - result = decompress(x, self._scale) + result = decompress_symmetric(x, self._scale) result = result.type(dtype=self.result_dtype) if self.result_dtype is not None else result return result diff --git a/nncf/torch/quantization/quantize_functions.py b/nncf/torch/quantization/quantize_functions.py index ad7e8c03ca6..9b4055c4586 100644 --- a/nncf/torch/quantization/quantize_functions.py +++ b/nncf/torch/quantization/quantize_functions.py @@ -8,7 +8,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Any import torch @@ -249,9 +249,9 @@ def backward(ctx: Any, *grad_outputs: Any) -> Any: @register_operator() -def decompress(input: torch.Tensor, scale: torch.Tensor, zero_point: Optional[torch.Tensor] = None) -> torch.Tensor: +def decompress_asymmetric(input: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor) -> torch.Tensor: """ - Decompress the input tensor. + Decompress the asymmetrically quantized input tensor. :param input: An input tensor :param scale: A scale tensor @@ -259,7 +259,19 @@ def decompress(input: torch.Tensor, scale: torch.Tensor, zero_point: Optional[to :return: The decompressed tensor """ input = input.type(dtype=scale.dtype) - if zero_point is not None: - input -= zero_point + decompressed_input = (input - zero_point) * scale + return decompressed_input + + +@register_operator() +def decompress_symmetric(input: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + """ + Decompress the symmetrically quantized input tensor. + + :param input: An input tensor + :param scale: A scale tensor + :return: The decompressed tensor + """ + input = input.type(dtype=scale.dtype) decompressed_input = input * scale return decompressed_input diff --git a/tests/post_training/data/wc_reference_data.yaml b/tests/post_training/data/wc_reference_data.yaml index 9f1e74e62b7..e9675646e74 100644 --- a/tests/post_training/data/wc_reference_data.yaml +++ b/tests/post_training/data/wc_reference_data.yaml @@ -24,5 +24,5 @@ tinyllama_int8_data_free_backend_TORCH: num_int8: 312 tinyllama_data_aware_gptq_backend_OV: metric_value: 0.83387 - num_int4: 188 + num_int4: 94 num_int8: 124 From ddaf50d56ac4276fb5ceea65172d24f7ade28945 Mon Sep 17 00:00:00 2001 From: Liubov Talamanova Date: Mon, 10 Jun 2024 16:43:36 +0100 Subject: [PATCH 08/10] rename function --- .../algorithms/weight_compression/scale_estimation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nncf/quantization/algorithms/weight_compression/scale_estimation.py b/nncf/quantization/algorithms/weight_compression/scale_estimation.py index ea414d8fe9c..8b6bc058140 100644 --- a/nncf/quantization/algorithms/weight_compression/scale_estimation.py +++ b/nncf/quantization/algorithms/weight_compression/scale_estimation.py @@ -225,7 +225,7 @@ def apply( input_tensors.append(zp.data) # iterative rectification of initial scale for i in range(self._initial_steps): - near_to_ideal_scale = get_near_to_ideal_scale(original_weight, target, zero_mask, importance) + near_to_ideal_scale = estimate_scales(original_weight, target, zero_mask, importance) input_tensors[1] = near_to_ideal_scale.data out = compress_decompress_model(input_tensors) @@ -270,7 +270,7 @@ def apply( target, zero_mask = get_target_zero_mask(compressed_weights, zp) zero_mask = zero_scale * zero_mask.astype(original_weight.dtype) - near_to_ideal_scale = get_near_to_ideal_scale(original_weight, target, zero_mask, importance) + near_to_ideal_scale = estimate_scales(original_weight, target, zero_mask, importance) input_tensors[1] = near_to_ideal_scale.data out = compress_decompress_model(input_tensors) @@ -307,7 +307,7 @@ def get_target_zero_mask(compressed_weights, zp=None): return target, zero_mask -def get_near_to_ideal_scale(weight, target, zero_mask, importance): +def estimate_scales(weight, target, zero_mask, importance): ideal_scale = fns.abs(weight) / (fns.abs(target) + zero_mask) weighted_scale = ideal_scale * importance near_to_ideal_scale = fns.sum(weighted_scale, axis=2, keepdims=True) From f81cc850de6202b517728174a15a0813c526e26d Mon Sep 17 00:00:00 2001 From: Liubov Talamanova Date: Tue, 11 Jun 2024 13:18:28 +0100 Subject: [PATCH 09/10] Reduce computation time for ASYM mode --- .../algorithms/weight_compression/scale_estimation.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nncf/quantization/algorithms/weight_compression/scale_estimation.py b/nncf/quantization/algorithms/weight_compression/scale_estimation.py index 8b6bc058140..b7d3f263b52 100644 --- a/nncf/quantization/algorithms/weight_compression/scale_estimation.py +++ b/nncf/quantization/algorithms/weight_compression/scale_estimation.py @@ -165,6 +165,8 @@ def apply( original_weight = fns.zeros_like(weight) + weight compressed_weights, scale, zp = do_integer_quantization(original_weight, reduction_axis, config) + if zp is not None: + zp = zp.astype(scale.dtype) q_weights = do_dequantization(compressed_weights, scale, zp, reduction_axis) s = fns.unsqueeze(s, 0) From 3e6c649a4e5d303c7d2f2daa71738ed1520e787f Mon Sep 17 00:00:00 2001 From: Liubov Talamanova Date: Thu, 13 Jun 2024 10:11:39 +0100 Subject: [PATCH 10/10] docstring --- .../weight_compression/scale_estimation.py | 23 ++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/nncf/quantization/algorithms/weight_compression/scale_estimation.py b/nncf/quantization/algorithms/weight_compression/scale_estimation.py index b7d3f263b52..156335e43b5 100644 --- a/nncf/quantization/algorithms/weight_compression/scale_estimation.py +++ b/nncf/quantization/algorithms/weight_compression/scale_estimation.py @@ -10,7 +10,7 @@ # limitations under the License. from copy import deepcopy -from typing import Any, Dict, List, Optional, TypeVar +from typing import Any, Dict, List, Optional, Tuple, TypeVar from nncf import Dataset from nncf.common.graph.graph import NNCFGraph @@ -301,7 +301,15 @@ def apply( return res -def get_target_zero_mask(compressed_weights, zp=None): +def get_target_zero_mask(compressed_weights: TTensor, zp: Optional[TTensor] = None) -> Tuple[TTensor, TTensor]: + """ + Computes the target values and a mask indicating zero values in the target. + + :param compressed_weights: The compressed weights tensor. + :param zp: The zero point tensor. + :return: The compressed weights optionally adjusted by the zero point and + a boolean mask indicating positions in the target that are close to zero. + """ target = compressed_weights if zp is not None: target = target.astype(dtype=zp.dtype) - zp @@ -309,7 +317,16 @@ def get_target_zero_mask(compressed_weights, zp=None): return target, zero_mask -def estimate_scales(weight, target, zero_mask, importance): +def estimate_scales(weight: TTensor, target: TTensor, zero_mask: TTensor, importance: TTensor) -> TTensor: + """ + Estimates scales for the given weight, target, zero mask, and importance. + + :param weight: The weights tensor. + :param target: The target values tensor. + :param zero_mask: A boolean mask indicating positions in the target that are close to zero. + :param importance: The importance values tensor. + :return: The estimated scales + """ ideal_scale = fns.abs(weight) / (fns.abs(target) + zero_mask) weighted_scale = ideal_scale * importance near_to_ideal_scale = fns.sum(weighted_scale, axis=2, keepdims=True)