From 489118438c6cb5fa0a2f76d1fa50d8ddc246e961 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Sun, 29 Dec 2024 18:54:05 +0100 Subject: [PATCH] Comments --- .../torch/fx/quantization/quantize_pt2e.py | 17 +++++++++++------ .../algorithms/min_max/algorithm.py | 4 ++-- tests/torch/fx/test_quantizer.py | 1 + 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/nncf/experimental/torch/fx/quantization/quantize_pt2e.py b/nncf/experimental/torch/fx/quantization/quantize_pt2e.py index 305416d4671..ac095395175 100644 --- a/nncf/experimental/torch/fx/quantization/quantize_pt2e.py +++ b/nncf/experimental/torch/fx/quantization/quantize_pt2e.py @@ -49,6 +49,7 @@ def quantize_pt2e( weights_range_estimator_params: Optional[RangeEstimatorParameters] = None, batchwise_statistics: Optional[bool] = None, fold_quantize: bool = False, + do_copy: bool = False, ) -> torch.fx.GraphModule: """ Applies post-training quantization to the torch.fx.GraphModule provided model @@ -72,6 +73,8 @@ def quantize_pt2e( for each item of the batch or for the entire batch, default is None, which means it set True if batch_size > 1 otherwise False. :param fold_quantize: Boolean flag for whether fold the quantize op or not. + :param do_copy: The copy of the given model is being quantized if do_copy == True, + otherwise the model is quantized inplace. Default value is False. """ nncf_logger.warning("This is an experimental feature and may change in the future without notice.") @@ -79,18 +82,20 @@ def quantize_pt2e( raise nncf.ValidationError("Subset size must be positive.") batch_size = calibration_dataset.get_batch_size() - batchwise_statistics = batchwise_statistics is None and batch_size is not None and batch_size > 1 + if batchwise_statistics is None: + batchwise_statistics = batch_size is not None and batch_size > 1 original_graph_meta = model.meta - copied_model = deepcopy(model) + if do_copy: + model = deepcopy(model) # To make it easier for bias correction algorithms, # biases are being separated by the followng calls. - fuse_conv_bn(copied_model) + fuse_conv_bn(model) # Call ao quantizer transform_for_annotation # before the NNCFGraph creation - quantizer.transform_for_annotation(copied_model) + quantizer.transform_for_annotation(model) if not isinstance(quantizer, NNCFQuantizer): quantizer = NNCFFXQuantizer(quantizer) @@ -107,8 +112,8 @@ def quantize_pt2e( batchwise_statistics=batchwise_statistics, ) - nncf_graph = NNCFGraphFactory.create(copied_model) - quantized_model = quantization_algorithm.apply(copied_model, nncf_graph, dataset=calibration_dataset) + nncf_graph = NNCFGraphFactory.create(model) + quantized_model = quantization_algorithm.apply(model, nncf_graph, dataset=calibration_dataset) # Magic. Without this call compiled model # is not preformant diff --git a/nncf/quantization/algorithms/min_max/algorithm.py b/nncf/quantization/algorithms/min_max/algorithm.py index 30ea8919e22..4cd5b781a6e 100644 --- a/nncf/quantization/algorithms/min_max/algorithm.py +++ b/nncf/quantization/algorithms/min_max/algorithm.py @@ -790,7 +790,7 @@ def _get_activation_quantization_target_point( def find_quantization_setup(self, model: TModel, nncf_graph: NNCFGraph) -> SingleConfigQuantizerSetup: """ - Initializes a cache, finds quantization target points and them puts in the cache. + Initializes a cache, finds quantization target points and then puts them in the cache. :param quantizer_setup: Quantization Target Points in format of SingleConfigQuantizerSetup. :param nncf_graph: NNCFGraph instance. @@ -822,7 +822,7 @@ def fill_quantization_target_points( self, quantizer_setup: SingleConfigQuantizerSetup, nncf_graph: NNCFGraph ) -> Tuple[OrderedDict[TargetPoint, QuantizerConfig], List[List[TargetPoint]]]: """ - Initializes a cache, finds quantization target points and them puts in the cache. + Initializes a cache and puts the given quantization target points in the cache. :param model: Backend-specific model, for which Quantization Target Points are being seek. :param nncf_graph: NNCFGraph instance. diff --git a/tests/torch/fx/test_quantizer.py b/tests/torch/fx/test_quantizer.py index f95f59a4341..66186070fae 100644 --- a/tests/torch/fx/test_quantizer.py +++ b/tests/torch/fx/test_quantizer.py @@ -137,6 +137,7 @@ def test_quantized_model( calibration_dataset=calibration_dataset, fast_bias_correction=None, # BC is disabled fold_quantize=True, + do_copy=True, **pt2e_params, )