Skip to content

Commit

Permalink
Comments/ SharedQuantizationSpec test
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Jan 23, 2025
1 parent 9721fa8 commit 9187faa
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 5 deletions.
22 changes: 17 additions & 5 deletions nncf/experimental/quantization/quantizers/openvino_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torch.ao.quantization.quantizer.quantizer import Quantizer as TorchAOQuantizer
from torch.ao.quantization.quantizer.quantizer import SharedQuantizationSpec as TorchAOSharedQuantizationSpec

import nncf
from nncf.common.graph.graph import NNCFGraph
from nncf.common.quantization.quantizer_propagation.solver import QuantizerPropagationRule
from nncf.common.quantization.quantizer_setup import QuantizationPointBase
Expand Down Expand Up @@ -51,6 +52,7 @@ class OpenVINOQuantizer(TorchAOQuantizer):

def __init__(
self,
*,
mode: Optional[QuantizationMode] = None,
preset: Optional[QuantizationPreset] = None,
target_device: TargetDevice = TargetDevice.ANY,
Expand Down Expand Up @@ -117,9 +119,19 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
root_quantizer_id = self._get_unified_scales_root_quantizer_id(
nncf_graph, quantizer_ids, quantization_setup
)
qp = quantization_setup.quantization_points[root_quantizer_id]
root_edge_or_node, annotation = self._get_edge_or_node_and_annotation(graph, qp, node_vs_torch_annotation)
qspec = self._get_inductor_qspec_from_qp(qp)
root_qp = quantization_setup.quantization_points[root_quantizer_id]

if any(root_qp.qconfig != quantization_setup.quantization_points[q_id].qconfig for q_id in quantizer_ids):
qps = [quantization_setup.quantization_points[q_id] for q_id in quantizer_ids]
raise nncf.InternalError(
"Different quantization configs are set to one unified scale group:"
f"{[(qp.insertion_point.__dict__, str(qp.qconfig)) for qp in qps]}"
)

root_edge_or_node, annotation = self._get_edge_or_node_and_annotation(
graph, root_qp, node_vs_torch_annotation
)
qspec = self._get_torch_ao_qspec_from_qp(root_qp)
self._fill_torch_ao_annotation(root_edge_or_node, qspec, annotation)

while quantizer_ids:
Expand All @@ -137,7 +149,7 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:

for qp in non_unified_quantization_points.values():
edge_or_node, annotation = self._get_edge_or_node_and_annotation(graph, qp, node_vs_torch_annotation)
qspec = self._get_inductor_qspec_from_qp(qp)
qspec = self._get_torch_ao_qspec_from_qp(qp)
self._fill_torch_ao_annotation(edge_or_node, qspec, annotation)

for node, annotation in node_vs_torch_annotation.items():
Expand Down Expand Up @@ -228,7 +240,7 @@ def _fill_torch_ao_annotation(
annotation_to_update.input_qspec_map[edge_or_node[0]] = qspec

@staticmethod
def _get_inductor_qspec_from_qp(qp: QuantizationPointBase) -> TorchAOQuantizationSpec:
def _get_torch_ao_qspec_from_qp(qp: QuantizationPointBase) -> TorchAOQuantizationSpec:
"""
Retrieves the quantization configuration from the given quantization point and
converts it into a TorchAOQuantizationSpec.
Expand Down
33 changes: 33 additions & 0 deletions tests/torch/fx/test_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,21 @@
import torchvision.models as models
from torch.ao.quantization.quantize_pt2e import convert_pt2e
from torch.ao.quantization.quantize_pt2e import prepare_pt2e
from torch.ao.quantization.quantizer.quantizer import QuantizationSpec as TorchAOQuantizationSpec
from torch.ao.quantization.quantizer.quantizer import Quantizer
from torch.ao.quantization.quantizer.quantizer import SharedQuantizationSpec as TorchAOSharedQuantizationSpec
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
from torch.ao.quantization.quantizer.x86_inductor_quantizer import get_default_x86_inductor_quantization_config

import nncf
from nncf.experimental.quantization.quantizers.openvino_quantizer import OpenVINOQuantizer
from nncf.experimental.quantization.quantizers.torch_ao_adapter import _get_edge_or_node_to_qspec
from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter
from nncf.experimental.torch.fx.quantization.quantize_pt2e import quantize_pt2e
from tests.torch import test_models
from tests.torch.fx.helpers import get_torch_fx_model
from tests.torch.test_compressed_graph import check_graph
from tests.torch.test_models.synthetic import ModelForGraphBuildingTest
from tests.torch.test_models.synthetic import ShortTransformer
from tests.torch.test_models.synthetic import YOLO11N_SDPABlock

Expand Down Expand Up @@ -190,3 +194,32 @@ def test_openvino_quantizer_with_torch_ao_convert_pt2e(model_case: ModelCase, qu
FX_QUANTIZED_DIR_NAME / "ao_export_quantization_OpenVINOQuantizer",
extended=True,
)


TorchAOSharedQuantizationSpecTestCases = (
(
ModelCase(ModelForGraphBuildingTest, "unified_scales_test_model", ModelForGraphBuildingTest.INPUT_SHAPES[0]),
("relu", "conv_transpose2d"),
),
)


@pytest.mark.parametrize(
"model_case,unified_scale_node_names",
TorchAOSharedQuantizationSpecTestCases,
ids=[m[0].model_id for m in TorchAOSharedQuantizationSpecTestCases],
)
def test_OVQuantizer_TorchAOSharedQuantizationSpec_handling(model_case, unified_scale_node_names):
fx_model, _ = _build_torch_fx_model(model_case)
quantizer = OpenVINOQuantizer()
fx_model = quantizer.transform_for_annotation(fx_model)
quantizer.annotate(fx_model)

actual_annotation = _get_edge_or_node_to_qspec(fx_model)
for edge_or_node, annotation in actual_annotation.items():
if isinstance(edge_or_node, torch.fx.Node) and edge_or_node.name == unified_scale_node_names[1]:
assert isinstance(annotation, TorchAOSharedQuantizationSpec)
assert annotation.edge_or_node.name == unified_scale_node_names[0]
assert isinstance(actual_annotation[annotation.edge_or_node], TorchAOQuantizationSpec)
return
raise RuntimeError(f"Node {unified_scale_node_names[1]} should be annotated as quantizable")

0 comments on commit 9187faa

Please sign in to comment.