Skip to content

Commit

Permalink
[Fix] fix triu (#792)
Browse files Browse the repository at this point in the history
* fix triu

* triu -> triu_default
  • Loading branch information
AllentDan authored Jul 22, 2022
1 parent 36b3ca4 commit 36c35b6
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 23 deletions.
1 change: 1 addition & 0 deletions mmdeploy/backend/ncnn/init_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def get_onnx2ncnn_path() -> str:

if onnx2ncnn_path is None or not os.path.exists(onnx2ncnn_path):
onnx2ncnn_path = shutil.which('mmdeploy_onnx2ncnn')
onnx2ncnn_path = '' if onnx2ncnn_path is None else onnx2ncnn_path

return onnx2ncnn_path

Expand Down
4 changes: 2 additions & 2 deletions mmdeploy/pytorch/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
from .size import tensor__size__ncnn
from .tensor_setitem import tensor__setitem__default
from .topk import topk__dynamic, topk__tensorrt
from .triu import triu
from .triu import triu__default

__all__ = [
'tensor__getattribute__ncnn', 'group_norm__ncnn', 'interpolate__ncnn',
'interpolate__tensorrt', 'linear__ncnn', 'tensor__repeat__tensorrt',
'tensor__size__ncnn', 'topk__dynamic', 'topk__tensorrt', 'chunk__ncnn',
'triu', 'atan2__default', 'normalize__ncnn', 'expand__ncnn',
'triu__default', 'atan2__default', 'normalize__ncnn', 'expand__ncnn',
'chunk__torchscript', 'masked_fill__onnxruntime',
'tensor__setitem__default'
]
12 changes: 6 additions & 6 deletions mmdeploy/pytorch/functions/triu.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@


@FUNCTION_REWRITER.register_rewriter(func_name='torch.triu')
def triu(ctx,
input: torch.Tensor,
diagonal: int = 0,
*args,
**kwargs) -> torch.Tensor:
def triu__default(ctx,
input: torch.Tensor,
diagonal: int = 0,
*args,
**kwargs) -> torch.Tensor:
"""Rewrite `triu` for exporting model to ONNX."""
assert len(input.shape) >= 2
height, width = input.shape[-2:]
arange0 = torch.arange(width, device=input.device).unsqueeze(0)
arange1 = torch.arange(height, device=input.device).unsqueeze(-1)
mask = arange0 >= arange1
mask = arange0 >= torch.add(arange1, diagonal)
return input * mask
30 changes: 15 additions & 15 deletions tests/test_pytorch/test_pytorch_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,28 +239,28 @@ def model_func(input):

@backend_checker(Backend.TENSORRT)
@pytest.mark.parametrize('shape', [[2, 2], [4, 2], [2, 4], [2, 4, 2]])
def test_triu_trt(shape):
@pytest.mark.parametrize('diagonal', [0, 1, -1])
def test_triu_trt(shape, diagonal):

input = torch.rand(shape)
model_output = torch.triu(input=input, diagonal=diagonal)

def triu_caller(*arg, **kwargs):
return torch.triu(*arg, **kwargs)

wrapped_func = WrapFunction(triu_caller, diagonal=1)
import tempfile

import onnx
wrapped_func = WrapFunction(triu_caller, diagonal=diagonal)
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_func,
model_inputs={'input': input},
deploy_cfg=get_trt_config(['output'], shape=shape),
run_with_backend=True)
if is_backend_output:
rewrite_outputs = rewrite_outputs[0].detach().cpu()

from mmdeploy.core import RewriterContext
onnx_file = tempfile.NamedTemporaryFile(suffix='onnx').name
with RewriterContext(
cfg=get_trt_config('output', shape),
backend=Backend.TENSORRT.value,
opset=11), torch.no_grad():
torch.onnx.export(wrapped_func, input, onnx_file, opset_version=11)
onnx_model = onnx.load(onnx_file)
nodes = onnx_model.graph.node
assert nodes is not None
assert np.allclose(
model_output, rewrite_outputs, rtol=1e-03, atol=1e-05)
else:
assert rewrite_outputs is not None


@backend_checker(Backend.NCNN)
Expand Down

0 comments on commit 36c35b6

Please sign in to comment.