Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PT2] Introduce PT2OpLayerAttribute #3178

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from __future__ import annotations

from typing import Any, Dict, Tuple, Union, cast
from typing import Any, Callable, Dict, Optional, Tuple, Union, cast

import networkx as nx # type: ignore[import-untyped]
import torch
Expand Down Expand Up @@ -185,23 +185,23 @@ def process_tensor_attributes(self, output: torch.Tensor, op_meta: OpMeta) -> No
:param output: The output tensor.
:param op_meta: Metadata about the operation.
"""
fn_name = None
func: Optional[Callable[..., Any]] = None
fn_kwargs = None

if output.grad_fn is not None:
if output.grad_fn.name() == "TransposeBackward0":
fn_name = "transpose"
func = torch.transpose
# grad_fn collect arguments as _saved_dim0=18446744073709551614
fn_kwargs = {
"dim0": -(2**64 - output.grad_fn._saved_dim0), # type: ignore[attr-defined]
"dim1": -(2**64 - output.grad_fn._saved_dim1), # type: ignore[attr-defined]
}
if output.grad_fn.name() == "PermuteBackward0":
fn_name = "permute"
func = torch.permute
fn_kwargs = {"dims": output.grad_fn._saved_dims} # type: ignore[attr-defined]

if fn_name is not None and fn_kwargs is not None:
self.graph.nodes[op_meta.extra_info["node_id"]]["meta"].fn_name = fn_name
if func is not None and fn_kwargs is not None:
self.graph.nodes[op_meta.extra_info["node_id"]]["meta"].func = func
self.graph.nodes[op_meta.extra_info["node_id"]]["meta"].kwargs = fn_kwargs

def execute_post_hooks(self, outputs: Any, op_meta: OpMeta) -> Any:
Expand Down Expand Up @@ -320,7 +320,7 @@ def register_op_node(self, args: Tuple[Any], kwargs: Dict[str, Any], op_meta: Op
self.graph.add_node(
node_id,
type=NodeType.fn_call,
meta=FunctionMeta(op_name=op_name, fn_name=op_meta.func.__name__, args=tuple(op_attrs), kwargs=op_kwargs),
meta=FunctionMeta(op_name=op_name, func=op_meta.func, args=tuple(op_attrs), kwargs=op_kwargs),
)

logger.debug(f"GraphBuilderMode.process_op_inputs: {node_id=} {op_name=} {op_attrs=} {op_kwargs=}")
Expand Down
8 changes: 6 additions & 2 deletions nncf/experimental/torch2/function_hook/graph/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, Optional, Tuple
from typing import Any, Callable, Dict, Optional, Tuple

import torch

Expand Down Expand Up @@ -75,10 +75,14 @@ def from_tensor(tensor: torch.Tensor, name: str) -> InOutMeta:
@dataclass
class FunctionMeta:
op_name: str
fn_name: str
func: Callable[..., Any]
args: Tuple[Any, ...]
kwargs: Dict[str, Any]

@property
def func_name(self) -> str:
return self.func.__name__


@dataclass
class EdgeMeta:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def get_label_from_node_data(node_data: Dict[str, Any], style: PydotStyleTemplat
rows = [
f"type: {node_type}",
f"op_name: {meta.op_name}",
f"fn_name: {meta.fn_name}",
f"fn_name: {meta.func_name}",
f"args: {args_to_label(meta.args)}",
f"kwargs: {kwargs_to_label(meta.kwargs)}",
]
Expand Down Expand Up @@ -195,7 +195,7 @@ def get_style(node: Dict[str, Any], style: PydotStyleTemplate) -> Dict[str, str]
}
if isinstance(meta, FunctionMeta):
return {
"fillcolor": color_picker(meta.fn_name),
"fillcolor": color_picker(meta.func_name),
"fontcolor": "#000000",
"shape": "record",
"style": '"filled,rounded"',
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) 2025 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import Any, Callable, Dict, Set, Tuple

from nncf.common.graph.layer_attributes import BaseLayerAttributes


@dataclass(frozen=True)
class PT2OpLayerAttributes(BaseLayerAttributes):
"""
This class stores information about operation.

