mirror of https://github.com/open-mmlab/mmcv.git
[Feature] Add the implementation of diff_iou_rotated with mlu-ops (#2840)
parent
10c8b9e78b
commit
d28aa8a9cc
|
@ -20,7 +20,7 @@ We implement common ops used in detection, segmentation, etc.
|
|||
| Correlation | | √ | | | |
|
||||
| Deformable Convolution v1/v2 | √ | √ | | | √ |
|
||||
| Deformable RoIPool | | √ | √ | | √ |
|
||||
| DiffIoURotated | | √ | | | |
|
||||
| DiffIoURotated | | √ | √ | | |
|
||||
| DynamicScatter | | √ | √ | | |
|
||||
| FurthestPointSample | | √ | | | |
|
||||
| FurthestPointSampleWithDist | | √ | | | |
|
||||
|
|
|
@ -20,7 +20,7 @@ MMCV 提供了检测、分割等任务中常用的算子
|
|||
| Correlation | | √ | | | |
|
||||
| Deformable Convolution v1/v2 | √ | √ | | | √ |
|
||||
| Deformable RoIPool | | √ | √ | | √ |
|
||||
| DiffIoURotated | | √ | | | |
|
||||
| DiffIoURotated | | √ | √ | | |
|
||||
| DynamicScatter | | √ | √ | | |
|
||||
| FurthestPointSample | | √ | | | |
|
||||
| FurthestPointSampleWithDist | | √ | | | |
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
/*************************************************************************
|
||||
* Copyright (C) 2023 Cambricon.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#include "mlu_common_helper.h"
|
||||
|
||||
Tensor diff_iou_rotated_sort_vertices_forward_mlu(Tensor vertices, Tensor mask,
|
||||
Tensor num_valid) {
|
||||
// params check
|
||||
TORCH_CHECK(vertices.scalar_type() == at::kFloat,
|
||||
"vertices type should be Float, got ", vertices.scalar_type());
|
||||
TORCH_CHECK(mask.scalar_type() == at::kBool, "mask should be Bool, got ",
|
||||
mask.scalar_type());
|
||||
TORCH_CHECK(num_valid.scalar_type() == at::kInt,
|
||||
"num_valid type should be Int32, got ", num_valid.scalar_type());
|
||||
TORCH_CHECK(vertices.size(2) == 24, "vertices.dim(2) should be 24, got ",
|
||||
vertices.size(2));
|
||||
TORCH_CHECK(mask.size(2) == 24, "mask.dim(2) should be 24, got ",
|
||||
mask.size(2));
|
||||
|
||||
// zero-element check
|
||||
if (vertices.numel() == 0) {
|
||||
return at::empty({0}, num_valid.options().dtype(at::kInt));
|
||||
}
|
||||
|
||||
auto idx = at::empty({vertices.size(0), vertices.size(1), 9},
|
||||
num_valid.options().dtype(at::kInt));
|
||||
|
||||
INITIAL_MLU_PARAM_WITH_TENSOR(vertices);
|
||||
INITIAL_MLU_PARAM_WITH_TENSOR(mask);
|
||||
INITIAL_MLU_PARAM_WITH_TENSOR(num_valid);
|
||||
INITIAL_MLU_PARAM_WITH_TENSOR(idx);
|
||||
|
||||
// get compute handle
|
||||
auto handle = mluOpGetCurrentHandle();
|
||||
|
||||
// launch kernel
|
||||
mluOpDiffIouRotatedSortVerticesForward(
|
||||
handle, vertices_desc.desc(), vertices_ptr, mask_desc.desc(), mask_ptr,
|
||||
num_valid_desc.desc(), num_valid_ptr, idx_desc.desc(), idx_ptr);
|
||||
return idx;
|
||||
}
|
||||
|
||||
Tensor diff_iou_rotated_sort_vertices_forward_impl(Tensor vertices, Tensor mask,
|
||||
Tensor num_valid);
|
||||
|
||||
REGISTER_DEVICE_IMPL(diff_iou_rotated_sort_vertices_forward_impl, MLU,
|
||||
diff_iou_rotated_sort_vertices_forward_mlu);
|
|
@ -34,7 +34,7 @@
|
|||
auto NAME##_impl = torch_mlu::getMluTensorImpl(NAME##_contigous); \
|
||||
auto NAME##_ptr = NAME##_impl->cnnlMalloc();
|
||||
|
||||
enum class reduce_t{ SUM = 0, MEAN = 1, MAX = 2 };
|
||||
enum class reduce_t { SUM = 0, MEAN = 1, MAX = 2 };
|
||||
|
||||
inline std::string to_string(reduce_t reduce_type) {
|
||||
if (reduce_type == reduce_t::MAX) {
|
||||
|
|
|
@ -11,16 +11,16 @@
|
|||
*************************************************************************/
|
||||
#include "mlu_common_helper.h"
|
||||
|
||||
std::vector<Tensor> dynamic_point_to_voxel_forward_mlu(const Tensor &feats,
|
||||
const Tensor &coors,
|
||||
const reduce_t reduce_type) {
|
||||
std::vector<Tensor> dynamic_point_to_voxel_forward_mlu(
|
||||
const Tensor &feats, const Tensor &coors, const reduce_t reduce_type) {
|
||||
// params check
|
||||
TORCH_CHECK(feats.scalar_type() == at::kFloat,
|
||||
"feats type should be Float, got ", feats.scalar_type());
|
||||
"feats type should be Float, got ", feats.scalar_type());
|
||||
TORCH_CHECK(coors.scalar_type() == at::kInt,
|
||||
"coors type should be Int32, got ", coors.scalar_type());
|
||||
"coors type should be Int32, got ", coors.scalar_type());
|
||||
TORCH_CHECK(feats.size(0) == coors.size(0),
|
||||
"feats.dim(0) and coors.dim(0) should be same, got ", feats.size(0), " vs ", coors.size(0));
|
||||
"feats.dim(0) and coors.dim(0) should be same, got ",
|
||||
feats.size(0), " vs ", coors.size(0));
|
||||
|
||||
const int num_input = feats.size(0);
|
||||
const int num_feats = feats.size(1);
|
||||
|
@ -49,59 +49,48 @@ std::vector<Tensor> dynamic_point_to_voxel_forward_mlu(const Tensor &feats,
|
|||
auto handle = mluOpGetCurrentHandle();
|
||||
|
||||
size_t workspace_size;
|
||||
mluOpGetDynamicPointToVoxelForwardWorkspaceSize(handle,
|
||||
feats_desc.desc(),
|
||||
coors_desc.desc(),
|
||||
&workspace_size);
|
||||
mluOpGetDynamicPointToVoxelForwardWorkspaceSize(
|
||||
handle, feats_desc.desc(), coors_desc.desc(), &workspace_size);
|
||||
auto workspace_tensor =
|
||||
at::empty(workspace_size, feats.options().dtype(at::kByte));
|
||||
INITIAL_MLU_PARAM_WITH_TENSOR(workspace_tensor);
|
||||
|
||||
// launch kernel
|
||||
mluOpDynamicPointToVoxelForward(handle,
|
||||
mlu_reduce_type,
|
||||
feats_desc.desc(),
|
||||
feats_ptr,
|
||||
coors_desc.desc(),
|
||||
coors_ptr,
|
||||
workspace_tensor_ptr,
|
||||
workspace_size,
|
||||
reduced_feats_desc.desc(),
|
||||
reduced_feats_ptr,
|
||||
out_coors_desc.desc(),
|
||||
out_coors_ptr,
|
||||
coors_map_desc.desc(),
|
||||
coors_map_ptr,
|
||||
reduce_count_desc.desc(),
|
||||
reduce_count_ptr,
|
||||
voxel_num_desc.desc(),
|
||||
voxel_num_ptr);
|
||||
mluOpDynamicPointToVoxelForward(
|
||||
handle, mlu_reduce_type, feats_desc.desc(), feats_ptr, coors_desc.desc(),
|
||||
coors_ptr, workspace_tensor_ptr, workspace_size,
|
||||
reduced_feats_desc.desc(), reduced_feats_ptr, out_coors_desc.desc(),
|
||||
out_coors_ptr, coors_map_desc.desc(), coors_map_ptr,
|
||||
reduce_count_desc.desc(), reduce_count_ptr, voxel_num_desc.desc(),
|
||||
voxel_num_ptr);
|
||||
|
||||
int voxel_num_value = *static_cast<int *>(voxel_num.cpu().data_ptr());
|
||||
TORCH_CHECK(voxel_num_value <= feats.size(0),
|
||||
"voxel_num should be less than or equal to feats_num, got ", voxel_num_value, " vs ", feats.size(0));
|
||||
return {reduced_feats.slice(0, 0, voxel_num_value), out_coors.slice(0, 0, voxel_num_value),
|
||||
coors_map, reduce_count.slice(0, 0, voxel_num_value)};
|
||||
"voxel_num should be less than or equal to feats_num, got ",
|
||||
voxel_num_value, " vs ", feats.size(0));
|
||||
return {reduced_feats.slice(0, 0, voxel_num_value),
|
||||
out_coors.slice(0, 0, voxel_num_value), coors_map,
|
||||
reduce_count.slice(0, 0, voxel_num_value)};
|
||||
}
|
||||
|
||||
void dynamic_point_to_voxel_backward_mlu(Tensor &grad_feats,
|
||||
const Tensor &grad_reduced_feats,
|
||||
const Tensor &feats,
|
||||
const Tensor &reduced_feats,
|
||||
const Tensor &coors_idx,
|
||||
const Tensor &reduce_count,
|
||||
const reduce_t reduce_type) {
|
||||
void dynamic_point_to_voxel_backward_mlu(
|
||||
Tensor &grad_feats, const Tensor &grad_reduced_feats, const Tensor &feats,
|
||||
const Tensor &reduced_feats, const Tensor &coors_idx,
|
||||
const Tensor &reduce_count, const reduce_t reduce_type) {
|
||||
// params check
|
||||
TORCH_CHECK(grad_reduced_feats.scalar_type() == at::kFloat,
|
||||
"grad_reduced_feats type should be Float, got ", grad_reduced_feats.scalar_type());
|
||||
"grad_reduced_feats type should be Float, got ",
|
||||
grad_reduced_feats.scalar_type());
|
||||
TORCH_CHECK(feats.scalar_type() == at::kFloat,
|
||||
"feats type should be Float, got ", feats.scalar_type());
|
||||
"feats type should be Float, got ", feats.scalar_type());
|
||||
TORCH_CHECK(reduced_feats.scalar_type() == at::kFloat,
|
||||
"reduced_feats type should be Float, got ", reduced_feats.scalar_type());
|
||||
"reduced_feats type should be Float, got ",
|
||||
reduced_feats.scalar_type());
|
||||
TORCH_CHECK(coors_idx.scalar_type() == at::kInt,
|
||||
"coors_idx type should be Int32, got ", coors_idx.scalar_type());
|
||||
"coors_idx type should be Int32, got ", coors_idx.scalar_type());
|
||||
TORCH_CHECK(reduce_count.scalar_type() == at::kInt,
|
||||
"reduce_count type should be Int32, got ", reduce_count.scalar_type());
|
||||
"reduce_count type should be Int32, got ",
|
||||
reduce_count.scalar_type());
|
||||
|
||||
const int num_input = feats.size(0);
|
||||
const int num_reduced = reduced_feats.size(0);
|
||||
|
@ -114,11 +103,13 @@ void dynamic_point_to_voxel_backward_mlu(Tensor &grad_feats,
|
|||
|
||||
// TODO(miaochen): remove this after mlu-ops supports other mode of reduce.
|
||||
TORCH_CHECK(reduce_type == reduce_t::MAX,
|
||||
"only supports max reduce in current version, got ", to_string(reduce_type));
|
||||
"only supports max reduce in current version, got ",
|
||||
to_string(reduce_type));
|
||||
|
||||
int voxel_num_value = reduced_feats.size(0);
|
||||
auto opts = torch::TensorOptions().dtype(torch::kInt32);
|
||||
auto voxel_num = torch::from_blob(&voxel_num_value, {1}, opts).clone().to(at::kMLU);
|
||||
auto voxel_num =
|
||||
torch::from_blob(&voxel_num_value, {1}, opts).clone().to(at::kMLU);
|
||||
auto mlu_reduce_type = getMluOpReduceMode(reduce_type);
|
||||
|
||||
INITIAL_MLU_PARAM_WITH_TENSOR(grad_feats);
|
||||
|
@ -134,43 +125,30 @@ void dynamic_point_to_voxel_backward_mlu(Tensor &grad_feats,
|
|||
|
||||
size_t workspace_size;
|
||||
mluOpGetDynamicPointToVoxelBackwardWorkspaceSize(
|
||||
handle, mlu_reduce_type,
|
||||
grad_feats_desc.desc(),
|
||||
feats_desc.desc(),
|
||||
grad_reduced_feats_desc.desc(),
|
||||
coors_idx_desc.desc(),
|
||||
reduce_count_desc.desc(),
|
||||
voxel_num_desc.desc(),
|
||||
&workspace_size);
|
||||
handle, mlu_reduce_type, grad_feats_desc.desc(), feats_desc.desc(),
|
||||
grad_reduced_feats_desc.desc(), coors_idx_desc.desc(),
|
||||
reduce_count_desc.desc(), voxel_num_desc.desc(), &workspace_size);
|
||||
auto workspace_tensor =
|
||||
at::empty(workspace_size, feats.options().dtype(at::kByte));
|
||||
INITIAL_MLU_PARAM_WITH_TENSOR(workspace_tensor);
|
||||
|
||||
// launch kernel
|
||||
mluOpDynamicPointToVoxelBackward(
|
||||
handle, mlu_reduce_type,
|
||||
grad_reduced_feats_desc.desc(),
|
||||
grad_reduced_feats_ptr,
|
||||
feats_desc.desc(), feats_ptr,
|
||||
reduced_feats_desc.desc(), reduced_feats_ptr,
|
||||
coors_idx_desc.desc(), coors_idx_ptr,
|
||||
reduce_count_desc.desc(), reduce_count_ptr,
|
||||
voxel_num_desc.desc(), voxel_num_ptr,
|
||||
workspace_tensor_ptr, workspace_size,
|
||||
grad_feats_desc.desc(), grad_feats_ptr);
|
||||
handle, mlu_reduce_type, grad_reduced_feats_desc.desc(),
|
||||
grad_reduced_feats_ptr, feats_desc.desc(), feats_ptr,
|
||||
reduced_feats_desc.desc(), reduced_feats_ptr, coors_idx_desc.desc(),
|
||||
coors_idx_ptr, reduce_count_desc.desc(), reduce_count_ptr,
|
||||
voxel_num_desc.desc(), voxel_num_ptr, workspace_tensor_ptr,
|
||||
workspace_size, grad_feats_desc.desc(), grad_feats_ptr);
|
||||
}
|
||||
|
||||
std::vector<Tensor> dynamic_point_to_voxel_forward_impl(const Tensor &feats,
|
||||
const Tensor &coors,
|
||||
const reduce_t reduce_type);
|
||||
std::vector<Tensor> dynamic_point_to_voxel_forward_impl(
|
||||
const Tensor &feats, const Tensor &coors, const reduce_t reduce_type);
|
||||
|
||||
void dynamic_point_to_voxel_backward_impl(Tensor &grad_feats,
|
||||
const Tensor &grad_reduced_feats,
|
||||
const Tensor &feats,
|
||||
const Tensor &reduced_feats,
|
||||
const Tensor &coors_idx,
|
||||
const Tensor &reduce_count,
|
||||
const reduce_t reduce_type);
|
||||
void dynamic_point_to_voxel_backward_impl(
|
||||
Tensor &grad_feats, const Tensor &grad_reduced_feats, const Tensor &feats,
|
||||
const Tensor &reduced_feats, const Tensor &coors_idx,
|
||||
const Tensor &reduce_count, const reduce_t reduce_type);
|
||||
|
||||
REGISTER_DEVICE_IMPL(dynamic_point_to_voxel_forward_impl, MLU,
|
||||
dynamic_point_to_voxel_forward_mlu);
|
||||
|
|
|
@ -4,11 +4,23 @@ import pytest
|
|||
import torch
|
||||
|
||||
from mmcv.ops import diff_iou_rotated_2d, diff_iou_rotated_3d
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
|
||||
|
||||
if IS_MLU_AVAILABLE:
|
||||
torch.backends.mlu.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='requires CUDA support')
|
||||
def test_diff_iou_rotated_2d():
|
||||
@pytest.mark.parametrize('device', [
|
||||
pytest.param(
|
||||
'cuda',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
|
||||
pytest.param(
|
||||
'mlu',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_MLU_AVAILABLE, reason='requires MLU support'))
|
||||
])
|
||||
def test_diff_iou_rotated_2d(device):
|
||||
np_boxes1 = np.asarray([[[0.5, 0.5, 1., 1., .0], [0.5, 0.5, 1., 1., .0],
|
||||
[0.5, 0.5, 1., 1., .0], [0.5, 0.5, 1., 1., .0],
|
||||
[0.5, 0.5, 1., 1., .0]]],
|
||||
|
@ -19,17 +31,25 @@ def test_diff_iou_rotated_2d():
|
|||
[1.5, 1.5, 1., 1., .0]]],
|
||||
dtype=np.float32)
|
||||
|
||||
boxes1 = torch.from_numpy(np_boxes1).cuda()
|
||||
boxes2 = torch.from_numpy(np_boxes2).cuda()
|
||||
boxes1 = torch.from_numpy(np_boxes1).to(device)
|
||||
boxes2 = torch.from_numpy(np_boxes2).to(device)
|
||||
|
||||
np_expect_ious = np.asarray([[1., 1., .7071, 1 / 7, .0]])
|
||||
ious = diff_iou_rotated_2d(boxes1, boxes2)
|
||||
assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='requires CUDA support')
|
||||
def test_diff_iou_rotated_3d():
|
||||
@pytest.mark.parametrize('device', [
|
||||
pytest.param(
|
||||
'cuda',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
|
||||
pytest.param(
|
||||
'mlu',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_MLU_AVAILABLE, reason='requires MLU support'))
|
||||
])
|
||||
def test_diff_iou_rotated_3d(device):
|
||||
np_boxes1 = np.asarray(
|
||||
[[[.5, .5, .5, 1., 1., 1., .0], [.5, .5, .5, 1., 1., 1., .0],
|
||||
[.5, .5, .5, 1., 1., 1., .0], [.5, .5, .5, 1., 1., 1., .0],
|
||||
|
@ -41,8 +61,8 @@ def test_diff_iou_rotated_3d():
|
|||
[-1.5, -1.5, -1.5, 2.5, 2.5, 2.5, .0]]],
|
||||
dtype=np.float32)
|
||||
|
||||
boxes1 = torch.from_numpy(np_boxes1).cuda()
|
||||
boxes2 = torch.from_numpy(np_boxes2).cuda()
|
||||
boxes1 = torch.from_numpy(np_boxes1).to(device)
|
||||
boxes2 = torch.from_numpy(np_boxes2).to(device)
|
||||
|
||||
np_expect_ious = np.asarray([[1., .5, .7071, 1 / 15, .0]])
|
||||
ious = diff_iou_rotated_3d(boxes1, boxes2)
|
||||
|
|
Loading…
Reference in New Issue