Skip to content

Commit

Permalink
Added tests for fx graph tracing
Browse files Browse the repository at this point in the history
  • Loading branch information
Jegp committed Dec 15, 2024
1 parent ad31333 commit 92b4fd0
Show file tree
Hide file tree
Showing 4 changed files with 350 additions and 42 deletions.
2 changes: 1 addition & 1 deletion nirtorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .from_nir import load # noqa F401
from .graph import extract_torch_graph # noqa F401
from .to_nir import extract_nir_graph # noqa F401
from .to_nir import extract_nir_graph, trace_nir_graph # noqa F401

__version__ = version = "1.0"
156 changes: 116 additions & 40 deletions nirtorch/graph_fx.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,14 @@
from typing import Any, Callable, Dict, Set, Tuple, TypeAlias
from typing import Any, Callable, Dict, Set, Tuple, TypeAlias, Optional

Check failure on line 1 in nirtorch/graph_fx.py

View workflow job for this annotation

GitHub Actions / Build on ubuntu-latest

Ruff (F401)

nirtorch/graph_fx.py:1:53: F401 `typing.TypeAlias` imported but unused
import operator

import numpy as np
from torch.nn.modules import Module

Check failure on line 5 in nirtorch/graph_fx.py

View workflow job for this annotation

GitHub Actions / Build on ubuntu-latest

Ruff (F401)

nirtorch/graph_fx.py:5:30: F401 `torch.nn.modules.Module` imported but unused

import nir
import torch
from torch.fx import GraphModule, Tracer, Transformer
from torch.fx import Graph, GraphModule, Node, Tracer, Transformer

Check failure on line 9 in nirtorch/graph_fx.py

View workflow job for this annotation

GitHub Actions / Build on ubuntu-latest

Ruff (F401)

nirtorch/graph_fx.py:9:22: F401 `torch.fx.Graph` imported but unused
from torch.fx.passes import shape_prop

Check failure on line 10 in nirtorch/graph_fx.py

View workflow job for this annotation

GitHub Actions / Build on ubuntu-latest

Ruff (F401)

nirtorch/graph_fx.py:10:29: F401 `torch.fx.passes.shape_prop` imported but unused

NIRTORCH_MAPPING: TypeAlias = Dict[
torch.nn.Module, Callable[[torch.nn.Module], nir.NIRNode]
]

DEFAULT_DICT: NIRTORCH_MAPPING = {
torch.nn.Linear: (
lambda module: nir.Affine(
module.weight.detach().numpy(), module.bias.detach().numpy()
)
)
}


class NIRTorchTracer(Tracer):

Expand Down Expand Up @@ -55,11 +44,9 @@ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool
if hasattr(m, "_is_leaf_module") and m._is_leaf_module:
return True

# return m.__module__.startswith("torch.nn") and not isinstance(
# m, torch.nn.Sequential
# )
return super().is_leaf_module(m, module_qualified_name)


class NIRTorchTransformer(Transformer):
def call_function(self, target: str, args: Tuple, kwargs: Dict) -> Any:
print("sup", target)
Expand All @@ -73,58 +60,147 @@ def call_module(self, target, args, kwargs):
return super().call_module(target, args, kwargs)