:param func: Function that the operation represents.
:param op_args: Tuple of positional arguments for the operation.
:param op_kwargs: Dictionary of keyword arguments for the operation.
:param constant_port_ids: Set of input port indices with constants.
"""

func: Callable[..., Any]
op_args: Tuple[Any, ...]
op_kwargs: Dict[str, Any]
constant_port_ids: Set[int]
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple, Union, cast
from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast

import networkx as nx # type: ignore
import torch
Expand All @@ -21,13 +21,15 @@
import nncf.torch.graph.operator_metatypes as om
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.layer_attributes import BaseLayerAttributes
from nncf.common.graph.layer_attributes import Dtype
from nncf.experimental.torch2.function_hook.graph.build_graph_mode import build_graph
from nncf.experimental.torch2.function_hook.graph.graph_utils import ConstMeta
from nncf.experimental.torch2.function_hook.graph.graph_utils import EdgeMeta
from nncf.experimental.torch2.function_hook.graph.graph_utils import FunctionMeta
from nncf.experimental.torch2.function_hook.graph.graph_utils import InOutMeta
from nncf.experimental.torch2.function_hook.graph.graph_utils import NodeType
from nncf.experimental.torch2.function_hook.nncf_graph.layer_attributes import PT2OpLayerAttributes


def get_node_type(type: NodeType, meta: Union[ConstMeta, FunctionMeta, InOutMeta]) -> str:
Expand All @@ -45,7 +47,7 @@ def get_node_type(type: NodeType, meta: Union[ConstMeta, FunctionMeta, InOutMeta
if isinstance(meta, ConstMeta):
return "nncf_model_const"
if isinstance(meta, FunctionMeta):
return meta.fn_name
return meta.func_name
raise nncf.InternalError("Unexpected metadata type")


Expand Down Expand Up @@ -77,20 +79,86 @@ def get_dtype(dtype: torch.dtype) -> Dtype:
return Dtype.INTEGER


def get_meta_type(node_type: str, meta: Union[ConstMeta, FunctionMeta, InOutMeta]) -> om.PTOperatorMetatype:
def get_meta_type(node_type: str, meta: Union[ConstMeta, FunctionMeta, InOutMeta]) -> type[om.PTOperatorMetatype]:
"""
Converts the node type and metadata into a PTOperatorMetatype object.
:param node_type: The type of the node.
:param meta: The metadata associated with the node.
:return: The PTOperatorMetatype object.
"""
node_metatype = cast(om.PTOperatorMetatype, om.PT_OPERATOR_METATYPES.get_operator_metatype_by_op_name(node_type))
node_sub_meta_type: Optional[om.PTOperatorMetatype] = None
node_metatype = cast(
type[om.PTOperatorMetatype], om.PT_OPERATOR_METATYPES.get_operator_metatype_by_op_name(node_type)
)
node_sub_meta_type: Optional[type[om.PTOperatorMetatype]] = None
if node_metatype.get_subtypes() and isinstance(meta, FunctionMeta):
node_sub_meta_type = node_metatype.determine_subtype(function_args=meta.args, functions_kwargs=meta.kwargs)
return node_sub_meta_type or node_metatype


def is_constant_input_node(nx_graph: nx.MultiDiGraph, node: int) -> bool:
"""
Check if a node is a constant input node or constant subgraph:

1) constant
2) quantize_function -> constant

:param nx_graph: The graph to check the node from.
:param node: The node to check.
:return: True if the node is a constant input node, False otherwise.
"""
meta = nx_graph.nodes[node]["meta"]

# 1) Input node is a constant node (parameter or buffer)
if isinstance(meta, ConstMeta):
return True

# 2) Quantize node with constant input
if (
isinstance(meta, FunctionMeta)
and meta.func_name in om.QUANTIZE_NODE_TYPES
and isinstance(nx_graph.nodes[node]["meta"], FunctionMeta)
):
return all(isinstance(nx_graph.nodes[s_node]["meta"], ConstMeta) for s_node, _ in nx_graph.in_edges(node))

return False


def get_constant_port_ids(nx_graph: nx.MultiDiGraph, node: int) -> Set[int]:
"""
Get the indices of input ports corresponding to the constant node or subgraph.

:param nx_graph: The graph to get the constant port IDs from.
:param node: The node to get the constant port IDs from.
:return: The list of input port indices with constants.
"""
constant_port_ids: Set[int] = set()

for s_node, _, data in nx_graph.in_edges(node, data=True):
if is_constant_input_node(nx_graph, s_node):
meta = cast(EdgeMeta, data["meta"])
constant_port_ids.add(meta.input_port)

return constant_port_ids


def get_layer_attributes(
nx_graph: nx.MultiDiGraph, node: int, meta: Union[ConstMeta, FunctionMeta, InOutMeta]
) -> Optional[BaseLayerAttributes]:
"""
Get the layer attributes of a node in the graph.

