diff --git a/nncf/experimental/quantization/algorithms/quantizer/fx_quantizer.py b/nncf/experimental/quantization/algorithms/quantizer/torch_ao_adapter.py similarity index 99% rename from nncf/experimental/quantization/algorithms/quantizer/fx_quantizer.py rename to nncf/experimental/quantization/algorithms/quantizer/torch_ao_adapter.py index 7af8c022d17..a95ddff9dcc 100644 --- a/nncf/experimental/quantization/algorithms/quantizer/fx_quantizer.py +++ b/nncf/experimental/quantization/algorithms/quantizer/torch_ao_adapter.py @@ -34,7 +34,7 @@ EdgeOrNode = Union[Tuple[torch.fx.Node, torch.fx.Node]] -class NNCFFXQuantizer(NNCFQuantizer): +class TorchAOQuantizerAdapter(NNCFQuantizer): def __init__(self, quantizer: Quantizer): self._quantizer = quantizer diff --git a/nncf/experimental/torch/fx/quantization/quantize_pt2e.py b/nncf/experimental/torch/fx/quantization/quantize_pt2e.py index ac095395175..e92a630ee93 100644 --- a/nncf/experimental/torch/fx/quantization/quantize_pt2e.py +++ b/nncf/experimental/torch/fx/quantization/quantize_pt2e.py @@ -17,6 +17,7 @@ from torch.ao.quantization.pt2e.duplicate_dq_pass import DuplicateDQPass from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ from torch.ao.quantization.pt2e.utils import _disallow_eval_train +from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_ from torch.ao.quantization.quantizer import Quantizer from torch.fx import GraphModule from torch.fx.passes.infra.pass_manager import PassManager @@ -26,11 +27,9 @@ from nncf.common.logging import nncf_logger from nncf.data import Dataset from nncf.experimental.quantization.algorithms.post_training.algorithm import ExperimentalPostTrainingQuantization -from nncf.experimental.quantization.algorithms.quantizer.base_quantizer import Quantizer as NNCFQuantizer -from nncf.experimental.quantization.algorithms.quantizer.fx_quantizer import NNCFFXQuantizer +from nncf.experimental.quantization.algorithms.quantizer.torch_ao_adapter import TorchAOQuantizerAdapter from nncf.experimental.torch.fx.constant_folding import constant_fold from nncf.experimental.torch.fx.transformations import QUANTIZE_NODE_TARGETS -from nncf.experimental.torch.fx.transformations import fuse_conv_bn from nncf.quantization.advanced_parameters import AdvancedBiasCorrectionParameters from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters from nncf.quantization.advanced_parameters import RangeEstimatorParameters @@ -48,7 +47,7 @@ def quantize_pt2e( activations_range_estimator_params: Optional[RangeEstimatorParameters] = None, weights_range_estimator_params: Optional[RangeEstimatorParameters] = None, batchwise_statistics: Optional[bool] = None, - fold_quantize: bool = False, + fold_quantize: bool = True, do_copy: bool = False, ) -> torch.fx.GraphModule: """ @@ -72,7 +71,7 @@ def quantize_pt2e( :param batchwise_statistics: Determines whether quantizer statistics should be calculated for each item of the batch or for the entire batch, default is None, which means it set True if batch_size > 1 otherwise False. - :param fold_quantize: Boolean flag for whether fold the quantize op or not. + :param fold_quantize: Boolean flag for whether fold the quantize op or not. The value is True by default. :param do_copy: The copy of the given model is being quantized if do_copy == True, otherwise the model is quantized inplace. Default value is False. """ @@ -90,15 +89,12 @@ def quantize_pt2e( if do_copy: model = deepcopy(model) - # To make it easier for bias correction algorithms, - # biases are being separated by the followng calls. - fuse_conv_bn(model) + _fuse_conv_bn_(model) # Call ao quantizer transform_for_annotation # before the NNCFGraph creation quantizer.transform_for_annotation(model) - if not isinstance(quantizer, NNCFQuantizer): - quantizer = NNCFFXQuantizer(quantizer) + quantizer = TorchAOQuantizerAdapter(quantizer) quantization_algorithm = ExperimentalPostTrainingQuantization( quantizer=quantizer, diff --git a/nncf/experimental/torch/fx/transformations.py b/nncf/experimental/torch/fx/transformations.py index 735dd4eeca7..0121c535187 100644 --- a/nncf/experimental/torch/fx/transformations.py +++ b/nncf/experimental/torch/fx/transformations.py @@ -15,7 +15,7 @@ import torch import torch.fx from torch.ao.quantization.fx.utils import create_getattr_from_value -from torch.ao.quantization.pt2e.utils import fold_bn_weights_into_conv_node +from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_ from torch.quantization.fake_quantize import FakeQuantize import nncf @@ -512,41 +512,6 @@ def _is_supported_batch_norm_for_training(node: torch.fx.Node): return node.target in supported_ops -def _is_bn_node(node: torch.fx.Node): - return ( - _is_supported_batch_norm_for_training(node) - or node.target == torch.ops.aten._native_batch_norm_legit_no_training.default - ) - - -def fuse_conv_bn(model: torch.fx.GraphModule) -> None: - """ - BatchNorm operations have 3 output ports, to make it easier for algorithms to work with - the target graph BatchNorm operations are being fused - - :param model: Model to apply transformations to. - """ - has_bn = any(_is_bn_node(node) for node in model.graph.nodes) - if not has_bn: - return - - for node in model.graph.nodes: - if node.op != "call_function" or not _is_bn_node(node): - continue - bn_node = node - - node = bn_node.args[0] - if not _is_conv(node): - continue - conv_node = node - conv_weight_node = conv_node.args[1] - conv_bias_node = conv_node.args[2] if len(conv_node.args) > 2 else None - fold_bn_weights_into_conv_node(conv_node, conv_weight_node, conv_bias_node, bn_node, model) - - model.graph.eliminate_dead_code() - model.recompile() - - def _get_pattern_replacement_per_channel() -> ( Tuple[Callable[[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, torch.dtype], torch.Tensor]] ): @@ -768,7 +733,7 @@ def apply_quantization_transformations(model: torch.fx.GraphModule) -> None: # with the target graph BatchNorm operations # are being fused fold_constant_except_qdq(model) - fuse_conv_bn(model) + _fuse_conv_bn_(model) def fold_constant_except_qdq(model: torch.fx.GraphModule): @@ -782,14 +747,3 @@ def constraint_fn(node: torch.fx.Node): return node.op != "call_function" or node.target not in QUANTIZE_NODE_TARGETS + DEQUANTIZE_NODE_TARGETS constant_fold(model, constraint_fn=constraint_fn) - - -def _is_conv(n: torch.fx.Node): - """ - Return whether the node refers to an aten conv op. - """ - return n.op == "call_function" and n.target in ( - torch.ops.aten.conv1d.default, - torch.ops.aten.conv2d.default, - torch.ops.aten.conv_transpose2d.input, - ) diff --git a/tests/torch/fx/test_quantizer.py b/tests/torch/fx/test_quantizer.py index 66186070fae..86f0573f635 100644 --- a/tests/torch/fx/test_quantizer.py +++ b/tests/torch/fx/test_quantizer.py @@ -156,8 +156,10 @@ def test_quantized_model( # Uncomment to visualize reference graphs # from torch.ao.quantization.quantize_pt2e import convert_pt2e # from torch.ao.quantization.quantize_pt2e import prepare_pt2e + # from tests.torch.fx.helpers import visualize_fx_model # prepared_model = prepare_pt2e(fx_model, quantizer) # prepared_model(example_input) # ao_quantized_model = convert_pt2e(prepared_model) + # visualize_fx_model(ao_quantized_model, f"{model_case.model_id}ao_int8.svg") # ao_nncf_graph = GraphConverter.create_nncf_graph(ao_quantized_model) # ao_nncf_graph.visualize_graph("ao_" + get_dot_filename(model_case.model_id))