Skip to content

Commit

Permalink
修改: src/tvm_book/transforms/yolo.py
Browse files Browse the repository at this point in the history
  • Loading branch information
liuxinwei committed Jul 8, 2024
1 parent 1fa2863 commit cb7230a
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 32 deletions.
10 changes: 4 additions & 6 deletions doc/topics/ultralytics/intro.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1457,17 +1457,15 @@
" yield {ENV[\"input_name\"]: preprocessing(path, **ENV)[1]}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": []
"source": [
"from tvm.relay.dataflow_pattern import rewrite\n",
"from tvm_book.transforms.yolo import Dist2xywhSimplify"
]
},
{
"cell_type": "code",
Expand Down
26 changes: 0 additions & 26 deletions src/tvm_book/transforms/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,32 +10,6 @@
@tvm.relay.transform.function_pass(opt_level=1)
class FuseTransform:
"""替换融合函数为全局函数
.. code-block:: python
def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
"""Transform distance(ltrb) to box(xywh or xyxy)."""
lt, rb = distance.chunk(2, dim)
x1y1 = anchor_points - lt
x2y2 = anchor_points + rb
if xywh:
c_xy = (x1y1 + x2y2) / 2
wh = x2y2 - x1y1
return torch.cat((c_xy, wh), dim) # xywh bbox
return torch.cat((x1y1, x2y2), dim) # xyxy bbox
``dist2bbox(distance, anchor_points, xywh=True, dim=-1)`` 等价于:
.. code-block:: python
def dist2bbox2(distance, anchor_points, xywh=True, dim=-1):
"""Transform distance(ltrb) to box(xywh or xyxy)."""
lt, rb = distance.chunk(2, dim)
if xywh:
wh = rb - lt
c_xy = wh * 0.5 + anchor_points
return torch.cat((c_xy, wh), dim) # xywh bbox
x1y1 = anchor_points - lt
x2y2 = anchor_points + rb
return torch.cat((x1y1, x2y2), dim) # xyxy bbox
"""
def __init__(self):
self.reset()
Expand Down
27 changes: 27 additions & 0 deletions src/tvm_book/transforms/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,33 @@
from tvm.relay import transform as _transform

class Dist2xywhSimplify(DFPatternCallback):
"""
.. code-block:: python
def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
'Transform distance(ltrb) to box(xywh or xyxy).'
lt, rb = distance.chunk(2, dim)
x1y1 = anchor_points - lt
x2y2 = anchor_points + rb
if xywh:
c_xy = (x1y1 + x2y2) / 2
wh = x2y2 - x1y1
return torch.cat((c_xy, wh), dim) # xywh bbox
return torch.cat((x1y1, x2y2), dim) # xyxy bbox
等价于:
.. code-block:: python
def dist2bbox2(distance, anchor_points, xywh=True, dim=-1):
'Transform distance(ltrb) to box(xywh or xyxy).'
lt, rb = distance.chunk(2, dim)
if xywh:
wh = rb - lt
c_xy = wh * 0.5 + anchor_points
return torch.cat((c_xy, wh), dim) # xywh bbox
x1y1 = anchor_points - lt
x2y2 = anchor_points + rb
return torch.cat((x1y1, x2y2), dim) # xyxy bbox
"""
def __init__(self):
super().__init__()
self.x = wildcard()
Expand Down

0 comments on commit cb7230a

Please sign in to comment.