Skip to content

Commit

Permalink
Rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Jan 17, 2025
1 parent 9511c7c commit cde3a30
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torch.ao.quantization.observer import PerChannelMinMaxObserver
from torch.ao.quantization.quantizer.quantizer import QuantizationAnnotation as InductorQAnotation
from torch.ao.quantization.quantizer.quantizer import QuantizationSpec as InductorQuantizationSpec
from torch.ao.quantization.quantizer.quantizer import Quantizer
from torch.ao.quantization.quantizer.quantizer import Quantizer as TorchAOQuantizer

from nncf.common.graph.graph import NNCFGraph
from nncf.common.quantization.quantizer_propagation.solver import QuantizerPropagationRule
Expand All @@ -26,7 +26,7 @@
from nncf.common.quantization.structs import QuantizationPreset
from nncf.common.quantization.structs import QuantizationScheme
from nncf.common.quantization.structs import QuantizerConfig as NNCFQuantizerConfig
from nncf.experimental.quantization.algorithms.quantizer.base_quantizer import Quantizer as NNCFQuantizer
from nncf.experimental.quantization.quantizers.quantizer import Quantizer
from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter
from nncf.experimental.torch.fx.node_utils import get_graph_node_by_name
from nncf.experimental.torch.fx.transformations import fold_constant_except_qdq
Expand All @@ -42,7 +42,12 @@
QUANT_ANNOTATION_KEY = "quantization_annotation"


class OpenVINOQuantizer(Quantizer):
class OpenVINOQuantizer(TorchAOQuantizer):
"""
Implementation of the Torch AO quantizer which annotates models with quantization annotations
optimally for the inference via OpenVINO.
"""

def __init__(
self,
mode: Optional[QuantizationMode] = None,
Expand Down Expand Up @@ -171,9 +176,12 @@ def transform_for_annotation(self, model: torch.fx.GraphModule) -> torch.fx.Grap
return model


class OpenVINOQuantizerAdapter(NNCFQuantizer):
class OpenVINOQuantizerAdapter(Quantizer):
def __init__(self, quantizer: OpenVINOQuantizer):
self._quantizer = quantizer

def transform_prior_quantization(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
return self._quantizer.transform_for_annotation(model)

def get_quantization_setup(self, model: torch.fx.GraphModule, nncf_graph: NNCFGraph) -> SingleConfigQuantizerSetup:
return self._quantizer.get_quantization_setup(model, nncf_graph)
9 changes: 7 additions & 2 deletions nncf/experimental/torch/fx/quantization/quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@
from nncf.common.factory import NNCFGraphFactory
from nncf.common.logging import nncf_logger
from nncf.data import Dataset
from nncf.experimental.common.quantization.algorithms.quantizer.openvino_quantizer import OpenVINOQuantizerAdapter
from nncf.experimental.quantization.algorithms.post_training.algorithm import ExperimentalPostTrainingQuantization
from nncf.experimental.quantization.quantizers.openvino_quantizer import OpenVINOQuantizer
from nncf.experimental.quantization.quantizers.openvino_quantizer import OpenVINOQuantizerAdapter
from nncf.experimental.quantization.quantizers.torch_ao_adapter import TorchAOQuantizerAdapter
from nncf.experimental.torch.fx.constant_folding import constant_fold
from nncf.experimental.torch.fx.transformations import QUANTIZE_NODE_TARGETS
Expand Down Expand Up @@ -92,7 +93,11 @@ def quantize_pt2e(
model = deepcopy(model)

_fuse_conv_bn_(model)
quantizer = TorchAOQuantizerAdapter(quantizer)
if isinstance(quantizer, OpenVINOQuantizer):
quantizer = OpenVINOQuantizerAdapter(quantizer)
else:
quantizer = TorchAOQuantizerAdapter(quantizer)

# Call transform_prior_quantization before the NNCFGraph creation
transformed_model = quantizer.transform_prior_quantization(model)

Expand Down
2 changes: 1 addition & 1 deletion tests/torch/fx/test_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from torch.ao.quantization.quantizer.x86_inductor_quantizer import get_default_x86_inductor_quantization_config

import nncf
from nncf.experimental.common.quantization.algorithms.quantizer.openvino_quantizer import OpenVINOQuantizer
from nncf.experimental.quantization.quantizers.openvino_quantizer import OpenVINOQuantizer
from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter
from nncf.experimental.torch.fx.quantization.quantize_pt2e import quantize_pt2e
from tests.torch import test_models
Expand Down

0 comments on commit cde3a30

Please sign in to comment.