Skip to content

Commit

Permalink
WIP statistic collection and quantizer params calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed May 27, 2024
1 parent fdd87f2 commit 4b842b3
Show file tree
Hide file tree
Showing 7 changed files with 134 additions and 51 deletions.
78 changes: 59 additions & 19 deletions nncf/experimental/torch_fx/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
from dataclasses import dataclass

# from functools import partial
# from typing import Callable, Dict, List, Optional, Tuple
from typing import Callable, List, Optional
from typing import Callable, List, Optional, Union

import torch
import torch.fx
Expand All @@ -31,12 +30,16 @@
from torch.fx.passes.infra.pass_manager import PassManager

from nncf.common.graph.model_transformer import ModelTransformer
from nncf.common.graph.transformations.commands import TargetType

# from nncf.torch.graph.transformations.commands import PTModelExtractionCommand
# from nncf.common.graph.transformations.commands import TransformationPriority
from nncf.common.graph.transformations.commands import Command
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.commands import TransformationPriority
from nncf.common.graph.transformations.commands import TransformationType
from nncf.torch.graph.transformations.commands import PTInsertionCommand
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.graph.transformations.commands import PTTargetPoint

# from nncf.torch.graph.transformations.commands import PTTargetPoint
# from nncf.torch.graph.transformations.commands import PTWeightUpdateCommand
Expand All @@ -50,6 +53,19 @@
# from nncf.torch.utils import is_multidevice


class FXInsertionCommand(Command):
def __init__(
self,
target_points: List[PTTargetPoint],
fn: Callable,
priority: Union[TransformationPriority, int] = TransformationPriority.DEFAULT_PRIORITY,
):
super().__init__(TransformationType.INSERT)
self.target_points = target_points
self.fn = fn
self.priority = priority


