Skip to content

Commit

Permalink
Smooth quant algorithm implementation
Browse files Browse the repository at this point in the history
Swin transformer conformance test

FXSQMultiply

Refereces update
  • Loading branch information
daniil-lyakhov committed Aug 13, 2024
1 parent 55e653b commit 10a8997
Show file tree
Hide file tree
Showing 16 changed files with 9,105 additions and 8,638 deletions.
4 changes: 0 additions & 4 deletions nncf/experimental/common/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,6 @@ def get_inplace_fn(self) -> Optional[InplaceInsertionFNType]:
"""

def __call__(self, x: List[Tensor]):
# try:
# any(t.isempty() for t in x)
# except:
# breakpoint()
if any(t.isempty() for t in x):
return None

Expand Down
26 changes: 5 additions & 21 deletions nncf/experimental/torch/fx/nncf_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,25 +98,6 @@ def create_nncf_graph(model: torch.fx.GraphModule) -> PTNNCFGraph:
output_port_id=output_port_id,
dtype=Dtype.FLOAT,
)
if source_nncf_node.metatype in [om.PTMinMetatype, om.PTMaxMetatype]:
offset = len(source_node.users)
output_tensors = source_node.meta.get("val", [])
output_tensors = (output_tensors,) if isinstance(output_tensors, torch.Tensor) else output_tensors

for idx, tensor in enumerate(output_tensors[offset:]):
curr_idx = offset + idx
result = nncf_graph.add_nncf_node(
f"{source_nncf_node.node_name}_output_{curr_idx}", "output", om.PTOutputNoopMetatype
)
nncf_graph.add_edge_between_nncf_nodes(
source_nncf_node.node_id,
result.node_id,
tensor_shape=tuple(tensor.shape),
input_port_id=0,
output_port_id=curr_idx,
dtype=Dtype.FLOAT,
)

return nncf_graph

@staticmethod
Expand All @@ -139,6 +120,7 @@ def get_edge_params(
edge tensor shape.
"""
output_port_id = 0
tensor_shape = None
if source_node.op in ("get_attr",):
tensor_shape = tuple(getattr(model, source_node.target).shape)
elif "val" in source_node.meta:
Expand All @@ -150,8 +132,10 @@ def get_edge_params(
output_port_id = output_idx
else:
tensor = source_node.meta["val"]
tensor_shape = tuple(tensor.shape)
else:
if isinstance(tensor, torch.Tensor):
tensor_shape = tuple(tensor.shape)

if tensor_shape is None:
# TODO(dlyakhov): Refactor algorithms to always have knowns edges shapes.
nncf_logger.debug(f"Edge shape between {source_node.name} and {dist_node.name} is unknown.")
tensor_shape = None
Expand Down
9 changes: 8 additions & 1 deletion nncf/experimental/torch/fx/statistics/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,15 @@ def _get_transformation_layout_extra_outputs(
for _statistic_point in _statistic_points:
for collectors in _statistic_point.algorithm_to_tensor_collectors.values():
for collector in collectors:
tp = _statistic_point.target_point
module_to_insert = TensorCollectorModule(collector)
target_module_name = (
"_".join([tp.target_node_name, str(tp.input_port_id), str(tp.target_type.value)])
+ "_"
+ str(id(module_to_insert))
)
transformation = leaf_module_insertion_transformation_builder(
TensorCollectorModule(collector), [_statistic_point.target_point]
module_to_insert, [tp], target_module_name
)
transformation_commands.append(
FXApplyTransformationCommand(
Expand Down
102 changes: 56 additions & 46 deletions nncf/experimental/torch/fx/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,27 @@
TransformationFNType = Callable[[torch.fx.GraphModule], None]


def _get_node_tesnor_output(node: torch.fx.Node) -> List[torch.Tensor]:
val = node.meta["val"]
def _set_new_node_meta(new_node: torch.fx.Node, prev_node: torch.fx.Node, target_module: torch.nn.Module):
"""
Sets correct meta \"val\" value to the new node.
:param new_node: The new node.
:param prev_node: Input node of the new node.
New node expected to have only one input node.
:param target_module: Module which is being called by the new node.
"""
val = prev_node.meta["val"]
val = val if isinstance(val, tuple) else (val,)
retval = []
for t in val:
retval.append(torch.ones(t.shape))
return retval


def _set_new_node_meta(new_node: torch.fx.Node, prev_node: torch.fx.Node, target_module: torch.nn.Module):
with torch.no_grad():
new_node.meta["val"] = target_module(*_get_node_tesnor_output(prev_node))
new_node.meta["val"] = target_module(*val)


def module_insertion_transformation_builder(
module_to_insert: torch.nn.Module, target_points: List[PTTargetPoint]
module_to_insert: torch.nn.Module, target_points: List[PTTargetPoint], target_module_name: str
) -> TransformationFNType:
"""
Returns transformation which inserts given module to a target model
Expand All @@ -50,16 +55,17 @@ def module_insertion_transformation_builder(
:param module_to_insert: Given torch.nn.Module to insert.
:param target_points: Target points to insert the target module.
:param target_module_name: Target model attribute name for the module_to_insert.
:returns: Transformation which which inserts given module to a target model
and calls given module after each target points.
"""

def module_insertion_transformation(model: torch.fx.GraphModule):
module_attr_name = _set_module_to_the_graph_module(model, module_to_insert, target_points)
module_attr_name = _set_module_to_the_graph_module(model, module_to_insert, target_module_name)
# Insert call_module nodes to the model
graph = model.graph
for target_point in target_points:
new_node = _insert_call_module(graph, target_point, module_attr_name)
for idx, target_point in enumerate(target_points):
new_node = _insert_call_module(graph, target_point, module_attr_name, f"{module_attr_name}_{idx}")
target_node = get_graph_node_by_name(graph, target_point.target_node_name)

if target_point.target_type == TargetType.OPERATOR_POST_HOOK:
Expand All @@ -79,35 +85,36 @@ def module_insertion_transformation(model: torch.fx.GraphModule):


def leaf_module_insertion_transformation_builder(
module_to_insert: torch.nn.Module, target_points: List[PTTargetPoint]
module_to_insert: torch.nn.Module, target_points: List[PTTargetPoint], target_module_name: str
) -> TransformationFNType:
"""
Returns transformation which inserts given module to a target model
and calls given module after each target points.
:param module_to_insert: Given torch.nn.Module to insert.
:param target_points: Target points to insert the target module.
:param target_module_name: Target model attribute name for the module_to_insert.
:returns: Transformation which which inserts given module to a target model
and calls given module after each target points.
"""

def leaf_module_insertion_transformation(model: torch.fx.GraphModule):
module_attr_name = _set_module_to_the_graph_module(model, module_to_insert, target_points)
module_attr_name = _set_module_to_the_graph_module(model, module_to_insert, target_module_name)
# Insert call_module nodes to the model
graph = model.graph
for target_point in target_points:
_insert_call_module(graph, target_point, module_attr_name)
for idx, target_point in enumerate(target_points):
_insert_call_module(graph, target_point, module_attr_name, f"{module_attr_name}_{idx}")

return leaf_module_insertion_transformation


def bias_update_transformation_builder(node: NNCFNode, value: torch.Tensor) -> TransformationFNType:
"""
Return transformation which updates constant of the given bias node to the given value.
Return transformation which updates constant of the given node with bias to the given value.
:param node: Bias node which requires bias constant update.
:param node: Node with bias which requires bias constant update.
:param value: New value to use as the bias constant.
:return: Transformation which updates constant of the given bias node to the given value.
:return: Transformation which updates constant of the given node with bias to the given value.
"""

def bias_update_transformation(model: torch.fx.GraphModule):
Expand All @@ -124,28 +131,42 @@ def bias_update_transformation(model: torch.fx.GraphModule):


def constant_update_transformation_builder(node: NNCFNode, value: torch.Tensor) -> TransformationFNType:
"""
Return transformation which updates constant of the given node to the given value.
:param node: Node which requires bias constant update.
:param value: New value to use as the node constant.
:return: Transformation which updates constant of the given node to the given value.
"""

def constant_update_transformation(model: torch.fx.GraphModule):
constant_update_fn(model, get_graph_node_by_name(model.graph, node.node_name), value, input_port_id=1)

return constant_update_transformation


def constant_update_fn(
model: torch.fx.GraphModule, graph_node: torch.fx.Node, value: torch.Tensor, input_port_id: int = 1
):
def constant_update_fn(model: torch.fx.GraphModule, node: torch.fx.Node, value: torch.Tensor, input_port_id: int = 1):
"""
Updates constant of given node on the given input port id with given value.
:param model: Target torch GraphModule.
:param node: Given graph node.
:param value: New value to use as the node constant.
:param input_port_id: Target constant input port id.
"""
graph = model.graph
with graph.inserting_before(graph_node):
new_constant = create_getattr_from_value(model, graph, graph_node.name + "_updated_constant", value)
with graph.inserting_before(node):
new_constant = create_getattr_from_value(model, graph, node.name + "_updated_constant", value)

args = list(graph_node.args)
args = list(node.args)
# A bias node suppose to have constant on the second input port.
if args[input_port_id].op != "get_attr":
raise nncf.InternalError(
f"Constant on input port {input_port_id} for {graph_node} is expected,"
f"Constant on input port {input_port_id} for {node} is expected,"
f" but node {args[input_port_id]} is present."
)
args[input_port_id] = new_constant
graph_node.args = tuple(args)
node.args = tuple(args)
graph.eliminate_dead_code()


Expand Down Expand Up @@ -270,26 +291,23 @@ def insert_one_qdq(model: torch.fx.GraphModule, target_point: PTTargetPoint, qua
raise nncf.InternalError(f"Unexpected target type: {target_point.target_type}")


def _insert_call_module(graph: torch.fx.Graph, target_point: PTTargetPoint, module_attr_name: str):
def _insert_call_module(
graph: torch.fx.Graph, target_point: PTTargetPoint, module_attr_name: str, graph_node_name: str
):
"""
Inserts module call node to the graph after the target node.
:param graph: Graph to insert module call node.
:param target_node: Target node, module call node is being iserted just after the target node.
:param module_attr_name: The name of the graph attribute which keeps the target module.
:return: Target node used
:param graph_node_name: Target name for module call node.
:return: Inserted module call node.
"""
target_node = get_graph_node_by_name(graph, target_point.target_node_name)
input_node = get_input_node(target_point, target_node)
ctx_manager = get_ctx_manager(graph, target_point)
with ctx_manager(target_node):
return graph.create_node(
"call_module",
module_attr_name,
(input_node,),
{},
name=f"{module_attr_name}_{str(target_point.target_type)}_graph_node",
)
return graph.create_node("call_module", module_attr_name, (input_node,), {}, name=graph_node_name)


def get_input_node(target_point: PTTargetPoint, target_node: torch.fx.Node) -> torch.fx.Node:
Expand Down Expand Up @@ -335,26 +353,18 @@ def get_ctx_manager(graph: torch.fx.Graph, target_point: PTTargetPoint) -> Calla


def _set_module_to_the_graph_module(
model: torch.fx.GraphModule, module_to_insert: torch.nn.Module, target_points: List[PTTargetPoint]
model: torch.fx.GraphModule,
module_to_insert: torch.nn.Module,
module_name_in_model: str,
) -> str:
"""
Sets given module to the given torch.fx.GraphModule with unique name.
:param graph: Target torch.fx.Graph.
:param module_to_insert: Module to insert to the target graph.
:param target_points: Target points which will be used to insert target module
to the graph.
:param module_name_in_model: Target model attribute name for the module_to_insert.
:return: A graph module attribute name which keep given module.
"""
module_to_insert = module_to_insert
# TODO(dlyakhov) Make module name human readable.
module_name_in_model = (
"__".join(
"_".join((tp.target_node_name, str(tp.input_port_id), str(tp.target_type.value))) for tp in target_points
)
+ "_"
+ str(id(module_to_insert))
)
assert not hasattr(model, module_name_in_model)
setattr(model, module_name_in_model, module_to_insert)
return module_name_in_model
Expand Down
6 changes: 3 additions & 3 deletions nncf/quantization/algorithms/min_max/torch_fx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.experimental.common.tensor_statistics.statistics import MinMaxTensorStatistic
from nncf.experimental.torch.fx.commands import FXApplyTransformationCommand
from nncf.experimental.torch.fx.transformations import qdq_insertion_transformation_builder
from nncf.experimental.torch.fx.transformations import qdq_insertion_tranformation_builder
from nncf.parameters import ModelType
from nncf.parameters import TargetDevice
from nncf.quantization.advanced_parameters import StatisticsType
Expand Down Expand Up @@ -288,7 +288,7 @@ def create_quantizer_insertion_command(
quantizer = FXMinMaxAlgoBackend._create_quantizer(
quantizer_config, scale_shape, parameters, target_point.target_type
)
transformation = qdq_insertion_transformation_builder(quantizer, [target_point])
transformation = qdq_insertion_tranformation_builder(quantizer, [target_point])
return FXApplyTransformationCommand(transformation)

@staticmethod
Expand All @@ -308,7 +308,7 @@ def create_unified_scales_quantizers_insertion_commands(

transformations = []
for tp in target_points:
transformation = qdq_insertion_transformation_builder(quantizer, [tp])
transformation = qdq_insertion_tranformation_builder(quantizer, [tp])
transformations.append(FXApplyTransformationCommand(transformation))
return transformations

Expand Down
21 changes: 16 additions & 5 deletions nncf/quantization/algorithms/smooth_quant/torch_fx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from nncf.openvino.graph.transformations.commands import OVMultiplyInsertionCommand
from nncf.openvino.graph.transformations.commands import OVWeightUpdateCommand
from nncf.quantization.algorithms.smooth_quant.backend import SmoothQuantAlgoBackend
from nncf.quantization.algorithms.smooth_quant.torch_backend import SQMultiply
from nncf.tensor import Tensor
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.model_graph_manager import get_const_node
Expand All @@ -40,6 +39,15 @@
PT_PRE_LAYER_TARGET_TYPE = TargetType.OPERATOR_PRE_HOOK


class FXSQMultiply(torch.nn.Module):
def __init__(self, scale: torch.Tensor):
super().__init__()
self._scale_value = scale

def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.mul(x, self._scale_value)


class FXSmoothQuantAlgoBackend(SmoothQuantAlgoBackend):
@property
def convolution_metatypes(self) -> List[OperatorMetatype]:
Expand Down Expand Up @@ -112,9 +120,10 @@ def scale_insertion_command(
for node in nodes:
target_points.append(PTTargetPoint(PT_PRE_LAYER_TARGET_TYPE, node.node_name, input_port_id=input_port_id))

sq_multiply = SQMultiply(scale_value.shape)
sq_multiply.scale = scale_value
return FXApplyTransformationCommand(module_insertion_transformation_builder(sq_multiply, target_points))
sq_multiply = FXSQMultiply(scale_value)
return FXApplyTransformationCommand(
module_insertion_transformation_builder(sq_multiply, target_points, scale_node_name)
)

@staticmethod
def get_activation_channel_axis(node: NNCFNode, port_id: int) -> int:
Expand All @@ -130,7 +139,9 @@ def get_weight_channel_axis(node: NNCFNode) -> int:

@staticmethod
def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
return node.is_shared()
# TODO(dlyakvho): Support shared layers in TorchFX.
# Ref: 149316
return False

@staticmethod
def get_filter_fn_for_statistics(activation_port_id: int, algorithm_key: str) -> Callable[[StatisticPoint], bool]:
Expand Down
9 changes: 4 additions & 5 deletions tests/openvino/native/test_smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import torch

import nncf
from nncf.common.graph.transformations.commands import TransformationCommand
from nncf.openvino.graph.layer_attributes import OVLayerAttributes
from nncf.openvino.graph.layout import OVLayoutElem
from nncf.openvino.graph.metatypes.openvino_metatypes import OVConvolutionMetatype
Expand Down Expand Up @@ -68,6 +67,10 @@


class TestOVSQAlgorithm(TemplateTestSQAlgorithm):
@staticmethod
def backend_supports_shared_layers() -> bool:
return True

@staticmethod
def fn_to_type(tensor) -> np.ndarray:
return np.array(tensor)
Expand All @@ -85,10 +88,6 @@ def get_node_name_map(self, model_cls) -> Dict[str, str]:
return {}
raise NotImplementedError

@staticmethod
def get_target_node_name(command: TransformationCommand):
return command.target_point.target_node_name

@staticmethod
def get_transform_fn() -> Callable:
def transform_fn(data_item):
Expand Down
Loading

0 comments on commit 10a8997

Please sign in to comment.