Skip to content

Commit

Permalink
[PTQ] Extend num_samples=1 for weights statistics collection for all…
Browse files Browse the repository at this point in the history
… backends (#1999)

### Changes

Make num_samples=1 for weights statistics collection for all backends

### Reason for changes

Speed up statistics collection for ONNX, Torch.
Generalize logic.

### Related tickets

N/A

### Tests

Update `test_get_stat_collector`
  • Loading branch information
kshpv authored Jul 28, 2023
1 parent feb4ca3 commit 1848211
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 17 deletions.
24 changes: 18 additions & 6 deletions nncf/quantization/algorithms/min_max/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,13 +261,21 @@ def _get_range_estimator_parameters(
return RangeEstimatorParameters(min_statistic_collector, max_statistic_collector)

def _get_stat_collector(
self, nncf_graph: NNCFGraph, target_point: TargetPoint, quantizer_config: QuantizerConfig
self,
nncf_graph: NNCFGraph,
target_point: TargetPoint,
quantizer_config: QuantizerConfig,
num_samples: int,
) -> TensorStatisticCollectorBase:
"""
Creates and returns statistic collector instance based on the quantizer's configuration.
Creates and returns a statistic collector based on the quantizer's configuration.
:param quantizer_config: QuantizerConfig instance for the current layer.
:return: One of the TensorStatisticCollectorBase instances
:param nncf_graph: NNCFGraph instance.
:param target_point: Target point indicates where statistics should be collected.
:param quantizer_config: Configuration of a quantizer layer,
defining the configuration of created statistic collector.
:param num_samples: Number of samples to collect from the 'target_point'.
:return: Statistic Collector.
"""
range_estimator_params = self._get_range_estimator_parameters(target_point, quantizer_config)

Expand All @@ -277,7 +285,7 @@ def _get_stat_collector(
target_point,
quantizer_config,
inplace=self._inplace_statistics,
num_samples=self._subset_size,
num_samples=num_samples,
)

def _get_default_qconfig(self, constraints: QuantizationConstraints = None) -> QuantizerConfig:
Expand Down Expand Up @@ -687,7 +695,11 @@ def get_statistic_points(self, model: TModel) -> StatisticPointsContainer:
f"Adding target point {quantization_target_point.target_node_name}"
f" with type {quantization_target_point.type} for statistics collection"
)
stat_collector = self._get_stat_collector(nncf_graph, quantization_target_point, qconfig)
num_samples = self._subset_size
if quantization_target_point.is_weight_target_point():
# Weight statistics is constant, so only one collection is enough.
num_samples = 1
stat_collector = self._get_stat_collector(nncf_graph, quantization_target_point, qconfig, num_samples)
output.add_statistic_point(
StatisticPoint(
target_point=quantization_target_point,
Expand Down
3 changes: 2 additions & 1 deletion nncf/quantization/algorithms/min_max/onnx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from nncf.common.hardware.config import HWConfig
from nncf.common.quantization.structs import QuantizationMode
from nncf.common.quantization.structs import QuantizerConfig
from nncf.common.tensor_statistics.collectors import ReductionShape
from nncf.common.utils.backend import BackendType
from nncf.onnx.graph.metatypes import onnx_metatypes as om
from nncf.onnx.graph.node_utils import get_input_edges_mapping
Expand Down Expand Up @@ -166,7 +167,7 @@ def _get_axis(
@staticmethod
def _get_reduction_shape_and_use_abs_max(
nncf_graph: NNCFGraph, target_point: ONNXTargetPoint, quantizer_config: QuantizerConfig
) -> Tuple[Optional[Tuple[int, ...]], bool]:
) -> Tuple[ReductionShape, bool]:
use_abs_max = quantizer_config.mode == QuantizationMode.SYMMETRIC
if not quantizer_config.per_channel:
return None, use_abs_max
Expand Down
10 changes: 1 addition & 9 deletions nncf/quantization/algorithms/min_max/openvino_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,15 +160,8 @@ def _get_reduction_shape_and_use_abs_max(
axes = tuple(i for i in range(len(const_shape)) if i not in channel_axes)
else:
axes = tuple(range(len(const_shape)))

return axes, use_abs_max

@staticmethod
def _get_num_samples(num_samples, target_point: OVTargetPoint):
if target_point.is_weight_target_point():
return 1
return num_samples

@staticmethod
def get_statistic_collector(
range_estimator_params: RangeEstimatorParameters,
Expand All @@ -181,7 +174,6 @@ def get_statistic_collector(
reduction_shape, use_abs_max = OVMinMaxAlgoBackend._get_reduction_shape_and_use_abs_max(
nncf_graph, target_point, quantizer_config
)
_num_samples = OVMinMaxAlgoBackend._get_num_samples(num_samples, target_point)

collector = TensorCollector(OVMinMaxTensorStatistic)
for params, container_key in zip(
Expand Down Expand Up @@ -211,7 +203,7 @@ def get_statistic_collector(
statistic_type = StatisticsType.ABS_MAX
reducer = OV_REDUCERS_MAP[statistic_type](**kwargs)

kwargs = {"num_samples": _num_samples, "tensor_processor": OVNNCFCollectorTensorProcessor}
kwargs = {"num_samples": num_samples, "tensor_processor": OVNNCFCollectorTensorProcessor}
aggregator = AGGREGATORS_MAP[params.aggregator_type](**kwargs)

collector.register_statistic_branch(container_key, reducer, aggregator)
Expand Down
6 changes: 5 additions & 1 deletion tests/post_training/test_templates/test_quantizer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,11 +204,13 @@ def test_depthwise_conv_default_quantizer_config(self, depthwise_conv_nncf_graph
)
@pytest.mark.parametrize("q_config_mode", [QuantizationMode.SYMMETRIC, QuantizationMode.ASYMMETRIC])
@pytest.mark.parametrize("q_config_per_channel", [True, False])
@pytest.mark.parametrize("num_samples", [5, 12])
def test_get_stat_collector(
self,
range_estimator_params,
q_config_mode,
q_config_per_channel,
num_samples,
conv_sum_aggregation_nncf_graph,
statistic_collector_parameters: TestGetStatisticsCollectorParameters,
):
Expand All @@ -234,7 +236,7 @@ def test_get_stat_collector(

target_point = list(min_max_algo._quantization_target_points_to_qconfig.keys())[0]
tensor_collector = min_max_algo._get_stat_collector(
conv_sum_aggregation_nncf_graph.nncf_graph, target_point, q_config
conv_sum_aggregation_nncf_graph.nncf_graph, target_point, q_config, num_samples
)

is_weight_tp = target_point.is_weight_target_point()
Expand Down Expand Up @@ -271,3 +273,5 @@ def test_get_stat_collector(
assert reducer._reduction_shape == params.ref_per_ch_reduction_shape
else:
assert reducer._reduction_shape == params.ref_per_tensor_reduction_shape

assert tensor_collector.num_samples == num_samples

0 comments on commit 1848211

Please sign in to comment.