:param nx_graph: The graph to get the layer attributes from.
:param node: The node to get the layer attributes from.
:param meta: The metadata associated with the node.
:return: The layer attributes of the node.
"""
if isinstance(meta, FunctionMeta):
constant_port_ids = get_constant_port_ids(nx_graph, node)
return PT2OpLayerAttributes(meta.func, meta.args, meta.kwargs, constant_port_ids)
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved

return None


def convert_to_nncf_graph(nx_graph: nx.MultiDiGraph) -> NNCFGraph:
"""
Converts a graph to an NNCFGraph.
Expand All @@ -102,15 +170,18 @@ def convert_to_nncf_graph(nx_graph: nx.MultiDiGraph) -> NNCFGraph:

map_nx_node_to_nncf_node: Dict[int, NNCFNode] = {}
for node, data in nx_graph.nodes(data=True):
meta: Union[ConstMeta, FunctionMeta, InOutMeta] = data["meta"]
meta = data["meta"]
if not isinstance(meta, (ConstMeta, FunctionMeta, InOutMeta)):
raise nncf.InternalError(f"Unknown metadata type: {type(meta)}")
node_name = get_name_of_node(meta)
node_type = get_node_type(data["type"], meta)
meta_type = get_meta_type(node_type, meta)

layer_attributes = get_layer_attributes(nx_graph, node, meta)
nncf_node = nncf_graph.add_nncf_node(
node_name=node_name,
node_type=node_type,
node_metatype=meta_type, # type: ignore[arg-type]
node_metatype=meta_type,
layer_attributes=layer_attributes,
)
map_nx_node_to_nncf_node[node] = nncf_node

Expand Down
2 changes: 1 addition & 1 deletion nncf/torch/graph/operator_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def get_all_aliases(cls) -> List[str]:
@classmethod
def determine_subtype(
cls, layer_attributes: Optional[BaseLayerAttributes] = None, function_args=None, functions_kwargs=None
) -> Optional["PTOperatorSubtype"]:
) -> Optional["type[PTOperatorSubtype]"]:
matches = []
for subtype in cls.get_subtypes():
if subtype.matches(layer_attributes, function_args, functions_kwargs):
Expand Down
6 changes: 3 additions & 3 deletions tests/torch2/function_hook/graph/test_build_graph_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def test_execute_pre_hooks():
"type": NodeType.fn_call,
"meta": FunctionMeta(
op_name="/relu/0",
fn_name="relu",
func=torch.relu,
args=(
TensorMeta(dtype=torch.float32, shape=(1,), requires_grad=False),
TensorMeta(dtype=torch.float32, shape=(1, 1, 1, 1), requires_grad=True),
Expand Down Expand Up @@ -190,14 +190,14 @@ def test_tensor_attributes(attr):
if attr == ".T":
ref_meta = FunctionMeta(
op_name="/__get__/0",
fn_name="permute",
func=torch.permute,
args=(TensorMeta(dtype=torch.float32, shape=(2, 3), requires_grad=True),),
kwargs={"dims": (1, 0)},
)
else:
ref_meta = FunctionMeta(
op_name="/__get__/0",
fn_name="transpose",
func=torch.transpose,
args=(TensorMeta(dtype=torch.float32, shape=(2, 3), requires_grad=True),),
kwargs={"dim0": -2, "dim1": -1},
)
Expand Down
33 changes: 33 additions & 0 deletions tests/torch2/function_hook/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,36 @@ def __init__(self) -> None:
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + x
return x


class MatMulLeft(torch.nn.Module):
def __init__(self):
super().__init__()
self.w = torch.nn.Parameter(torch.tensor([1], dtype=torch.float32))

@staticmethod
def get_example_inputs():
return torch.ones([1, 1])

def forward(self, x):
return torch.matmul(x, self.w)


class MatMulRight(MatMulLeft):
def forward(self, x):
return torch.matmul(self.w, x)


class QuantizedConvModel(nn.Module):
@staticmethod
def get_example_inputs():
return torch.ones([1, 1, 3, 3])

def __init__(self) -> None:
super().__init__()
self.conv = nn.Conv2d(1, 1, 1)

def forward(self, x: torch.Tensor):
x = self.conv(x)
x = torch.relu(x)
return x
Loading