Skip to content

Commit

Permalink
[PT2] Introduce PT2OpLayerAttribute (openvinotoolkit#3178)
Browse files Browse the repository at this point in the history
### Changes

- Introduced `PT2OpLayerAttribute`, to collect called function,
attributes and constant ports
- `FunctionMeta` stored function instead of function name

### Reason for changes

Needs to implement subgraph extractor for FBC

### Related tickets

152996

### Tests

tests/torch2/function_hook/nncf_graph/test_layer_attributes.py
  • Loading branch information
AlexanderDokuchaev authored Jan 28, 2025
1 parent 1b6af84 commit cdf7208
Show file tree
Hide file tree
Showing 10 changed files with 256 additions and 24 deletions.
14 changes: 7 additions & 7 deletions nncf/experimental/torch2/function_hook/graph/build_graph_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from __future__ import annotations

from typing import Any, Dict, Tuple, Union, cast
from typing import Any, Callable, Dict, Optional, Tuple, Union, cast

import networkx as nx # type: ignore[import-untyped]
import torch
Expand Down Expand Up @@ -185,12 +185,12 @@ def process_tensor_attributes(self, output: torch.Tensor, op_meta: OpMeta) -> No
:param output: The output tensor.
:param op_meta: Metadata about the operation.
"""
fn_name = None
func: Optional[Callable[..., Any]] = None
fn_kwargs = None

if output.grad_fn is not None:
if output.grad_fn.name() == "TransposeBackward0":
fn_name = "transpose"
func = torch.transpose
# grad_fn collect arguments as _saved_dim0=18446744073709551614
# Use static arguments for .mT
# https://pytorch.org/docs/stable/tensors.html#torch.Tensor.mT
Expand All @@ -199,11 +199,11 @@ def process_tensor_attributes(self, output: torch.Tensor, op_meta: OpMeta) -> No
"dim1": -1,
}
if output.grad_fn.name() == "PermuteBackward0":
fn_name = "permute"
func = torch.permute
fn_kwargs = {"dims": output.grad_fn._saved_dims} # type: ignore[attr-defined]

if fn_name is not None and fn_kwargs is not None:
self.graph.nodes[op_meta.extra_info["node_id"]]["meta"].fn_name = fn_name
if func is not None and fn_kwargs is not None:
self.graph.nodes[op_meta.extra_info["node_id"]]["meta"].func = func
self.graph.nodes[op_meta.extra_info["node_id"]]["meta"].kwargs = fn_kwargs

def execute_post_hooks(self, outputs: Any, op_meta: OpMeta) -> Any:
Expand Down Expand Up @@ -322,7 +322,7 @@ def register_op_node(self, args: Tuple[Any], kwargs: Dict[str, Any], op_meta: Op
self.graph.add_node(
node_id,
type=NodeType.fn_call,
meta=FunctionMeta(op_name=op_name, fn_name=op_meta.func.__name__, args=tuple(op_attrs), kwargs=op_kwargs),
meta=FunctionMeta(op_name=op_name, func=op_meta.func, args=tuple(op_attrs), kwargs=op_kwargs),
)

logger.debug(f"GraphBuilderMode.process_op_inputs: {node_id=} {op_name=} {op_attrs=} {op_kwargs=}")
Expand Down
8 changes: 6 additions & 2 deletions nncf/experimental/torch2/function_hook/graph/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, Optional, Tuple
from typing import Any, Callable, Dict, Optional, Tuple

import torch

Expand Down Expand Up @@ -75,10 +75,14 @@ def from_tensor(tensor: torch.Tensor, name: str) -> InOutMeta:
@dataclass
class FunctionMeta:
op_name: str
fn_name: str
func: Callable[..., Any]
args: Tuple[Any, ...]
kwargs: Dict[str, Any]

@property
def func_name(self) -> str:
return self.func.__name__


@dataclass
class EdgeMeta:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def get_label_from_node_data(node_data: Dict[str, Any], style: PydotStyleTemplat
rows = [
f"type: {node_type}",
f"op_name: {meta.op_name}",
f"fn_name: {meta.fn_name}",
f"fn_name: {meta.func_name}",
f"args: {args_to_label(meta.args)}",
f"kwargs: {kwargs_to_label(meta.kwargs)}",
]
Expand Down Expand Up @@ -195,7 +195,7 @@ def get_style(node: Dict[str, Any], style: PydotStyleTemplate) -> Dict[str, str]
}
if isinstance(meta, FunctionMeta):
return {
"fillcolor": color_picker(meta.fn_name),
"fillcolor": color_picker(meta.func_name),
"fontcolor": "#000000",
"shape": "record",
"style": '"filled,rounded"',
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) 2025 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import Any, Callable, Dict, Set, Tuple

from nncf.common.graph.layer_attributes import BaseLayerAttributes


@dataclass(frozen=True)
class PT2OpLayerAttributes(BaseLayerAttributes):
"""
This class stores information about operation.
:param func: Function that the operation represents.
:param op_args: Tuple of positional arguments for the operation.
:param op_kwargs: Dictionary of keyword arguments for the operation.
:param constant_port_ids: Set of input port indices with constants.
"""

func: Callable[..., Any]
op_args: Tuple[Any, ...]
op_kwargs: Dict[str, Any]
constant_port_ids: Set[int]
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple, Union, cast
from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast

import networkx as nx # type: ignore
import torch
Expand All @@ -21,13 +21,15 @@
import nncf.torch.graph.operator_metatypes as om
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.layer_attributes import BaseLayerAttributes
from nncf.common.graph.layer_attributes import Dtype
from nncf.experimental.torch2.function_hook.graph.build_graph_mode import build_graph
from nncf.experimental.torch2.function_hook.graph.graph_utils import ConstMeta
from nncf.experimental.torch2.function_hook.graph.graph_utils import EdgeMeta
from nncf.experimental.torch2.function_hook.graph.graph_utils import FunctionMeta
from nncf.experimental.torch2.function_hook.graph.graph_utils import InOutMeta
from nncf.experimental.torch2.function_hook.graph.graph_utils import NodeType
from nncf.experimental.torch2.function_hook.nncf_graph.layer_attributes import PT2OpLayerAttributes


def get_node_type(type: NodeType, meta: Union[ConstMeta, FunctionMeta, InOutMeta]) -> str:
Expand All @@ -45,7 +47,7 @@ def get_node_type(type: NodeType, meta: Union[ConstMeta, FunctionMeta, InOutMeta
if isinstance(meta, ConstMeta):
return "nncf_model_const"
if isinstance(meta, FunctionMeta):
return meta.fn_name
return meta.func_name
raise nncf.InternalError("Unexpected metadata type")


Expand Down Expand Up @@ -77,20 +79,86 @@ def get_dtype(dtype: torch.dtype) -> Dtype:
return Dtype.INTEGER


def get_meta_type(node_type: str, meta: Union[ConstMeta, FunctionMeta, InOutMeta]) -> om.PTOperatorMetatype:
def get_meta_type(node_type: str, meta: Union[ConstMeta, FunctionMeta, InOutMeta]) -> type[om.PTOperatorMetatype]:
"""
Converts the node type and metadata into a PTOperatorMetatype object.
:param node_type: The type of the node.
:param meta: The metadata associated with the node.
:return: The PTOperatorMetatype object.
"""
node_metatype = cast(om.PTOperatorMetatype, om.PT_OPERATOR_METATYPES.get_operator_metatype_by_op_name(node_type))
node_sub_meta_type: Optional[om.PTOperatorMetatype] = None
node_metatype = cast(
type[om.PTOperatorMetatype], om.PT_OPERATOR_METATYPES.get_operator_metatype_by_op_name(node_type)
)
node_sub_meta_type: Optional[type[om.PTOperatorMetatype]] = None
if node_metatype.get_subtypes() and isinstance(meta, FunctionMeta):
node_sub_meta_type = node_metatype.determine_subtype(function_args=meta.args, functions_kwargs=meta.kwargs)
return node_sub_meta_type or node_metatype


def is_constant_input_node(nx_graph: nx.MultiDiGraph, node: int) -> bool:
"""
Check if a node is a constant input node or constant subgraph:
1) constant
2) quantize_function -> constant
:param nx_graph: The graph to check the node from.
:param node: The node to check.
:return: True if the node is a constant input node, False otherwise.
"""
meta = nx_graph.nodes[node]["meta"]

# 1) Input node is a constant node (parameter or buffer)
if isinstance(meta, ConstMeta):
return True

# 2) Quantize node with constant input
if (
isinstance(meta, FunctionMeta)
and meta.func_name in om.QUANTIZE_NODE_TYPES
and isinstance(nx_graph.nodes[node]["meta"], FunctionMeta)
):
return all(isinstance(nx_graph.nodes[s_node]["meta"], ConstMeta) for s_node, _ in nx_graph.in_edges(node))

return False


def get_constant_port_ids(nx_graph: nx.MultiDiGraph, node: int) -> Set[int]:
"""
Get the indices of input ports corresponding to the constant node or subgraph.
:param nx_graph: The graph to get the constant port IDs from.
:param node: The node to get the constant port IDs from.
:return: The list of input port indices with constants.
"""
constant_port_ids: Set[int] = set()

for s_node, _, data in nx_graph.in_edges(node, data=True):
if is_constant_input_node(nx_graph, s_node):
meta = cast(EdgeMeta, data["meta"])
constant_port_ids.add(meta.input_port)

return constant_port_ids


def get_layer_attributes(
nx_graph: nx.MultiDiGraph, node: int, meta: Union[ConstMeta, FunctionMeta, InOutMeta]
) -> Optional[BaseLayerAttributes]:
"""
Get the layer attributes of a node in the graph.
:param nx_graph: The graph to get the layer attributes from.
:param node: The node to get the layer attributes from.
:param meta: The metadata associated with the node.
:return: The layer attributes of the node.
"""
if isinstance(meta, FunctionMeta):
constant_port_ids = get_constant_port_ids(nx_graph, node)
return PT2OpLayerAttributes(meta.func, meta.args, meta.kwargs, constant_port_ids)

return None


def convert_to_nncf_graph(nx_graph: nx.MultiDiGraph) -> NNCFGraph:
"""
Converts a graph to an NNCFGraph.
Expand All @@ -102,15 +170,18 @@ def convert_to_nncf_graph(nx_graph: nx.MultiDiGraph) -> NNCFGraph:

