mirror of https://github.com/open-mmlab/mmcv.git
[Feature] Support NmsRotated with cambricon MLU backend (#2643)
* [Feature] Support NmsRotated with cambricon MLU backend * [Feature] remove foolproofs in nms_rotated_mlu.cpp * [Feature] fix lint in test_nms_rotated.py * [Feature] fix kMLU not found in nms_rotated.cpp * [Feature] modify mlu support in nms.py * [Feature] modify nms_rotated support in ops.md * [Feature] modify ops/nms.pypull/2598/head
parent
e6f434c440
commit
0d1b224fb1
|
@ -35,7 +35,7 @@ We implement common ops used in detection, segmentation, etc.
|
|||
| ModulatedDeformConv2d | √ | √ | √ | | √ |
|
||||
| MultiScaleDeformableAttn | | √ | √ | | |
|
||||
| NMS | √ | √ | √ | | √ |
|
||||
| NMSRotated | √ | √ | | | √ |
|
||||
| NMSRotated | √ | √ | √ | | √ |
|
||||
| NMSQuadri | √ | √ | | | |
|
||||
| PixelGroup | √ | | | | |
|
||||
| PointsInBoxes | √ | √ | | | |
|
||||
|
|
|
@ -35,7 +35,7 @@ MMCV 提供了检测、分割等任务中常用的算子
|
|||
| ModulatedDeformConv2d | √ | √ | √ | | √ |
|
||||
| MultiScaleDeformableAttn | | √ | √ | | |
|
||||
| NMS | √ | √ | √ | | √ |
|
||||
| NMSRotated | √ | √ | | | √ |
|
||||
| NMSRotated | √ | √ | √ | | √ |
|
||||
| NMSQuadri | √ | √ | | | |
|
||||
| PixelGroup | √ | | | | |
|
||||
| PointsInBoxes | √ | √ | | | |
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
/*************************************************************************
|
||||
* Copyright (C) 2021 Cambricon.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#include "mlu_common_helper.h"
|
||||
|
||||
Tensor nms_rotated_mlu(Tensor boxes, Tensor scores, float iou_threshold) {
|
||||
if (boxes.numel() == 0) {
|
||||
return at::empty({0}, boxes.options().dtype(at::kLong));
|
||||
}
|
||||
|
||||
int boxes_num = boxes.size(0);
|
||||
auto boxes_ = torch_mlu::cnnl::ops::cnnl_contiguous(boxes);
|
||||
auto scores_ = torch_mlu::cnnl::ops::cnnl_contiguous(scores);
|
||||
auto output = at::empty({boxes_num}, boxes.options().dtype(at::kInt));
|
||||
auto output_size = at::empty({1}, scores.options().dtype(at::kInt));
|
||||
|
||||
MluOpTensorDescriptor boxes_desc, scores_desc, output_desc;
|
||||
boxes_desc.set(boxes_);
|
||||
scores_desc.set(scores_);
|
||||
output_desc.set(output);
|
||||
|
||||
// workspace
|
||||
size_t workspace_size = 0;
|
||||
auto handle = mluOpGetCurrentHandle();
|
||||
mluOpGetNmsRotatedWorkspaceSize(handle, boxes_desc.desc(), &workspace_size);
|
||||
auto workspace = at::empty(workspace_size, boxes.options().dtype(at::kByte));
|
||||
|
||||
auto boxes_impl = torch_mlu::getMluTensorImpl(boxes_);
|
||||
auto boxes_ptr = boxes_impl->cnnlMalloc();
|
||||
auto scores_impl = torch_mlu::getMluTensorImpl(scores_);
|
||||
auto scores_ptr = scores_impl->cnnlMalloc();
|
||||
auto workspace_impl = torch_mlu::getMluTensorImpl(workspace);
|
||||
auto workspace_ptr = workspace_impl->cnnlMalloc();
|
||||
auto output_impl = torch_mlu::getMluTensorImpl(output);
|
||||
auto output_ptr = output_impl->cnnlMalloc();
|
||||
auto output_size_impl = torch_mlu::getMluTensorImpl(output_size);
|
||||
auto output_size_ptr = output_size_impl->cnnlMalloc();
|
||||
|
||||
mluOpNmsRotated(handle, iou_threshold, boxes_desc.desc(), boxes_ptr,
|
||||
scores_desc.desc(), scores_ptr, workspace_ptr, workspace_size,
|
||||
output_desc.desc(), output_ptr, (int *)output_size_ptr);
|
||||
int output_num = *static_cast<int *>(output_size.cpu().data_ptr());
|
||||
auto ret = output.to(boxes.options().dtype(at::kLong));
|
||||
return ret.slice(0, 0, output_num);
|
||||
}
|
|
@ -17,6 +17,11 @@ Tensor nms_rotated_npu(const Tensor dets, const Tensor scores,
|
|||
const Tensor labels, const float iou_threshold);
|
||||
#endif
|
||||
|
||||
#ifdef MMCV_WITH_MLU
|
||||
Tensor nms_rotated_mlu(const Tensor dets, const Tensor scores,
|
||||
const float iou_threshold);
|
||||
#endif
|
||||
|
||||
// Interface for Python
|
||||
// inline is needed to prevent multiple function definitions when this header is
|
||||
// included by different cpps
|
||||
|
@ -36,6 +41,10 @@ Tensor nms_rotated(const Tensor dets, const Tensor scores, const Tensor order,
|
|||
return nms_rotated_npu(dets, scores, labels, iou_threshold);
|
||||
#else
|
||||
AT_ERROR("Not compiled with NPU support");
|
||||
#endif
|
||||
#ifdef MMCV_WITH_MLU
|
||||
} else if (dets.device().type() == at::kMLU) {
|
||||
return nms_rotated_mlu(dets, scores, iou_threshold);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
|
@ -458,11 +458,12 @@ def nms_rotated(dets: Tensor,
|
|||
input_labels = scores.new_empty(0, dtype=torch.int)
|
||||
else:
|
||||
input_labels = labels
|
||||
if dets.device.type == 'npu':
|
||||
if dets.device.type in ('npu', 'mlu'):
|
||||
order = scores.new_empty(0, dtype=torch.long)
|
||||
coefficient = 57.29578 # 180 / PI
|
||||
for i in range(dets.size()[0]):
|
||||
dets_cw[i][4] *= coefficient # radians to angle
|
||||
if dets.device.type == 'npu':
|
||||
coefficient = 57.29578 # 180 / PI
|
||||
for i in range(dets.size()[0]):
|
||||
dets_cw[i][4] *= coefficient # radians to angle
|
||||
keep_inds = ext_module.nms_rotated(dets_cw, scores, order, dets_cw,
|
||||
input_labels, iou_threshold,
|
||||
multi_label)
|
||||
|
|
|
@ -3,7 +3,7 @@ import numpy as np
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
|
||||
|
||||
|
||||
class TestNmsRotated:
|
||||
|
@ -16,7 +16,11 @@ class TestNmsRotated:
|
|||
pytest.param(
|
||||
'cuda',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_CUDA_AVAILABLE, reason='requires CUDA support'))
|
||||
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
|
||||
pytest.param(
|
||||
'mlu',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_MLU_AVAILABLE, reason='requires MLU support'))
|
||||
])
|
||||
def test_ml_nms_rotated(self, device):
|
||||
from mmcv.ops import nms_rotated
|
||||
|
@ -58,7 +62,11 @@ class TestNmsRotated:
|
|||
pytest.param(
|
||||
'cuda',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_CUDA_AVAILABLE, reason='requires CUDA support'))
|
||||
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
|
||||
pytest.param(
|
||||
'mlu',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_MLU_AVAILABLE, reason='requires MLU support'))
|
||||
])
|
||||
def test_nms_rotated(self, device):
|
||||
from mmcv.ops import nms_rotated
|
||||
|
|
Loading…
Reference in New Issue