Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PT] Support custom modules in PTQ #2461

Merged
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@
from nncf.experimental.torch.nas.bootstrapNAS.elasticity.elasticity_dim import ElasticityDim
from nncf.experimental.torch.nas.bootstrapNAS.elasticity.filter_reorder import FilterReorderingAlgorithm
from nncf.torch.graph.graph import PTNNCFGraph
from nncf.torch.graph.operator_metatypes import PTDepthwiseConv2dSubtype
from nncf.torch.graph.operator_metatypes import PTModuleBatchNormMetatype
from nncf.torch.graph.operator_metatypes import PTModuleConv2dMetatype
from nncf.torch.graph.operator_metatypes import PTModuleDepthwiseConv2dSubtype
from nncf.torch.graph.operator_metatypes import PTModuleLayerNormMetatype
from nncf.torch.graph.operator_metatypes import PTModuleLinearMetatype
from nncf.torch.graph.transformations.commands import PTInsertionCommand
Expand Down Expand Up @@ -1027,7 +1027,7 @@ def build(self, target_model: NNCFNetwork) -> ElasticWidthHandler:

metatype_vs_elastic_op_creator = {
PTModuleConv2dMetatype: self._create_elastic_conv_width_op,
PTDepthwiseConv2dSubtype: self._create_elastic_conv_width_op,
PTModuleDepthwiseConv2dSubtype: self._create_elastic_conv_width_op,
PTModuleLinearMetatype: self._create_elastic_linear_width_op,
}

Expand Down Expand Up @@ -1078,7 +1078,7 @@ def build(self, target_model: NNCFNetwork) -> ElasticWidthHandler:

