mirror of https://github.com/open-mmlab/mmcv.git
[Feature] Add support for Ascend devices with nms_rotated (#2550)
* [Feature]: add nms_rotated npu adaptater code * [BugFix]: modify param in nms_rotated_npu.cpp * [clean code]: nms_rotated_npu.cpp * [clean code]: nms_rotated_npu.cpp * [clean code]: nms_rotated_npu.cpp * [clean code]: nms_rotated.cpp * [Doc]: add nms_rotated op in supported op list at ops.md * [Test]: add nms_rotated unit_test * [Bug]: remove device parameter in test_batched_nms functionpull/2572/head
parent
0b005c52b4
commit
54ed0ed869
|
@ -35,7 +35,7 @@ We implement common ops used in detection, segmentation, etc.
|
|||
| ModulatedDeformConv2d | √ | √ | √ | | √ |
|
||||
| MultiScaleDeformableAttn | | √ | √ | | |
|
||||
| NMS | √ | √ | √ | | √ |
|
||||
| NMSRotated | √ | √ | | | |
|
||||
| NMSRotated | √ | √ | | | √ |
|
||||
| NMSQuadri | √ | √ | | | |
|
||||
| PixelGroup | √ | | | | |
|
||||
| PointsInBoxes | √ | √ | | | |
|
||||
|
|
|
@ -35,7 +35,7 @@ MMCV 提供了检测、分割等任务中常用的算子
|
|||
| ModulatedDeformConv2d | √ | √ | √ | | √ |
|
||||
| MultiScaleDeformableAttn | | √ | √ | | |
|
||||
| NMS | √ | √ | √ | | √ |
|
||||
| NMSRotated | √ | √ | | | |
|
||||
| NMSRotated | √ | √ | | | √ |
|
||||
| NMSQuadri | √ | √ | | | |
|
||||
| PixelGroup | √ | | | | |
|
||||
| PointsInBoxes | √ | √ | | | |
|
||||
|
|
|
@ -12,12 +12,17 @@ Tensor nms_rotated_cuda(const Tensor dets, const Tensor scores,
|
|||
const float iou_threshold, const int multi_label);
|
||||
#endif
|
||||
|
||||
#ifdef MMCV_WITH_NPU
|
||||
Tensor nms_rotated_npu(const Tensor dets, const Tensor scores,
|
||||
const Tensor labels, const float iou_threshold);
|
||||
#endif
|
||||
|
||||
// Interface for Python
|
||||
// inline is needed to prevent multiple function definitions when this header is
|
||||
// included by different cpps
|
||||
Tensor nms_rotated(const Tensor dets, const Tensor scores, const Tensor order,
|
||||
const Tensor dets_sorted, const float iou_threshold,
|
||||
const int multi_label) {
|
||||
const Tensor dets_sorted, const Tensor labels,
|
||||
const float iou_threshold, const int multi_label) {
|
||||
assert(dets.device().is_cuda() == scores.device().is_cuda());
|
||||
if (dets.device().is_cuda()) {
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
|
@ -25,6 +30,12 @@ Tensor nms_rotated(const Tensor dets, const Tensor scores, const Tensor order,
|
|||
multi_label);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else if (dets.device().type() == at::kXLA) {
|
||||
#ifdef MMCV_WITH_NPU
|
||||
return nms_rotated_npu(dets, scores, labels, iou_threshold);
|
||||
#else
|
||||
AT_ERROR("Not compiled with NPU support");
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
#include "pytorch_npu_helper.hpp"
|
||||
|
||||
using namespace NPU_NAME_SPACE;
|
||||
|
||||
Tensor nms_rotated_npu(const Tensor dets, const Tensor scores,
|
||||
const Tensor labels, const float iou_threshold) {
|
||||
auto originDtype = dets.scalar_type();
|
||||
at::Tensor detsCast = dets;
|
||||
at::Tensor scoresCast = scores;
|
||||
if (originDtype != at::ScalarType::Float) {
|
||||
detsCast = NPUNativeFunctions::npu_dtype_cast(dets, at::kFloat);
|
||||
scoresCast = NPUNativeFunctions::npu_dtype_cast(scores, at::kFloat);
|
||||
}
|
||||
c10::SmallVector<int64_t, SIZE> selectedIndexSize = {dets.size(0)};
|
||||
at::Tensor selectedBox = OpPreparation::ApplyTensor(dets);
|
||||
at::Tensor selectedIndex = OpPreparation::ApplyTensor(
|
||||
selectedIndexSize, dets.options().dtype(at::kInt), dets);
|
||||
|
||||
c10::SmallVector<int64_t, N> output_sync_idx = {0, 1};
|
||||
OpCommand cmd;
|
||||
cmd.Sync(output_sync_idx)
|
||||
.Name("RotatedNMS")
|
||||
.Input(detsCast)
|
||||
.Input(scoresCast)
|
||||
.Input(labels)
|
||||
.Output(selectedBox)
|
||||
.Output(selectedIndex)
|
||||
.Attr("iou_threshold", (float)iou_threshold)
|
||||
.Run();
|
||||
selectedIndex = NPUNativeFunctions::npu_dtype_cast(selectedIndex, at::kLong);
|
||||
return selectedIndex;
|
||||
}
|
|
@ -309,8 +309,8 @@ void box_iou_rotated(const Tensor boxes1, const Tensor boxes2, Tensor ious,
|
|||
const int mode_flag, const bool aligned);
|
||||
|
||||
Tensor nms_rotated(const Tensor dets, const Tensor scores, const Tensor order,
|
||||
const Tensor dets_sorted, const float iou_threshold,
|
||||
const int multi_label);
|
||||
const Tensor dets_sorted, const Tensor labels,
|
||||
const float iou_threshold, const int multi_label);
|
||||
|
||||
Tensor upfirdn2d(const Tensor &input, const Tensor &kernel, int up_x, int up_y,
|
||||
int down_x, int down_y, int pad_x0, int pad_x1, int pad_y0,
|
||||
|
@ -748,7 +748,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||
py::arg("mode_flag"), py::arg("aligned"));
|
||||
m.def("nms_rotated", &nms_rotated, "NMS for rotated boxes", py::arg("dets"),
|
||||
py::arg("scores"), py::arg("order"), py::arg("dets_sorted"),
|
||||
py::arg("iou_threshold"), py::arg("multi_label"));
|
||||
py::arg("labels"), py::arg("iou_threshold"), py::arg("multi_label"));
|
||||
m.def("ball_query_forward", &ball_query_forward, "ball_query_forward",
|
||||
py::arg("new_xyz_tensor"), py::arg("xyz_tensor"), py::arg("idx_tensor"),
|
||||
py::arg("b"), py::arg("n"), py::arg("m"), py::arg("min_radius"),
|
||||
|
|
|
@ -454,6 +454,19 @@ def nms_rotated(dets: Tensor,
|
|||
else:
|
||||
dets_cw = dets
|
||||
multi_label = labels is not None
|
||||
if labels is None:
|
||||
input_labels = scores.new_empty(0, dtype=torch.int)
|
||||
else:
|
||||
input_labels = labels
|
||||
if dets.device.type == 'npu':
|
||||
order = scores.new_empty(0, dtype=torch.long)
|
||||
keep_inds = ext_module.nms_rotated(dets_cw, scores, order, dets_cw,
|
||||
input_labels, iou_threshold,
|
||||
multi_label)
|
||||
dets = torch.cat((dets[keep_inds], scores[keep_inds].reshape(-1, 1)),
|
||||
dim=1)
|
||||
return dets, keep_inds
|
||||
|
||||
if multi_label:
|
||||
dets_wl = torch.cat((dets_cw, labels.unsqueeze(1)), 1) # type: ignore
|
||||
else:
|
||||
|
@ -467,11 +480,13 @@ def nms_rotated(dets: Tensor,
|
|||
scores,
|
||||
order,
|
||||
dets_sorted,
|
||||
input_labels,
|
||||
iou_threshold=iou_threshold,
|
||||
multi_label=multi_label)
|
||||
else:
|
||||
keep_inds = ext_module.nms_rotated(dets_wl, scores, order, dets_sorted,
|
||||
iou_threshold, multi_label)
|
||||
input_labels, iou_threshold,
|
||||
multi_label)
|
||||
dets = torch.cat((dets[keep_inds], scores[keep_inds].reshape(-1, 1)),
|
||||
dim=1)
|
||||
return dets, keep_inds
|
||||
|
|
|
@ -3,13 +3,22 @@ import numpy as np
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(),
|
||||
reason='GPU is required to test NMSRotated op')
|
||||
class TestNmsRotated:
|
||||
|
||||
def test_ml_nms_rotated(self):
|
||||
@pytest.mark.parametrize('device', [
|
||||
pytest.param(
|
||||
'npu',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_NPU_AVAILABLE, reason='requires NPU support')),
|
||||
pytest.param(
|
||||
'cuda',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_CUDA_AVAILABLE, reason='requires CUDA support'))
|
||||
])
|
||||
def test_ml_nms_rotated(self, device):
|
||||
from mmcv.ops import nms_rotated
|
||||
np_boxes = np.array(
|
||||
[[6.0, 3.0, 8.0, 7.0, 0.5, 0.7], [3.0, 6.0, 9.0, 11.0, 0.6, 0.8],
|
||||
|
@ -24,8 +33,8 @@ class TestNmsRotated:
|
|||
dtype=np.float32)
|
||||
np_expect_keep_inds = np.array([3, 1, 0], dtype=np.int64)
|
||||
|
||||
boxes = torch.from_numpy(np_boxes).cuda()
|
||||
labels = torch.from_numpy(np_labels).cuda()
|
||||
boxes = torch.from_numpy(np_boxes).to(device)
|
||||
labels = torch.from_numpy(np_labels).to(device)
|
||||
|
||||
# test cw angle definition
|
||||
dets, keep_inds = nms_rotated(boxes[:, :5], boxes[:, -1], 0.5, labels)
|
||||
|
@ -41,7 +50,17 @@ class TestNmsRotated:
|
|||
assert np.allclose(dets.cpu().numpy()[:, :5], np_expect_dets)
|
||||
assert np.allclose(keep_inds.cpu().numpy(), np_expect_keep_inds)
|
||||
|
||||
def test_nms_rotated(self):
|
||||
@pytest.mark.parametrize('device', [
|
||||
pytest.param(
|
||||
'npu',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_NPU_AVAILABLE, reason='requires NPU support')),
|
||||
pytest.param(
|
||||
'cuda',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_CUDA_AVAILABLE, reason='requires CUDA support'))
|
||||
])
|
||||
def test_nms_rotated(self, device):
|
||||
from mmcv.ops import nms_rotated
|
||||
np_boxes = np.array(
|
||||
[[6.0, 3.0, 8.0, 7.0, 0.5, 0.7], [3.0, 6.0, 9.0, 11.0, 0.6, 0.8],
|
||||
|
@ -55,7 +74,7 @@ class TestNmsRotated:
|
|||
dtype=np.float32)
|
||||
np_expect_keep_inds = np.array([3, 1, 0], dtype=np.int64)
|
||||
|
||||
boxes = torch.from_numpy(np_boxes).cuda()
|
||||
boxes = torch.from_numpy(np_boxes).to(device)
|
||||
|
||||
# test cw angle definition
|
||||
dets, keep_inds = nms_rotated(boxes[:, :5], boxes[:, -1], 0.5)
|
||||
|
|
Loading…
Reference in New Issue