Skip to content

Commit

Permalink
Fix quantile parameter not being used for MEAN_NO_OUTLIERS aggragator
Browse files Browse the repository at this point in the history
  • Loading branch information
nikita-savelyevv committed Apr 18, 2024
1 parent 573b0c3 commit 92e491c
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 17 deletions.
15 changes: 9 additions & 6 deletions nncf/quantization/algorithms/min_max/onnx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from nncf.quantization.algorithms.min_max.backend import MinMaxAlgoBackend
from nncf.quantization.fake_quantize import FakeConvertParameters
from nncf.quantization.fake_quantize import FakeQuantizeParameters
from nncf.quantization.range_estimator import RangeEstimatorParameters
from nncf.quantization.range_estimator import RangeEstimatorParameters, AggregatorType


class ONNXMinMaxAlgoBackend(MinMaxAlgoBackend):
Expand Down Expand Up @@ -211,11 +211,14 @@ def get_statistic_collector(
statistic_type = StatisticsType.ABS_MAX
reducer = ONNX_REDUCERS_MAP[statistic_type](**kwargs)

aggregator = AGGREGATORS_MAP[params.aggregator_type](
num_samples=num_samples,
aggregation_axes=aggregation_axes,
tensor_processor=ONNXNNCFCollectorTensorProcessor,
)
kwargs = {
"num_samples": num_samples,
"aggregation_axes": aggregation_axes,
"tensor_processor": ONNXNNCFCollectorTensorProcessor
}
if params.aggregator_type == AggregatorType.MEAN_NO_OUTLIERS:
kwargs.update({"quantile": params.quantile_outlier_prob})
aggregator = AGGREGATORS_MAP[params.aggregator_type](**kwargs)

collector.register_statistic_branch(container_key, reducer, aggregator)
return collector
Expand Down
14 changes: 9 additions & 5 deletions nncf/quantization/algorithms/min_max/openvino_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from nncf.quantization.algorithms.min_max.backend import MinMaxAlgoBackend
from nncf.quantization.fake_quantize import FakeConvertParameters
from nncf.quantization.fake_quantize import FakeQuantizeParameters
from nncf.quantization.range_estimator import AggregatorType


class OVMinMaxAlgoBackend(MinMaxAlgoBackend):
Expand Down Expand Up @@ -195,11 +196,14 @@ def get_statistic_collector(
statistic_type = StatisticsType.ABS_MAX
reducer = OV_REDUCERS_MAP[statistic_type](**kwargs)

aggregator = AGGREGATORS_MAP[params.aggregator_type](
num_samples=num_samples,
aggregation_axes=aggregation_axes,
tensor_processor=OVNNCFCollectorTensorProcessor,
)
kwargs = {
"num_samples": num_samples,
"aggregation_axes": aggregation_axes,
"tensor_processor": OVNNCFCollectorTensorProcessor
}
if params.aggregator_type == AggregatorType.MEAN_NO_OUTLIERS:
kwargs.update({"quantile": params.quantile_outlier_prob})
aggregator = AGGREGATORS_MAP[params.aggregator_type](**kwargs)

collector.register_statistic_branch(container_key, reducer, aggregator)
return collector
Expand Down
15 changes: 9 additions & 6 deletions nncf/quantization/algorithms/min_max/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from nncf.quantization.algorithms.min_max.backend import MinMaxAlgoBackend
from nncf.quantization.fake_quantize import FakeConvertParameters
from nncf.quantization.fake_quantize import FakeQuantizeParameters
from nncf.quantization.range_estimator import RangeEstimatorParameters
from nncf.quantization.range_estimator import RangeEstimatorParameters, AggregatorType
from nncf.torch.graph.graph import PTNNCFGraph
from nncf.torch.graph.graph import PTTargetPoint
from nncf.torch.graph.transformations.command_creation import create_quantizer_insertion_command
Expand Down Expand Up @@ -195,11 +195,14 @@ def get_statistic_collector(
statistic_type = StatisticsType.ABS_MAX
reducer = PT_REDUCERS_MAP[statistic_type](reduction_axes=reduction_axes)

aggregator = AGGREGATORS_MAP[params.aggregator_type](
aggregation_axes=aggregation_axes,
num_samples=num_samples,
tensor_processor=PTNNCFCollectorTensorProcessor,
)
kwargs = {
"num_samples": num_samples,
"aggregation_axes": aggregation_axes,
"tensor_processor": PTNNCFCollectorTensorProcessor
}
if params.aggregator_type == AggregatorType.MEAN_NO_OUTLIERS:
kwargs.update({"quantile": params.quantile_outlier_prob})
aggregator = AGGREGATORS_MAP[params.aggregator_type](**kwargs)

collector.register_statistic_branch(container_key, reducer, aggregator)
return collector
Expand Down

0 comments on commit 92e491c

Please sign in to comment.