diff --git a/nncf/quantization/algorithms/min_max/onnx_backend.py b/nncf/quantization/algorithms/min_max/onnx_backend.py index f58299a5d10..276defe2b0f 100644 --- a/nncf/quantization/algorithms/min_max/onnx_backend.py +++ b/nncf/quantization/algorithms/min_max/onnx_backend.py @@ -42,7 +42,7 @@ from nncf.quantization.algorithms.min_max.backend import MinMaxAlgoBackend from nncf.quantization.fake_quantize import FakeConvertParameters from nncf.quantization.fake_quantize import FakeQuantizeParameters -from nncf.quantization.range_estimator import RangeEstimatorParameters +from nncf.quantization.range_estimator import RangeEstimatorParameters, AggregatorType class ONNXMinMaxAlgoBackend(MinMaxAlgoBackend): @@ -211,11 +211,14 @@ def get_statistic_collector( statistic_type = StatisticsType.ABS_MAX reducer = ONNX_REDUCERS_MAP[statistic_type](**kwargs) - aggregator = AGGREGATORS_MAP[params.aggregator_type]( - num_samples=num_samples, - aggregation_axes=aggregation_axes, - tensor_processor=ONNXNNCFCollectorTensorProcessor, - ) + kwargs = { + "num_samples": num_samples, + "aggregation_axes": aggregation_axes, + "tensor_processor": ONNXNNCFCollectorTensorProcessor + } + if params.aggregator_type == AggregatorType.MEAN_NO_OUTLIERS: + kwargs.update({"quantile": params.quantile_outlier_prob}) + aggregator = AGGREGATORS_MAP[params.aggregator_type](**kwargs) collector.register_statistic_branch(container_key, reducer, aggregator) return collector diff --git a/nncf/quantization/algorithms/min_max/openvino_backend.py b/nncf/quantization/algorithms/min_max/openvino_backend.py index 417f9c7cbec..dfed19632a6 100644 --- a/nncf/quantization/algorithms/min_max/openvino_backend.py +++ b/nncf/quantization/algorithms/min_max/openvino_backend.py @@ -42,6 +42,7 @@ from nncf.quantization.algorithms.min_max.backend import MinMaxAlgoBackend from nncf.quantization.fake_quantize import FakeConvertParameters from nncf.quantization.fake_quantize import FakeQuantizeParameters +from nncf.quantization.range_estimator import AggregatorType class OVMinMaxAlgoBackend(MinMaxAlgoBackend): @@ -195,11 +196,14 @@ def get_statistic_collector( statistic_type = StatisticsType.ABS_MAX reducer = OV_REDUCERS_MAP[statistic_type](**kwargs) - aggregator = AGGREGATORS_MAP[params.aggregator_type]( - num_samples=num_samples, - aggregation_axes=aggregation_axes, - tensor_processor=OVNNCFCollectorTensorProcessor, - ) + kwargs = { + "num_samples": num_samples, + "aggregation_axes": aggregation_axes, + "tensor_processor": OVNNCFCollectorTensorProcessor + } + if params.aggregator_type == AggregatorType.MEAN_NO_OUTLIERS: + kwargs.update({"quantile": params.quantile_outlier_prob}) + aggregator = AGGREGATORS_MAP[params.aggregator_type](**kwargs) collector.register_statistic_branch(container_key, reducer, aggregator) return collector diff --git a/nncf/quantization/algorithms/min_max/torch_backend.py b/nncf/quantization/algorithms/min_max/torch_backend.py index 541792eca78..60dd7ee311b 100644 --- a/nncf/quantization/algorithms/min_max/torch_backend.py +++ b/nncf/quantization/algorithms/min_max/torch_backend.py @@ -33,7 +33,7 @@ from nncf.quantization.algorithms.min_max.backend import MinMaxAlgoBackend from nncf.quantization.fake_quantize import FakeConvertParameters from nncf.quantization.fake_quantize import FakeQuantizeParameters -from nncf.quantization.range_estimator import RangeEstimatorParameters +from nncf.quantization.range_estimator import RangeEstimatorParameters, AggregatorType from nncf.torch.graph.graph import PTNNCFGraph from nncf.torch.graph.graph import PTTargetPoint from nncf.torch.graph.transformations.command_creation import create_quantizer_insertion_command @@ -195,11 +195,14 @@ def get_statistic_collector( statistic_type = StatisticsType.ABS_MAX reducer = PT_REDUCERS_MAP[statistic_type](reduction_axes=reduction_axes) - aggregator = AGGREGATORS_MAP[params.aggregator_type]( - aggregation_axes=aggregation_axes, - num_samples=num_samples, - tensor_processor=PTNNCFCollectorTensorProcessor, - ) + kwargs = { + "num_samples": num_samples, + "aggregation_axes": aggregation_axes, + "tensor_processor": PTNNCFCollectorTensorProcessor + } + if params.aggregator_type == AggregatorType.MEAN_NO_OUTLIERS: + kwargs.update({"quantile": params.quantile_outlier_prob}) + aggregator = AGGREGATORS_MAP[params.aggregator_type](**kwargs) collector.register_statistic_branch(container_key, reducer, aggregator) return collector