mirror of https://github.com/open-mmlab/mmcv.git
Support fps and points_in_box ops for Ascend device (#3031)
parent
265531fa9f
commit
43ee50e948
|
@ -22,8 +22,8 @@ MMCV 提供了检测、分割等任务中常用的算子
|
|||
| Deformable RoIPool | | √ | √ | | √ |
|
||||
| DiffIoURotated | | √ | √ | | |
|
||||
| DynamicScatter | | √ | √ | | |
|
||||
| FurthestPointSample | | √ | | | |
|
||||
| FurthestPointSampleWithDist | | √ | | | |
|
||||
| FurthestPointSample | | √ | | | √ |
|
||||
| FurthestPointSampleWithDist | | √ | | | √ |
|
||||
| FusedBiasLeakyrelu | | √ | | | √ |
|
||||
| GatherPoints | | √ | | | √ |
|
||||
| GroupPoints | | √ | | | |
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
#include "pytorch_npu_helper.hpp"
|
||||
|
||||
using namespace NPU_NAME_SPACE;
|
||||
using namespace std;
|
||||
|
||||
void furthest_point_sampling_forward_npu(Tensor points_tensor,
|
||||
Tensor temp_tensor, Tensor idx_tensor,
|
||||
int b, int n, int m) {
|
||||
TORCH_CHECK(
|
||||
(points_tensor.sizes()[1] >= m),
|
||||
"the num of sampled points should smaller than total num of points.");
|
||||
at::Tensor points_xyz = points_tensor.transpose(1, 2).contiguous();
|
||||
at::Tensor nearest_dist = temp_tensor.contiguous();
|
||||
EXEC_NPU_CMD(aclnnFurthestPointSampling, points_xyz, nearest_dist, m,
|
||||
idx_tensor);
|
||||
}
|
||||
|
||||
void furthest_point_sampling_forward_impl(Tensor points_tensor,
|
||||
Tensor temp_tensor, Tensor idx_tensor,
|
||||
int b, int n, int m);
|
||||
|
||||
REGISTER_NPU_IMPL(furthest_point_sampling_forward_impl,
|
||||
furthest_point_sampling_forward_npu);
|
|
@ -0,0 +1,22 @@
|
|||
#include "pytorch_npu_helper.hpp"
|
||||
using namespace NPU_NAME_SPACE;
|
||||
using namespace std;
|
||||
|
||||
void furthest_point_sampling_with_dist_npu(Tensor points_tensor,
|
||||
Tensor temp_tensor,
|
||||
Tensor idx_tensor, int b, int n,
|
||||
int m) {
|
||||
auto output_size = {b, m};
|
||||
at::Tensor result =
|
||||
at::empty(output_size, points_tensor.options().dtype(at::kInt));
|
||||
EXEC_NPU_CMD(aclnnFurthestPointSamplingWithDist, points_tensor, temp_tensor,
|
||||
m, result);
|
||||
}
|
||||
|
||||
void furthest_point_sampling_with_dist_forward_impl(Tensor points_tensor,
|
||||
Tensor temp_tensor,
|
||||
Tensor idx_tensor, int b,
|
||||
int n, int m);
|
||||
|
||||
REGISTER_NPU_IMPL(furthest_point_sampling_with_dist_forward_impl,
|
||||
furthest_point_sampling_with_dist_npu);
|
|
@ -80,16 +80,14 @@ void ms_deform_attn_impl_backward(
|
|||
const Tensor &value, const Tensor &spatial_shapes,
|
||||
const Tensor &level_start_index, const Tensor &sampling_loc,
|
||||
const Tensor &attn_weight, const Tensor &grad_output, Tensor &grad_value,
|
||||
Tensor &grad_sampling_loc, Tensor &grad_attn_weight,
|
||||
const int im2col_step);
|
||||
Tensor &grad_sampling_loc, Tensor &grad_attn_weight, const int im2col_step);
|
||||
|
||||
void ms_deform_attn_backward_npu(const Tensor &value, const Tensor &spatial_shapes,
|
||||
const Tensor &level_start_index,
|
||||
const Tensor &sampling_loc,
|
||||
const Tensor &attn_weight,
|
||||
const Tensor &grad_output, Tensor &grad_value,
|
||||
Tensor &grad_sampling_loc,
|
||||
Tensor &grad_attn_weight, const int im2col_step) {
|
||||
void ms_deform_attn_backward_npu(
|
||||
const Tensor &value, const Tensor &spatial_shapes,
|
||||
const Tensor &level_start_index, const Tensor &sampling_loc,
|
||||
const Tensor &attn_weight, const Tensor &grad_output, Tensor &grad_value,
|
||||
Tensor &grad_sampling_loc, Tensor &grad_attn_weight,
|
||||
const int im2col_step) {
|
||||
check_support(value, attn_weight);
|
||||
at::Tensor value_fp32 = value;
|
||||
at::Tensor spatial_shapes_int32 = spatial_shapes;
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
#include "pytorch_npu_helper.hpp"
|
||||
|
||||
using namespace NPU_NAME_SPACE;
|
||||
|
||||
void iou3d_nms3d_normal_forward_npu(const Tensor boxes, Tensor &keep,
|
||||
Tensor &keep_num,
|
||||
float nms_overlap_thresh) {
|
||||
int32_t box_num = boxes.size(0);
|
||||
int32_t data_align = 16;
|
||||
int32_t mask_num = ((box_num - 1) / data_align + 1) * data_align;
|
||||
at::Tensor mask =
|
||||
at::empty({box_num, mask_num}, boxes.options().dtype(at::kShort));
|
||||
EXEC_NPU_CMD(aclnnNms3dNormal, boxes, nms_overlap_thresh, mask);
|
||||
|
||||
keep = at::zeros({box_num}, mask.options());
|
||||
keep_num = at::zeros(1, mask.options());
|
||||
EXEC_NPU_CMD(aclnnGatherNms3dMask, mask, keep, keep_num);
|
||||
}
|
||||
|
||||
void iou3d_nms3d_normal_forward_impl(const Tensor boxes, Tensor &keep,
|
||||
Tensor &keep_num,
|
||||
float nms_overlap_thresh);
|
||||
|
||||
REGISTER_NPU_IMPL(iou3d_nms3d_normal_forward_impl,
|
||||
iou3d_nms3d_normal_forward_npu);
|
|
@ -0,0 +1,26 @@
|
|||
#include "pytorch_npu_helper.hpp"
|
||||
|
||||
using namespace NPU_NAME_SPACE;
|
||||
using namespace std;
|
||||
|
||||
constexpr int32_t BOX_DIM = 7;
|
||||
|
||||
void iou3d_nms3d_forward_npu(const Tensor boxes, Tensor &keep, Tensor &keep_num,
|
||||
float nms_overlap_thresh) {
|
||||
TORCH_CHECK((boxes.sizes()[1] == BOX_DIM),
|
||||
"Input boxes shape should be (N, 7)");
|
||||
int32_t box_num = boxes.size(0);
|
||||
int32_t data_align = 16;
|
||||
int32_t mask_num = ((box_num - 1) / data_align + 1) * data_align;
|
||||
at::Tensor mask =
|
||||
at::empty({box_num, mask_num}, boxes.options().dtype(at::kShort));
|
||||
EXEC_NPU_CMD(aclnnNms3d, boxes, nms_overlap_thresh, mask);
|
||||
keep = at::zeros({box_num}, mask.options());
|
||||
keep_num = at::zeros(1, mask.options());
|
||||
EXEC_NPU_CMD(aclnnGatherNms3dMask, mask, keep, keep_num);
|
||||
}
|
||||
|
||||
void iou3d_nms3d_forward_impl(const Tensor boxes, Tensor &keep,
|
||||
Tensor &keep_num, float nms_overlap_thresh);
|
||||
|
||||
REGISTER_NPU_IMPL(iou3d_nms3d_forward_impl, iou3d_nms3d_forward_npu);
|
|
@ -0,0 +1,19 @@
|
|||
#include "pytorch_npu_helper.hpp"
|
||||
|
||||
using namespace NPU_NAME_SPACE;
|
||||
using namespace std;
|
||||
|
||||
void points_in_boxes_part_forward_impl_npu(int batch_size, int boxes_num,
|
||||
int pts_num, const Tensor boxes,
|
||||
const Tensor pts,
|
||||
Tensor box_idx_of_points) {
|
||||
c10::SmallVector<int64_t, 8> output_size = {pts.size(0), pts.size(1)};
|
||||
auto boxes_trans = boxes.transpose(1, 2).contiguous();
|
||||
EXEC_NPU_CMD(aclnnPointsInBox, boxes_trans, pts, box_idx_of_points);
|
||||
}
|
||||
void points_in_boxes_part_forward_impl(int batch_size, int boxes_num,
|
||||
int pts_num, const Tensor boxes,
|
||||
const Tensor pts,
|
||||
Tensor box_idx_of_points);
|
||||
REGISTER_NPU_IMPL(points_in_boxes_part_forward_impl,
|
||||
points_in_boxes_part_forward_impl_npu);
|
|
@ -12,7 +12,7 @@ void points_in_polygons_npu(const Tensor points, Tensor polygons, Tensor output,
|
|||
"The batch of polygons tensor must be less than MAX_POLYGONS_BATCH");
|
||||
at::Tensor trans_polygons = polygons.transpose(0, 1);
|
||||
OpCommand cmd;
|
||||
at::Tensor new_trans_polygons = NpuUtils::format_contiguous(trans_polygons);
|
||||
at::Tensor new_trans_polygons = trans_polygons.contiguous();
|
||||
cmd.Name("PointsInPolygons")
|
||||
.Input(points, (string) "points")
|
||||
.Input(new_trans_polygons, (string) "polygons")
|
||||
|
|
|
@ -27,8 +27,12 @@ class FurthestPointSampling(Function):
|
|||
assert points_xyz.is_contiguous()
|
||||
|
||||
B, N = points_xyz.size()[:2]
|
||||
output = torch.cuda.IntTensor(B, num_points)
|
||||
temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
|
||||
if points_xyz.device.type == 'npu':
|
||||
output = torch.IntTensor(B, num_points).npu()
|
||||
temp = torch.FloatTensor(B, N).fill_(1e10).npu()
|
||||
else:
|
||||
output = torch.cuda.IntTensor(B, num_points)
|
||||
temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
|
||||
|
||||
ext_module.furthest_point_sampling_forward(
|
||||
points_xyz,
|
||||
|
|
|
@ -3,11 +3,20 @@ import pytest
|
|||
import torch
|
||||
|
||||
from mmcv.ops import furthest_point_sample, furthest_point_sample_with_dist
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='requires CUDA support')
|
||||
def test_fps():
|
||||
@pytest.mark.parametrize('device', [
|
||||
pytest.param(
|
||||
'cuda',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
|
||||
pytest.param(
|
||||
'npu',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_NPU_AVAILABLE, reason='requires NPU support'))
|
||||
])
|
||||
def test_fps(device):
|
||||
xyz = torch.tensor([[[-0.2748, 1.0020, -1.1674], [0.1015, 1.3952, -1.2681],
|
||||
[-0.8070, 2.4137,
|
||||
-0.5845], [-1.0001, 2.1982, -0.5859],
|
||||
|
@ -15,16 +24,24 @@ def test_fps():
|
|||
[[-1.0696, 3.0758,
|
||||
-0.1899], [-0.2559, 3.5521, -0.1402],
|
||||
[0.8164, 4.0081, -0.1839], [-1.1000, 3.0213, -0.8205],
|
||||
[-0.0518, 3.7251, -0.3950]]]).cuda()
|
||||
[-0.0518, 3.7251, -0.3950]]]).to(device)
|
||||
|
||||
idx = furthest_point_sample(xyz, 3)
|
||||
expected_idx = torch.tensor([[0, 2, 4], [0, 2, 1]]).cuda()
|
||||
expected_idx = torch.tensor([[0, 2, 4], [0, 2, 1]]).to(device)
|
||||
assert torch.all(idx == expected_idx)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='requires CUDA support')
|
||||
def test_fps_with_dist():
|
||||
@pytest.mark.parametrize('device', [
|
||||
pytest.param(
|
||||
'cuda',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
|
||||
pytest.param(
|
||||
'npu',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_NPU_AVAILABLE, reason='requires NPU support'))
|
||||
])
|
||||
def test_fps_with_dist(device):
|
||||
xyz = torch.tensor([[[-0.2748, 1.0020, -1.1674], [0.1015, 1.3952, -1.2681],
|
||||
[-0.8070, 2.4137,
|
||||
-0.5845], [-1.0001, 2.1982, -0.5859],
|
||||
|
@ -32,9 +49,9 @@ def test_fps_with_dist():
|
|||
[[-1.0696, 3.0758,
|
||||
-0.1899], [-0.2559, 3.5521, -0.1402],
|
||||
[0.8164, 4.0081, -0.1839], [-1.1000, 3.0213, -0.8205],
|
||||
[-0.0518, 3.7251, -0.3950]]]).cuda()
|
||||
[-0.0518, 3.7251, -0.3950]]]).to(device)
|
||||
|
||||
expected_idx = torch.tensor([[0, 2, 4], [0, 2, 1]]).cuda()
|
||||
expected_idx = torch.tensor([[0, 2, 4], [0, 2, 1]]).to(device)
|
||||
xyz_square_dist = ((xyz.unsqueeze(dim=1) -
|
||||
xyz.unsqueeze(dim=2))**2).sum(-1)
|
||||
idx = furthest_point_sample_with_dist(xyz_square_dist, 3)
|
||||
|
@ -44,7 +61,7 @@ def test_fps_with_dist():
|
|||
fps_idx = np.load('tests/data/for_3d_ops/fps_idx.npy')
|
||||
features_for_fps_distance = np.load(
|
||||
'tests/data/for_3d_ops/features_for_fps_distance.npy')
|
||||
expected_idx = torch.from_numpy(fps_idx).cuda()
|
||||
expected_idx = torch.from_numpy(fps_idx).to(device)
|
||||
features_for_fps_distance = torch.from_numpy(
|
||||
features_for_fps_distance).cuda()
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ import pytest
|
|||
import torch
|
||||
|
||||
from mmcv.ops import boxes_iou3d, boxes_overlap_bev, nms3d, nms3d_normal
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
|
||||
|
||||
|
||||
@pytest.mark.parametrize('device', [
|
||||
|
@ -77,7 +77,11 @@ def test_boxes_iou3d(device):
|
|||
pytest.param(
|
||||
'mlu',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_MLU_AVAILABLE, reason='requires MLU support'))
|
||||
not IS_MLU_AVAILABLE, reason='requires MLU support')),
|
||||
pytest.param(
|
||||
'npu',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_NPU_AVAILABLE, reason='requires NPU support'))
|
||||
])
|
||||
def test_nms3d(device):
|
||||
# test for 5 boxes
|
||||
|
@ -116,7 +120,11 @@ def test_nms3d(device):
|
|||
pytest.param(
|
||||
'cuda',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_CUDA_AVAILABLE, reason='requires CUDA support'))
|
||||
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
|
||||
pytest.param(
|
||||
'npu',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_NPU_AVAILABLE, reason='requires NPU support'))
|
||||
])
|
||||
def test_nms3d_normal(device):
|
||||
# test for 5 boxes
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch
|
|||
|
||||
from mmcv.ops import (RoIAwarePool3d, points_in_boxes_all, points_in_boxes_cpu,
|
||||
points_in_boxes_part)
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
|
||||
|
||||
|
||||
@pytest.mark.parametrize('device', [
|
||||
|
@ -56,38 +56,46 @@ def test_RoIAwarePool3d(device, dtype):
|
|||
torch.tensor(49.750, dtype=dtype).to(device), 1e-3)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='requires CUDA support')
|
||||
def test_points_in_boxes_part():
|
||||
@pytest.mark.parametrize('device', [
|
||||
pytest.param(
|
||||
'cuda',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
|
||||
pytest.param(
|
||||
'npu',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_NPU_AVAILABLE, reason='requires NPU support'))
|
||||
])
|
||||
def test_points_in_boxes_part(device):
|
||||
boxes = torch.tensor(
|
||||
[[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.3]],
|
||||
[[-10.0, 23.0, 16.0, 10, 20, 20, 0.5]]],
|
||||
dtype=torch.float32).cuda(
|
||||
) # boxes (b, t, 7) with bottom center in lidar coordinate
|
||||
dtype=torch.float32).to(
|
||||
device) # boxes (b, t, 7) with bottom center in lidar coordinate
|
||||
pts = torch.tensor(
|
||||
[[[1, 2, 3.3], [1.2, 2.5, 3.0], [0.8, 2.1, 3.5], [1.6, 2.6, 3.6],
|
||||
[0.8, 1.2, 3.9], [-9.2, 21.0, 18.2], [3.8, 7.9, 6.3],
|
||||
[4.7, 3.5, -12.2]],
|
||||
[[3.8, 7.6, -2], [-10.6, -12.9, -20], [-16, -18, 9], [-21.3, -52, -5],
|
||||
[0, 0, 0], [6, 7, 8], [-2, -3, -4], [6, 4, 9]]],
|
||||
dtype=torch.float32).cuda() # points (b, m, 3) in lidar coordinate
|
||||
dtype=torch.float32).to(device) # points (b, m, 3) in lidar coordinate
|
||||
|
||||
point_indices = points_in_boxes_part(points=pts, boxes=boxes)
|
||||
expected_point_indices = torch.tensor(
|
||||
[[0, 0, 0, 0, 0, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1]],
|
||||
dtype=torch.int32).cuda()
|
||||
dtype=torch.int32).to(device)
|
||||
assert point_indices.shape == torch.Size([2, 8])
|
||||
assert (point_indices == expected_point_indices).all()
|
||||
|
||||
boxes = torch.tensor([[[0.0, 0.0, 0.0, 1.0, 20.0, 1.0, 0.523598]]],
|
||||
dtype=torch.float32).cuda() # 30 degrees
|
||||
dtype=torch.float32).to(device) # 30 degrees
|
||||
pts = torch.tensor(
|
||||
[[[4, 6.928, 0], [6.928, 4, 0], [4, -6.928, 0], [6.928, -4, 0],
|
||||
[-4, 6.928, 0], [-6.928, 4, 0], [-4, -6.928, 0], [-6.928, -4, 0]]],
|
||||
dtype=torch.float32).cuda()
|
||||
dtype=torch.float32).to(device)
|
||||
point_indices = points_in_boxes_part(points=pts, boxes=boxes)
|
||||
expected_point_indices = torch.tensor([[-1, -1, 0, -1, 0, -1, -1, -1]],
|
||||
dtype=torch.int32).cuda()
|
||||
dtype=torch.int32).to(device)
|
||||
assert (point_indices == expected_point_indices).all()
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue