Skip to content

Commit

Permalink
WIP FQ insertion
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed May 27, 2024
1 parent 4b842b3 commit 2c2921a
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 82 deletions.
139 changes: 68 additions & 71 deletions nncf/experimental/torch_fx/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.commands import TransformationPriority
from nncf.common.graph.transformations.commands import TransformationType
from nncf.torch.graph.transformations.commands import PTInsertionCommand
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.graph.transformations.commands import PTTargetPoint

# from nncf.torch.graph.transformations.commands import PTTargetPoint
Expand All @@ -53,16 +51,29 @@
# from nncf.torch.utils import is_multidevice


class FXInsertionCommand(Command):
class FXModuleInsertionCommand(Command):
def __init__(
self,
target_points: List[PTTargetPoint],
fn: Callable,
module_to_insert: torch.nn.Module,
priority: Union[TransformationPriority, int] = TransformationPriority.DEFAULT_PRIORITY,
):
super().__init__(TransformationType.INSERT)
self.target_points = target_points
self.fn = fn
self.module_to_insert = module_to_insert
self.priority = priority


class FXApplyTransformationCommand(Command):
def __init__(
self,
target_point: PTTargetPoint,
transformation_fn: Callable[[torch.fx.Graph, torch.fx.Node], None],
priority: Union[TransformationPriority, int] = TransformationPriority.DEFAULT_PRIORITY,
):
super().__init__(TransformationType.INSERT)
self.target_point = target_point
self.tranformation_fn = transformation_fn
self.priority = priority


Expand All @@ -75,9 +86,8 @@ def __init__(self, model: torch.fx.GraphModule):
super().__init__(model)

self._command_transformation_ordered_pairs = [
(PTInsertionCommand, self._apply_insertion_transformations),
(PTSharedFnInsertionCommand, self._apply_shared_nodes_insertion),
(FXInsertionCommand, self._apply_insertion_transformations),
(FXApplyTransformationCommand, self._apply_fn_insertion),
(FXModuleInsertionCommand, self._apply_module_insertion),
]

def transform(self, transformation_layout: PTTransformationLayout) -> torch.fx.GraphModule:
Expand All @@ -99,87 +109,74 @@ def transform(self, transformation_layout: PTTransformationLayout) -> torch.fx.G
return model

