Skip to content

Commit

Permalink
Init Torch.fx BiasCorrection
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Jun 4, 2024
1 parent 3b1e7f0 commit a40c281
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 23 deletions.
44 changes: 39 additions & 5 deletions nncf/experimental/torch_fx/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@

import torch
import torch.fx

# from torch import Tensor
# from torch import nn
from torch.ao.quantization.fx.utils import create_getattr_from_value
from torch.ao.quantization.pt2e.duplicate_dq_pass import DuplicateDQPass
from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ
Expand All @@ -28,6 +25,7 @@
from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_
from torch.fx import GraphModule
from torch.fx.passes.infra.pass_manager import PassManager
from torch.fx.passes.split_utils import split_by_tags

from nncf.common.graph.model_transformer import ModelTransformer

Expand All @@ -37,6 +35,10 @@
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.commands import TransformationPriority
from nncf.common.graph.transformations.commands import TransformationType

# from torch import Tensor
# from torch import nn
from nncf.torch.graph.transformations.commands import PTModelExtractionCommand
from nncf.torch.graph.transformations.commands import PTTargetPoint

# from nncf.torch.graph.transformations.commands import PTTargetPoint
Expand Down Expand Up @@ -80,13 +82,16 @@ class FXModelTransformer(ModelTransformer):
Applies transformations upon Torch FX model.
"""

# TODO: manage priorities of transformations

def __init__(self, model: torch.fx.GraphModule):
super().__init__(model)

self._command_transformation_ordered_pairs = [
# TODO: Move the module insertion command to a transformation
(FXApplyTransformationCommand, self._apply_transformation),
(FXModuleInsertionCommand, self._apply_module_insertion),
(PTModelExtractionCommand, self._apply_model_extraction),
]

def transform(self, transformation_layout: PTTransformationLayout) -> torch.fx.GraphModule:
Expand All @@ -107,6 +112,34 @@ def transform(self, transformation_layout: PTTransformationLayout) -> torch.fx.G
model.recompile()
return model

@staticmethod
def _apply_model_extraction(
model: torch.fx.GraphModule,
transformations: List[PTModelExtractionCommand],
) -> torch.fx.GraphModule:
transformation = transformations[-1]
assert len(transformation.input_node_names) == 1
assert transformation.input_node_names == transformation.output_node_names
node_name = transformation.input_node_names[0]

tags = ["before", "extracted", "after"]
i = 0
for node in model.graph.nodes:
if node.name == node_name:
node.tag = tags[1]
weights = [node.all_input_nodes[1]]
while weights:
w_node = weights.pop()
assert w_node.tag in tags[0:2]
w_node.tag = tags[1]
weights.extend(w_node.all_input_nodes)
i = 2
continue
node.tag = tags[i]

splitted_gm = split_by_tags(model, tags)
return splitted_gm.extracted

@staticmethod
def _apply_module_insertion(
model: torch.fx.GraphModule,
Expand Down Expand Up @@ -141,15 +174,16 @@ def _apply_module_insertion(
return model

@staticmethod
def _get_grah_node_by_name(graph, name):
def get_graph_node_by_name(graph, name):
for node in graph.nodes:
if node.name == name:
return node
raise RuntimeError(f"Node with name {name} is not found")

@staticmethod
def _get_target_node(graph: torch.fx.Graph, target_point: PTTargetPoint):
target_type = target_point.target_type
target_node = FXModelTransformer._get_grah_node_by_name(graph, target_point.target_node_name)
target_node = FXModelTransformer.get_graph_node_by_name(graph, target_point.target_node_name)
if target_type in [TargetType.OPERATOR_PRE_HOOK, TargetType.OPERATION_WITH_WEIGHTS]:
target_node = target_node.all_input_nodes[target_point.input_port_id]
elif target_type == TargetType.OPERATOR_POST_HOOK:
Expand Down
51 changes: 34 additions & 17 deletions nncf/experimental/torch_fx/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torch.ao.quantization.fx.utils import create_getattr_from_value
from torch.quantization.fake_quantize import FakeQuantize

from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.transformations.commands import TargetType
from nncf.experimental.torch_fx.model_transformer import FXModelTransformer
from nncf.torch.graph.transformations.commands import PTTargetPoint
Expand Down Expand Up @@ -46,23 +47,20 @@ def fake_quantize_insertion_transformation(model: torch.fx.GraphModule):
return fake_quantize_insertion_transformation


def _set_module_to_the_graph_module(
model: torch.fx.GraphModule, module_to_insert: torch.nn.Module, target_points: List[PTTargetPoint]
) -> str:
"""
Sets given module to the given torch.fx.GraphModule with unique name.
"""
module_to_insert = module_to_insert
module_name_in_model = (
";".join(
"_".join((tp.target_node_name, str(tp.input_port_id), str(tp.target_type.value))) for tp in target_points
)
+ "_"
+ str(id(module_to_insert))
)
assert not hasattr(model, module_name_in_model)
setattr(model, module_name_in_model, module_to_insert)
return module_name_in_model
def bias_update_transformation_builder(node: NNCFNode, value: torch.Tensor):
def bias_update_transformation(model: torch.fx.GraphModule):
graph = model.graph
target_node_name = node.node_name
graph_node = FXModelTransformer.get_graph_node_by_name(graph, target_node_name)
bias_node = next(iter(graph_node.users))
with graph.inserting_before(bias_node):
new_constant = create_getattr_from_value(model, graph, target_node_name + "_shifted_bias", value)
args = list(bias_node.args)
args[1] = new_constant
bias_node.args = tuple(args)
graph.eliminate_dead_code()

return bias_update_transformation


def qdq_insertion_tranformation_builder(quantizer: FakeQuantize, target_points: List[PTTargetPoint]):
Expand Down Expand Up @@ -150,3 +148,22 @@ def insert_one_qdq(

for user, dq_node in user_dq_nodes:
user.replace_input_with(target_node, dq_node)


def _set_module_to_the_graph_module(
model: torch.fx.GraphModule, module_to_insert: torch.nn.Module, target_points: List[PTTargetPoint]
) -> str:
"""
Sets given module to the given torch.fx.GraphModule with unique name.
"""
module_to_insert = module_to_insert
module_name_in_model = (
";".join(
"_".join((tp.target_node_name, str(tp.input_port_id), str(tp.target_type.value))) for tp in target_points
)
+ "_"
+ str(id(module_to_insert))
)
assert not hasattr(model, module_name_in_model)
setattr(model, module_name_in_model, module_to_insert)
return module_name_in_model
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(

@property
def available_backends(self) -> List[BackendType]:
return [BackendType.ONNX, BackendType.OPENVINO, BackendType.TORCH]
return [BackendType.ONNX, BackendType.OPENVINO, BackendType.TORCH, BackendType.TORCH_FX]

def _set_backend_entity(self, model: TModel) -> None:
"""
Expand All @@ -116,6 +116,12 @@ def _set_backend_entity(self, model: TModel) -> None:
from nncf.quantization.algorithms.fast_bias_correction.torch_backend import PTFastBiasCorrectionAlgoBackend

