Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TorchFX] SmoothQuant algorithm implementation #2875

Merged
merged 4 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions nncf/experimental/torch/fx/nncf_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ def create_nncf_graph(model: torch.fx.GraphModule) -> PTNNCFGraph:
output_port_id=output_port_id,
dtype=Dtype.FLOAT,
)

return nncf_graph

@staticmethod
Expand All @@ -121,22 +120,24 @@ def get_edge_params(
edge tensor shape.
"""
output_port_id = 0
AlexanderDokuchaev marked this conversation as resolved.
Show resolved Hide resolved
tensor_shape = None
if source_node.op in ("get_attr",):
tensor_shape = tuple(getattr(model, source_node.target).shape)
elif "val" in source_node.meta:
if source_nncf_node.metatype is om.PTBatchNormMetatype:
tensor = source_node.meta["val"][0]
AlexanderDokuchaev marked this conversation as resolved.
Show resolved Hide resolved
elif source_nncf_node.metatype is om.PTSplitMetatype:
elif source_nncf_node.metatype in [om.PTSplitMetatype, om.PTMaxMetatype, om.PTMinMetatype]:
tensor = source_node.meta["val"][output_idx]
# Assume every split outputs corresponds to an unique output_port_id
# Assume every outputs corresponds to an unique output_port_id
output_port_id = output_idx
else:
tensor = source_node.meta["val"]
tensor_shape = tuple(tensor.shape)
else:
if isinstance(tensor, torch.Tensor):
tensor_shape = tuple(tensor.shape)

if tensor_shape is None:
# TODO(dlyakhov): Refactor algorithms to always have knowns edges shapes.
nncf_logger.debug(f"Edge shape between {source_node.name} and {dist_node.name} is unknown.")
tensor_shape = None

input_port_id = dist_node.all_input_nodes.index(source_node)
return input_port_id, output_port_id, tensor_shape
23 changes: 22 additions & 1 deletion nncf/experimental/torch/fx/node_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.fx


# TODO(dlyakhov): Use torch.fx.graph.find_nodes method instead after
Expand All @@ -28,3 +28,24 @@ def get_graph_node_by_name(graph: torch.fx.Graph, name: str) -> torch.fx.Node:
if node.name == name:
return node
raise RuntimeError(f"Node with name {name} is not found")


def get_tensor_constant_from_node(constant_node: torch.fx.Node, model: torch.fx.GraphModule) -> torch.nn.Parameter:
"""
Retrieves tensor from the given constant node.
:param constant_node: Given constant node.
:param model: Given model.
:return: Torch tensor referenced by the given constant node.
"""
if constant_node is None:
return None
if constant_node.op != "get_attr":
raise RuntimeError(f"Given node op == {constant_node.op}, but get_attr is expected.")
target_atoms = constant_node.target.split(".")
attr_itr = model
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
attr_itr = getattr(attr_itr, atom)
return attr_itr
24 changes: 23 additions & 1 deletion nncf/experimental/torch/fx/statistics/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from nncf.experimental.torch.fx.commands import FXApplyTransformationCommand
from nncf.experimental.torch.fx.transformations import leaf_module_insertion_transformation_builder
from nncf.tensor import Tensor
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.return_types import maybe_get_values_from_torch_return_type

Expand Down Expand Up @@ -65,6 +66,24 @@ def collect_statistics(self, model: NNCFNetwork, graph: NNCFGraph) -> None:
def _register_statistics(self, outputs: Dict[str, Tensor], statistic_points: StatisticPointsContainer) -> None:
return

@staticmethod
def _get_statistic_collector_name(tp: PTTargetPoint, module_to_insert: torch.nn.Module) -> str:
"""
Compouses unique statistic collector name according to given target point and module.
:param tp: Given target point.
:param module_to_insert: Given statistic collection module.
:return: Unique statistic collector name according to given target point and module.
"""
return "_".join(
[
tp.target_node_name,
str(tp.input_port_id),
str(tp.target_type.value),
str(id(module_to_insert)),
]
)

def _get_transformation_layout_extra_outputs(
self, statistic_points: StatisticPointsContainer
) -> TransformationLayout:
Expand All @@ -75,8 +94,11 @@ def _get_transformation_layout_extra_outputs(
for _statistic_point in _statistic_points:
for collectors in _statistic_point.algorithm_to_tensor_collectors.values():
for collector in collectors:
tp = _statistic_point.target_point
module_to_insert = TensorCollectorModule(collector)
target_module_name = self._get_statistic_collector_name(tp, module_to_insert)
transformation = leaf_module_insertion_transformation_builder(
TensorCollectorModule(collector), [_statistic_point.target_point]
module_to_insert, [tp], target_module_name
)
transformation_commands.append(
FXApplyTransformationCommand(
Expand Down
162 changes: 122 additions & 40 deletions nncf/experimental/torch/fx/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,49 +14,108 @@
import torch
import torch.fx
from torch.ao.quantization.fx.utils import create_getattr_from_value
from torch.ao.quantization.pt2e.utils import _get_tensor_constant_from_node
from torch.ao.quantization.pt2e.utils import fold_bn_weights_into_conv_node
from torch.quantization.fake_quantize import FakeQuantize

import nncf
from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.transformations.commands import TargetType
from nncf.experimental.torch.fx.node_utils import get_graph_node_by_name
from nncf.experimental.torch.fx.node_utils import get_tensor_constant_from_node
from nncf.torch.graph.transformations.commands import PTTargetPoint

TransformationFNType = Callable[[torch.fx.GraphModule], None]


def _set_new_node_meta(new_node: torch.fx.Node, prev_node: torch.fx.Node, target_module: torch.nn.Module):
"""
Sets correct meta \"val\" value to the new node.
:param new_node: The new node.
:param prev_node: Input node of the new node.
New node expected to have only one input node.
:param target_module: Module which is being called by the new node.
"""
val = prev_node.meta["val"]
val = val if isinstance(val, tuple) else (val,)
retval = []
for t in val:
retval.append(torch.ones(t.shape))

with torch.no_grad():
new_node.meta["val"] = target_module(*val)


def module_insertion_transformation_builder(
module_to_insert: torch.nn.Module, target_points: List[PTTargetPoint], target_module_name: str
) -> TransformationFNType:
"""
Returns transformation which inserts given module to a target model
and calls given module after each target points replacing inputs/outputs
of the target node.
:param module_to_insert: Given torch.nn.Module to insert.
:param target_points: Target points to insert the target module.
:param target_module_name: Target model attribute name for the module_to_insert.
:returns: Transformation which which inserts given module to a target model
and calls given module after each target points.
"""

def module_insertion_transformation(model: torch.fx.GraphModule):
module_attr_name = _set_module_to_the_graph_module(model, module_to_insert, target_module_name)
# Insert call_module nodes to the model
graph = model.graph
for idx, target_point in enumerate(target_points):
new_node = _insert_call_module(graph, target_point, module_attr_name, f"{module_attr_name}_{idx}")
target_node = get_graph_node_by_name(graph, target_point.target_node_name)

if target_point.target_type == TargetType.OPERATOR_POST_HOOK:
_set_new_node_meta(new_node, target_node, module_to_insert)
with graph.inserting_after(target_node):
for user in target_node.users:
if user is new_node:
continue
user.replace_input_with(target_node, new_node)

else:
prev_node = target_node.args[target_point.input_port_id]
_set_new_node_meta(new_node, prev_node, module_to_insert)
target_node.replace_input_with(prev_node, new_node)

return module_insertion_transformation


def leaf_module_insertion_transformation_builder(
module_to_insert: torch.nn.Module, target_points: List[PTTargetPoint]
module_to_insert: torch.nn.Module, target_points: List[PTTargetPoint], target_module_name: str
) -> TransformationFNType:
"""
Returns transformation which inserts given module to a target model
and calls given module after each target points.
:param module_to_insert: Given torch.nn.Module to insert.
:param target_points: Target points to insert the target module.
:param target_module_name: Target model attribute name for the module_to_insert.
:returns: Transformation which which inserts given module to a target model
and calls given module after each target points.
"""

def leaf_module_insertion_transformation(model: torch.fx.GraphModule):
module_attr_name = _set_module_to_the_graph_module(model, module_to_insert, target_points)
module_attr_name = _set_module_to_the_graph_module(model, module_to_insert, target_module_name)
# Insert call_module nodes to the model
graph = model.graph
for target_point in target_points:
_insert_call_module(graph, target_point, module_attr_name)
for idx, target_point in enumerate(target_points):
_insert_call_module(graph, target_point, module_attr_name, f"{module_attr_name}_{idx}")

return leaf_module_insertion_transformation


def bias_update_transformation_builder(node: NNCFNode, value: torch.Tensor) -> TransformationFNType:
"""
Return transformation which updates constant of the given bias node to the given value.
Return transformation which updates constant of the given node with bias to the given value.
:param node: Bias node which requires bias constant update.
:param node: Node with bias which requires bias constant update.
:param value: New value to use as the bias constant.
:return: Transformation which updates constant of the given bias node to the given value.
:return: Transformation which updates constant of the given node with bias to the given value.
"""

def bias_update_transformation(model: torch.fx.GraphModule):
Expand All @@ -67,18 +126,51 @@ def bias_update_transformation(model: torch.fx.GraphModule):
raise nncf.InternalError(f"Node with bias have {len(graph_node.users)} users, 1 expected.")

bias_node = next(iter(graph_node.users))
with graph.inserting_before(bias_node):
new_constant = create_getattr_from_value(model, graph, target_node_name + "_shifted_bias", value)

args = list(bias_node.args)
# A bias node suppose to have constant on the second input port.
args[1] = new_constant
bias_node.args = tuple(args)
graph.eliminate_dead_code()
constant_update_fn(model, bias_node, value, input_port_id=1)

return bias_update_transformation


def constant_update_transformation_builder(node: NNCFNode, value: torch.Tensor) -> TransformationFNType:
"""
Return transformation which updates constant of the given node to the given value.
:param node: Node which requires bias constant update.
:param value: New value to use as the node constant.
:return: Transformation which updates constant of the given node to the given value.
"""

def constant_update_transformation(model: torch.fx.GraphModule):
constant_update_fn(model, get_graph_node_by_name(model.graph, node.node_name), value, input_port_id=1)

return constant_update_transformation


def constant_update_fn(model: torch.fx.GraphModule, node: torch.fx.Node, value: torch.Tensor, input_port_id: int = 1):
"""
Updates constant of given node on the given input port id with given value.
:param model: Target torch GraphModule.
:param node: Given graph node.
:param value: New value to use as the node constant.
:param input_port_id: Target constant input port id.
"""
graph = model.graph
with graph.inserting_before(node):
new_constant = create_getattr_from_value(model, graph, node.name + "_updated_constant", value)

args = list(node.args)
# A bias node suppose to have constant on the second input port.
if args[input_port_id].op != "get_attr":
raise nncf.InternalError(
f"Constant on input port {input_port_id} for {node} is expected,"
f" but node {args[input_port_id]} is present."
)
args[input_port_id] = new_constant
node.args = tuple(args)
graph.eliminate_dead_code()


def qdq_insertion_transformation_builder(
quantizer: FakeQuantize, target_points: List[PTTargetPoint]
) -> TransformationFNType:
Expand Down Expand Up @@ -200,25 +292,23 @@ def insert_one_qdq(model: torch.fx.GraphModule, target_point: PTTargetPoint, qua
raise nncf.InternalError(f"Unexpected target type: {target_point.target_type}")


def _insert_call_module(graph: torch.fx.Graph, target_point: PTTargetPoint, module_attr_name: str):
def _insert_call_module(
graph: torch.fx.Graph, target_point: PTTargetPoint, module_attr_name: str, graph_node_name: str
):
"""
Inserts module call node to the graph after the target node.
:param graph: Graph to insert module call node.
:param target_node: Target node, module call node is being iserted just after the target node.
:param module_attr_name: The name of the graph attribute which keeps the target module.
:param graph_node_name: Target name for module call node.
:return: Inserted module call node.
"""
target_node = get_graph_node_by_name(graph, target_point.target_node_name)
input_node = get_input_node(target_point, target_node)
ctx_manager = get_ctx_manager(graph, target_point)
with ctx_manager(target_node):
return graph.create_node(
"call_module",
module_attr_name,
(input_node,),
{},
name=f"{module_attr_name}_{str(target_point.target_type)}_graph_node",
)
return graph.create_node("call_module", module_attr_name, (input_node,), {}, name=graph_node_name)


def get_input_node(target_point: PTTargetPoint, target_node: torch.fx.Node) -> torch.fx.Node:
Expand Down Expand Up @@ -264,26 +354,18 @@ def get_ctx_manager(graph: torch.fx.Graph, target_point: PTTargetPoint) -> Calla


def _set_module_to_the_graph_module(
model: torch.fx.GraphModule, module_to_insert: torch.nn.Module, target_points: List[PTTargetPoint]
model: torch.fx.GraphModule,
module_to_insert: torch.nn.Module,
module_name_in_model: str,
) -> str:
"""
Sets given module to the given torch.fx.GraphModule with unique name.
:param graph: Target torch.fx.Graph.
:param module_to_insert: Module to insert to the target graph.
:param target_points: Target points which will be used to insert target module
to the graph.
:param module_name_in_model: Target model attribute name for the module_to_insert.
:return: A graph module attribute name which keep given module.
"""
module_to_insert = module_to_insert
# TODO(dlyakhov) Make module name human readable.
module_name_in_model = (
"__".join(
"_".join((tp.target_node_name, str(tp.input_port_id), str(tp.target_type.value))) for tp in target_points
)
+ "_"
+ str(id(module_to_insert))
)
assert not hasattr(model, module_name_in_model)
setattr(model, module_name_in_model, module_to_insert)
return module_name_in_model
Expand Down Expand Up @@ -397,7 +479,7 @@ def separate_linear_and_bias(model: torch.fx.GraphModule):
while linear_bias_node.op != "get_attr":
# Assume zero argument is on a path to the constant
linear_bias_node = linear_bias_node.args[0]
linear_bias_value = _get_tensor_constant_from_node(linear_bias_node, model)
linear_bias_value = get_tensor_constant_from_node(linear_bias_node, model)
args = list(n.args)
args[2] = None
linear_node.args = tuple(args)
Expand Down Expand Up @@ -436,9 +518,9 @@ def separate_conv_and_bias(model: torch.fx.GraphModule):
if len(n.args) < 3 or n.args[2] is None:
continue
conv_node = n
dims = len(_get_tensor_constant_from_node(conv_node.args[1], model).shape)
dims = len(get_tensor_constant_from_node(conv_node.args[1], model).shape)
conv_bias_node = conv_node.args[2]
conv_bias_value = _get_tensor_constant_from_node(conv_bias_node, model)
conv_bias_value = get_tensor_constant_from_node(conv_bias_node, model)
args = list(n.args)
args[2] = None
conv_node.args = tuple(args)
Expand Down Expand Up @@ -502,7 +584,7 @@ def _merge_node_and_bias(model: torch.fx.GraphModule, is_target_node: Callable[[
const_node = node
break
assert const_node is not None
bias_value = _get_tensor_constant_from_node(const_node, model).squeeze()
bias_value = get_tensor_constant_from_node(const_node, model).squeeze()
with model.graph.inserting_before(conv_node):
new_bias_node = create_getattr_from_value(model, model.graph, const_node.name + "_", bias_value)
args = list(conv_node.args)
Expand Down
Loading
Loading