mmcv/mmcv/ops/csrc/pytorch/mlu/carafe_mlu.cpp

209 lines
9.3 KiB
C++

/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "mlu_common_helper.h"
void CARAFEForwardMLUKernelLauncher(const Tensor input, const Tensor mask,
Tensor rinput, Tensor routput, Tensor rmask,
Tensor output, const int kernel_size,
const int group_size,
const int scale_factor) {
// check tensor data type
TORCH_CHECK(
input.scalar_type() == at::kFloat || input.scalar_type() == at::kHalf,
"Data type of input should be Float or Half. But now input type is ",
input.scalar_type(), ".");
TORCH_CHECK(mask.scalar_type() == input.scalar_type(),
"Data types of input and mask should be the same, but got ",
input.scalar_type(), " and ", mask.scalar_type());
// check number of dimensions
TORCH_CHECK(input.dim() == 4, "input should be a 4-D tensor, but has ",
input.dim(), "D.");
TORCH_CHECK(mask.dim() == 4, "mask should be a 4-D tensor, but has ",
input.dim(), "D.");
// return fast on zero-element tensor
if (output.numel() == 0) {
output = at::zeros(output.sizes().vec(), output.options());
return;
}
// convert NCHW to NHWC
auto memory_format_input_nhwc =
torch_mlu::cnnl::ops::get_channels_last_memory_format(input.dim());
auto rinput_ =
torch_mlu::cnnl::ops::cnnl_contiguous(input, memory_format_input_nhwc);
auto memory_format_mask_nhwc =
torch_mlu::cnnl::ops::get_channels_last_memory_format(mask.dim());
auto rmask_ =
torch_mlu::cnnl::ops::cnnl_contiguous(mask, memory_format_mask_nhwc);
auto memory_format_output_nhwc =
torch_mlu::cnnl::ops::get_channels_last_memory_format(output.dim());
auto routput_ =
torch_mlu::cnnl::ops::cnnl_contiguous(output, memory_format_output_nhwc);
// set tensor descriptor
MluOpTensorDescriptor input_desc, mask_desc, output_desc;
input_desc.set_with_layout(rinput_, MLUOP_LAYOUT_NHWC);
mask_desc.set_with_layout(rmask_, MLUOP_LAYOUT_NHWC);
output_desc.set_with_layout(routput_, MLUOP_LAYOUT_NHWC);
// get ptr of tensors
auto input_impl = torch_mlu::getMluTensorImpl(rinput_);
auto input_ptr = input_impl->cnnlMalloc();
auto mask_impl = torch_mlu::getMluTensorImpl(rmask_);
auto mask_ptr = mask_impl->cnnlMalloc();
auto output_impl = torch_mlu::getMluTensorImpl(routput_);
auto output_ptr = output_impl->cnnlMalloc();
// set op descriptor
auto handle = mluOpGetCurrentHandle();
mluOpCarafeDescriptor_t carafe_desc;
TORCH_MLUOP_CHECK(mluOpCreateCarafeDescriptor(&carafe_desc));
TORCH_MLUOP_CHECK(mluOpSetCarafeDescriptor(
carafe_desc, input.dim(), kernel_size, group_size, scale_factor));
// launch kernel
TORCH_MLUOP_CHECK(mluOpCarafeForward(handle, carafe_desc, input_desc.desc(),
input_ptr, mask_desc.desc(), mask_ptr,
output_desc.desc(), output_ptr));
// destroy op descriptor
TORCH_MLUOP_CHECK(mluOpDestroyCarafeDescriptor(carafe_desc));
// copy output from NHWC back into NCHW
rinput.copy_(rinput_);
output.copy_(routput_);
}
void CARAFEBackwardMLUKernelLauncher(
const Tensor grad_output, const Tensor rinput, const Tensor mask,
Tensor rgrad_output, Tensor rgrad_input_hs, Tensor rgrad_input,
Tensor rgrad_mask, Tensor grad_input, Tensor grad_mask,
const int kernel_size, const int group_size, const int scale_factor) {
// data type check
TORCH_CHECK(grad_output.scalar_type() == at::kFloat ||
grad_output.scalar_type() == at::kHalf,
"grad_output type should be Float or Half, got ",
grad_output.scalar_type());
TORCH_CHECK(grad_output.scalar_type() == mask.scalar_type(),
"mask should have the same type as grad_output");
// dim check
TORCH_CHECK(grad_output.dim() == 4, "grad_output should be a 4d tensor, got ",
grad_output.dim(), "D");
// param check
TORCH_CHECK(kernel_size < 137, "kernel_size should be less than 137, got ",
kernel_size);
// convert NCHW to NHWC
auto memory_format_input_nhwc =
torch_mlu::cnnl::ops::get_channels_last_memory_format(rinput.dim());
auto rinput_ =
torch_mlu::cnnl::ops::cnnl_contiguous(rinput, memory_format_input_nhwc);
auto memory_format_mask_nhwc =
torch_mlu::cnnl::ops::get_channels_last_memory_format(mask.dim());
auto rmask_ =
torch_mlu::cnnl::ops::cnnl_contiguous(mask, memory_format_mask_nhwc);
auto memory_format_grad_output_nhwc =
torch_mlu::cnnl::ops::get_channels_last_memory_format(grad_output.dim());
auto rgrad_output_ = torch_mlu::cnnl::ops::cnnl_contiguous(
grad_output, memory_format_grad_output_nhwc);
auto memory_format_grad_input_nhwc =
torch_mlu::cnnl::ops::get_channels_last_memory_format(grad_input.dim());
auto rgrad_input_ = torch_mlu::cnnl::ops::cnnl_contiguous(
grad_input, memory_format_grad_input_nhwc)
.zero_();
auto memory_format_grad_mask_nhwc =
torch_mlu::cnnl::ops::get_channels_last_memory_format(grad_mask.dim());
auto rgrad_mask_ = torch_mlu::cnnl::ops::cnnl_contiguous(
grad_mask, memory_format_grad_mask_nhwc);
// set tensor descriptor
MluOpTensorDescriptor input_desc, mask_desc;
input_desc.set_with_layout(rinput_, MLUOP_LAYOUT_NHWC);
mask_desc.set_with_layout(rmask_, MLUOP_LAYOUT_NHWC);
MluOpTensorDescriptor grad_output_desc, grad_input_desc, grad_mask_desc;
grad_output_desc.set_with_layout(rgrad_output_, MLUOP_LAYOUT_NHWC);
grad_input_desc.set_with_layout(rgrad_input_, MLUOP_LAYOUT_NHWC);
grad_mask_desc.set_with_layout(rgrad_mask_, MLUOP_LAYOUT_NHWC);
// get ptr of tensors
auto input_impl = torch_mlu::getMluTensorImpl(rinput_);
auto input_ptr = input_impl->cnnlMalloc();
auto mask_impl = torch_mlu::getMluTensorImpl(rmask_);
auto mask_ptr = mask_impl->cnnlMalloc();
auto grad_output_impl = torch_mlu::getMluTensorImpl(rgrad_output_);
auto grad_output_ptr = grad_output_impl->cnnlMalloc();
auto grad_input_impl = torch_mlu::getMluTensorImpl(rgrad_input_);
auto grad_input_ptr = grad_input_impl->cnnlMalloc();
auto grad_mask_impl = torch_mlu::getMluTensorImpl(rgrad_mask_);
auto grad_mask_ptr = grad_mask_impl->cnnlMalloc();
// set op descriptor
auto handle = mluOpGetCurrentHandle();
mluOpCarafeDescriptor_t carafe_desc;
TORCH_MLUOP_CHECK(mluOpCreateCarafeDescriptor(&carafe_desc));
TORCH_MLUOP_CHECK(mluOpSetCarafeDescriptor(
carafe_desc, grad_output.dim(), kernel_size, group_size, scale_factor));
// launch kernel
TORCH_MLUOP_CHECK(mluOpCarafeBackward(
handle, carafe_desc, input_desc.desc(), input_ptr, mask_desc.desc(),
mask_ptr, grad_output_desc.desc(), grad_output_ptr,
grad_input_desc.desc(), grad_input_ptr, grad_mask_desc.desc(),
grad_mask_ptr));
// destroy op descriptor
TORCH_MLUOP_CHECK(mluOpDestroyCarafeDescriptor(carafe_desc));
// copy output from NHWC back into NCHW
grad_input.copy_(rgrad_input_);
grad_mask.copy_(rgrad_mask_);
}
void carafe_forward_mlu(Tensor features, Tensor masks, Tensor rfeatures,
Tensor routput, Tensor rmasks, Tensor output,
int kernel_size, int group_size, int scale_factor) {
CARAFEForwardMLUKernelLauncher(features, masks, rfeatures, routput, rmasks,
output, kernel_size, group_size, scale_factor);
}
void carafe_backward_mlu(Tensor top_grad, Tensor rfeatures, Tensor masks,
Tensor rtop_grad, Tensor rbottom_grad_hs,
Tensor rbottom_grad, Tensor rmask_grad,
Tensor bottom_grad, Tensor mask_grad, int kernel_size,
int group_size, int scale_factor) {
CARAFEBackwardMLUKernelLauncher(top_grad, rfeatures, masks, rtop_grad,
rbottom_grad_hs, rbottom_grad, rmask_grad,
bottom_grad, mask_grad, kernel_size,
group_size, scale_factor);
}
void carafe_forward_impl(Tensor features, Tensor masks, Tensor rfeatures,
Tensor routput, Tensor rmasks, Tensor output,
int kernel_size, int group_size, int scale_factor);
void carafe_backward_impl(Tensor top_grad, Tensor rfeatures, Tensor masks,
Tensor rtop_grad, Tensor rbottom_grad_hs,
Tensor rbottom_grad, Tensor rmask_grad,
Tensor bottom_grad, Tensor mask_grad, int kernel_size,
int group_size, int scale_factor);
REGISTER_DEVICE_IMPL(carafe_forward_impl, MLU, carafe_forward_mlu);
REGISTER_DEVICE_IMPL(carafe_backward_impl, MLU, carafe_backward_mlu);