Skip to content

Commit

Permalink
Adapter refactoring/ comments
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Jan 8, 2025
1 parent b923176 commit e9bc7a8
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.common.utils.backend import BackendType
from nncf.experimental.quantization.algorithms.post_training.pipeline import experimental_create_ptq_pipeline
from nncf.experimental.quantization.algorithms.quantizer.base_quantizer import Quantizer as NNCFQuantizer
from nncf.experimental.quantization.quantizer.quantizer import Quantizer as NNCFQuantizer
from nncf.quantization.advanced_parameters import AdvancedBiasCorrectionParameters
from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters
from nncf.quantization.advanced_parameters import RangeEstimatorParameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

from typing import Optional, TypeVar

from nncf.experimental.quantization.algorithms.quantizer.base_quantizer import Quantizer as NNCFQuantizer
from nncf.experimental.quantization.algorithms.range_estimator.algorithm import MinMaxRangeEstimator
from nncf.experimental.quantization.quantizer.quantizer import Quantizer as NNCFQuantizer
from nncf.quantization.advanced_parameters import AdvancedBiasCorrectionParameters
from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters
from nncf.quantization.advanced_parameters import RangeEstimatorParameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from nncf.common.graph.graph import NNCFGraph
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.common.utils.backend import BackendType
from nncf.experimental.quantization.algorithms.quantizer.base_quantizer import Quantizer as NNCFQuantizer
from nncf.experimental.quantization.quantizer.quantizer import Quantizer as NNCFQuantizer
from nncf.quantization.algorithms.algorithm import Algorithm
from nncf.quantization.algorithms.min_max.algorithm import MinMaxQuantization
from nncf.quantization.range_estimator import RangeEstimatorParameters
Expand All @@ -34,6 +34,8 @@ def __init__(
weights_range_estimator_params: Optional[RangeEstimatorParameters] = None,
):
"""
:param quantizer: Instance of NNCFQuantizer to retrieve a quantization config
for the given model.
:param subset_size: Size of a subset to calculate activations statistics used
for quantization, defaults to 300.
:param inplace_statistics: Defines wheather to calculate quantizers statistics
Expand Down Expand Up @@ -68,7 +70,7 @@ def apply(
) -> TModel:
if self._min_max_algo._quantization_target_points_to_qconfig is None:
raise RuntimeError(
"Static points are not available."
"Statistic points are not available."
" Please call `get_statistic_points` before calling the `apply` method."
)
return self._min_max_algo.apply(model=model, graph=graph, statistic_points=statistic_points)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC
from abc import abstractmethod
from typing import TypeVar

Expand All @@ -18,7 +19,12 @@
TModel = TypeVar("TModel")


class Quantizer:
class Quantizer(ABC):
"""
Quantizer is an interface for the RangeEstimator algorithm
which specifies all the required methods to retrieve quantization setup from the given model.
"""

@abstractmethod
def get_quantization_setup(self, model: TModel, nncf_graph: NNCFGraph) -> SingleConfigQuantizerSetup:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from collections import defaultdict
from copy import deepcopy
from typing import Dict, Tuple, Union
from typing import Dict, List, Tuple, Union

import torch
import torch.fx
Expand All @@ -23,18 +23,25 @@

import nncf
from nncf.common.graph.graph import NNCFGraph
from nncf.common.logging import nncf_logger
from nncf.common.quantization.quantizer_setup import ActivationQuantizationInsertionPoint
from nncf.common.quantization.quantizer_setup import QuantizationPointBase
from nncf.common.quantization.quantizer_setup import SingleConfigQuantizationPoint
from nncf.common.quantization.quantizer_setup import SingleConfigQuantizerSetup
from nncf.common.quantization.quantizer_setup import WeightQuantizationInsertionPoint
from nncf.common.quantization.structs import QuantizationScheme as QuantizationMode
from nncf.common.quantization.structs import QuantizerConfig
from nncf.experimental.quantization.algorithms.quantizer.base_quantizer import Quantizer as NNCFQuantizer
from nncf.experimental.quantization.quantizer.quantizer import Quantizer as NNCFQuantizer
from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter

EdgeOrNode = Union[Tuple[torch.fx.Node, torch.fx.Node]]


class TorchAOQuantizerAdapter(NNCFQuantizer):
"""
Implementation of the NNCF Quantizer interface for any given torch.ao quantizer.
"""

def __init__(self, quantizer: Quantizer):
self._quantizer = quantizer

Expand All @@ -47,6 +54,40 @@ def get_quantization_setup(self, model: torch.fx.GraphModule, nncf_graph: NNCFGr
self._quantizer.validate(anotated_model)
return self.get_quantizer_config_from_anotated_model(anotated_model)

@staticmethod
def _get_quantization_points(
from_node: torch.fx.Node,
to_nodes: List[torch.fx.Node],
anotated_model: torch.fx.GraphModule,
qconfig: QuantizerConfig,
) -> List[QuantizationPointBase]:
to_n = to_nodes[0]
if from_node.op == "get_attr":
_, metatype = GraphConverter.get_node_type_and_metatype(to_n, anotated_model)
# Check that the constant is placed on the actual weight port, as it is possible for
# activations to be a constant as well.
if TorchAOQuantizerAdapter._get_node_args(to_n).index(from_node) in metatype.weight_port_ids:
qip = WeightQuantizationInsertionPoint(to_n.name)
return [SingleConfigQuantizationPoint(qip, qconfig, [x.name for x in to_nodes])]

if len(from_node.users) == len(to_nodes):
qip = ActivationQuantizationInsertionPoint(from_node.name)
return [SingleConfigQuantizationPoint(qip, qconfig, [x.name for x in to_nodes])]

qps = []
for to_n_ in to_nodes:
input_port_id = to_n_.args.index(from_node)
qip = ActivationQuantizationInsertionPoint(to_n_.name, input_port_id)
qp = SingleConfigQuantizationPoint(qip, qconfig, [to_n_.name])
qps.append(qp)
return qps

@staticmethod
def _get_node_args(node: torch.fx.Node):
if node.target == torch.ops.aten.cat.default:
return node.args[0]
return node.args

@staticmethod
def get_quantizer_config_from_anotated_model(anotated_model: torch.fx.GraphModule) -> SingleConfigQuantizerSetup:
edge_or_node_to_qspec = _get_edge_or_node_to_qspec(anotated_model)
Expand All @@ -71,36 +112,23 @@ def get_quantizer_config_from_anotated_model(anotated_model: torch.fx.GraphModul
per_channel = False
else:
raise nncf.InternalError(f"Unknown qscheme: {qspec.qscheme}")
signed = qspec.dtype is torch.uint8
signed = qspec.dtype is torch.int8
mode = (
QuantizationMode.SYMMETRIC
if qspec.qscheme in [torch.per_channel_symmetric, torch.per_tensor_symmetric]
else QuantizationMode.ASYMMETRIC
)
qconfig = QuantizerConfig(mode=mode, signedness_to_force=signed, per_channel=per_channel)
qps = []
# If input node is a constant and placed not at activations port (0)
if from_n.op == "get_attr" and to_n.args.index(from_n) != 0:
qip = WeightQuantizationInsertionPoint(to_n.name)
qp = SingleConfigQuantizationPoint(qip, qconfig, [x.name for x in to_nodes])
qps.append(qp)
else:
if len(from_n.users) == len(to_nodes):
qip = ActivationQuantizationInsertionPoint(from_n.name)
qp = SingleConfigQuantizationPoint(qip, qconfig, [x.name for x in to_nodes])
qps.append(qp)
else:
for to_n_ in to_nodes:
input_port_id = to_n_.args.index(from_n)
qip = ActivationQuantizationInsertionPoint(to_n_.name, input_port_id)
qp = SingleConfigQuantizationPoint(qip, qconfig, [to_n_.name])
qps.append(qp)

qps = TorchAOQuantizerAdapter._get_quantization_points(from_n, to_nodes, anotated_model, qconfig)
for qp in qps:
q_setup.add_independent_quantization_point(qp)

elif isinstance(qspec, SharedQuantizationSpec):
pass
# TODO(dlyakhov): Support SharedQuantizationSpec
nncf_logger.warning(
"SharedQuantizationSpec is not supported yet;" f" edges {from_n} -> {to_nodes} won't be quantized."
)
else:
raise nncf.InternalError(f"Unknown torch.ao quantization spec: {qspec}")

Expand Down
6 changes: 2 additions & 4 deletions nncf/experimental/torch/fx/nncf_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,7 @@ def _map_fx_unique_metatypes(node: torch.fx.Node, metatype: om.OperatorMetatype)
return metatype

@staticmethod
def _get_node_type_and_metatype(
node: torch.fx.Node, model: torch.fx.GraphModule
) -> Tuple[str, om.OperatorMetatype]:
def get_node_type_and_metatype(node: torch.fx.Node, model: torch.fx.GraphModule) -> Tuple[str, om.OperatorMetatype]:
"""
Retrieves node's type and metatype.
Expand Down Expand Up @@ -136,7 +134,7 @@ def create_nncf_graph(model: torch.fx.GraphModule) -> PTNNCFGraph:

