Skip to content

Commit

Permalink
[PT] Support custom modules in PTQ (#2461)
Browse files Browse the repository at this point in the history
### Changes

- Wrap model with tracing of parameters for PTQ.
- `PTModelExtractionWithFusedBiasCommand` to `PTModelExtractionCommand`
- Removed model_analyzer.py
- Renamed PTDepthwiseConv3dSubtype to PTModuleDepthwiseConv3dSubtype
metatype
- Added PTModuleDepthwiseConv3dSubtype
- Added is_subtype to OperatorMetatypeRegistry.register  

### Reason for changes

Support models with custom modules.

### Related tickets

129581
  • Loading branch information
AlexanderDokuchaev authored Apr 24, 2024
1 parent fa1a4ce commit 590bc6d
Show file tree
Hide file tree
Showing 55 changed files with 12,137 additions and 8,043 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@
from nncf.experimental.torch.nas.bootstrapNAS.elasticity.elasticity_dim import ElasticityDim
from nncf.experimental.torch.nas.bootstrapNAS.elasticity.filter_reorder import FilterReorderingAlgorithm
from nncf.torch.graph.graph import PTNNCFGraph
from nncf.torch.graph.operator_metatypes import PTDepthwiseConv2dSubtype
from nncf.torch.graph.operator_metatypes import PTModuleBatchNormMetatype
from nncf.torch.graph.operator_metatypes import PTModuleConv2dMetatype
from nncf.torch.graph.operator_metatypes import PTModuleDepthwiseConv2dSubtype
from nncf.torch.graph.operator_metatypes import PTModuleLayerNormMetatype
from nncf.torch.graph.operator_metatypes import PTModuleLinearMetatype
from nncf.torch.graph.transformations.commands import PTInsertionCommand
Expand Down Expand Up @@ -1027,7 +1027,7 @@ def build(self, target_model: NNCFNetwork) -> ElasticWidthHandler:

metatype_vs_elastic_op_creator = {
PTModuleConv2dMetatype: self._create_elastic_conv_width_op,
PTDepthwiseConv2dSubtype: self._create_elastic_conv_width_op,
PTModuleDepthwiseConv2dSubtype: self._create_elastic_conv_width_op,
PTModuleLinearMetatype: self._create_elastic_linear_width_op,
}

Expand Down Expand Up @@ -1078,7 +1078,7 @@ def build(self, target_model: NNCFNetwork) -> ElasticWidthHandler:

metatype_vs_dynamic_input_op_creator = {
PTModuleConv2dMetatype: self._create_dynamic_conv_input_op,
PTDepthwiseConv2dSubtype: self._create_dynamic_dw_conv_input_op,
PTModuleDepthwiseConv2dSubtype: self._create_dynamic_dw_conv_input_op,
PTModuleBatchNormMetatype: self._create_dynamic_bn_input_op,
PTModuleLayerNormMetatype: self._create_dynamic_ln_input_op,
PTModuleLinearMetatype: self._create_dynamic_linear_input_op,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@
from nncf.experimental.torch.nas.bootstrapNAS.elasticity.elastic_width import ElasticWidthHandler
from nncf.experimental.torch.nas.bootstrapNAS.elasticity.elastic_width import ElasticWidthSearchSpace
from nncf.experimental.torch.nas.bootstrapNAS.elasticity.elasticity_dim import ElasticityDim
from nncf.torch.graph.operator_metatypes import PTDepthwiseConv1dSubtype
from nncf.torch.graph.operator_metatypes import PTDepthwiseConv2dSubtype
from nncf.torch.graph.operator_metatypes import PTDepthwiseConv3dSubtype
from nncf.torch.graph.operator_metatypes import PTModuleConv1dMetatype
from nncf.torch.graph.operator_metatypes import PTModuleConv2dMetatype
from nncf.torch.graph.operator_metatypes import PTModuleConv3dMetatype
from nncf.torch.graph.operator_metatypes import PTModuleConvTranspose2dMetatype
from nncf.torch.graph.operator_metatypes import PTModuleConvTranspose3dMetatype
from nncf.torch.graph.operator_metatypes import PTModuleDepthwiseConv1dSubtype
from nncf.torch.graph.operator_metatypes import PTModuleDepthwiseConv2dSubtype
from nncf.torch.graph.operator_metatypes import PTModuleDepthwiseConv3dSubtype
from nncf.torch.graph.operator_metatypes import PTModuleLinearMetatype
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.pruning.utils import collect_output_shapes
Expand All @@ -60,11 +60,11 @@ class MultiElasticityHandler(ElasticityHandler):
def __init__(self, handlers: OrderedDictType[ElasticityDim, SingleElasticityHandler], target_model: NNCFNetwork):
GENERAL_CONV_LAYER_METATYPES = [
PTModuleConv1dMetatype,
PTDepthwiseConv1dSubtype,
PTModuleDepthwiseConv1dSubtype,
PTModuleConv2dMetatype,
PTDepthwiseConv2dSubtype,
PTModuleDepthwiseConv2dSubtype,
PTModuleConv3dMetatype,
PTDepthwiseConv3dSubtype,
PTModuleDepthwiseConv3dSubtype,
PTModuleConvTranspose2dMetatype,
PTModuleConvTranspose3dMetatype,
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from nncf.experimental.torch.nas.bootstrapNAS.elasticity.elastic_width import ElasticWidthHandler
from nncf.experimental.torch.nas.bootstrapNAS.elasticity.multi_elasticity_handler import MultiElasticityHandler
from nncf.torch.graph.graph import PTNNCFGraph
from nncf.torch.graph.operator_metatypes import PTDepthwiseConv2dSubtype
from nncf.torch.graph.operator_metatypes import PTModuleConv2dMetatype
from nncf.torch.graph.operator_metatypes import PTModuleDepthwiseConv2dSubtype


class SubnetGraph:
Expand All @@ -38,7 +38,7 @@ def __init__(self, compression_graph: PTNNCFGraph, multi_elasticity_handler: Mul
color = None
if metatype == PTModuleConv2dMetatype:
color = "lightblue"
if metatype == PTDepthwiseConv2dSubtype:
if metatype == PTModuleDepthwiseConv2dSubtype:
operator_name = f"DW_{operator_name}"
color = "purple"

Expand Down
14 changes: 10 additions & 4 deletions nncf/quantization/algorithms/fast_bias_correction/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,10 @@ def apply(
input_fp, input_shape = self._get_fp_inputs(statistic_points, in_node_name)
output_fp = self._get_fp_outputs(statistic_points, out_node_name)

extracted_model = self._extract_submodel(model_transformer, node_name)
extracted_model = self._extract_submodel(model_transformer, in_node_name, out_node_name)
if extracted_model is None:
nncf_logger.debug(f"Skipping node {node_name} because cant extract submodel")
continue

sub_input_name, sub_output_name = self._backend_entity.get_sub_input_output_names(extracted_model)

Expand Down Expand Up @@ -267,15 +270,18 @@ def output_filter_func(point):
output_fp.extend(Tensor(tensor_collector.get_statistics().mean_values))
return output_fp

def _extract_submodel(self, model_transformer: ModelTransformer, node_name: str) -> TModel:
def _extract_submodel(self, model_transformer: ModelTransformer, in_node_name: str, out_node_name: str) -> TModel:
"""
Extracts sub-model using backend-specific ModelTransformer.
:param model_transformer: Backend-specific ModelTransformer.
:param node_name: Name of the node that should be a center of the sub-model.
:param in_node_name: Name of the start node.
:param out_node_name: Name of the output node.
:return: Backend-specific sub-model.
"""
model_extraction_command = self._backend_entity.model_extraction_command([(node_name, 0)], [(node_name, 0)])
model_extraction_command = self._backend_entity.model_extraction_command(
[(in_node_name, 0)], [(out_node_name, 0)]
)
me_transformation_layout = TransformationLayout()
me_transformation_layout.register(model_extraction_command)
extracted_model = model_transformer.transform(me_transformation_layout)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
from nncf.quantization.algorithms.fast_bias_correction.backend import FastBiasCorrectionAlgoBackend
from nncf.torch.graph.transformations.command_creation import create_bias_correction_command
from nncf.torch.graph.transformations.commands import PTBiasCorrectionCommand
from nncf.torch.graph.transformations.commands import PTModelExtractionWithFusedBiasCommand
from nncf.torch.graph.transformations.commands import PTModelExtractionCommand
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.model_analyzer import get_fused_bias_value
from nncf.torch.model_analyzer import get_potential_fused_node
from nncf.torch.model_analyzer import is_node_with_fused_bias
from nncf.torch.model_analyzer import is_quantized_weights
from nncf.torch.model_graph_manager import get_fused_bias_value
from nncf.torch.model_graph_manager import get_potential_fused_node
from nncf.torch.model_graph_manager import is_node_with_fused_bias
from nncf.torch.model_graph_manager import is_quantized_weights
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.tensor_statistics.collectors import get_mean_statistic_collector

Expand Down Expand Up @@ -56,8 +56,8 @@ def create_bias_correction_command(
@staticmethod
def model_extraction_command(
input_ids: List[Tuple[str, int]], output_ids: List[Tuple[str, int]]
) -> PTModelExtractionWithFusedBiasCommand:
return PTModelExtractionWithFusedBiasCommand(input_ids[0][0])
) -> PTModelExtractionCommand:
return PTModelExtractionCommand([input_ids[0][0]], [output_ids[0][0]])

@staticmethod
def mean_statistic_collector(
Expand Down
6 changes: 3 additions & 3 deletions nncf/quantization/algorithms/min_max/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import collections
import dataclasses
from copy import deepcopy
from typing import Any, Dict, List, Optional, OrderedDict, Set, TypeVar, Union
from typing import Any, Dict, List, Optional, OrderedDict, Set, Tuple, TypeVar, Union

import numpy as np

Expand Down Expand Up @@ -681,7 +681,7 @@ def _get_activation_quantization_target_point(

def _get_quantization_target_points(
self, model: TModel, nncf_graph: NNCFGraph
) -> OrderedDict[TargetPoint, QuantizerConfig]:
) -> Tuple[OrderedDict[TargetPoint, QuantizerConfig], List[List[TargetPoint]]]:
"""
Returns Quantization Target Points.
In the Compression Pipeline logic NNCF assumes that the compression pipeline works only on the single model.
Expand Down Expand Up @@ -1053,7 +1053,7 @@ def _is_node_after_producers(node):
quantizer_setup.discard(fq_2_q_key, True)
continue

# In the case of the two quantizers without the brancking after them,
# In the case of the two quantizers without the branching after them,
# it needs to check that all quantizers follows after producer nodes.
if _is_node_after_producers(fq_1_producer) and _is_node_after_producers(fq_2_producer):
fq_1_prod_shape = np.prod(nncf_graph.get_output_edges(fq_1_producer)[0].tensor_shape)
Expand Down
20 changes: 10 additions & 10 deletions nncf/quantization/algorithms/min_max/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class PTMinMaxAlgoBackend(MinMaxAlgoBackend):

@property
def mat_mul_metatypes(self) -> List[OperatorMetatype]:
return [om.PTModuleLinearMetatype, om.PTLinearMetatype, om.PTMatMulMetatype]
return [om.PTLinearMetatype, om.PTMatMulMetatype]

@property
def post_processing_metatypes(self) -> List[OperatorMetatype]:
Expand All @@ -82,18 +82,18 @@ def read_variable_metatypes(self) -> List[OperatorMetatype]:

@property
def conv_metatypes(self) -> List[OperatorMetatype]:
return [om.PTModuleConv1dMetatype, om.PTModuleConv2dMetatype, om.PTModuleConv3dMetatype]
return [om.PTConv1dMetatype, om.PTConv2dMetatype, om.PTConv3dMetatype]

@property
def overflow_fix_metatypes(self) -> List[OperatorMetatype]:
return [
om.PTModuleConv1dMetatype,
om.PTModuleConv2dMetatype,
om.PTModuleConv3dMetatype,
om.PTModuleLinearMetatype,
om.PTModuleConvTranspose1dMetatype,
om.PTModuleConvTranspose2dMetatype,
om.PTModuleConvTranspose3dMetatype,
om.PTConv1dMetatype,
om.PTConv2dMetatype,
om.PTConv3dMetatype,
om.PTLinearMetatype,
om.PTConvTranspose1dMetatype,
om.PTConvTranspose2dMetatype,
om.PTConvTranspose3dMetatype,
]

@property
Expand Down Expand Up @@ -210,7 +210,7 @@ def get_statistic_collector(

@staticmethod
def get_weight_tensor_port_ids(node: NNCFNode) -> List[Optional[int]]:
return [None]
return node.metatype.weight_port_ids

@staticmethod
def get_weight_name(nncf_graph: NNCFGraph, target_point: PTTargetPoint) -> str:
Expand Down
23 changes: 14 additions & 9 deletions nncf/quantization/algorithms/smooth_quant/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from nncf.torch.graph.transformations.command_creation import create_command_to_update_weight
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.model_graph_manager import get_const_data
from nncf.torch.model_graph_manager import get_const_node
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.quantization.default_quantization import DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT
from nncf.torch.tensor_statistics.collectors import PTAbsMaxReducer
Expand All @@ -52,14 +54,14 @@ class PTSmoothQuantAlgoBackend(SmoothQuantAlgoBackend):
@property
def convolution_metatypes(self) -> List[OperatorMetatype]:
return [
om.PTModuleConv1dMetatype,
om.PTModuleConv2dMetatype,
om.PTModuleConv3dMetatype,
om.PTConv1dMetatype,
om.PTConv2dMetatype,
om.PTConv3dMetatype,
]

@property
def matmul_metatypes(self) -> List[OperatorMetatype]:
return [om.PTModuleLinearMetatype]
return [om.PTLinearMetatype]

@property
def quantize_agnostic_metatypes(self) -> List[OperatorMetatype]:
Expand Down Expand Up @@ -98,10 +100,13 @@ def get_abs_max_channel_collector(

@staticmethod
def get_weight_value(node_with_weight: NNCFNode, model: NNCFNetwork) -> Tensor:
node_module = model.nncf.get_containing_module(node_with_weight.node_name)
if node_module.weight is None:
raise RuntimeError(f"{node_module} module has no .weight attribute.")
return Tensor(node_module.weight.data)
weight_node = get_const_node(
node_with_weight, node_with_weight.metatype.weight_port_ids[0], model.nncf.get_graph()
)
if weight_node is None:
raise RuntimeError(f"{node_with_weight} node has no weight node.")
weight_data = get_const_data(weight_node, model)
return Tensor(weight_data)

@staticmethod
def get_weight_tensor_port_id(node: NNCFNode) -> int:
Expand Down Expand Up @@ -131,7 +136,7 @@ def scale_insertion_command(

@staticmethod
def get_activation_channel_axis(node: NNCFNode, port_id: int) -> int:
if node.metatype == om.PTModuleLinearMetatype:
if node.metatype == om.PTLinearMetatype:
return -1
# TODO: Add activation axis calculation when MatMul will be supported
return 1
Expand Down
9 changes: 6 additions & 3 deletions nncf/torch/graph/graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from nncf.common.graph import NNCFGraph
from nncf.common.graph import NNCFNode
from nncf.common.graph.definitions import MODEL_CONST_OP_NAME
from nncf.common.graph.operator_metatypes import CONST_NOOP_METATYPES
from nncf.common.graph.operator_metatypes import INPUT_NOOP_METATYPES
from nncf.torch.dynamic_graph.context import TracingContext
Expand Down Expand Up @@ -79,9 +80,11 @@ class GraphConverter:
def convert(dynamic_graph: DynamicGraph, traced_parameters) -> PTNNCFGraph:
module_id_vs_known_op_addrs_map: Dict[int, Set[Scope]] = defaultdict(set)
for dynamic_graph_node in dynamic_graph.get_all_nodes():
module_id_vs_known_op_addrs_map[dynamic_graph_node.calling_module_id].add(
dynamic_graph_node.op_exec_context.op_address
)
# Skip const nodes to detect shared nodes
if dynamic_graph_node.op_exec_context.operator_name != MODEL_CONST_OP_NAME:
module_id_vs_known_op_addrs_map[dynamic_graph_node.calling_module_id].add(
dynamic_graph_node.op_exec_context.op_address
)

module_id_vs_sorted_scopes_map = {
k: list(sorted([s.scope_in_model for s in v], key=str)) for k, v in module_id_vs_known_op_addrs_map.items()
Expand Down
Loading

0 comments on commit 590bc6d

Please sign in to comment.