mirror of
https://github.com/open-mmlab/mmcv.git
synced 2025-06-03 21:54:52 +08:00
[Feature] Support MaskedConv2d with cambricon MLU backend (#2202)
* [Feature] Support MaskedConv2d with cambricon MLU backend * [Refactor] Refactor test masked_conv2d code Co-authored-by: budefei <budefei@cambricon.com>
This commit is contained in:
parent
9b11e560f3
commit
e663670a74
@ -28,7 +28,7 @@ We implement common ops used in detection, segmentation, etc.
|
||||
| GroupPoints | | √ | | |
|
||||
| Iou3d | | √ | | |
|
||||
| KNN | | √ | | |
|
||||
| MaskedConv | | √ | | |
|
||||
| MaskedConv | | √ | √ | |
|
||||
| MergeCells | | √ | | |
|
||||
| MinAreaPolygon | | √ | | |
|
||||
| ModulatedDeformConv2d | √ | √ | | |
|
||||
|
@ -28,7 +28,7 @@ MMCV 提供了检测、分割等任务中常用的算子
|
||||
| GroupPoints | | √ | | |
|
||||
| Iou3d | | √ | | |
|
||||
| KNN | | √ | | |
|
||||
| MaskedConv | | √ | | |
|
||||
| MaskedConv | | √ | √ | |
|
||||
| MergeCells | | √ | | |
|
||||
| MinAreaPolygon | | √ | | |
|
||||
| ModulatedDeformConv2d | √ | √ | | |
|
||||
|
@ -35,6 +35,16 @@
|
||||
|
||||
#define CEIL_ALIGN(x, y) (((x) + (y)-1) / (y) * (y))
|
||||
|
||||
template <typename scalar_t>
|
||||
__mlu_func__ inline scalar_t min(scalar_t a, scalar_t b) {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__mlu_func__ inline scalar_t max(scalar_t a, scalar_t b) {
|
||||
return a > b ? a : b;
|
||||
}
|
||||
|
||||
/*!
|
||||
* @brief loads data from global DRAM to NRAM with 2D pattern.
|
||||
*
|
||||
|
181
mmcv/ops/csrc/common/mlu/masked_conv2d_mlu_kernel.mlu
Executable file
181
mmcv/ops/csrc/common/mlu/masked_conv2d_mlu_kernel.mlu
Executable file
@ -0,0 +1,181 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) 2022 Cambricon.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#include "common_mlu_helper.hpp"
|
||||
|
||||
__nram__ char nram_buffer[MAX_NRAM_SIZE];
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void MLUUnion1MaskedIm2colForward(
|
||||
const T *feature, const int height, const int width, const int channels,
|
||||
const int kernel_h, const int kernel_w, const int pad_h, const int pad_w,
|
||||
const int32_t *mask_h_idx, const int32_t *mask_w_idx, const int mask_cnt,
|
||||
T *data_col) {
|
||||
for (int index = taskId; index < mask_cnt; index += taskDim) {
|
||||
const int h_col = mask_h_idx[index];
|
||||
const int w_col = mask_w_idx[index];
|
||||
const int h_offset = h_col - pad_h;
|
||||
const int w_offset = w_col - pad_w;
|
||||
int h_start = h_offset;
|
||||
int h_end = h_offset + kernel_h - 1;
|
||||
int w_start = w_offset;
|
||||
int w_end = w_start + kernel_w - 1;
|
||||
if (h_start >= height || w_start >= width || h_end < 0 || w_end < 0) {
|
||||
continue;
|
||||
} else {
|
||||
int h_start_valid = max(0, h_start);
|
||||
int h_end_valid = min(height - 1, h_end);
|
||||
int w_start_valid = max(0, w_start);
|
||||
int w_end_valid = min(width - 1, w_end);
|
||||
__memcpy(
|
||||
data_col + index * kernel_h * kernel_w * channels +
|
||||
((h_start_valid - h_start) * kernel_w +
|
||||
(w_start_valid - w_start)) *
|
||||
channels,
|
||||
feature + h_start_valid * width * channels + w_start_valid * channels,
|
||||
(w_end_valid - w_start_valid + 1) * channels * sizeof(T), GDRAM2GDRAM,
|
||||
kernel_w * channels * sizeof(T), width * channels * sizeof(T),
|
||||
h_end_valid - h_start_valid);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void MLUUnion1MaskedCol2imForward(const T *col, const int height,
|
||||
const int width,
|
||||
const int channels,
|
||||
const int32_t *mask_h_idx,
|
||||
const int32_t *mask_w_idx,
|
||||
const int mask_cnt, T *im) {
|
||||
const int channels_max_num_nram = MAX_NRAM_SIZE / sizeof(T);
|
||||
if (channels <= channels_max_num_nram) {
|
||||
const int deal_num = channels_max_num_nram / channels;
|
||||
int mask_per_core = mask_cnt / taskDim;
|
||||
const int mask_remain = mask_cnt % taskDim;
|
||||
mask_per_core += taskId < mask_remain ? 1 : 0;
|
||||
int index_start = taskId < mask_remain
|
||||
? taskId * mask_per_core
|
||||
: taskId * mask_per_core + mask_remain;
|
||||
int loop = mask_per_core / deal_num;
|
||||
int remain_num = mask_per_core % deal_num;
|
||||
T *nram_col = (T *)nram_buffer;
|
||||
for (int index = 0; index < loop; ++index) {
|
||||
int cur_index = index_start + index * deal_num;
|
||||
__memcpy(nram_col, col + cur_index * channels,
|
||||
deal_num * channels * sizeof(T), GDRAM2NRAM);
|
||||
for (int i = 0; i < deal_num; ++i) {
|
||||
int mask_index = cur_index + i;
|
||||
const int h_im = mask_h_idx[mask_index];
|
||||
const int w_im = mask_w_idx[mask_index];
|
||||
// if(h_im>=height || w_im>=width) continue;
|
||||
__memcpy(im + (h_im * width + w_im) * channels, nram_col + i * channels,
|
||||
channels * sizeof(T), NRAM2GDRAM);
|
||||
}
|
||||
}
|
||||
if (remain_num > 0) {
|
||||
int cur_index = index_start + loop * deal_num;
|
||||
__memcpy(nram_col, col + cur_index * channels,
|
||||
remain_num * channels * sizeof(T), GDRAM2NRAM);
|
||||
for (int i = 0; i < remain_num; ++i) {
|
||||
int mask_index = cur_index + i;
|
||||
const int h_im = mask_h_idx[mask_index];
|
||||
const int w_im = mask_w_idx[mask_index];
|
||||
// if(h_im>=height || w_im>=width) continue;
|
||||
__memcpy(im + (h_im * width + w_im) * channels, nram_col + i * channels,
|
||||
channels * sizeof(T), NRAM2GDRAM);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int index = taskId; index < mask_cnt; index += taskDim) {
|
||||
const int m_index = index % mask_cnt;
|
||||
const int h_im = mask_h_idx[m_index];
|
||||
const int w_im = mask_w_idx[m_index];
|
||||
// if(h_im>=height || w_im>=width) continue;
|
||||
__memcpy(im + (h_im * width + w_im) * channels, col + index * channels,
|
||||
channels * sizeof(T), GDRAM2GDRAM);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__mlu_global__ void MLUKernelMaskedIm2colForward(
|
||||
const void *feature, const int height, const int width, const int channels,
|
||||
const int kernel_h, const int kernel_w, const int pad_h, const int pad_w,
|
||||
const void *mask_h_idx, const void *mask_w_idx, const int mask_cnt,
|
||||
void *data_col, const cnrtDataType_t data_dtype) {
|
||||
if (coreId == 0x80) {
|
||||
return;
|
||||
}
|
||||
|
||||
switch (data_dtype) {
|
||||
case CNRT_FLOAT16: {
|
||||
MLUUnion1MaskedIm2colForward((half *)feature, height, width, channels,
|
||||
kernel_h, kernel_w, pad_h, pad_w,
|
||||
(int32_t *)mask_h_idx, (int32_t *)mask_w_idx,
|
||||
mask_cnt, (half *)data_col);
|
||||
}; break;
|
||||
case CNRT_FLOAT32: {
|
||||
MLUUnion1MaskedIm2colForward((float *)feature, height, width, channels,
|
||||
kernel_h, kernel_w, pad_h, pad_w,
|
||||
(int32_t *)mask_h_idx, (int32_t *)mask_w_idx,
|
||||
mask_cnt, (float *)data_col);
|
||||
}; break;
|
||||
default: {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__mlu_global__ void MLUKernelMaskedCol2imForward(
|
||||
const void *col, const int height, const int width, const int channels,
|
||||
const void *mask_h_idx, const void *mask_w_idx, const int mask_cnt,
|
||||
void *im, const cnrtDataType_t data_dtype) {
|
||||
if (coreId == 0x80) {
|
||||
return;
|
||||
}
|
||||
switch (data_dtype) {
|
||||
case CNRT_FLOAT16: {
|
||||
MLUUnion1MaskedCol2imForward((half *)col, height, width, channels,
|
||||
(int32_t *)mask_h_idx, (int32_t *)mask_w_idx,
|
||||
mask_cnt, (half *)im);
|
||||
}; break;
|
||||
case CNRT_FLOAT32: {
|
||||
MLUUnion1MaskedCol2imForward((float *)col, height, width, channels,
|
||||
(int32_t *)mask_h_idx, (int32_t *)mask_w_idx,
|
||||
mask_cnt, (float *)im);
|
||||
}; break;
|
||||
default: {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void KernelMaskedIm2colForward(
|
||||
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
|
||||
cnrtDataType_t k_dtype, const void *im_ptr, const int height,
|
||||
const int width, const int channels, const int kernel_h, const int kernel_w,
|
||||
const int pad_h, const int pad_w, const void *mask_h_idx_ptr,
|
||||
const void *mask_w_idx_ptr, const int mask_cnt, void *col_ptr) {
|
||||
MLUKernelMaskedIm2colForward<<<k_dim, k_type, queue>>>(
|
||||
im_ptr, height, width, channels, kernel_h, kernel_w, pad_h, pad_w,
|
||||
mask_h_idx_ptr, mask_w_idx_ptr, mask_cnt, col_ptr, k_dtype);
|
||||
}
|
||||
|
||||
void KernelMaskedCol2imForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
|
||||
cnrtQueue_t queue, cnrtDataType_t k_dtype,
|
||||
const void *col_ptr, const int height,
|
||||
const int width, const int channels,
|
||||
const void *mask_h_idx_ptr,
|
||||
const void *mask_w_idx_ptr, const int mask_cnt,
|
||||
void *im_ptr) {
|
||||
MLUKernelMaskedCol2imForward<<<k_dim, k_type, queue>>>(
|
||||
col_ptr, height, width, channels, mask_h_idx_ptr, mask_w_idx_ptr,
|
||||
mask_cnt, im_ptr, k_dtype);
|
||||
}
|
226
mmcv/ops/csrc/pytorch/mlu/masked_conv2d_mlu.cpp
Executable file
226
mmcv/ops/csrc/pytorch/mlu/masked_conv2d_mlu.cpp
Executable file
@ -0,0 +1,226 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) 2022 Cambricon.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#include "pytorch_device_registry.hpp"
|
||||
#include "pytorch_mlu_helper.hpp"
|
||||
|
||||
void KernelMaskedIm2colForward(
|
||||
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
|
||||
cnrtDataType_t k_dtype, const void *im_ptr, const int height,
|
||||
const int width, const int channels, const int kernel_h, const int kernel_w,
|
||||
const int pad_h, const int pad_w, const void *mask_h_idx_ptr,
|
||||
const void *mask_w_idx_ptr, const int mask_cnt, void *col_ptr);
|
||||
|
||||
void KernelMaskedCol2imForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
|
||||
cnrtQueue_t queue, cnrtDataType_t k_dtype,
|
||||
const void *col_ptr, const int height,
|
||||
const int width, const int channels,
|
||||
const void *mask_h_idx_ptr,
|
||||
const void *mask_w_idx_ptr, const int mask_cnt,
|
||||
void *im_ptr);
|
||||
|
||||
// policy function
|
||||
static void policyFunc(const int mask_cnt, cnrtDim3_t *k_dim,
|
||||
cnrtFunctionType_t *k_type) {
|
||||
const size_t cluster_num = torch_mlu::getDeviceAttr(cnrtAttrClusterCount);
|
||||
const size_t core_num = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
|
||||
const size_t task_dim = CEIL_ALIGN(mask_cnt, core_num);
|
||||
k_dim->x = core_num;
|
||||
k_dim->y =
|
||||
(task_dim / core_num) > cluster_num ? cluster_num : (task_dim / core_num);
|
||||
k_dim->z = 1;
|
||||
*k_type = CNRT_FUNC_TYPE_UNION1;
|
||||
}
|
||||
|
||||
void MaskedIm2colForwardMLUKernelLauncher(const Tensor im,
|
||||
const Tensor mask_h_idx,
|
||||
const Tensor mask_w_idx, Tensor col,
|
||||
const int kernel_h,
|
||||
const int kernel_w, const int pad_h,
|
||||
const int pad_w) {
|
||||
// Check dtype.
|
||||
TORCH_CHECK(im.scalar_type() == at::kFloat || im.scalar_type() == at::kHalf,
|
||||
"im type should be Float or Half, got ", im.scalar_type(), ".");
|
||||
TORCH_CHECK(mask_h_idx.scalar_type() == at::kInt ||
|
||||
mask_h_idx.scalar_type() == at::kLong,
|
||||
"mask_h_idx type should be Int or Long, got ",
|
||||
mask_h_idx.scalar_type(), ".");
|
||||
TORCH_CHECK(mask_w_idx.scalar_type() == at::kInt ||
|
||||
mask_w_idx.scalar_type() == at::kLong,
|
||||
"mask_w_idx type should be Int or Long, got ",
|
||||
mask_w_idx.scalar_type(), ".");
|
||||
TORCH_CHECK(kernel_h > 0, "kernel_h should greater than 0, got ", kernel_h,
|
||||
".");
|
||||
TORCH_CHECK(kernel_w > 0, "kernel_w should greater than 0, got ", kernel_w,
|
||||
".");
|
||||
|
||||
// zero element check
|
||||
TORCH_CHECK(im.numel() > 0, "im.numel should greater than zero, got ",
|
||||
im.numel(), ".");
|
||||
TORCH_CHECK(col.size(0) > 0, "col.size(0) should greater than zero, got ",
|
||||
col.size(0), ".");
|
||||
|
||||
// large tensor check
|
||||
const size_t max_input_num = 2147483648; // 2^31, 2G num
|
||||
TORCH_CHECK(im.numel() < max_input_num,
|
||||
"im.numel() should be less than 2147483648, got ", im.numel(),
|
||||
".");
|
||||
TORCH_CHECK(col.numel() < max_input_num,
|
||||
"col.numel() should be less than 2147483648, got ", col.numel(),
|
||||
".");
|
||||
|
||||
const int channels = im.size(1);
|
||||
const int height = im.size(2);
|
||||
const int width = im.size(3);
|
||||
const int mask_cnt = mask_h_idx.size(0);
|
||||
|
||||
// auto im_t = im.permute({0, 2, 3, 1}).contiguous();
|
||||
auto memory_format =
|
||||
torch_mlu::cnnl::ops::get_channels_last_memory_format(im.dim());
|
||||
auto im_ = torch_mlu::cnnl::ops::cnnl_contiguous(im, memory_format);
|
||||
auto col_ =
|
||||
at::zeros({mask_cnt, kernel_h * kernel_w, channels}, col.options());
|
||||
// calculate task dimension
|
||||
cnrtDim3_t k_dim;
|
||||
cnrtFunctionType_t k_type;
|
||||
policyFunc(mask_cnt, &k_dim, &k_type);
|
||||
|
||||
// get compute queue
|
||||
auto queue = torch_mlu::getCurQueue();
|
||||
// get ptr of tensors
|
||||
auto im_impl = torch_mlu::getMluTensorImpl(im_);
|
||||
auto im_ptr = im_impl->cnnlMalloc();
|
||||
auto mask_h_idx_impl = torch_mlu::getMluTensorImpl(mask_h_idx);
|
||||
auto mask_h_idx_ptr = mask_h_idx_impl->cnnlMalloc();
|
||||
auto mask_w_idx_impl = torch_mlu::getMluTensorImpl(mask_w_idx);
|
||||
auto mask_w_idx_ptr = mask_w_idx_impl->cnnlMalloc();
|
||||
auto col_impl = torch_mlu::getMluTensorImpl(col_);
|
||||
auto col_ptr = col_impl->cnnlMalloc();
|
||||
|
||||
// get comput dtype of input
|
||||
cnrtDataType_t data_type = torch_mlu::toCnrtDtype(im.dtype());
|
||||
|
||||
// launch kernel
|
||||
CNLOG(INFO) << "Launch Kernel MLUKernelMaskedIm2colForward<<<" << k_dim.x
|
||||
<< ", " << k_dim.y << ", " << k_dim.z << ">>>";
|
||||
KernelMaskedIm2colForward(k_dim, k_type, queue, data_type, im_ptr, height,
|
||||
width, channels, kernel_h, kernel_w, pad_h, pad_w,
|
||||
mask_h_idx_ptr, mask_w_idx_ptr, mask_cnt, col_ptr);
|
||||
|
||||
col.copy_(col_.permute({2, 1, 0})
|
||||
.reshape({channels * kernel_h * kernel_w, mask_cnt})
|
||||
.contiguous());
|
||||
}
|
||||
|
||||
void MaskedCol2imForwardMLUKernelLauncher(const Tensor col,
|
||||
const Tensor mask_h_idx,
|
||||
const Tensor mask_w_idx, Tensor im,
|
||||
const int height, const int width,
|
||||
const int channels) {
|
||||
// Check dtype.
|
||||
TORCH_CHECK(col.scalar_type() == at::kFloat || col.scalar_type() == at::kHalf,
|
||||
"col type should be Float or Half, got ", col.scalar_type(), ".");
|
||||
TORCH_CHECK(mask_h_idx.scalar_type() == at::kInt ||
|
||||
mask_h_idx.scalar_type() == at::kLong,
|
||||
"mask_h_idx type should be Int or Long, got ",
|
||||
mask_h_idx.scalar_type(), ".");
|
||||
TORCH_CHECK(mask_w_idx.scalar_type() == at::kInt ||
|
||||
mask_w_idx.scalar_type() == at::kLong,
|
||||
"mask_w_idx type should be Int or Long, got ",
|
||||
mask_w_idx.scalar_type(), ".");
|
||||
|
||||
// zero element check
|
||||
TORCH_CHECK(im.numel() > 0, "im.numel should greater than zero, got ",
|
||||
im.numel(), ".");
|
||||
TORCH_CHECK(col.size(0) > 0, "col.size(0) should greater than zero, got ",
|
||||
col.size(0), ".");
|
||||
|
||||
// large tensor check
|
||||
const size_t max_input_num = 2147483648; // 2^31, 2G num
|
||||
TORCH_CHECK(im.numel() < max_input_num,
|
||||
"im.numel() should be less than 2147483648, got ", im.numel(),
|
||||
".");
|
||||
TORCH_CHECK(col.numel() < max_input_num,
|
||||
"col.numel() should be less than 2147483648, got ", col.numel(),
|
||||
".");
|
||||
|
||||
auto memory_format =
|
||||
torch_mlu::cnnl::ops::get_channels_last_memory_format(im.dim());
|
||||
at::Tensor im_ =
|
||||
at::empty({1, channels, height, width}, im.options(), memory_format)
|
||||
.zero_();
|
||||
|
||||
auto col_t = torch_mlu::cnnl::ops::cnnl_contiguous(col.transpose(0, 1));
|
||||
|
||||
const int mask_cnt = mask_h_idx.size(0);
|
||||
// calculate task dimension
|
||||
cnrtDim3_t k_dim;
|
||||
cnrtFunctionType_t k_type;
|
||||
policyFunc(mask_cnt, &k_dim, &k_type);
|
||||
|
||||
// get compute queue
|
||||
auto queue = torch_mlu::getCurQueue();
|
||||
// get ptr of tensors
|
||||
auto im_impl = torch_mlu::getMluTensorImpl(im_);
|
||||
auto im_ptr = im_impl->cnnlMalloc();
|
||||
auto mask_h_idx_impl = torch_mlu::getMluTensorImpl(mask_h_idx);
|
||||
auto mask_h_idx_ptr = mask_h_idx_impl->cnnlMalloc();
|
||||
auto mask_w_idx_impl = torch_mlu::getMluTensorImpl(mask_w_idx);
|
||||
auto mask_w_idx_ptr = mask_w_idx_impl->cnnlMalloc();
|
||||
auto col_t_impl = torch_mlu::getMluTensorImpl(col_t);
|
||||
auto col_t_ptr = col_t_impl->cnnlMalloc();
|
||||
|
||||
// get comput dtype of input
|
||||
cnrtDataType_t data_type = torch_mlu::toCnrtDtype(col.dtype());
|
||||
|
||||
// launch kernel
|
||||
CNLOG(INFO) << "Launch Kernel MLUKernelMaskedCol2imForward<<<" << k_dim.x
|
||||
<< ", " << k_dim.y << ", " << k_dim.z << ">>>";
|
||||
|
||||
KernelMaskedCol2imForward(k_dim, k_type, queue, data_type, col_t_ptr, height,
|
||||
width, channels, mask_h_idx_ptr, mask_w_idx_ptr,
|
||||
mask_cnt, im_ptr);
|
||||
|
||||
im.copy_(im_);
|
||||
}
|
||||
|
||||
void masked_im2col_forward_mlu(const Tensor im, const Tensor mask_h_idx,
|
||||
const Tensor mask_w_idx, Tensor col,
|
||||
const int kernel_h, const int kernel_w,
|
||||
const int pad_h, const int pad_w) {
|
||||
// im: (n, ic, h, w), kernel size (kh, kw)
|
||||
// kernel: (oc, ic * kh * kw), col: (kh * kw * ic, ow * oh)
|
||||
MaskedIm2colForwardMLUKernelLauncher(im, mask_h_idx, mask_w_idx, col,
|
||||
kernel_h, kernel_w, pad_h, pad_w);
|
||||
}
|
||||
|
||||
void masked_col2im_forward_mlu(const Tensor col, const Tensor mask_h_idx,
|
||||
const Tensor mask_w_idx, Tensor im, int height,
|
||||
int width, int channels) {
|
||||
// im: (n, ic, h, w), kernel size (kh, kw)
|
||||
// kernel: (oc, ic * kh * kh), col: (kh * kw * ic, ow * oh)
|
||||
MaskedCol2imForwardMLUKernelLauncher(col, mask_h_idx, mask_w_idx, im, height,
|
||||
width, channels);
|
||||
}
|
||||
|
||||
void masked_im2col_forward_impl(const Tensor im, const Tensor mask_h_idx,
|
||||
const Tensor mask_w_idx, Tensor col,
|
||||
const int kernel_h, const int kernel_w,
|
||||
const int pad_h, const int pad_w);
|
||||
|
||||
void masked_col2im_forward_impl(const Tensor col, const Tensor mask_h_idx,
|
||||
const Tensor mask_w_idx, Tensor im, int height,
|
||||
int width, int channels);
|
||||
|
||||
REGISTER_DEVICE_IMPL(masked_im2col_forward_impl, MLU,
|
||||
masked_im2col_forward_mlu);
|
||||
REGISTER_DEVICE_IMPL(masked_col2im_forward_impl, MLU,
|
||||
masked_col2im_forward_mlu);
|
BIN
tests/data/for_masked_conv2d/masked_conv2d_for_bias.npy
Normal file
BIN
tests/data/for_masked_conv2d/masked_conv2d_for_bias.npy
Normal file
Binary file not shown.
BIN
tests/data/for_masked_conv2d/masked_conv2d_for_input.npy
Normal file
BIN
tests/data/for_masked_conv2d/masked_conv2d_for_input.npy
Normal file
Binary file not shown.
BIN
tests/data/for_masked_conv2d/masked_conv2d_for_mask.npy
Normal file
BIN
tests/data/for_masked_conv2d/masked_conv2d_for_mask.npy
Normal file
Binary file not shown.
BIN
tests/data/for_masked_conv2d/masked_conv2d_for_output.npy
Normal file
BIN
tests/data/for_masked_conv2d/masked_conv2d_for_output.npy
Normal file
Binary file not shown.
BIN
tests/data/for_masked_conv2d/masked_conv2d_for_weight.npy
Normal file
BIN
tests/data/for_masked_conv2d/masked_conv2d_for_weight.npy
Normal file
Binary file not shown.
@ -1,15 +1,41 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
|
||||
|
||||
|
||||
class TestMaskedConv2d:
|
||||
|
||||
def test_masked_conv2d(self):
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
@pytest.mark.parametrize('device', [
|
||||
pytest.param(
|
||||
'cuda',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
|
||||
pytest.param(
|
||||
'mlu',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_MLU_AVAILABLE, reason='requires MLU support'))
|
||||
])
|
||||
def test_masked_conv2d_all_close(self, device):
|
||||
from mmcv.ops import MaskedConv2d
|
||||
input = torch.randn(1, 3, 16, 16, requires_grad=True, device='cuda')
|
||||
mask = torch.randn(1, 16, 16, requires_grad=True, device='cuda')
|
||||
conv = MaskedConv2d(3, 3, 3).cuda()
|
||||
np_input = np.load(
|
||||
'tests/data/for_masked_conv2d/masked_conv2d_for_input.npy')
|
||||
np_mask = np.load(
|
||||
'tests/data/for_masked_conv2d/masked_conv2d_for_mask.npy')
|
||||
np_weight = np.load(
|
||||
'tests/data/for_masked_conv2d/masked_conv2d_for_weight.npy')
|
||||
np_bias = np.load(
|
||||
'tests/data/for_masked_conv2d/masked_conv2d_for_bias.npy')
|
||||
np_output = np.load(
|
||||
'tests/data/for_masked_conv2d/masked_conv2d_for_output.npy')
|
||||
input = torch.tensor(np_input, dtype=torch.float, device=device)
|
||||
mask = torch.tensor(np_mask, dtype=torch.float, device=device)
|
||||
weight = torch.tensor(np_weight, dtype=torch.float, device=device)
|
||||
bias = torch.tensor(np_bias, dtype=torch.float, device=device)
|
||||
conv = MaskedConv2d(3, 3, 3, 1, 1).to(device)
|
||||
conv.weight = torch.nn.Parameter(weight)
|
||||
conv.bias = torch.nn.Parameter(bias)
|
||||
output = conv(input, mask)
|
||||
assert output is not None
|
||||
assert np.allclose(output.data.cpu().numpy(), np_output, 1e-3)
|
||||
|
Loading…
x
Reference in New Issue
Block a user