From e663670a7462a65dac2c9cc2deb8873e56d9da07 Mon Sep 17 00:00:00 2001 From: bdf <36697723+defei-coder@users.noreply.github.com> Date: Sun, 21 Aug 2022 23:21:54 +0800 Subject: [PATCH] [Feature] Support MaskedConv2d with cambricon MLU backend (#2202) * [Feature] Support MaskedConv2d with cambricon MLU backend * [Refactor] Refactor test masked_conv2d code Co-authored-by: budefei --- docs/en/understand_mmcv/ops.md | 2 +- docs/zh_cn/understand_mmcv/ops.md | 2 +- .../ops/csrc/common/mlu/common_mlu_helper.hpp | 10 + .../common/mlu/masked_conv2d_mlu_kernel.mlu | 181 ++++++++++++++ .../csrc/pytorch/mlu/masked_conv2d_mlu.cpp | 226 ++++++++++++++++++ .../masked_conv2d_for_bias.npy | Bin 0 -> 140 bytes .../masked_conv2d_for_input.npy | Bin 0 -> 3200 bytes .../masked_conv2d_for_mask.npy | Bin 0 -> 1152 bytes .../masked_conv2d_for_output.npy | Bin 0 -> 3200 bytes .../masked_conv2d_for_weight.npy | Bin 0 -> 452 bytes tests/test_ops/test_masked_conv2d.py | 40 +++- 11 files changed, 452 insertions(+), 9 deletions(-) create mode 100755 mmcv/ops/csrc/common/mlu/masked_conv2d_mlu_kernel.mlu create mode 100755 mmcv/ops/csrc/pytorch/mlu/masked_conv2d_mlu.cpp create mode 100644 tests/data/for_masked_conv2d/masked_conv2d_for_bias.npy create mode 100644 tests/data/for_masked_conv2d/masked_conv2d_for_input.npy create mode 100644 tests/data/for_masked_conv2d/masked_conv2d_for_mask.npy create mode 100644 tests/data/for_masked_conv2d/masked_conv2d_for_output.npy create mode 100644 tests/data/for_masked_conv2d/masked_conv2d_for_weight.npy diff --git a/docs/en/understand_mmcv/ops.md b/docs/en/understand_mmcv/ops.md index f33b7d42e..15aadd298 100644 --- a/docs/en/understand_mmcv/ops.md +++ b/docs/en/understand_mmcv/ops.md @@ -28,7 +28,7 @@ We implement common ops used in detection, segmentation, etc. | GroupPoints | | √ | | | | Iou3d | | √ | | | | KNN | | √ | | | -| MaskedConv | | √ | | | +| MaskedConv | | √ | √ | | | MergeCells | | √ | | | | MinAreaPolygon | | √ | | | | ModulatedDeformConv2d | √ | √ | | | diff --git a/docs/zh_cn/understand_mmcv/ops.md b/docs/zh_cn/understand_mmcv/ops.md index cebad9bca..fdcc8bab5 100644 --- a/docs/zh_cn/understand_mmcv/ops.md +++ b/docs/zh_cn/understand_mmcv/ops.md @@ -28,7 +28,7 @@ MMCV 提供了检测、分割等任务中常用的算子 | GroupPoints | | √ | | | | Iou3d | | √ | | | | KNN | | √ | | | -| MaskedConv | | √ | | | +| MaskedConv | | √ | √ | | | MergeCells | | √ | | | | MinAreaPolygon | | √ | | | | ModulatedDeformConv2d | √ | √ | | | diff --git a/mmcv/ops/csrc/common/mlu/common_mlu_helper.hpp b/mmcv/ops/csrc/common/mlu/common_mlu_helper.hpp index 89d015109..e59099ae8 100644 --- a/mmcv/ops/csrc/common/mlu/common_mlu_helper.hpp +++ b/mmcv/ops/csrc/common/mlu/common_mlu_helper.hpp @@ -35,6 +35,16 @@ #define CEIL_ALIGN(x, y) (((x) + (y)-1) / (y) * (y)) +template +__mlu_func__ inline scalar_t min(scalar_t a, scalar_t b) { + return a < b ? a : b; +} + +template +__mlu_func__ inline scalar_t max(scalar_t a, scalar_t b) { + return a > b ? a : b; +} + /*! * @brief loads data from global DRAM to NRAM with 2D pattern. * diff --git a/mmcv/ops/csrc/common/mlu/masked_conv2d_mlu_kernel.mlu b/mmcv/ops/csrc/common/mlu/masked_conv2d_mlu_kernel.mlu new file mode 100755 index 000000000..1356a799a --- /dev/null +++ b/mmcv/ops/csrc/common/mlu/masked_conv2d_mlu_kernel.mlu @@ -0,0 +1,181 @@ +/************************************************************************* + * Copyright (C) 2022 Cambricon. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include "common_mlu_helper.hpp" + +__nram__ char nram_buffer[MAX_NRAM_SIZE]; + +template +__mlu_func__ void MLUUnion1MaskedIm2colForward( + const T *feature, const int height, const int width, const int channels, + const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, + const int32_t *mask_h_idx, const int32_t *mask_w_idx, const int mask_cnt, + T *data_col) { + for (int index = taskId; index < mask_cnt; index += taskDim) { + const int h_col = mask_h_idx[index]; + const int w_col = mask_w_idx[index]; + const int h_offset = h_col - pad_h; + const int w_offset = w_col - pad_w; + int h_start = h_offset; + int h_end = h_offset + kernel_h - 1; + int w_start = w_offset; + int w_end = w_start + kernel_w - 1; + if (h_start >= height || w_start >= width || h_end < 0 || w_end < 0) { + continue; + } else { + int h_start_valid = max(0, h_start); + int h_end_valid = min(height - 1, h_end); + int w_start_valid = max(0, w_start); + int w_end_valid = min(width - 1, w_end); + __memcpy( + data_col + index * kernel_h * kernel_w * channels + + ((h_start_valid - h_start) * kernel_w + + (w_start_valid - w_start)) * + channels, + feature + h_start_valid * width * channels + w_start_valid * channels, + (w_end_valid - w_start_valid + 1) * channels * sizeof(T), GDRAM2GDRAM, + kernel_w * channels * sizeof(T), width * channels * sizeof(T), + h_end_valid - h_start_valid); + } + } +} + +template +__mlu_func__ void MLUUnion1MaskedCol2imForward(const T *col, const int height, + const int width, + const int channels, + const int32_t *mask_h_idx, + const int32_t *mask_w_idx, + const int mask_cnt, T *im) { + const int channels_max_num_nram = MAX_NRAM_SIZE / sizeof(T); + if (channels <= channels_max_num_nram) { + const int deal_num = channels_max_num_nram / channels; + int mask_per_core = mask_cnt / taskDim; + const int mask_remain = mask_cnt % taskDim; + mask_per_core += taskId < mask_remain ? 1 : 0; + int index_start = taskId < mask_remain + ? taskId * mask_per_core + : taskId * mask_per_core + mask_remain; + int loop = mask_per_core / deal_num; + int remain_num = mask_per_core % deal_num; + T *nram_col = (T *)nram_buffer; + for (int index = 0; index < loop; ++index) { + int cur_index = index_start + index * deal_num; + __memcpy(nram_col, col + cur_index * channels, + deal_num * channels * sizeof(T), GDRAM2NRAM); + for (int i = 0; i < deal_num; ++i) { + int mask_index = cur_index + i; + const int h_im = mask_h_idx[mask_index]; + const int w_im = mask_w_idx[mask_index]; + // if(h_im>=height || w_im>=width) continue; + __memcpy(im + (h_im * width + w_im) * channels, nram_col + i * channels, + channels * sizeof(T), NRAM2GDRAM); + } + } + if (remain_num > 0) { + int cur_index = index_start + loop * deal_num; + __memcpy(nram_col, col + cur_index * channels, + remain_num * channels * sizeof(T), GDRAM2NRAM); + for (int i = 0; i < remain_num; ++i) { + int mask_index = cur_index + i; + const int h_im = mask_h_idx[mask_index]; + const int w_im = mask_w_idx[mask_index]; + // if(h_im>=height || w_im>=width) continue; + __memcpy(im + (h_im * width + w_im) * channels, nram_col + i * channels, + channels * sizeof(T), NRAM2GDRAM); + } + } + } else { + for (int index = taskId; index < mask_cnt; index += taskDim) { + const int m_index = index % mask_cnt; + const int h_im = mask_h_idx[m_index]; + const int w_im = mask_w_idx[m_index]; + // if(h_im>=height || w_im>=width) continue; + __memcpy(im + (h_im * width + w_im) * channels, col + index * channels, + channels * sizeof(T), GDRAM2GDRAM); + } + } +} + +__mlu_global__ void MLUKernelMaskedIm2colForward( + const void *feature, const int height, const int width, const int channels, + const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, + const void *mask_h_idx, const void *mask_w_idx, const int mask_cnt, + void *data_col, const cnrtDataType_t data_dtype) { + if (coreId == 0x80) { + return; + } + + switch (data_dtype) { + case CNRT_FLOAT16: { + MLUUnion1MaskedIm2colForward((half *)feature, height, width, channels, + kernel_h, kernel_w, pad_h, pad_w, + (int32_t *)mask_h_idx, (int32_t *)mask_w_idx, + mask_cnt, (half *)data_col); + }; break; + case CNRT_FLOAT32: { + MLUUnion1MaskedIm2colForward((float *)feature, height, width, channels, + kernel_h, kernel_w, pad_h, pad_w, + (int32_t *)mask_h_idx, (int32_t *)mask_w_idx, + mask_cnt, (float *)data_col); + }; break; + default: { + break; + } + } +} + +__mlu_global__ void MLUKernelMaskedCol2imForward( + const void *col, const int height, const int width, const int channels, + const void *mask_h_idx, const void *mask_w_idx, const int mask_cnt, + void *im, const cnrtDataType_t data_dtype) { + if (coreId == 0x80) { + return; + } + switch (data_dtype) { + case CNRT_FLOAT16: { + MLUUnion1MaskedCol2imForward((half *)col, height, width, channels, + (int32_t *)mask_h_idx, (int32_t *)mask_w_idx, + mask_cnt, (half *)im); + }; break; + case CNRT_FLOAT32: { + MLUUnion1MaskedCol2imForward((float *)col, height, width, channels, + (int32_t *)mask_h_idx, (int32_t *)mask_w_idx, + mask_cnt, (float *)im); + }; break; + default: { + break; + } + } +} + +void KernelMaskedIm2colForward( + cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, + cnrtDataType_t k_dtype, const void *im_ptr, const int height, + const int width, const int channels, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const void *mask_h_idx_ptr, + const void *mask_w_idx_ptr, const int mask_cnt, void *col_ptr) { + MLUKernelMaskedIm2colForward<<>>( + im_ptr, height, width, channels, kernel_h, kernel_w, pad_h, pad_w, + mask_h_idx_ptr, mask_w_idx_ptr, mask_cnt, col_ptr, k_dtype); +} + +void KernelMaskedCol2imForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, + cnrtQueue_t queue, cnrtDataType_t k_dtype, + const void *col_ptr, const int height, + const int width, const int channels, + const void *mask_h_idx_ptr, + const void *mask_w_idx_ptr, const int mask_cnt, + void *im_ptr) { + MLUKernelMaskedCol2imForward<<>>( + col_ptr, height, width, channels, mask_h_idx_ptr, mask_w_idx_ptr, + mask_cnt, im_ptr, k_dtype); +} diff --git a/mmcv/ops/csrc/pytorch/mlu/masked_conv2d_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/masked_conv2d_mlu.cpp new file mode 100755 index 000000000..e7842b3a1 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/mlu/masked_conv2d_mlu.cpp @@ -0,0 +1,226 @@ +/************************************************************************* + * Copyright (C) 2022 Cambricon. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include "pytorch_device_registry.hpp" +#include "pytorch_mlu_helper.hpp" + +void KernelMaskedIm2colForward( + cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, + cnrtDataType_t k_dtype, const void *im_ptr, const int height, + const int width, const int channels, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const void *mask_h_idx_ptr, + const void *mask_w_idx_ptr, const int mask_cnt, void *col_ptr); + +void KernelMaskedCol2imForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, + cnrtQueue_t queue, cnrtDataType_t k_dtype, + const void *col_ptr, const int height, + const int width, const int channels, + const void *mask_h_idx_ptr, + const void *mask_w_idx_ptr, const int mask_cnt, + void *im_ptr); + +// policy function +static void policyFunc(const int mask_cnt, cnrtDim3_t *k_dim, + cnrtFunctionType_t *k_type) { + const size_t cluster_num = torch_mlu::getDeviceAttr(cnrtAttrClusterCount); + const size_t core_num = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); + const size_t task_dim = CEIL_ALIGN(mask_cnt, core_num); + k_dim->x = core_num; + k_dim->y = + (task_dim / core_num) > cluster_num ? cluster_num : (task_dim / core_num); + k_dim->z = 1; + *k_type = CNRT_FUNC_TYPE_UNION1; +} + +void MaskedIm2colForwardMLUKernelLauncher(const Tensor im, + const Tensor mask_h_idx, + const Tensor mask_w_idx, Tensor col, + const int kernel_h, + const int kernel_w, const int pad_h, + const int pad_w) { + // Check dtype. + TORCH_CHECK(im.scalar_type() == at::kFloat || im.scalar_type() == at::kHalf, + "im type should be Float or Half, got ", im.scalar_type(), "."); + TORCH_CHECK(mask_h_idx.scalar_type() == at::kInt || + mask_h_idx.scalar_type() == at::kLong, + "mask_h_idx type should be Int or Long, got ", + mask_h_idx.scalar_type(), "."); + TORCH_CHECK(mask_w_idx.scalar_type() == at::kInt || + mask_w_idx.scalar_type() == at::kLong, + "mask_w_idx type should be Int or Long, got ", + mask_w_idx.scalar_type(), "."); + TORCH_CHECK(kernel_h > 0, "kernel_h should greater than 0, got ", kernel_h, + "."); + TORCH_CHECK(kernel_w > 0, "kernel_w should greater than 0, got ", kernel_w, + "."); + + // zero element check + TORCH_CHECK(im.numel() > 0, "im.numel should greater than zero, got ", + im.numel(), "."); + TORCH_CHECK(col.size(0) > 0, "col.size(0) should greater than zero, got ", + col.size(0), "."); + + // large tensor check + const size_t max_input_num = 2147483648; // 2^31, 2G num + TORCH_CHECK(im.numel() < max_input_num, + "im.numel() should be less than 2147483648, got ", im.numel(), + "."); + TORCH_CHECK(col.numel() < max_input_num, + "col.numel() should be less than 2147483648, got ", col.numel(), + "."); + + const int channels = im.size(1); + const int height = im.size(2); + const int width = im.size(3); + const int mask_cnt = mask_h_idx.size(0); + + // auto im_t = im.permute({0, 2, 3, 1}).contiguous(); + auto memory_format = + torch_mlu::cnnl::ops::get_channels_last_memory_format(im.dim()); + auto im_ = torch_mlu::cnnl::ops::cnnl_contiguous(im, memory_format); + auto col_ = + at::zeros({mask_cnt, kernel_h * kernel_w, channels}, col.options()); + // calculate task dimension + cnrtDim3_t k_dim; + cnrtFunctionType_t k_type; + policyFunc(mask_cnt, &k_dim, &k_type); + + // get compute queue + auto queue = torch_mlu::getCurQueue(); + // get ptr of tensors + auto im_impl = torch_mlu::getMluTensorImpl(im_); + auto im_ptr = im_impl->cnnlMalloc(); + auto mask_h_idx_impl = torch_mlu::getMluTensorImpl(mask_h_idx); + auto mask_h_idx_ptr = mask_h_idx_impl->cnnlMalloc(); + auto mask_w_idx_impl = torch_mlu::getMluTensorImpl(mask_w_idx); + auto mask_w_idx_ptr = mask_w_idx_impl->cnnlMalloc(); + auto col_impl = torch_mlu::getMluTensorImpl(col_); + auto col_ptr = col_impl->cnnlMalloc(); + + // get comput dtype of input + cnrtDataType_t data_type = torch_mlu::toCnrtDtype(im.dtype()); + + // launch kernel + CNLOG(INFO) << "Launch Kernel MLUKernelMaskedIm2colForward<<<" << k_dim.x + << ", " << k_dim.y << ", " << k_dim.z << ">>>"; + KernelMaskedIm2colForward(k_dim, k_type, queue, data_type, im_ptr, height, + width, channels, kernel_h, kernel_w, pad_h, pad_w, + mask_h_idx_ptr, mask_w_idx_ptr, mask_cnt, col_ptr); + + col.copy_(col_.permute({2, 1, 0}) + .reshape({channels * kernel_h * kernel_w, mask_cnt}) + .contiguous()); +} + +void MaskedCol2imForwardMLUKernelLauncher(const Tensor col, + const Tensor mask_h_idx, + const Tensor mask_w_idx, Tensor im, + const int height, const int width, + const int channels) { + // Check dtype. + TORCH_CHECK(col.scalar_type() == at::kFloat || col.scalar_type() == at::kHalf, + "col type should be Float or Half, got ", col.scalar_type(), "."); + TORCH_CHECK(mask_h_idx.scalar_type() == at::kInt || + mask_h_idx.scalar_type() == at::kLong, + "mask_h_idx type should be Int or Long, got ", + mask_h_idx.scalar_type(), "."); + TORCH_CHECK(mask_w_idx.scalar_type() == at::kInt || + mask_w_idx.scalar_type() == at::kLong, + "mask_w_idx type should be Int or Long, got ", + mask_w_idx.scalar_type(), "."); + + // zero element check + TORCH_CHECK(im.numel() > 0, "im.numel should greater than zero, got ", + im.numel(), "."); + TORCH_CHECK(col.size(0) > 0, "col.size(0) should greater than zero, got ", + col.size(0), "."); + + // large tensor check + const size_t max_input_num = 2147483648; // 2^31, 2G num + TORCH_CHECK(im.numel() < max_input_num, + "im.numel() should be less than 2147483648, got ", im.numel(), + "."); + TORCH_CHECK(col.numel() < max_input_num, + "col.numel() should be less than 2147483648, got ", col.numel(), + "."); + + auto memory_format = + torch_mlu::cnnl::ops::get_channels_last_memory_format(im.dim()); + at::Tensor im_ = + at::empty({1, channels, height, width}, im.options(), memory_format) + .zero_(); + + auto col_t = torch_mlu::cnnl::ops::cnnl_contiguous(col.transpose(0, 1)); + + const int mask_cnt = mask_h_idx.size(0); + // calculate task dimension + cnrtDim3_t k_dim; + cnrtFunctionType_t k_type; + policyFunc(mask_cnt, &k_dim, &k_type); + + // get compute queue + auto queue = torch_mlu::getCurQueue(); + // get ptr of tensors + auto im_impl = torch_mlu::getMluTensorImpl(im_); + auto im_ptr = im_impl->cnnlMalloc(); + auto mask_h_idx_impl = torch_mlu::getMluTensorImpl(mask_h_idx); + auto mask_h_idx_ptr = mask_h_idx_impl->cnnlMalloc(); + auto mask_w_idx_impl = torch_mlu::getMluTensorImpl(mask_w_idx); + auto mask_w_idx_ptr = mask_w_idx_impl->cnnlMalloc(); + auto col_t_impl = torch_mlu::getMluTensorImpl(col_t); + auto col_t_ptr = col_t_impl->cnnlMalloc(); + + // get comput dtype of input + cnrtDataType_t data_type = torch_mlu::toCnrtDtype(col.dtype()); + + // launch kernel + CNLOG(INFO) << "Launch Kernel MLUKernelMaskedCol2imForward<<<" << k_dim.x + << ", " << k_dim.y << ", " << k_dim.z << ">>>"; + + KernelMaskedCol2imForward(k_dim, k_type, queue, data_type, col_t_ptr, height, + width, channels, mask_h_idx_ptr, mask_w_idx_ptr, + mask_cnt, im_ptr); + + im.copy_(im_); +} + +void masked_im2col_forward_mlu(const Tensor im, const Tensor mask_h_idx, + const Tensor mask_w_idx, Tensor col, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w) { + // im: (n, ic, h, w), kernel size (kh, kw) + // kernel: (oc, ic * kh * kw), col: (kh * kw * ic, ow * oh) + MaskedIm2colForwardMLUKernelLauncher(im, mask_h_idx, mask_w_idx, col, + kernel_h, kernel_w, pad_h, pad_w); +} + +void masked_col2im_forward_mlu(const Tensor col, const Tensor mask_h_idx, + const Tensor mask_w_idx, Tensor im, int height, + int width, int channels) { + // im: (n, ic, h, w), kernel size (kh, kw) + // kernel: (oc, ic * kh * kh), col: (kh * kw * ic, ow * oh) + MaskedCol2imForwardMLUKernelLauncher(col, mask_h_idx, mask_w_idx, im, height, + width, channels); +} + +void masked_im2col_forward_impl(const Tensor im, const Tensor mask_h_idx, + const Tensor mask_w_idx, Tensor col, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w); + +void masked_col2im_forward_impl(const Tensor col, const Tensor mask_h_idx, + const Tensor mask_w_idx, Tensor im, int height, + int width, int channels); + +REGISTER_DEVICE_IMPL(masked_im2col_forward_impl, MLU, + masked_im2col_forward_mlu); +REGISTER_DEVICE_IMPL(masked_col2im_forward_impl, MLU, + masked_col2im_forward_mlu); diff --git a/tests/data/for_masked_conv2d/masked_conv2d_for_bias.npy b/tests/data/for_masked_conv2d/masked_conv2d_for_bias.npy new file mode 100644 index 0000000000000000000000000000000000000000..c60951a1d64d67dc8717284753e1bd33f95f6a48 GIT binary patch literal 140 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlV+l>qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= iXCxM+0{I%oI+{8PwF(pfE}aMZ`&H^E?B{%9Zw~;zoglFQ literal 0 HcmV?d00001 diff --git a/tests/data/for_masked_conv2d/masked_conv2d_for_input.npy b/tests/data/for_masked_conv2d/masked_conv2d_for_input.npy new file mode 100644 index 0000000000000000000000000000000000000000..f45c03457d84e47a6c737dc2f362a279634d0108 GIT binary patch literal 3200 zcmb7_`CH9-AI3`@!ed0*wTxCOQB%tKe%^;kPJ@(^^rQtTD(w-YRD?=YqM_18HE9!C zobTs-TBc1Zl_`Z5GDRXOp?J)n@Z8sR|8T#4y|3%O6D)Rav)nBz5+rh5W528aK0gf; zWex5D0}X9ujRW3(0e&uroxJ__yZ+0~Ts-|r56AL%fQlbOf_(zCrzcPF%EriX|YQ*KiYkKI)QsMZ9GIC|m5`!Hg zFgU9NCv|M_Tfj~_CirvPB5%iq#s!<*DIZ8sdaH$~UX3lrLN>Tp+4G&J|6u&sKBuvEey3^v>* zXF~hPNvBPyG%X5;`L`fRr<%NUJjlMs>B81!l>`i@BPoX0xBd6hO#;9JawcJ@N2T_cDVO%uhWWcFIF zKk#ExnFV{cqWN(lmg%*kS1lJSig$x;Ct*|u=Lm}>^-(>{3yhA{5qFoBVAWko_quO_ zaJM-8`1Ul|tDazI+)g9!Wdj7;Y$Rf)_t8CkidO$BkLU8VK{2@qYy=|=vwka0Fm^?k zkz{m9@WHI;i>T?=LubsVz*tiSkLxF~#rngWaIS=3Nhaz19t__L-{GBu%kXJ?AikN`OQyWvfQtvy?4MSm;~eh~uarzDV*I>h}Ye*sev8&5;D4PZKpC2GfdiAh^5o$BubyXbQm zbJ-Cz_e%=5ge&rXzhh064i{17KrwhXHBOj{PUc0g2aey1!bOAP$S*wwm4X#S?}{q4 zKMF%{k!SST5kJV(h=U>iUqpMU0(g)~5C+fV?Vq=rdM>+-;iG2-dYovC-4X@W`xA(7 z=Lxne-we`1GVy(UA{uMyL!L`Bs!XUtnCBImGIJa>wUUW<0Wx>}#V~v7HoR8OMgP=t zc#`>v)_zH+#v@me-+U0?FW3N%W689E&8M+S8T34BLz7pWMP2h0D38cz(>#_y>iZC*B7U!OR%kK{4W4E0b@@F(`%38JH2MYOC( z9wyua1Y->sA?LCHH;O+bBTeD-_oh37m*!DK)}#yr{er+_`73fD?h#{C$HC+}z@`&R zcp6^s@t{X0ep=oLH+L7I!|*wfQI1B_MOj#~JR4V~h@nEO30*M~O6%X>1c%4ExO0gS z)MTpR%PY}P_2eW`dDII@jb~8irZaMvnBa{37Lfnh6z9h7rRJX>(k#ndER8CJ)`$1V z^O6Sa=h`DbEf0*Z)&buz6;!Wsz)iCB_S#4~8MFq=RP2ei!a0aeki_s0skF>qoZig~g-;1- zn7E^w7CFY^T=7w9 zE3`1ZpM$ts>#v~;qmIMc5#&>HI(-a@u*p;mjlC~}U6Kdon7pHx2E>ISqYKGLmo@CB z_{(^~NJ4OZZV>8sxWdZt<=EhN6V&q`Q@JY%jQ!F*wEN}^{3KQZPEq-+#K~$>BYF(2 z&RgIFk%pnIda%_Y3&W1ccE%i4*~svE;jlP4tHZ-Vjiw}w?& z`WPjiVx(EmYF-K2aR+ zB^D9uY5HQ8ZlCU?MVeD&=yVQoiQ5ZZ4wAfUv-=@qZWh&@;gf-$Bk=CCn(*bsBqZ{3Ql*H;P(no`1O)A?rVJ$v5K8wbgpPKRg^g84cGrB(HGyI+Z7%%Tj zg5Pitemh?V#=}`KJJo=fR+EASMrmwZ*LutvxDFqmX2Ni(6>R4|Vr*|cpl(VAKL_@%d+s4Xdn9I6F7-F4_CelFWK zJ&yUaE)aELH&K+Ur_9Sc&}?>*tIg~t7cJ81QQvhC^WipI^uU8`C_N>JIU9+x^82V- zSs46hR~U+^_u$L;J@~1x8UIihp?`=d(!9wq;M(cHriME7d}RU2{n?l-Z3KzYQK*>P z4mlcQ3}?Y&{HMbiU&0~MRuq5>GyJf(QdPK4QW5%oZ-(oRACY5v1MQ6(@M=>aSf<8XJhoij?e z(HP|JI{>cF?ICG$BMda>(YpBAsJFk0sU02mbgF?iem-O8G;WyYUFtl=Gvrj>G7qasgH1WymtmWxzT$(F5_JG&&>$b3$z) zWo zahn(*Zm&nkgvc~D{;kSdO~+uB^J3uU8&j@AI{fE`2JM$W&Te$eWSYkWw5@wBR#!iU IN+)mlABX)c8UO$Q literal 0 HcmV?d00001 diff --git a/tests/data/for_masked_conv2d/masked_conv2d_for_mask.npy b/tests/data/for_masked_conv2d/masked_conv2d_for_mask.npy new file mode 100644 index 0000000000000000000000000000000000000000..4c074471e7b7411868d1ac62ed5a56a93ab6f216 GIT binary patch literal 1152 zcmbW0`A-ve7>7Be1BO*%+?GiIhqQpWawue^{l1+;97hq9b1IfW3v=aa9mD~RV3A|t z+Ca2`r5-3EAUG~dzb`|X5FijMEJn)qvSo9{+u6>xY;ltGh%`F!P?9uO^tSs%OJt(eE{l&&7OlF2tFx`E$6Fko zZ8L5E&&)ZKI_fcH0_z&CAhYCqY9@8YiA$BVqHQbw@lm^WOyR*BG}wz5P8HJQbNk?w z9~W0no1@&{29El!P^My;4(zhQM;HMWzD9i8DOJ~W@FX-XE3kJDn@UVS=yqT4iB3aR zu)j0{yxZ^5&C_4vCUHN6UfiI~$g6>&s1Dwqpb)~6ne$!l_d%@jIm`*s&?s2~tegLb zr`9fLd0!nS+*t?Qa?ljk&!^(Ljk%=QT7_l$08B5;#+?s}Vf2KK==*Pz)A$rQQcK(s zewRB(m6-XY6sM|du=BbNsx8D|eaRRezp2N>xlO#?=Vr+Zu{VsL&ml~r#&TW^ZE}LZxV`X@~?HS2V&7vpgFM{++ zF79aU1$E$2Xz$me@Y+jlvW$cAzb!)Qd~X zCPVndCJyDdQg9?xh~BDcn(;*wZM0q_tuK4PqT~b)3>D*!Z}oW7-4K12jk-apz~}oy zx_&k4f!Wjv>a_hHRJ@Y#w%t3A=i8MuYN`u*p&TfWqMe*UWCtdR$-8kV{I`Re=4a9C zj^AL3D223c4MAaDIZP-rP&aZASmQ}_RC1Zrjvat^?6MJ$IpW5aDh%zg!=HaK#zQW3 z_~4EVYMVCj_ogVROe2IoK_!f`@@UzQf!M(Qm@cpjvF@rwGZ~(V&)6e4oW&sh58X)3 z?^V=#Z6_YJ8X$DbF-YHELy9Kk5Ehe&NtIFfX;>{bc2C1%r!Ndj9nq|>0zJHSylQNQ zv+u8G(3Jvnob@|ppeeirZ_V~QY5 literal 0 HcmV?d00001 diff --git a/tests/data/for_masked_conv2d/masked_conv2d_for_output.npy b/tests/data/for_masked_conv2d/masked_conv2d_for_output.npy new file mode 100644 index 0000000000000000000000000000000000000000..4741265afb3411221c184b2665737b5dab755692 GIT binary patch literal 3200 zcmb7FdpwnS8h>@s#C~c@%#6qhp@Vjs#TbkCIWZc>w3D)3#**7vV{Dy8mz8Eh4CTWS z$+RvhH6@`>W|ic9PDz8NoMJ+w6nC4St?gVMn?uN4_FZy5=y#9U_?~tpdJ+#s0tnA zReeXg1>dH0S${7lS2*l6I2^Ez5k zerl_9=9Zo&oYkq#@8gn$oM*Xc+qFoVV}22(oX4MS(PmN|4vu25*mn!+*8K`YjN^9u zS!=|;HNae9Xw-De9C;@dzA6FR_hB3ZSS!h z=^{pZcnMM4iRG*4_TNExVM+(Ufx5u z$e2Bor>ByT^xE=|StixNS@UyLi7RO%VY2-VxZ3Z83^$9m!;VVm|d2wQcR zorQeZ=I{EGrw6Lvm*c#k3>da%L3?#+eUYSjJDvffwbn?PvL80#%^7*p9MXyo7~$@( zj{x=midDipeH?DjeTQ8KY1_710#!mo?jP9uO}&jTCoC=KVY#=h5`i?rZFf0R-MM%% ztK6@xLzTqma6YUL)IOM#1KZuS*i5spyfCErF#ojm+nPzME|a`!1%q1BAkN9%Re z6c?Tr!nxz9Q}I;roqJShIJ&6D!@rcBW9pJgK)IJ;H~j+?`7Mu$nmV5_#M?@ji~8xj zbL=waO}nKUc2C4S)%#;zrRV6`FR(o6R^jf$)+EHFVrr5vd*3H_B}(^+vJ7UmE?H|* z@Hyl51V3SOH;&WCu`3fH$t@n|XS!q7*&-?9r1eNhJLbW7yVf+8_i4^xyz|)~hwWck zr6n}@1cH?rGM^j7;}Ztb=Z`qSbYLq?OfEp`$InZ^r>O}T(^3nx)UKZSSVb=+!qvAG9D@;CI~W!Dq@ z_p?ZR=d1@1Xgdn19|;@Mbu|u42GJ|2yCgL9Dxl)t z7yRV|=C;^}|tYMo~&Hep>QG zFj^HPWtMeYpqR~`Sqowxv_23Ey9?MoF~4dionejte3l8iUCd@EGXkVF;##%64$FU& zdj;F{b|UHLdh+ncP#rY3X+p2`4buJ|jtd2w4{tCYT&jnw<&OnboG&DIHM-^pbPCRT zQv|=TPuPrI1FydeIk<^=v`3=LHQ1a|&xiJ^LEUAJJ3k)9yevfpyAH~{is+Ln_H0T2 z^umkp#XVQ|NM|Ik&@77YIT*>bV;TP*=8=E$pUu#Ez*Ew%j7tF0ji$w5oPPsoE$V~Z z-WuWbTAu0TIV2R|oc4!|7hh4aYFw7fYU%3ucc~3w;s4@OCaJ-0)IJ-uVeS$UTSu6dBw+ zs}@9F{@*gNsgDAYU%&oD>ID-O-*jFJTWb%&j)9@?2+hV2w$}7~HPV{gt_HByHDo<; z!UexPcF*>0S&x<45z<~mF2`9T1Shikls`$dE^t8#iz6LocpBOmAEI?A=R9L(TZ=f4PHAdd4#Da0)#xr zq2B`CN;~1K{4_A*4gbdK?89^aQ=n+OWUmyI6(&%TX#y#3bBFC?^1)2(evlw#kN=zo zi9TZ)4?k0npgc1yf9IYBq^{Ox@3@t=2FG`Q$#-6RBxT}1+9vhrZSO3ep}0{Ap_^Cm z#C^C95cT~Q1FVm|V}hieKE8`p>P5Vjxdo};_*f4v%}xdCJKQgWTX_{6K35yBRH3;ze{;P5^G literal 0 HcmV?d00001 diff --git a/tests/data/for_masked_conv2d/masked_conv2d_for_weight.npy b/tests/data/for_masked_conv2d/masked_conv2d_for_weight.npy new file mode 100644 index 0000000000000000000000000000000000000000..50f04b53f01297845c95a690bfbf87ce3a88e523 GIT binary patch literal 452 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlV+l>qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I%oItoyvsiRP_kBXhe)GVs`=$z> zv`@XXeBYb7mG%>2lI&|j%bf@83Kp+5Xh~W&7-nZtmkLx3Iq+&~GniUAOPgLM3}Gp{;g@m45I4 z+23Wq*~i{iW8s8-fsEz*f4+aa|M05@yDzp!_PJV*i+TqJ6P*xBcbn zaJ$37Z|t_oSnj{Fu+jcvw5Z*WgbVxhR+igH*KD?vvfpI4AZ>~5v2=AirQ;LqeO4CR zT?~@5PuAVI??0D`U7O65eSR}4_Ae0*+&3xwtlf13&ixe!P3_*f`r9x5bHa{qC;$Gk zHu=5Bv|9Gh4N=&CDZzaIiGZ$sk@LmvpB?VoXVCH9E}-D;{@Pi`?W^{D-5*`wxeoxl C$-Nr@ literal 0 HcmV?d00001 diff --git a/tests/test_ops/test_masked_conv2d.py b/tests/test_ops/test_masked_conv2d.py index 4516b22e9..a292f6a4f 100644 --- a/tests/test_ops/test_masked_conv2d.py +++ b/tests/test_ops/test_masked_conv2d.py @@ -1,15 +1,41 @@ # Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import pytest import torch +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE + class TestMaskedConv2d: - def test_masked_conv2d(self): - if not torch.cuda.is_available(): - return + @pytest.mark.parametrize('device', [ + pytest.param( + 'cuda', + marks=pytest.mark.skipif( + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), + pytest.param( + 'mlu', + marks=pytest.mark.skipif( + not IS_MLU_AVAILABLE, reason='requires MLU support')) + ]) + def test_masked_conv2d_all_close(self, device): from mmcv.ops import MaskedConv2d - input = torch.randn(1, 3, 16, 16, requires_grad=True, device='cuda') - mask = torch.randn(1, 16, 16, requires_grad=True, device='cuda') - conv = MaskedConv2d(3, 3, 3).cuda() + np_input = np.load( + 'tests/data/for_masked_conv2d/masked_conv2d_for_input.npy') + np_mask = np.load( + 'tests/data/for_masked_conv2d/masked_conv2d_for_mask.npy') + np_weight = np.load( + 'tests/data/for_masked_conv2d/masked_conv2d_for_weight.npy') + np_bias = np.load( + 'tests/data/for_masked_conv2d/masked_conv2d_for_bias.npy') + np_output = np.load( + 'tests/data/for_masked_conv2d/masked_conv2d_for_output.npy') + input = torch.tensor(np_input, dtype=torch.float, device=device) + mask = torch.tensor(np_mask, dtype=torch.float, device=device) + weight = torch.tensor(np_weight, dtype=torch.float, device=device) + bias = torch.tensor(np_bias, dtype=torch.float, device=device) + conv = MaskedConv2d(3, 3, 3, 1, 1).to(device) + conv.weight = torch.nn.Parameter(weight) + conv.bias = torch.nn.Parameter(bias) output = conv(input, mask) - assert output is not None + assert np.allclose(output.data.cpu().numpy(), np_output, 1e-3)