[Feature] Add the implementation of dynamic_scatter with mlu-ops (#2847)

pull/2862/head
Danielmic 2023-06-29 17:13:57 +08:00 committed by GitHub
parent af0aaddf89
commit 10c8b9e78b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 248 additions and 31 deletions

View File

@ -21,7 +21,7 @@ We implement common ops used in detection, segmentation, etc.
| Deformable Convolution v1/v2 | √ | √ | | | √ |
| Deformable RoIPool | | √ | √ | | √ |
| DiffIoURotated | | √ | | | |
| DynamicScatter | | √ | | | |
| DynamicScatter | | √ | | | |
| FurthestPointSample | | √ | | | |
| FurthestPointSampleWithDist | | √ | | | |
| FusedBiasLeakyrelu | | √ | | | √ |

View File

@ -21,7 +21,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| Deformable Convolution v1/v2 | √ | √ | | | √ |
| Deformable RoIPool | | √ | √ | | √ |
| DiffIoURotated | | √ | | | |
| DynamicScatter | | √ | | | |
| DynamicScatter | | √ | | | |
| FurthestPointSample | | √ | | | |
| FurthestPointSampleWithDist | | √ | | | |
| FusedBiasLeakyrelu | | √ | | | √ |

View File

@ -56,6 +56,19 @@ mluOpTensorLayout_t getMluOpSuggestLayout(const at::Tensor& input) {
return layout;
}
mluOpReduceMode_t getMluOpReduceMode(const reduce_t reduce_type) {
const std::map<reduce_t, mluOpReduceMode_t> mapping_type = {
{reduce_t::MAX, MLUOP_REDUCE_DMAX},
{reduce_t::SUM, MLUOP_REDUCE_DSUM},
{reduce_t::MEAN, MLUOP_REDUCE_DMEAN}};
if (mapping_type.find(reduce_type) != mapping_type.end()) {
return mapping_type.find(reduce_type)->second;
} else {
TORCH_CHECK(false, "Unsupported reduce type: ", to_string(reduce_type));
return MLUOP_REDUCE_DSUM;
}
}
void MluOpTensorDescriptor::set(Tensor t) {
mluOpDataType_t data_type = getMluOpDataType(t.dtype());
mluOpTensorLayout_t layout = getMluOpSuggestLayout(t);

View File

@ -18,11 +18,39 @@
#include "pytorch_device_registry.hpp"
#define MLUOP_MAJOR 0
#define MLUOP_MINOR 6
#define MLUOP_PATCHLEVEL 0
#define MLUOP_MINOR 7
#define MLUOP_PATCHLEVEL 1
/*************************************************************************
* This MACRO contains operations of simple tensor to mlu-tensor.
* _contiguous, _desc, _impl, _ptr will be automatically generated in
* this MACRO.
*************************************************************************/
#define INITIAL_MLU_PARAM_WITH_TENSOR(NAME) \
auto NAME##_contigous = torch_mlu::cnnl::ops::cnnl_contiguous( \
NAME, NAME.suggest_memory_format()); \
MluOpTensorDescriptor NAME##_desc; \
NAME##_desc.set(NAME##_contigous); \
auto NAME##_impl = torch_mlu::getMluTensorImpl(NAME##_contigous); \
auto NAME##_ptr = NAME##_impl->cnnlMalloc();
enum class reduce_t{ SUM = 0, MEAN = 1, MAX = 2 };
inline std::string to_string(reduce_t reduce_type) {
if (reduce_type == reduce_t::MAX) {
return "max";
} else if (reduce_type == reduce_t::MEAN) {
return "mean";
} else if (reduce_type == reduce_t::SUM) {
return "sum";
} else {
return "unknown reduce type";
}
}
mluOpDataType_t getMluOpDataType(const caffe2::TypeMeta& data_type);
mluOpTensorLayout_t getMluOpSuggestLayout(const at::Tensor& input);
mluOpReduceMode_t getMluOpReduceMode(const reduce_t reduce_type);
class MluOpTensorDescriptor {
public:

View File

@ -13,19 +13,6 @@
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
/*************************************************************************
* This MACRO contains operations of simple tensor to mlu-tensor.
* _contiguous, _desc, _impl, _ptr will be automatically generated in
* this MACRO.
*************************************************************************/
#define INITIAL_MLU_PARAM_WITH_TENSOR(NAME) \
auto NAME##_contigous = torch_mlu::cnnl::ops::cnnl_contiguous( \
NAME, NAME.suggest_memory_format()); \
MluOpTensorDescriptor NAME##_desc; \
NAME##_desc.set(NAME##_contigous); \
auto NAME##_impl = torch_mlu::getMluTensorImpl(NAME##_contigous); \
auto NAME##_ptr = NAME##_impl->cnnlMalloc();
Tensor MsDeformAttnForwardLauncher(const Tensor& value,
const Tensor& spatial_shapes,
const Tensor& level_start_index,

View File

@ -0,0 +1,178 @@
/*************************************************************************
* Copyright (C) 2023 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "mlu_common_helper.h"
std::vector<Tensor> dynamic_point_to_voxel_forward_mlu(const Tensor &feats,
const Tensor &coors,
const reduce_t reduce_type) {
// params check
TORCH_CHECK(feats.scalar_type() == at::kFloat,
"feats type should be Float, got ", feats.scalar_type());
TORCH_CHECK(coors.scalar_type() == at::kInt,
"coors type should be Int32, got ", coors.scalar_type());
TORCH_CHECK(feats.size(0) == coors.size(0),
"feats.dim(0) and coors.dim(0) should be same, got ", feats.size(0), " vs ", coors.size(0));
const int num_input = feats.size(0);
const int num_feats = feats.size(1);
// zero-element check
if (num_input == 0)
return {feats.clone().detach(), coors.clone().detach(),
coors.new_empty({0}, torch::kInt32),
coors.new_empty({0}, torch::kInt32)};
auto mlu_reduce_type = getMluOpReduceMode(reduce_type);
auto reduced_feats = at::empty({num_input, num_feats}, feats.options());
auto out_coors = at::empty({num_input, 3}, coors.options());
auto coors_map = at::empty({num_input}, coors.options());
auto reduce_count = at::empty({num_input}, coors.options());
auto voxel_num = at::empty({1}, coors.options());
INITIAL_MLU_PARAM_WITH_TENSOR(feats);
INITIAL_MLU_PARAM_WITH_TENSOR(coors);
INITIAL_MLU_PARAM_WITH_TENSOR(reduced_feats);
INITIAL_MLU_PARAM_WITH_TENSOR(out_coors);
INITIAL_MLU_PARAM_WITH_TENSOR(coors_map);
INITIAL_MLU_PARAM_WITH_TENSOR(reduce_count);
INITIAL_MLU_PARAM_WITH_TENSOR(voxel_num);
// get compute handle
auto handle = mluOpGetCurrentHandle();
size_t workspace_size;
mluOpGetDynamicPointToVoxelForwardWorkspaceSize(handle,
feats_desc.desc(),
coors_desc.desc(),
&workspace_size);
auto workspace_tensor =
at::empty(workspace_size, feats.options().dtype(at::kByte));
INITIAL_MLU_PARAM_WITH_TENSOR(workspace_tensor);
// launch kernel
mluOpDynamicPointToVoxelForward(handle,
mlu_reduce_type,
feats_desc.desc(),
feats_ptr,
coors_desc.desc(),
coors_ptr,
workspace_tensor_ptr,
workspace_size,
reduced_feats_desc.desc(),
reduced_feats_ptr,
out_coors_desc.desc(),
out_coors_ptr,
coors_map_desc.desc(),
coors_map_ptr,
reduce_count_desc.desc(),
reduce_count_ptr,
voxel_num_desc.desc(),
voxel_num_ptr);
int voxel_num_value = *static_cast<int *>(voxel_num.cpu().data_ptr());
TORCH_CHECK(voxel_num_value <= feats.size(0),
"voxel_num should be less than or equal to feats_num, got ", voxel_num_value, " vs ", feats.size(0));
return {reduced_feats.slice(0, 0, voxel_num_value), out_coors.slice(0, 0, voxel_num_value),
coors_map, reduce_count.slice(0, 0, voxel_num_value)};
}
void dynamic_point_to_voxel_backward_mlu(Tensor &grad_feats,
const Tensor &grad_reduced_feats,
const Tensor &feats,
const Tensor &reduced_feats,
const Tensor &coors_idx,
const Tensor &reduce_count,
const reduce_t reduce_type) {
// params check
TORCH_CHECK(grad_reduced_feats.scalar_type() == at::kFloat,
"grad_reduced_feats type should be Float, got ", grad_reduced_feats.scalar_type());
TORCH_CHECK(feats.scalar_type() == at::kFloat,
"feats type should be Float, got ", feats.scalar_type());
TORCH_CHECK(reduced_feats.scalar_type() == at::kFloat,
"reduced_feats type should be Float, got ", reduced_feats.scalar_type());
TORCH_CHECK(coors_idx.scalar_type() == at::kInt,
"coors_idx type should be Int32, got ", coors_idx.scalar_type());
TORCH_CHECK(reduce_count.scalar_type() == at::kInt,
"reduce_count type should be Int32, got ", reduce_count.scalar_type());
const int num_input = feats.size(0);
const int num_reduced = reduced_feats.size(0);
const int num_feats = feats.size(1);
grad_feats.fill_(0);
// zero-element check
if (num_input == 0 || num_reduced == 0) return;
// TODO(miaochen): remove this after mlu-ops supports other mode of reduce.
TORCH_CHECK(reduce_type == reduce_t::MAX,
"only supports max reduce in current version, got ", to_string(reduce_type));
int voxel_num_value = reduced_feats.size(0);
auto opts = torch::TensorOptions().dtype(torch::kInt32);
auto voxel_num = torch::from_blob(&voxel_num_value, {1}, opts).clone().to(at::kMLU);
auto mlu_reduce_type = getMluOpReduceMode(reduce_type);
INITIAL_MLU_PARAM_WITH_TENSOR(grad_feats);
INITIAL_MLU_PARAM_WITH_TENSOR(grad_reduced_feats);
INITIAL_MLU_PARAM_WITH_TENSOR(feats);
INITIAL_MLU_PARAM_WITH_TENSOR(reduced_feats);
INITIAL_MLU_PARAM_WITH_TENSOR(coors_idx);
INITIAL_MLU_PARAM_WITH_TENSOR(reduce_count);
INITIAL_MLU_PARAM_WITH_TENSOR(voxel_num);
// get compute handle
auto handle = mluOpGetCurrentHandle();
size_t workspace_size;
mluOpGetDynamicPointToVoxelBackwardWorkspaceSize(
handle, mlu_reduce_type,
grad_feats_desc.desc(),
feats_desc.desc(),
grad_reduced_feats_desc.desc(),
coors_idx_desc.desc(),
reduce_count_desc.desc(),
voxel_num_desc.desc(),
&workspace_size);
auto workspace_tensor =
at::empty(workspace_size, feats.options().dtype(at::kByte));
INITIAL_MLU_PARAM_WITH_TENSOR(workspace_tensor);
// launch kernel
mluOpDynamicPointToVoxelBackward(
handle, mlu_reduce_type,
grad_reduced_feats_desc.desc(),
grad_reduced_feats_ptr,
feats_desc.desc(), feats_ptr,
reduced_feats_desc.desc(), reduced_feats_ptr,
coors_idx_desc.desc(), coors_idx_ptr,
reduce_count_desc.desc(), reduce_count_ptr,
voxel_num_desc.desc(), voxel_num_ptr,
workspace_tensor_ptr, workspace_size,
grad_feats_desc.desc(), grad_feats_ptr);
}
std::vector<Tensor> dynamic_point_to_voxel_forward_impl(const Tensor &feats,
const Tensor &coors,
const reduce_t reduce_type);
void dynamic_point_to_voxel_backward_impl(Tensor &grad_feats,
const Tensor &grad_reduced_feats,
const Tensor &feats,
const Tensor &reduced_feats,
const Tensor &coors_idx,
const Tensor &reduce_count,
const reduce_t reduce_type);
REGISTER_DEVICE_IMPL(dynamic_point_to_voxel_forward_impl, MLU,
dynamic_point_to_voxel_forward_mlu);
REGISTER_DEVICE_IMPL(dynamic_point_to_voxel_backward_impl, MLU,
dynamic_point_to_voxel_backward_mlu);

View File

@ -4,22 +4,31 @@ import torch
from torch.autograd import gradcheck
from mmcv.ops import DynamicScatter
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
if torch.__version__ == 'parrots':
pytest.skip('not supported in parrots now', allow_module_level=True)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_dynamic_scatter():
@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
])
def test_dynamic_scatter(device):
dsmean = DynamicScatter([0.32, 0.32, 6],
[-74.88, -74.88, -2, 74.88, 74.88, 4], True)
dsmax = DynamicScatter([0.32, 0.32, 6],
[-74.88, -74.88, -2, 74.88, 74.88, 4], False)
# test empty input
empty_feats = torch.empty(size=(0, 3), dtype=torch.float32, device='cuda')
empty_coors = torch.empty(size=(0, 3), dtype=torch.int32, device='cuda')
empty_feats = torch.empty(size=(0, 3), dtype=torch.float32, device=device)
empty_coors = torch.empty(size=(0, 3), dtype=torch.int32, device=device)
empty_feats.requires_grad_()
empty_feats_out_mean, empty_coors_out_mean = dsmean(
@ -35,9 +44,9 @@ def test_dynamic_scatter():
# test empty reduced output
empty_o_feats = torch.rand(
size=(200000, 3), dtype=torch.float32, device='cuda') * 100 - 50
size=(200000, 3), dtype=torch.float32, device=device) * 100 - 50
empty_o_coors = torch.randint(
low=-1, high=0, size=(200000, 3), dtype=torch.int32, device='cuda')
low=-1, high=0, size=(200000, 3), dtype=torch.int32, device=device)
empty_o_feats.requires_grad_()
empty_o_feats_out_mean, empty_o_coors_out_mean = dsmean(
@ -52,9 +61,9 @@ def test_dynamic_scatter():
# test non-empty input
feats = torch.rand(
size=(200000, 3), dtype=torch.float32, device='cuda') * 100 - 50
size=(200000, 3), dtype=torch.float32, device=device) * 100 - 50
coors = torch.randint(
low=-1, high=20, size=(200000, 3), dtype=torch.int32, device='cuda')
low=-1, high=20, size=(200000, 3), dtype=torch.int32, device=device)
ref_voxel_coors = coors.unique(dim=0, sorted=True)
ref_voxel_coors = ref_voxel_coors[ref_voxel_coors.min(dim=-1).values >= 0]
@ -88,9 +97,9 @@ def test_dynamic_scatter():
# test non-empty input without any point out of bound
feats = torch.rand(
size=(200000, 3), dtype=torch.float32, device='cuda') * 100 - 50
size=(200000, 3), dtype=torch.float32, device=device) * 100 - 50
coors = torch.randint(
low=0, high=20, size=(200000, 3), dtype=torch.int32, device='cuda')
low=0, high=20, size=(200000, 3), dtype=torch.int32, device=device)
ref_voxel_coors = coors.unique(dim=0, sorted=True)
ref_voxel_coors = ref_voxel_coors[ref_voxel_coors.min(dim=-1).values >= 0]
@ -124,9 +133,11 @@ def test_dynamic_scatter():
# test grad #
feats = torch.rand(
size=(100, 4), dtype=torch.float32, device='cuda') * 100 - 50
size=(100, 4), dtype=torch.float32, device=device) * 100 - 50
coors = torch.randint(
low=-1, high=3, size=(100, 3), dtype=torch.int32, device='cuda')
low=-1, high=3, size=(100, 3), dtype=torch.int32, device=device)
feats.requires_grad_()
# TODO(Cambricon): mlu only support max reduce in current version.
if not IS_MLU_AVAILABLE:
gradcheck(dsmean, (feats, coors), eps=1e-2, atol=1e-2, rtol=1e-5)
gradcheck(dsmax, (feats, coors), eps=1e-2, atol=1e-2, rtol=1e-5)