def trace_pytorch_graph(
module: torch.nn.Module, module_map: NIRTORCH_MAPPING, use_default_dict: bool = True
def trace_torch_graph(
module: torch.nn.Module,
module_map: Dict[torch.nn.Module, Callable[[torch.nn.Module], nir.NIRNode]],
default_dict: Optional[
Dict[torch.nn.Module, Callable[[torch.nn.Module], nir.NIRNode]]
] = None,
) -> nir.NIRGraph:
# Merge the default dictionary, if requested
if use_default_dict:
module_map = module_map | DEFAULT_DICT
# Cover the edge case that the
"""
Traces a PyTorch module and converts it to a NIR graph using the specified module map.
Args:
module (torch.nn.Module): The module of interest
module_map (Dict[torch.nn.Module, Callable[[torch.nn.Module], nir.NIRNode]]): A dictionary that maps
a given module type to a function that can convert the model to an NIRNode type
default_dict (Optional[Dict[torch.nn.Module, Callable[[torch.nn.Module], nir.NIRNode]]]): An optional dictionary that maps
a given module type to a function that can convert the model to an NIRNode type. This dictionary is merged
with the module_map dictionary.
"""
# Merge the default dictionary, if it exists
if default_dict is not None:
module_map = module_map | default_dict

# Cover the edge case that the incoming module is a leaf node
if module.__class__ in module_map:
return module_map[module.__class__](module)

# Trace the graph
tracer = NIRTorchTracer(module_map.keys())
traced = tracer.trace(module)

graph_module = GraphModule(tracer.root, traced)
if len(traced.nodes) == 2 and len(list(tracer.root.children())) == 0:
raise ValueError(
"The module is a leaf node, but does not appear in the module map. We cannot trace it further"
)

# transformer = NIRTorchTransformer(graph_module)
# transformed = transformer.transform()
# print(transformed)
shapes = shape_prop.ShapeProp(graph_module)
graph_module = GraphModule(tracer.root, traced)

# Create NIR nodes
nodes = {}
edges = []
ignored_nodes = set()
skipped_nodes = set()

def _find_users(node: Node) -> Set[Node]:
"""
Finds all the users (outputs) of a given node, recursively if the node is registered as a skipped node
"""
nodes = set()
for user in node.users:
if user in ignored_nodes:
continue
elif user in skipped_nodes:
nodes |= _find_users(user)
else:
nodes.add(user)
return nodes

def _find_inputs(node: Node) -> Set[Node]:
"""
Finds all the inputs (inputs) of a given node, recursively if the node is registered as a skipped node
"""
nodes = set()
for in_node in node.all_input_nodes:
if in_node in ignored_nodes:
continue
elif in_node in skipped_nodes:
nodes |= _find_inputs(in_node)
else:
nodes.add(in_node)
return nodes

for node in traced.nodes:
# Add Node
if node.op == "placeholder":
if node.target == "input":
module = nir.Input(np.array([1]))
if node.target == "input" or node.prev.op == "root":
nodes[str(node.name)] = nir.Input(np.array([1]))
else:
ignored_nodes.add(node)
continue
elif node.op == "output":
module = nir.Output(np.array([1]))
nodes[str(node.name)] = nir.Output(np.array([1]))
elif node.op == "call_function":
# Ensure that we skip add nodes
# TODO: Consider using transformations for this
# https://pytorch.org/docs/stable/fx.html#torch.fx.Transformer
if node.target == operator.add:
skipped_nodes.add(node)
# Raise a warning if we encounter other methods than addition
else:
raise ValueError(
"The only supported function is addition. Please modify your model or raise an issue on GitHub"
)
elif node.op == "call_method":
# Skip add methods
if node.target == "add":
skipped_nodes.add(node)
else:
raise ValueError(
"The only supported method is addition. Please modify your model or raise an issue on GitHub"
)
elif node.op == "call_module":
torch_module = graph_module.get_submodule(node.target)
nir_module = module_map[torch_module.__class__](torch_module)
nodes[str(node.name)] = nir_module
elif node.op == "get_attr":
# Skip attribute
skipped_nodes.add(node)
else:
module = graph_module.get_submodule(node.target)
module = module_map[module.__class__](module)
nodes[node.name] = module
raise ValueError(
f"Unsupported operation {node.op}. Please modify your model or raise an issue on GitHub"
)

# Create edges
# - This is done in a separate loop to ensure that we correctly ignore the edges in case the nodes
# are ignored out-of-order
for node in traced.nodes:
if node in ignored_nodes:
continue

# Add edges
for in_edge in node.all_input_nodes:
edges.append((in_edge, node.name))
for in_node in node.all_input_nodes:
if in_node in ignored_nodes or in_node in skipped_nodes:
continue
# If the function is set to be skipped, we simply forward the input to all the outputs
if node in skipped_nodes:
for next_node in _find_users(node):
edges.append((in_node.name, next_node.name))
# Ignore additions as incoming edges
elif in_node.op == "call_function" and in_node.target == operator.add:
break
# Otherwise, add an edge
elif in_node not in ignored_nodes:
edges.append((in_node.name, node.name))
graph = nir.NIRGraph(nodes=nodes, edges=edges)
graph.infer_types()
return graph

# Create NIR edges



if __name__ == "__main__":
module = torch.nn.Sequential(torch.nn.Linear(2, 1), torch.nn.Linear(1, 1))
graph = trace_pytorch_graph(module, DEFAULT_DICT)
graph = trace_torch_graph(module)

import pprint

Expand Down
44 changes: 43 additions & 1 deletion nirtorch/to_nir.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
import logging
from typing import Any, Callable, Optional, Sequence
from typing import Any, Callable, Dict, Optional, Sequence
import warnings

import nir
import numpy as np
import torch.nn as nn

from .graph import extract_torch_graph
from .graph_fx import trace_torch_graph

Check failure on line 10 in nirtorch/to_nir.py

View workflow job for this annotation

GitHub Actions / Build on ubuntu-latest

Ruff (F401)

nirtorch/to_nir.py:10:23: F401 `.graph_fx.trace_torch_graph` imported but unused


DEFAULT_MAPS: Dict[nn.Module, Callable[[nn.Module], nir.NIRNode]] = {
nn.Linear: (
lambda module: nir.Affine(
module.weight.detach().numpy(), module.bias.detach().numpy()
)
)
}


def extract_nir_graph(
Expand Down Expand Up @@ -39,6 +50,11 @@ def extract_nir_graph(
Returns:
nir.NIR: Returns the generated NIR graph representation.
"""
warnings.warn(
"extract_nir_graph is deprecated, use trace_nir_graph instead",
DeprecationWarning,
stacklevel=2,
)

if len(list(model.children())):
# If the model has submodules, ignore the top level module
Expand Down Expand Up @@ -152,3 +168,29 @@ def extract_nir_graph(
nir_edges.remove(edge)

return nir.NIRGraph(nir_nodes, nir_edges)


def trace_nir_graph(
model: nn.Module,
model_map: Dict[nn.Module, Callable[[nn.Module], nir.NIRNode]],
default_map: Optional[Dict[nn.Module, Callable[[nn.Module], nir.NIRNode]]] = DEFAULT_MAPS,
model_name: Optional[str] = "model",
ignore_submodules_of=None,
model_fwd_args=[],
ignore_dims: Optional[Sequence[int]] = None,
) -> nir.NIRNode:
"""
Given a PyTorch `model`, we trace it and recreate a NIR graph using the specified `model_map`.
Args:
model (nn.Module): The model of interest
model_map (Dict[nn.Module, Callable[[nn.Module], nir.NIRNode]]): A dictionary that maps
a given module type to a function that can convert the model to an NIRNode type
model_name (Optional[str], optional): The name of the top level module.
Defaults to "model".
ignore_submodules_of (Optional[Sequence[nn.Module]]): If specified,
the corresponding module's children will not be traversed for graph.
ignore_dims (Optional[Sequence[int]]): Dimensions of data to be ignored for
type/shape inference. Typically the dimensions that you will want to ignore
are for batch and time.
"""
Loading

0 comments on commit 92b4fd0

Please sign in to comment.