Skip to content

Commit

Permalink
[Feature] support MMRotate model with le135 (#788)
Browse files Browse the repository at this point in the history
* support MMRotate model with le135

* cse before fuse select assign

* remove unused import
  • Loading branch information
q.yao authored Jul 25, 2022
1 parent 5b31d7a commit 0e1a3aa
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <torch/csrc/jit/passes/dead_code_elimination.h>

#include "../../ir/subgraph_matcher.h"
#include "common_subgraph_elimination.h"
#include "torch/csrc/jit/ir/irparser.h"

namespace mmdeploy {
Expand Down Expand Up @@ -126,14 +127,16 @@ void FuseSelectAssign(Block* block, std::unordered_map<std::string, Tensor>& par

void FuseSelectAssign(std::shared_ptr<Graph>& graph,
std::unordered_map<std::string, Tensor>& params) {
// cse before search
CommonSubgraphElimination(graph, params);

std::string pattern_str = R"IR(
graph(%y, %z, %cmp_1, %cmp_2, %start, %axes):
graph(%y, %z, %cmp_1, %cmp_2, %start, %axes, %shape_2):
%nz_1 = onnx::NonZero(%cmp_1)
%trans_1 = onnx::Transpose(%nz_1)
%gather_1 = onnx::GatherND(%z, %trans_1)
%reshape_1_shape = onnx::Constant()
%reshape_1 = onnx::Reshape(%gather_1, %reshape_1_shape)
%shape_2 = onnx::Shape(%y)
%expand_2 = onnx::Expand(%cmp_2, %shape_2)
%nz_2 = onnx::NonZero(%expand_2)
%trans_2 = onnx::Transpose(%nz_2)
Expand Down
62 changes: 62 additions & 0 deletions mmdeploy/codebase/mmrotate/core/bbox/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,65 @@ def poly2obb_le90__tensorrt(ctx, polys: torch.Tensor) -> torch.Tensor:
width, _ = torch.max(edges, 1)
height, _ = torch.min(edges, 1)
return torch.stack([x_ctr, y_ctr, width, height, angles], 1)


@FUNCTION_REWRITER.register_rewriter(
func_name='mmrotate.core.bbox.transforms.poly2obb_le135')
def poly2obb_le135__default(ctx, polys):
"""This is a rewrite for poly2obb to remove NonZero ops.
Args:
polys (torch.Tensor): [x0,y0,x1,y1,x2,y2,x3,y3]
Returns:
obbs (torch.Tensor): [x_ctr,y_ctr,w,h,angle]
"""
polys = torch.reshape(polys, [-1, 8])
pt1, pt2, pt3, pt4 = polys[..., :8].chunk(4, 1)
edge1 = torch.sqrt(
torch.pow(pt1[..., 0] - pt2[..., 0], 2) +
torch.pow(pt1[..., 1] - pt2[..., 1], 2))
edge2 = torch.sqrt(
torch.pow(pt2[..., 0] - pt3[..., 0], 2) +
torch.pow(pt2[..., 1] - pt3[..., 1], 2))
angles1 = torch.atan2((pt2[..., 1] - pt1[..., 1]),
(pt2[..., 0] - pt1[..., 0]))
angles2 = torch.atan2((pt4[..., 1] - pt1[..., 1]),
(pt4[..., 0] - pt1[..., 0]))
angles = torch.where(edge1 > edge2, angles1, angles2)
angles = norm_angle(angles, 'le135')
x_ctr = (pt1[..., 0] + pt3[..., 0]) / 2.0
y_ctr = (pt1[..., 1] + pt3[..., 1]) / 2.0
edges = torch.stack([edge1, edge2], dim=1)
width, _ = torch.max(edges, 1)
height, _ = torch.min(edges, 1)
return torch.stack([x_ctr, y_ctr, width, height, angles], 1)


@FUNCTION_REWRITER.register_rewriter(
func_name='mmrotate.core.bbox.transforms.obb2poly_le135')
def obb2poly_le135__default(ctx, rboxes):
"""Support batched input.
Args:
ctx : context of rewriter
obbs (torch.Tensor): [x_ctr,y_ctr,w,h,angle]
Returns:
polys (torch.Tensor): [x0,y0,x1,y1,x2,y2,x3,y3]
"""
B, N = rboxes.shape[:2]
x_ctr, y_ctr, width, height, angle = rboxes[..., 0], rboxes[
..., 1], rboxes[..., 2], rboxes[..., 3], rboxes[..., 4]
tl_x, tl_y, br_x, br_y = \
-width * 0.5, -height * 0.5, \
width * 0.5, height * 0.5
rects = torch.stack([tl_x, br_x, br_x, tl_x, tl_y, tl_y, br_y, br_y],
dim=-1).reshape(B, N, 2, 4)
sin, cos = torch.sin(angle), torch.cos(angle)
M = torch.stack([cos, -sin, sin, cos], dim=-1).reshape(B, N, 2, 2)
polys = M.matmul(rects).permute(0, 1, 3, 2)
xy_ctr = torch.stack([x_ctr, y_ctr], dim=-1).unsqueeze(-2)
polys += xy_ctr
polys = polys.reshape(B, N, 8)
return polys.contiguous()
70 changes: 70 additions & 0 deletions tests/test_codebase/test_mmrotate/test_mmrotate_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,76 @@ def poly2obb_le90(*args, **kwargs):
assert rewrite_outputs is not None


@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
def test_poly2obb_le135(backend_type: Backend):
check_backend(backend_type)
polys = torch.rand(1, 10, 8)
deploy_cfg = mmcv.Config(
dict(
onnx_config=dict(output_names=None, input_shape=None),
backend_config=dict(
type=backend_type.value,
model_inputs=[
dict(
input_shapes=dict(
polys=dict(
min_shape=polys.shape,
opt_shape=polys.shape,
max_shape=polys.shape)))
]),
codebase_config=dict(type='mmrotate', task='RotatedDetection')))

# wrap function to enable rewrite
def poly2obb_le135(*args, **kwargs):
import mmrotate
return mmrotate.core.bbox.transforms.poly2obb_le135(*args, **kwargs)

# wrap function to nn.Module, enable torch.onnx.export
wrapped_func = WrapFunction(poly2obb_le135)
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_func,
model_inputs={'polys': polys},
deploy_cfg=deploy_cfg,
run_with_backend=False)

assert rewrite_outputs is not None


@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
def test_obb2poly_le135(backend_type: Backend):
check_backend(backend_type)
rboxes = torch.rand(1, 10, 5)
deploy_cfg = mmcv.Config(
dict(
onnx_config=dict(output_names=None, input_shape=None),
backend_config=dict(
type=backend_type.value,
model_inputs=[
dict(
input_shapes=dict(
rboxes=dict(
min_shape=rboxes.shape,
opt_shape=rboxes.shape,
max_shape=rboxes.shape)))
]),
codebase_config=dict(type='mmrotate', task='RotatedDetection')))

# wrap function to enable rewrite
def obb2poly_le135(*args, **kwargs):
import mmrotate
return mmrotate.core.bbox.transforms.obb2poly_le135(*args, **kwargs)

# wrap function to nn.Module, enable torch.onnx.export
wrapped_func = WrapFunction(obb2poly_le135)
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_func,
model_inputs={'rboxes': rboxes},
deploy_cfg=deploy_cfg,
run_with_backend=False)

assert rewrite_outputs is not None


@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
def test_gvfixcoder__decode(backend_type: Backend):
check_backend(backend_type)
Expand Down

0 comments on commit 0e1a3aa

Please sign in to comment.