class FXModelTransformer(ModelTransformer):
"""
Applies transformations upon Torch FX model.
Expand All @@ -61,6 +77,7 @@ def __init__(self, model: torch.fx.GraphModule):
self._command_transformation_ordered_pairs = [
(PTInsertionCommand, self._apply_insertion_transformations),
(PTSharedFnInsertionCommand, self._apply_shared_nodes_insertion),
(FXInsertionCommand, self._apply_insertion_transformations),
]

def transform(self, transformation_layout: PTTransformationLayout) -> torch.fx.GraphModule:
Expand All @@ -75,6 +92,10 @@ def transform(self, transformation_layout: PTTransformationLayout) -> torch.fx.G
if transformations:
model = transformation_fn(model, transformations)

# Do not eliminate dead code as
# the dead code is coputing statistics :)
# model.graph.eliminate_dead_code()
model.recompile()
return model

@staticmethod
Expand All @@ -90,9 +111,8 @@ def _apply_insertion_transformations(
functions which are subclassed from torch.nn.Module. Do nothing in case device is None.
:return: A modified torch.fx.GraphModule.
"""
node_type = "output"
node_type = "call_module"
graph = model.graph
outputs = []
for transformation in transformations:
for node in graph.nodes:
if node.name == transformation.target_point.target_node_name:
Expand All @@ -101,25 +121,22 @@ def _apply_insertion_transformations(
target_type = transformation.target_point.target_type
if target_type == TargetType.OPERATOR_PRE_HOOK:
ctx = graph.inserting_before(target_node)
target_nodes = [target_node]
elif target_type == TargetType.OPERATOR_POST_HOOK:
ctx = graph.inserting_after(target_node)
target_nodes = [target_node]
elif target_type == TargetType.OPERATION_WITH_WEIGHTS:
# TODO: make it common
target_node = target_node.all_input_nodes[transformation.target_point.input_port_id]
ctx = graph.inserting_after(target_node)
target_nodes = []
for input in target_node.all_input_nodes:
if input.op == "get_attr":
target_nodes.append(input)
else:
raise RuntimeError(f"Unsupported target type: {target_type} for transformation: {transformation}")

fn = transformation.fn
obs_name_in_model = target_node.name + str(id(fn))
assert not hasattr(model, obs_name_in_model)
setattr(model, obs_name_in_model, fn)
with ctx:
for target_node in target_nodes:
outputs.append(
graph.create_node(node_type, "", (target_node,), {}, name=target_node.name + "_nncf_output")
)
graph.create_node(
node_type, obs_name_in_model, (target_node,), {}, name=obs_name_in_model + "_graph_node"
)
return model

@staticmethod
Expand All @@ -137,10 +154,33 @@ def _apply_shared_nodes_insertion(
functions which are subclassed from torch.nn.Module. Do nothing in case device is None.
:return: A modified torch.fx.GraphModule.
"""
node_type = "call_module"
graph = model.graph
for transformation in transformations:
a = 6
del a
pass
for node in graph.nodes:
if node.name == transformation.target_point.target_node_name:
target_node = node
break
target_type = transformation.target_point.target_type
if target_type == TargetType.OPERATOR_PRE_HOOK:
ctx = graph.inserting_before(target_node)
elif target_type == TargetType.OPERATOR_POST_HOOK:
ctx = graph.inserting_after(target_node)
elif target_type == TargetType.OPERATION_WITH_WEIGHTS:
target_node = target_node.all_input_nodes[transformation.target_point.input_port_id]
ctx = graph.inserting_after(target_node)
else:
raise RuntimeError(f"Unsupported target type: {target_type} for transformation: {transformation}")

fn = transformation.fn
obs_name_in_model = target_node.name + str(id(fn))
assert not hasattr(model, obs_name_in_model)
setattr(model, obs_name_in_model, fn)
with ctx:
graph.create_node(
node_type, obs_name_in_model, (target_node,), {}, name=obs_name_in_model + "_graph_node"
)
return model


@dataclass
Expand Down
32 changes: 23 additions & 9 deletions nncf/experimental/torch_fx/nncf_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
from typing import Tuple

import torch.fx
from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_
from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_ # noqa

import nncf.torch.graph.operator_metatypes as om
from nncf.common.graph import NNCFGraph
from nncf.common.graph import NNCFNode
from nncf.common.graph.layer_attributes import Dtype
from nncf.common.graph.operator_metatypes import UnknownMetatype
from nncf.torch.graph.graph import PTNNCFGraph
from nncf.torch.graph.operator_metatypes import PT_OPERATOR_METATYPES

# from nncf.experimental.torch_fx.operator_metatypes import FX_OPERATOR_METATYPES
Expand Down Expand Up @@ -52,7 +54,12 @@ def _get_node_type_and_metatype(node: torch.fx.Node) -> Tuple[str, om.OperatorMe
node_type = "get_attr"
node_metatype = om.PTConstNoopMetatype
elif node.op in ("call_function",):
node_type = str(node.target.overloadpacket).split(".")[1]
if hasattr(node.target, "overloadpacket"):
torch.nn.BatchNorm2d
node_type = str(node.target.overloadpacket).split(".")[1]
else:
# TODO: get correct nodes types from this nodes as well
node_type = str(node.target)
node_metatype = PT_OPERATOR_METATYPES.get_operator_metatype_by_op_name(node_type)
# TODO: add layer attrs and support subtypes
# if node_metatype.get_subtypes():
Expand All @@ -75,9 +82,12 @@ def create_nncf_graph(model: torch.fx.GraphModule) -> NNCFGraph:
:return: NNCFGraph.
"""

_fuse_conv_bn_(model)
# _fuse_conv_bn_(model)

nncf_graph = NNCFGraph()
# model.graph.eliminate_dead_code()
# model.recompile()

nncf_graph = PTNNCFGraph()

for source_node in model.graph.nodes:

Expand Down Expand Up @@ -112,15 +122,15 @@ def get_module_params_or_buffers():

for source_node in model.graph.nodes:

source_node_id = nncf_graph.get_node_by_name(source_node.name).node_id
source_nncf_node = nncf_graph.get_node_by_name(source_node.name)
for dist_node in source_node.users:
dist_node_id = nncf_graph.get_node_by_name(dist_node.name).node_id
input_port_id, output_port_id, tensor_shape = GraphConverter.get_edge_params(
model, source_node, dist_node
model, source_node, source_nncf_node, dist_node
)

nncf_graph.add_edge_between_nncf_nodes(
source_node_id,
source_nncf_node.node_id,
dist_node_id,
tensor_shape=tensor_shape,
input_port_id=input_port_id,
Expand All @@ -131,13 +141,17 @@ def get_module_params_or_buffers():
return nncf_graph

@staticmethod
def get_edge_params(model, source_node: torch.fx.Node, dist_node: torch.fx.Node):
def get_edge_params(model, source_node: torch.fx.Node, source_nncf_node: NNCFNode, dist_node: torch.fx.Node):
# TODO: support cat
output_port_id = 0
if source_node.op in ("get_attr",):
tensor_shape = tuple(getattr(model, source_node.target).shape)
elif "val" in source_node.meta:
tensor_shape = tuple(source_node.meta["val"].shape)
if source_nncf_node.metatype is om.PTBatchNormMetatype:
tensor = source_node.meta["val"][0]
else:
tensor = source_node.meta["val"]
tensor_shape = tuple(tensor.shape)
else:
print(f"Edge shape between {source_node.name} and {dist_node.name} is unknown. Using [1,1,1,1] instead.")
tensor_shape = [1, 1, 1, 1]
Expand Down
37 changes: 32 additions & 5 deletions nncf/experimental/torch_fx/statistics/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,32 @@
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.common.tensor_statistics.aggregator import StatisticPointsContainer
from nncf.common.tensor_statistics.aggregator import StatisticsAggregator
from nncf.torch.graph.transformations.commands import PTInsertionCommand
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.experimental.torch_fx.model_transformer import PTInsertionCommand
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.return_types import maybe_get_values_from_torch_return_type
from nncf.torch.tensor import PTNNCFTensor
from nncf.torch.tensor_statistics.algo import create_register_input_hook


class TensorCollectorModule(torch.nn.Module):
"""
torch.nn.Module which calls given collector in forward
"""

def __init__(self, collector: TensorCollector):
super().__init__()
self._collector = collector

def forward(self, x: torch.Tensor):
"""
Register inputs hook function.
:parameter x: tensor to register in hook.
:return: tensor to register in hook.
"""
x_unwrapped = maybe_get_values_from_torch_return_type(x)
self._collector.register_input_for_all_reducers(PTNNCFTensor(x_unwrapped))
return x


class FXStatisticsAggregator(StatisticsAggregator):
Expand All @@ -32,7 +54,12 @@ class FXStatisticsAggregator(StatisticsAggregator):
def collect_statistics(self, model: NNCFNetwork, graph: NNCFGraph) -> None:
with torch.no_grad():
super().collect_statistics(model, graph)
model.nncf.remove_hooks_group(self.HOOKS_GROUP_NAME)
# All statistics are collected as a dead code,
# so eliminate dead core removed statistcs collector
# from the target model. No additional code required
# for that, horay!
model.graph.eliminate_dead_code()
model.recompile()

def _register_statistics(
self, outputs: Dict[str, PTNNCFTensor], statistic_points: StatisticPointsContainer
Expand All @@ -50,11 +77,11 @@ def _get_transformation_layout_extra_outputs(
for collectors in _statistic_point.algorithm_to_tensor_collectors.values():
for collector in collectors:
transformation_commands.append(
# FXInsertionCommand(
PTInsertionCommand(
_statistic_point.target_point,
create_register_input_hook(collector=collector),
TensorCollectorModule(collector),
TransformationPriority.FP32_TENSOR_STATISTICS_OBSERVATION,
hooks_group_name=self.HOOKS_GROUP_NAME,
)
)

Expand Down
22 changes: 9 additions & 13 deletions nncf/quantization/algorithms/min_max/torch_fx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from nncf.common.graph.definitions import NNCFGraphNodeType
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.layer_attributes import WeightedLayerAttributes
from nncf.common.graph.operator_metatypes import OperatorMetatype
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.commands import TransformationCommand
Expand Down Expand Up @@ -51,7 +50,7 @@
from nncf.torch.quantization.layers import get_scale_shape
from nncf.torch.tensor_statistics.collectors import PT_REDUCERS_MAP
from nncf.torch.tensor_statistics.collectors import PTNNCFCollectorTensorProcessor
from nncf.torch.tensor_statistics.statistics import FXMinMaxTensorStatistic
from nncf.torch.tensor_statistics.statistics import PTMinMaxTensorStatistic


class FXMinMaxAlgoBackend(MinMaxAlgoBackend):
Expand Down Expand Up @@ -140,14 +139,14 @@ def create_convert_insertion_command(
raise nncf.InternalError("FakeConvert insertion not implemented in PyTorch backend!")

@staticmethod
def unify_statistics(statistics: List[FXMinMaxTensorStatistic]) -> FXMinMaxTensorStatistic:
def unify_statistics(statistics: List[PTMinMaxTensorStatistic]) -> PTMinMaxTensorStatistic:
max_values, min_values = [], []
for statistic in statistics:
max_values.append(statistic.max_values.flatten())
min_values.append(statistic.min_values.flatten())
max_values = torch.amax(torch.stack(max_values), dim=0)
min_values = torch.amin(torch.stack(min_values), dim=0)
return FXMinMaxTensorStatistic(min_values=min_values, max_values=max_values)
return PTMinMaxTensorStatistic(min_values=min_values, max_values=max_values)

@staticmethod
def get_target_point_shape(nncf_graph: NNCFGraph, node: NNCFNode, target_point: PTTargetPoint) -> Tuple[int, ...]:
Expand All @@ -168,10 +167,10 @@ def get_statistic_collector(
inplace: bool,
num_samples: Optional[int] = None,
) -> TensorCollector:
collector = TensorCollector(FXMinMaxTensorStatistic)
collector = TensorCollector(PTMinMaxTensorStatistic)
for params, container_key in zip(
[range_estimator_params.min, range_estimator_params.max],
[FXMinMaxTensorStatistic.MIN_STAT, FXMinMaxTensorStatistic.MAX_STAT],
[PTMinMaxTensorStatistic.MIN_STAT, PTMinMaxTensorStatistic.MAX_STAT],
):
if params.statistics_type not in PT_REDUCERS_MAP:
raise nncf.InternalError(
Expand All @@ -186,7 +185,7 @@ def get_statistic_collector(
statistic_type = params.statistics_type
if statistic_type in [StatisticsType.QUANTILE, StatisticsType.ABS_QUANTILE]:
# TODO(dlyakhov): merge two quantile aggregators in one
if container_key == FXMinMaxTensorStatistic.MIN_STAT:
if container_key == PTMinMaxTensorStatistic.MIN_STAT:
quantile = params.quantile_outlier_prob
else:
quantile = 1 - params.quantile_outlier_prob
Expand Down Expand Up @@ -231,15 +230,12 @@ def _get_input_scale_shape(
) -> Tuple[Tuple[int, ...], Tuple[int, ...], int]:
is_weights = target_point.is_weight_target_point()
if is_weights:
module_node = nncf_graph.get_node_by_name(target_point.target_node_name)
layer_attributes = module_node.layer_attributes
assert isinstance(layer_attributes, WeightedLayerAttributes)
input_shape = layer_attributes.get_weight_shape()
channel_idx = layer_attributes.get_target_dim_for_compression()
# TODO: support transpose conv/ make channel_idx common
channel_idx = 0
else:
input_shape = nncf_graph.get_input_shape_for_insertion_point(target_point)
channel_idx = 1 # channel dim for activations

input_shape = nncf_graph.get_input_shape_for_insertion_point(target_point)
scale_shape = tuple(
get_scale_shape(input_shape, is_weights=is_weights, per_channel=per_channel, channel_idx=channel_idx)
)
Expand Down
2 changes: 1 addition & 1 deletion nncf/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,4 @@

from nncf.torch.extensions import force_build_cpu_extensions, force_build_cuda_extensions

patch_torch_operators()
# patch_torch_operators()
10 changes: 8 additions & 2 deletions nncf/torch/graph/operator_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,13 +693,19 @@ class PTThresholdMetatype(PTOperatorMetatype):
@PT_OPERATOR_METATYPES.register(is_subtype=True)
class PTModuleBatchNormMetatype(PTModuleOperatorSubtype):
name = "BatchNormOp"
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["batch_norm"]}
module_to_function_names = {
NamespaceTarget.TORCH_NN_FUNCTIONAL: ["batch_norm"],
NamespaceTarget.ATEN: ["_native_batch_norm_legit_no_training"],
}


@PT_OPERATOR_METATYPES.register()
class PTBatchNormMetatype(PTOperatorMetatype):
name = "BatchNormOp"
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["batch_norm"]}
module_to_function_names = {
NamespaceTarget.TORCH_NN_FUNCTIONAL: ["batch_norm"],
NamespaceTarget.ATEN: ["_native_batch_norm_legit_no_training"],
}
subtypes = [PTModuleBatchNormMetatype]
weight_port_ids = [3]
bias_port_id = 4
Expand Down
Loading

0 comments on commit 4b842b3

Please sign in to comment.