Skip to content

Commit

Permalink
use const graph for nncf.quantize
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderDokuchaev committed Feb 12, 2024
1 parent d11528f commit 1ca7bee
Show file tree
Hide file tree
Showing 43 changed files with 12,115 additions and 7,758 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,9 @@ def apply(
output_fp = self._get_fp_outputs(statistic_points, out_node_name)

extracted_model = self._extract_submodel(model_transformer, node_name)

if extracted_model is None:
nncf_logger.debug(f"Skipping node {node_name} because cant extract submodel")
continue
sub_input_name, sub_output_name = self._backend_entity.get_sub_input_output_names(extracted_model)

channel_axis = node.metatype.output_channel_axis
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def _prepare_pipeline_step(
:param step_index: Zero-based index of pipeline step that should be prepared.
:param step_model: A model.
:param step_graph: A graph assosiated with a model.
:param step_graph: A graph associated with a model.
:param step_combinations: Combinations that change parameters only for the step_index-th pipeline step.
"""
# Create a separate pipeline for each combination
Expand Down
6 changes: 3 additions & 3 deletions nncf/quantization/algorithms/min_max/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import collections
import dataclasses
from copy import deepcopy
from typing import Any, Dict, List, Optional, OrderedDict, Set, TypeVar, Union
from typing import Any, Dict, List, Optional, OrderedDict, Set, Tuple, TypeVar, Union

import numpy as np

Expand Down Expand Up @@ -618,7 +618,7 @@ def _get_activation_quantization_target_point(

def _get_quantization_target_points(
self, model: TModel, nncf_graph: NNCFGraph
) -> OrderedDict[TargetPoint, QuantizerConfig]:
) -> Tuple[OrderedDict[TargetPoint, QuantizerConfig], List[List[TargetPoint]]]:
"""
Returns Quantization Target Points.
In the Compression Pipeline logic NNCF assumes that the compression pipeline works only on the single model.
Expand Down Expand Up @@ -988,7 +988,7 @@ def _is_node_after_producers(node):
quantizer_setup.discard(fq_2_q_key, True)
continue

# In the case of the two quantizers without the brancking after them,
# In the case of the two quantizers without the branching after them,
# it needs to check that all quantizers follows after producer nodes.
if _is_node_after_producers(fq_1_producer) and _is_node_after_producers(fq_2_producer):
fq_1_prod_shape = np.prod(nncf_graph.get_output_edges(fq_1_producer)[0].tensor_shape)
Expand Down
2 changes: 1 addition & 1 deletion nncf/quantization/algorithms/min_max/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def mat_mul_metatypes(self) -> List[OperatorMetatype]:
@abstractmethod
def post_processing_metatypes(self) -> List[OperatorMetatype]:
"""
Property for the backend-specific post-processing metatypes (NonMaximumSupression, TopK, etc.).
Property for the backend-specific post-processing metatypes (NonMaximumSuppression, TopK, etc.).
"""

@property
Expand Down
18 changes: 9 additions & 9 deletions nncf/quantization/algorithms/min_max/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,18 +80,18 @@ def read_variable_metatypes(self) -> List[OperatorMetatype]:

@property
def conv_metatypes(self) -> List[OperatorMetatype]:
return [om.PTModuleConv1dMetatype, om.PTModuleConv2dMetatype, om.PTModuleConv3dMetatype]
return [om.PTConv1dMetatype, om.PTConv2dMetatype, om.PTConv3dMetatype]

@property
def overflow_fix_metatypes(self) -> List[OperatorMetatype]:
return [
om.PTModuleConv1dMetatype,
om.PTModuleConv2dMetatype,
om.PTModuleConv3dMetatype,
om.PTModuleLinearMetatype,
om.PTModuleConvTranspose1dMetatype,
om.PTModuleConvTranspose2dMetatype,
om.PTModuleConvTranspose3dMetatype,
om.PTConv1dMetatype,
om.PTConv2dMetatype,
om.PTConv3dMetatype,
om.PTLinearMetatype,
om.PTConvTranspose1dMetatype,
om.PTConvTranspose2dMetatype,
om.PTConvTranspose3dMetatype,
]

@property
Expand Down Expand Up @@ -210,7 +210,7 @@ def get_statistic_collector(

@staticmethod
def get_weight_tensor_port_ids(node: NNCFNode) -> List[Optional[int]]:
return [None]
return node.metatype.weight_port_ids

@staticmethod
def get_weight_name(nncf_graph: NNCFGraph, target_point: PTTargetPoint) -> str:
Expand Down
8 changes: 4 additions & 4 deletions nncf/quantization/algorithms/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def collect_statistics(
:param statistic_points: Statistic points that need to be collected.
:param model: A model.
:param graph: A graph assosiated with a model.
:param graph: A graph associated with a model.
:param dataset: A dataset.
:return: Collected statistics.
"""
Expand Down Expand Up @@ -105,7 +105,7 @@ def run_step(
:param step_index: Zero-based index of the pipeline step that should be executed
:param step_statistics: Statistics required to execute a pipeline step.
:param model: A model to which a pipeline step will be applied.
:param graph: A graph assosiated with a model.
:param graph: A graph associated with a model.
:return: The updated model after executing the pipeline step.
"""
current_model = model
Expand Down Expand Up @@ -134,7 +134,7 @@ def run_from_step(
:param model: This is the model after the (start_step_index - 1)-th pipeline
step, or the initial model if start_step_index is 0.
:param dataset: A dataset that holds the data items for pipeline steps.
:param graph: A graph assosiated with a model.
:param graph: A graph associated with a model.
:param start_step_index: Zero-based pipeline step index from which the pipeline
should be executed.
:param step_index_to_statistics: A mapping from pipeline step index to statistics
Expand Down Expand Up @@ -175,7 +175,7 @@ def get_statistic_points_for_step(
:param step_index: Zero-based index of the pipeline step.
:param model: A model.
:param graph: A graph assosiated with a model.
:param graph: A graph associated with a model.
:return: Statistics that should be collected to execute `step_index`-th pipeline step.
"""
container = StatisticPointsContainer()
Expand Down
19 changes: 10 additions & 9 deletions nncf/quantization/algorithms/smooth_quant/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from nncf.torch.graph.transformations.command_creation import create_command_to_update_weight
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.model_analyzer import get_const
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.quantization.default_quantization import DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT
from nncf.torch.tensor_statistics.collectors import PTAbsMaxReducer
Expand All @@ -53,14 +54,14 @@ class PTSmoothQuantAlgoBackend(SmoothQuantAlgoBackend):
@property
def convolution_metatypes(self) -> List[OperatorMetatype]:
return [
om.PTModuleConv1dMetatype,
om.PTModuleConv2dMetatype,
om.PTModuleConv3dMetatype,
om.PTConv1dMetatype,
om.PTConv2dMetatype,
om.PTConv3dMetatype,
]

@property
def matmul_metatypes(self) -> List[OperatorMetatype]:
return [om.PTModuleLinearMetatype]
return [om.PTLinearMetatype]

@property
def quantize_agnostic_metatypes(self) -> List[OperatorMetatype]:
Expand Down Expand Up @@ -103,10 +104,10 @@ def get_abs_max_channel_collector(

@staticmethod
def get_weight_value(node_with_weight: NNCFNode, model: NNCFNetwork) -> Tensor:
node_module = model.nncf.get_containing_module(node_with_weight.node_name)
if node_module.weight is None:
raise RuntimeError(f"{node_module} module has no .weight attribute.")
return Tensor(node_module.weight.data)
data = get_const(node_with_weight, node_with_weight.metatype.weight_port_ids[0], model)
if data is None:
raise RuntimeError(f"{node_with_weight.node_name} node has no weights.")
return Tensor(data)

@staticmethod
def get_weight_tensor_port_id(node: NNCFNode) -> int:
Expand Down Expand Up @@ -136,7 +137,7 @@ def scale_insertion_command(

@staticmethod
def get_activation_channel_axis(node: NNCFNode, port_id: int) -> int:
if node.metatype == om.PTModuleLinearMetatype:
if node.metatype == om.PTLinearMetatype:
return -1
# TODO: Add activation axis calculation when MatMul will be supported
return 1
Expand Down
7 changes: 7 additions & 0 deletions nncf/torch/dynamic_graph/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,15 @@ def wrap_parameters(model: torch.nn.Module):
:param model: A model.
"""
from nncf.torch.external_hook import EXTERNAL_OP_STORAGE_PREFIX
from nncf.torch.quantization.external_quantizer import EXTERNAL_QUANTIZERS_STORAGE_PREFIX

ignored_prefixes = [EXTERNAL_QUANTIZERS_STORAGE_PREFIX, EXTERNAL_OP_STORAGE_PREFIX]

ctx = get_current_context()
for name, param in model.named_parameters():
if any(name.startswith(ignore_prefix) for ignore_prefix in ignored_prefixes):
continue
is_reused = name in ctx.reused_parameters
tt = TracedParameter.from_torch_parameter(param, name, is_reused)
ctx.register_traced_tensor(tt)
1 change: 1 addition & 0 deletions nncf/torch/external_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from nncf.torch.dynamic_graph.context import TracingContext

EXTERNAL_OP_STORAGE_NAME = "external_op"
EXTERNAL_OP_STORAGE_PREFIX = "_nncf." + EXTERNAL_OP_STORAGE_NAME


class ExternalOpCallHook:
Expand Down
9 changes: 6 additions & 3 deletions nncf/torch/graph/graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import torch

from nncf.common.graph.definitions import MODEL_CONST_OP_NAME
from nncf.common.graph.operator_metatypes import CONST_NOOP_METATYPES
from nncf.common.graph.operator_metatypes import INPUT_NOOP_METATYPES
from nncf.torch.dynamic_graph.context import TracingContext
Expand Down Expand Up @@ -76,9 +77,11 @@ class GraphConverter:
def convert(dynamic_graph: DynamicGraph) -> PTNNCFGraph:
module_id_vs_known_op_addrs_map: Dict[int, Set[Scope]] = defaultdict(set)
for dynamic_graph_node in dynamic_graph.get_all_nodes():
module_id_vs_known_op_addrs_map[dynamic_graph_node.calling_module_id].add(
dynamic_graph_node.op_exec_context.op_address
)
# Skip const nodes to detect shared nodes
if dynamic_graph_node.op_exec_context.operator_name != MODEL_CONST_OP_NAME:
module_id_vs_known_op_addrs_map[dynamic_graph_node.calling_module_id].add(
dynamic_graph_node.op_exec_context.op_address
)

module_id_vs_sorted_scopes_map = {
k: list(sorted([s.scope_in_model for s in v], key=str)) for k, v in module_id_vs_known_op_addrs_map.items()
Expand Down
25 changes: 18 additions & 7 deletions nncf/torch/graph/operator_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ class PTDepthwiseConv1dSubtype(PTDepthwiseConvOperatorSubtype):
output_channel_axis = 1
num_expected_input_edges = 2
weight_port_ids = [1]
bias_port_id = 2


@PT_OPERATOR_METATYPES.register()
Expand All @@ -186,6 +187,7 @@ class PTModuleConv1dMetatype(PTModuleOperatorSubtype):
output_channel_axis = 1
num_expected_input_edges = 2
weight_port_ids = [1]
bias_port_id = 2


@PT_OPERATOR_METATYPES.register()
Expand All @@ -197,6 +199,7 @@ class PTConv1dMetatype(PTOperatorMetatype):
output_channel_axis = 1
num_expected_input_edges = 2
weight_port_ids = [1]
bias_port_id = 2


@PT_OPERATOR_METATYPES.register()
Expand All @@ -207,6 +210,7 @@ class PTDepthwiseConv2dSubtype(PTDepthwiseConvOperatorSubtype):
output_channel_axis = 1
num_expected_input_edges = 2
weight_port_ids = [1]
bias_port_id = 2


@PT_OPERATOR_METATYPES.register()
Expand All @@ -218,6 +222,7 @@ class PTModuleConv2dMetatype(PTModuleOperatorSubtype):
output_channel_axis = 1
num_expected_input_edges = 2
weight_port_ids = [1]
bias_port_id = 2


@PT_OPERATOR_METATYPES.register()
Expand All @@ -229,6 +234,7 @@ class PTConv2dMetatype(PTOperatorMetatype):
output_channel_axis = 1
num_expected_input_edges = 2
weight_port_ids = [1]
bias_port_id = 2


@PT_OPERATOR_METATYPES.register()
Expand All @@ -239,6 +245,7 @@ class PTDepthwiseConv3dSubtype(PTDepthwiseConvOperatorSubtype):
output_channel_axis = 1
num_expected_input_edges = 2
weight_port_ids = [1]
bias_port_id = 2


@PT_OPERATOR_METATYPES.register()
Expand All @@ -250,6 +257,7 @@ class PTModuleConv3dMetatype(PTModuleOperatorSubtype):
output_channel_axis = 1
num_expected_input_edges = 2
weight_port_ids = [1]
bias_port_id = 2


@PT_OPERATOR_METATYPES.register()
Expand All @@ -261,6 +269,7 @@ class PTConv3dMetatype(PTOperatorMetatype):
output_channel_axis = 1
num_expected_input_edges = 2
weight_port_ids = [1]
bias_port_id = 2


@PT_OPERATOR_METATYPES.register()
Expand Down Expand Up @@ -626,6 +635,8 @@ class PTBatchNormMetatype(PTOperatorMetatype):
name = "BatchNormOp"
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["batch_norm"]}
subtypes = [PTModuleBatchNormMetatype]
weight_port_ids = [3]
bias_port_id = 4


@PT_OPERATOR_METATYPES.register()
Expand Down Expand Up @@ -1035,19 +1046,19 @@ def get_operator_metatypes() -> List[Type[OperatorMetatype]]:

# Contains the operation metatypes for which bias can be applied.
OPERATORS_WITH_BIAS_METATYPES = [
PTModuleConv1dMetatype,
PTModuleConv2dMetatype,
PTModuleConv3dMetatype,
PTConv1dMetatype,
PTConv2dMetatype,
PTConv3dMetatype,
PTDepthwiseConv1dSubtype,
PTDepthwiseConv2dSubtype,
PTDepthwiseConv3dSubtype,
PTModuleConvTranspose1dMetatype,
PTModuleConvTranspose2dMetatype,
PTModuleConvTranspose3dMetatype,
PTConvTranspose1dMetatype,
PTConvTranspose2dMetatype,
PTConvTranspose3dMetatype,
]

OPERATORS_FUSED_METATYPES = [
PTModuleBatchNormMetatype,
PTBatchNormMetatype,
]

OP_NAMES_QUANTIZE_NODE = ["symmetric_quantize", "asymmetric_quantize"]
Loading

0 comments on commit 1ca7bee

Please sign in to comment.