[Feature] Add the implementation of diff_iou_rotated with mlu-ops (#2840)

pull/2862/head
Danielmic 2023-06-30 15:18:39 +08:00 committed by GitHub
parent 10c8b9e78b
commit d28aa8a9cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 139 additions and 86 deletions

View File

@ -20,7 +20,7 @@ We implement common ops used in detection, segmentation, etc.
| Correlation | | √ | | | |
| Deformable Convolution v1/v2 | √ | √ | | | √ |
| Deformable RoIPool | | √ | √ | | √ |
| DiffIoURotated | | √ | | | |
| DiffIoURotated | | √ | | | |
| DynamicScatter | | √ | √ | | |
| FurthestPointSample | | √ | | | |
| FurthestPointSampleWithDist | | √ | | | |

View File

@ -20,7 +20,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| Correlation | | √ | | | |
| Deformable Convolution v1/v2 | √ | √ | | | √ |
| Deformable RoIPool | | √ | √ | | √ |
| DiffIoURotated | | √ | | | |
| DiffIoURotated | | √ | | | |
| DynamicScatter | | √ | √ | | |
| FurthestPointSample | | √ | | | |
| FurthestPointSampleWithDist | | √ | | | |

View File

@ -0,0 +1,55 @@
/*************************************************************************
* Copyright (C) 2023 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "mlu_common_helper.h"
Tensor diff_iou_rotated_sort_vertices_forward_mlu(Tensor vertices, Tensor mask,
Tensor num_valid) {
// params check
TORCH_CHECK(vertices.scalar_type() == at::kFloat,
"vertices type should be Float, got ", vertices.scalar_type());
TORCH_CHECK(mask.scalar_type() == at::kBool, "mask should be Bool, got ",
mask.scalar_type());
TORCH_CHECK(num_valid.scalar_type() == at::kInt,
"num_valid type should be Int32, got ", num_valid.scalar_type());
TORCH_CHECK(vertices.size(2) == 24, "vertices.dim(2) should be 24, got ",
vertices.size(2));
TORCH_CHECK(mask.size(2) == 24, "mask.dim(2) should be 24, got ",
mask.size(2));
// zero-element check
if (vertices.numel() == 0) {
return at::empty({0}, num_valid.options().dtype(at::kInt));
}
auto idx = at::empty({vertices.size(0), vertices.size(1), 9},
num_valid.options().dtype(at::kInt));
INITIAL_MLU_PARAM_WITH_TENSOR(vertices);
INITIAL_MLU_PARAM_WITH_TENSOR(mask);
INITIAL_MLU_PARAM_WITH_TENSOR(num_valid);
INITIAL_MLU_PARAM_WITH_TENSOR(idx);
// get compute handle
auto handle = mluOpGetCurrentHandle();
// launch kernel
mluOpDiffIouRotatedSortVerticesForward(
handle, vertices_desc.desc(), vertices_ptr, mask_desc.desc(), mask_ptr,
num_valid_desc.desc(), num_valid_ptr, idx_desc.desc(), idx_ptr);
return idx;
}
Tensor diff_iou_rotated_sort_vertices_forward_impl(Tensor vertices, Tensor mask,
Tensor num_valid);
REGISTER_DEVICE_IMPL(diff_iou_rotated_sort_vertices_forward_impl, MLU,
diff_iou_rotated_sort_vertices_forward_mlu);

View File

@ -11,16 +11,16 @@
*************************************************************************/
#include "mlu_common_helper.h"
std::vector<Tensor> dynamic_point_to_voxel_forward_mlu(const Tensor &feats,
const Tensor &coors,
const reduce_t reduce_type) {
std::vector<Tensor> dynamic_point_to_voxel_forward_mlu(
const Tensor &feats, const Tensor &coors, const reduce_t reduce_type) {
// params check
TORCH_CHECK(feats.scalar_type() == at::kFloat,
"feats type should be Float, got ", feats.scalar_type());
TORCH_CHECK(coors.scalar_type() == at::kInt,
"coors type should be Int32, got ", coors.scalar_type());
TORCH_CHECK(feats.size(0) == coors.size(0),
"feats.dim(0) and coors.dim(0) should be same, got ", feats.size(0), " vs ", coors.size(0));
"feats.dim(0) and coors.dim(0) should be same, got ",
feats.size(0), " vs ", coors.size(0));
const int num_input = feats.size(0);
const int num_feats = feats.size(1);
@ -49,59 +49,48 @@ std::vector<Tensor> dynamic_point_to_voxel_forward_mlu(const Tensor &feats,
auto handle = mluOpGetCurrentHandle();
size_t workspace_size;
mluOpGetDynamicPointToVoxelForwardWorkspaceSize(handle,
feats_desc.desc(),
coors_desc.desc(),
&workspace_size);
mluOpGetDynamicPointToVoxelForwardWorkspaceSize(
handle, feats_desc.desc(), coors_desc.desc(), &workspace_size);
auto workspace_tensor =
at::empty(workspace_size, feats.options().dtype(at::kByte));
INITIAL_MLU_PARAM_WITH_TENSOR(workspace_tensor);
// launch kernel
mluOpDynamicPointToVoxelForward(handle,
mlu_reduce_type,
feats_desc.desc(),
feats_ptr,
coors_desc.desc(),
coors_ptr,
workspace_tensor_ptr,
workspace_size,
reduced_feats_desc.desc(),
reduced_feats_ptr,
out_coors_desc.desc(),
out_coors_ptr,
coors_map_desc.desc(),
coors_map_ptr,
reduce_count_desc.desc(),
reduce_count_ptr,
voxel_num_desc.desc(),
mluOpDynamicPointToVoxelForward(
handle, mlu_reduce_type, feats_desc.desc(), feats_ptr, coors_desc.desc(),
coors_ptr, workspace_tensor_ptr, workspace_size,
reduced_feats_desc.desc(), reduced_feats_ptr, out_coors_desc.desc(),
out_coors_ptr, coors_map_desc.desc(), coors_map_ptr,
reduce_count_desc.desc(), reduce_count_ptr, voxel_num_desc.desc(),
voxel_num_ptr);
int voxel_num_value = *static_cast<int *>(voxel_num.cpu().data_ptr());
TORCH_CHECK(voxel_num_value <= feats.size(0),
"voxel_num should be less than or equal to feats_num, got ", voxel_num_value, " vs ", feats.size(0));
return {reduced_feats.slice(0, 0, voxel_num_value), out_coors.slice(0, 0, voxel_num_value),
coors_map, reduce_count.slice(0, 0, voxel_num_value)};
"voxel_num should be less than or equal to feats_num, got ",
voxel_num_value, " vs ", feats.size(0));
return {reduced_feats.slice(0, 0, voxel_num_value),
out_coors.slice(0, 0, voxel_num_value), coors_map,
reduce_count.slice(0, 0, voxel_num_value)};
}
void dynamic_point_to_voxel_backward_mlu(Tensor &grad_feats,
const Tensor &grad_reduced_feats,
const Tensor &feats,
const Tensor &reduced_feats,
const Tensor &coors_idx,
const Tensor &reduce_count,
const reduce_t reduce_type) {
void dynamic_point_to_voxel_backward_mlu(
Tensor &grad_feats, const Tensor &grad_reduced_feats, const Tensor &feats,
const Tensor &reduced_feats, const Tensor &coors_idx,
const Tensor &reduce_count, const reduce_t reduce_type) {
// params check
TORCH_CHECK(grad_reduced_feats.scalar_type() == at::kFloat,
"grad_reduced_feats type should be Float, got ", grad_reduced_feats.scalar_type());
"grad_reduced_feats type should be Float, got ",
grad_reduced_feats.scalar_type());
TORCH_CHECK(feats.scalar_type() == at::kFloat,
"feats type should be Float, got ", feats.scalar_type());
TORCH_CHECK(reduced_feats.scalar_type() == at::kFloat,
"reduced_feats type should be Float, got ", reduced_feats.scalar_type());
"reduced_feats type should be Float, got ",
reduced_feats.scalar_type());
TORCH_CHECK(coors_idx.scalar_type() == at::kInt,
"coors_idx type should be Int32, got ", coors_idx.scalar_type());
TORCH_CHECK(reduce_count.scalar_type() == at::kInt,
"reduce_count type should be Int32, got ", reduce_count.scalar_type());
"reduce_count type should be Int32, got ",
reduce_count.scalar_type());
const int num_input = feats.size(0);
const int num_reduced = reduced_feats.size(0);
@ -114,11 +103,13 @@ void dynamic_point_to_voxel_backward_mlu(Tensor &grad_feats,
// TODO(miaochen): remove this after mlu-ops supports other mode of reduce.
TORCH_CHECK(reduce_type == reduce_t::MAX,
"only supports max reduce in current version, got ", to_string(reduce_type));
"only supports max reduce in current version, got ",
to_string(reduce_type));
int voxel_num_value = reduced_feats.size(0);
auto opts = torch::TensorOptions().dtype(torch::kInt32);
auto voxel_num = torch::from_blob(&voxel_num_value, {1}, opts).clone().to(at::kMLU);
auto voxel_num =
torch::from_blob(&voxel_num_value, {1}, opts).clone().to(at::kMLU);
auto mlu_reduce_type = getMluOpReduceMode(reduce_type);
INITIAL_MLU_PARAM_WITH_TENSOR(grad_feats);
@ -134,43 +125,30 @@ void dynamic_point_to_voxel_backward_mlu(Tensor &grad_feats,
size_t workspace_size;
mluOpGetDynamicPointToVoxelBackwardWorkspaceSize(
handle, mlu_reduce_type,
grad_feats_desc.desc(),
feats_desc.desc(),
grad_reduced_feats_desc.desc(),
coors_idx_desc.desc(),
reduce_count_desc.desc(),
voxel_num_desc.desc(),
&workspace_size);
handle, mlu_reduce_type, grad_feats_desc.desc(), feats_desc.desc(),
grad_reduced_feats_desc.desc(), coors_idx_desc.desc(),
reduce_count_desc.desc(), voxel_num_desc.desc(), &workspace_size);
auto workspace_tensor =
at::empty(workspace_size, feats.options().dtype(at::kByte));
INITIAL_MLU_PARAM_WITH_TENSOR(workspace_tensor);
// launch kernel
mluOpDynamicPointToVoxelBackward(
handle, mlu_reduce_type,
grad_reduced_feats_desc.desc(),
grad_reduced_feats_ptr,
feats_desc.desc(), feats_ptr,
reduced_feats_desc.desc(), reduced_feats_ptr,
coors_idx_desc.desc(), coors_idx_ptr,
reduce_count_desc.desc(), reduce_count_ptr,
voxel_num_desc.desc(), voxel_num_ptr,
workspace_tensor_ptr, workspace_size,
grad_feats_desc.desc(), grad_feats_ptr);
handle, mlu_reduce_type, grad_reduced_feats_desc.desc(),
grad_reduced_feats_ptr, feats_desc.desc(), feats_ptr,
reduced_feats_desc.desc(), reduced_feats_ptr, coors_idx_desc.desc(),
coors_idx_ptr, reduce_count_desc.desc(), reduce_count_ptr,
voxel_num_desc.desc(), voxel_num_ptr, workspace_tensor_ptr,
workspace_size, grad_feats_desc.desc(), grad_feats_ptr);
}
std::vector<Tensor> dynamic_point_to_voxel_forward_impl(const Tensor &feats,
const Tensor &coors,
const reduce_t reduce_type);
std::vector<Tensor> dynamic_point_to_voxel_forward_impl(
const Tensor &feats, const Tensor &coors, const reduce_t reduce_type);
void dynamic_point_to_voxel_backward_impl(Tensor &grad_feats,
const Tensor &grad_reduced_feats,
const Tensor &feats,
const Tensor &reduced_feats,
const Tensor &coors_idx,
const Tensor &reduce_count,
const reduce_t reduce_type);
void dynamic_point_to_voxel_backward_impl(
Tensor &grad_feats, const Tensor &grad_reduced_feats, const Tensor &feats,
const Tensor &reduced_feats, const Tensor &coors_idx,
const Tensor &reduce_count, const reduce_t reduce_type);
REGISTER_DEVICE_IMPL(dynamic_point_to_voxel_forward_impl, MLU,
dynamic_point_to_voxel_forward_mlu);

