From 36c35b6e882cb082dd8d49a82a6a6f2d98b5b266 Mon Sep 17 00:00:00 2001 From: AllentDan <41138331+AllentDan@users.noreply.github.com> Date: Fri, 22 Jul 2022 16:31:00 +0800 Subject: [PATCH] [Fix] fix triu (#792) * fix triu * triu -> triu_default --- mmdeploy/backend/ncnn/init_plugins.py | 1 + mmdeploy/pytorch/functions/__init__.py | 4 +-- mmdeploy/pytorch/functions/triu.py | 12 ++++---- tests/test_pytorch/test_pytorch_functions.py | 30 ++++++++++---------- 4 files changed, 24 insertions(+), 23 deletions(-) diff --git a/mmdeploy/backend/ncnn/init_plugins.py b/mmdeploy/backend/ncnn/init_plugins.py index 519384be94..e721a3f11a 100644 --- a/mmdeploy/backend/ncnn/init_plugins.py +++ b/mmdeploy/backend/ncnn/init_plugins.py @@ -31,6 +31,7 @@ def get_onnx2ncnn_path() -> str: if onnx2ncnn_path is None or not os.path.exists(onnx2ncnn_path): onnx2ncnn_path = shutil.which('mmdeploy_onnx2ncnn') + onnx2ncnn_path = '' if onnx2ncnn_path is None else onnx2ncnn_path return onnx2ncnn_path diff --git a/mmdeploy/pytorch/functions/__init__.py b/mmdeploy/pytorch/functions/__init__.py index 6a2ac22858..4201942476 100644 --- a/mmdeploy/pytorch/functions/__init__.py +++ b/mmdeploy/pytorch/functions/__init__.py @@ -12,13 +12,13 @@ from .size import tensor__size__ncnn from .tensor_setitem import tensor__setitem__default from .topk import topk__dynamic, topk__tensorrt -from .triu import triu +from .triu import triu__default __all__ = [ 'tensor__getattribute__ncnn', 'group_norm__ncnn', 'interpolate__ncnn', 'interpolate__tensorrt', 'linear__ncnn', 'tensor__repeat__tensorrt', 'tensor__size__ncnn', 'topk__dynamic', 'topk__tensorrt', 'chunk__ncnn', - 'triu', 'atan2__default', 'normalize__ncnn', 'expand__ncnn', + 'triu__default', 'atan2__default', 'normalize__ncnn', 'expand__ncnn', 'chunk__torchscript', 'masked_fill__onnxruntime', 'tensor__setitem__default' ] diff --git a/mmdeploy/pytorch/functions/triu.py b/mmdeploy/pytorch/functions/triu.py index 4a9dc4da06..025b2029ff 100644 --- a/mmdeploy/pytorch/functions/triu.py +++ b/mmdeploy/pytorch/functions/triu.py @@ -5,15 +5,15 @@ @FUNCTION_REWRITER.register_rewriter(func_name='torch.triu') -def triu(ctx, - input: torch.Tensor, - diagonal: int = 0, - *args, - **kwargs) -> torch.Tensor: +def triu__default(ctx, + input: torch.Tensor, + diagonal: int = 0, + *args, + **kwargs) -> torch.Tensor: """Rewrite `triu` for exporting model to ONNX.""" assert len(input.shape) >= 2 height, width = input.shape[-2:] arange0 = torch.arange(width, device=input.device).unsqueeze(0) arange1 = torch.arange(height, device=input.device).unsqueeze(-1) - mask = arange0 >= arange1 + mask = arange0 >= torch.add(arange1, diagonal) return input * mask diff --git a/tests/test_pytorch/test_pytorch_functions.py b/tests/test_pytorch/test_pytorch_functions.py index 2508556cba..65bda0ecf4 100644 --- a/tests/test_pytorch/test_pytorch_functions.py +++ b/tests/test_pytorch/test_pytorch_functions.py @@ -239,28 +239,28 @@ def model_func(input): @backend_checker(Backend.TENSORRT) @pytest.mark.parametrize('shape', [[2, 2], [4, 2], [2, 4], [2, 4, 2]]) -def test_triu_trt(shape): +@pytest.mark.parametrize('diagonal', [0, 1, -1]) +def test_triu_trt(shape, diagonal): input = torch.rand(shape) + model_output = torch.triu(input=input, diagonal=diagonal) def triu_caller(*arg, **kwargs): return torch.triu(*arg, **kwargs) - wrapped_func = WrapFunction(triu_caller, diagonal=1) - import tempfile - - import onnx + wrapped_func = WrapFunction(triu_caller, diagonal=diagonal) + rewrite_outputs, is_backend_output = get_rewrite_outputs( + wrapped_func, + model_inputs={'input': input}, + deploy_cfg=get_trt_config(['output'], shape=shape), + run_with_backend=True) + if is_backend_output: + rewrite_outputs = rewrite_outputs[0].detach().cpu() - from mmdeploy.core import RewriterContext - onnx_file = tempfile.NamedTemporaryFile(suffix='onnx').name - with RewriterContext( - cfg=get_trt_config('output', shape), - backend=Backend.TENSORRT.value, - opset=11), torch.no_grad(): - torch.onnx.export(wrapped_func, input, onnx_file, opset_version=11) - onnx_model = onnx.load(onnx_file) - nodes = onnx_model.graph.node - assert nodes is not None + assert np.allclose( + model_output, rewrite_outputs, rtol=1e-03, atol=1e-05) + else: + assert rewrite_outputs is not None @backend_checker(Backend.NCNN)