self._backend_entity = PTFastBiasCorrectionAlgoBackend()
elif model_backend == BackendType.TORCH_FX:
from nncf.quantization.algorithms.fast_bias_correction.torch_fx_backend import (
FXFastBiasCorrectionAlgoBackend,
)

self._backend_entity = FXFastBiasCorrectionAlgoBackend()
else:
raise nncf.UnsupportedBackendError(
"Cannot return backend-specific entity because {} is not supported!".format(model_backend.value)
Expand Down
116 changes: 116 additions & 0 deletions nncf/quantization/algorithms/fast_bias_correction/torch_fx_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
import torch.fx
from torch.ao.quantization.pt2e.utils import _get_tensor_constant_from_node

import nncf.torch.graph.operator_metatypes as om
from nncf.common.graph import NNCFGraph
from nncf.common.graph import NNCFNode
from nncf.common.graph.definitions import NNCFGraphNodeType
from nncf.common.graph.transformations.commands import TargetType
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.experimental.tensor import Tensor
from nncf.experimental.torch_fx.model_transformer import FXApplyTransformationCommand
from nncf.experimental.torch_fx.transformations import bias_update_transformation_builder
from nncf.quantization.algorithms.fast_bias_correction.backend import FastBiasCorrectionAlgoBackend
from nncf.torch.graph.transformations.commands import PTModelExtractionCommand
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.tensor_statistics.collectors import get_mean_statistic_collector


class FXFastBiasCorrectionAlgoBackend(FastBiasCorrectionAlgoBackend):
TARGET_TYPE_TO_PT_INS_TYPE_MAP = {
TargetType.PRE_LAYER_OPERATION: TargetType.OPERATOR_PRE_HOOK,
TargetType.POST_LAYER_OPERATION: TargetType.OPERATOR_POST_HOOK,
}

@staticmethod
def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> PTTargetPoint:
if NNCFGraphNodeType.INPUT_NODE in target_node_name or target_type == TargetType.POST_LAYER_OPERATION:
port_id = None
if target_type in FXFastBiasCorrectionAlgoBackend.TARGET_TYPE_TO_PT_INS_TYPE_MAP:
target_type = FXFastBiasCorrectionAlgoBackend.TARGET_TYPE_TO_PT_INS_TYPE_MAP[target_type]
return PTTargetPoint(target_type, target_node_name, input_port_id=port_id)

@staticmethod
def create_bias_correction_command(
node: NNCFNode, bias_value: Tensor, nncf_graph: NNCFGraph
) -> FXApplyTransformationCommand:
return FXApplyTransformationCommand(bias_update_transformation_builder(node, bias_value.data))

@staticmethod
def model_extraction_command(
input_ids: List[Tuple[str, int]], output_ids: List[Tuple[str, int]]
) -> PTModelExtractionCommand:
return PTModelExtractionCommand([input_ids[0][0]], [output_ids[0][0]])

@staticmethod
def mean_statistic_collector(
channel_axis: int,
inplace: bool,
num_samples: Optional[int] = None,
window_size: Optional[int] = None,
) -> TensorCollector:
return get_mean_statistic_collector(num_samples, channel_axis, window_size)

@staticmethod
def get_sub_input_output_names(subgraph: NNCFNetwork) -> Tuple[str, str]:
# Pytorch does not have name for extracted node
return None, None

@staticmethod
def create_input_data(shape: Tuple[int], data: List[Tensor], input_name: str, channel_axis: int) -> torch.Tensor:
blob = torch.zeros(shape, dtype=data[0].data.dtype, device=data[0].data.device)
for j, idx in enumerate(np.ndindex(blob.shape[channel_axis])):
index = tuple(slice(None) if i != channel_axis else idx for i in range(blob.ndim))
blob[index] = data[j].data
return blob

@staticmethod
def get_bias_value(node: NNCFNode, nncf_graph: NNCFGraph, model: torch.fx.GraphModule) -> Tensor:
# TODO: make a node_name_vs_node map to speed up the process
from nncf.experimental.torch_fx.model_transformer import FXModelTransformer

bias_node = nncf_graph.get_next_nodes(node)[0]
graph_bias_node = FXModelTransformer.get_graph_node_by_name(model.graph, bias_node.node_name)
return Tensor(_get_tensor_constant_from_node(graph_bias_node.all_input_nodes[1], model))

@staticmethod
def get_activation_port_ids_for_bias_node(node: NNCFNode) -> Tuple[int, int]:
return 0, 0

@staticmethod
def process_model_output(raw_data: Dict, output_name: str) -> Tensor:
return Tensor(raw_data)

@staticmethod
def is_quantized_weights(node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
weight_node = nncf_graph.get_previous_nodes(node)[1]
return weight_node.node_type == "dequantize_per_channel"

@staticmethod
def is_node_with_bias(node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
# Assumes that all biases were unfused
if node.metatype in (om.PTConv1dMetatype, om.PTConv2dMetatype, om.PTConv3dMetatype, om.PTLinearMetatype):
next_nodes = nncf_graph.get_next_nodes(node)
if len(next_nodes) != 1:
return False
return next_nodes[0].metatype in (om.PTAddMetatype,)

@staticmethod
def get_node_names_for_input_output_statistics(node: NNCFNode, nncf_graph: NNCFGraph) -> Tuple[str, str]:
return node.node_name, node.node_name

0 comments on commit a40c281

Please sign in to comment.