mirror of https://github.com/open-mmlab/mmcv.git
[Feature] Add MLU support for Sparse Convolution op (#2589)
* [Feature] Add sparse convolution MLU API * [Feature] update cpp code style * end-of-file * delete libext.a * code style * update ops.md --------- Co-authored-by: budefei <budefei@cambricon.com>pull/2683/head
parent
01a0f53ea4
commit
06fa32853b
|
@ -52,7 +52,7 @@ We implement common ops used in detection, segmentation, etc.
|
||||||
| SigmoidFocalLoss | | √ | √ | | √ |
|
| SigmoidFocalLoss | | √ | √ | | √ |
|
||||||
| SoftmaxFocalLoss | | √ | | | √ |
|
| SoftmaxFocalLoss | | √ | | | √ |
|
||||||
| SoftNMS | | √ | | | |
|
| SoftNMS | | √ | | | |
|
||||||
| Sparse Convolution | | √ | | | |
|
| Sparse Convolution | | √ | √ | | |
|
||||||
| Synchronized BatchNorm | | √ | | | |
|
| Synchronized BatchNorm | | √ | | | |
|
||||||
| ThreeInterpolate | | √ | | | |
|
| ThreeInterpolate | | √ | | | |
|
||||||
| ThreeNN | | √ | √ | | |
|
| ThreeNN | | √ | √ | | |
|
||||||
|
|
|
@ -52,7 +52,7 @@ MMCV 提供了检测、分割等任务中常用的算子
|
||||||
| SigmoidFocalLoss | | √ | √ | | √ |
|
| SigmoidFocalLoss | | √ | √ | | √ |
|
||||||
| SoftmaxFocalLoss | | √ | | | √ |
|
| SoftmaxFocalLoss | | √ | | | √ |
|
||||||
| SoftNMS | | √ | | | |
|
| SoftNMS | | √ | | | |
|
||||||
| Sparse Convolution | | √ | | | |
|
| Sparse Convolution | | √ | √ | | |
|
||||||
| Synchronized BatchNorm | | √ | | | |
|
| Synchronized BatchNorm | | √ | | | |
|
||||||
| ThreeInterpolate | | √ | | | |
|
| ThreeInterpolate | | √ | | | |
|
||||||
| ThreeNN | | √ | √ | | |
|
| ThreeNN | | √ | √ | | |
|
||||||
|
|
|
@ -18,8 +18,8 @@
|
||||||
#include "pytorch_device_registry.hpp"
|
#include "pytorch_device_registry.hpp"
|
||||||
|
|
||||||
#define MLUOP_MAJOR 0
|
#define MLUOP_MAJOR 0
|
||||||
#define MLUOP_MINOR 4
|
#define MLUOP_MINOR 5
|
||||||
#define MLUOP_PATCHLEVEL 2
|
#define MLUOP_PATCHLEVEL 302
|
||||||
|
|
||||||
mluOpDataType_t getMluOpDataType(const caffe2::TypeMeta& data_type);
|
mluOpDataType_t getMluOpDataType(const caffe2::TypeMeta& data_type);
|
||||||
mluOpTensorLayout_t getMluOpSuggestLayout(const at::Tensor& input);
|
mluOpTensorLayout_t getMluOpSuggestLayout(const at::Tensor& input);
|
||||||
|
|
|
@ -0,0 +1,446 @@
|
||||||
|
/*************************************************************************
|
||||||
|
* Copyright (C) 2022 Cambricon.
|
||||||
|
*
|
||||||
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||||
|
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||||
|
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||||
|
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||||
|
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||||
|
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||||
|
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||||
|
*************************************************************************/
|
||||||
|
#include <torch/script.h>
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mlu_common_helper.h"
|
||||||
|
#include "pytorch_device_registry.hpp"
|
||||||
|
#include "pytorch_mlu_helper.hpp"
|
||||||
|
|
||||||
|
template <unsigned NDim>
|
||||||
|
std::vector<torch::Tensor> GetIndicePairsForwardMLUKernelLauncher(
|
||||||
|
torch::Tensor indices, int64_t batchSize,
|
||||||
|
std::vector<int64_t> outSpatialShape, std::vector<int64_t> spatialShape,
|
||||||
|
std::vector<int64_t> kernelSize, std::vector<int64_t> stride,
|
||||||
|
std::vector<int64_t> padding, std::vector<int64_t> dilation,
|
||||||
|
std::vector<int64_t> outPadding, int64_t _subM, int64_t _transpose) {
|
||||||
|
// The following code is copied from
|
||||||
|
// mmcv/ops/csrc/pytorch/cuda/spconv_ops_cuda.cu to ensure the output is
|
||||||
|
// available for network train. The outputs of this function have correct
|
||||||
|
// shape but wrong value.
|
||||||
|
auto numAct = indices.size(0);
|
||||||
|
auto kernelVolume = kernelSize[0];
|
||||||
|
int sub_m = (int)_subM;
|
||||||
|
int transpose = (int)_transpose;
|
||||||
|
int batch = (int)batchSize;
|
||||||
|
auto coorDim = indices.size(1) - 1;
|
||||||
|
|
||||||
|
for (int i = 1; i < kernelSize.size(); ++i) {
|
||||||
|
kernelVolume *= kernelSize[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
auto outputVolume = outSpatialShape[0];
|
||||||
|
for (int i = 1; i < outSpatialShape.size(); ++i) {
|
||||||
|
outputVolume *= outSpatialShape[i];
|
||||||
|
}
|
||||||
|
torch::Tensor indicePairs = at::full({kernelVolume, 2, numAct}, -1,
|
||||||
|
indices.options().dtype(at::kInt));
|
||||||
|
torch::Tensor indiceNum =
|
||||||
|
at::zeros({kernelVolume}, indices.options().dtype(at::kInt));
|
||||||
|
int out_size = sub_m == 1
|
||||||
|
? numAct
|
||||||
|
: std::min(numAct * kernelVolume, batch * outputVolume);
|
||||||
|
torch::Tensor out_indices =
|
||||||
|
at::zeros({out_size, coorDim + 1}, indices.options().dtype(at::kInt));
|
||||||
|
auto indices_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
|
||||||
|
indices, at::MemoryFormat::Contiguous);
|
||||||
|
auto indicePairs_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
|
||||||
|
indicePairs, at::MemoryFormat::Contiguous);
|
||||||
|
auto indiceNum_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
|
||||||
|
indiceNum, at::MemoryFormat::Contiguous);
|
||||||
|
auto out_indices_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
|
||||||
|
out_indices, at::MemoryFormat::Contiguous);
|
||||||
|
|
||||||
|
std::vector<int> input_space;
|
||||||
|
std::vector<int> filter_space;
|
||||||
|
std::vector<int> output_space;
|
||||||
|
std::vector<int> padding32;
|
||||||
|
std::vector<int> stride32;
|
||||||
|
std::vector<int> dilation32;
|
||||||
|
for (int i = 0; i < NDim; i++) {
|
||||||
|
input_space.push_back(spatialShape[i]);
|
||||||
|
filter_space.push_back(kernelSize[i]);
|
||||||
|
output_space.push_back(outSpatialShape[i]);
|
||||||
|
padding32.push_back(padding[i]);
|
||||||
|
stride32.push_back(stride[i]);
|
||||||
|
dilation32.push_back(dilation[i]);
|
||||||
|
}
|
||||||
|
MluOpTensorDescriptor indices_desc, out_indices_desc, indicePairs_desc,
|
||||||
|
indiceNum_desc;
|
||||||
|
indices_desc.set(indices_contiguous);
|
||||||
|
indicePairs_desc.set(indicePairs_contiguous);
|
||||||
|
indiceNum_desc.set(indiceNum_contiguous);
|
||||||
|
out_indices_desc.set(out_indices_contiguous);
|
||||||
|
{
|
||||||
|
mluOpTensorLayout_t layout = MLUOP_LAYOUT_ARRAY;
|
||||||
|
mluOpDataType_t dtype = MLUOP_DTYPE_INT32;
|
||||||
|
std::vector<int> dims;
|
||||||
|
dims = {numAct, coorDim + 1};
|
||||||
|
mluOpSetTensorDescriptor(indices_desc.desc(), layout, dtype, dims.size(),
|
||||||
|
dims.data());
|
||||||
|
dims = {kernelVolume, 2, numAct};
|
||||||
|
mluOpSetTensorDescriptor(indicePairs_desc.desc(), layout, dtype,
|
||||||
|
dims.size(), dims.data());
|
||||||
|
dims = {kernelVolume};
|
||||||
|
mluOpSetTensorDescriptor(indiceNum_desc.desc(), layout, dtype, dims.size(),
|
||||||
|
dims.data());
|
||||||
|
dims = {out_size, coorDim + 1};
|
||||||
|
mluOpSetTensorDescriptor(out_indices_desc.desc(), layout, dtype,
|
||||||
|
dims.size(), dims.data());
|
||||||
|
}
|
||||||
|
|
||||||
|
mluOpSparseConvolutionDescriptor_t sparse_conv_desc;
|
||||||
|
mluOpCreateSparseConvolutionDescriptor(&sparse_conv_desc);
|
||||||
|
mluOpSetSparseConvolutionDescriptor(
|
||||||
|
sparse_conv_desc, NDim + 2, batch, padding32.data(), stride32.data(),
|
||||||
|
dilation32.data(), input_space.data(), filter_space.data(),
|
||||||
|
output_space.data(), sub_m, transpose, 0);
|
||||||
|
|
||||||
|
auto handle = mluOpGetCurrentHandle();
|
||||||
|
size_t workspace_size = 0;
|
||||||
|
mluOpGetIndicePairsWorkspaceSize(
|
||||||
|
handle, sparse_conv_desc, indices_desc.desc(), indicePairs_desc.desc(),
|
||||||
|
out_indices_desc.desc(), indiceNum_desc.desc(), &workspace_size);
|
||||||
|
auto indice_workspace_size =
|
||||||
|
at::empty(workspace_size, indices.options().dtype(at::kByte));
|
||||||
|
|
||||||
|
auto indices_impl = torch_mlu::getMluTensorImpl(indices_contiguous);
|
||||||
|
auto out_indices_impl = torch_mlu::getMluTensorImpl(out_indices_contiguous);
|
||||||
|
auto indicePairs_impl = torch_mlu::getMluTensorImpl(indicePairs_contiguous);
|
||||||
|
auto indiceNum_impl = torch_mlu::getMluTensorImpl(indiceNum_contiguous);
|
||||||
|
auto indice_workspace_impl =
|
||||||
|
torch_mlu::getMluTensorImpl(indice_workspace_size);
|
||||||
|
|
||||||
|
auto indices_ptr = indices_impl->cnnlMalloc();
|
||||||
|
auto out_indices_ptr = out_indices_impl->cnnlMalloc();
|
||||||
|
auto indicePairs_ptr = indicePairs_impl->cnnlMalloc();
|
||||||
|
auto indiceNum_ptr = indiceNum_impl->cnnlMalloc();
|
||||||
|
auto indice_workspace_ptr = indice_workspace_impl->cnnlMalloc();
|
||||||
|
|
||||||
|
mluOpGetIndicePairs(handle, sparse_conv_desc, indices_desc.desc(),
|
||||||
|
indices_ptr, indice_workspace_ptr, workspace_size,
|
||||||
|
indicePairs_desc.desc(), indicePairs_ptr,
|
||||||
|
out_indices_desc.desc(), out_indices_ptr,
|
||||||
|
indiceNum_desc.desc(), indiceNum_ptr);
|
||||||
|
int num_act_out = 0;
|
||||||
|
mluOpGetSparseConvolutionNumActOut(sparse_conv_desc, &num_act_out);
|
||||||
|
mluOpDestroySparseConvolutionDescriptor(sparse_conv_desc);
|
||||||
|
if (!sub_m) {
|
||||||
|
return {out_indices.slice(0, 0, num_act_out), indicePairs, indiceNum};
|
||||||
|
} else {
|
||||||
|
return {indices, indicePairs, indiceNum};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
torch::Tensor IndiceConvForwardMLUKernelLauncher(
|
||||||
|
torch::Tensor features, torch::Tensor filters, torch::Tensor indicePairs,
|
||||||
|
torch::Tensor indiceNum, int64_t numActOut, int64_t _inverse,
|
||||||
|
int64_t _subM) {
|
||||||
|
auto indice_num_cpu = indiceNum.to({torch::kCPU});
|
||||||
|
auto indice_num_cpu_64 = indice_num_cpu.data_ptr<int>();
|
||||||
|
int indice_num_len = indiceNum.numel();
|
||||||
|
int64_t indice_num[indice_num_len];
|
||||||
|
for (int i = 0; i < indice_num_len; ++i) {
|
||||||
|
indice_num[i] = (int64_t)(((int *)indice_num_cpu_64)[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// generate empty output
|
||||||
|
int C = filters.dim() == 4 ? filters.size(3) : filters.size(4);
|
||||||
|
torch::Tensor output =
|
||||||
|
at::zeros({numActOut, C}, features.options().dtype(at::kFloat));
|
||||||
|
// generate descriptor
|
||||||
|
auto features_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
|
||||||
|
features, at::MemoryFormat::Contiguous);
|
||||||
|
auto filters_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
|
||||||
|
filters, at::MemoryFormat::Contiguous);
|
||||||
|
auto indice_pairs_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
|
||||||
|
indicePairs, at::MemoryFormat::Contiguous);
|
||||||
|
auto output_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
|
||||||
|
output, at::MemoryFormat::Contiguous);
|
||||||
|
|
||||||
|
MluOpTensorDescriptor features_desc, filters_desc, indice_pairs_desc,
|
||||||
|
output_desc;
|
||||||
|
features_desc.set(features_contiguous);
|
||||||
|
filters_desc.set(filters_contiguous);
|
||||||
|
indice_pairs_desc.set(indice_pairs_contiguous);
|
||||||
|
output_desc.set(output_contiguous);
|
||||||
|
|
||||||
|
// set layout
|
||||||
|
{
|
||||||
|
mluOpTensorLayout_t layout;
|
||||||
|
mluOpDataType_t dtype;
|
||||||
|
int dim;
|
||||||
|
int dims[8];
|
||||||
|
|
||||||
|
// features_desc
|
||||||
|
mluOpGetTensorDescriptor(features_desc.desc(), &layout, &dtype, &dim, dims);
|
||||||
|
mluOpSetTensorDescriptor(features_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype,
|
||||||
|
dim, dims);
|
||||||
|
|
||||||
|
// filters_desc
|
||||||
|
mluOpGetTensorDescriptor(filters_desc.desc(), &layout, &dtype, &dim, dims);
|
||||||
|
mluOpSetTensorDescriptor(filters_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype,
|
||||||
|
dim, dims);
|
||||||
|
|
||||||
|
// indice_pairs_desc
|
||||||
|
mluOpGetTensorDescriptor(indice_pairs_desc.desc(), &layout, &dtype, &dim,
|
||||||
|
dims);
|
||||||
|
mluOpSetTensorDescriptor(indice_pairs_desc.desc(), MLUOP_LAYOUT_ARRAY,
|
||||||
|
dtype, dim, dims);
|
||||||
|
|
||||||
|
// output_desc
|
||||||
|
mluOpGetTensorDescriptor(output_desc.desc(), &layout, &dtype, &dim, dims);
|
||||||
|
mluOpSetTensorDescriptor(output_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype, dim,
|
||||||
|
dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto handle = mluOpGetCurrentHandle();
|
||||||
|
size_t workspace_size = 0;
|
||||||
|
mluOpGetIndiceConvolutionForwardWorkspaceSize(
|
||||||
|
handle, features_desc.desc(), filters_desc.desc(),
|
||||||
|
indice_pairs_desc.desc(), output_desc.desc(), indice_num, numActOut,
|
||||||
|
_inverse, _subM, &workspace_size);
|
||||||
|
|
||||||
|
auto workspace =
|
||||||
|
at::empty(workspace_size, features.options().dtype(at::kByte));
|
||||||
|
|
||||||
|
auto features_impl = torch_mlu::getMluTensorImpl(features_contiguous);
|
||||||
|
auto filters_impl = torch_mlu::getMluTensorImpl(filters_contiguous);
|
||||||
|
auto indice_pairs_impl = torch_mlu::getMluTensorImpl(indice_pairs_contiguous);
|
||||||
|
auto workspace_impl = torch_mlu::getMluTensorImpl(workspace);
|
||||||
|
|
||||||
|
auto features_ptr = features_impl->cnnlMalloc();
|
||||||
|
auto filters_ptr = filters_impl->cnnlMalloc();
|
||||||
|
auto indice_pairs_ptr = indice_pairs_impl->cnnlMalloc();
|
||||||
|
auto workspace_ptr = workspace_impl->cnnlMalloc();
|
||||||
|
|
||||||
|
// outputs
|
||||||
|
auto output_impl = torch_mlu::getMluTensorImpl(output);
|
||||||
|
auto output_ptr = output_impl->cnnlMalloc();
|
||||||
|
mluOpIndiceConvolutionForward(
|
||||||
|
handle, features_desc.desc(), features_ptr, filters_desc.desc(),
|
||||||
|
filters_ptr, indice_pairs_desc.desc(), indice_pairs_ptr, indice_num,
|
||||||
|
numActOut, _inverse, _subM, workspace_ptr, workspace_size,
|
||||||
|
output_desc.desc(), output_ptr);
|
||||||
|
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<torch::Tensor> IndiceConvBackwardMLUKernelLauncher(
|
||||||
|
torch::Tensor features, torch::Tensor filters, torch::Tensor outGrad,
|
||||||
|
torch::Tensor indicePairs, torch::Tensor indiceNum, int64_t _inverse,
|
||||||
|
int64_t _subM) {
|
||||||
|
auto indice_num_cpu = indiceNum.to({torch::kCPU});
|
||||||
|
auto indice_num_cpu_64 = indice_num_cpu.data_ptr<int>();
|
||||||
|
int indice_num_len = indiceNum.numel();
|
||||||
|
int64_t indice_num[indice_num_len];
|
||||||
|
for (int i = 0; i < indice_num_len; ++i) {
|
||||||
|
indice_num[i] = (int64_t)(((int *)(indice_num_cpu_64))[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// generate empty input_grad
|
||||||
|
torch::Tensor input_grad = at::zeros({features.size(0), features.size(1)},
|
||||||
|
features.options().dtype(at::kFloat));
|
||||||
|
torch::Tensor filters_grad;
|
||||||
|
if (filters.dim() == 4) {
|
||||||
|
int h = filters.size(0);
|
||||||
|
int w = filters.size(1);
|
||||||
|
int c = filters.size(2);
|
||||||
|
int n = filters.size(3);
|
||||||
|
filters_grad = at::zeros({h, w, c, n}, filters.options().dtype(at::kFloat));
|
||||||
|
} else if (filters.dim() == 5) {
|
||||||
|
int d = filters.size(0);
|
||||||
|
int h = filters.size(1);
|
||||||
|
int w = filters.size(2);
|
||||||
|
int c = filters.size(3);
|
||||||
|
int n = filters.size(4);
|
||||||
|
filters_grad =
|
||||||
|
at::zeros({d, h, w, c, n}, filters.options().dtype(at::kFloat));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto features_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
|
||||||
|
features, at::MemoryFormat::Contiguous);
|
||||||
|
auto filters_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
|
||||||
|
filters, at::MemoryFormat::Contiguous);
|
||||||
|
auto output_grad_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
|
||||||
|
outGrad, at::MemoryFormat::Contiguous);
|
||||||
|
auto indice_pairs_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
|
||||||
|
indicePairs, at::MemoryFormat::Contiguous);
|
||||||
|
auto input_grad_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
|
||||||
|
features, at::MemoryFormat::Contiguous);
|
||||||
|
auto filters_grad_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
|
||||||
|
filters, at::MemoryFormat::Contiguous);
|
||||||
|
|
||||||
|
MluOpTensorDescriptor features_desc, output_grad_desc, filters_desc,
|
||||||
|
indice_pairs_desc, input_grad_desc, filters_grad_desc;
|
||||||
|
features_desc.set(features_contiguous);
|
||||||
|
filters_desc.set(filters_contiguous);
|
||||||
|
output_grad_desc.set(output_grad_contiguous);
|
||||||
|
indice_pairs_desc.set(indice_pairs_contiguous);
|
||||||
|
input_grad_desc.set(input_grad_contiguous);
|
||||||
|
filters_grad_desc.set(filters_grad_contiguous);
|
||||||
|
|
||||||
|
// need to set desc layout with mluOp functions
|
||||||
|
{
|
||||||
|
mluOpTensorLayout_t layout;
|
||||||
|
mluOpDataType_t dtype;
|
||||||
|
int dim;
|
||||||
|
int dims[8];
|
||||||
|
|
||||||
|
// features_desc
|
||||||
|
mluOpGetTensorDescriptor(features_desc.desc(), &layout, &dtype, &dim, dims);
|
||||||
|
mluOpSetTensorDescriptor(features_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype,
|
||||||
|
dim, dims);
|
||||||
|
|
||||||
|
// filters_desc
|
||||||
|
mluOpGetTensorDescriptor(filters_desc.desc(), &layout, &dtype, &dim, dims);
|
||||||
|
if (dim == 4) {
|
||||||
|
mluOpSetTensorDescriptor(filters_desc.desc(), MLUOP_LAYOUT_HWCN, dtype,
|
||||||
|
dim, dims);
|
||||||
|
} else {
|
||||||
|
mluOpSetTensorDescriptor(filters_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype,
|
||||||
|
dim, dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
// output_grad_desc
|
||||||
|
mluOpGetTensorDescriptor(output_grad_desc.desc(), &layout, &dtype, &dim,
|
||||||
|
dims);
|
||||||
|
mluOpSetTensorDescriptor(output_grad_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype,
|
||||||
|
dim, dims);
|
||||||
|
|
||||||
|
// indice_pairs_desc
|
||||||
|
mluOpGetTensorDescriptor(indice_pairs_desc.desc(), &layout, &dtype, &dim,
|
||||||
|
dims);
|
||||||
|
mluOpSetTensorDescriptor(indice_pairs_desc.desc(), MLUOP_LAYOUT_ARRAY,
|
||||||
|
dtype, dim, dims);
|
||||||
|
|
||||||
|
// input_grad_desc
|
||||||
|
mluOpGetTensorDescriptor(input_grad_desc.desc(), &layout, &dtype, &dim,
|
||||||
|
dims);
|
||||||
|
mluOpSetTensorDescriptor(input_grad_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype,
|
||||||
|
dim, dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto handle = mluOpGetCurrentHandle();
|
||||||
|
size_t data_workspace_size = 0;
|
||||||
|
mluOpGetIndiceConvolutionBackwardDataWorkspaceSize(
|
||||||
|
handle, output_grad_desc.desc(), filters_desc.desc(),
|
||||||
|
indice_pairs_desc.desc(), input_grad_desc.desc(), indice_num, _inverse,
|
||||||
|
&data_workspace_size);
|
||||||
|
|
||||||
|
size_t filters_workspace_size = 0;
|
||||||
|
mluOpGetIndiceConvolutionBackwardFilterWorkspaceSize(
|
||||||
|
handle, features_desc.desc(), output_grad_desc.desc(),
|
||||||
|
indice_pairs_desc.desc(), filters_grad_desc.desc(), indice_num, _inverse,
|
||||||
|
_subM, &filters_workspace_size);
|
||||||
|
|
||||||
|
auto indice_convbpdata_workspace =
|
||||||
|
at::empty(data_workspace_size, features.options().dtype(at::kByte));
|
||||||
|
auto indice_convbpfilter_workspace =
|
||||||
|
at::empty(filters_workspace_size, filters.options().dtype(at::kByte));
|
||||||
|
|
||||||
|
auto features_impl = torch_mlu::getMluTensorImpl(features_contiguous);
|
||||||
|
auto filters_impl = torch_mlu::getMluTensorImpl(filters_contiguous);
|
||||||
|
auto output_grad_impl = torch_mlu::getMluTensorImpl(output_grad_contiguous);
|
||||||
|
auto indice_pairs_impl = torch_mlu::getMluTensorImpl(indice_pairs_contiguous);
|
||||||
|
auto indice_convbpdata_workspace_impl =
|
||||||
|
torch_mlu::getMluTensorImpl(indice_convbpdata_workspace);
|
||||||
|
auto indice_convbpfilter_workspace_impl =
|
||||||
|
torch_mlu::getMluTensorImpl(indice_convbpfilter_workspace);
|
||||||
|
|
||||||
|
auto features_ptr = features_impl->cnnlMalloc();
|
||||||
|
auto filters_ptr = filters_impl->cnnlMalloc();
|
||||||
|
auto output_grad_ptr = output_grad_impl->cnnlMalloc();
|
||||||
|
auto indice_pairs_ptr = indice_pairs_impl->cnnlMalloc();
|
||||||
|
auto indice_convbpdata_workspace_ptr =
|
||||||
|
indice_convbpdata_workspace_impl->cnnlMalloc();
|
||||||
|
auto indice_convbpfilter_workspace_ptr =
|
||||||
|
indice_convbpfilter_workspace_impl->cnnlMalloc();
|
||||||
|
|
||||||
|
// outputs
|
||||||
|
auto input_grad_impl = torch_mlu::getMluTensorImpl(input_grad);
|
||||||
|
auto input_grad_ptr = input_grad_impl->cnnlMalloc();
|
||||||
|
auto filters_grad_impl = torch_mlu::getMluTensorImpl(filters_grad);
|
||||||
|
auto filters_grad_ptr = filters_grad_impl->cnnlMalloc();
|
||||||
|
|
||||||
|
mluOpIndiceConvolutionBackwardData(
|
||||||
|
handle, output_grad_desc.desc(), output_grad_ptr, filters_desc.desc(),
|
||||||
|
filters_ptr, indice_pairs_desc.desc(), indice_pairs_ptr, indice_num,
|
||||||
|
_inverse, _subM, indice_convbpdata_workspace_ptr, data_workspace_size,
|
||||||
|
input_grad_desc.desc(), input_grad_ptr);
|
||||||
|
|
||||||
|
mluOpIndiceConvolutionBackwardFilter(
|
||||||
|
handle, features_desc.desc(), features_ptr, output_grad_desc.desc(),
|
||||||
|
output_grad_ptr, indice_pairs_desc.desc(), indice_pairs_ptr, indice_num,
|
||||||
|
_inverse, _subM, indice_convbpfilter_workspace_ptr,
|
||||||
|
filters_workspace_size, filters_grad_desc.desc(), filters_grad_ptr);
|
||||||
|
|
||||||
|
std::vector<torch::Tensor> result;
|
||||||
|
result.push_back(input_grad);
|
||||||
|
result.push_back(filters_grad);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
torch::Tensor indice_conv_forward_mlu(torch::Tensor features,
|
||||||
|
torch::Tensor filters,
|
||||||
|
torch::Tensor indicePairs,
|
||||||
|
torch::Tensor indiceNum,
|
||||||
|
int64_t numActOut, int64_t _inverse,
|
||||||
|
int64_t _subM) {
|
||||||
|
return IndiceConvForwardMLUKernelLauncher(
|
||||||
|
features, filters, indicePairs, indiceNum, numActOut, _inverse, _subM);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<torch::Tensor> indice_conv_backward_mlu(
|
||||||
|
torch::Tensor features, torch::Tensor filters, torch::Tensor outGrad,
|
||||||
|
torch::Tensor indicePairs, torch::Tensor indiceNum, int64_t _inverse,
|
||||||
|
int64_t _subM) {
|
||||||
|
return IndiceConvBackwardMLUKernelLauncher(
|
||||||
|
features, filters, outGrad, indicePairs, indiceNum, _inverse, _subM);
|
||||||
|
}
|
||||||
|
|
||||||
|
torch::Tensor indice_conv_forward_impl(torch::Tensor features,
|
||||||
|
torch::Tensor filters,
|
||||||
|
torch::Tensor indicePairs,
|
||||||
|
torch::Tensor indiceNum,
|
||||||
|
int64_t numActOut, int64_t _inverse,
|
||||||
|
int64_t _subM);
|
||||||
|
|
||||||
|
std::vector<torch::Tensor> indice_conv_backward_impl(
|
||||||
|
torch::Tensor features, torch::Tensor filters, torch::Tensor outGrad,
|
||||||
|
torch::Tensor indicePairs, torch::Tensor indiceNum, int64_t _inverse,
|
||||||
|
int64_t _subM);
|
||||||
|
|
||||||
|
REGISTER_DEVICE_IMPL(indice_conv_forward_impl, MLU, indice_conv_forward_mlu);
|
||||||
|
REGISTER_DEVICE_IMPL(indice_conv_backward_impl, MLU, indice_conv_backward_mlu);
|
||||||
|
|
||||||
|
template std::vector<torch::Tensor> GetIndicePairsForwardMLUKernelLauncher<2>(
|
||||||
|
torch::Tensor indices, int64_t batchSize,
|
||||||
|
std::vector<int64_t> outSpatialShape, std::vector<int64_t> spatialShape,
|
||||||
|
std::vector<int64_t> kernelSize, std::vector<int64_t> stride,
|
||||||
|
std::vector<int64_t> padding, std::vector<int64_t> dilation,
|
||||||
|
std::vector<int64_t> outPadding, int64_t _subM, int64_t _transpose);
|
||||||
|
|
||||||
|
template std::vector<torch::Tensor> GetIndicePairsForwardMLUKernelLauncher<3>(
|
||||||
|
torch::Tensor indices, int64_t batchSize,
|
||||||
|
std::vector<int64_t> outSpatialShape, std::vector<int64_t> spatialShape,
|
||||||
|
std::vector<int64_t> kernelSize, std::vector<int64_t> stride,
|
||||||
|
std::vector<int64_t> padding, std::vector<int64_t> dilation,
|
||||||
|
std::vector<int64_t> outPadding, int64_t _subM, int64_t _transpose);
|
||||||
|
|
||||||
|
template std::vector<torch::Tensor> GetIndicePairsForwardMLUKernelLauncher<4>(
|
||||||
|
torch::Tensor indices, int64_t batchSize,
|
||||||
|
std::vector<int64_t> outSpatialShape, std::vector<int64_t> spatialShape,
|
||||||
|
std::vector<int64_t> kernelSize, std::vector<int64_t> stride,
|
||||||
|
std::vector<int64_t> padding, std::vector<int64_t> dilation,
|
||||||
|
std::vector<int64_t> outPadding, int64_t _subM, int64_t _transpose);
|
|
@ -35,6 +35,26 @@ std::vector<torch::Tensor> get_indice_pairs_forward_cuda(
|
||||||
padding, dilation, outPadding, _subM, _transpose);
|
padding, dilation, outPadding, _subM, _transpose);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <unsigned NDim>
|
||||||
|
std::vector<torch::Tensor> GetIndicePairsForwardMLUKernelLauncher(
|
||||||
|
torch::Tensor indices, int64_t batchSize,
|
||||||
|
std::vector<int64_t> outSpatialShape, std::vector<int64_t> spatialShape,
|
||||||
|
std::vector<int64_t> kernelSize, std::vector<int64_t> stride,
|
||||||
|
std::vector<int64_t> padding, std::vector<int64_t> dilation,
|
||||||
|
std::vector<int64_t> outPadding, int64_t _subM, int64_t _transpose);
|
||||||
|
|
||||||
|
template <unsigned NDim>
|
||||||
|
std::vector<torch::Tensor> get_indice_pairs_forward_mlu(
|
||||||
|
torch::Tensor indices, int64_t batchSize,
|
||||||
|
std::vector<int64_t> outSpatialShape, std::vector<int64_t> spatialShape,
|
||||||
|
std::vector<int64_t> kernelSize, std::vector<int64_t> stride,
|
||||||
|
std::vector<int64_t> padding, std::vector<int64_t> dilation,
|
||||||
|
std::vector<int64_t> outPadding, int64_t _subM, int64_t _transpose) {
|
||||||
|
return GetIndicePairsForwardMLUKernelLauncher<NDim>(
|
||||||
|
indices, batchSize, outSpatialShape, spatialShape, kernelSize, stride,
|
||||||
|
padding, dilation, outPadding, _subM, _transpose);
|
||||||
|
}
|
||||||
|
|
||||||
template <unsigned NDim>
|
template <unsigned NDim>
|
||||||
std::vector<torch::Tensor> GetIndicePairsBackwardCUDAKernelLauncher(
|
std::vector<torch::Tensor> GetIndicePairsBackwardCUDAKernelLauncher(
|
||||||
torch::Tensor indices, torch::Tensor gridOut, int64_t batchSize,
|
torch::Tensor indices, torch::Tensor gridOut, int64_t batchSize,
|
||||||
|
@ -71,6 +91,12 @@ std::vector<torch::Tensor> get_indice_pairs_forward(
|
||||||
padding, dilation, outPadding, _subM, _transpose);
|
padding, dilation, outPadding, _subM, _transpose);
|
||||||
#else
|
#else
|
||||||
AT_ERROR("get_indice_pairs is not compiled with GPU support");
|
AT_ERROR("get_indice_pairs is not compiled with GPU support");
|
||||||
|
#endif
|
||||||
|
#ifdef MMCV_WITH_MLU
|
||||||
|
} else if (indices.device().type() == at::kMLU) {
|
||||||
|
return get_indice_pairs_forward_mlu<NDim>(
|
||||||
|
indices, batchSize, outSpatialShape, spatialShape, kernelSize, stride,
|
||||||
|
padding, dilation, outPadding, _subM, _transpose);
|
||||||
#endif
|
#endif
|
||||||
} else {
|
} else {
|
||||||
AT_ERROR("get_indice_pairs is not implemented on CPU");
|
AT_ERROR("get_indice_pairs is not implemented on CPU");
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -388,7 +388,7 @@ def get_extensions():
|
||||||
glob.glob(
|
glob.glob(
|
||||||
'./mlu-ops/bangc-ops/kernels/**/*.mlu', recursive=True)
|
'./mlu-ops/bangc-ops/kernels/**/*.mlu', recursive=True)
|
||||||
extra_objects = glob.glob(
|
extra_objects = glob.glob(
|
||||||
'./mlu-ops/bangc-ops/kernels/*/x86_64/*.o')
|
'./mlu-ops/bangc-ops/kernels/kernel_wrapper/*.o')
|
||||||
extension = MLUExtension
|
extension = MLUExtension
|
||||||
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common'))
|
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common'))
|
||||||
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/mlu'))
|
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/mlu'))
|
||||||
|
|
|
@ -10,6 +10,8 @@ from mmcv.ops import (SparseConvTensor, SparseInverseConv3d, SparseSequential,
|
||||||
if torch.__version__ == 'parrots':
|
if torch.__version__ == 'parrots':
|
||||||
pytest.skip('not supported in parrots now', allow_module_level=True)
|
pytest.skip('not supported in parrots now', allow_module_level=True)
|
||||||
|
|
||||||
|
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
|
||||||
|
|
||||||
|
|
||||||
def make_sparse_convmodule(in_channels,
|
def make_sparse_convmodule(in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
|
@ -76,21 +78,29 @@ def make_sparse_convmodule(in_channels,
|
||||||
return layers
|
return layers
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.parametrize('device', [
|
||||||
not torch.cuda.is_available(), reason='requires CUDA support')
|
pytest.param(
|
||||||
def test_make_sparse_convmodule():
|
'cuda',
|
||||||
|
marks=pytest.mark.skipif(
|
||||||
|
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
|
||||||
|
pytest.param(
|
||||||
|
'mlu',
|
||||||
|
marks=pytest.mark.skipif(
|
||||||
|
not IS_MLU_AVAILABLE, reason='requires MLU support'))
|
||||||
|
])
|
||||||
|
def test_make_sparse_convmodule(device):
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
voxel_features = torch.tensor([[6.56126, 0.9648336, -1.7339306, 0.315],
|
voxel_features = torch.tensor([[6.56126, 0.9648336, -1.7339306, 0.315],
|
||||||
[6.8162713, -2.480431, -1.3616394, 0.36],
|
[6.8162713, -2.480431, -1.3616394, 0.36],
|
||||||
[11.643568, -4.744306, -1.3580885, 0.16],
|
[11.643568, -4.744306, -1.3580885, 0.16],
|
||||||
[23.482342, 6.5036807, 0.5806964, 0.35]],
|
[23.482342, 6.5036807, 0.5806964, 0.35]],
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device='cuda') # n, point_features
|
device=device) # n, point_features
|
||||||
coordinates = torch.tensor(
|
coordinates = torch.tensor(
|
||||||
[[0, 12, 819, 131], [0, 16, 750, 136], [1, 16, 705, 232],
|
[[0, 12, 819, 131], [0, 16, 750, 136], [1, 16, 705, 232],
|
||||||
[1, 35, 930, 469]],
|
[1, 35, 930, 469]],
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device='cuda') # n, 4(batch, ind_x, ind_y, ind_z)
|
device=device) # n, 4(batch, ind_x, ind_y, ind_z)
|
||||||
|
|
||||||
# test
|
# test
|
||||||
input_sp_tensor = SparseConvTensor(voxel_features, coordinates,
|
input_sp_tensor = SparseConvTensor(voxel_features, coordinates,
|
||||||
|
@ -105,7 +115,7 @@ def test_make_sparse_convmodule():
|
||||||
padding=0,
|
padding=0,
|
||||||
conv_type='SubMConv3d',
|
conv_type='SubMConv3d',
|
||||||
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01),
|
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01),
|
||||||
order=('conv', 'norm', 'act')).cuda()
|
order=('conv', 'norm', 'act')).to(device)
|
||||||
assert isinstance(sparse_block0[0], SubMConv3d)
|
assert isinstance(sparse_block0[0], SubMConv3d)
|
||||||
assert sparse_block0[0].in_channels == 4
|
assert sparse_block0[0].in_channels == 4
|
||||||
assert sparse_block0[0].out_channels == 16
|
assert sparse_block0[0].out_channels == 16
|
||||||
|
@ -118,6 +128,8 @@ def test_make_sparse_convmodule():
|
||||||
out_features = sparse_block0(input_sp_tensor)
|
out_features = sparse_block0(input_sp_tensor)
|
||||||
assert out_features.features.shape == torch.Size([4, 16])
|
assert out_features.features.shape == torch.Size([4, 16])
|
||||||
|
|
||||||
|
# device == mlu: not support inverse==1 yet
|
||||||
|
if device != 'mlu':
|
||||||
sparse_block1 = make_sparse_convmodule(
|
sparse_block1 = make_sparse_convmodule(
|
||||||
4,
|
4,
|
||||||
16,
|
16,
|
||||||
|
@ -127,7 +139,7 @@ def test_make_sparse_convmodule():
|
||||||
padding=0,
|
padding=0,
|
||||||
conv_type='SparseInverseConv3d',
|
conv_type='SparseInverseConv3d',
|
||||||
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01),
|
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01),
|
||||||
order=('norm', 'act', 'conv')).cuda()
|
order=('norm', 'act', 'conv')).to(device)
|
||||||
|
assert isinstance(sparse_block1[2], SparseInverseConv3d)
|
||||||
assert isinstance(sparse_block1[0], torch.nn.BatchNorm1d)
|
assert isinstance(sparse_block1[0], torch.nn.BatchNorm1d)
|
||||||
assert isinstance(sparse_block1[1], torch.nn.ReLU)
|
assert isinstance(sparse_block1[1], torch.nn.ReLU)
|
||||||
assert isinstance(sparse_block1[2], SparseInverseConv3d)
|
|
||||||
|
|
Loading…
Reference in New Issue