Skip to content

Commit

Permalink
BN fused, conv and bias separated
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed May 29, 2024
1 parent 2c2921a commit 857a255
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 25 deletions.
12 changes: 4 additions & 8 deletions nncf/experimental/torch_fx/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,10 @@ def __init__(
class FXApplyTransformationCommand(Command):
def __init__(
self,
target_point: PTTargetPoint,
transformation_fn: Callable[[torch.fx.Graph, torch.fx.Node], None],
transformation_fn: Callable[[torch.fx.GraphModule], None],
priority: Union[TransformationPriority, int] = TransformationPriority.DEFAULT_PRIORITY,
):
super().__init__(TransformationType.INSERT)
self.target_point = target_point
self.tranformation_fn = transformation_fn
self.priority = priority

Expand All @@ -86,7 +84,7 @@ def __init__(self, model: torch.fx.GraphModule):
super().__init__(model)

self._command_transformation_ordered_pairs = [
(FXApplyTransformationCommand, self._apply_fn_insertion),
(FXApplyTransformationCommand, self._apply_transformation),
(FXModuleInsertionCommand, self._apply_module_insertion),
]

Expand Down Expand Up @@ -169,14 +167,12 @@ def _create_call_module_node(graph: torch.fx.Graph, target_point: PTTargetPoint,
graph.create_node("call_module", module_name, (target_node,), {}, name=module_name + "_graph_node")

@staticmethod
def _apply_fn_insertion(
def _apply_transformation(
model: torch.fx.GraphModule,
transformations: List[FXApplyTransformationCommand],
) -> torch.fx.GraphModule:
graph = model.graph
for transformation in transformations:
target_node, _ = FXModelTransformer._get_target_node_and_ctx(graph, transformation.target_point)
transformation.tranformation_fn(graph, target_node)
transformation.tranformation_fn(model)
return model


Expand Down
59 changes: 55 additions & 4 deletions nncf/experimental/torch_fx/nncf_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
from typing import Tuple

import torch.fx
from torch.ao.quantization.fx.utils import create_getattr_from_value
from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_ # noqa
from torch.ao.quantization.pt2e.utils import _get_tensor_constant_from_node
from torch.ao.quantization.pt2e.utils import _is_conv # noqa

import nncf.torch.graph.operator_metatypes as om
from nncf.common.graph import NNCFGraph
Expand Down Expand Up @@ -57,6 +60,8 @@ def _get_node_type_and_metatype(node: torch.fx.Node) -> Tuple[str, om.OperatorMe
if hasattr(node.target, "overloadpacket"):
torch.nn.BatchNorm2d
node_type = str(node.target.overloadpacket).split(".")[1]
elif node.target.__name__ == "getitem":
node_type = "__getitem__"
else:
# TODO: get correct nodes types from this nodes as well
node_type = str(node.target)
Expand All @@ -71,6 +76,52 @@ def _get_node_type_and_metatype(node: torch.fx.Node) -> Tuple[str, om.OperatorMe
node_metatype = UnknownMetatype
return node_type, node_metatype

@staticmethod
def _separate_conv_and_bias(model: torch.fx.GraphModule):
"""
Separates one joined conv+bias node to two nodes: conv and bias.
Needed as nncf does not expect joined conv
"""
add_node_target = torch.ops.aten.add_.Tensor
for n in model.graph.nodes:
if not _is_conv(n):
continue
if len(n.args) < 3 or n.args[2] is None:
continue
conv_node = n
dims = len(_get_tensor_constant_from_node(conv_node.args[1], model).shape)
conv_bias_node = conv_node.args[2]
conv_bias_value = _get_tensor_constant_from_node(conv_bias_node, model)
args = list(n.args)
args[2] = None
conv_node.args = tuple(args)
with model.graph.inserting_after(conv_node):
new_conv_bias_node = create_getattr_from_value(
model,
model.graph,
conv_bias_node.name + "_",
conv_bias_value.reshape(
(
1,
-1,
)
+ (1,) * (dims - 2)
),
)
with model.graph.inserting_after(new_conv_bias_node):
add_node = model.graph.create_node(
"call_function", add_node_target, (conv_node, new_conv_bias_node), {}
)
for user in list(conv_node.users):
if user is add_node:
continue
user.replace_input_with(conv_node, add_node)

if "val" in conv_node.meta:
add_node.meta["val"] = conv_node.meta["val"]
model.graph.eliminate_dead_code()
model.recompile()

@staticmethod
def create_nncf_graph(model: torch.fx.GraphModule) -> NNCFGraph:
"""
Expand All @@ -82,10 +133,10 @@ def create_nncf_graph(model: torch.fx.GraphModule) -> NNCFGraph:
:return: NNCFGraph.
"""

# _fuse_conv_bn_(model)

# model.graph.eliminate_dead_code()
# model.recompile()
_fuse_conv_bn_(model)
# BN fuses to conv bias, conv+bias joined op
# needs to be splited for nncf
GraphConverter._separate_conv_and_bias(model)

nncf_graph = PTNNCFGraph()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,70 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Optional
from typing import Callable, List, Optional

import torch
import torch.fx
from torch.ao.quantization.fx.utils import create_getattr_from_value
from torch.quantization.fake_quantize import FakeQuantize

from nncf.experimental.torch_fx.model_transformer import FXModelTransformer
from nncf.quantization.fake_quantize import FakeQuantizeParameters
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.quantization.layers import PTQuantizerSpec


def quantizer_insertion_tranformation_builder(
qspec: PTQuantizerSpec, fq_params: FakeQuantizeParameters, axis: int, eps=1e-5
):
def stat_collectorts_insertion_tranformation_builder():
def stat_collectorts_insertion_tranformation(model: torch.fx.GraphModule, node: torch.fx.Node):
pass

return stat_collectorts_insertion_tranformation


def fake_quantize_insertion_tranformation_builder(quantizer: FakeQuantize, target_points: List[PTTargetPoint]):
def fake_quantize_insertion_transformation(model: torch.fx.GraphModule):
module_attr_name = _set_module_to_the_graph_module(model, quantizer, target_points)
graph = model.graph
for target_point in target_points:
target_node, ctx = FXModelTransformer._get_target_node_and_ctx(model.graph, target_point)
with ctx:
fq_node = graph.create_node(
"call_module", module_attr_name, (target_node,), {}, name=module_attr_name + "_quantizer"
)
for user in list(target_node.users):
if user is fq_node:
continue
user.replace_input_with(target_node, fq_node)

return fake_quantize_insertion_transformation


def _set_module_to_the_graph_module(
model: torch.fx.GraphModule, module_to_insert: torch.nn.Module, target_points: List[PTTargetPoint]
) -> str:
"""
Sets given module to the given torch.fx.GraphModule with unique name.
"""
module_to_insert = module_to_insert
module_name_in_model = (
";".join(
"_".join((tp.target_node_name, str(tp.input_port_id), str(tp.target_type.value))) for tp in target_points
)
+ "_"
+ str(id(module_to_insert))
)
assert not hasattr(model, module_name_in_model)
setattr(model, module_name_in_model, module_to_insert)
return module_name_in_model


def qdq_insertion_tranformation_builder(qspec: PTQuantizerSpec, fq_params: FakeQuantizeParameters, axis: int, eps=1e-5):
# signed = bool(torch.any(fq_params.input_low.data < 0))
# Subtract eps from the scale to make quantizer parameters equal to
# original parameters on the forward call.
scale = (fq_params.input_high.data - eps).reshape(qspec.scale_shape)

def quantizer_insertion_tranformation(model: torch.fx.GraphModule, node: torch.fx.Node):
def qdq_insertion_tranformation(model: torch.fx.GraphModule, node: torch.fx.Node):
# 1. extract information for inserting q/dq node from activation_post_process
node_type = "call_function"
quantize_op: Optional[Callable] = None
Expand Down Expand Up @@ -87,4 +132,4 @@ def quantizer_insertion_tranformation(model: torch.fx.GraphModule, node: torch.f
for user, dq_node in user_dq_nodes:
user.replace_input_with(node, dq_node)

return quantizer_insertion_tranformation
return qdq_insertion_tranformation
21 changes: 15 additions & 6 deletions nncf/quantization/algorithms/min_max/torch_fx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from nncf.experimental.common.tensor_statistics.collectors import AGGREGATORS_MAP
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.experimental.torch_fx.model_transformer import FXApplyTransformationCommand
from nncf.experimental.torch_fx.transformations import fake_quantize_insertion_tranformation_builder
from nncf.parameters import ModelType
from nncf.parameters import TargetDevice
from nncf.quantization.advanced_parameters import StatisticsType
Expand All @@ -46,6 +47,7 @@
from nncf.torch.quantization.layers import BaseQuantizer
from nncf.torch.quantization.layers import PTQuantizerSpec
from nncf.torch.quantization.layers import get_scale_shape
from nncf.torch.quantization.strip import convert_to_torch_fakequantizer
from nncf.torch.tensor_statistics.collectors import PT_REDUCERS_MAP
from nncf.torch.tensor_statistics.collectors import PTNNCFCollectorTensorProcessor
from nncf.torch.tensor_statistics.statistics import PTMinMaxTensorStatistic
Expand Down Expand Up @@ -263,7 +265,9 @@ def _create_quantizer(

# Fill it with minmax
FXMinMaxAlgoBackend._fill_quantizer_parameters(quantizer, parameters, quantizer_spec.scale_shape)
return quantizer
# Convert to the torch fake quantizer
torch_fq = convert_to_torch_fakequantizer(quantizer)
return torch_fq

@staticmethod
def _fill_quantizer_parameters(quantizer: BaseQuantizer, parameters: FakeQuantizeParameters, scale_shape) -> None:
Expand Down Expand Up @@ -293,7 +297,8 @@ def create_quantizer_insertion_command(
quantizer = FXMinMaxAlgoBackend._create_quantizer(
quantizer_config, scale_shape, parameters, target_point.target_type
)
return FXApplyTransformationCommand([target_point], quantizer)
transformation = fake_quantize_insertion_tranformation_builder(quantizer, [target_point])
return FXApplyTransformationCommand(transformation)

@staticmethod
def create_unified_scales_quantizers_insertion_commands(
Expand All @@ -309,7 +314,9 @@ def create_unified_scales_quantizers_insertion_commands(
quantizer = FXMinMaxAlgoBackend._create_quantizer(
quantizer_config, scale_shape, parameters, target_points[0].target_type
)
return [FXApplyTransformationCommand(tp, quantizer) for tp in target_points]

transformation = fake_quantize_insertion_tranformation_builder(quantizer, target_points)
return [FXApplyTransformationCommand(transformation)]

@staticmethod
def get_ignored_metatypes(model_type: ModelType, device: TargetDevice) -> List[OperatorMetatype]:
Expand Down Expand Up @@ -347,7 +354,9 @@ def get_ignored_names_by_layer_attributes(nncf_graph: NNCFGraph) -> List[str]:
def get_weight_nodes(nncf_graph: NNCFGraph) -> List[NNCFNode]:
retval = set()
for node in nncf_graph.get_all_nodes():
if node.metatype is om.PTConstNoopMetatype:
for node in nncf_graph.get_next_nodes(node):
retval.add(node)
if node.metatype in [om.PTConv1dMetatype, om.PTConv2dMetatype, om.PTConv3dMetatype, om.PTLinearMetatype]:
retval.add(node)
# if node.metatype is om.PTConstNoopMetatype:
# for node in nncf_graph.get_next_nodes(node):
# retval.add(node)
return list(retval)
5 changes: 5 additions & 0 deletions nncf/torch/graph/pattern_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@
GraphPattern.LABEL_ATTR: "BATCH_NORMALIZATION",
}

GETITEM_OPERATIONS = {
GraphPattern.METATYPE_ATTR: ["index_select", "__getitem__", "gather", "index_select", "where"],
GraphPattern.LABEL_ATTR: "GETITEM",
}

GROUP_NORMALIZATION_OPERATIONS = {
GraphPattern.METATYPE_ATTR: ["group_norm"],
GraphPattern.LABEL_ATTR: "GROUP_NORMALIZATION",
Expand Down
8 changes: 7 additions & 1 deletion nncf/torch/hardware/fused_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from nncf.torch.graph.pattern_operations import ARITHMETIC_OPERATIONS
from nncf.torch.graph.pattern_operations import ATOMIC_ACTIVATIONS_OPERATIONS
from nncf.torch.graph.pattern_operations import BATCH_NORMALIZATION_OPERATIONS
from nncf.torch.graph.pattern_operations import GETITEM_OPERATIONS
from nncf.torch.graph.pattern_operations import GROUP_NORMALIZATION_OPERATIONS
from nncf.torch.graph.pattern_operations import LINEAR_OPERATIONS
from nncf.torch.graph.pattern_operations import RELU_OPERATIONS
Expand Down Expand Up @@ -199,7 +200,12 @@ def arithmetic_operations() -> GraphPattern:
def batch_norm_operations() -> GraphPattern:
pattern = GraphPattern()
pattern.add_node(**BATCH_NORMALIZATION_OPERATIONS)
return pattern
pattern_alt = GraphPattern()
bn = pattern_alt.add_node(**BATCH_NORMALIZATION_OPERATIONS)
get_item = pattern_alt.add_node(**GETITEM_OPERATIONS)
pattern_alt.add_edge(bn, get_item)
pattern.add_pattern_alternative(pattern_alt)
return pattern_alt


def activation_operations() -> GraphPattern:
Expand Down
2 changes: 2 additions & 0 deletions torch_compile_ex_release.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ def quantize(model, example_inputs):

calibration_dataset = nncf.Dataset(example_inputs)
quantized_model = nncf.quantize(exported_model, calibration_dataset)
g = FxGraphDrawer(quantized_model, "resnet18_quantized_native_nncf")
g.get_dot_graph().write_svg("resnet18_quantized_native_nncf.svg")
return quantized_model

else:
Expand Down

0 comments on commit 857a255

Please sign in to comment.