Skip to content

Commit

Permalink
Apply comments
Browse files Browse the repository at this point in the history
  • Loading branch information
l-bat committed Jun 6, 2024
1 parent 44f24dd commit 4b2e602
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 94 deletions.
14 changes: 6 additions & 8 deletions nncf/quantization/algorithms/weight_compression/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def _quantize_weights(
quantized_col = decompress_nf4_weight(compressed_weights, scales[-1])
else:
compressed_weights = calculate_quantized_weight(
fns.unsqueeze(weight_col, 1), scales[-1], zero_points[-1], block_compression_config
fns.unsqueeze(weight_col, 1), block_compression_config, scales[-1], zero_points[-1]
)
quantized_col = do_dequantization(compressed_weights, scales[-1], zero_points[-1])
quantized_col = fns.flatten(quantized_col)
Expand All @@ -287,13 +287,11 @@ def _quantize_weights(
)

scales = fns.stack(scales, axis=1)
if wc_params.compression_config.mode == CompressWeightsMode.NF4:
zero_points = None
elif wc_params.compression_config.mode in [
CompressWeightsMode.INT8_SYM,
CompressWeightsMode.INT4_SYM,
if wc_params.compression_config.mode in [
CompressWeightsMode.INT8_ASYM,
CompressWeightsMode.INT4_ASYM,
]:
zero_points = fns.squeeze(zero_points[0])
else:
zero_points = fns.stack(zero_points, axis=1)
else:
zero_points = None
return scales, zero_points
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from nncf.quantization.algorithms.weight_compression.backend import WeightCompressionAlgoBackend
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_dequantization
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_integer_quantization
from nncf.quantization.algorithms.weight_compression.weight_lowering import get_integer_quantization_error

Expand Down Expand Up @@ -176,7 +177,7 @@ def _calc_weight_sensitivity(self, weight_param: WeightCompressionParameters) ->
weight = weight.astype(TensorDataType.float32)

compressed_weights, scale, zero_point = do_integer_quantization(weight, reduction_axes, backup_config)
decompressed_weight = (compressed_weights - zero_point).astype(weight.dtype) * scale
decompressed_weight = do_dequantization(compressed_weights, scale, zero_point)
decompressed_weight = decompressed_weight.reshape(orig_shape)
return fns.linalg.norm(decompressed_weight - weight, ord="fro").item()

Expand Down
62 changes: 32 additions & 30 deletions nncf/quantization/algorithms/weight_compression/openvino_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from nncf.common.graph.utils import get_reduction_axes
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.experimental.tensor.definitions import TensorDataType
from nncf.experimental.tensor.functions import count_nonzero
from nncf.experimental.tensor.tensor import Tensor
from nncf.openvino.graph.metatypes import openvino_metatypes as om
from nncf.openvino.graph.model_transformer import OVModelTransformer
Expand Down Expand Up @@ -174,19 +173,16 @@ def transform_model(
compressed_weight.tensor.data, dtype=compression_dtype, name=const_node_name
)
converted_const = opset.convert(compressed_const, ov.Type.f16)
if compressed_weight.zero_point is not None:
if compressed_weight.tensor.dtype == TensorDataType.int8:
assert count_nonzero(compressed_weight.zero_point.data) == 0
else:
zero_point_const = opset.constant(
compressed_weight.zero_point.data,
dtype=compression_dtype,
name=f"{const_node_name}/zero_point",
)
converted_zero_point = opset.convert(zero_point_const, ov.Type.f16)
converted_const = opset.subtract(
converted_const, converted_zero_point, name=f"{const_node_name}/zero_point/subtract"
)
if compressed_weight.zero_point is not None and compressed_weight.tensor.dtype == TensorDataType.uint8:
zero_point_const = opset.constant(
compressed_weight.zero_point.data,
dtype=compression_dtype,
name=f"{const_node_name}/zero_point",
)
converted_zero_point = opset.convert(zero_point_const, ov.Type.f16)
converted_const = opset.subtract(
converted_const, converted_zero_point, name=f"{const_node_name}/zero_point/subtract"
)

scale_const = opset.constant(
compressed_weight.scale.data, dtype=ov.Type.f16, name=f"{const_node_name}/scale"
Expand Down Expand Up @@ -222,27 +218,28 @@ def dump_parameters(

@staticmethod
def get_compress_decompress_pipeline(
weight_compression_parameter: WeightCompressionParameters, w_shape, s_shape, z_p_shape
weight_compression_parameter: WeightCompressionParameters, w_shape, s_shape, z_p_shape=None
):
(
w,
s,
zp,
clamp,
) = OVWeightCompressionAlgoBackend.get_compress_pipeline(
parameters, clamp = OVWeightCompressionAlgoBackend.get_compress_pipeline(
weight_compression_parameter, w_shape, s_shape, z_p_shape, True
)

result = (clamp - zp) * s
model = ov.Model([result], [w, s, zp])
if len(parameters) == 3:
_, s, zp = parameters
result = (clamp - zp) * s
else:
s = parameters[1]
result = clamp * s

model = ov.Model([result], parameters)

compiled_model = ov.compile_model(model)

return lambda w, s, zp: compiled_model([w, s, zp])[0]
return lambda parameters: compiled_model(parameters)[0]

@staticmethod
def get_compress_pipeline(
weight_compression_parameter: WeightCompressionParameters, w_shape, s_shape, z_p_shape, return_nodes=False
weight_compression_parameter: WeightCompressionParameters, w_shape, s_shape, z_p_shape=None, return_nodes=False
):
config = weight_compression_parameter.compression_config
mode = config.mode
Expand All @@ -254,18 +251,23 @@ def get_compress_pipeline(

w = opset.parameter(w_shape, name="w")
s = opset.parameter(s_shape, name="s")
zp = opset.parameter(z_p_shape, name="zp")
parameters = [w, s]
compressed_w = w / s
if z_p_shape is not None:
zp = opset.parameter(z_p_shape, name="zp")
parameters.append(zp)
compressed_w += zp

result = opset.clamp(opset.round(w / s + zp), level_low, level_high, name="compressed_weights")
result = opset.clamp(opset.round(compressed_w), level_low, level_high, name="compressed_weights")

if return_nodes:
return w, s, zp, result
return parameters, result

model = ov.Model([result], [w, s, zp])
model = ov.Model([result], parameters)

compiled_model = ov.compile_model(model)

return lambda w, s, zp: compiled_model([w, s, zp])[0]
return lambda parameters: compiled_model(parameters)[0]


class OVAWQAlgoAlgoBackend(OVWeightCompressionAlgoBackend):
Expand Down
74 changes: 43 additions & 31 deletions nncf/quantization/algorithms/weight_compression/scale_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def apply(
:return: Dict with pairs (weight name, estimated scale).
"""

compress_decompress_cashe = {}
compress_decompress_cache = {}
res = dict()

for wp in track(self._all_weight_params, description="Applying Scale Estimation"):
Expand Down Expand Up @@ -165,8 +165,6 @@ def apply(
original_weight = fns.zeros_like(weight) + weight

compressed_weights, scale, zp = do_integer_quantization(original_weight, reduction_axis, config)
zp = zp.astype(scale.dtype)

q_weights = do_dequantization(compressed_weights, scale, zp, reduction_axis)

s = fns.unsqueeze(s, 0)
Expand All @@ -180,9 +178,7 @@ def apply(
importance = fns.ones_like(original_weight)
importance = importance * s

target = compressed_weights.astype(dtype=zp.dtype) - zp
zero_mask = compressed_weights == zp

target, zero_mask = get_target_zero_mask(compressed_weights, zp)
importance = fns.where(zero_mask, 0.0, importance)

# normalize importances for every group of weights to make sum of them equal to 1.0
Expand All @@ -203,33 +199,36 @@ def apply(
if self._weight_penalty > 0.0:
min_max_scale_diffs += self._weight_penalty * fns.mean((q_weights - original_weight) ** 2, axis=-1)

key = (
(wp.compression_config.mode, wp.compression_config.num_bits) + q_weights.shape + scale.shape + zp.shape
)
if key in compress_decompress_cashe:
compress_decompress_model = compress_decompress_cashe[key]["compress_decompress_model"]
compress_model = compress_decompress_cashe[key]["compress_model"]
zp_shape = zp.shape if zp is not None else None
key = [(wp.compression_config.mode, wp.compression_config.num_bits) + q_weights.shape + scale.shape]
if zp is not None:
key += zp_shape
key = tuple(key)
if key in compress_decompress_cache:
compress_decompress_model = compress_decompress_cache[key]["compress_decompress_model"]
compress_model = compress_decompress_cache[key]["compress_model"]
else:
compress_decompress_model = self._backend_entity.get_compress_decompress_pipeline(
wp, q_weights.shape, scale.shape, zp.shape
wp, q_weights.shape, scale.shape, zp_shape
)
compress_model = self._backend_entity.get_compress_pipeline(wp, q_weights.shape, scale.shape, zp.shape)
compress_decompress_cashe[key] = {
compress_model = self._backend_entity.get_compress_pipeline(wp, q_weights.shape, scale.shape, zp_shape)
compress_decompress_cache[key] = {
"compress_decompress_model": compress_decompress_model,
"compress_model": compress_model,
}

zero_scale = 0.001
zero_mask = zero_scale * zero_mask.astype(original_weight.dtype)

input_tensors = [original_weight.data, None]
if zp is not None:
input_tensors.append(zp.data)
# iterative rectification of initial scale
for i in range(self._initial_steps):
ideal_scale = fns.abs(original_weight) / (fns.abs(target) + zero_mask)
weighted_scale = ideal_scale * importance
near_to_ideal_scale = get_near_to_ideal_scale(original_weight, target, zero_mask, importance)
input_tensors[1] = near_to_ideal_scale.data

near_to_ideal_scale = fns.sum(weighted_scale, axis=2, keepdims=True)

out = compress_decompress_model(original_weight.data, near_to_ideal_scale.data, zp.data)
out = compress_decompress_model(input_tensors)
q_weights_ = fns.zeros_like(original_weight) + out
q_outs = fns.matmul(fns.transpose(q_weights_, (1, 0, 2)), X)

Expand All @@ -252,31 +251,29 @@ def apply(
else:
near_to_ideal_scale = mask * result_scale + (1.0 - mask) * near_to_ideal_scale
result_scale = near_to_ideal_scale
input_tensors[1] = near_to_ideal_scale.data

if i < self._initial_steps - 1:
out = compress_model(original_weight.data, near_to_ideal_scale.data, zp.data)
out = compress_model(input_tensors)
compressed_weights = fns.zeros_like(original_weight) + out
target = compressed_weights - zp
zero_mask = compressed_weights == zp
target, zero_mask = get_target_zero_mask(compressed_weights, zp)
zero_mask = zero_scale * zero_mask.astype(original_weight.dtype)

# iterative rectification of scale based on grid search
for scale_steps in range(self._scale_steps):
factor = 1.0 - 0.05 * scale_steps
scaled_scale = factor * scale

out = compress_model(original_weight.data, scaled_scale.data, zp.data)
input_tensors[1] = scaled_scale.data
out = compress_model(input_tensors)
compressed_weights = fns.zeros_like(original_weight) + out

target = compressed_weights - zp
zero_mask = compressed_weights == zp
target, zero_mask = get_target_zero_mask(compressed_weights, zp)
zero_mask = zero_scale * zero_mask.astype(original_weight.dtype)
near_to_ideal_scale = get_near_to_ideal_scale(original_weight, target, zero_mask, importance)

ideal_scale = fns.abs(original_weight) / (fns.abs(target) + zero_mask)
weighted_scale = ideal_scale * importance
near_to_ideal_scale = fns.sum(weighted_scale, axis=2, keepdims=True)

out = compress_decompress_model(original_weight.data, near_to_ideal_scale.data, zp.data)
input_tensors[1] = near_to_ideal_scale.data
out = compress_decompress_model(input_tensors)
q_weights_ = fns.zeros_like(original_weight) + out

q_outs = fns.matmul(fns.transpose(q_weights_, (1, 0, 2)), X)
Expand All @@ -300,3 +297,18 @@ def apply(
res[weight_name] = result_scale

return res


def get_target_zero_mask(compressed_weights, zp=None):
target = compressed_weights
if zp is not None:
target = target.astype(dtype=zp.dtype) - zp
zero_mask = fns.isclose(target, 0)
return target, zero_mask


def get_near_to_ideal_scale(weight, target, zero_mask, importance):
ideal_scale = fns.abs(weight) / (fns.abs(target) + zero_mask)
weighted_scale = ideal_scale * importance
near_to_ideal_scale = fns.sum(weighted_scale, axis=2, keepdims=True)
return near_to_ideal_scale
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.experimental.tensor.definitions import TensorDataType
from nncf.experimental.tensor.functions import count_nonzero
from nncf.experimental.tensor.tensor import Tensor
from nncf.parameters import CompressWeightsMode
from nncf.quantization.algorithms.weight_compression.backend import WeightCompressionAlgoBackend
Expand Down Expand Up @@ -233,7 +232,6 @@ def transform_model(

# creates weight decompressor
if compression_config.mode == CompressWeightsMode.INT8_SYM:
assert count_nonzero(compressed_weight.zero_point) == 0
decompressor = SymmetricWeightsDecompressor(compressed_weight.scale.data, result_dtype=weight.dtype)
else:
packed_zero_point = compressed_weight.zero_point.astype(dtype)
Expand Down
37 changes: 20 additions & 17 deletions nncf/quantization/algorithms/weight_compression/weight_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,28 +233,27 @@ def calculate_integer_quantization_params(
scale, zero_point = calculate_scale_zero_point(
min_values, max_values, level_low, level_high, narrow_range=False
)
else:
level_high = 2 ** (num_bits - 1) - 1
scale = fns.max(fns.abs(weight), axis=reduction_axes, keepdims=True) # [a1, r//gs, 1, a2]
scale = scale / level_high
zero_point = fns.as_tensor_like(scale, [0]).astype(TensorDataType.int32)
eps = fns.finfo(scale).eps
# NOTE: adding machine epsilon to avoid division by zero
scale = fns.where(fns.abs(scale) < eps, eps, scale)
return scale, zero_point

return scale, zero_point
level_high = 2 ** (num_bits - 1) - 1
scale = fns.max(fns.abs(weight), axis=reduction_axes, keepdims=True) # [a1, r//gs, 1, a2]
scale /= level_high
eps = fns.finfo(scale).eps
# NOTE: adding machine epsilon to avoid division by zero
scale = fns.where(fns.abs(scale) < eps, eps, scale)
return scale, None


def calculate_quantized_weight(
weight: Tensor, scale: Tensor, zero_point: Tensor, config: WeightCompressionConfig
weight: Tensor, config: WeightCompressionConfig, scale: Tensor, zero_point: Optional[Tensor] = None
) -> Tensor:
"""
Quantizes the weight tensor using the provided scale and zero point.
:param weight: Weight tensor to quantize.
:param config: Weight compression configuration.
:param scale: Scale tensor used for quantization.
:param zero_point: Zero point tensor used for quantization.
:param config: Weight compression configuration.
:return: Quantized weight tensor of uint8 or int8 type.
"""
if weight.dtype != TensorDataType.float32:
Expand All @@ -268,7 +267,10 @@ def calculate_quantized_weight(
level_low = 0 if asym_quant else -(2 ** (num_bits - 1))
level_high = 2**num_bits - 1 if asym_quant else 2 ** (num_bits - 1) - 1

compressed_weights = fns.round(weight / scale + zero_point.astype(weight.dtype))
compressed_weights = weight / scale
if zero_point is not None:
compressed_weights += zero_point.astype(weight.dtype)
compressed_weights = fns.round(compressed_weights)
compressed_weights = fns.clip(compressed_weights, level_low, level_high).astype(dtype)
return compressed_weights

Expand Down Expand Up @@ -322,7 +324,7 @@ def do_integer_quantization(
if precomputed_zero_point is not None:
zero_point = precomputed_zero_point

compressed_weights = calculate_quantized_weight(weight, scale, zero_point, config)
compressed_weights = calculate_quantized_weight(weight, config, scale, zero_point)
return compressed_weights, scale, zero_point


Expand All @@ -344,8 +346,7 @@ def get_integer_quantization_error(
weight = weight.astype(TensorDataType.float32)

compressed_weights, scale, zero_point = do_integer_quantization(weight, reduction_axes, config)

decompressed_weight = (compressed_weights - zero_point).astype(weight.dtype) * scale
decompressed_weight = do_dequantization(compressed_weights, scale, zero_point)

decompressed_weight = decompressed_weight.reshape(orig_shape)
diff = (decompressed_weight - weight) ** 2
Expand Down Expand Up @@ -384,7 +385,7 @@ def compress_weight(


def do_dequantization(
compressed_weights: Tensor, scale: Tensor, zero_point: Tensor, reduction_axis: int = -1
compressed_weights: Tensor, scale: Tensor, zero_point: Optional[Tensor] = None, reduction_axis: int = -1
) -> Tensor:
"""
The method dequantizes the given weights to float point data type in accordance with the scale and
Expand All @@ -397,7 +398,9 @@ def do_dequantization(
:return: dequantized/decompressed weights.
"""
decompressed_weight = compressed_weights.astype(dtype=scale.dtype)
decompressed_weight = (decompressed_weight - zero_point) * scale
if zero_point is not None:
decompressed_weight -= zero_point
decompressed_weight *= scale

if reduction_axis > -1:
shape = list(decompressed_weight.shape) # [a1, r, a2] - "r" refers to number of channels along reduction axis
Expand Down
Loading

0 comments on commit 4b2e602

Please sign in to comment.