diff --git a/docs/compression_algorithms/CompressWeights.md b/docs/compression_algorithms/CompressWeights.md index 3de8bf93277..d7c083740fc 100644 --- a/docs/compression_algorithms/CompressWeights.md +++ b/docs/compression_algorithms/CompressWeights.md @@ -9,7 +9,7 @@ The Weights Compression algorithm is aimed at compressing the weights of the mod #### Supported modes By default, weights are compressed asymmetrically to 8-bit integer data type - "INT8_ASYM" mode. -OpenVINO backend also supports 3 modes of mixed precision weight quantization with a 4-bit data type as a primary precision - INT4_SYM, INT4_ASYM and NF4. The primary precision in case of INT4_SYM mode is unsigned 4-bit integer and weights are quantized to it [symmetrically](https://github.com/openvinotoolkit/nncf/blob/develop/docs/compression_algorithms/Quantization.md#symmetric-quantization) with a fixed zero point equals to 8. In case of INT4_ASYM mode - also unsigned 4-bit integer, but weight are quantized to it [asymmetrically](https://github.com/openvinotoolkit/nncf/blob/develop/docs/compression_algorithms/Quantization.md#asymmetric-quantization) with a typical non-fixed zero point. In case of NF4 mode - [nf4](https://arxiv.org/pdf/2305.14314v1.pdf) data type without zero point. +OpenVINO backend also supports 3 modes of mixed precision weight quantization with a 4-bit data type as a primary precision - INT4_SYM, INT4_ASYM and NF4. The primary precision in case of INT4_SYM mode is signed 4-bit integer and weights are quantized to it [symmetrically](https://github.com/openvinotoolkit/nncf/blob/develop/docs/compression_algorithms/Quantization.md#symmetric-quantization) without zero point. In case of INT4_ASYM mode - unsigned 4-bit integer, but weight are quantized to it [asymmetrically](https://github.com/openvinotoolkit/nncf/blob/develop/docs/compression_algorithms/Quantization.md#asymmetric-quantization) with a typical non-fixed zero point. In case of NF4 mode - [nf4](https://arxiv.org/pdf/2305.14314v1.pdf) data type without zero point. All 4-bit modes have a grouped quantization support, when small group of weights (e.g. 128) in the channel dimension share quantization parameters (scale). All embeddings and last linear layers are always compressed to 8-bit integer data type. Percent of the rest layers compressed to 4-bit can be configured by "ratio" parameter. E.g. ratio=0.9 means 90% of layers compressed to the corresponding 4-bit data type and the rest to 8-bit asymmetric integer data type. @@ -348,7 +348,7 @@ Here is the word perplexity with data-free and data-aware mixed-precision INT4-I - The algorithm is supported for OpenVINO and PyTorch models. - The compression applies in-place. - The compressed model is not trainable. -- INT8_SYM, INT4_SYM, INT4_ASYM and NF4 modes, grouped quantization and mixed precision selection is available for OpenVINO backend only. +- INT4_SYM, INT4_ASYM and NF4 modes, grouped quantization and mixed precision selection is available for OpenVINO backend only. - NF4 support is experimental - models quantized to nf4 should not be faster models quantized to 8-bit integer. #### Additional resources diff --git a/nncf/quantization/algorithms/weight_compression/openvino_backend.py b/nncf/quantization/algorithms/weight_compression/openvino_backend.py index 27d1be422e1..f7ab65d7cd8 100644 --- a/nncf/quantization/algorithms/weight_compression/openvino_backend.py +++ b/nncf/quantization/algorithms/weight_compression/openvino_backend.py @@ -18,6 +18,8 @@ from nncf.common.graph.operator_metatypes import OperatorMetatype from nncf.common.graph.transformations.commands import TargetType from nncf.experimental.common.tensor_statistics.collectors import TensorCollector +from nncf.experimental.tensor.definitions import TensorDataType +from nncf.experimental.tensor.functions import count_nonzero from nncf.experimental.tensor.tensor import Tensor from nncf.openvino.graph.metatypes.openvino_metatypes import OVEmbeddingMetatype from nncf.openvino.graph.metatypes.openvino_metatypes import OVMatMulMetatype @@ -119,17 +121,14 @@ def transform_model( compression_config = wc_params.compression_config if compression_config.mode == CompressWeightsMode.NF4: compression_dtype = ov.Type.nf4 - elif compression_config.mode in [ - CompressWeightsMode.INT8_ASYM, - CompressWeightsMode.INT8_SYM, - CompressWeightsMode.INT8, - CompressWeightsMode.INT4_ASYM, - CompressWeightsMode.INT4_SYM, - ]: - if compression_config.mode in [CompressWeightsMode.INT4_ASYM, CompressWeightsMode.INT4_SYM]: - compression_dtype = ov.Type.u4 - else: - compression_dtype = ov.Type.u8 + elif compression_config.mode == CompressWeightsMode.INT4_SYM: + compression_dtype = ov.Type.i4 + elif compression_config.mode == CompressWeightsMode.INT4_ASYM: + compression_dtype = ov.Type.u4 + elif compression_config.mode == CompressWeightsMode.INT8_SYM: + compression_dtype = ov.Type.i8 + elif compression_config.mode == CompressWeightsMode.INT8_ASYM: + compression_dtype = ov.Type.u8 else: raise ValueError(f"{compression_config.mode.value} is not supported.") @@ -147,13 +146,16 @@ def transform_model( ) converted_const = opset.convert(compressed_const, const_dtype) if compressed_weight.zero_point is not None: - zero_point_const = opset.constant( - compressed_weight.zero_point.data, - dtype=compression_dtype, - name=f"{const_node_name}/zero_point", - ) - converted_zero_point = opset.convert(zero_point_const, const_dtype) - converted_const = opset.subtract(converted_const, converted_zero_point) + if compressed_weight.tensor.dtype == TensorDataType.int8: + assert count_nonzero(compressed_weight.zero_point.data) == 0 + else: + zero_point_const = opset.constant( + compressed_weight.zero_point.data, + dtype=compression_dtype, + name=f"{const_node_name}/zero_point", + ) + converted_zero_point = opset.convert(zero_point_const, const_dtype) + converted_const = opset.subtract(converted_const, converted_zero_point) scale_data = compressed_weight.scale.data mul = opset.multiply( diff --git a/nncf/quantization/algorithms/weight_compression/torch_backend.py b/nncf/quantization/algorithms/weight_compression/torch_backend.py index f5db65e3a05..cf09bf87cda 100644 --- a/nncf/quantization/algorithms/weight_compression/torch_backend.py +++ b/nncf/quantization/algorithms/weight_compression/torch_backend.py @@ -23,6 +23,7 @@ from nncf.common.graph.transformations.layout import TransformationLayout from nncf.experimental.common.tensor_statistics.collectors import TensorCollector from nncf.experimental.tensor.definitions import TensorDataType +from nncf.experimental.tensor.functions import count_nonzero from nncf.experimental.tensor.tensor import Tensor from nncf.parameters import CompressWeightsMode from nncf.quantization.algorithms.weight_compression.backend import WeightCompressionAlgoBackend @@ -34,7 +35,8 @@ from nncf.torch.graph.transformations.commands import PTTargetPoint from nncf.torch.model_transformer import PTModelTransformer from nncf.torch.nncf_network import NNCFNetwork -from nncf.torch.quantization.layers import WeightsDecompressor +from nncf.torch.quantization.layers import AsymmetricWeightsDecompressor +from nncf.torch.quantization.layers import SymmetricWeightsDecompressor from nncf.torch.tensor_statistics.collectors import get_raw_stat_collector @@ -215,7 +217,11 @@ def transform_model( compressed_weight = compress_weight(Tensor(weight), wc_params.reduction_axis, compression_config) # pack compressed tensor - packed_tensor = compressed_weight.tensor.astype(TensorDataType.uint8) + if compression_config.mode == CompressWeightsMode.INT8_SYM: + dtype = TensorDataType.int8 + else: + dtype = TensorDataType.uint8 + packed_tensor = compressed_weight.tensor.astype(dtype) # sets compressed tensor compressed_parameter = torch.nn.Parameter(packed_tensor.data, requires_grad=False) @@ -229,11 +235,13 @@ def transform_model( if id(param) == id(weight): setattr(c_module, name, compressed_parameter) - # pack zero point tensor - packed_zero_point = compressed_weight.zero_point.astype(TensorDataType.uint8) - # creates weight decompressor - decompressor = WeightsDecompressor(compressed_weight.scale.data, packed_zero_point.data) + if compression_config.mode == CompressWeightsMode.INT8_SYM: + assert count_nonzero(compressed_weight.zero_point) == 0 + decompressor = SymmetricWeightsDecompressor(compressed_weight.scale.data) + else: + packed_zero_point = compressed_weight.zero_point.astype(dtype) + decompressor = AsymmetricWeightsDecompressor(compressed_weight.scale.data, packed_zero_point.data) # registry weight decompression module in the model decompressor_name = f"weights_decompressor_{weight_node.node_name.replace('.', '_')}" diff --git a/nncf/quantization/algorithms/weight_compression/weight_lowering.py b/nncf/quantization/algorithms/weight_compression/weight_lowering.py index 90bb404f306..968315c93f7 100644 --- a/nncf/quantization/algorithms/weight_compression/weight_lowering.py +++ b/nncf/quantization/algorithms/weight_compression/weight_lowering.py @@ -97,14 +97,14 @@ def do_integer_quantization( """ The method quantizes the given weights to integer data type in accordance with the compression config. The config defines a quantization mode: - INT8_SYM mode refers to unsigned int8 symmetric weight compression with a fixed zero point equals to 128 - - quantization to [0, 255] range. + INT8_SYM mode refers to signed int8 symmetric weight compression without zero point - + quantization to [-128, 127] range. INT8_ASYM mode refers to unsigned int8 asymmetric weight compression with a typical non-fixed zero-point - quantization to [0, 255] range. INT4_ASYM mode refers to unsigned int4 asymmetric weight compression with a typical non-fixed zero-point - quantization to [0, 15] range. - INT4_SYM mode refers to unsigned int4 symmetric weight compression with a fixed zero point equals to 8 - - quantization to [0, 15] range. + INT4_SYM mode refers to signed int4 symmetric weight compression without zero point - + quantization to [-8, 7] range. NF4 mode requires a dedicated procedure and it is not supported in this method. One of the parameter of compression config is a group size. Quantization is per-channel, if group size equals to -1, otherwise it's per-group, i.e. group size number of weights in the channel dimension share quantization parameters @@ -113,17 +113,14 @@ def do_integer_quantization( :param weight: Weight array to compress. :param reduction_axis: Axis, along which to reduce (collect) different statistics (e.g. min, max). :param config: Information on how to compress (quantize) a specific weight. - :return: The compressed weights tensor of uint8 type, scale tensor of float32 type and - zero point tensor of int32 type that was used for its quantization. + :return: The compressed weights tensor of uint8 (asymmetric mode) or int8 (symmetric mode) type, + scale tensor of float32 type and zero point tensor of int32 type that was used for its quantization. """ mode = config.mode assert mode != CompressWeightsMode.NF4, "The function supports integer quantization only" group_size = config.group_size num_bits = config.num_bits - level_low = 0 - level_high = 2**num_bits - 1 - if weight.dtype != TensorDataType.float32: weight = weight.astype(TensorDataType.float32) @@ -132,23 +129,27 @@ def do_integer_quantization( weight, reduction_axis = reshape_weight_for_grouped_quantization(weight, reduction_axis, group_size) if mode in [CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT4_ASYM]: + level_low = 0 + level_high = 2**num_bits - 1 min_values = fns.min(weight, axis=reduction_axis, keepdims=True) # [a1, r, a2] -> [a1, 1, a2] max_values = fns.max(weight, axis=reduction_axis, keepdims=True) # [a1, r, a2] -> [a1, 1, a2] scale, zero_point = calculate_scale_zero_point( min_values, max_values, level_low, level_high, narrow_range=False ) + compressed_weights = fns.round(weight / scale + zero_point.astype(weight.dtype)) + compressed_weights = fns.clip(compressed_weights, level_low, level_high).astype(TensorDataType.uint8) else: + level_low = -(2 ** (num_bits - 1)) + level_high = 2 ** (num_bits - 1) - 1 scale = fns.max(fns.abs(weight), axis=reduction_axis, keepdims=True) # [a1, r//gs, 1, a2] - level_low_sym = -(2 ** (num_bits - 1)) - level_high_sym = 2 ** (num_bits - 1) - 1 - scale = scale / level_high_sym - zero_point = fns.as_tensor_like(scale, [-level_low_sym]) + scale = scale / level_high + zero_point = fns.zeros_like(scale) eps = fns.finfo(scale).eps # NOTE: adding machine epsilon to avoid division by zero scale = fns.where(fns.abs(scale) < eps, eps, scale) + compressed_weights = fns.round(weight / scale) + compressed_weights = fns.clip(compressed_weights, level_low, level_high).astype(TensorDataType.int8) - compressed_weights = fns.round(weight / scale + zero_point.astype(weight.dtype)) - compressed_weights = fns.clip(compressed_weights, level_low, level_high).astype(TensorDataType.uint8) return compressed_weights, scale, zero_point diff --git a/nncf/torch/quantization/layers.py b/nncf/torch/quantization/layers.py index 937d156c8fe..ada0ec5f5c7 100644 --- a/nncf/torch/quantization/layers.py +++ b/nncf/torch/quantization/layers.py @@ -1034,9 +1034,9 @@ def get_scale_shape(input_shape: List[int], is_weights: bool, per_channel: bool, return get_per_channel_scale_shape(input_shape, is_weights, channel_idx) -class WeightsDecompressor(nn.Module): +class AsymmetricWeightsDecompressor(nn.Module): """ - Applies decompression of compressed weights in the forward pass + Applies asymmetric decompression of compressed weights in the forward pass """ def __init__(self, scale: torch.Tensor, zero_point: torch.Tensor): @@ -1050,3 +1050,20 @@ def __init__(self, scale: torch.Tensor, zero_point: torch.Tensor): def forward(self, x): return decompress(x, self._scale, self._zero_point) + + +class SymmetricWeightsDecompressor(nn.Module): + """ + Applies symmetric decompression of compressed weights in the forward pass + """ + + def __init__(self, scale: torch.Tensor): + """ + :param scale: A scale in quantization scheme + """ + super().__init__() + self.register_buffer("_scale", scale) + + def forward(self, x): + zero_point = torch.zeros_like(self._scale) + return decompress(x, self._scale, zero_point) diff --git a/tests/openvino/native/data/2023.3/reference_scales/IntegerModel_compressed_weights_int8_sym.json b/tests/openvino/native/data/2023.3/reference_scales/IntegerModel_compressed_weights_int8_sym.json index 41b80d9aa5e..1d62f50135a 100644 --- a/tests/openvino/native/data/2023.3/reference_scales/IntegerModel_compressed_weights_int8_sym.json +++ b/tests/openvino/native/data/2023.3/reference_scales/IntegerModel_compressed_weights_int8_sym.json @@ -2,63 +2,60 @@ "matmul_2_data": { "compressed_weight": [ [ - 182, - 152, - 200, - 255, - 165, - 136, - 193 - ], - [ - 155, - 140, - 206, - 168, - 219, - 155, - 255 - ], - [ - 177, - 142, - 212, - 251, - 187, - 255, - 195 - ], - [ - 182, - 207, - 255, - 249, - 187, - 225, - 191 - ], - [ - 200, - 235, - 184, - 228, - 225, - 255, - 144 - ], - [ - 222, - 248, - 253, - 130, - 240, - 255, - 252 + 54, + 24, + 72, + 127, + 37, + 8, + 65 + ], + [ + 27, + 12, + 78, + 40, + 91, + 27, + 127 + ], + [ + 49, + 14, + 84, + 123, + 59, + 127, + 67 + ], + [ + 54, + 79, + 127, + 121, + 59, + 97, + 63 + ], + [ + 72, + 107, + 56, + 100, + 97, + 127, + 16 + ], + [ + 94, + 120, + 125, + 2, + 112, + 127, + 124 ] ], - "zero_point": [ - 128 - ], "scale": [ [ 0.006270269863307476 @@ -83,57 +80,54 @@ "matmul_1_data": { "compressed_weight": [ [ - 185, - 208, - 133, - 152, - 255, - 251 + 57, + 80, + 5, + 24, + 127, + 123 ], [ - 206, - 177, - 255, - 253, - 215, - 211 + 78, + 49, + 127, + 125, + 87, + 83 ], [ - 249, - 196, - 152, - 255, - 220, - 183 + 121, + 68, + 24, + 127, + 92, + 55 ], [ - 194, - 249, - 255, - 177, - 206, - 172 + 66, + 121, + 127, + 49, + 78, + 44 ], [ - 213, - 176, - 184, - 255, - 160, - 217 + 85, + 48, + 56, + 127, + 32, + 89 ], [ - 140, - 249, - 242, - 163, - 255, - 136 + 12, + 121, + 114, + 35, + 127, + 8 ] ], - "zero_point": [ - 128 - ], "scale": [ [ 0.0052805072627961636 @@ -158,33 +152,30 @@ "gather_2_data": { "compressed_weight": [ [ - 217, - 166, - 134, - 130, - 241, - 255 + 89, + 38, + 6, + 2, + 113, + 127 ], [ - 210, - 227, - 202, - 255, - 239, - 128 + 82, + 99, + 74, + 127, + 111, + 0 ], [ - 254, - 133, - 235, - 154, - 255, - 208 + 126, + 5, + 107, + 26, + 127, + 80 ] ], - "zero_point": [ - 128 - ], "scale": [ [ 0.007187051698565483 diff --git a/tests/openvino/native/quantization/test_weights_compression.py b/tests/openvino/native/quantization/test_weights_compression.py index 1ef73115fb2..96b692d260a 100644 --- a/tests/openvino/native/quantization/test_weights_compression.py +++ b/tests/openvino/native/quantization/test_weights_compression.py @@ -67,41 +67,41 @@ def get_next_node(node): def check_int8_node(op: ov.Node, mode: CompressWeightsMode = CompressWeightsMode.INT8_ASYM): - assert op.get_element_type() == ov.Type(np.uint8) + dtype = ov.Type.u8 if mode == CompressWeightsMode.INT8_ASYM else ov.Type.i8 + assert op.get_element_type() == dtype compressed_weight = get_const_value(op) + stats = {"compressed_weight": compressed_weight} convert_node = get_next_node(op) assert convert_node.get_type_name() == "Convert" - sub_node = get_next_node(convert_node) - assert sub_node.get_type_name() == "Subtract" + if mode == CompressWeightsMode.INT8_ASYM: + sub_node = get_next_node(convert_node) + assert sub_node.get_type_name() == "Subtract" - convert_node = sub_node.input_value(1).get_node() - assert convert_node.get_type_name() == "Convert" + convert_node = sub_node.input_value(1).get_node() + assert convert_node.get_type_name() == "Convert" - zero_point_node = convert_node.input_value(0).get_node() - zero_point = get_const_value(zero_point_node) - if mode == CompressWeightsMode.INT8_SYM: - assert list(zero_point_node.shape) == [1] - else: + zero_point_node = convert_node.input_value(0).get_node() + zero_point = get_const_value(zero_point_node) + stats["zero_point"] = zero_point reduced_weight_shape = list(op.shape) reduced_weight_shape[-1] = 1 assert list(zero_point_node.shape) == reduced_weight_shape + mul_node = get_next_node(sub_node) + else: + mul_node = get_next_node(convert_node) - mul_node = get_next_node(sub_node) assert mul_node.get_type_name() == "Multiply" scale_node = mul_node.input_value(1).get_node() scale = get_const_value(scale_node) - - return { - "compressed_weight": compressed_weight, - "zero_point": zero_point, - "scale": scale, - } + stats["scale"] = scale + return stats def check_int4_grouped(op: ov.Node, mode: CompressWeightsMode, group_size: int = 7): - assert op.get_element_type() == ov.Type.u4 + dtype = ov.Type.u4 if mode == CompressWeightsMode.INT4_ASYM else ov.Type.i4 + assert op.get_element_type() == dtype weight_shape = op.shape # NOTE: get_const_value doesn't work for 4-bit types assert list(weight_shape)[-1] == group_size @@ -111,20 +111,20 @@ def check_int4_grouped(op: ov.Node, mode: CompressWeightsMode, group_size: int = convert_node = get_next_node(op) assert convert_node.get_type_name() == "Convert" - sub_node = get_next_node(convert_node) - assert sub_node.get_type_name() == "Subtract" + if mode == CompressWeightsMode.INT4_ASYM: + sub_node = get_next_node(convert_node) + assert sub_node.get_type_name() == "Subtract" - convert_node = sub_node.input_value(1).get_node() - assert convert_node.get_type_name() == "Convert" + convert_node = sub_node.input_value(1).get_node() + assert convert_node.get_type_name() == "Convert" - zero_point_node = convert_node.input_value(0).get_node() - assert zero_point_node.get_element_type() == ov.Type.u4 - if mode == CompressWeightsMode.INT4_SYM: - assert list(zero_point_node.shape) == [1] - else: + zero_point_node = convert_node.input_value(0).get_node() + assert zero_point_node.get_element_type() == dtype assert list(zero_point_node.shape) == reduced_weight_shape + mul_node = get_next_node(sub_node) + else: + mul_node = get_next_node(convert_node) - mul_node = get_next_node(sub_node) assert mul_node.get_type_name() == "Multiply" scale_node = mul_node.input_value(1).get_node() assert list(scale_node.shape) == reduced_weight_shape @@ -266,7 +266,7 @@ def test_gather_in_4_bit_if_all_layers_with_data(metric): ) for op in compressed_model.get_ordered_ops(): if op.get_type_name() == "Constant" and "gather" in op.get_friendly_name(): - assert op.get_element_type() == ov.Type.u4 + assert op.get_element_type() == ov.Type.i4 def test_gather_can_be_8_bit_if_all_layers_without_data(): @@ -280,7 +280,7 @@ def test_gather_can_be_8_bit_if_all_layers_without_data(): ) for op in compressed_model.get_ordered_ops(): if op.get_type_name() == "Constant" and "gather" in op.get_friendly_name(): - assert ov.Type(np.uint8) == op.get_element_type() + assert ov.Type.u8 == op.get_element_type() def test_gather_can_be_4_bit_if_all_layers_without_data(): @@ -294,7 +294,7 @@ def test_gather_can_be_4_bit_if_all_layers_without_data(): ) for op in compressed_model.get_ordered_ops(): if op.get_type_name() == "Constant" and "gather" in op.get_friendly_name(): - assert ov.Type.u4 == op.get_element_type() + assert ov.Type.i4 == op.get_element_type() @pytest.mark.parametrize("metric", ALL_SENSITIVITY_METRICS) @@ -312,7 +312,7 @@ def test_gather_in_8_bit_if_not_all_layers(metric): ) for op in compressed_model.get_ordered_ops(): if op.get_type_name() == "Constant" and "gather" in op.get_friendly_name(): - assert op.get_element_type() == ov.Type(np.uint8) + assert op.get_element_type() == ov.Type.u8 MAX_BASELINE_SCORE = 1 / np.finfo(np.float32).eps @@ -361,15 +361,15 @@ def test_not_quantize_with_multiple_reduction_axes(mode): compressed_model = compress_weights(model, mode=mode) for op in compressed_model.get_ordered_ops(): if op.get_type_name() == "Constant" and op.get_friendly_name() == "gather_1_data": - assert op.get_element_type() == ov.Type(np.float32) + assert op.get_element_type() == ov.Type.f32 @pytest.mark.parametrize("mode", (CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM)) def test_shared_gather(mode): weight_name_vs_type = { - "gather_2_data": ov.Type(np.uint8), - "shared_data": ov.Type(np.uint8), - "matmul_1_data": ov.Type.u4, + "gather_2_data": ov.Type.u8, + "shared_data": ov.Type.u8, + "matmul_1_data": ov.Type.i4 if mode == CompressWeightsMode.INT4_SYM else ov.Type.u4, } model = GatherAndMatmulShareData().ov_model compressed_model = compress_weights(model, mode, group_size=3) @@ -545,7 +545,7 @@ def test_weight_compress_with_ignored_scope(ignored_scope, num_compressed): if ( op.get_type_name() == "Constant" and op.get_friendly_name() in ref_compressed_weights - and op.get_element_type() == ov.Type(np.uint8) + and op.get_element_type() == ov.Type.u8 ): act_num += 1 assert act_num == num_compressed diff --git a/tests/torch/ptq/test_weights_compression.py b/tests/torch/ptq/test_weights_compression.py index 35cd568a5bc..eafbaf86108 100644 --- a/tests/torch/ptq/test_weights_compression.py +++ b/tests/torch/ptq/test_weights_compression.py @@ -26,7 +26,7 @@ ALL_SENSITIVITY_METRICS = DATA_BASED_SENSITIVITY_METRICS + (SensitivityMetric.WEIGHT_QUANTIZATION_ERROR,) -SUPPORTED_MODES = (CompressWeightsMode.INT8, CompressWeightsMode.INT8_ASYM) +SUPPORTED_MODES = (CompressWeightsMode.INT8, CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT8_SYM) UNSUPPORTED_MODES = ( CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM, @@ -51,12 +51,14 @@ def forward(self, input_ids): return res -def test_compress_weights(): +@pytest.mark.parametrize("mode", (CompressWeightsMode.INT8_SYM, CompressWeightsMode.INT8_ASYM)) +def test_compress_weights(mode): model = ShortTransformer(5, 10) + dtype = torch.int8 if mode == CompressWeightsMode.INT8_SYM else torch.uint8 input_ids = torch.randint(0, 10, (5,)) wrapped_model = wrap_model(model, example_input=input_ids, trace_parameters=True) - compressed_model = compress_weights(wrapped_model) + compressed_model = compress_weights(wrapped_model, mode=mode) n_compressed_weights = 0 n_target_modules = 0 @@ -64,18 +66,20 @@ def test_compress_weights(): for _, module in compressed_model.named_children(): if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)): n_target_modules += 1 - if module.weight.dtype in [torch.uint8, torch.int8]: + if module.weight.dtype == dtype: n_compressed_weights += 1 assert n_compressed_weights == n_target_modules -def test_compress_shared_weights(mocker): +@pytest.mark.parametrize("mode", (CompressWeightsMode.INT8_SYM, CompressWeightsMode.INT8_ASYM)) +def test_compress_shared_weights(mocker, mode): model = ShortTransformer(5, 10, share_weights=True) + dtype = torch.int8 if mode == CompressWeightsMode.INT8_SYM else torch.uint8 input_ids = torch.randint(0, 10, (5,)) wrapped_model = wrap_model(model, example_input=input_ids, trace_parameters=True) - compressed_model = compress_weights(wrapped_model) + compressed_model = compress_weights(wrapped_model, mode=mode) n_compressed_weights = 0 n_target_modules = 0 @@ -83,7 +87,7 @@ def test_compress_shared_weights(mocker): for _, module in compressed_model.named_children(): if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)): n_target_modules += 1 - if module.weight.dtype in [torch.uint8, torch.int8]: + if module.weight.dtype == dtype: n_compressed_weights += 1 assert n_compressed_weights == n_target_modules @@ -126,7 +130,7 @@ def test_raise_error_with_unsupported_params_for_int8(mode, params): @pytest.mark.parametrize("mode", UNSUPPORTED_MODES) -def test_raise_error_with_not_int8_asym(mode): +def test_raise_error_with_not_int8(mode): dummy_torch_model = EmptyModel() dummy_input = torch.Tensor() wrapped_model = wrap_model(dummy_torch_model, example_input=dummy_input, trace_parameters=True)