@staticmethod
def _apply_insertion_transformations(
model: torch.fx.GraphModule, transformations: List[PTInsertionCommand]
def _apply_module_insertion(
model: torch.fx.GraphModule,
transformations: List[FXModuleInsertionCommand],
) -> torch.fx.GraphModule:
"""
Applies insertion transformations to the model.
Applies insertion of PTSharedFnInsertionCommand commands. For each command method inserts
a torch module to the torch.fx.GraphModule and inserts call hooks for each command target points.
:param model: Model to apply transformations.
:param transformations: List of the bias correction transformations.
:param device: Target device for the insertion functions. Applies only to
functions which are subclassed from torch.nn.Module. Do nothing in case device is None.
:return: A modified torch.fx.GraphModule.
"""
node_type = "call_module"
graph = model.graph
for transformation in transformations:
for node in graph.nodes:
if node.name == transformation.target_point.target_node_name:
target_node = node
break
target_type = transformation.target_point.target_type
if target_type == TargetType.OPERATOR_PRE_HOOK:
ctx = graph.inserting_before(target_node)
elif target_type == TargetType.OPERATOR_POST_HOOK:
ctx = graph.inserting_after(target_node)
elif target_type == TargetType.OPERATION_WITH_WEIGHTS:
target_node = target_node.all_input_nodes[transformation.target_point.input_port_id]
ctx = graph.inserting_after(target_node)
else:
raise RuntimeError(f"Unsupported target type: {target_type} for transformation: {transformation}")

fn = transformation.fn
obs_name_in_model = target_node.name + str(id(fn))
assert not hasattr(model, obs_name_in_model)
setattr(model, obs_name_in_model, fn)
with ctx:
graph.create_node(
node_type, obs_name_in_model, (target_node,), {}, name=obs_name_in_model + "_graph_node"
# Set fn to the model as an attribute
module_to_insert = transformation.module_to_insert
module_name_in_model = (
";".join(
"_".join((tp.target_node_name, str(tp.input_port_id), str(tp.target_type.value)))
for tp in transformation.target_points
)
+ "_"
+ str(id(module_to_insert))
)
assert not hasattr(model, module_name_in_model)
setattr(model, module_name_in_model, module_to_insert)
# Insert call_module nodes to the model
for target_point in transformation.target_points:
FXModelTransformer._create_call_module_node(model.graph, target_point, module_name_in_model)
return model

@staticmethod
def _apply_shared_nodes_insertion(
def _get_grah_node_by_name(graph, name):
for node in graph.nodes:
if node.name == name:
return node

@staticmethod
def _get_target_node_and_ctx(graph: torch.fx.Graph, target_point: PTTargetPoint):
target_type = target_point.target_type
target_node = FXModelTransformer._get_grah_node_by_name(graph, target_point.target_node_name)
if target_type == TargetType.OPERATOR_PRE_HOOK:
ctx = graph.inserting_before(target_node)
elif target_type == TargetType.OPERATOR_POST_HOOK:
ctx = graph.inserting_after(target_node)
elif target_type == TargetType.OPERATION_WITH_WEIGHTS:
target_node = target_node.all_input_nodes[target_point.input_port_id]
ctx = graph.inserting_after(target_node)
else:
raise RuntimeError(f"Unsupported target type: {target_type} for target_point: {target_point}")
return target_node, ctx

@staticmethod
def _create_call_module_node(graph: torch.fx.Graph, target_point: PTTargetPoint, module_name: str):
target_node, ctx = FXModelTransformer._get_target_node_and_ctx(graph, target_point)
with ctx:
graph.create_node("call_module", module_name, (target_node,), {}, name=module_name + "_graph_node")

@staticmethod
def _apply_fn_insertion(
model: torch.fx.GraphModule,
transformations: List[PTSharedFnInsertionCommand],
transformations: List[FXApplyTransformationCommand],
) -> torch.fx.GraphModule:
"""
Applies insertion of PTSharedFnInsertionCommand commands. For each command method inserts
a torch module to the torch.fx.GraphModule and inserts call hooks for each command target points.
:param model: Model to apply transformations.
:param transformations: List of the bias correction transformations.
:param device: Target device for the insertion functions. Applies only to
functions which are subclassed from torch.nn.Module. Do nothing in case device is None.
:return: A modified torch.fx.GraphModule.
"""
node_type = "call_module"
graph = model.graph
for transformation in transformations:
for node in graph.nodes:
if node.name == transformation.target_point.target_node_name:
target_node = node
break
target_type = transformation.target_point.target_type
if target_type == TargetType.OPERATOR_PRE_HOOK:
ctx = graph.inserting_before(target_node)
elif target_type == TargetType.OPERATOR_POST_HOOK:
ctx = graph.inserting_after(target_node)
elif target_type == TargetType.OPERATION_WITH_WEIGHTS:
target_node = target_node.all_input_nodes[transformation.target_point.input_port_id]
ctx = graph.inserting_after(target_node)
else:
raise RuntimeError(f"Unsupported target type: {target_type} for transformation: {transformation}")

fn = transformation.fn
obs_name_in_model = target_node.name + str(id(fn))
assert not hasattr(model, obs_name_in_model)
setattr(model, obs_name_in_model, fn)
with ctx:
graph.create_node(
node_type, obs_name_in_model, (target_node,), {}, name=obs_name_in_model + "_graph_node"
)
target_node, _ = FXModelTransformer._get_target_node_and_ctx(graph, transformation.target_point)
transformation.tranformation_fn(graph, target_node)
return model


Expand Down
90 changes: 90 additions & 0 deletions nncf/experimental/torch_fx/node_creation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Optional

import torch
import torch.fx
from torch.ao.quantization.fx.utils import create_getattr_from_value

from nncf.quantization.fake_quantize import FakeQuantizeParameters
from nncf.torch.quantization.layers import PTQuantizerSpec


def quantizer_insertion_tranformation_builder(
qspec: PTQuantizerSpec, fq_params: FakeQuantizeParameters, axis: int, eps=1e-5
):
# signed = bool(torch.any(fq_params.input_low.data < 0))
# Subtract eps from the scale to make quantizer parameters equal to
# original parameters on the forward call.
scale = (fq_params.input_high.data - eps).reshape(qspec.scale_shape)

def quantizer_insertion_tranformation(model: torch.fx.GraphModule, node: torch.fx.Node):
# 1. extract information for inserting q/dq node from activation_post_process
node_type = "call_function"
quantize_op: Optional[Callable] = None
# scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined, operator]
if qspec.per_channel:
quantize_op = torch.ops.quantized_decomposed.quantize_per_channel.default
dequantize_op = torch.ops.quantized_decomposed.dequantize_per_channel.default
else:
quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.default
dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default
# TODO: map FakeQuantizePramaeters to qparams for quantize/dequantize
qparams = {
"_scale_": scale,
"_zero_point_": 0,
"_axis_": axis,
"_quant_min_": 0,
"_quant_max_": 2**qspec.num_bits - 1,
"_dtype_": torch.int8,
}
# 2. replace activation_post_process node with quantize and dequantize
graph = model.graph
# TODO: use metatype to get correct input_port_id
# Do not quantize already quantized nodes
# inserting_before handle only order in the graph generated code.
# so, inserting quantize-dequantize and all constant nodes before the usage of the nodes
with graph.inserting_before(node):
quantize_op_inputs = [node]
for key, value_or_node in qparams.items():
# TODO: we can add the information of whether a value needs to
# be registered as an attribute in qparams dict itself
if key in ["_scale_", "_zero_point_"] and (not isinstance(value_or_node, (float, int))):
# For scale and zero_point values we register them as buffers in the root module.
# However, note that when the values are not tensors, as in the case of
# per_tensor quantization, they will be treated as literals.
# However, registering them as a node seems to cause issue with dynamo
# tracing where it may consider tensor overload as opposed to default.
# With extra check of scale and zero_point being scalar, it makes
# sure that the default overload can be used.
# TODO: maybe need more complex attr name here
qparam_node = create_getattr_from_value(model, graph, node.name + key, value_or_node)
quantize_op_inputs.append(qparam_node)
else:
# for qparams that are not scale/zero_point (like axis, dtype) we store
# them as literals in the graph.
quantize_op_inputs.append(value_or_node)
with graph.inserting_after(node):
quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {})
# use the same qparams from quantize op
dq_inputs = [quantized_node] + quantize_op_inputs[1:]
user_dq_nodes = []
with graph.inserting_after(quantized_node):
for user in node.users:
if user is quantized_node:
continue
user_dq_nodes.append((user, graph.call_function(dequantize_op, tuple(dq_inputs), {})))

for user, dq_node in user_dq_nodes:
user.replace_input_with(node, dq_node)

return quantizer_insertion_tranformation
14 changes: 10 additions & 4 deletions nncf/experimental/torch_fx/statistics/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from nncf.common.tensor_statistics.aggregator import StatisticPointsContainer
from nncf.common.tensor_statistics.aggregator import StatisticsAggregator
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.experimental.torch_fx.model_transformer import PTInsertionCommand
from nncf.experimental.torch_fx.model_transformer import FXModuleInsertionCommand
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.return_types import maybe_get_values_from_torch_return_type
from nncf.torch.tensor import PTNNCFTensor
Expand All @@ -48,6 +48,13 @@ def forward(self, x: torch.Tensor):
return x


def get_statistic_fn_builder(collector: TensorCollector):
def fn(*args, **kwargs):
return TensorCollectorModule(collector)

return fn


class FXStatisticsAggregator(StatisticsAggregator):
HOOKS_GROUP_NAME = "statistics_hooks"

Expand Down Expand Up @@ -77,9 +84,8 @@ def _get_transformation_layout_extra_outputs(
for collectors in _statistic_point.algorithm_to_tensor_collectors.values():
for collector in collectors:
transformation_commands.append(
# FXInsertionCommand(
PTInsertionCommand(
_statistic_point.target_point,
FXModuleInsertionCommand(
[_statistic_point.target_point],
TensorCollectorModule(collector),
TransformationPriority.FP32_TENSOR_STATISTICS_OBSERVATION,
)
Expand Down
12 changes: 5 additions & 7 deletions nncf/quantization/algorithms/min_max/torch_fx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Optional, Set, Tuple, Union
from typing import Dict, List, Optional, Set, Tuple

import torch

Expand All @@ -26,6 +26,7 @@
from nncf.common.quantization.structs import QuantizerConfig
from nncf.experimental.common.tensor_statistics.collectors import AGGREGATORS_MAP
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.experimental.torch_fx.model_transformer import FXApplyTransformationCommand
from nncf.parameters import ModelType
from nncf.parameters import TargetDevice
from nncf.quantization.advanced_parameters import StatisticsType
Expand All @@ -36,9 +37,6 @@
from nncf.quantization.range_estimator import RangeEstimatorParameters
from nncf.torch.graph.graph import PTNNCFGraph
from nncf.torch.graph.graph import PTTargetPoint
from nncf.torch.graph.transformations.command_creation import create_quantizer_insertion_command
from nncf.torch.graph.transformations.command_creation import create_shared_quantizer_insertion_command
from nncf.torch.graph.transformations.commands import PTInsertionCommand
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.hardware.config import PTHWConfig
from nncf.torch.nncf_network import NNCFNetwork
Expand Down Expand Up @@ -287,15 +285,15 @@ def create_quantizer_insertion_command(
target_point: PTTargetPoint,
quantizer_config: QuantizerConfig,
parameters: FakeQuantizeParameters,
) -> Union[PTInsertionCommand, PTSharedFnInsertionCommand]:
) -> FXApplyTransformationCommand:
_, scale_shape, _ = FXMinMaxAlgoBackend._get_input_scale_shape(
nncf_graph, target_point, quantizer_config.per_channel
)

quantizer = FXMinMaxAlgoBackend._create_quantizer(
quantizer_config, scale_shape, parameters, target_point.target_type
)
return create_quantizer_insertion_command(target_point, quantizer)
return FXApplyTransformationCommand([target_point], quantizer)

@staticmethod
def create_unified_scales_quantizers_insertion_commands(
Expand All @@ -311,7 +309,7 @@ def create_unified_scales_quantizers_insertion_commands(
quantizer = FXMinMaxAlgoBackend._create_quantizer(
quantizer_config, scale_shape, parameters, target_points[0].target_type
)
return [create_shared_quantizer_insertion_command(target_points, quantizer)]
return [FXApplyTransformationCommand(tp, quantizer) for tp in target_points]

@staticmethod
def get_ignored_metatypes(model_type: ModelType, device: TargetDevice) -> List[OperatorMetatype]:
Expand Down

0 comments on commit 2c2921a

Please sign in to comment.