mirror of https://github.com/open-mmlab/mmcv.git
[Feature] Add the MLU support for box_iou_rotated op (#2703)
parent
8d01bf6a6c
commit
595f8fdff3
|
@ -9,7 +9,7 @@ We implement common ops used in detection, segmentation, etc.
|
|||
| BallQuery | | √ | √ | | |
|
||||
| BBoxOverlaps | | √ | √ | √ | √ |
|
||||
| BorderAlign | | √ | | | |
|
||||
| BoxIouRotated | √ | √ | | | |
|
||||
| BoxIouRotated | √ | √ | √ | | |
|
||||
| BoxIouQuadri | √ | √ | | | |
|
||||
| CARAFE | | √ | √ | | |
|
||||
| ChamferDistance | | √ | | | |
|
||||
|
|
|
@ -9,7 +9,7 @@ MMCV 提供了检测、分割等任务中常用的算子
|
|||
| BallQuery | | √ | √ | | |
|
||||
| BBoxOverlaps | | √ | √ | √ | √ |
|
||||
| BorderAlign | | √ | | | |
|
||||
| BoxIouRotated | √ | √ | | | |
|
||||
| BoxIouRotated | √ | √ | √ | | |
|
||||
| BoxIouQuadri | √ | √ | | | |
|
||||
| CARAFE | | √ | √ | | |
|
||||
| ChamferDistance | | √ | | | |
|
||||
|
|
|
@ -133,7 +133,10 @@ def box_iou_rotated(bboxes1: torch.Tensor,
|
|||
if aligned:
|
||||
ious = bboxes1.new_zeros(rows)
|
||||
else:
|
||||
ious = bboxes1.new_zeros(rows * cols)
|
||||
if bboxes1.device.type == 'mlu':
|
||||
ious = bboxes1.new_zeros([rows, cols])
|
||||
else:
|
||||
ious = bboxes1.new_zeros(rows * cols)
|
||||
if not clockwise:
|
||||
flip_mat = bboxes1.new_ones(bboxes1.shape[-1])
|
||||
flip_mat[-1] = -1
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
/*************************************************************************
|
||||
* Copyright (C) 2022 by Cambricon.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#include "mlu_common_helper.h"
|
||||
|
||||
void BoxIouRotatedMLUKernelLauncher(const Tensor boxes1, const Tensor boxes2,
|
||||
Tensor ious, const int mode_flag,
|
||||
const bool aligned) {
|
||||
// get compute handle
|
||||
auto handle = mluOpGetCurrentHandle();
|
||||
|
||||
auto boxes1_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
|
||||
boxes1, boxes1.suggest_memory_format());
|
||||
auto boxes2_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
|
||||
boxes2, boxes2.suggest_memory_format());
|
||||
auto ious_contiguous =
|
||||
torch_mlu::cnnl::ops::cnnl_contiguous(ious, ious.suggest_memory_format());
|
||||
|
||||
MluOpTensorDescriptor boxes1_desc, boxes2_desc, ious_desc;
|
||||
boxes1_desc.set(boxes1_contiguous);
|
||||
boxes2_desc.set(boxes2_contiguous);
|
||||
ious_desc.set(ious_contiguous);
|
||||
|
||||
auto boxes1_impl = torch_mlu::getMluTensorImpl(boxes1_contiguous);
|
||||
auto boxes2_impl = torch_mlu::getMluTensorImpl(boxes2_contiguous);
|
||||
auto ious_impl = torch_mlu::getMluTensorImpl(ious_contiguous);
|
||||
|
||||
auto boxes1_ptr = boxes1_impl->cnnlMalloc();
|
||||
auto boxes2_ptr = boxes2_impl->cnnlMalloc();
|
||||
auto ious_ptr = ious_impl->cnnlMalloc();
|
||||
|
||||
CNLOG(INFO) << "Call mluOpBoxIouRotated().";
|
||||
mluOpBoxIouRotated(handle, mode_flag, aligned, boxes1_desc.desc(), boxes1_ptr,
|
||||
boxes2_desc.desc(), boxes2_ptr, ious_desc.desc(),
|
||||
ious_ptr);
|
||||
}
|
||||
|
||||
void box_iou_rotated_mlu(const Tensor boxes1, const Tensor boxes2, Tensor ious,
|
||||
const int mode_flag, const bool aligned) {
|
||||
BoxIouRotatedMLUKernelLauncher(boxes1, boxes2, ious, mode_flag, aligned);
|
||||
}
|
||||
|
||||
void box_iou_rotated_impl(const Tensor boxes1, const Tensor boxes2, Tensor ious,
|
||||
const int mode_flag, const bool aligned);
|
||||
|
||||
REGISTER_DEVICE_IMPL(box_iou_rotated_impl, MLU, box_iou_rotated_mlu);
|
|
@ -3,11 +3,13 @@ import numpy as np
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcv.ops import box_iou_rotated
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
|
||||
|
||||
|
||||
class TestBoxIoURotated:
|
||||
|
||||
def test_box_iou_rotated_cpu(self):
|
||||
from mmcv.ops import box_iou_rotated
|
||||
np_boxes1 = np.asarray(
|
||||
[[1.0, 1.0, 3.0, 4.0, 0.5], [2.0, 2.0, 3.0, 4.0, 0.6],
|
||||
[7.0, 7.0, 8.0, 8.0, 0.4]],
|
||||
|
@ -44,10 +46,17 @@ class TestBoxIoURotated:
|
|||
assert np.allclose(
|
||||
ious.cpu().numpy(), np_expect_ious_aligned, atol=1e-4)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='requires CUDA support')
|
||||
def test_box_iou_rotated_cuda(self):
|
||||
from mmcv.ops import box_iou_rotated
|
||||
@pytest.mark.parametrize('device', [
|
||||
pytest.param(
|
||||
'cuda',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
|
||||
pytest.param(
|
||||
'mlu',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_MLU_AVAILABLE, reason='requires MLU support'))
|
||||
])
|
||||
def test_box_iou_rotated(self, device):
|
||||
np_boxes1 = np.asarray(
|
||||
[[1.0, 1.0, 3.0, 4.0, 0.5], [2.0, 2.0, 3.0, 4.0, 0.6],
|
||||
[7.0, 7.0, 8.0, 8.0, 0.4]],
|
||||
|
@ -63,8 +72,8 @@ class TestBoxIoURotated:
|
|||
np_expect_ious_aligned = np.asarray([0.3708, 0.4487, 0.3622],
|
||||
dtype=np.float32)
|
||||
|
||||
boxes1 = torch.from_numpy(np_boxes1).cuda()
|
||||
boxes2 = torch.from_numpy(np_boxes2).cuda()
|
||||
boxes1 = torch.from_numpy(np_boxes1).to(device)
|
||||
boxes2 = torch.from_numpy(np_boxes2).to(device)
|
||||
|
||||
# test cw angle definition
|
||||
ious = box_iou_rotated(boxes1, boxes2)
|
||||
|
@ -85,7 +94,6 @@ class TestBoxIoURotated:
|
|||
ious.cpu().numpy(), np_expect_ious_aligned, atol=1e-4)
|
||||
|
||||
def test_box_iou_rotated_iof_cpu(self):
|
||||
from mmcv.ops import box_iou_rotated
|
||||
np_boxes1 = np.asarray(
|
||||
[[1.0, 1.0, 3.0, 4.0, 0.5], [2.0, 2.0, 3.0, 4.0, 0.6],
|
||||
[7.0, 7.0, 8.0, 8.0, 0.4]],
|
||||
|
@ -121,10 +129,17 @@ class TestBoxIoURotated:
|
|||
assert np.allclose(
|
||||
ious.cpu().numpy(), np_expect_ious_aligned, atol=1e-4)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='requires CUDA support')
|
||||
def test_box_iou_rotated_iof_cuda(self):
|
||||
from mmcv.ops import box_iou_rotated
|
||||
@pytest.mark.parametrize('device', [
|
||||
pytest.param(
|
||||
'cuda',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
|
||||
pytest.param(
|
||||
'mlu',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_MLU_AVAILABLE, reason='requires MLU support'))
|
||||
])
|
||||
def test_box_iou_rotated_iof(self, device):
|
||||
np_boxes1 = np.asarray(
|
||||
[[1.0, 1.0, 3.0, 4.0, 0.5], [2.0, 2.0, 3.0, 4.0, 0.6],
|
||||
[7.0, 7.0, 8.0, 8.0, 0.4]],
|
||||
|
@ -140,8 +155,8 @@ class TestBoxIoURotated:
|
|||
np_expect_ious_aligned = np.asarray([0.4959, 0.5420, 0.4404],
|
||||
dtype=np.float32)
|
||||
|
||||
boxes1 = torch.from_numpy(np_boxes1).cuda()
|
||||
boxes2 = torch.from_numpy(np_boxes2).cuda()
|
||||
boxes1 = torch.from_numpy(np_boxes1).to(device)
|
||||
boxes2 = torch.from_numpy(np_boxes2).to(device)
|
||||
|
||||
# test cw angle definition
|
||||
ious = box_iou_rotated(boxes1, boxes2, mode='iof')
|
||||
|
|
Loading…
Reference in New Issue