const_targets_counter = Counter([node.target for node in model.graph.nodes if node.op == "get_attr"])
for source_node in model.graph.nodes:
node_type, node_metatype = GraphConverter._get_node_type_and_metatype(source_node, model)
node_type, node_metatype = GraphConverter.get_node_type_and_metatype(source_node, model)
node_metatype = GraphConverter._map_fx_unique_metatypes(source_node, node_metatype)
is_shared_node = source_node.op in ("get_attr",) and (
const_targets_counter[source_node.target] > 1 or len(source_node.users) > 1
Expand Down
2 changes: 1 addition & 1 deletion nncf/experimental/torch/fx/quantization/quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from nncf.common.logging import nncf_logger
from nncf.data import Dataset
from nncf.experimental.quantization.algorithms.post_training.algorithm import ExperimentalPostTrainingQuantization
from nncf.experimental.quantization.algorithms.quantizer.torch_ao_adapter import TorchAOQuantizerAdapter
from nncf.experimental.quantization.quantizer.torch_ao_adapter import TorchAOQuantizerAdapter
from nncf.experimental.torch.fx.constant_folding import constant_fold
from nncf.experimental.torch.fx.transformations import QUANTIZE_NODE_TARGETS
from nncf.quantization.advanced_parameters import AdvancedBiasCorrectionParameters
Expand Down

0 comments on commit e9bc7a8

Please sign in to comment.