Skip to content

Commit

Permalink
Add a pass to replace nodes with empty tensors with full.
Browse files Browse the repository at this point in the history
Differential Revision: D68907459

Pull Request resolved: #8130
  • Loading branch information
hsharma35 authored Feb 4, 2025
1 parent b02c692 commit 9b3f2ba
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 5 deletions.
10 changes: 8 additions & 2 deletions backends/cadence/aot/graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,16 @@
from typing import Optional, Sequence, Union

import torch
from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue
from executorch.exir.pass_base import (
Argument,
ExportPass,
NodeMetadata,
PassResult,
ProxyValue,
)
from torch._dispatch.python import enable_python_dispatcher
from torch._subclasses import FakeTensor, FakeTensorMode
from torch.fx.node import Argument, Target
from torch.fx.node import Target
from torch.utils import _pytree as pytree


Expand Down
22 changes: 22 additions & 0 deletions backends/cadence/aot/replace_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2071,10 +2071,32 @@ def call_operator(
)


@register_cadence_pass(CadencePassAttribute(opt_level=1))
class ReplaceEmptyTensorsWithFullPass(ExportPass):
"""Replaces nodes that produce empty tensors with full nodes."""

def call_operator(self, op, args, kwargs, meta):
val = meta.data.get("val", None)
if isinstance(val, torch.Tensor) and val.numel() == 0:
return super().call_operator(
exir_ops.edge.aten.full.default,
args=(val.shape, 0),
kwargs={"dtype": val.dtype},
meta=meta,
)
return super().call_operator(op, args, kwargs, meta)

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
ret = super().call(graph_module)
modified = ret.graph_module.graph.eliminate_dead_code() or ret.modified
return PassResult(ret.graph_module, modified)


# This class encapsulates all the functions that replace/switch one op in the
# graph with another.
class CadenceReplaceOpsInGraph:
passes = [
ReplaceEmptyTensorsWithFullPass,
ReplaceFunctionallyEquivalentOpTargets,
ReplaceTCopyWithTransposePass,
ReplacePermuteWithTransposePass,
Expand Down
3 changes: 1 addition & 2 deletions backends/cadence/aot/tests/test_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def test_graph_with_single_im2row(self) -> None:
channels_last = False
im2row = builder.call_operator(
exir_ops.edge.cadence.im2row.default,
# pyre-ignore
(
x,
(2, 2),
Expand Down Expand Up @@ -80,7 +79,7 @@ def _get_inner_graph(self, x_shape: Sequence[int]) -> torch.fx.GraphModule:
x = builder.placeholder("x", torch.randn(*x_shape))
add = builder.call_operator(
exir_ops.edge.aten.add.Tensor,
(x, x), # pyre-ignore
(x, x),
)
builder.output([x, add])
gm = builder.get_graph_module()
Expand Down
60 changes: 59 additions & 1 deletion backends/cadence/aot/tests/test_replace_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
import torch.nn.functional as F
from executorch.backends.cadence.aot import compiler
from executorch.backends.cadence.aot.compiler import export_to_edge, quantize_pt2
from executorch.backends.cadence.aot.graph_builder import single_op_builder
from executorch.backends.cadence.aot.graph_builder import (
GraphBuilder,
single_op_builder,
)
from executorch.backends.cadence.aot.pass_utils import count_node
from executorch.backends.cadence.aot.replace_ops import (
ForceChannelLastForConvPass,
Expand All @@ -18,6 +21,7 @@
ReplaceConstantPadNdWithSlicePass,
ReplaceConvolutionOptionalArgsWithConcreteArgsPass,
ReplaceConvWithIm2RowAndLinear,
ReplaceEmptyTensorsWithFullPass,
ReplaceFunctionallyEquivalentOpTargets,
ReplaceIm2RowWithViewPass,
ReplaceLinearWithFullyConnectedOpPass,
Expand Down Expand Up @@ -1681,3 +1685,57 @@ def test_cat_insert_transpose(self):
count_node(gm_after_pass, exir_ops.edge.aten.transpose_copy.int),
3,
)


class TestReplaceEmptyTensorsWithFullPass(unittest.TestCase):
def _get_slice_empty_gm(self) -> torch.fx.GraphModule:
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(4))
# This is empty (numel == 0).
slice0 = builder.call_operator(
exir_ops.edge.aten.slice_copy.Tensor, (x, 0, 0, 0)
)
# Copy of x.
slice1 = builder.call_operator(exir_ops.edge.aten.slice_copy.Tensor, (x,))
cat = builder.call_operator(
exir_ops.edge.aten.cat.default,
((slice0, slice1),),
)
builder.output([cat])
return builder.get_graph_module()

def test_empty_slice(self):
gm = self._get_slice_empty_gm()
self.assertEqual(
len(
gm.graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
)
),
2,
)
self.assertEqual(
len(
gm.graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.full.default
)
),
0,
)
updated_gm = ReplaceEmptyTensorsWithFullPass()(gm).graph_module
self.assertEqual(
len(
updated_gm.graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
)
),
1,
)
self.assertEqual(
len(
updated_gm.graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.full.default
)
),
1,
)

0 comments on commit 9b3f2ba

Please sign in to comment.