Skip to content

Commit

Permalink
Code migrated to adapters/ comments
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Jan 8, 2025
1 parent 4891184 commit b923176
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
EdgeOrNode = Union[Tuple[torch.fx.Node, torch.fx.Node]]


class NNCFFXQuantizer(NNCFQuantizer):
class TorchAOQuantizerAdapter(NNCFQuantizer):
def __init__(self, quantizer: Quantizer):
self._quantizer = quantizer

Expand Down
16 changes: 6 additions & 10 deletions nncf/experimental/torch/fx/quantization/quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from torch.ao.quantization.pt2e.duplicate_dq_pass import DuplicateDQPass
from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ
from torch.ao.quantization.pt2e.utils import _disallow_eval_train
from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_
from torch.ao.quantization.quantizer import Quantizer
from torch.fx import GraphModule
from torch.fx.passes.infra.pass_manager import PassManager
Expand All @@ -26,11 +27,9 @@
from nncf.common.logging import nncf_logger
from nncf.data import Dataset
from nncf.experimental.quantization.algorithms.post_training.algorithm import ExperimentalPostTrainingQuantization
from nncf.experimental.quantization.algorithms.quantizer.base_quantizer import Quantizer as NNCFQuantizer
from nncf.experimental.quantization.algorithms.quantizer.fx_quantizer import NNCFFXQuantizer
from nncf.experimental.quantization.algorithms.quantizer.torch_ao_adapter import TorchAOQuantizerAdapter
from nncf.experimental.torch.fx.constant_folding import constant_fold
from nncf.experimental.torch.fx.transformations import QUANTIZE_NODE_TARGETS
from nncf.experimental.torch.fx.transformations import fuse_conv_bn
from nncf.quantization.advanced_parameters import AdvancedBiasCorrectionParameters
from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters
from nncf.quantization.advanced_parameters import RangeEstimatorParameters
Expand All @@ -48,7 +47,7 @@ def quantize_pt2e(
activations_range_estimator_params: Optional[RangeEstimatorParameters] = None,
weights_range_estimator_params: Optional[RangeEstimatorParameters] = None,
batchwise_statistics: Optional[bool] = None,
fold_quantize: bool = False,
fold_quantize: bool = True,
do_copy: bool = False,
) -> torch.fx.GraphModule:
"""
Expand All @@ -72,7 +71,7 @@ def quantize_pt2e(
:param batchwise_statistics: Determines whether quantizer statistics should be calculated
for each item of the batch or for the entire batch, default is None, which means
it set True if batch_size > 1 otherwise False.
:param fold_quantize: Boolean flag for whether fold the quantize op or not.
:param fold_quantize: Boolean flag for whether fold the quantize op or not. The value is True by default.
:param do_copy: The copy of the given model is being quantized if do_copy == True,
otherwise the model is quantized inplace. Default value is False.
"""
Expand All @@ -90,15 +89,12 @@ def quantize_pt2e(
if do_copy:
model = deepcopy(model)

# To make it easier for bias correction algorithms,
# biases are being separated by the followng calls.
fuse_conv_bn(model)
_fuse_conv_bn_(model)
# Call ao quantizer transform_for_annotation
# before the NNCFGraph creation
quantizer.transform_for_annotation(model)

if not isinstance(quantizer, NNCFQuantizer):
quantizer = NNCFFXQuantizer(quantizer)
quantizer = TorchAOQuantizerAdapter(quantizer)

quantization_algorithm = ExperimentalPostTrainingQuantization(
quantizer=quantizer,
Expand Down
50 changes: 2 additions & 48 deletions nncf/experimental/torch/fx/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch
import torch.fx
from torch.ao.quantization.fx.utils import create_getattr_from_value
from torch.ao.quantization.pt2e.utils import fold_bn_weights_into_conv_node
from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_
from torch.quantization.fake_quantize import FakeQuantize

import nncf
Expand Down Expand Up @@ -512,41 +512,6 @@ def _is_supported_batch_norm_for_training(node: torch.fx.Node):
return node.target in supported_ops


def _is_bn_node(node: torch.fx.Node):
return (
_is_supported_batch_norm_for_training(node)
or node.target == torch.ops.aten._native_batch_norm_legit_no_training.default
)


def fuse_conv_bn(model: torch.fx.GraphModule) -> None:
"""
BatchNorm operations have 3 output ports, to make it easier for algorithms to work with
the target graph BatchNorm operations are being fused
:param model: Model to apply transformations to.
"""
has_bn = any(_is_bn_node(node) for node in model.graph.nodes)
if not has_bn:
return

for node in model.graph.nodes:
if node.op != "call_function" or not _is_bn_node(node):
continue
bn_node = node

node = bn_node.args[0]
if not _is_conv(node):
continue
conv_node = node
conv_weight_node = conv_node.args[1]
conv_bias_node = conv_node.args[2] if len(conv_node.args) > 2 else None
fold_bn_weights_into_conv_node(conv_node, conv_weight_node, conv_bias_node, bn_node, model)

model.graph.eliminate_dead_code()
model.recompile()


def _get_pattern_replacement_per_channel() -> (
Tuple[Callable[[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, torch.dtype], torch.Tensor]]
):
Expand Down Expand Up @@ -768,7 +733,7 @@ def apply_quantization_transformations(model: torch.fx.GraphModule) -> None:
# with the target graph BatchNorm operations
# are being fused
fold_constant_except_qdq(model)
fuse_conv_bn(model)
_fuse_conv_bn_(model)


def fold_constant_except_qdq(model: torch.fx.GraphModule):
Expand All @@ -782,14 +747,3 @@ def constraint_fn(node: torch.fx.Node):
return node.op != "call_function" or node.target not in QUANTIZE_NODE_TARGETS + DEQUANTIZE_NODE_TARGETS

constant_fold(model, constraint_fn=constraint_fn)


def _is_conv(n: torch.fx.Node):
"""
Return whether the node refers to an aten conv op.
"""
return n.op == "call_function" and n.target in (
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.conv_transpose2d.input,
)
2 changes: 2 additions & 0 deletions tests/torch/fx/test_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,10 @@ def test_quantized_model(
# Uncomment to visualize reference graphs
# from torch.ao.quantization.quantize_pt2e import convert_pt2e
# from torch.ao.quantization.quantize_pt2e import prepare_pt2e
# from tests.torch.fx.helpers import visualize_fx_model
# prepared_model = prepare_pt2e(fx_model, quantizer)
# prepared_model(example_input)
# ao_quantized_model = convert_pt2e(prepared_model)
# visualize_fx_model(ao_quantized_model, f"{model_case.model_id}ao_int8.svg")
# ao_nncf_graph = GraphConverter.create_nncf_graph(ao_quantized_model)
# ao_nncf_graph.visualize_graph("ao_" + get_dot_filename(model_case.model_id))

0 comments on commit b923176

Please sign in to comment.