From 4e0e96af8d7fb98688c9bfe433854ed48f7754ac Mon Sep 17 00:00:00 2001 From: Liubov Talamanova Date: Thu, 6 Jun 2024 12:42:03 +0100 Subject: [PATCH] Apply comments --- .../algorithms/weight_compression/gptq.py | 14 ++-- .../weight_compression/mixed_precision.py | 3 +- .../weight_compression/openvino_backend.py | 62 ++++++++-------- .../weight_compression/scale_estimation.py | 74 +++++++++++-------- .../weight_compression/torch_backend.py | 2 - .../weight_compression/weight_lowering.py | 37 +++++----- nncf/torch/quantization/layers.py | 3 +- nncf/torch/quantization/quantize_functions.py | 8 +- 8 files changed, 109 insertions(+), 94 deletions(-) diff --git a/nncf/quantization/algorithms/weight_compression/gptq.py b/nncf/quantization/algorithms/weight_compression/gptq.py index 282e2f6910a..ee2940a86aa 100644 --- a/nncf/quantization/algorithms/weight_compression/gptq.py +++ b/nncf/quantization/algorithms/weight_compression/gptq.py @@ -263,7 +263,7 @@ def _quantize_weights( quantized_col = decompress_nf4_weight(compressed_weights, scales[-1]) else: compressed_weights = calculate_quantized_weight( - fns.unsqueeze(weight_col, 1), scales[-1], zero_points[-1], block_compression_config + fns.unsqueeze(weight_col, 1), block_compression_config, scales[-1], zero_points[-1] ) quantized_col = do_dequantization(compressed_weights, scales[-1], zero_points[-1]) quantized_col = fns.flatten(quantized_col) @@ -287,13 +287,11 @@ def _quantize_weights( ) scales = fns.stack(scales, axis=1) - if wc_params.compression_config.mode == CompressWeightsMode.NF4: - zero_points = None - elif wc_params.compression_config.mode in [ - CompressWeightsMode.INT8_SYM, - CompressWeightsMode.INT4_SYM, + if wc_params.compression_config.mode in [ + CompressWeightsMode.INT8_ASYM, + CompressWeightsMode.INT4_ASYM, ]: - zero_points = fns.squeeze(zero_points[0]) - else: zero_points = fns.stack(zero_points, axis=1) + else: + zero_points = None return scales, zero_points diff --git a/nncf/quantization/algorithms/weight_compression/mixed_precision.py b/nncf/quantization/algorithms/weight_compression/mixed_precision.py index 29a5fc4bb3a..c4e4c73fa92 100644 --- a/nncf/quantization/algorithms/weight_compression/mixed_precision.py +++ b/nncf/quantization/algorithms/weight_compression/mixed_precision.py @@ -22,6 +22,7 @@ from nncf.quantization.algorithms.weight_compression.backend import WeightCompressionAlgoBackend from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters +from nncf.quantization.algorithms.weight_compression.weight_lowering import do_dequantization from nncf.quantization.algorithms.weight_compression.weight_lowering import do_integer_quantization from nncf.quantization.algorithms.weight_compression.weight_lowering import get_integer_quantization_error @@ -176,7 +177,7 @@ def _calc_weight_sensitivity(self, weight_param: WeightCompressionParameters) -> weight = weight.astype(TensorDataType.float32) compressed_weights, scale, zero_point = do_integer_quantization(weight, reduction_axes, backup_config) - decompressed_weight = (compressed_weights - zero_point).astype(weight.dtype) * scale + decompressed_weight = do_dequantization(compressed_weights, scale, zero_point) decompressed_weight = decompressed_weight.reshape(orig_shape) return fns.linalg.norm(decompressed_weight - weight, ord="fro").item() diff --git a/nncf/quantization/algorithms/weight_compression/openvino_backend.py b/nncf/quantization/algorithms/weight_compression/openvino_backend.py index 859112a395e..33ecb91a56f 100644 --- a/nncf/quantization/algorithms/weight_compression/openvino_backend.py +++ b/nncf/quantization/algorithms/weight_compression/openvino_backend.py @@ -21,7 +21,6 @@ from nncf.common.graph.utils import get_reduction_axes from nncf.experimental.common.tensor_statistics.collectors import TensorCollector from nncf.experimental.tensor.definitions import TensorDataType -from nncf.experimental.tensor.functions import count_nonzero from nncf.experimental.tensor.tensor import Tensor from nncf.openvino.graph.metatypes import openvino_metatypes as om from nncf.openvino.graph.model_transformer import OVModelTransformer @@ -174,19 +173,16 @@ def transform_model( compressed_weight.tensor.data, dtype=compression_dtype, name=const_node_name ) converted_const = opset.convert(compressed_const, ov.Type.f16) - if compressed_weight.zero_point is not None: - if compressed_weight.tensor.dtype == TensorDataType.int8: - assert count_nonzero(compressed_weight.zero_point.data) == 0 - else: - zero_point_const = opset.constant( - compressed_weight.zero_point.data, - dtype=compression_dtype, - name=f"{const_node_name}/zero_point", - ) - converted_zero_point = opset.convert(zero_point_const, ov.Type.f16) - converted_const = opset.subtract( - converted_const, converted_zero_point, name=f"{const_node_name}/zero_point/subtract" - ) + if compressed_weight.zero_point is not None and compressed_weight.tensor.dtype == TensorDataType.uint8: + zero_point_const = opset.constant( + compressed_weight.zero_point.data, + dtype=compression_dtype, + name=f"{const_node_name}/zero_point", + ) + converted_zero_point = opset.convert(zero_point_const, ov.Type.f16) + converted_const = opset.subtract( + converted_const, converted_zero_point, name=f"{const_node_name}/zero_point/subtract" + ) scale_const = opset.constant( compressed_weight.scale.data, dtype=ov.Type.f16, name=f"{const_node_name}/scale" @@ -222,27 +218,28 @@ def dump_parameters( @staticmethod def get_compress_decompress_pipeline( - weight_compression_parameter: WeightCompressionParameters, w_shape, s_shape, z_p_shape + weight_compression_parameter: WeightCompressionParameters, w_shape, s_shape, z_p_shape=None ): - ( - w, - s, - zp, - clamp, - ) = OVWeightCompressionAlgoBackend.get_compress_pipeline( + parameters, clamp = OVWeightCompressionAlgoBackend.get_compress_pipeline( weight_compression_parameter, w_shape, s_shape, z_p_shape, True ) - result = (clamp - zp) * s - model = ov.Model([result], [w, s, zp]) + if len(parameters) == 3: + _, s, zp = parameters + result = (clamp - zp) * s + else: + s = parameters[1] + result = clamp * s + + model = ov.Model([result], parameters) compiled_model = ov.compile_model(model) - return lambda w, s, zp: compiled_model([w, s, zp])[0] + return lambda parameters: compiled_model(parameters)[0] @staticmethod def get_compress_pipeline( - weight_compression_parameter: WeightCompressionParameters, w_shape, s_shape, z_p_shape, return_nodes=False + weight_compression_parameter: WeightCompressionParameters, w_shape, s_shape, z_p_shape=None, return_nodes=False ): config = weight_compression_parameter.compression_config mode = config.mode @@ -254,18 +251,23 @@ def get_compress_pipeline( w = opset.parameter(w_shape, name="w") s = opset.parameter(s_shape, name="s") - zp = opset.parameter(z_p_shape, name="zp") + parameters = [w, s] + compressed_w = w / s + if z_p_shape is not None: + zp = opset.parameter(z_p_shape, name="zp") + parameters.append(zp) + compressed_w += zp - result = opset.clamp(opset.round(w / s + zp), level_low, level_high, name="compressed_weights") + result = opset.clamp(opset.round(compressed_w), level_low, level_high, name="compressed_weights") if return_nodes: - return w, s, zp, result + return parameters, result - model = ov.Model([result], [w, s, zp]) + model = ov.Model([result], parameters) compiled_model = ov.compile_model(model) - return lambda w, s, zp: compiled_model([w, s, zp])[0] + return lambda parameters: compiled_model(parameters)[0] class OVAWQAlgoAlgoBackend(OVWeightCompressionAlgoBackend): diff --git a/nncf/quantization/algorithms/weight_compression/scale_estimation.py b/nncf/quantization/algorithms/weight_compression/scale_estimation.py index 3cc9d7f4180..ea414d8fe9c 100644 --- a/nncf/quantization/algorithms/weight_compression/scale_estimation.py +++ b/nncf/quantization/algorithms/weight_compression/scale_estimation.py @@ -117,7 +117,7 @@ def apply( :return: Dict with pairs (weight name, estimated scale). """ - compress_decompress_cashe = {} + compress_decompress_cache = {} res = dict() for wp in track(self._all_weight_params, description="Applying Scale Estimation"): @@ -165,8 +165,6 @@ def apply( original_weight = fns.zeros_like(weight) + weight compressed_weights, scale, zp = do_integer_quantization(original_weight, reduction_axis, config) - zp = zp.astype(scale.dtype) - q_weights = do_dequantization(compressed_weights, scale, zp, reduction_axis) s = fns.unsqueeze(s, 0) @@ -180,9 +178,7 @@ def apply( importance = fns.ones_like(original_weight) importance = importance * s - target = compressed_weights.astype(dtype=zp.dtype) - zp - zero_mask = compressed_weights == zp - + target, zero_mask = get_target_zero_mask(compressed_weights, zp) importance = fns.where(zero_mask, 0.0, importance) # normalize importances for every group of weights to make sum of them equal to 1.0 @@ -203,18 +199,20 @@ def apply( if self._weight_penalty > 0.0: min_max_scale_diffs += self._weight_penalty * fns.mean((q_weights - original_weight) ** 2, axis=-1) - key = ( - (wp.compression_config.mode, wp.compression_config.num_bits) + q_weights.shape + scale.shape + zp.shape - ) - if key in compress_decompress_cashe: - compress_decompress_model = compress_decompress_cashe[key]["compress_decompress_model"] - compress_model = compress_decompress_cashe[key]["compress_model"] + zp_shape = zp.shape if zp is not None else None + key = [(wp.compression_config.mode, wp.compression_config.num_bits) + q_weights.shape + scale.shape] + if zp is not None: + key += zp_shape + key = tuple(key) + if key in compress_decompress_cache: + compress_decompress_model = compress_decompress_cache[key]["compress_decompress_model"] + compress_model = compress_decompress_cache[key]["compress_model"] else: compress_decompress_model = self._backend_entity.get_compress_decompress_pipeline( - wp, q_weights.shape, scale.shape, zp.shape + wp, q_weights.shape, scale.shape, zp_shape ) - compress_model = self._backend_entity.get_compress_pipeline(wp, q_weights.shape, scale.shape, zp.shape) - compress_decompress_cashe[key] = { + compress_model = self._backend_entity.get_compress_pipeline(wp, q_weights.shape, scale.shape, zp_shape) + compress_decompress_cache[key] = { "compress_decompress_model": compress_decompress_model, "compress_model": compress_model, } @@ -222,14 +220,15 @@ def apply( zero_scale = 0.001 zero_mask = zero_scale * zero_mask.astype(original_weight.dtype) + input_tensors = [original_weight.data, None] + if zp is not None: + input_tensors.append(zp.data) # iterative rectification of initial scale for i in range(self._initial_steps): - ideal_scale = fns.abs(original_weight) / (fns.abs(target) + zero_mask) - weighted_scale = ideal_scale * importance + near_to_ideal_scale = get_near_to_ideal_scale(original_weight, target, zero_mask, importance) + input_tensors[1] = near_to_ideal_scale.data - near_to_ideal_scale = fns.sum(weighted_scale, axis=2, keepdims=True) - - out = compress_decompress_model(original_weight.data, near_to_ideal_scale.data, zp.data) + out = compress_decompress_model(input_tensors) q_weights_ = fns.zeros_like(original_weight) + out q_outs = fns.matmul(fns.transpose(q_weights_, (1, 0, 2)), X) @@ -252,12 +251,12 @@ def apply( else: near_to_ideal_scale = mask * result_scale + (1.0 - mask) * near_to_ideal_scale result_scale = near_to_ideal_scale + input_tensors[1] = near_to_ideal_scale.data if i < self._initial_steps - 1: - out = compress_model(original_weight.data, near_to_ideal_scale.data, zp.data) + out = compress_model(input_tensors) compressed_weights = fns.zeros_like(original_weight) + out - target = compressed_weights - zp - zero_mask = compressed_weights == zp + target, zero_mask = get_target_zero_mask(compressed_weights, zp) zero_mask = zero_scale * zero_mask.astype(original_weight.dtype) # iterative rectification of scale based on grid search @@ -265,18 +264,16 @@ def apply( factor = 1.0 - 0.05 * scale_steps scaled_scale = factor * scale - out = compress_model(original_weight.data, scaled_scale.data, zp.data) + input_tensors[1] = scaled_scale.data + out = compress_model(input_tensors) compressed_weights = fns.zeros_like(original_weight) + out - target = compressed_weights - zp - zero_mask = compressed_weights == zp + target, zero_mask = get_target_zero_mask(compressed_weights, zp) zero_mask = zero_scale * zero_mask.astype(original_weight.dtype) + near_to_ideal_scale = get_near_to_ideal_scale(original_weight, target, zero_mask, importance) - ideal_scale = fns.abs(original_weight) / (fns.abs(target) + zero_mask) - weighted_scale = ideal_scale * importance - near_to_ideal_scale = fns.sum(weighted_scale, axis=2, keepdims=True) - - out = compress_decompress_model(original_weight.data, near_to_ideal_scale.data, zp.data) + input_tensors[1] = near_to_ideal_scale.data + out = compress_decompress_model(input_tensors) q_weights_ = fns.zeros_like(original_weight) + out q_outs = fns.matmul(fns.transpose(q_weights_, (1, 0, 2)), X) @@ -300,3 +297,18 @@ def apply( res[weight_name] = result_scale return res + + +def get_target_zero_mask(compressed_weights, zp=None): + target = compressed_weights + if zp is not None: + target = target.astype(dtype=zp.dtype) - zp + zero_mask = fns.isclose(target, 0) + return target, zero_mask + + +def get_near_to_ideal_scale(weight, target, zero_mask, importance): + ideal_scale = fns.abs(weight) / (fns.abs(target) + zero_mask) + weighted_scale = ideal_scale * importance + near_to_ideal_scale = fns.sum(weighted_scale, axis=2, keepdims=True) + return near_to_ideal_scale diff --git a/nncf/quantization/algorithms/weight_compression/torch_backend.py b/nncf/quantization/algorithms/weight_compression/torch_backend.py index f80fd3a7864..6e1b4a7a6c1 100644 --- a/nncf/quantization/algorithms/weight_compression/torch_backend.py +++ b/nncf/quantization/algorithms/weight_compression/torch_backend.py @@ -23,7 +23,6 @@ from nncf.common.graph.transformations.layout import TransformationLayout from nncf.experimental.common.tensor_statistics.collectors import TensorCollector from nncf.experimental.tensor.definitions import TensorDataType -from nncf.experimental.tensor.functions import count_nonzero from nncf.experimental.tensor.tensor import Tensor from nncf.parameters import CompressWeightsMode from nncf.quantization.algorithms.weight_compression.backend import WeightCompressionAlgoBackend @@ -233,7 +232,6 @@ def transform_model( # creates weight decompressor if compression_config.mode == CompressWeightsMode.INT8_SYM: - assert count_nonzero(compressed_weight.zero_point) == 0 decompressor = SymmetricWeightsDecompressor(compressed_weight.scale.data, result_dtype=weight.dtype) else: packed_zero_point = compressed_weight.zero_point.astype(dtype) diff --git a/nncf/quantization/algorithms/weight_compression/weight_lowering.py b/nncf/quantization/algorithms/weight_compression/weight_lowering.py index e519a2727cc..54a2ef90754 100644 --- a/nncf/quantization/algorithms/weight_compression/weight_lowering.py +++ b/nncf/quantization/algorithms/weight_compression/weight_lowering.py @@ -233,28 +233,27 @@ def calculate_integer_quantization_params( scale, zero_point = calculate_scale_zero_point( min_values, max_values, level_low, level_high, narrow_range=False ) - else: - level_high = 2 ** (num_bits - 1) - 1 - scale = fns.max(fns.abs(weight), axis=reduction_axes, keepdims=True) # [a1, r//gs, 1, a2] - scale = scale / level_high - zero_point = fns.as_tensor_like(scale, [0]).astype(TensorDataType.int32) - eps = fns.finfo(scale).eps - # NOTE: adding machine epsilon to avoid division by zero - scale = fns.where(fns.abs(scale) < eps, eps, scale) + return scale, zero_point - return scale, zero_point + level_high = 2 ** (num_bits - 1) - 1 + scale = fns.max(fns.abs(weight), axis=reduction_axes, keepdims=True) # [a1, r//gs, 1, a2] + scale /= level_high + eps = fns.finfo(scale).eps + # NOTE: adding machine epsilon to avoid division by zero + scale = fns.where(fns.abs(scale) < eps, eps, scale) + return scale, None def calculate_quantized_weight( - weight: Tensor, scale: Tensor, zero_point: Tensor, config: WeightCompressionConfig + weight: Tensor, config: WeightCompressionConfig, scale: Tensor, zero_point: Optional[Tensor] = None ) -> Tensor: """ Quantizes the weight tensor using the provided scale and zero point. :param weight: Weight tensor to quantize. + :param config: Weight compression configuration. :param scale: Scale tensor used for quantization. :param zero_point: Zero point tensor used for quantization. - :param config: Weight compression configuration. :return: Quantized weight tensor of uint8 or int8 type. """ if weight.dtype != TensorDataType.float32: @@ -268,7 +267,10 @@ def calculate_quantized_weight( level_low = 0 if asym_quant else -(2 ** (num_bits - 1)) level_high = 2**num_bits - 1 if asym_quant else 2 ** (num_bits - 1) - 1 - compressed_weights = fns.round(weight / scale + zero_point.astype(weight.dtype)) + compressed_weights = weight / scale + if zero_point is not None: + compressed_weights += zero_point.astype(weight.dtype) + compressed_weights = fns.round(compressed_weights) compressed_weights = fns.clip(compressed_weights, level_low, level_high).astype(dtype) return compressed_weights @@ -322,7 +324,7 @@ def do_integer_quantization( if precomputed_zero_point is not None: zero_point = precomputed_zero_point - compressed_weights = calculate_quantized_weight(weight, scale, zero_point, config) + compressed_weights = calculate_quantized_weight(weight, config, scale, zero_point) return compressed_weights, scale, zero_point @@ -344,8 +346,7 @@ def get_integer_quantization_error( weight = weight.astype(TensorDataType.float32) compressed_weights, scale, zero_point = do_integer_quantization(weight, reduction_axes, config) - - decompressed_weight = (compressed_weights - zero_point).astype(weight.dtype) * scale + decompressed_weight = do_dequantization(compressed_weights, scale, zero_point) decompressed_weight = decompressed_weight.reshape(orig_shape) diff = (decompressed_weight - weight) ** 2 @@ -384,7 +385,7 @@ def compress_weight( def do_dequantization( - compressed_weights: Tensor, scale: Tensor, zero_point: Tensor, reduction_axis: int = -1 + compressed_weights: Tensor, scale: Tensor, zero_point: Optional[Tensor] = None, reduction_axis: int = -1 ) -> Tensor: """ The method dequantizes the given weights to float point data type in accordance with the scale and @@ -397,7 +398,9 @@ def do_dequantization( :return: dequantized/decompressed weights. """ decompressed_weight = compressed_weights.astype(dtype=scale.dtype) - decompressed_weight = (decompressed_weight - zero_point) * scale + if zero_point is not None: + decompressed_weight -= zero_point + decompressed_weight *= scale if reduction_axis > -1: shape = list(decompressed_weight.shape) # [a1, r, a2] - "r" refers to number of channels along reduction axis diff --git a/nncf/torch/quantization/layers.py b/nncf/torch/quantization/layers.py index ab16a48cb3d..213d226926e 100644 --- a/nncf/torch/quantization/layers.py +++ b/nncf/torch/quantization/layers.py @@ -1081,7 +1081,6 @@ def __init__(self, scale: torch.Tensor, result_dtype: torch.dtype = None): self.result_dtype = result_dtype def forward(self, x): - zero_point = torch.zeros_like(self._scale) - result = decompress(x, self._scale, zero_point) + result = decompress(x, self._scale) result = result.type(dtype=self.result_dtype) if self.result_dtype is not None else result return result diff --git a/nncf/torch/quantization/quantize_functions.py b/nncf/torch/quantization/quantize_functions.py index 6771d6f38fa..ad7e8c03ca6 100644 --- a/nncf/torch/quantization/quantize_functions.py +++ b/nncf/torch/quantization/quantize_functions.py @@ -8,7 +8,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Optional import torch @@ -249,7 +249,7 @@ def backward(ctx: Any, *grad_outputs: Any) -> Any: @register_operator() -def decompress(input: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor) -> torch.Tensor: +def decompress(input: torch.Tensor, scale: torch.Tensor, zero_point: Optional[torch.Tensor] = None) -> torch.Tensor: """ Decompress the input tensor. @@ -259,5 +259,7 @@ def decompress(input: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tenso :return: The decompressed tensor """ input = input.type(dtype=scale.dtype) - decompressed_input = (input - zero_point) * scale + if zero_point is not None: + input -= zero_point + decompressed_input = input * scale return decompressed_input