From 5f161da88b0650dd58a34ea7618d9497ebbe6436 Mon Sep 17 00:00:00 2001 From: Alexander Dokuchaev Date: Tue, 7 Jan 2025 15:39:00 +0200 Subject: [PATCH 1/2] OPLayerAttribute --- .../function_hook/graph/build_graph_mode.py | 14 +-- .../torch2/function_hook/graph/graph_utils.py | 8 +- .../graph/graph_visualization.py | 4 +- .../nncf_graph/layer_attributes.py | 32 +++++++ .../nncf_graph/nncf_graph_builder.py | 89 ++++++++++++++++-- nncf/torch/graph/operator_metatypes.py | 2 +- .../graph/test_build_graph_mode.py | 6 +- tests/torch2/function_hook/helpers.py | 33 +++++++ .../nncf_graph/test_layer_attributes.py | 92 +++++++++++++++++++ .../nncf_graph/test_nncf_graph.py | 2 +- 10 files changed, 257 insertions(+), 25 deletions(-) create mode 100644 nncf/experimental/torch2/function_hook/nncf_graph/layer_attributes.py create mode 100644 tests/torch2/function_hook/nncf_graph/test_layer_attributes.py diff --git a/nncf/experimental/torch2/function_hook/graph/build_graph_mode.py b/nncf/experimental/torch2/function_hook/graph/build_graph_mode.py index a5659c631b1..1fd52922315 100644 --- a/nncf/experimental/torch2/function_hook/graph/build_graph_mode.py +++ b/nncf/experimental/torch2/function_hook/graph/build_graph_mode.py @@ -11,7 +11,7 @@ from __future__ import annotations -from typing import Any, Dict, Tuple, Union, cast +from typing import Any, Callable, Dict, Optional, Tuple, Union, cast import networkx as nx # type: ignore[import-untyped] import torch @@ -185,23 +185,23 @@ def process_tensor_attributes(self, output: torch.Tensor, op_meta: OpMeta) -> No :param output: The output tensor. :param op_meta: Metadata about the operation. """ - fn_name = None + func: Optional[Callable[..., Any]] = None fn_kwargs = None if output.grad_fn is not None: if output.grad_fn.name() == "TransposeBackward0": - fn_name = "transpose" + func = torch.transpose # grad_fn collect arguments as _saved_dim0=18446744073709551614 fn_kwargs = { "dim0": -(2**64 - output.grad_fn._saved_dim0), # type: ignore[attr-defined] "dim1": -(2**64 - output.grad_fn._saved_dim1), # type: ignore[attr-defined] } if output.grad_fn.name() == "PermuteBackward0": - fn_name = "permute" + func = torch.permute fn_kwargs = {"dims": output.grad_fn._saved_dims} # type: ignore[attr-defined] - if fn_name is not None and fn_kwargs is not None: - self.graph.nodes[op_meta.extra_info["node_id"]]["meta"].fn_name = fn_name + if func is not None and fn_kwargs is not None: + self.graph.nodes[op_meta.extra_info["node_id"]]["meta"].func = func self.graph.nodes[op_meta.extra_info["node_id"]]["meta"].kwargs = fn_kwargs def execute_post_hooks(self, outputs: Any, op_meta: OpMeta) -> Any: @@ -320,7 +320,7 @@ def register_op_node(self, args: Tuple[Any], kwargs: Dict[str, Any], op_meta: Op self.graph.add_node( node_id, type=NodeType.fn_call, - meta=FunctionMeta(op_name=op_name, fn_name=op_meta.func.__name__, args=tuple(op_attrs), kwargs=op_kwargs), + meta=FunctionMeta(op_name=op_name, func=op_meta.func, args=tuple(op_attrs), kwargs=op_kwargs), ) logger.debug(f"GraphBuilderMode.process_op_inputs: {node_id=} {op_name=} {op_attrs=} {op_kwargs=}") diff --git a/nncf/experimental/torch2/function_hook/graph/graph_utils.py b/nncf/experimental/torch2/function_hook/graph/graph_utils.py index 1dcdb7deaa8..c04b4f4f87d 100644 --- a/nncf/experimental/torch2/function_hook/graph/graph_utils.py +++ b/nncf/experimental/torch2/function_hook/graph/graph_utils.py @@ -13,7 +13,7 @@ from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, Optional, Tuple +from typing import Any, Callable, Dict, Optional, Tuple import torch @@ -75,10 +75,14 @@ def from_tensor(tensor: torch.Tensor, name: str) -> InOutMeta: @dataclass class FunctionMeta: op_name: str - fn_name: str + func: Callable[..., Any] args: Tuple[Any, ...] kwargs: Dict[str, Any] + @property + def func_name(self) -> str: + return self.func.__name__ + @dataclass class EdgeMeta: diff --git a/nncf/experimental/torch2/function_hook/graph/graph_visualization.py b/nncf/experimental/torch2/function_hook/graph/graph_visualization.py index 8d32f591162..44f47a9bcd5 100644 --- a/nncf/experimental/torch2/function_hook/graph/graph_visualization.py +++ b/nncf/experimental/torch2/function_hook/graph/graph_visualization.py @@ -108,7 +108,7 @@ def get_label_from_node_data(node_data: Dict[str, Any], style: PydotStyleTemplat rows = [ f"type: {node_type}", f"op_name: {meta.op_name}", - f"fn_name: {meta.fn_name}", + f"fn_name: {meta.func_name}", f"args: {args_to_label(meta.args)}", f"kwargs: {kwargs_to_label(meta.kwargs)}", ] @@ -195,7 +195,7 @@ def get_style(node: Dict[str, Any], style: PydotStyleTemplate) -> Dict[str, str] } if isinstance(meta, FunctionMeta): return { - "fillcolor": color_picker(meta.fn_name), + "fillcolor": color_picker(meta.func_name), "fontcolor": "#000000", "shape": "record", "style": '"filled,rounded"', diff --git a/nncf/experimental/torch2/function_hook/nncf_graph/layer_attributes.py b/nncf/experimental/torch2/function_hook/nncf_graph/layer_attributes.py new file mode 100644 index 00000000000..ccf786fc3e4 --- /dev/null +++ b/nncf/experimental/torch2/function_hook/nncf_graph/layer_attributes.py @@ -0,0 +1,32 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Callable, Dict, Set, Tuple + +from nncf.common.graph.layer_attributes import BaseLayerAttributes + + +@dataclass(frozen=True) +class PT2OpLayerAttributes(BaseLayerAttributes): + """ + This class stores information about operation. + + :param func: Function that the operation represents. + :param op_args: Tuple of positional arguments for the operation. + :param op_kwargs: Dictionary of keyword arguments for the operation. + :param constant_port_ids: Set of input port indices with constants. + """ + + func: Callable[..., Any] + op_args: Tuple[Any, ...] + op_kwargs: Dict[str, Any] + constant_port_ids: Set[int] diff --git a/nncf/experimental/torch2/function_hook/nncf_graph/nncf_graph_builder.py b/nncf/experimental/torch2/function_hook/nncf_graph/nncf_graph_builder.py index 190d4959d95..9c7a5c1c164 100644 --- a/nncf/experimental/torch2/function_hook/nncf_graph/nncf_graph_builder.py +++ b/nncf/experimental/torch2/function_hook/nncf_graph/nncf_graph_builder.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025 Intel Corporation +# Copyright (c) 2024 Intel Corporation # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,7 +11,7 @@ from collections import defaultdict -from typing import Any, Dict, List, Optional, Tuple, Union, cast +from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast import networkx as nx # type: ignore import torch @@ -21,6 +21,7 @@ import nncf.torch.graph.operator_metatypes as om from nncf.common.graph.graph import NNCFGraph from nncf.common.graph.graph import NNCFNode +from nncf.common.graph.layer_attributes import BaseLayerAttributes from nncf.common.graph.layer_attributes import Dtype from nncf.experimental.torch2.function_hook.graph.build_graph_mode import build_graph from nncf.experimental.torch2.function_hook.graph.graph_utils import ConstMeta @@ -28,6 +29,7 @@ from nncf.experimental.torch2.function_hook.graph.graph_utils import FunctionMeta from nncf.experimental.torch2.function_hook.graph.graph_utils import InOutMeta from nncf.experimental.torch2.function_hook.graph.graph_utils import NodeType +from nncf.experimental.torch2.function_hook.nncf_graph.layer_attributes import PT2OpLayerAttributes def get_node_type(type: NodeType, meta: Union[ConstMeta, FunctionMeta, InOutMeta]) -> str: @@ -45,7 +47,7 @@ def get_node_type(type: NodeType, meta: Union[ConstMeta, FunctionMeta, InOutMeta if isinstance(meta, ConstMeta): return "nncf_model_const" if isinstance(meta, FunctionMeta): - return meta.fn_name + return meta.func_name raise nncf.InternalError("Unexpected metadata type") @@ -77,20 +79,86 @@ def get_dtype(dtype: torch.dtype) -> Dtype: return Dtype.INTEGER -def get_meta_type(node_type: str, meta: Union[ConstMeta, FunctionMeta, InOutMeta]) -> om.PTOperatorMetatype: +def get_meta_type(node_type: str, meta: Union[ConstMeta, FunctionMeta, InOutMeta]) -> type[om.PTOperatorMetatype]: """ Converts the node type and metadata into a PTOperatorMetatype object. :param node_type: The type of the node. :param meta: The metadata associated with the node. :return: The PTOperatorMetatype object. """ - node_metatype = cast(om.PTOperatorMetatype, om.PT_OPERATOR_METATYPES.get_operator_metatype_by_op_name(node_type)) - node_sub_meta_type: Optional[om.PTOperatorMetatype] = None + node_metatype = cast( + type[om.PTOperatorMetatype], om.PT_OPERATOR_METATYPES.get_operator_metatype_by_op_name(node_type) + ) + node_sub_meta_type: Optional[type[om.PTOperatorMetatype]] = None if node_metatype.get_subtypes() and isinstance(meta, FunctionMeta): node_sub_meta_type = node_metatype.determine_subtype(function_args=meta.args, functions_kwargs=meta.kwargs) return node_sub_meta_type or node_metatype +def is_constant_input_node(nx_graph: nx.MultiDiGraph, node: int) -> bool: + """ + Check if a node is a constant input node or constant subgraph: + + 1) constant + 2) quantize_function -> constant + + :param nx_graph: The graph to check the node from. + :param node: The node to check. + :return: True if the node is a constant input node, False otherwise. + """ + meta = nx_graph.nodes[node]["meta"] + + # 1) Input node is a constant node (parameter or buffer) + if isinstance(meta, ConstMeta): + return True + + # 2) Quantize node with constant input + if ( + isinstance(meta, FunctionMeta) + and meta.func_name in om.QUANTIZE_NODE_TYPES + and isinstance(nx_graph.nodes[node]["meta"], FunctionMeta) + ): + return all(isinstance(nx_graph.nodes[s_node]["meta"], ConstMeta) for s_node, _ in nx_graph.in_edges(node)) + + return False + + +def get_constant_port_ids(nx_graph: nx.MultiDiGraph, node: int) -> Set[int]: + """ + Get the indices of input ports corresponding to the constant node or subgraph. + + :param nx_graph: The graph to get the constant port IDs from. + :param node: The node to get the constant port IDs from. + :return: The list of input port indices with constants. + """ + constant_port_ids: Set[int] = set() + + for s_node, _, data in nx_graph.in_edges(node, data=True): + if is_constant_input_node(nx_graph, s_node): + meta = cast(EdgeMeta, data["meta"]) + constant_port_ids.add(meta.input_port) + + return constant_port_ids + + +def get_layer_attributes( + nx_graph: nx.MultiDiGraph, node: int, meta: Union[ConstMeta, FunctionMeta, InOutMeta] +) -> Optional[BaseLayerAttributes]: + """ + Get the layer attributes of a node in the graph. + + :param nx_graph: The graph to get the layer attributes from. + :param node: The node to get the layer attributes from. + :param meta: The metadata associated with the node. + :return: The layer attributes of the node. + """ + if isinstance(meta, FunctionMeta): + constant_port_ids = get_constant_port_ids(nx_graph, node) + return PT2OpLayerAttributes(meta.func, meta.args, meta.kwargs, constant_port_ids) + + return None + + def convert_to_nncf_graph(nx_graph: nx.MultiDiGraph) -> NNCFGraph: """ Converts a graph to an NNCFGraph. @@ -102,15 +170,18 @@ def convert_to_nncf_graph(nx_graph: nx.MultiDiGraph) -> NNCFGraph: map_nx_node_to_nncf_node: Dict[int, NNCFNode] = {} for node, data in nx_graph.nodes(data=True): - meta: Union[ConstMeta, FunctionMeta, InOutMeta] = data["meta"] + meta = data["meta"] + if not isinstance(meta, (ConstMeta, FunctionMeta, InOutMeta)): + raise nncf.InternalError(f"Unknown metadata type: {type(meta)}") node_name = get_name_of_node(meta) node_type = get_node_type(data["type"], meta) meta_type = get_meta_type(node_type, meta) - + layer_attributes = get_layer_attributes(nx_graph, node, meta) nncf_node = nncf_graph.add_nncf_node( node_name=node_name, node_type=node_type, - node_metatype=meta_type, # type: ignore[arg-type] + node_metatype=meta_type, + layer_attributes=layer_attributes, ) map_nx_node_to_nncf_node[node] = nncf_node diff --git a/nncf/torch/graph/operator_metatypes.py b/nncf/torch/graph/operator_metatypes.py index 2f8806d27c3..825c63c5765 100644 --- a/nncf/torch/graph/operator_metatypes.py +++ b/nncf/torch/graph/operator_metatypes.py @@ -84,7 +84,7 @@ def get_all_aliases(cls) -> List[str]: @classmethod def determine_subtype( cls, layer_attributes: Optional[BaseLayerAttributes] = None, function_args=None, functions_kwargs=None - ) -> Optional["PTOperatorSubtype"]: + ) -> Optional["type[PTOperatorSubtype]"]: matches = [] for subtype in cls.get_subtypes(): if subtype.matches(layer_attributes, function_args, functions_kwargs): diff --git a/tests/torch2/function_hook/graph/test_build_graph_mode.py b/tests/torch2/function_hook/graph/test_build_graph_mode.py index 0aa6dde1c2f..d1f1dd4c022 100644 --- a/tests/torch2/function_hook/graph/test_build_graph_mode.py +++ b/tests/torch2/function_hook/graph/test_build_graph_mode.py @@ -101,7 +101,7 @@ def test_execute_pre_hooks(): "type": NodeType.fn_call, "meta": FunctionMeta( op_name="/relu/0", - fn_name="relu", + func=torch.relu, args=( TensorMeta(dtype=torch.float32, shape=(1,), requires_grad=False), TensorMeta(dtype=torch.float32, shape=(1, 1, 1, 1), requires_grad=True), @@ -190,14 +190,14 @@ def test_tensor_attributes(attr): if attr == ".T": ref_meta = FunctionMeta( op_name="/__get__/0", - fn_name="permute", + func=torch.permute, args=(TensorMeta(dtype=torch.float32, shape=(2, 3), requires_grad=True),), kwargs={"dims": (1, 0)}, ) else: ref_meta = FunctionMeta( op_name="/__get__/0", - fn_name="transpose", + func=torch.transpose, args=(TensorMeta(dtype=torch.float32, shape=(2, 3), requires_grad=True),), kwargs={"dim0": -2, "dim1": -1}, ) diff --git a/tests/torch2/function_hook/helpers.py b/tests/torch2/function_hook/helpers.py index 0b261253905..5b058c188ef 100644 --- a/tests/torch2/function_hook/helpers.py +++ b/tests/torch2/function_hook/helpers.py @@ -103,3 +103,36 @@ def __init__(self) -> None: def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + x return x + + +class MatMulLeft(torch.nn.Module): + def __init__(self): + super().__init__() + self.w = torch.nn.Parameter(torch.tensor([1], dtype=torch.float32)) + + @staticmethod + def get_example_inputs(): + return torch.ones([1, 1]) + + def forward(self, x): + return torch.matmul(x, self.w) + + +class MatMulRight(MatMulLeft): + def forward(self, x): + return torch.matmul(self.w, x) + + +class QuantizedConvModel(nn.Module): + @staticmethod + def get_example_inputs(): + return torch.ones([1, 1, 3, 3]) + + def __init__(self) -> None: + super().__init__() + self.conv = nn.Conv2d(1, 1, 1) + + def forward(self, x: torch.Tensor): + x = self.conv(x) + x = torch.relu(x) + return x diff --git a/tests/torch2/function_hook/nncf_graph/test_layer_attributes.py b/tests/torch2/function_hook/nncf_graph/test_layer_attributes.py new file mode 100644 index 00000000000..c2d972e663c --- /dev/null +++ b/tests/torch2/function_hook/nncf_graph/test_layer_attributes.py @@ -0,0 +1,92 @@ +# Copyright (c) 2025 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +import pytest +import torch +from torch import nn + +from nncf.experimental.torch2.function_hook.graph.graph_utils import TensorMeta +from nncf.experimental.torch2.function_hook.nncf_graph.layer_attributes import PT2OpLayerAttributes +from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import build_nncf_graph +from nncf.experimental.torch2.function_hook.wrapper import wrap_model +from tests.torch2.function_hook.helpers import ConvModel +from tests.torch2.function_hook.helpers import MatMulLeft +from tests.torch2.function_hook.helpers import MatMulRight + + +@dataclass +class ParamForLayerAttributes: + model_cls: type[nn.Module] + node_name: str + ref: PT2OpLayerAttributes + + def __str__(self) -> str: + return self.model_cls.__name__ + + +@pytest.mark.parametrize( + "param", + [ + ParamForLayerAttributes( + ConvModel, + "conv/conv2d/0", + PT2OpLayerAttributes( + func=torch.conv2d, + op_args=( + TensorMeta(dtype=torch.float32, shape=(1, 1, 3, 3), requires_grad=False), + TensorMeta(dtype=torch.float32, shape=(1, 1, 1, 1), requires_grad=True), + TensorMeta(dtype=torch.float32, shape=(1,), requires_grad=True), + [1, 1], + [0, 0], + [1, 1], + 1, + ), + op_kwargs={}, + constant_port_ids={1, 2}, + ), + ), + ParamForLayerAttributes( + MatMulLeft, + "/matmul/0", + PT2OpLayerAttributes( + func=torch.matmul, + op_args=( + TensorMeta(dtype=torch.float32, shape=(1, 1), requires_grad=False), + TensorMeta(dtype=torch.float32, shape=(1,), requires_grad=True), + ), + op_kwargs={}, + constant_port_ids={1}, + ), + ), + ParamForLayerAttributes( + MatMulRight, + "/matmul/0", + PT2OpLayerAttributes( + func=torch.matmul, + op_args=( + TensorMeta(dtype=torch.float32, shape=(1,), requires_grad=True), + TensorMeta(dtype=torch.float32, shape=(1, 1), requires_grad=False), + ), + op_kwargs={}, + constant_port_ids={0}, + ), + ), + ], + ids=str, +) +def test_op_layer_attribute(param: ParamForLayerAttributes): + # TODO(AlexanderDokuchaev): quantized model too + model = wrap_model(param.model_cls()) + nncf_graph = build_nncf_graph(model, model.get_example_inputs()) + op_node = nncf_graph.get_node_by_name(param.node_name) + assert op_node.layer_attributes == param.ref diff --git a/tests/torch2/function_hook/nncf_graph/test_nncf_graph.py b/tests/torch2/function_hook/nncf_graph/test_nncf_graph.py index bce5e0c1695..4fec6cb4075 100644 --- a/tests/torch2/function_hook/nncf_graph/test_nncf_graph.py +++ b/tests/torch2/function_hook/nncf_graph/test_nncf_graph.py @@ -62,7 +62,7 @@ def get_reference_graph(graph: NNCFGraph) -> nx.DiGraph: [ [NodeType.input, InOutMeta(torch.float32, (1), "input"), "nncf_model_input"], [NodeType.output, InOutMeta(torch.float32, (1), "output"), "nncf_model_output"], - [NodeType.output, FunctionMeta("op", "fn_name_ref", [], {}), "fn_name_ref"], + [NodeType.output, FunctionMeta("op", torch.relu, [], {}), "relu"], [NodeType.output, ConstMeta(torch.float32, (1), "model.bias"), "nncf_model_const"], ], ) From 3d5e074d40bf4af70dc262f8ce50669f5798b053 Mon Sep 17 00:00:00 2001 From: Alexander Dokuchaev Date: Thu, 9 Jan 2025 02:20:14 +0200 Subject: [PATCH 2/2] 2025 --- .../torch2/function_hook/nncf_graph/layer_attributes.py | 2 +- .../torch2/function_hook/nncf_graph/nncf_graph_builder.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/nncf/experimental/torch2/function_hook/nncf_graph/layer_attributes.py b/nncf/experimental/torch2/function_hook/nncf_graph/layer_attributes.py index ccf786fc3e4..301077cf14c 100644 --- a/nncf/experimental/torch2/function_hook/nncf_graph/layer_attributes.py +++ b/nncf/experimental/torch2/function_hook/nncf_graph/layer_attributes.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024 Intel Corporation +# Copyright (c) 2025 Intel Corporation # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/nncf/experimental/torch2/function_hook/nncf_graph/nncf_graph_builder.py b/nncf/experimental/torch2/function_hook/nncf_graph/nncf_graph_builder.py index 9c7a5c1c164..374745f524c 100644 --- a/nncf/experimental/torch2/function_hook/nncf_graph/nncf_graph_builder.py +++ b/nncf/experimental/torch2/function_hook/nncf_graph/nncf_graph_builder.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024 Intel Corporation +# Copyright (c) 2025 Intel Corporation # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at