From 9b11e560f3f8b4deb87afc85d1287c12d20e99a2 Mon Sep 17 00:00:00 2001 From: tudejiang79 <57201278+tudejiang79@users.noreply.github.com> Date: Fri, 19 Aug 2022 21:19:40 +0800 Subject: [PATCH] [Fix] Fix roi_align_rotated op of MLU (#2210) * [Fix] roi_align_rotated codes * [Fix]: fix code style * [Fix]: fix code style * [Fix]: fix code style --- .../mlu/roi_align_rotated_mlu_kernel.mlu | 154 ++++++++++-------- .../pytorch/mlu/roi_align_rotated_mlu.cpp | 4 +- 2 files changed, 88 insertions(+), 70 deletions(-) mode change 100644 => 100755 mmcv/ops/csrc/pytorch/mlu/roi_align_rotated_mlu.cpp diff --git a/mmcv/ops/csrc/common/mlu/roi_align_rotated_mlu_kernel.mlu b/mmcv/ops/csrc/common/mlu/roi_align_rotated_mlu_kernel.mlu index 7f05b525a..9356776c5 100644 --- a/mmcv/ops/csrc/common/mlu/roi_align_rotated_mlu_kernel.mlu +++ b/mmcv/ops/csrc/common/mlu/roi_align_rotated_mlu_kernel.mlu @@ -25,10 +25,10 @@ __mlu_func__ void swap(T &a, T &b) { template __mlu_func__ void bilinearInterpolate(const int input_height, - const int input_width, T x, T y, - const T zero_sign, T *w1, T *w2, T *w3, - T *w4, int *x_low, int *x_high, - int *y_low, int *y_high, bool *empty) { + const int input_width, T x, T y, T *w1, + T *w2, T *w3, T *w4, int *x_low, + int *x_high, int *y_low, int *y_high, + bool *empty) { // deal with case that the point is out of feature map boundary if (y < -1.0 || y > input_height || x < -1.0 || x > input_width) { *empty = true; @@ -58,10 +58,11 @@ __mlu_func__ void bilinearInterpolate(const int input_height, T lx = x - *x_low; T hy = 1.0 - ly; T hx = 1.0 - lx; - *w1 = hy * hx * zero_sign; - *w2 = hy * lx * zero_sign; - *w3 = ly * hx * zero_sign; - *w4 = ly * lx * zero_sign; + *w1 = hy * hx; + *w2 = hy * lx; + *w3 = ly * hx; + *w4 = ly * lx; + return; } template @@ -141,7 +142,7 @@ __mlu_func__ void roiAlignRotatedForward(const T *input_dram, int dst_offset = 0; int c_rem, c_slice, c_slice_align, pongc_slice, pongc_slice_align; for (int c_offset = 0; c_offset < channel; c_offset += channel_align) { - __nramset(nram_out, channel_align, (T)0); + __bang_write_value(nram_out, channel_align, (T)0); c_rem = channel - c_offset; c_slice = channel_align > c_rem ? c_rem : channel_align; c_slice_align = CEIL_ALIGN(c_slice, align_base_128); @@ -159,9 +160,8 @@ __mlu_func__ void roiAlignRotatedForward(const T *input_dram, T w1, w2, w3, w4; bool empty = false; int x_low, x_high, y_low, y_high; - bilinearInterpolate(height, width, x, y, zero_sign, &w1, &w2, &w3, - &w4, &x_low, &x_high, &y_low, &y_high, &empty); - int sample_wdim = x_high - x_low + 1; + bilinearInterpolate(height, width, x, y, &w1, &w2, &w3, &w4, &x_low, + &x_high, &y_low, &y_high, &empty); /******************************************************* | ping | pong | |------|-----|-----|-----|-----|-----|-----|-----|-----| @@ -170,22 +170,32 @@ __mlu_func__ void roiAlignRotatedForward(const T *input_dram, ********************************************************/ if (is_first_sample && !empty) { // load input data from dram to nram - __nramset(nram_ping, SAMPLING_NUM * c_slice_align, (T)0); - for (int h = y_low; h <= y_high; ++h) { - src_offset = - (batch_idx * height * width + h * width + x_low) * channel + - c_offset; - dst_offset = (h - y_low) * SAMPLING_NUM * c_slice_align / 2; - if (c_slice_align == channel) { - __memcpy(nram_ping + dst_offset, input_dram + src_offset, - sample_wdim * channel * sizeof(T), GDRAM2NRAM); - } else { - __memcpy(nram_ping + dst_offset, input_dram + src_offset, - c_slice * sizeof(T), GDRAM2NRAM, - c_slice_align * sizeof(T), channel * sizeof(T), - sample_wdim - 1); - } - } + __bang_write_value(nram_ping, SAMPLING_NUM * c_slice_align, (T)0); + src_offset = + (batch_idx * height * width + y_low * width + x_low) * channel + + c_offset; + dst_offset = 0; + __memcpy(nram_ping + dst_offset, input_dram + src_offset, + c_slice * sizeof(T), GDRAM2NRAM); + src_offset = (batch_idx * height * width + y_low * width + x_high) * + channel + + c_offset; + dst_offset = c_slice_align; + __memcpy(nram_ping + dst_offset, input_dram + src_offset, + c_slice * sizeof(T), GDRAM2NRAM); + src_offset = (batch_idx * height * width + y_high * width + x_low) * + channel + + c_offset; + dst_offset = c_slice_align * 2; + __memcpy(nram_ping + dst_offset, input_dram + src_offset, + c_slice * sizeof(T), GDRAM2NRAM); + src_offset = + (batch_idx * height * width + y_high * width + x_high) * + channel + + c_offset; + dst_offset = c_slice_align * 3; + __memcpy(nram_ping + dst_offset, input_dram + src_offset, + c_slice * sizeof(T), GDRAM2NRAM); } // load next input data to nram if (sample_i + 1 < bin_dim) { @@ -200,56 +210,65 @@ __mlu_func__ void roiAlignRotatedForward(const T *input_dram, T p_w1, p_w2, p_w3, p_w4; bool p_empty = false; int p_x_low, p_x_high, p_y_low, p_y_high; - bilinearInterpolate(height, width, p_x, p_y, zero_sign, &p_w1, - &p_w2, &p_w3, &p_w4, &p_x_low, &p_x_high, - &p_y_low, &p_y_high, &p_empty); - int p_sample_wdim = p_x_high - p_x_low + 1; + bilinearInterpolate(height, width, p_x, p_y, &p_w1, &p_w2, &p_w3, + &p_w4, &p_x_low, &p_x_high, &p_y_low, &p_y_high, + &p_empty); pongc_slice = c_slice; pongc_slice_align = c_slice_align; if (!p_empty) { - __nramset(nram_pong, SAMPLING_NUM * pongc_slice_align, (T)0); - for (int h = p_y_low; h <= p_y_high; ++h) { - src_offset = - (batch_idx * height * width + h * width + p_x_low) * - channel + - c_offset; - dst_offset = - (h - p_y_low) * SAMPLING_NUM * pongc_slice_align / 2; - if (pongc_slice_align == channel) { - __memcpy_async( - nram_pong + dst_offset, input_dram + src_offset, - p_sample_wdim * channel * sizeof(T), GDRAM2NRAM); - } else { - __memcpy_async(nram_pong + dst_offset, - input_dram + src_offset, - pongc_slice * sizeof(T), GDRAM2NRAM, - pongc_slice_align * sizeof(T), - channel * sizeof(T), p_sample_wdim - 1); - } - } + __bang_write_value(nram_pong, SAMPLING_NUM * pongc_slice_align, + (T)0); + src_offset = + (batch_idx * height * width + p_y_low * width + p_x_low) * + channel + + c_offset; + dst_offset = 0; + __memcpy(nram_pong + dst_offset, input_dram + src_offset, + c_slice * sizeof(T), GDRAM2NRAM); + src_offset = + (batch_idx * height * width + p_y_low * width + p_x_high) * + channel + + c_offset; + dst_offset = pongc_slice_align; + __memcpy(nram_pong + dst_offset, input_dram + src_offset, + c_slice * sizeof(T), GDRAM2NRAM); + src_offset = + (batch_idx * height * width + p_y_high * width + p_x_low) * + channel + + c_offset; + dst_offset = pongc_slice_align * 2; + __memcpy(nram_pong + dst_offset, input_dram + src_offset, + c_slice * sizeof(T), GDRAM2NRAM); + src_offset = + (batch_idx * height * width + p_y_high * width + p_x_high) * + channel + + c_offset; + dst_offset = pongc_slice_align * 3; + __memcpy(nram_pong + dst_offset, input_dram + src_offset, + c_slice * sizeof(T), GDRAM2NRAM); } } T *tmp_sum = nram_ping + 3 * c_slice_align; if (empty) { - __nramset(tmp_sum, c_slice_align, T(0)); + __bang_write_value(tmp_sum, c_slice_align, T(0)); } else { - __bang_mul_const(nram_ping, nram_ping, w1, c_slice_align); - __bang_mul_const(nram_ping + c_slice_align, - nram_ping + c_slice_align, w2, c_slice_align); - __bang_mul_const(nram_ping + 2 * c_slice_align, - nram_ping + 2 * c_slice_align, w3, c_slice_align); - __bang_mul_const(nram_ping + 3 * c_slice_align, - nram_ping + 3 * c_slice_align, w4, c_slice_align); + __bang_mul_scalar(nram_ping, nram_ping, w1, c_slice_align); + __bang_mul_scalar(nram_ping + c_slice_align, + nram_ping + c_slice_align, w2, c_slice_align); + __bang_mul_scalar(nram_ping + 2 * c_slice_align, + nram_ping + 2 * c_slice_align, w3, c_slice_align); + __bang_mul_scalar(nram_ping + 3 * c_slice_align, + nram_ping + 3 * c_slice_align, w4, c_slice_align); __bang_sumpool(tmp_sum, nram_ping, c_slice_align, 1, SAMPLING_NUM, 1, SAMPLING_NUM, 1, 1); } __bang_add(nram_out, nram_out, tmp_sum, c_slice_align); swap(nram_ping, nram_pong); - __asm__ volatile("sync;"); is_first_sample = false; } } + __bang_mul_scalar(nram_out, nram_out, zero_sign, c_slice_align); // store the result to dram int output_offset = ((roi_n * params.pooled_height + ph) * params.pooled_width + pw) * @@ -310,7 +329,6 @@ __mlu_func__ void roiAlignRotatedBackward(const T *top_grad_dram, T cos_theta = std::cos(theta); T sin_theta = std::sin(theta); T zero_sign = 1.0f / bin_dim; - int c_rem, c_slice, pongc_slice, c_offset; c_rem = channel; c_offset = 0; @@ -369,30 +387,30 @@ __mlu_func__ void roiAlignRotatedBackward(const T *top_grad_dram, T w1, w2, w3, w4; bool empty = false; int x_low, x_high, y_low, y_high; - bilinearInterpolate(height, width, x, y, zero_sign, &w1, &w2, &w3, - &w4, &x_low, &x_high, &y_low, &y_high, &empty); + bilinearInterpolate(height, width, x, y, &w1, &w2, &w3, &w4, &x_low, + &x_high, &y_low, &y_high, &empty); if (empty) { continue; } else { - __bang_mul_const(nram_output, nram_ping, w1, c_limit); + __bang_mul_scalar(nram_output, nram_ping, w1 * zero_sign, c_limit); __bang_atomic_add( (T *)nram_output, bottom_grad_dram + batch_idx * height * width * channel + y_low * width * channel + x_low * channel + c_offset, (T *)nram_output, c_slice); - __bang_mul_const(nram_output, nram_ping, w2, c_limit); + __bang_mul_scalar(nram_output, nram_ping, w2 * zero_sign, c_limit); __bang_atomic_add( (T *)nram_output, bottom_grad_dram + batch_idx * height * width * channel + y_low * width * channel + x_high * channel + c_offset, (T *)nram_output, c_slice); - __bang_mul_const(nram_output, nram_ping, w3, c_limit); + __bang_mul_scalar(nram_output, nram_ping, w3 * zero_sign, c_limit); __bang_atomic_add( (T *)nram_output, bottom_grad_dram + batch_idx * height * width * channel + y_high * width * channel + x_low * channel + c_offset, (T *)nram_output, c_slice); - __bang_mul_const(nram_output, nram_ping, w4, c_limit); + __bang_mul_scalar(nram_output, nram_ping, w4 * zero_sign, c_limit); __bang_atomic_add( (T *)nram_output, bottom_grad_dram + batch_idx * height * width * channel + diff --git a/mmcv/ops/csrc/pytorch/mlu/roi_align_rotated_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/roi_align_rotated_mlu.cpp old mode 100644 new mode 100755 index 255aefdd9..c3058c01f --- a/mmcv/ops/csrc/pytorch/mlu/roi_align_rotated_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/roi_align_rotated_mlu.cpp @@ -99,8 +99,8 @@ void ROIAlignRotatedForwardMLUKernelLauncher(Tensor input, Tensor rois, auto input_tensor = torch_mlu::cnnl::ops::cnnl_contiguous(input, memory_format); at::Tensor output_tmp = - at::empty({batch, channel, pooled_height, pooled_width}, input.options(), - memory_format); + at::empty({rois_nums, channel, pooled_height, pooled_width}, + input.options(), memory_format); // get compute queue auto queue = torch_mlu::getCurQueue();