metatype_vs_dynamic_input_op_creator = {
PTModuleConv2dMetatype: self._create_dynamic_conv_input_op,
PTDepthwiseConv2dSubtype: self._create_dynamic_dw_conv_input_op,
PTModuleDepthwiseConv2dSubtype: self._create_dynamic_dw_conv_input_op,
PTModuleBatchNormMetatype: self._create_dynamic_bn_input_op,
PTModuleLayerNormMetatype: self._create_dynamic_ln_input_op,
PTModuleLinearMetatype: self._create_dynamic_linear_input_op,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@
from nncf.experimental.torch.nas.bootstrapNAS.elasticity.elastic_width import ElasticWidthHandler
from nncf.experimental.torch.nas.bootstrapNAS.elasticity.elastic_width import ElasticWidthSearchSpace
from nncf.experimental.torch.nas.bootstrapNAS.elasticity.elasticity_dim import ElasticityDim
from nncf.torch.graph.operator_metatypes import PTDepthwiseConv1dSubtype
from nncf.torch.graph.operator_metatypes import PTDepthwiseConv2dSubtype
from nncf.torch.graph.operator_metatypes import PTDepthwiseConv3dSubtype
from nncf.torch.graph.operator_metatypes import PTModuleConv1dMetatype
from nncf.torch.graph.operator_metatypes import PTModuleConv2dMetatype
from nncf.torch.graph.operator_metatypes import PTModuleConv3dMetatype
from nncf.torch.graph.operator_metatypes import PTModuleConvTranspose2dMetatype
from nncf.torch.graph.operator_metatypes import PTModuleConvTranspose3dMetatype
from nncf.torch.graph.operator_metatypes import PTModuleDepthwiseConv1dSubtype
from nncf.torch.graph.operator_metatypes import PTModuleDepthwiseConv2dSubtype
from nncf.torch.graph.operator_metatypes import PTModuleDepthwiseConv3dSubtype
from nncf.torch.graph.operator_metatypes import PTModuleLinearMetatype
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.pruning.utils import collect_output_shapes
Expand All @@ -60,11 +60,11 @@ class MultiElasticityHandler(ElasticityHandler):
def __init__(self, handlers: OrderedDictType[ElasticityDim, SingleElasticityHandler], target_model: NNCFNetwork):
GENERAL_CONV_LAYER_METATYPES = [
PTModuleConv1dMetatype,
PTDepthwiseConv1dSubtype,
PTModuleDepthwiseConv1dSubtype,
PTModuleConv2dMetatype,
PTDepthwiseConv2dSubtype,
PTModuleDepthwiseConv2dSubtype,
PTModuleConv3dMetatype,
PTDepthwiseConv3dSubtype,
PTModuleDepthwiseConv3dSubtype,
PTModuleConvTranspose2dMetatype,
PTModuleConvTranspose3dMetatype,
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from nncf.experimental.torch.nas.bootstrapNAS.elasticity.elastic_width import ElasticWidthHandler
from nncf.experimental.torch.nas.bootstrapNAS.elasticity.multi_elasticity_handler import MultiElasticityHandler
from nncf.torch.graph.graph import PTNNCFGraph
from nncf.torch.graph.operator_metatypes import PTDepthwiseConv2dSubtype
from nncf.torch.graph.operator_metatypes import PTModuleConv2dMetatype
from nncf.torch.graph.operator_metatypes import PTModuleDepthwiseConv2dSubtype


class SubnetGraph:
Expand All @@ -38,7 +38,7 @@ def __init__(self, compression_graph: PTNNCFGraph, multi_elasticity_handler: Mul
color = None
if metatype == PTModuleConv2dMetatype:
color = "lightblue"
if metatype == PTDepthwiseConv2dSubtype:
if metatype == PTModuleDepthwiseConv2dSubtype:
operator_name = f"DW_{operator_name}"
color = "purple"

Expand Down
14 changes: 10 additions & 4 deletions nncf/quantization/algorithms/fast_bias_correction/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,10 @@ def apply(
input_fp, input_shape = self._get_fp_inputs(statistic_points, in_node_name)
output_fp = self._get_fp_outputs(statistic_points, out_node_name)

extracted_model = self._extract_submodel(model_transformer, node_name)
extracted_model = self._extract_submodel(model_transformer, in_node_name, out_node_name)
if extracted_model is None:
nncf_logger.debug(f"Skipping node {node_name} because cant extract submodel")
continue

sub_input_name, sub_output_name = self._backend_entity.get_sub_input_output_names(extracted_model)

Expand Down Expand Up @@ -267,15 +270,18 @@ def output_filter_func(point):
output_fp.extend(Tensor(tensor_collector.get_statistics().mean_values))
return output_fp

def _extract_submodel(self, model_transformer: ModelTransformer, node_name: str) -> TModel:
def _extract_submodel(self, model_transformer: ModelTransformer, in_node_name: str, out_node_name: str) -> TModel:
"""
Extracts sub-model using backend-specific ModelTransformer.

:param model_transformer: Backend-specific ModelTransformer.
:param node_name: Name of the node that should be a center of the sub-model.
:param in_node_name: Name of the node that should be a center of the sub-model.
:param out_node_name: Name of the node that should be a center of the sub-model.
AlexanderDokuchaev marked this conversation as resolved.
Show resolved Hide resolved
:return: Backend-specific sub-model.
"""
model_extraction_command = self._backend_entity.model_extraction_command([(node_name, 0)], [(node_name, 0)])
model_extraction_command = self._backend_entity.model_extraction_command(
[(in_node_name, 0)], [(out_node_name, 0)]
)
me_transformation_layout = TransformationLayout()
me_transformation_layout.register(model_extraction_command)
extracted_model = model_transformer.transform(me_transformation_layout)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
from nncf.quantization.algorithms.fast_bias_correction.backend import FastBiasCorrectionAlgoBackend
from nncf.torch.graph.transformations.command_creation import create_bias_correction_command
from nncf.torch.graph.transformations.commands import PTBiasCorrectionCommand
from nncf.torch.graph.transformations.commands import PTModelExtractionWithFusedBiasCommand
from nncf.torch.graph.transformations.commands import PTModelExtractionCommand
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.model_analyzer import get_fused_bias_value
from nncf.torch.model_analyzer import get_potential_fused_node
from nncf.torch.model_analyzer import is_node_with_fused_bias
from nncf.torch.model_analyzer import is_quantized_weights
from nncf.torch.model_graph_manager import get_fused_bias_value
from nncf.torch.model_graph_manager import get_potential_fused_node
from nncf.torch.model_graph_manager import is_node_with_fused_bias
from nncf.torch.model_graph_manager import is_quantized_weights
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.tensor_statistics.collectors import get_mean_statistic_collector

Expand Down Expand Up @@ -56,8 +56,8 @@ def create_bias_correction_command(
@staticmethod
def model_extraction_command(
input_ids: List[Tuple[str, int]], output_ids: List[Tuple[str, int]]
) -> PTModelExtractionWithFusedBiasCommand:
return PTModelExtractionWithFusedBiasCommand(input_ids[0][0])
) -> PTModelExtractionCommand:
return PTModelExtractionCommand([input_ids[0][0]], [output_ids[0][0]])

@staticmethod
def mean_statistic_collector(
Expand Down
6 changes: 3 additions & 3 deletions nncf/quantization/algorithms/min_max/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import collections
import dataclasses
from copy import deepcopy
from typing import Any, Dict, List, Optional, OrderedDict, Set, TypeVar, Union
from typing import Any, Dict, List, Optional, OrderedDict, Set, Tuple, TypeVar, Union

import numpy as np

Expand Down Expand Up @@ -681,7 +681,7 @@ def _get_activation_quantization_target_point(

def _get_quantization_target_points(
self, model: TModel, nncf_graph: NNCFGraph
) -> OrderedDict[TargetPoint, QuantizerConfig]:
) -> Tuple[OrderedDict[TargetPoint, QuantizerConfig], List[List[TargetPoint]]]:
"""
Returns Quantization Target Points.
In the Compression Pipeline logic NNCF assumes that the compression pipeline works only on the single model.
Expand Down Expand Up @@ -1049,7 +1049,7 @@ def _is_node_after_producers(node):
quantizer_setup.discard(fq_2_q_key, True)
continue

# In the case of the two quantizers without the brancking after them,
# In the case of the two quantizers without the branching after them,
# it needs to check that all quantizers follows after producer nodes.
if _is_node_after_producers(fq_1_producer) and _is_node_after_producers(fq_2_producer):
fq_1_prod_shape = np.prod(nncf_graph.get_output_edges(fq_1_producer)[0].tensor_shape)
Expand Down
20 changes: 10 additions & 10 deletions nncf/quantization/algorithms/min_max/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class PTMinMaxAlgoBackend(MinMaxAlgoBackend):

@property
def mat_mul_metatypes(self) -> List[OperatorMetatype]:
return [om.PTModuleLinearMetatype, om.PTLinearMetatype, om.PTMatMulMetatype]
return [om.PTLinearMetatype, om.PTLinearMetatype, om.PTMatMulMetatype]
AlexanderDokuchaev marked this conversation as resolved.
Show resolved Hide resolved

@property
def post_processing_metatypes(self) -> List[OperatorMetatype]:
Expand All @@ -80,18 +80,18 @@ def read_variable_metatypes(self) -> List[OperatorMetatype]:

@property
def conv_metatypes(self) -> List[OperatorMetatype]:
return [om.PTModuleConv1dMetatype, om.PTModuleConv2dMetatype, om.PTModuleConv3dMetatype]
return [om.PTConv1dMetatype, om.PTConv2dMetatype, om.PTConv3dMetatype]

@property
def overflow_fix_metatypes(self) -> List[OperatorMetatype]:
return [
om.PTModuleConv1dMetatype,
om.PTModuleConv2dMetatype,
om.PTModuleConv3dMetatype,
om.PTModuleLinearMetatype,
om.PTModuleConvTranspose1dMetatype,
om.PTModuleConvTranspose2dMetatype,
om.PTModuleConvTranspose3dMetatype,
om.PTConv1dMetatype,
om.PTConv2dMetatype,
om.PTConv3dMetatype,
om.PTLinearMetatype,
om.PTConvTranspose1dMetatype,
om.PTConvTranspose2dMetatype,
om.PTConvTranspose3dMetatype,
]

@property
Expand Down Expand Up @@ -205,7 +205,7 @@ def get_statistic_collector(

@staticmethod
def get_weight_tensor_port_ids(node: NNCFNode) -> List[Optional[int]]:
return [None]
return node.metatype.weight_port_ids
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I understand, not all metatypes have a weight_port_ids attribute, then [None] should be returned, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but weighted nodes have attribute, it's filtered before using get_weight_tensor_port_ids

def get_weight_nodes(nncf_graph: NNCFGraph) -> List[NNCFNode]:
return [
node for node in nncf_graph.get_all_nodes() if isinstance(node.layer_attributes, WeightedLayerAttributes)
]


@staticmethod
def get_weight_name(nncf_graph: NNCFGraph, target_point: PTTargetPoint) -> str:
Expand Down
23 changes: 14 additions & 9 deletions nncf/quantization/algorithms/smooth_quant/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from nncf.torch.graph.transformations.command_creation import create_command_to_update_weight
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.model_graph_manager import get_const_data
from nncf.torch.model_graph_manager import get_const_node
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.quantization.default_quantization import DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT
from nncf.torch.tensor_statistics.collectors import PTAbsMaxReducer
Expand All @@ -52,14 +54,14 @@ class PTSmoothQuantAlgoBackend(SmoothQuantAlgoBackend):
@property
def convolution_metatypes(self) -> List[OperatorMetatype]:
return [
om.PTModuleConv1dMetatype,
om.PTModuleConv2dMetatype,
om.PTModuleConv3dMetatype,
om.PTConv1dMetatype,
om.PTConv2dMetatype,
om.PTConv3dMetatype,
]

@property
def matmul_metatypes(self) -> List[OperatorMetatype]:
return [om.PTModuleLinearMetatype]
return [om.PTLinearMetatype]

@property
def quantize_agnostic_metatypes(self) -> List[OperatorMetatype]:
Expand Down Expand Up @@ -98,10 +100,13 @@ def get_abs_max_channel_collector(

@staticmethod
def get_weight_value(node_with_weight: NNCFNode, model: NNCFNetwork) -> Tensor:
node_module = model.nncf.get_containing_module(node_with_weight.node_name)
if node_module.weight is None:
raise RuntimeError(f"{node_module} module has no .weight attribute.")
return Tensor(node_module.weight.data)
weight_node = get_const_node(
node_with_weight, node_with_weight.metatype.weight_port_ids[0], model.nncf.get_graph()
)
if weight_node is None:
raise RuntimeError(f"{node_with_weight} node has no weight node.")
weight_data = get_const_data(weight_node, model)
return Tensor(weight_data)

@staticmethod
def get_weight_tensor_port_id(node: NNCFNode) -> int:
Expand Down Expand Up @@ -131,7 +136,7 @@ def scale_insertion_command(

@staticmethod
def get_activation_channel_axis(node: NNCFNode, port_id: int) -> int:
if node.metatype == om.PTModuleLinearMetatype:
if node.metatype == om.PTLinearMetatype:
return -1
# TODO: Add activation axis calculation when MatMul will be supported
return 1
Expand Down
9 changes: 6 additions & 3 deletions nncf/torch/graph/graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from nncf.common.graph import NNCFGraph
from nncf.common.graph import NNCFNode
from nncf.common.graph.definitions import MODEL_CONST_OP_NAME
from nncf.common.graph.operator_metatypes import CONST_NOOP_METATYPES
from nncf.common.graph.operator_metatypes import INPUT_NOOP_METATYPES
from nncf.torch.dynamic_graph.context import TracingContext
Expand Down Expand Up @@ -79,9 +80,11 @@ class GraphConverter:
def convert(dynamic_graph: DynamicGraph, traced_parameters) -> PTNNCFGraph:
module_id_vs_known_op_addrs_map: Dict[int, Set[Scope]] = defaultdict(set)
for dynamic_graph_node in dynamic_graph.get_all_nodes():
module_id_vs_known_op_addrs_map[dynamic_graph_node.calling_module_id].add(
dynamic_graph_node.op_exec_context.op_address
)
# Skip const nodes to detect shared nodes
if dynamic_graph_node.op_exec_context.operator_name != MODEL_CONST_OP_NAME:
module_id_vs_known_op_addrs_map[dynamic_graph_node.calling_module_id].add(
dynamic_graph_node.op_exec_context.op_address
)

module_id_vs_sorted_scopes_map = {
k: list(sorted([s.scope_in_model for s in v], key=str)) for k, v in module_id_vs_known_op_addrs_map.items()
Expand Down
Loading
Loading