[Feature] Add the MLU support for box_iou_rotated op (#2703)

pull/2795/head
Zhang 2023-04-20 19:31:08 +08:00 committed by GitHub
parent 8d01bf6a6c
commit 595f8fdff3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 89 additions and 17 deletions

View File

@ -9,7 +9,7 @@ We implement common ops used in detection, segmentation, etc.
| BallQuery | | √ | √ | | |
| BBoxOverlaps | | √ | √ | √ | √ |
| BorderAlign | | √ | | | |
| BoxIouRotated | √ | √ | | | |
| BoxIouRotated | √ | √ | | | |
| BoxIouQuadri | √ | √ | | | |
| CARAFE | | √ | √ | | |
| ChamferDistance | | √ | | | |

View File

@ -9,7 +9,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| BallQuery | | √ | √ | | |
| BBoxOverlaps | | √ | √ | √ | √ |
| BorderAlign | | √ | | | |
| BoxIouRotated | √ | √ | | | |
| BoxIouRotated | √ | √ | | | |
| BoxIouQuadri | √ | √ | | | |
| CARAFE | | √ | √ | | |
| ChamferDistance | | √ | | | |

View File

@ -133,7 +133,10 @@ def box_iou_rotated(bboxes1: torch.Tensor,
if aligned:
ious = bboxes1.new_zeros(rows)
else:
ious = bboxes1.new_zeros(rows * cols)
if bboxes1.device.type == 'mlu':
ious = bboxes1.new_zeros([rows, cols])
else:
ious = bboxes1.new_zeros(rows * cols)
if not clockwise:
flip_mat = bboxes1.new_ones(bboxes1.shape[-1])
flip_mat[-1] = -1

View File

@ -0,0 +1,54 @@
/*************************************************************************
* Copyright (C) 2022 by Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "mlu_common_helper.h"
void BoxIouRotatedMLUKernelLauncher(const Tensor boxes1, const Tensor boxes2,
Tensor ious, const int mode_flag,
const bool aligned) {
// get compute handle
auto handle = mluOpGetCurrentHandle();
auto boxes1_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
boxes1, boxes1.suggest_memory_format());
auto boxes2_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
boxes2, boxes2.suggest_memory_format());
auto ious_contiguous =
torch_mlu::cnnl::ops::cnnl_contiguous(ious, ious.suggest_memory_format());
MluOpTensorDescriptor boxes1_desc, boxes2_desc, ious_desc;
boxes1_desc.set(boxes1_contiguous);
boxes2_desc.set(boxes2_contiguous);
ious_desc.set(ious_contiguous);
auto boxes1_impl = torch_mlu::getMluTensorImpl(boxes1_contiguous);
auto boxes2_impl = torch_mlu::getMluTensorImpl(boxes2_contiguous);
auto ious_impl = torch_mlu::getMluTensorImpl(ious_contiguous);
auto boxes1_ptr = boxes1_impl->cnnlMalloc();
auto boxes2_ptr = boxes2_impl->cnnlMalloc();
auto ious_ptr = ious_impl->cnnlMalloc();
CNLOG(INFO) << "Call mluOpBoxIouRotated().";
mluOpBoxIouRotated(handle, mode_flag, aligned, boxes1_desc.desc(), boxes1_ptr,
boxes2_desc.desc(), boxes2_ptr, ious_desc.desc(),
ious_ptr);
}
void box_iou_rotated_mlu(const Tensor boxes1, const Tensor boxes2, Tensor ious,
const int mode_flag, const bool aligned) {
BoxIouRotatedMLUKernelLauncher(boxes1, boxes2, ious, mode_flag, aligned);
}
void box_iou_rotated_impl(const Tensor boxes1, const Tensor boxes2, Tensor ious,
const int mode_flag, const bool aligned);
REGISTER_DEVICE_IMPL(box_iou_rotated_impl, MLU, box_iou_rotated_mlu);

View File

@ -3,11 +3,13 @@ import numpy as np
import pytest
import torch
from mmcv.ops import box_iou_rotated
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
class TestBoxIoURotated:
def test_box_iou_rotated_cpu(self):
from mmcv.ops import box_iou_rotated
np_boxes1 = np.asarray(
[[1.0, 1.0, 3.0, 4.0, 0.5], [2.0, 2.0, 3.0, 4.0, 0.6],
[7.0, 7.0, 8.0, 8.0, 0.4]],
@ -44,10 +46,17 @@ class TestBoxIoURotated:
assert np.allclose(
ious.cpu().numpy(), np_expect_ious_aligned, atol=1e-4)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_box_iou_rotated_cuda(self):
from mmcv.ops import box_iou_rotated
@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
])
def test_box_iou_rotated(self, device):
np_boxes1 = np.asarray(
[[1.0, 1.0, 3.0, 4.0, 0.5], [2.0, 2.0, 3.0, 4.0, 0.6],
[7.0, 7.0, 8.0, 8.0, 0.4]],
@ -63,8 +72,8 @@ class TestBoxIoURotated:
np_expect_ious_aligned = np.asarray([0.3708, 0.4487, 0.3622],
dtype=np.float32)
boxes1 = torch.from_numpy(np_boxes1).cuda()
boxes2 = torch.from_numpy(np_boxes2).cuda()
boxes1 = torch.from_numpy(np_boxes1).to(device)
boxes2 = torch.from_numpy(np_boxes2).to(device)
# test cw angle definition
ious = box_iou_rotated(boxes1, boxes2)
@ -85,7 +94,6 @@ class TestBoxIoURotated:
ious.cpu().numpy(), np_expect_ious_aligned, atol=1e-4)
def test_box_iou_rotated_iof_cpu(self):
from mmcv.ops import box_iou_rotated
np_boxes1 = np.asarray(
[[1.0, 1.0, 3.0, 4.0, 0.5], [2.0, 2.0, 3.0, 4.0, 0.6],
[7.0, 7.0, 8.0, 8.0, 0.4]],
@ -121,10 +129,17 @@ class TestBoxIoURotated:
assert np.allclose(
ious.cpu().numpy(), np_expect_ious_aligned, atol=1e-4)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_box_iou_rotated_iof_cuda(self):
from mmcv.ops import box_iou_rotated
@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
])
def test_box_iou_rotated_iof(self, device):
np_boxes1 = np.asarray(
[[1.0, 1.0, 3.0, 4.0, 0.5], [2.0, 2.0, 3.0, 4.0, 0.6],
[7.0, 7.0, 8.0, 8.0, 0.4]],
@ -140,8 +155,8 @@ class TestBoxIoURotated:
np_expect_ious_aligned = np.asarray([0.4959, 0.5420, 0.4404],
dtype=np.float32)
boxes1 = torch.from_numpy(np_boxes1).cuda()
boxes2 = torch.from_numpy(np_boxes2).cuda()
boxes1 = torch.from_numpy(np_boxes1).to(device)
boxes2 = torch.from_numpy(np_boxes2).to(device)
# test cw angle definition
ious = box_iou_rotated(boxes1, boxes2, mode='iof')