Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch] add multiclass_nms_with_indices layer #27

Merged
merged 3 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@ jobs:
python-version: ${{matrix.py_ver}}
- name: Install dependencies
run: |
pip install tensorflow==${{matrix.tf_ver}}.*
if [ ${{matrix.tf_ver}} == 2.10 ] || [ ${{matrix.tf_ver}} == 2.11 ];then
extra_req='numpy<2'
fi
pip install tensorflow==${{matrix.tf_ver}}.* $extra_req
pip install -r requirements_test.txt
pip list
- name: Run pytest
Expand Down
6 changes: 3 additions & 3 deletions docs/index.html

Large diffs are not rendered by default.

24 changes: 15 additions & 9 deletions docs/sony_custom_layers/keras.html

Large diffs are not rendered by default.

754 changes: 552 additions & 202 deletions docs/sony_custom_layers/pytorch.html

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion sony_custom_layers/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@
if TYPE_CHECKING:
import onnxruntime as ort

__all__ = ['multiclass_nms', 'NMSResults', 'load_custom_ops']
__all__ = ['multiclass_nms', 'NMSResults', 'multiclass_nms_with_indices', 'NMSWithIndicesResults', 'load_custom_ops']

validate_installed_libraries(required_libraries['torch'])

from .object_detection import multiclass_nms, NMSResults # noqa: E402
from .object_detection import multiclass_nms_with_indices, NMSWithIndicesResults # noqa: E402


def load_custom_ops(load_ort: bool = False,
Expand Down
53 changes: 53 additions & 0 deletions sony_custom_layers/pytorch/custom_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# -----------------------------------------------------------------------------
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
from typing import Callable

import torch

from sony_custom_layers.util.import_util import is_compatible

CUSTOM_LIB_NAME = 'sony'
custom_lib = torch.library.Library(CUSTOM_LIB_NAME, "DEF")


def get_op_qualname(torch_op_name):
""" Op qualified name """
return CUSTOM_LIB_NAME + '::' + torch_op_name


def register_op(torch_op_name: str, schema: str, impl: Callable):
"""
Register torch custom op under the custom library.

Args:
torch_op_name: op name to register.
schema: schema for the custom op.
impl: implementation of the custom op.

Returns:
Custom op qualified name.
"""
torch_op_qualname = get_op_qualname(torch_op_name)

custom_lib.define(schema)

if is_compatible('torch>=2.2'):
register_impl = torch.library.impl(torch_op_qualname, 'default')
else:
register_impl = torch.library.impl(custom_lib, torch_op_name)
register_impl(impl)

return torch_op_qualname
9 changes: 8 additions & 1 deletion sony_custom_layers/pytorch/object_detection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,14 @@
# -----------------------------------------------------------------------------

from .nms import multiclass_nms, NMSResults
from .nms_with_indices import multiclass_nms_with_indices, NMSWithIndicesResults

# trigger onnx op registration
from . import nms_onnx

__all__ = ['multiclass_nms', 'NMSResults']
__all__ = [
'multiclass_nms',
'multiclass_nms_with_indices',
'NMSResults',
'NMSWithIndicesResults',
]
168 changes: 33 additions & 135 deletions sony_custom_layers/pytorch/object_detection/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
from typing import Tuple, NamedTuple, Union, Callable
from typing import NamedTuple, Callable

import numpy as np
import torch
from torch import Tensor
import torchvision # noqa: F401 # needed for torch.ops.torchvision

from sony_custom_layers.pytorch.custom_lib import register_op
from sony_custom_layers.pytorch.object_detection.nms_common import _batch_multiclass_nms, SCORES, LABELS
from sony_custom_layers.util.import_util import is_compatible

CUSTOM_LIB_NAME = 'sony'
MULTICLASS_NMS_TORCH_OP = 'multiclass_nms'
MULTICLASS_NMS_TORCH_OP_QUALNAME = CUSTOM_LIB_NAME + '::' + MULTICLASS_NMS_TORCH_OP

__all__ = ['multiclass_nms', 'NMSResults']

Expand All @@ -36,17 +35,19 @@ class NMSResults(NamedTuple):
labels: Tensor
n_valid: Tensor

# Note: convenience methods below are replicated in each Results container, since NamedTuple supports neither adding
# new fields in derived classes nor multiple inheritance, and we want it to behave like a tuple, so no dataclasses.
def detach(self) -> 'NMSResults':
""" Detach all tensors and return a new NMSResults object """
""" Detach all tensors and return a new object """
return self.apply(lambda t: t.detach())

def cpu(self) -> 'NMSResults':
""" Move all tensors to cpu and return a new NMSResults object """
""" Move all tensors to cpu and return a new object """
return self.apply(lambda t: t.cpu())

def apply(self, f: Callable[[Tensor], Tensor]) -> 'NMSResults':
""" Apply any function to all tensors and return a NMSResults new object """
return NMSResults(*[f(t) for t in self])
""" Apply any function to all tensors and return a new object """
return self.__class__(*[f(t) for t in self])


def multiclass_nms(boxes, scores, score_threshold: float, iou_threshold: float, max_detections: int) -> NMSResults:
Expand All @@ -56,6 +57,8 @@ def multiclass_nms(boxes, scores, score_threshold: float, iou_threshold: float,
The output tensors always contain a fixed number of detections, as defined by 'max_detections'.
If fewer detections are selected, the output tensors are zero-padded up to 'max_detections'.

If you also require the input indices of the selected boxes, see `multiclass_nms_with_indices`.

Args:
boxes (Tensor): Input boxes with shape [batch, n_boxes, 4], specified in corner coordinates
(x_min, y_min, x_max, y_max). Agnostic to the x-y axes order.
Expand Down Expand Up @@ -92,32 +95,35 @@ def multiclass_nms(boxes, scores, score_threshold: float, iou_threshold: float,
return NMSResults(*torch.ops.sony.multiclass_nms(boxes, scores, score_threshold, iou_threshold, max_detections))


custom_lib = torch.library.Library(CUSTOM_LIB_NAME, "DEF")
schema = (MULTICLASS_NMS_TORCH_OP +
"(Tensor boxes, Tensor scores, float score_threshold, float iou_threshold, SymInt max_detections) "
"-> (Tensor, Tensor, Tensor, Tensor)")
op_name = custom_lib.define(schema)
######################
# Register custom op #
######################

if is_compatible('torch>=2.2'):
register_impl = torch.library.impl(MULTICLASS_NMS_TORCH_OP_QUALNAME, 'default')
else:
register_impl = torch.library.impl(custom_lib, MULTICLASS_NMS_TORCH_OP)

def _multiclass_nms_impl(boxes: torch.Tensor, scores: torch.Tensor, score_threshold: float, iou_threshold: float,
max_detections: int) -> NMSResults:
""" This implementation is intended only to be registered as custom torch and onnxruntime op.
NamedTuple is used for clarity, it is not preserved when run through torch / onnxruntime op. """
res, valid_dets = _batch_multiclass_nms(boxes,
scores,
score_threshold=score_threshold,
iou_threshold=iou_threshold,
max_detections=max_detections)
return NMSResults(boxes=res[..., :4],
scores=res[..., SCORES],
labels=res[..., LABELS].to(torch.int64),
n_valid=valid_dets.to(torch.int64))

@register_impl
def _multiclass_nms_op(boxes: torch.Tensor, scores: torch.Tensor, score_threshold: float, iou_threshold: float,
max_detections: int) -> NMSResults:
""" Registers the torch op as torch.ops.sony.multiclass_nms """
return _multiclass_nms_impl(boxes,
scores,
score_threshold=score_threshold,
iou_threshold=iou_threshold,
max_detections=max_detections)

schema = (MULTICLASS_NMS_TORCH_OP +
"(Tensor boxes, Tensor scores, float score_threshold, float iou_threshold, SymInt max_detections) "
"-> (Tensor, Tensor, Tensor, Tensor)")

op_qualname = register_op(MULTICLASS_NMS_TORCH_OP, schema, _multiclass_nms_impl)

if is_compatible('torch>=2.2'):

@torch.library.impl_abstract(MULTICLASS_NMS_TORCH_OP_QUALNAME)
@torch.library.impl_abstract(op_qualname)
def _multiclass_nms_meta(boxes: torch.Tensor, scores: torch.Tensor, score_threshold: float, iou_threshold: float,
max_detections: int) -> NMSResults:
""" Registers torch op's abstract implementation. It specifies the properties of the output tensors.
Expand All @@ -130,111 +136,3 @@ def _multiclass_nms_meta(boxes: torch.Tensor, scores: torch.Tensor, score_thresh
torch.empty((batch, max_detections), dtype=torch.int64),
torch.empty((batch, 1), dtype=torch.int64)
) # yapf: disable


def _multiclass_nms_impl(boxes: Union[Tensor, np.ndarray], scores: Union[Tensor, np.ndarray], score_threshold: float,
iou_threshold: float, max_detections: int) -> NMSResults:
""" See multiclass_nms """
# this is needed for onnxruntime implementation
if not isinstance(boxes, Tensor):
boxes = Tensor(boxes)
if not isinstance(scores, Tensor):
scores = Tensor(scores)

if not 0 <= score_threshold <= 1:
raise ValueError(f'Invalid score_threshold {score_threshold} not in range [0, 1]')
if not 0 <= iou_threshold <= 1:
raise ValueError(f'Invalid iou_threshold {iou_threshold} not in range [0, 1]')
if max_detections <= 0:
raise ValueError(f'Invalid non-positive max_detections {max_detections}')

if boxes.ndim != 3 or boxes.shape[-1] != 4:
raise ValueError(f'Invalid input boxes shape {boxes.shape}. Expected shape (batch, n_boxes, 4).')
if scores.ndim != 3:
raise ValueError(f'Invalid input scores shape {scores.shape}. Expected shape (batch, n_boxes, n_classes).')
if boxes.shape[-2] != scores.shape[-2]:
raise ValueError(f'Mismatch in the number of boxes between input boxes ({boxes.shape[-2]}) '
f'and scores ({scores.shape[-2]})')

batch = boxes.shape[0]
res = torch.zeros((batch, max_detections, 6), device=boxes.device)
valid_dets = torch.zeros((batch, 1), device=boxes.device)
for i in range(batch):
res[i], valid_dets[i] = _image_multiclass_nms(boxes[i],
scores[i],
score_threshold=score_threshold,
iou_threshold=iou_threshold,
max_detections=max_detections)

return NMSResults(boxes=res[..., :4],
scores=res[..., 4],
labels=res[..., 5].to(torch.int64),
n_valid=valid_dets.to(torch.int64))


def _image_multiclass_nms(boxes: Tensor, scores: Tensor, score_threshold: float, iou_threshold: float,
max_detections: int) -> Tuple[Tensor, int]:
"""
Performs multi-class non-maximum suppression on a single image
Args:
boxes: input boxes of shape [n_boxes, 4]
scores: input scores of shape [n_boxes, n_classes]
score_threshold: score threshold
iou_threshold: intersection over union threshold
max_detections: fixed number of detections to return

Returns:
A tensor of shape [max_detections, 6] and the number of valid detections.
out[:, :4] contains the selected boxes
out[:, 4] and out[:, 5] contain the scores and labels for the selected boxes

"""
x = _convert_inputs(boxes, scores, score_threshold)
out = torch.zeros(max_detections, 6, device=boxes.device)
if x.size(0) == 0:
return out, 0
idxs = _nms_with_class_offsets(x, iou_threshold=iou_threshold)
idxs = idxs[:max_detections]
valid_dets = idxs.numel()
out[:valid_dets] = x[idxs]
return out, valid_dets


def _convert_inputs(boxes: Tensor, scores: Tensor, score_threshold: float) -> Tensor:
"""
Converts inputs and filters out boxes with score below the threshold.
Args:
boxes: input boxes of shape [n_boxes, 4]
scores: input scores of shape [n_boxes, n_classes]
score_threshold: score threshold for nms candidates

Returns:
A tensor of shape [m, 6] containing m nms candidates above the score threshold.
x[:, :4] contains the boxes with replication for different labels
x[:, 4] contains the scores
x[:, 5] contains the labels indices (label i corresponds to input scores[:, i])
"""
n_boxes, n_classes = scores.shape
scores_mask = scores > score_threshold
box_indices = torch.arange(n_boxes, device=boxes.device).unsqueeze(1).expand(-1, n_classes)[scores_mask]
x = torch.empty((box_indices.numel(), 6), device=boxes.device)
x[:, :4] = boxes[box_indices]
x[:, 4] = scores[scores_mask]
x[:, 5] = torch.arange(n_classes, device=boxes.device).unsqueeze(0).expand(n_boxes, -1)[scores_mask]
return x


def _nms_with_class_offsets(x: Tensor, iou_threshold: float) -> Tensor:
"""
Args:
x: nms candidates of shape [n, 6] ([:,:4] boxes, [:, 4] scores, [:, 5] labels)
iou_threshold: intersection over union threshold

Returns:
Indices of the selected candidates
"""
# shift boxes of each class to prevent intersection between boxes of different classes, and use single-class nms
# (similar to torchvision batched_nms trick)
offsets = x[:, 5:] * (x[:, :4].max() + 1)
shifted_boxes = x[:, :4] + offsets
return torch.ops.torchvision.nms(shifted_boxes, x[:, 4], iou_threshold)
Loading