From 6e4e70261347fe73ee9f3fb38f3bfab8834c7391 Mon Sep 17 00:00:00 2001 From: irenab Date: Sun, 8 Sep 2024 18:34:37 +0300 Subject: [PATCH 1/3] [torch] add multiclass_nms_with_indices op --- sony_custom_layers/pytorch/__init__.py | 3 +- sony_custom_layers/pytorch/custom_lib.py | 53 +++ .../pytorch/object_detection/__init__.py | 9 +- .../pytorch/object_detection/nms.py | 166 ++-------- .../pytorch/object_detection/nms_common.py | 153 +++++++++ .../pytorch/object_detection/nms_onnx.py | 33 +- .../pytorch/object_detection/nms_ort.py | 18 +- .../object_detection/nms_with_indices.py | 141 ++++++++ .../object_detection/test_multiclass_nms.py | 304 +++++++----------- .../tests/object_detection/test_nms_common.py | 178 ++++++++++ 10 files changed, 732 insertions(+), 326 deletions(-) create mode 100644 sony_custom_layers/pytorch/custom_lib.py create mode 100644 sony_custom_layers/pytorch/object_detection/nms_common.py create mode 100644 sony_custom_layers/pytorch/object_detection/nms_with_indices.py create mode 100644 sony_custom_layers/pytorch/tests/object_detection/test_nms_common.py diff --git a/sony_custom_layers/pytorch/__init__.py b/sony_custom_layers/pytorch/__init__.py index 671cd6c..2450fe7 100644 --- a/sony_custom_layers/pytorch/__init__.py +++ b/sony_custom_layers/pytorch/__init__.py @@ -21,11 +21,12 @@ if TYPE_CHECKING: import onnxruntime as ort -__all__ = ['multiclass_nms', 'NMSResults', 'load_custom_ops'] +__all__ = ['multiclass_nms', 'NMSResults', 'multiclass_nms_with_indices', 'NMSWithIndicesResults', 'load_custom_ops'] validate_installed_libraries(required_libraries['torch']) from .object_detection import multiclass_nms, NMSResults # noqa: E402 +from .object_detection import multiclass_nms_with_indices, NMSWithIndicesResults # noqa: E402 def load_custom_ops(load_ort: bool = False, diff --git a/sony_custom_layers/pytorch/custom_lib.py b/sony_custom_layers/pytorch/custom_lib.py new file mode 100644 index 0000000..9d1ef37 --- /dev/null +++ b/sony_custom_layers/pytorch/custom_lib.py @@ -0,0 +1,53 @@ +# ----------------------------------------------------------------------------- +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ----------------------------------------------------------------------------- +from typing import Callable + +import torch + +from sony_custom_layers.util.import_util import is_compatible + +CUSTOM_LIB_NAME = 'sony' +custom_lib = torch.library.Library(CUSTOM_LIB_NAME, "DEF") + + +def get_op_qualname(torch_op_name): + """ Op qualified name """ + return CUSTOM_LIB_NAME + '::' + torch_op_name + + +def register_op(torch_op_name: str, schema: str, impl: Callable): + """ + Register torch custom op under the custom library. + + Args: + torch_op_name: op name to register. + schema: schema for the custom op. + impl: implementation of the custom op. + + Returns: + Custom op qualified name. + """ + torch_op_qualname = get_op_qualname(torch_op_name) + + custom_lib.define(schema) + + if is_compatible('torch>=2.2'): + register_impl = torch.library.impl(torch_op_qualname, 'default') + else: + register_impl = torch.library.impl(custom_lib, torch_op_name) + register_impl(impl) + + return torch_op_qualname diff --git a/sony_custom_layers/pytorch/object_detection/__init__.py b/sony_custom_layers/pytorch/object_detection/__init__.py index df24e21..f7af0c5 100644 --- a/sony_custom_layers/pytorch/object_detection/__init__.py +++ b/sony_custom_layers/pytorch/object_detection/__init__.py @@ -15,7 +15,14 @@ # ----------------------------------------------------------------------------- from .nms import multiclass_nms, NMSResults +from .nms_with_indices import multiclass_nms_with_indices, NMSWithIndicesResults + # trigger onnx op registration from . import nms_onnx -__all__ = ['multiclass_nms', 'NMSResults'] +__all__ = [ + 'multiclass_nms', + 'multiclass_nms_with_indices', + 'NMSResults', + 'NMSWithIndicesResults', +] diff --git a/sony_custom_layers/pytorch/object_detection/nms.py b/sony_custom_layers/pytorch/object_detection/nms.py index e68885b..ab42587 100644 --- a/sony_custom_layers/pytorch/object_detection/nms.py +++ b/sony_custom_layers/pytorch/object_detection/nms.py @@ -13,18 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # ----------------------------------------------------------------------------- -from typing import Tuple, NamedTuple, Union, Callable +from typing import NamedTuple, Callable -import numpy as np import torch from torch import Tensor import torchvision # noqa: F401 # needed for torch.ops.torchvision +from sony_custom_layers.pytorch.custom_lib import register_op +from sony_custom_layers.pytorch.object_detection.nms_common import _batch_multiclass_nms, SCORES, LABELS from sony_custom_layers.util.import_util import is_compatible -CUSTOM_LIB_NAME = 'sony' MULTICLASS_NMS_TORCH_OP = 'multiclass_nms' -MULTICLASS_NMS_TORCH_OP_QUALNAME = CUSTOM_LIB_NAME + '::' + MULTICLASS_NMS_TORCH_OP __all__ = ['multiclass_nms', 'NMSResults'] @@ -36,17 +35,19 @@ class NMSResults(NamedTuple): labels: Tensor n_valid: Tensor + # Note: convenience methods below are replicated in each Results container, since NamedTuple supports neither adding + # new fields in derived classes nor multiple inheritance, and we want it to behave like a tuple, so no dataclasses. def detach(self) -> 'NMSResults': - """ Detach all tensors and return a new NMSResults object """ + """ Detach all tensors and return a new object """ return self.apply(lambda t: t.detach()) def cpu(self) -> 'NMSResults': - """ Move all tensors to cpu and return a new NMSResults object """ + """ Move all tensors to cpu and return a new object """ return self.apply(lambda t: t.cpu()) def apply(self, f: Callable[[Tensor], Tensor]) -> 'NMSResults': - """ Apply any function to all tensors and return a NMSResults new object """ - return NMSResults(*[f(t) for t in self]) + """ Apply any function to all tensors and return a new object """ + return self.__class__(*[f(t) for t in self]) def multiclass_nms(boxes, scores, score_threshold: float, iou_threshold: float, max_detections: int) -> NMSResults: @@ -92,32 +93,35 @@ def multiclass_nms(boxes, scores, score_threshold: float, iou_threshold: float, return NMSResults(*torch.ops.sony.multiclass_nms(boxes, scores, score_threshold, iou_threshold, max_detections)) -custom_lib = torch.library.Library(CUSTOM_LIB_NAME, "DEF") -schema = (MULTICLASS_NMS_TORCH_OP + - "(Tensor boxes, Tensor scores, float score_threshold, float iou_threshold, SymInt max_detections) " - "-> (Tensor, Tensor, Tensor, Tensor)") -op_name = custom_lib.define(schema) +###################### +# Register custom op # +###################### -if is_compatible('torch>=2.2'): - register_impl = torch.library.impl(MULTICLASS_NMS_TORCH_OP_QUALNAME, 'default') -else: - register_impl = torch.library.impl(custom_lib, MULTICLASS_NMS_TORCH_OP) + +def _multiclass_nms_impl(boxes: torch.Tensor, scores: torch.Tensor, score_threshold: float, iou_threshold: float, + max_detections: int) -> NMSResults: + """ This implementation is intended only to be registered as custom torch and onnxruntime op. + NamedTuple is used for clarity, it is not preserved when run through torch / onnxruntime op. """ + res, valid_dets = _batch_multiclass_nms(boxes, + scores, + score_threshold=score_threshold, + iou_threshold=iou_threshold, + max_detections=max_detections) + return NMSResults(boxes=res[..., :4], + scores=res[..., SCORES], + labels=res[..., LABELS].to(torch.int64), + n_valid=valid_dets.to(torch.int64)) -@register_impl -def _multiclass_nms_op(boxes: torch.Tensor, scores: torch.Tensor, score_threshold: float, iou_threshold: float, - max_detections: int) -> NMSResults: - """ Registers the torch op as torch.ops.sony.multiclass_nms """ - return _multiclass_nms_impl(boxes, - scores, - score_threshold=score_threshold, - iou_threshold=iou_threshold, - max_detections=max_detections) +schema = (MULTICLASS_NMS_TORCH_OP + + "(Tensor boxes, Tensor scores, float score_threshold, float iou_threshold, SymInt max_detections) " + "-> (Tensor, Tensor, Tensor, Tensor)") +op_qualname = register_op(MULTICLASS_NMS_TORCH_OP, schema, _multiclass_nms_impl) if is_compatible('torch>=2.2'): - @torch.library.impl_abstract(MULTICLASS_NMS_TORCH_OP_QUALNAME) + @torch.library.impl_abstract(op_qualname) def _multiclass_nms_meta(boxes: torch.Tensor, scores: torch.Tensor, score_threshold: float, iou_threshold: float, max_detections: int) -> NMSResults: """ Registers torch op's abstract implementation. It specifies the properties of the output tensors. @@ -130,111 +134,3 @@ def _multiclass_nms_meta(boxes: torch.Tensor, scores: torch.Tensor, score_thresh torch.empty((batch, max_detections), dtype=torch.int64), torch.empty((batch, 1), dtype=torch.int64) ) # yapf: disable - - -def _multiclass_nms_impl(boxes: Union[Tensor, np.ndarray], scores: Union[Tensor, np.ndarray], score_threshold: float, - iou_threshold: float, max_detections: int) -> NMSResults: - """ See multiclass_nms """ - # this is needed for onnxruntime implementation - if not isinstance(boxes, Tensor): - boxes = Tensor(boxes) - if not isinstance(scores, Tensor): - scores = Tensor(scores) - - if not 0 <= score_threshold <= 1: - raise ValueError(f'Invalid score_threshold {score_threshold} not in range [0, 1]') - if not 0 <= iou_threshold <= 1: - raise ValueError(f'Invalid iou_threshold {iou_threshold} not in range [0, 1]') - if max_detections <= 0: - raise ValueError(f'Invalid non-positive max_detections {max_detections}') - - if boxes.ndim != 3 or boxes.shape[-1] != 4: - raise ValueError(f'Invalid input boxes shape {boxes.shape}. Expected shape (batch, n_boxes, 4).') - if scores.ndim != 3: - raise ValueError(f'Invalid input scores shape {scores.shape}. Expected shape (batch, n_boxes, n_classes).') - if boxes.shape[-2] != scores.shape[-2]: - raise ValueError(f'Mismatch in the number of boxes between input boxes ({boxes.shape[-2]}) ' - f'and scores ({scores.shape[-2]})') - - batch = boxes.shape[0] - res = torch.zeros((batch, max_detections, 6), device=boxes.device) - valid_dets = torch.zeros((batch, 1), device=boxes.device) - for i in range(batch): - res[i], valid_dets[i] = _image_multiclass_nms(boxes[i], - scores[i], - score_threshold=score_threshold, - iou_threshold=iou_threshold, - max_detections=max_detections) - - return NMSResults(boxes=res[..., :4], - scores=res[..., 4], - labels=res[..., 5].to(torch.int64), - n_valid=valid_dets.to(torch.int64)) - - -def _image_multiclass_nms(boxes: Tensor, scores: Tensor, score_threshold: float, iou_threshold: float, - max_detections: int) -> Tuple[Tensor, int]: - """ - Performs multi-class non-maximum suppression on a single image - Args: - boxes: input boxes of shape [n_boxes, 4] - scores: input scores of shape [n_boxes, n_classes] - score_threshold: score threshold - iou_threshold: intersection over union threshold - max_detections: fixed number of detections to return - - Returns: - A tensor of shape [max_detections, 6] and the number of valid detections. - out[:, :4] contains the selected boxes - out[:, 4] and out[:, 5] contain the scores and labels for the selected boxes - - """ - x = _convert_inputs(boxes, scores, score_threshold) - out = torch.zeros(max_detections, 6, device=boxes.device) - if x.size(0) == 0: - return out, 0 - idxs = _nms_with_class_offsets(x, iou_threshold=iou_threshold) - idxs = idxs[:max_detections] - valid_dets = idxs.numel() - out[:valid_dets] = x[idxs] - return out, valid_dets - - -def _convert_inputs(boxes: Tensor, scores: Tensor, score_threshold: float) -> Tensor: - """ - Converts inputs and filters out boxes with score below the threshold. - Args: - boxes: input boxes of shape [n_boxes, 4] - scores: input scores of shape [n_boxes, n_classes] - score_threshold: score threshold for nms candidates - - Returns: - A tensor of shape [m, 6] containing m nms candidates above the score threshold. - x[:, :4] contains the boxes with replication for different labels - x[:, 4] contains the scores - x[:, 5] contains the labels indices (label i corresponds to input scores[:, i]) - """ - n_boxes, n_classes = scores.shape - scores_mask = scores > score_threshold - box_indices = torch.arange(n_boxes, device=boxes.device).unsqueeze(1).expand(-1, n_classes)[scores_mask] - x = torch.empty((box_indices.numel(), 6), device=boxes.device) - x[:, :4] = boxes[box_indices] - x[:, 4] = scores[scores_mask] - x[:, 5] = torch.arange(n_classes, device=boxes.device).unsqueeze(0).expand(n_boxes, -1)[scores_mask] - return x - - -def _nms_with_class_offsets(x: Tensor, iou_threshold: float) -> Tensor: - """ - Args: - x: nms candidates of shape [n, 6] ([:,:4] boxes, [:, 4] scores, [:, 5] labels) - iou_threshold: intersection over union threshold - - Returns: - Indices of the selected candidates - """ - # shift boxes of each class to prevent intersection between boxes of different classes, and use single-class nms - # (similar to torchvision batched_nms trick) - offsets = x[:, 5:] * (x[:, :4].max() + 1) - shifted_boxes = x[:, :4] + offsets - return torch.ops.torchvision.nms(shifted_boxes, x[:, 4], iou_threshold) diff --git a/sony_custom_layers/pytorch/object_detection/nms_common.py b/sony_custom_layers/pytorch/object_detection/nms_common.py new file mode 100644 index 0000000..ef45496 --- /dev/null +++ b/sony_custom_layers/pytorch/object_detection/nms_common.py @@ -0,0 +1,153 @@ +# ----------------------------------------------------------------------------- +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ----------------------------------------------------------------------------- +from typing import Union, Tuple + +import numpy as np +import torch +from torch import Tensor + +SCORES = 4 +LABELS = 5 +INDICES = 6 + + +def _batch_multiclass_nms(boxes: Union[Tensor, np.ndarray], scores: Union[Tensor, np.ndarray], score_threshold: float, + iou_threshold: float, max_detections: int) -> Tuple[Tensor, Tensor]: + """ + Performs multi-class non-maximum suppression on a batch of images + + Args: + boxes: input boxes of shape [batch, n_boxes, 4] + scores: input scores of shape [batch, n_boxes, n_classes] + score_threshold: score threshold + iou_threshold: intersection over union threshold + max_detections: fixed number of detections to return + + Returns: + A tuple of two tensors: + - results: A tensor of shape [batch, max_detections, 6] containing the results of multiclass nms. + - valid_dets: A tensor of shape [batch, 1] containing the number of valid detections. + + """ + # this is needed for onnxruntime implementation + if not isinstance(boxes, Tensor): + boxes = Tensor(boxes) + if not isinstance(scores, Tensor): + scores = Tensor(scores) + + if not 0 <= score_threshold <= 1: + raise ValueError(f'Invalid score_threshold {score_threshold} not in range [0, 1]') + if not 0 <= iou_threshold <= 1: + raise ValueError(f'Invalid iou_threshold {iou_threshold} not in range [0, 1]') + if max_detections <= 0: + raise ValueError(f'Invalid non-positive max_detections {max_detections}') + + if boxes.ndim != 3 or boxes.shape[-1] != 4: + raise ValueError(f'Invalid input boxes shape {boxes.shape}. Expected shape (batch, n_boxes, 4).') + if scores.ndim != 3: + raise ValueError(f'Invalid input scores shape {scores.shape}. Expected shape (batch, n_boxes, n_classes).') + if boxes.shape[-2] != scores.shape[-2]: + raise ValueError(f'Mismatch in the number of boxes between input boxes ({boxes.shape[-2]}) ' + f'and scores ({scores.shape[-2]})') + + batch = boxes.shape[0] + results = torch.zeros((batch, max_detections, 7), device=boxes.device) + valid_dets = torch.zeros((batch, 1), device=boxes.device) + for i in range(batch): + results[i], valid_dets[i] = _image_multiclass_nms(boxes[i], + scores[i], + score_threshold=score_threshold, + iou_threshold=iou_threshold, + max_detections=max_detections) + + return results, valid_dets + + +def _image_multiclass_nms(boxes: Tensor, scores: Tensor, score_threshold: float, iou_threshold: float, + max_detections: int) -> Tuple[Tensor, int]: + """ + Performs multi-class non-maximum suppression on a single image + + Args: + boxes: input boxes of shape [n_boxes, 4] + scores: input scores of shape [n_boxes, n_classes] + score_threshold: score threshold + iou_threshold: intersection over union threshold + max_detections: fixed number of detections to return + + Returns: + A tensor 'out' of shape [max_detections, 6] and the number of valid detections. + out[:, :4] contains the selected boxes. + out[:, 4] contains the scores for the selected boxes. + out[:, 5] contains the labels for the selected boxes. + out[:, 6] contains indices of input boxes that have been selected. + + """ + x = _convert_inputs(boxes, scores, score_threshold) + out = torch.zeros(max_detections, 7, device=boxes.device) + if x.size(0) == 0: + return out, 0 + idxs = _nms_with_class_offsets(x[:, :6], iou_threshold=iou_threshold) + idxs = idxs[:max_detections] + valid_dets = idxs.numel() + out[:valid_dets] = x[idxs] + return out, valid_dets + + +def _convert_inputs(boxes: Tensor, scores: Tensor, score_threshold: float) -> Tensor: + """ + Converts inputs into a tensor of candidates and filters out boxes with score below the threshold. + + Args: + boxes: input boxes of shape [n_boxes, 4] + scores: input scores of shape [n_boxes, n_classes] + score_threshold: score threshold for nms candidates + + Returns: + A tensor of shape [m, 6] containing m nms candidates above the score threshold. + x[:, :4] contains the boxes with replication for different labels + x[:, 4] contains the scores + x[:, 5] contains the labels indices (label i corresponds to input scores[:, i]) + x[:, 6] contains the input boxes indices (candidate x[i, :] corresponds to input box boxes[x[i, 6]]). + """ + n_boxes, n_classes = scores.shape + scores_mask = scores > score_threshold + box_indices = torch.arange(n_boxes, device=boxes.device).unsqueeze(1).expand(-1, n_classes)[scores_mask] + x = torch.empty((box_indices.numel(), 7), device=boxes.device) + x[:, :4] = boxes[box_indices] + x[:, SCORES] = scores[scores_mask] + x[:, LABELS] = torch.arange(n_classes, device=boxes.device).unsqueeze(0).expand(n_boxes, -1)[scores_mask] + x[:, INDICES] = box_indices + return x + + +def _nms_with_class_offsets(x: Tensor, iou_threshold: float) -> Tensor: + """ + Multiclass NMS implementation using the single class torchvision op. + Boxes of each class are shifted so that there is no intersection between boxes of different classes + (similarly to torchvision batched_nms trick). + + Args: + x: nms candidates of shape [n, 6] ([:,:4] boxes, [:, 4] scores, [:, 5] labels) + iou_threshold: intersection over union threshold + + Returns: + Indices of the selected candidates + """ + assert x.shape[1] == 6 + offsets = x[:, LABELS:] * (x[:, :4].max() + 1) + shifted_boxes = x[:, :4] + offsets + return torch.ops.torchvision.nms(shifted_boxes, x[:, SCORES], iou_threshold) diff --git a/sony_custom_layers/pytorch/object_detection/nms_onnx.py b/sony_custom_layers/pytorch/object_detection/nms_onnx.py index 83e70a8..25ad590 100644 --- a/sony_custom_layers/pytorch/object_detection/nms_onnx.py +++ b/sony_custom_layers/pytorch/object_detection/nms_onnx.py @@ -15,9 +15,12 @@ # ----------------------------------------------------------------------------- import torch -from .nms import MULTICLASS_NMS_TORCH_OP_QUALNAME +from .nms import MULTICLASS_NMS_TORCH_OP +from .nms_with_indices import MULTICLASS_NMS_WITH_INDICES_TORCH_OP +from ..custom_lib import get_op_qualname MULTICLASS_NMS_ONNX_OP = "Sony::MultiClassNMS" +MULTICLASS_NMS_WITH_INDICES_ONNX_OP = "Sony::MultiClassNMSWithIndices" @torch.onnx.symbolic_helper.parse_args('v', 'v', 'f', 'f', 'i') @@ -42,4 +45,30 @@ def multiclass_nms_onnx(g, boxes, scores, score_threshold, iou_threshold, max_de return outputs -torch.onnx.register_custom_op_symbolic(MULTICLASS_NMS_TORCH_OP_QUALNAME, multiclass_nms_onnx, opset_version=1) +@torch.onnx.symbolic_helper.parse_args('v', 'v', 'f', 'f', 'i') +def multiclass_nms_with_indices_onnx(g, boxes, scores, score_threshold, iou_threshold, max_detections): + outputs = g.op(MULTICLASS_NMS_WITH_INDICES_ONNX_OP, + boxes, + scores, + score_threshold_f=score_threshold, + iou_threshold_f=iou_threshold, + max_detections_i=max_detections, + outputs=5) + # Set output tensors shape and dtype + # Based on examples in https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/python/ + # training/ortmodule/_custom_op_symbolic_registry.py (see cross_entropy_loss) + # This is a hack to set output type that is different from input type. Apparently it cannot be set directly + output_int_type = g.op("Cast", boxes, to_i=torch.onnx.TensorProtoDataType.INT32).type() + batch = torch.onnx.symbolic_helper._get_tensor_dim_size(boxes, 0) + outputs[0].setType(boxes.type().with_sizes([batch, max_detections, 4])) + outputs[1].setType(scores.type().with_sizes([batch, max_detections])) + outputs[2].setType(output_int_type.with_sizes([batch, max_detections])) + outputs[3].setType(output_int_type.with_sizes([batch, max_detections])) + outputs[4].setType(output_int_type.with_sizes([batch, 1])) + return outputs + + +torch.onnx.register_custom_op_symbolic(get_op_qualname(MULTICLASS_NMS_TORCH_OP), multiclass_nms_onnx, opset_version=1) +torch.onnx.register_custom_op_symbolic(get_op_qualname(MULTICLASS_NMS_WITH_INDICES_TORCH_OP), + multiclass_nms_with_indices_onnx, + opset_version=1) diff --git a/sony_custom_layers/pytorch/object_detection/nms_ort.py b/sony_custom_layers/pytorch/object_detection/nms_ort.py index a139068..d1f2ff6 100644 --- a/sony_custom_layers/pytorch/object_detection/nms_ort.py +++ b/sony_custom_layers/pytorch/object_detection/nms_ort.py @@ -16,7 +16,8 @@ from onnxruntime_extensions import onnx_op, PyCustomOpDef from .nms import _multiclass_nms_impl -from .nms_onnx import MULTICLASS_NMS_ONNX_OP +from .nms_with_indices import _multiclass_nms_with_indices_impl +from .nms_onnx import MULTICLASS_NMS_ONNX_OP, MULTICLASS_NMS_WITH_INDICES_ONNX_OP @onnx_op(op_type=MULTICLASS_NMS_ONNX_OP, @@ -29,3 +30,18 @@ }) def multiclass_nms_ort(boxes, scores, score_threshold, iou_threshold, max_detections): return _multiclass_nms_impl(boxes, scores, score_threshold, iou_threshold, max_detections) + + +@onnx_op(op_type=MULTICLASS_NMS_WITH_INDICES_ONNX_OP, + inputs=[PyCustomOpDef.dt_float, PyCustomOpDef.dt_float], + outputs=[ + PyCustomOpDef.dt_float, PyCustomOpDef.dt_float, PyCustomOpDef.dt_int32, PyCustomOpDef.dt_int32, + PyCustomOpDef.dt_int32 + ], + attrs={ + "score_threshold": PyCustomOpDef.dt_float, + "iou_threshold": PyCustomOpDef.dt_float, + "max_detections": PyCustomOpDef.dt_int64, + }) +def multiclass_nms_with_indices_ort(boxes, scores, score_threshold, iou_threshold, max_detections): + return _multiclass_nms_with_indices_impl(boxes, scores, score_threshold, iou_threshold, max_detections) diff --git a/sony_custom_layers/pytorch/object_detection/nms_with_indices.py b/sony_custom_layers/pytorch/object_detection/nms_with_indices.py new file mode 100644 index 0000000..2f92eb3 --- /dev/null +++ b/sony_custom_layers/pytorch/object_detection/nms_with_indices.py @@ -0,0 +1,141 @@ +# ----------------------------------------------------------------------------- +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ----------------------------------------------------------------------------- +from typing import Callable, NamedTuple + +import torch +from torch import Tensor + +from sony_custom_layers.util.import_util import is_compatible +from sony_custom_layers.pytorch.custom_lib import register_op +from sony_custom_layers.pytorch.object_detection.nms_common import _batch_multiclass_nms, SCORES, LABELS, INDICES + +__all__ = ['multiclass_nms_with_indices', 'NMSWithIndicesResults'] + +MULTICLASS_NMS_WITH_INDICES_TORCH_OP = 'multiclass_nms_with_indices' + + +class NMSWithIndicesResults(NamedTuple): + """ Container for non-maximum suppression with indices results """ + boxes: Tensor + scores: Tensor + labels: Tensor + indices: Tensor + n_valid: Tensor + + # Note: convenience methods below are replicated in each Results container, since NamedTuple supports neither adding + # new fields in derived classes nor multiple inheritance, and we want it to behave like a tuple, so no dataclasses. + def detach(self) -> 'NMSWithIndicesResults': + """ Detach all tensors and return a new object """ + return self.apply(lambda t: t.detach()) + + def cpu(self) -> 'NMSWithIndicesResults': + """ Move all tensors to cpu and return a new object """ + return self.apply(lambda t: t.cpu()) + + def apply(self, f: Callable[[Tensor], Tensor]) -> 'NMSWithIndicesResults': + """ Apply any function to all tensors and return a new object """ + return self.__class__(*[f(t) for t in self]) + + +def multiclass_nms_with_indices(boxes, scores, score_threshold: float, iou_threshold: float, + max_detections: int) -> NMSWithIndicesResults: + """ + Multi-class non-maximum suppression with indices. + Detections are returned in descending order of their scores. + The output tensors always contain a fixed number of detections, as defined by 'max_detections'. + If fewer detections are selected, the output tensors are zero-padded up to 'max_detections'. + + Args: + boxes (Tensor): Input boxes with shape [batch, n_boxes, 4], specified in corner coordinates + (x_min, y_min, x_max, y_max). Agnostic to the x-y axes order. + scores (Tensor): Input scores with shape [batch, n_boxes, n_classes]. + score_threshold (float): The score threshold. Candidates with scores below the threshold are discarded. + iou_threshold (float): The Intersection Over Union (IOU) threshold for boxes overlap. + max_detections (int): The number of detections to return. + + Returns: + 'NMSWithIndicesResults' named tuple: + - boxes: The selected boxes with shape [batch, max_detections, 4]. + - scores: The corresponding scores in descending order with shape [batch, max_detections]. + - labels: The labels for each box with shape [batch, max_detections]. + - indices: Indices of the input boxes that have been selected. + - n_valid: The number of valid detections out of 'max_detections' with shape [batch, 1] + + Raises: + ValueError: If provided with invalid arguments or input tensors with unexpected or non-matching shapes. + + Example: + ``` + from sony_custom_layers.pytorch import multiclass_nms_with_indices + + # batch size=1, 1000 boxes, 50 classes + boxes = torch.rand(1, 1000, 4) + scores = torch.rand(1, 1000, 50) + res = multiclass_nms_with_indices(boxes, + scores, + score_threshold=0.1, + iou_threshold=0.6, + max_detections=300) + # res.boxes, res.scores, res.labels, res.indices, res.n_valid + ``` + """ + return NMSWithIndicesResults( + *torch.ops.sony.multiclass_nms_with_indices(boxes, scores, score_threshold, iou_threshold, max_detections)) + + +###################### +# Register custom op # +###################### + + +def _multiclass_nms_with_indices_impl(boxes: torch.Tensor, scores: torch.Tensor, score_threshold: float, + iou_threshold: float, max_detections: int) -> NMSWithIndicesResults: + """ This implementation is intended only to be registered as custom torch and onnxruntime op. + NamedTuple is used for clarity, it is not preserved when run through torch / onnxruntime op. """ + res, valid_dets = _batch_multiclass_nms(boxes, + scores, + score_threshold=score_threshold, + iou_threshold=iou_threshold, + max_detections=max_detections) + return NMSWithIndicesResults(boxes=res[..., :4], + scores=res[..., SCORES], + labels=res[..., LABELS].to(torch.int64), + indices=res[..., INDICES].to(torch.int64), + n_valid=valid_dets.to(torch.int64)) + + +schema = (MULTICLASS_NMS_WITH_INDICES_TORCH_OP + + "(Tensor boxes, Tensor scores, float score_threshold, float iou_threshold, SymInt max_detections) " + "-> (Tensor, Tensor, Tensor, Tensor, Tensor)") + +op_qualname = register_op(MULTICLASS_NMS_WITH_INDICES_TORCH_OP, schema, _multiclass_nms_with_indices_impl) + +if is_compatible('torch>=2.2'): + + @torch.library.impl_abstract(op_qualname) + def _multiclass_nms_with_indices_meta(boxes: torch.Tensor, scores: torch.Tensor, score_threshold: float, + iou_threshold: float, max_detections: int) -> NMSWithIndicesResults: + """ Registers torch op's abstract implementation. It specifies the properties of the output tensors. + Needed for torch.export """ + ctx = torch.library.get_ctx() + batch = ctx.new_dynamic_size() + return NMSWithIndicesResults( + torch.empty((batch, max_detections, 4)), + torch.empty((batch, max_detections)), + torch.empty((batch, max_detections), dtype=torch.int64), + torch.empty((batch, max_detections), dtype=torch.int64), + torch.empty((batch, 1), dtype=torch.int64) + ) # yapf: disable diff --git a/sony_custom_layers/pytorch/tests/object_detection/test_multiclass_nms.py b/sony_custom_layers/pytorch/tests/object_detection/test_multiclass_nms.py index d3b6925..b58e188 100644 --- a/sony_custom_layers/pytorch/tests/object_detection/test_multiclass_nms.py +++ b/sony_custom_layers/pytorch/tests/object_detection/test_multiclass_nms.py @@ -13,199 +13,130 @@ # See the License for the specific language governing permissions and # limitations under the License. # ----------------------------------------------------------------------------- -from typing import Optional from unittest.mock import Mock import pytest import numpy as np import torch -from torch import Tensor import onnx import onnxruntime as ort -from sony_custom_layers.pytorch.object_detection import nms +from sony_custom_layers.pytorch import multiclass_nms, multiclass_nms_with_indices, NMSResults, NMSWithIndicesResults from sony_custom_layers.pytorch import load_custom_ops +from sony_custom_layers.pytorch.object_detection.nms_common import LABELS, INDICES, SCORES +from sony_custom_layers.pytorch.tests.object_detection.test_nms_common import generate_random_inputs from sony_custom_layers.util.import_util import is_compatible from sony_custom_layers.util.test_util import exec_in_clean_process class NMS(torch.nn.Module): - def __init__(self, score_threshold, iou_threshold, max_detections): + def __init__(self, score_threshold, iou_threshold, max_detections, with_indices: bool): super().__init__() self.score_threshold = score_threshold self.iou_threshold = iou_threshold self.max_detections = max_detections + self.op = multiclass_nms_with_indices if with_indices else multiclass_nms def forward(self, boxes, scores): - return nms.multiclass_nms(boxes, - scores, - score_threshold=self.score_threshold, - iou_threshold=self.iou_threshold, - max_detections=self.max_detections) + return self.op(boxes, + scores, + score_threshold=self.score_threshold, + iou_threshold=self.iou_threshold, + max_detections=self.max_detections) class TestMultiClassNMS: - def test_flatten_image_inputs(self): - boxes = Tensor([[0.1, 0.2, 0.3, 0.4], - [0.11, 0.21, 0.31, 0.41], - [0.12, 0.22, 0.32, 0.42]]) # yapf: disable - scores = Tensor([[0.15, 0.25, 0.35, 0.45], - [0.16, 0.26, 0.11, 0.46], - [0.1, 0.27, 0.37, 0.47]]) # yapf: disable - x = nms._convert_inputs(boxes, scores, score_threshold=0.11) - flat_boxes, flat_scores, labels = x[:, :4], x[:, 4], x[:, 5] - assert flat_boxes.shape == (10, 4) - assert flat_scores.shape == labels.shape == (10, ) - assert torch.equal(labels, Tensor([0, 1, 2, 3, 0, 1, 3, 1, 2, 3])) - for i in range(4): - assert torch.equal(flat_boxes[i], boxes[0]), i - for i in range(4, 7): - assert torch.equal(flat_boxes[i], boxes[1]), i - for i in range(7, 10): - assert torch.equal(flat_boxes[i], boxes[2]), i - assert torch.equal(flat_scores, Tensor([0.15, 0.25, 0.35, 0.45, 0.16, 0.26, 0.46, 0.27, 0.37, 0.47])) - - def test_nms_with_class_offsets(self): - boxes = Tensor([[0.1, 0.2, 0.3, 0.4], - [0.1, 0.2, 0.3, 0.4], - [0.5, 0.6, 0.7, 0.8], - [0.5, 0.6, 0.7, 0.8], - [0.1, 0.2, 0.3, 0.4], - [0.1, 0.2, 0.3, 0.4]]) # yapf: disable - scores = Tensor([0.25, 0.15, 0.3, 0.45, 0.5, 0.4]) - labels = Tensor([1, 0, 1, 2, 2, 1]) - x = torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1) - iou_threshold = 0.5 - ret_idxs = nms._nms_with_class_offsets(x, iou_threshold) - assert torch.equal(ret_idxs, Tensor([4, 3, 5, 2, 1])) - - @pytest.mark.parametrize('max_detections', [3, 6, 10]) - # mock is to test our logic, and no mock is for integration sanity - @pytest.mark.parametrize('mock_tv_op', [True, False]) - def test_image_multiclass_nms(self, mocker, max_detections, mock_tv_op): - boxes = Tensor([[0.1, 0.2, 0.3, 0.4], - [0.5, 0.6, 0.7, 0.8]]) # yapf: disable - scores = Tensor([[0.2, 0.109, 0.3, 0.12], - [0.111, 0.5, 0.05, 0.4]]) # yapf: disable - score_threshold = 0.11 - iou_threshold = 0.61 - if mock_tv_op: - nms_mock = mocker.patch('sony_custom_layers.pytorch.object_detection.nms._nms_with_class_offsets', - Mock(return_value=Tensor([4, 5, 1, 0, 2, 3]).to(torch.int64))) - ret, ret_valid_dets = nms._image_multiclass_nms(boxes, - scores, - score_threshold=score_threshold, - iou_threshold=iou_threshold, - max_detections=max_detections) - if mock_tv_op: - assert torch.equal(nms_mock.call_args.args[0][:, :4], - Tensor([[0.1, 0.2, 0.3, 0.4], - [0.1, 0.2, 0.3, 0.4], - [0.1, 0.2, 0.3, 0.4], - [0.5, 0.6, 0.7, 0.8], - [0.5, 0.6, 0.7, 0.8], - [0.5, 0.6, 0.7, 0.8]])) # yapf: disable - assert torch.equal(nms_mock.call_args.args[0][:, 4], Tensor([0.2, 0.3, 0.12, 0.111, 0.5, 0.4])) - assert torch.equal(nms_mock.call_args.args[0][:, 5], Tensor([0, 2, 3, 0, 1, 3])) - assert nms_mock.call_args.kwargs == {'iou_threshold': iou_threshold} - - assert ret.shape == (max_detections, 6) - exp_valid_dets = min(6, max_detections) - assert torch.equal(ret[:, :4][:exp_valid_dets], - Tensor([[0.5, 0.6, 0.7, 0.8], - [0.5, 0.6, 0.7, 0.8], - [0.1, 0.2, 0.3, 0.4], - [0.1, 0.2, 0.3, 0.4], - [0.1, 0.2, 0.3, 0.4], - [0.5, 0.6, 0.7, 0.8]])[:exp_valid_dets]) # yapf: disable - assert torch.all(ret[:, :4][exp_valid_dets:] == 0) - assert torch.equal(ret[:, 4][:exp_valid_dets], Tensor([0.5, 0.4, 0.3, 0.2, 0.12, 0.111])[:exp_valid_dets]) - assert torch.all(ret[:, 4][exp_valid_dets:] == 0) - assert torch.equal(ret[:, 5][:exp_valid_dets], Tensor([1, 3, 2, 0, 3, 0])[:exp_valid_dets]) - assert torch.all(ret[:, 5][exp_valid_dets:] == 0) - assert ret_valid_dets == exp_valid_dets - - def test_empty_tensors(self): - # empty inputs - ret = nms.multiclass_nms(torch.rand(1, 0, 4), torch.rand(1, 0, 10), 0.55, 0.6, 50) - assert ret.n_valid[0] == 0 and ret.boxes.size(1) == 50 - # no valid scores - ret = nms.multiclass_nms(torch.rand(1, 100, 4), torch.rand(1, 100, 20) / 2, 0.55, 0.6, 50) - assert ret.n_valid[0] == 0 and ret.boxes.size(1) == 50 - - def test_batch_multiclass_nms(self, mocker): - input_boxes, input_scores = self._generate_random_inputs(batch=3, n_boxes=20, n_classes=10) - max_dets = 5 - - # these numbers don't really make sense as nms outputs, but we don't really care, we only want to test - # that outputs are combined correctly - img_nms_ret = torch.rand(3, max_dets, 6) - img_nms_ret[..., 5] = torch.randint(0, 10, (3, max_dets), dtype=torch.float32) - ret_valid_dets = Tensor([[5], [4], [3]]) - # each time the function is called, next value in the list returned - images_ret = [(img_nms_ret[i], ret_valid_dets[i]) for i in range(3)] - mock = mocker.patch('sony_custom_layers.pytorch.object_detection.nms._image_multiclass_nms', - Mock(side_effect=lambda *args, **kwargs: images_ret.pop(0))) - - ret = nms._multiclass_nms_impl(input_boxes, - input_scores, - score_threshold=0.1, - iou_threshold=0.6, - max_detections=5) - - # check each invocation - for i, call_args in enumerate(mock.call_args_list): - assert torch.equal(call_args.args[0], input_boxes[i]), i - assert torch.equal(call_args.args[1], input_scores[i]), i - assert call_args.kwargs == dict(score_threshold=0.1, iou_threshold=0.6, max_detections=5), i - - assert torch.equal(ret.boxes, img_nms_ret[:, :, :4]) - assert torch.equal(ret.scores, img_nms_ret[:, :, 4]) - assert torch.equal(ret.labels, img_nms_ret[:, :, 5]) - assert ret.labels.dtype == torch.int64 - assert torch.equal(ret.n_valid, ret_valid_dets) - assert ret.n_valid.dtype == torch.int64 - - def test_torch_op(self, mocker): - mock = mocker.patch( - 'sony_custom_layers.pytorch.object_detection.nms._multiclass_nms_impl', - Mock(return_value=(torch.rand(3, 5, 4), torch.rand(3, 5), torch.rand(3, 5), torch.rand(3, 1)))) - boxes, scores = self._generate_random_inputs(batch=3, n_boxes=10, n_classes=5) - ret = torch.ops.sony.multiclass_nms(boxes, scores, score_threshold=0.1, iou_threshold=0.6, max_detections=5) + def _batch_multiclass_nms_mock(self, batch, n_dets, n_classes=20): + ret = torch.rand(batch, n_dets, 7) + ret[..., LABELS] = torch.randint(n_classes, size=(batch, n_dets), dtype=torch.float32) # labels + ret[..., INDICES] = torch.randint(n_dets * n_classes, size=(batch, n_dets), + dtype=torch.float32) # input box indices + n_valid = torch.randint(n_dets + 1, size=(3, 1), dtype=torch.float32) + return Mock(return_value=(ret, n_valid)) + + @pytest.mark.parametrize('op, patch_pkg', [(torch.ops.sony.multiclass_nms, 'nms'), + (torch.ops.sony.multiclass_nms_with_indices, 'nms_with_indices')]) + def test_torch_op(self, mocker, op, patch_pkg): + mock = mocker.patch(f'sony_custom_layers.pytorch.object_detection.{patch_pkg}._batch_multiclass_nms', + self._batch_multiclass_nms_mock(batch=3, n_dets=5)) + boxes, scores = generate_random_inputs(batch=3, n_boxes=10, n_classes=5) + ret = op(boxes, scores, score_threshold=0.1, iou_threshold=0.6, max_detections=5) assert torch.equal(mock.call_args.args[0], boxes) assert torch.equal(mock.call_args.args[1], scores) assert mock.call_args.kwargs == dict(score_threshold=0.1, iou_threshold=0.6, max_detections=5.) - assert ret == mock.return_value - - def test_torch_op_wrapper(self, mocker): - mock = mocker.patch( - 'sony_custom_layers.pytorch.object_detection.nms._multiclass_nms_impl', - Mock(return_value=(torch.rand(3, 5, 4), torch.rand(3, 5), torch.rand(3, 5), torch.rand(3, 1)))) - boxes, scores = self._generate_random_inputs(batch=3, n_boxes=20, n_classes=10) - ret = nms.multiclass_nms(boxes, scores, score_threshold=0.1, iou_threshold=0.6, max_detections=5) + assert torch.equal(ret[0], mock.return_value[0][:, :, :4]) + assert ret[0].dtype == torch.float32 + assert torch.equal(ret[1], mock.return_value[0][:, :, SCORES]) + assert ret[1].dtype == torch.float32 + assert torch.equal(ret[2], mock.return_value[0][:, :, LABELS]) + assert ret[2].dtype == torch.int64 + if op == torch.ops.sony.multiclass_nms_with_indices: + assert torch.equal(ret[3], mock.return_value[0][:, :, INDICES]) + assert ret[3].dtype == torch.int64 + assert torch.equal(ret[4], mock.return_value[1]) + assert ret[4].dtype == torch.int64 + assert len(ret) == 5 + elif op == torch.ops.sony.multiclass_nms: + assert torch.equal(ret[3], mock.return_value[1]) + assert ret[3].dtype == torch.int64 + assert len(ret) == 4 + else: + raise ValueError(op) + + @pytest.mark.parametrize('op, res_cls, torch_op, patch_pkg', + [(multiclass_nms, NMSResults, torch.ops.sony.multiclass_nms, 'nms'), + (multiclass_nms_with_indices, NMSWithIndicesResults, + torch.ops.sony.multiclass_nms_with_indices, 'nms_with_indices')]) + def test_torch_op_wrapper(self, mocker, op, res_cls, torch_op, patch_pkg): + mock = mocker.patch(f'sony_custom_layers.pytorch.object_detection.{patch_pkg}._batch_multiclass_nms', + self._batch_multiclass_nms_mock(batch=3, n_dets=5)) + boxes, scores = generate_random_inputs(batch=3, n_boxes=20, n_classes=10) + ret = op(boxes, scores, score_threshold=0.1, iou_threshold=0.6, max_detections=5) assert torch.equal(mock.call_args.args[0], boxes) assert torch.equal(mock.call_args.args[1], scores) assert mock.call_args.kwargs == dict(score_threshold=0.1, iou_threshold=0.6, max_detections=5) - assert isinstance(ret, nms.NMSResults) - assert torch.equal(ret.boxes, mock.return_value[0]) - assert torch.equal(ret.scores, mock.return_value[1]) - assert torch.equal(ret.labels, mock.return_value[2]) - assert torch.equal(ret.n_valid, mock.return_value[3]) + + ref_ret = torch_op(boxes, scores, score_threshold=0.1, iou_threshold=0.6, max_detections=5) + assert isinstance(ret, res_cls) + assert torch.equal(ret.boxes, ref_ret[0]) + assert ret.boxes.dtype == torch.float32 + assert torch.equal(ret.scores, ref_ret[1]) + assert ret.scores.dtype == torch.float32 + assert torch.equal(ret.labels, ref_ret[2]) + assert ret.labels.dtype == torch.int64 + if op == multiclass_nms: + assert torch.equal(ret.n_valid, ref_ret[3]) + assert ret.n_valid.dtype == torch.int64 + elif op == multiclass_nms_with_indices: + assert torch.equal(ret.indices, ref_ret[3]) + assert ret.indices.dtype == torch.int64 + assert torch.equal(ret.n_valid, ref_ret[4]) + assert ret.n_valid.dtype == torch.int64 + + @pytest.mark.parametrize('op', [multiclass_nms, multiclass_nms_with_indices]) + def test_empty_tensors(self, op): + # empty inputs + ret = op(torch.rand(1, 0, 4), torch.rand(1, 0, 10), 0.55, 0.6, 50) + assert ret.n_valid[0] == 0 and ret.boxes.size(1) == 50 + # no valid scores + ret = op(torch.rand(1, 100, 4), torch.rand(1, 100, 20) / 2, 0.55, 0.6, 50) + assert ret.n_valid[0] == 0 and ret.boxes.size(1) == 50 @pytest.mark.parametrize('dynamic_batch', [True, False]) - def test_onnx_export(self, dynamic_batch, tmpdir_factory): + @pytest.mark.parametrize('with_indices', [True, False]) + def test_onnx_export(self, dynamic_batch, tmpdir_factory, with_indices): score_thresh = 0.1 iou_thresh = 0.6 n_boxes = 10 n_classes = 5 max_dets = 7 - onnx_model = NMS(score_thresh, iou_thresh, max_dets) + onnx_model = NMS(score_thresh, iou_thresh, max_dets, with_indices=with_indices) - path = str(tmpdir_factory.mktemp('nms').join('nms.onnx')) - self._export_onnx(onnx_model, n_boxes, n_classes, path, dynamic_batch=dynamic_batch) + path = str(tmpdir_factory.mktemp('nms').join(f'nms{with_indices}.onnx')) + self._export_onnx(onnx_model, n_boxes, n_classes, path, dynamic_batch=dynamic_batch, with_indices=with_indices) onnx_model = onnx.load(path) onnx.checker.check_model(onnx_model, full_check=True) @@ -214,7 +145,7 @@ def test_onnx_export(self, dynamic_batch, tmpdir_factory): nms_node = list(onnx_model.graph.node)[0] assert nms_node.domain == 'Sony' - assert nms_node.op_type == 'MultiClassNMS' + assert nms_node.op_type == ('MultiClassNMSWithIndices' if with_indices else 'MultiClassNMS') attrs = sorted(nms_node.attribute, key=lambda a: a.name) assert attrs[0].name == 'iou_threshold' np.isclose(attrs[0].f, iou_thresh) @@ -223,7 +154,7 @@ def test_onnx_export(self, dynamic_batch, tmpdir_factory): assert attrs[2].name == 'score_threshold' np.isclose(attrs[2].f, score_thresh) assert len(nms_node.input) == 2 - assert len(nms_node.output) == 4 + assert len(nms_node.output) == 4 + int(with_indices) def check_tensor(onnx_tensor, exp_shape, exp_type): tensor_type = onnx_tensor.type.tensor_type @@ -238,18 +169,23 @@ def check_tensor(onnx_tensor, exp_shape, exp_type): check_tensor(onnx_model.graph.output[0], [max_dets, 4], torch.onnx.TensorProtoDataType.FLOAT) check_tensor(onnx_model.graph.output[1], [max_dets], torch.onnx.TensorProtoDataType.FLOAT) check_tensor(onnx_model.graph.output[2], [max_dets], torch.onnx.TensorProtoDataType.INT32) - check_tensor(onnx_model.graph.output[3], [1], torch.onnx.TensorProtoDataType.INT32) + if with_indices: + check_tensor(onnx_model.graph.output[3], [max_dets], torch.onnx.TensorProtoDataType.INT32) + check_tensor(onnx_model.graph.output[4], [1], torch.onnx.TensorProtoDataType.INT32) + else: + check_tensor(onnx_model.graph.output[3], [1], torch.onnx.TensorProtoDataType.INT32) @pytest.mark.parametrize('dynamic_batch', [True, False]) - def test_ort(self, dynamic_batch, tmpdir_factory): - model = NMS(0.5, 0.3, 1000) + @pytest.mark.parametrize('with_indices', [True, False]) + def test_ort(self, dynamic_batch, tmpdir_factory, with_indices): + model = NMS(0.5, 0.3, 1000, with_indices=with_indices) n_boxes = 500 n_classes = 20 - path = str(tmpdir_factory.mktemp('nms').join('nms.onnx')) - self._export_onnx(model, n_boxes, n_classes, path, dynamic_batch) + path = str(tmpdir_factory.mktemp('nms').join(f'nms{with_indices}.onnx')) + self._export_onnx(model, n_boxes, n_classes, path, dynamic_batch, with_indices=with_indices) batch = 5 if dynamic_batch else 1 - boxes, scores = self._generate_random_inputs(batch=batch, n_boxes=n_boxes, n_classes=n_classes, seed=42) + boxes, scores = generate_random_inputs(batch=batch, n_boxes=n_boxes, n_classes=n_classes, seed=42) torch_res = model(boxes, scores) so = load_custom_ops(load_ort=True) session = ort.InferenceSession(path, sess_options=so) @@ -271,28 +207,34 @@ def test_ort(self, dynamic_batch, tmpdir_factory): exec_in_clean_process(code, check=True) @pytest.mark.skipif(not is_compatible('torch>=2.2'), reason='unsupported') - def test_pt2_export(self, tmpdir_factory): + @pytest.mark.parametrize('with_indices', [True, False]) + def test_pt2_export(self, tmpdir_factory, with_indices): - def f(boxes, scores): - return nms.multiclass_nms(boxes, scores, 0.5, 0.3, 100) - - prog = torch.export.export(f, args=(torch.rand(1, 10, 4), torch.rand(1, 10, 5))) + model = NMS(score_threshold=0.5, iou_threshold=0.3, max_detections=100, with_indices=with_indices) + prog = torch.export.export(model, args=(torch.rand(1, 10, 4), torch.rand(1, 10, 5))) nms_node = list(prog.graph.nodes)[2] - assert nms_node.target == torch.ops.sony.multiclass_nms.default + exp_target = torch.ops.sony.multiclass_nms_with_indices if with_indices else torch.ops.sony.multiclass_nms + assert nms_node.target == exp_target.default val = nms_node.meta['val'] assert val[0].shape[1:] == (100, 4) assert val[1].shape[1:] == val[2].shape[1:] == (100, ) assert val[2].dtype == torch.int64 - assert val[3].shape[1:] == (1, ) - assert val[3].dtype == torch.int64 - - boxes, scores = self._generate_random_inputs(1, 10, 5) - torch_out = f(boxes, scores) + if with_indices: + assert val[3].shape[1:] == (100, ) + assert val[3].dtype == torch.int64 + assert val[4].shape[1:] == (1, ) + assert val[4].dtype == torch.int64 + else: + assert val[3].shape[1:] == (1, ) + assert val[3].dtype == torch.int64 + + boxes, scores = generate_random_inputs(1, 10, 5) + torch_out = model(boxes, scores) prog_out = prog.module()(boxes, scores) for i in range(len(torch_out)): assert torch.allclose(torch_out[i], prog_out[i]), i - path = str(tmpdir_factory.mktemp('nms').join('nms.pt2')) + path = str(tmpdir_factory.mktemp('nms').join(f'nms{with_indices}.pt2')) torch.export.save(prog, path) # check that exported program can be loaded in a clean env code = f""" @@ -303,21 +245,11 @@ def f(boxes, scores): """ exec_in_clean_process(code, check=True) - @staticmethod - def _generate_random_inputs(batch: Optional[int], n_boxes, n_classes, seed=None): - boxes_shape = (batch, n_boxes, 4) if batch else (n_boxes, 4) - scores_shape = (batch, n_boxes, n_classes) if batch else (n_boxes, n_classes) - if seed: - torch.random.manual_seed(seed) - boxes = torch.rand(*boxes_shape) - boxes[..., 0], boxes[..., 2] = torch.aminmax(boxes[..., (0, 2)], dim=-1) - boxes[..., 1], boxes[..., 3] = torch.aminmax(boxes[..., (1, 3)], dim=-1) - scores = torch.rand(*scores_shape) - return boxes, scores - - def _export_onnx(self, nms_model, n_boxes, n_classes, path, dynamic_batch: bool): + def _export_onnx(self, nms_model, n_boxes, n_classes, path, dynamic_batch: bool, with_indices: bool): input_names = ['boxes', 'scores'] output_names = ['det_boxes', 'det_scores', 'det_labels', 'valid_dets'] + if with_indices: + output_names.insert(3, 'indices') kwargs = {'dynamic_axes': {k: {0: 'batch'} for k in input_names + output_names}} if dynamic_batch else {} torch.onnx.export(nms_model, args=(torch.ones(1, n_boxes, 4), torch.ones(1, n_boxes, n_classes)), diff --git a/sony_custom_layers/pytorch/tests/object_detection/test_nms_common.py b/sony_custom_layers/pytorch/tests/object_detection/test_nms_common.py new file mode 100644 index 0000000..bf02da4 --- /dev/null +++ b/sony_custom_layers/pytorch/tests/object_detection/test_nms_common.py @@ -0,0 +1,178 @@ +# ----------------------------------------------------------------------------- +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ----------------------------------------------------------------------------- +from typing import Optional +from unittest.mock import Mock + +import pytest +import torch +from torch import Tensor + +from sony_custom_layers.pytorch.object_detection import nms_common + + +def generate_random_inputs(batch: Optional[int], n_boxes, n_classes, seed=None): + boxes_shape = (batch, n_boxes, 4) if batch else (n_boxes, 4) + scores_shape = (batch, n_boxes, n_classes) if batch else (n_boxes, n_classes) + if seed: + torch.random.manual_seed(seed) + boxes = torch.rand(*boxes_shape) + boxes[..., 0], boxes[..., 2] = torch.aminmax(boxes[..., (0, 2)], dim=-1) + boxes[..., 1], boxes[..., 3] = torch.aminmax(boxes[..., (1, 3)], dim=-1) + scores = torch.rand(*scores_shape) + return boxes, scores + + +class TestNMSCommon: + + def test_flatten_image_inputs(self): + boxes = Tensor([[0.1, 0.2, 0.3, 0.4], + [0.11, 0.21, 0.31, 0.41], + [0.12, 0.22, 0.32, 0.42]]) # yapf: disable + scores = Tensor([[0.15, 0.25, 0.35, 0.45], + [0.16, 0.26, 0.11, 0.46], + [0.1, 0.27, 0.37, 0.47]]) # yapf: disable + x = nms_common._convert_inputs(boxes, scores, score_threshold=0.11) + assert x.shape == (10, 7) + flat_boxes, flat_scores, labels, input_box_indices = x[:, :4], x[:, 4], x[:, 5], x[:, 6] + assert flat_boxes.shape == (10, 4) + assert flat_scores.shape == labels.shape == input_box_indices.shape == (10, ) + assert torch.equal(labels, Tensor([0, 1, 2, 3, 0, 1, 3, 1, 2, 3])) + for i in range(4): + assert torch.equal(flat_boxes[i], boxes[0]), i + for i in range(4, 7): + assert torch.equal(flat_boxes[i], boxes[1]), i + for i in range(7, 10): + assert torch.equal(flat_boxes[i], boxes[2]), i + assert torch.equal(flat_scores, Tensor([0.15, 0.25, 0.35, 0.45, 0.16, 0.26, 0.46, 0.27, 0.37, 0.47])) + assert torch.equal(input_box_indices, Tensor([0, 0, 0, 0, 1, 1, 1, 2, 2, 2])) + + def test_nms_with_class_offsets(self): + boxes = Tensor([[0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.5, 0.6, 0.7, 0.8], + [0.5, 0.6, 0.7, 0.8], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4]]) # yapf: disable + scores = Tensor([0.25, 0.15, 0.3, 0.45, 0.5, 0.4]) + labels = Tensor([1, 0, 1, 2, 2, 1]) + x = torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1) + iou_threshold = 0.5 + ret_idxs = nms_common._nms_with_class_offsets(x, iou_threshold) + assert torch.equal(ret_idxs, Tensor([4, 3, 5, 2, 1])) + + @pytest.mark.parametrize('max_detections', [3, 6, 10]) + # mock is to test our logic, and no mock is for integration sanity + @pytest.mark.parametrize('mock_tv_op', [True, False]) + def test_image_multiclass_nms(self, mocker, max_detections, mock_tv_op): + boxes = Tensor([[0.1, 0.2, 0.3, 0.4], + [0.5, 0.6, 0.7, 0.8]]) # yapf: disable + scores = Tensor([[0.2, 0.109, 0.3, 0.12], + [0.111, 0.5, 0.05, 0.4]]) # yapf: disable + score_threshold = 0.11 + iou_threshold = 0.61 + if mock_tv_op: + nms_mock = mocker.patch('sony_custom_layers.pytorch.object_detection.nms_common._nms_with_class_offsets', + Mock(return_value=Tensor([4, 5, 1, 0, 2, 3]).to(torch.int64))) + ret, ret_valid_dets = nms_common._image_multiclass_nms(boxes, + scores, + score_threshold=score_threshold, + iou_threshold=iou_threshold, + max_detections=max_detections) + if mock_tv_op: + assert nms_mock.call_args.args[0].shape == (6, 6) + assert torch.equal(nms_mock.call_args.args[0][:, :4], + Tensor([[0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.5, 0.6, 0.7, 0.8], + [0.5, 0.6, 0.7, 0.8], + [0.5, 0.6, 0.7, 0.8]])) # yapf: disable + assert torch.equal(nms_mock.call_args.args[0][:, 4], Tensor([0.2, 0.3, 0.12, 0.111, 0.5, 0.4])) + assert torch.equal(nms_mock.call_args.args[0][:, 5], Tensor([0, 2, 3, 0, 1, 3])) + assert nms_mock.call_args.kwargs == {'iou_threshold': iou_threshold} + + assert ret.shape == (max_detections, 7) + exp_valid_dets = min(6, max_detections) + assert torch.equal(ret[:, :4][:exp_valid_dets], + Tensor([[0.5, 0.6, 0.7, 0.8], + [0.5, 0.6, 0.7, 0.8], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.5, 0.6, 0.7, 0.8]])[:exp_valid_dets]) # yapf: disable + assert torch.all(ret[:, :4][exp_valid_dets:] == 0) + assert torch.equal(ret[:, 4][:exp_valid_dets], Tensor([0.5, 0.4, 0.3, 0.2, 0.12, 0.111])[:exp_valid_dets]) + assert torch.all(ret[:, 4][exp_valid_dets:] == 0) + assert torch.equal(ret[:, 5][:exp_valid_dets], Tensor([1, 3, 2, 0, 3, 0])[:exp_valid_dets]) + assert torch.all(ret[:, 5][exp_valid_dets:] == 0) + assert torch.equal(ret[:, 6][:exp_valid_dets], Tensor([1, 1, 0, 0, 0, 1])[:exp_valid_dets]) + assert torch.all(ret[:, 6][exp_valid_dets:] == 0) + assert ret_valid_dets == exp_valid_dets + + def test_image_multiclass_nms_no_valid_boxes(self): + boxes, scores = generate_random_inputs(None, 100, 20) + scores = 0.5 * scores + score_threshold = 0.51 + res, n_valid_dets = nms_common._image_multiclass_nms(boxes, + scores, + score_threshold=score_threshold, + iou_threshold=0.1, + max_detections=200) + assert torch.equal(res, torch.zeros(200, 7)) + assert n_valid_dets == 0 + + def test_image_multiclass_nms_single_class(self): + boxes, scores = generate_random_inputs(None, 100, 1) + res, n_valid_dets = nms_common._image_multiclass_nms(boxes, + scores, + score_threshold=0.1, + iou_threshold=0.1, + max_detections=50) + assert res.shape == (50, 7) + assert n_valid_dets > 0 + assert torch.equal(res[:n_valid_dets, 5], torch.zeros((n_valid_dets, ))) + + def test_batch_multiclass_nms(self, mocker): + input_boxes, input_scores = generate_random_inputs(batch=3, n_boxes=20, n_classes=10) + max_dets = 5 + + # these numbers don't really make sense as nms outputs, but we don't really care, we only want to test + # that outputs are combined correctly + img_nms_ret = torch.rand(3, max_dets, 7) + # scores + img_nms_ret[..., 5] = torch.randint(0, 20, (3, max_dets), dtype=torch.float32) + # input box indices + img_nms_ret[..., 6] = torch.randint(0, 200, (3, max_dets), dtype=torch.float32) + ret_valid_dets = Tensor([[5], [4], [3]]) + # each time the function is called, next value in the list returned + images_ret = [(img_nms_ret[i], ret_valid_dets[i]) for i in range(3)] + mock = mocker.patch('sony_custom_layers.pytorch.object_detection.nms_common._image_multiclass_nms', + Mock(side_effect=lambda *args, **kwargs: images_ret.pop(0))) + + res, n_valid = nms_common._batch_multiclass_nms(input_boxes, + input_scores, + score_threshold=0.1, + iou_threshold=0.6, + max_detections=5) + + # check each invocation + for i, call_args in enumerate(mock.call_args_list): + assert torch.equal(call_args.args[0], input_boxes[i]), i + assert torch.equal(call_args.args[1], input_scores[i]), i + assert call_args.kwargs == dict(score_threshold=0.1, iou_threshold=0.6, max_detections=5), i + + assert torch.equal(res, img_nms_ret) + assert torch.equal(n_valid, ret_valid_dets) From 50196dbe2bc55f27dc4e53c98e2fa9431f375436 Mon Sep 17 00:00:00 2001 From: irenab Date: Sun, 8 Sep 2024 18:48:30 +0300 Subject: [PATCH 2/3] update docs --- docs/index.html | 6 +- docs/sony_custom_layers/keras.html | 24 +- docs/sony_custom_layers/pytorch.html | 754 +++++++++++++----- .../pytorch/object_detection/nms.py | 2 + .../pytorch/object_detection/nms_common.py | 6 +- .../object_detection/nms_with_indices.py | 10 +- 6 files changed, 581 insertions(+), 221 deletions(-) diff --git a/docs/index.html b/docs/index.html index 3a5531d..327a69f 100644 --- a/docs/index.html +++ b/docs/index.html @@ -3,14 +3,14 @@ - - Module List – pdoc 14.4.0 + + Module List – pdoc 14.6.1 - + diff --git a/docs/sony_custom_layers/keras.html b/docs/sony_custom_layers/keras.html index 1d373d3..20e53bd 100644 --- a/docs/sony_custom_layers/keras.html +++ b/docs/sony_custom_layers/keras.html @@ -3,14 +3,14 @@ - + sony_custom_layers.keras API documentation - +