View File

@ -4,11 +4,23 @@ import pytest
import torch
from mmcv.ops import diff_iou_rotated_2d, diff_iou_rotated_3d
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
if IS_MLU_AVAILABLE:
torch.backends.mlu.matmul.allow_tf32 = False
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_diff_iou_rotated_2d():
@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
])
def test_diff_iou_rotated_2d(device):
np_boxes1 = np.asarray([[[0.5, 0.5, 1., 1., .0], [0.5, 0.5, 1., 1., .0],
[0.5, 0.5, 1., 1., .0], [0.5, 0.5, 1., 1., .0],
[0.5, 0.5, 1., 1., .0]]],
@ -19,17 +31,25 @@ def test_diff_iou_rotated_2d():
[1.5, 1.5, 1., 1., .0]]],
dtype=np.float32)
boxes1 = torch.from_numpy(np_boxes1).cuda()
boxes2 = torch.from_numpy(np_boxes2).cuda()
boxes1 = torch.from_numpy(np_boxes1).to(device)
boxes2 = torch.from_numpy(np_boxes2).to(device)
np_expect_ious = np.asarray([[1., 1., .7071, 1 / 7, .0]])
ious = diff_iou_rotated_2d(boxes1, boxes2)
assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_diff_iou_rotated_3d():
@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
])
def test_diff_iou_rotated_3d(device):
np_boxes1 = np.asarray(
[[[.5, .5, .5, 1., 1., 1., .0], [.5, .5, .5, 1., 1., 1., .0],
[.5, .5, .5, 1., 1., 1., .0], [.5, .5, .5, 1., 1., 1., .0],
@ -41,8 +61,8 @@ def test_diff_iou_rotated_3d():
[-1.5, -1.5, -1.5, 2.5, 2.5, 2.5, .0]]],
dtype=np.float32)
boxes1 = torch.from_numpy(np_boxes1).cuda()
boxes2 = torch.from_numpy(np_boxes2).cuda()
boxes1 = torch.from_numpy(np_boxes1).to(device)
boxes2 = torch.from_numpy(np_boxes2).to(device)
np_expect_ious = np.asarray([[1., .5, .7071, 1 / 15, .0]])
ious = diff_iou_rotated_3d(boxes1, boxes2)