mirror of https://github.com/open-mmlab/mmcv.git
[Feature] Add support for Ascend devices with gather_points (#2555)
parent
4032b83270
commit
84f60c178c
|
@ -25,7 +25,7 @@ We implement common ops used in detection, segmentation, etc.
|
|||
| FurthestPointSample | | √ | | | |
|
||||
| FurthestPointSampleWithDist | | √ | | | |
|
||||
| FusedBiasLeakyrelu | | √ | | | √ |
|
||||
| GatherPoints | | √ | | | |
|
||||
| GatherPoints | | √ | | | √ |
|
||||
| GroupPoints | | √ | | | |
|
||||
| Iou3d | | √ | √ | | |
|
||||
| KNN | | √ | | | |
|
||||
|
|
|
@ -25,7 +25,7 @@ MMCV 提供了检测、分割等任务中常用的算子
|
|||
| FurthestPointSample | | √ | | | |
|
||||
| FurthestPointSampleWithDist | | √ | | | |
|
||||
| FusedBiasLeakyrelu | | √ | | | √ |
|
||||
| GatherPoints | | √ | | | |
|
||||
| GatherPoints | | √ | | | √ |
|
||||
| GroupPoints | | √ | | | |
|
||||
| Iou3d | | √ | √ | | |
|
||||
| KNN | | √ | | | |
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
#include "pytorch_npu_helper.hpp"
|
||||
|
||||
using namespace NPU_NAME_SPACE;
|
||||
using namespace std;
|
||||
|
||||
void gather_points_forward_npu(int b, int c, int n, int npoints,
|
||||
const Tensor points, const Tensor idx,
|
||||
Tensor out) {
|
||||
// b, c, n, and npoints do not need to be passed into gatherv2,
|
||||
// b, c, n, and npoints are calculated inside the operator
|
||||
// gatherv2 operator in ascend needs to set axis to 2, batch_dims is 1
|
||||
c10::SmallVector<int64_t, N> axis = {2};
|
||||
int64_t batch_dims = 1;
|
||||
|
||||
OpCommand cmd;
|
||||
cmd.Name("GatherV2")
|
||||
.Input(points)
|
||||
.Input(idx)
|
||||
.Input(axis)
|
||||
.Output(out)
|
||||
.Attr("batch_dims", batch_dims)
|
||||
.Run();
|
||||
}
|
||||
|
||||
void gather_points_forward_impl(int b, int c, int n, int npoints,
|
||||
const Tensor points, const Tensor idx,
|
||||
Tensor out);
|
||||
|
||||
REGISTER_NPU_IMPL(gather_points_forward_impl, gather_points_forward_npu);
|
|
@ -3,49 +3,65 @@ import pytest
|
|||
import torch
|
||||
|
||||
from mmcv.ops import gather_points
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='requires CUDA support')
|
||||
def test_gather_points():
|
||||
features = torch.tensor([[[
|
||||
-1.6095, -0.1029, -0.8876, -1.2447, -2.4031, 0.3708, -1.1586, -1.4967,
|
||||
-0.4800, 0.2252
|
||||
],
|
||||
[
|
||||
1.9138, 3.4979, 1.6854, 1.5631, 3.6776,
|
||||
3.1154, 2.1705, 2.5221, 2.0411, 3.1446
|
||||
],
|
||||
[
|
||||
-1.4173, 0.3073, -1.4339, -1.4340, -1.2770,
|
||||
-0.2867, -1.4162, -1.4044, -1.4245, -1.4074
|
||||
]],
|
||||
[[
|
||||
0.2160, 0.0842, 0.3661, -0.2749, -0.4909,
|
||||
-0.6066, -0.8773, -0.0745, -0.9496, 0.1434
|
||||
],
|
||||
[
|
||||
1.3644, 1.8087, 1.6855, 1.9563, 1.2746,
|
||||
1.9662, 0.9566, 1.8778, 1.1437, 1.3639
|
||||
],
|
||||
[
|
||||
-0.7172, 0.1692, 0.2241, 0.0721, -0.7540,
|
||||
0.0462, -0.6227, 0.3223, -0.6944, -0.5294
|
||||
]]]).cuda()
|
||||
class TestGatherPoints:
|
||||
|
||||
idx = torch.tensor([[0, 1, 4, 0, 0, 0], [0, 5, 6, 0, 0, 0]]).int().cuda()
|
||||
@pytest.mark.parametrize('device', [
|
||||
pytest.param(
|
||||
'cuda',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
|
||||
pytest.param(
|
||||
'npu',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_NPU_AVAILABLE, reason='requires NPU support'))
|
||||
])
|
||||
def test_gather_points_all_close(self, device):
|
||||
features = torch.tensor(
|
||||
[[[
|
||||
-1.6095, -0.1029, -0.8876, -1.2447, -2.4031, 0.3708, -1.1586,
|
||||
-1.4967, -0.4800, 0.2252
|
||||
],
|
||||
[
|
||||
1.9138, 3.4979, 1.6854, 1.5631, 3.6776, 3.1154, 2.1705,
|
||||
2.5221, 2.0411, 3.1446
|
||||
],
|
||||
[
|
||||
-1.4173, 0.3073, -1.4339, -1.4340, -1.2770, -0.2867, -1.4162,
|
||||
-1.4044, -1.4245, -1.4074
|
||||
]],
|
||||
[[
|
||||
0.2160, 0.0842, 0.3661, -0.2749, -0.4909, -0.6066, -0.8773,
|
||||
-0.0745, -0.9496, 0.1434
|
||||
],
|
||||
[
|
||||
1.3644, 1.8087, 1.6855, 1.9563, 1.2746, 1.9662, 0.9566,
|
||||
1.8778, 1.1437, 1.3639
|
||||
],
|
||||
[
|
||||
-0.7172, 0.1692, 0.2241, 0.0721, -0.7540, 0.0462, -0.6227,
|
||||
0.3223, -0.6944, -0.5294
|
||||
]]],
|
||||
dtype=torch.float,
|
||||
device=device)
|
||||
idx = torch.tensor([[0, 1, 4, 0, 0, 0], [0, 5, 6, 0, 0, 0]],
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
output = gather_points(features, idx)
|
||||
expected_output = torch.tensor(
|
||||
[[[-1.6095, -0.1029, -2.4031, -1.6095, -1.6095, -1.6095],
|
||||
[1.9138, 3.4979, 3.6776, 1.9138, 1.9138, 1.9138],
|
||||
[-1.4173, 0.3073, -1.2770, -1.4173, -1.4173, -1.4173]],
|
||||
[[0.2160, -0.6066, -0.8773, 0.2160, 0.2160, 0.2160],
|
||||
[1.3644, 1.9662, 0.9566, 1.3644, 1.3644, 1.3644],
|
||||
[-0.7172, 0.0462, -0.6227, -0.7172, -0.7172, -0.7172]]],
|
||||
dtype=torch.float,
|
||||
device=device)
|
||||
|
||||
output = gather_points(features, idx)
|
||||
expected_output = torch.tensor(
|
||||
[[[-1.6095, -0.1029, -2.4031, -1.6095, -1.6095, -1.6095],
|
||||
[1.9138, 3.4979, 3.6776, 1.9138, 1.9138, 1.9138],
|
||||
[-1.4173, 0.3073, -1.2770, -1.4173, -1.4173, -1.4173]],
|
||||
[[0.2160, -0.6066, -0.8773, 0.2160, 0.2160, 0.2160],
|
||||
[1.3644, 1.9662, 0.9566, 1.3644, 1.3644, 1.3644],
|
||||
[-0.7172, 0.0462, -0.6227, -0.7172, -0.7172, -0.7172]]]).cuda()
|
||||
assert torch.allclose(output, expected_output)
|
||||
|
||||
assert torch.allclose(output, expected_output)
|
||||
|
||||
# test fp16
|
||||
output_half = gather_points(features.half(), idx)
|
||||
assert torch.allclose(output_half, expected_output.half())
|
||||
# test fp16
|
||||
output_half = gather_points(features.half(), idx)
|
||||
assert torch.allclose(output_half, expected_output.half())
|
||||
|
|
Loading…
Reference in New Issue