map_nx_node_to_nncf_node: Dict[int, NNCFNode] = {}
for node, data in nx_graph.nodes(data=True):
meta: Union[ConstMeta, FunctionMeta, InOutMeta] = data["meta"]
meta = data["meta"]
if not isinstance(meta, (ConstMeta, FunctionMeta, InOutMeta)):
raise nncf.InternalError(f"Unknown metadata type: {type(meta)}")
node_name = get_name_of_node(meta)
node_type = get_node_type(data["type"], meta)
meta_type = get_meta_type(node_type, meta)

layer_attributes = get_layer_attributes(nx_graph, node, meta)
nncf_node = nncf_graph.add_nncf_node(
node_name=node_name,
node_type=node_type,
node_metatype=meta_type, # type: ignore[arg-type]
node_metatype=meta_type,
layer_attributes=layer_attributes,
)
map_nx_node_to_nncf_node[node] = nncf_node

Expand Down
2 changes: 1 addition & 1 deletion nncf/torch/graph/operator_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def get_all_aliases(cls) -> List[str]:
@classmethod
def determine_subtype(
cls, layer_attributes: Optional[BaseLayerAttributes] = None, function_args=None, functions_kwargs=None
) -> Optional["PTOperatorSubtype"]:
) -> Optional["type[PTOperatorSubtype]"]:
matches = []
for subtype in cls.get_subtypes():
if subtype.matches(layer_attributes, function_args, functions_kwargs):
Expand Down
6 changes: 3 additions & 3 deletions tests/torch2/function_hook/graph/test_build_graph_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def test_execute_pre_hooks():
"type": NodeType.fn_call,
"meta": FunctionMeta(
op_name="/relu/0",
fn_name="relu",
func=torch.relu,
args=(
TensorMeta(dtype=torch.float32, shape=(1,), requires_grad=False),
TensorMeta(dtype=torch.float32, shape=(1, 1, 1, 1), requires_grad=True),
Expand Down Expand Up @@ -190,14 +190,14 @@ def test_tensor_attributes(attr):
if attr == ".T":
ref_meta = FunctionMeta(
op_name="/__get__/0",
fn_name="permute",
func=torch.permute,
args=(TensorMeta(dtype=torch.float32, shape=(2, 3), requires_grad=True),),
kwargs={"dims": (1, 0)},
)
else:
ref_meta = FunctionMeta(
op_name="/__get__/0",
fn_name="transpose",
func=torch.transpose,
args=(TensorMeta(dtype=torch.float32, shape=(2, 3), requires_grad=True),),
kwargs={"dim0": -2, "dim1": -1},
)
Expand Down
33 changes: 33 additions & 0 deletions tests/torch2/function_hook/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,36 @@ def __init__(self) -> None:
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + x
return x


class MatMulLeft(torch.nn.Module):
def __init__(self):
super().__init__()
self.w = torch.nn.Parameter(torch.tensor([1], dtype=torch.float32))

@staticmethod
def get_example_inputs():
return torch.ones([1, 1])

def forward(self, x):
return torch.matmul(x, self.w)


class MatMulRight(MatMulLeft):
def forward(self, x):
return torch.matmul(self.w, x)


class QuantizedConvModel(nn.Module):
@staticmethod
def get_example_inputs():
return torch.ones([1, 1, 3, 3])

def __init__(self) -> None:
super().__init__()
self.conv = nn.Conv2d(1, 1, 1)

def forward(self, x: torch.Tensor):
x = self.conv(x)
x = torch.relu(x)
return x
Loading

0 comments on commit cdf7208

Please sign in to comment.