From 595f8fdff3522866d2d2280f0bf6e5f1902c5c92 Mon Sep 17 00:00:00 2001 From: Zhang <33895156+ZhangLearning@users.noreply.github.com> Date: Thu, 20 Apr 2023 19:31:08 +0800 Subject: [PATCH] [Feature] Add the MLU support for box_iou_rotated op (#2703) --- docs/en/understand_mmcv/ops.md | 2 +- docs/zh_cn/understand_mmcv/ops.md | 2 +- mmcv/ops/box_iou_rotated.py | 5 +- mmcv/ops/csrc/pytorch/mlu/box_iou_rotated.cpp | 54 +++++++++++++++++++ tests/test_ops/test_box_iou_rotated.py | 43 ++++++++++----- 5 files changed, 89 insertions(+), 17 deletions(-) create mode 100644 mmcv/ops/csrc/pytorch/mlu/box_iou_rotated.cpp diff --git a/docs/en/understand_mmcv/ops.md b/docs/en/understand_mmcv/ops.md index e7212bbdd..d1d5fe11b 100644 --- a/docs/en/understand_mmcv/ops.md +++ b/docs/en/understand_mmcv/ops.md @@ -9,7 +9,7 @@ We implement common ops used in detection, segmentation, etc. | BallQuery | | √ | √ | | | | BBoxOverlaps | | √ | √ | √ | √ | | BorderAlign | | √ | | | | -| BoxIouRotated | √ | √ | | | | +| BoxIouRotated | √ | √ | √ | | | | BoxIouQuadri | √ | √ | | | | | CARAFE | | √ | √ | | | | ChamferDistance | | √ | | | | diff --git a/docs/zh_cn/understand_mmcv/ops.md b/docs/zh_cn/understand_mmcv/ops.md index 81092144a..c687dc913 100644 --- a/docs/zh_cn/understand_mmcv/ops.md +++ b/docs/zh_cn/understand_mmcv/ops.md @@ -9,7 +9,7 @@ MMCV 提供了检测、分割等任务中常用的算子 | BallQuery | | √ | √ | | | | BBoxOverlaps | | √ | √ | √ | √ | | BorderAlign | | √ | | | | -| BoxIouRotated | √ | √ | | | | +| BoxIouRotated | √ | √ | √ | | | | BoxIouQuadri | √ | √ | | | | | CARAFE | | √ | √ | | | | ChamferDistance | | √ | | | | diff --git a/mmcv/ops/box_iou_rotated.py b/mmcv/ops/box_iou_rotated.py index 2443af27c..8e199d9ac 100644 --- a/mmcv/ops/box_iou_rotated.py +++ b/mmcv/ops/box_iou_rotated.py @@ -133,7 +133,10 @@ def box_iou_rotated(bboxes1: torch.Tensor, if aligned: ious = bboxes1.new_zeros(rows) else: - ious = bboxes1.new_zeros(rows * cols) + if bboxes1.device.type == 'mlu': + ious = bboxes1.new_zeros([rows, cols]) + else: + ious = bboxes1.new_zeros(rows * cols) if not clockwise: flip_mat = bboxes1.new_ones(bboxes1.shape[-1]) flip_mat[-1] = -1 diff --git a/mmcv/ops/csrc/pytorch/mlu/box_iou_rotated.cpp b/mmcv/ops/csrc/pytorch/mlu/box_iou_rotated.cpp new file mode 100644 index 000000000..6a903973d --- /dev/null +++ b/mmcv/ops/csrc/pytorch/mlu/box_iou_rotated.cpp @@ -0,0 +1,54 @@ +/************************************************************************* + * Copyright (C) 2022 by Cambricon. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include "mlu_common_helper.h" + +void BoxIouRotatedMLUKernelLauncher(const Tensor boxes1, const Tensor boxes2, + Tensor ious, const int mode_flag, + const bool aligned) { + // get compute handle + auto handle = mluOpGetCurrentHandle(); + + auto boxes1_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + boxes1, boxes1.suggest_memory_format()); + auto boxes2_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + boxes2, boxes2.suggest_memory_format()); + auto ious_contiguous = + torch_mlu::cnnl::ops::cnnl_contiguous(ious, ious.suggest_memory_format()); + + MluOpTensorDescriptor boxes1_desc, boxes2_desc, ious_desc; + boxes1_desc.set(boxes1_contiguous); + boxes2_desc.set(boxes2_contiguous); + ious_desc.set(ious_contiguous); + + auto boxes1_impl = torch_mlu::getMluTensorImpl(boxes1_contiguous); + auto boxes2_impl = torch_mlu::getMluTensorImpl(boxes2_contiguous); + auto ious_impl = torch_mlu::getMluTensorImpl(ious_contiguous); + + auto boxes1_ptr = boxes1_impl->cnnlMalloc(); + auto boxes2_ptr = boxes2_impl->cnnlMalloc(); + auto ious_ptr = ious_impl->cnnlMalloc(); + + CNLOG(INFO) << "Call mluOpBoxIouRotated()."; + mluOpBoxIouRotated(handle, mode_flag, aligned, boxes1_desc.desc(), boxes1_ptr, + boxes2_desc.desc(), boxes2_ptr, ious_desc.desc(), + ious_ptr); +} + +void box_iou_rotated_mlu(const Tensor boxes1, const Tensor boxes2, Tensor ious, + const int mode_flag, const bool aligned) { + BoxIouRotatedMLUKernelLauncher(boxes1, boxes2, ious, mode_flag, aligned); +} + +void box_iou_rotated_impl(const Tensor boxes1, const Tensor boxes2, Tensor ious, + const int mode_flag, const bool aligned); + +REGISTER_DEVICE_IMPL(box_iou_rotated_impl, MLU, box_iou_rotated_mlu); diff --git a/tests/test_ops/test_box_iou_rotated.py b/tests/test_ops/test_box_iou_rotated.py index 9f5e0dfa3..f57e54c1e 100644 --- a/tests/test_ops/test_box_iou_rotated.py +++ b/tests/test_ops/test_box_iou_rotated.py @@ -3,11 +3,13 @@ import numpy as np import pytest import torch +from mmcv.ops import box_iou_rotated +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE + class TestBoxIoURotated: def test_box_iou_rotated_cpu(self): - from mmcv.ops import box_iou_rotated np_boxes1 = np.asarray( [[1.0, 1.0, 3.0, 4.0, 0.5], [2.0, 2.0, 3.0, 4.0, 0.6], [7.0, 7.0, 8.0, 8.0, 0.4]], @@ -44,10 +46,17 @@ class TestBoxIoURotated: assert np.allclose( ious.cpu().numpy(), np_expect_ious_aligned, atol=1e-4) - @pytest.mark.skipif( - not torch.cuda.is_available(), reason='requires CUDA support') - def test_box_iou_rotated_cuda(self): - from mmcv.ops import box_iou_rotated + @pytest.mark.parametrize('device', [ + pytest.param( + 'cuda', + marks=pytest.mark.skipif( + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), + pytest.param( + 'mlu', + marks=pytest.mark.skipif( + not IS_MLU_AVAILABLE, reason='requires MLU support')) + ]) + def test_box_iou_rotated(self, device): np_boxes1 = np.asarray( [[1.0, 1.0, 3.0, 4.0, 0.5], [2.0, 2.0, 3.0, 4.0, 0.6], [7.0, 7.0, 8.0, 8.0, 0.4]], @@ -63,8 +72,8 @@ class TestBoxIoURotated: np_expect_ious_aligned = np.asarray([0.3708, 0.4487, 0.3622], dtype=np.float32) - boxes1 = torch.from_numpy(np_boxes1).cuda() - boxes2 = torch.from_numpy(np_boxes2).cuda() + boxes1 = torch.from_numpy(np_boxes1).to(device) + boxes2 = torch.from_numpy(np_boxes2).to(device) # test cw angle definition ious = box_iou_rotated(boxes1, boxes2) @@ -85,7 +94,6 @@ class TestBoxIoURotated: ious.cpu().numpy(), np_expect_ious_aligned, atol=1e-4) def test_box_iou_rotated_iof_cpu(self): - from mmcv.ops import box_iou_rotated np_boxes1 = np.asarray( [[1.0, 1.0, 3.0, 4.0, 0.5], [2.0, 2.0, 3.0, 4.0, 0.6], [7.0, 7.0, 8.0, 8.0, 0.4]], @@ -121,10 +129,17 @@ class TestBoxIoURotated: assert np.allclose( ious.cpu().numpy(), np_expect_ious_aligned, atol=1e-4) - @pytest.mark.skipif( - not torch.cuda.is_available(), reason='requires CUDA support') - def test_box_iou_rotated_iof_cuda(self): - from mmcv.ops import box_iou_rotated + @pytest.mark.parametrize('device', [ + pytest.param( + 'cuda', + marks=pytest.mark.skipif( + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), + pytest.param( + 'mlu', + marks=pytest.mark.skipif( + not IS_MLU_AVAILABLE, reason='requires MLU support')) + ]) + def test_box_iou_rotated_iof(self, device): np_boxes1 = np.asarray( [[1.0, 1.0, 3.0, 4.0, 0.5], [2.0, 2.0, 3.0, 4.0, 0.6], [7.0, 7.0, 8.0, 8.0, 0.4]], @@ -140,8 +155,8 @@ class TestBoxIoURotated: np_expect_ious_aligned = np.asarray([0.4959, 0.5420, 0.4404], dtype=np.float32) - boxes1 = torch.from_numpy(np_boxes1).cuda() - boxes2 = torch.from_numpy(np_boxes2).cuda() + boxes1 = torch.from_numpy(np_boxes1).to(device) + boxes2 = torch.from_numpy(np_boxes2).to(device) # test cw angle definition ious = box_iou_rotated(boxes1, boxes2, mode='iof')