mirror of https://github.com/open-mmlab/mmcv.git
[Fix] Fix roi_align_rotated op of MLU (#2210)
* [Fix] roi_align_rotated codes * [Fix]: fix code style * [Fix]: fix code style * [Fix]: fix code stylepull/2223/head
parent
730a53a062
commit
9b11e560f3
|
@ -25,10 +25,10 @@ __mlu_func__ void swap(T &a, T &b) {
|
|||
|
||||
template <typename T>
|
||||
__mlu_func__ void bilinearInterpolate(const int input_height,
|
||||
const int input_width, T x, T y,
|
||||
const T zero_sign, T *w1, T *w2, T *w3,
|
||||
T *w4, int *x_low, int *x_high,
|
||||
int *y_low, int *y_high, bool *empty) {
|
||||
const int input_width, T x, T y, T *w1,
|
||||
T *w2, T *w3, T *w4, int *x_low,
|
||||
int *x_high, int *y_low, int *y_high,
|
||||
bool *empty) {
|
||||
// deal with case that the point is out of feature map boundary
|
||||
if (y < -1.0 || y > input_height || x < -1.0 || x > input_width) {
|
||||
*empty = true;
|
||||
|
@ -58,10 +58,11 @@ __mlu_func__ void bilinearInterpolate(const int input_height,
|
|||
T lx = x - *x_low;
|
||||
T hy = 1.0 - ly;
|
||||
T hx = 1.0 - lx;
|
||||
*w1 = hy * hx * zero_sign;
|
||||
*w2 = hy * lx * zero_sign;
|
||||
*w3 = ly * hx * zero_sign;
|
||||
*w4 = ly * lx * zero_sign;
|
||||
*w1 = hy * hx;
|
||||
*w2 = hy * lx;
|
||||
*w3 = ly * hx;
|
||||
*w4 = ly * lx;
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -141,7 +142,7 @@ __mlu_func__ void roiAlignRotatedForward(const T *input_dram,
|
|||
int dst_offset = 0;
|
||||
int c_rem, c_slice, c_slice_align, pongc_slice, pongc_slice_align;
|
||||
for (int c_offset = 0; c_offset < channel; c_offset += channel_align) {
|
||||
__nramset(nram_out, channel_align, (T)0);
|
||||
__bang_write_value(nram_out, channel_align, (T)0);
|
||||
c_rem = channel - c_offset;
|
||||
c_slice = channel_align > c_rem ? c_rem : channel_align;
|
||||
c_slice_align = CEIL_ALIGN(c_slice, align_base_128);
|
||||
|
@ -159,9 +160,8 @@ __mlu_func__ void roiAlignRotatedForward(const T *input_dram,
|
|||
T w1, w2, w3, w4;
|
||||
bool empty = false;
|
||||
int x_low, x_high, y_low, y_high;
|
||||
bilinearInterpolate(height, width, x, y, zero_sign, &w1, &w2, &w3,
|
||||
&w4, &x_low, &x_high, &y_low, &y_high, &empty);
|
||||
int sample_wdim = x_high - x_low + 1;
|
||||
bilinearInterpolate(height, width, x, y, &w1, &w2, &w3, &w4, &x_low,
|
||||
&x_high, &y_low, &y_high, &empty);
|
||||
/*******************************************************
|
||||
| ping | pong |
|
||||
|------|-----|-----|-----|-----|-----|-----|-----|-----|
|
||||
|
@ -170,22 +170,32 @@ __mlu_func__ void roiAlignRotatedForward(const T *input_dram,
|
|||
********************************************************/
|
||||
if (is_first_sample && !empty) {
|
||||
// load input data from dram to nram
|
||||
__nramset(nram_ping, SAMPLING_NUM * c_slice_align, (T)0);
|
||||
for (int h = y_low; h <= y_high; ++h) {
|
||||
src_offset =
|
||||
(batch_idx * height * width + h * width + x_low) * channel +
|
||||
c_offset;
|
||||
dst_offset = (h - y_low) * SAMPLING_NUM * c_slice_align / 2;
|
||||
if (c_slice_align == channel) {
|
||||
__memcpy(nram_ping + dst_offset, input_dram + src_offset,
|
||||
sample_wdim * channel * sizeof(T), GDRAM2NRAM);
|
||||
} else {
|
||||
__memcpy(nram_ping + dst_offset, input_dram + src_offset,
|
||||
c_slice * sizeof(T), GDRAM2NRAM,
|
||||
c_slice_align * sizeof(T), channel * sizeof(T),
|
||||
sample_wdim - 1);
|
||||
}
|
||||
}
|
||||
__bang_write_value(nram_ping, SAMPLING_NUM * c_slice_align, (T)0);
|
||||
src_offset =
|
||||
(batch_idx * height * width + y_low * width + x_low) * channel +
|
||||
c_offset;
|
||||
dst_offset = 0;
|
||||
__memcpy(nram_ping + dst_offset, input_dram + src_offset,
|
||||
c_slice * sizeof(T), GDRAM2NRAM);
|
||||
src_offset = (batch_idx * height * width + y_low * width + x_high) *
|
||||
channel +
|
||||
c_offset;
|
||||
dst_offset = c_slice_align;
|
||||
__memcpy(nram_ping + dst_offset, input_dram + src_offset,
|
||||
c_slice * sizeof(T), GDRAM2NRAM);
|
||||
src_offset = (batch_idx * height * width + y_high * width + x_low) *
|
||||
channel +
|
||||
c_offset;
|
||||
dst_offset = c_slice_align * 2;
|
||||
__memcpy(nram_ping + dst_offset, input_dram + src_offset,
|
||||
c_slice * sizeof(T), GDRAM2NRAM);
|
||||
src_offset =
|
||||
(batch_idx * height * width + y_high * width + x_high) *
|
||||
channel +
|
||||
c_offset;
|
||||
dst_offset = c_slice_align * 3;
|
||||
__memcpy(nram_ping + dst_offset, input_dram + src_offset,
|
||||
c_slice * sizeof(T), GDRAM2NRAM);
|
||||
}
|
||||
// load next input data to nram
|
||||
if (sample_i + 1 < bin_dim) {
|
||||
|
@ -200,56 +210,65 @@ __mlu_func__ void roiAlignRotatedForward(const T *input_dram,
|
|||
T p_w1, p_w2, p_w3, p_w4;
|
||||
bool p_empty = false;
|
||||
int p_x_low, p_x_high, p_y_low, p_y_high;
|
||||
bilinearInterpolate(height, width, p_x, p_y, zero_sign, &p_w1,
|
||||
&p_w2, &p_w3, &p_w4, &p_x_low, &p_x_high,
|
||||
&p_y_low, &p_y_high, &p_empty);
|
||||
int p_sample_wdim = p_x_high - p_x_low + 1;
|
||||
bilinearInterpolate(height, width, p_x, p_y, &p_w1, &p_w2, &p_w3,
|
||||
&p_w4, &p_x_low, &p_x_high, &p_y_low, &p_y_high,
|
||||
&p_empty);
|
||||
pongc_slice = c_slice;
|
||||
pongc_slice_align = c_slice_align;
|
||||
if (!p_empty) {
|
||||
__nramset(nram_pong, SAMPLING_NUM * pongc_slice_align, (T)0);
|
||||
for (int h = p_y_low; h <= p_y_high; ++h) {
|
||||
src_offset =
|
||||
(batch_idx * height * width + h * width + p_x_low) *
|
||||
channel +
|
||||
c_offset;
|
||||
dst_offset =
|
||||
(h - p_y_low) * SAMPLING_NUM * pongc_slice_align / 2;
|
||||
if (pongc_slice_align == channel) {
|
||||
__memcpy_async(
|
||||
nram_pong + dst_offset, input_dram + src_offset,
|
||||
p_sample_wdim * channel * sizeof(T), GDRAM2NRAM);
|
||||
} else {
|
||||
__memcpy_async(nram_pong + dst_offset,
|
||||
input_dram + src_offset,
|
||||
pongc_slice * sizeof(T), GDRAM2NRAM,
|
||||
pongc_slice_align * sizeof(T),
|
||||
channel * sizeof(T), p_sample_wdim - 1);
|
||||
}
|
||||
}
|
||||
__bang_write_value(nram_pong, SAMPLING_NUM * pongc_slice_align,
|
||||
(T)0);
|
||||
src_offset =
|
||||
(batch_idx * height * width + p_y_low * width + p_x_low) *
|
||||
channel +
|
||||
c_offset;
|
||||
dst_offset = 0;
|
||||
__memcpy(nram_pong + dst_offset, input_dram + src_offset,
|
||||
c_slice * sizeof(T), GDRAM2NRAM);
|
||||
src_offset =
|
||||
(batch_idx * height * width + p_y_low * width + p_x_high) *
|
||||
channel +
|
||||
c_offset;
|
||||
dst_offset = pongc_slice_align;
|
||||
__memcpy(nram_pong + dst_offset, input_dram + src_offset,
|
||||
c_slice * sizeof(T), GDRAM2NRAM);
|
||||
src_offset =
|
||||
(batch_idx * height * width + p_y_high * width + p_x_low) *
|
||||
channel +
|
||||
c_offset;
|
||||
dst_offset = pongc_slice_align * 2;
|
||||
__memcpy(nram_pong + dst_offset, input_dram + src_offset,
|
||||
c_slice * sizeof(T), GDRAM2NRAM);
|
||||
src_offset =
|
||||
(batch_idx * height * width + p_y_high * width + p_x_high) *
|
||||
channel +
|
||||
c_offset;
|
||||
dst_offset = pongc_slice_align * 3;
|
||||
__memcpy(nram_pong + dst_offset, input_dram + src_offset,
|
||||
c_slice * sizeof(T), GDRAM2NRAM);
|
||||
}
|
||||
}
|
||||
T *tmp_sum = nram_ping + 3 * c_slice_align;
|
||||
if (empty) {
|
||||
__nramset(tmp_sum, c_slice_align, T(0));
|
||||
__bang_write_value(tmp_sum, c_slice_align, T(0));
|
||||
} else {
|
||||
__bang_mul_const(nram_ping, nram_ping, w1, c_slice_align);
|
||||
__bang_mul_const(nram_ping + c_slice_align,
|
||||
nram_ping + c_slice_align, w2, c_slice_align);
|
||||
__bang_mul_const(nram_ping + 2 * c_slice_align,
|
||||
nram_ping + 2 * c_slice_align, w3, c_slice_align);
|
||||
__bang_mul_const(nram_ping + 3 * c_slice_align,
|
||||
nram_ping + 3 * c_slice_align, w4, c_slice_align);
|
||||
__bang_mul_scalar(nram_ping, nram_ping, w1, c_slice_align);
|
||||
__bang_mul_scalar(nram_ping + c_slice_align,
|
||||
nram_ping + c_slice_align, w2, c_slice_align);
|
||||
__bang_mul_scalar(nram_ping + 2 * c_slice_align,
|
||||
nram_ping + 2 * c_slice_align, w3, c_slice_align);
|
||||
__bang_mul_scalar(nram_ping + 3 * c_slice_align,
|
||||
nram_ping + 3 * c_slice_align, w4, c_slice_align);
|
||||
__bang_sumpool(tmp_sum, nram_ping, c_slice_align, 1, SAMPLING_NUM,
|
||||
1, SAMPLING_NUM, 1, 1);
|
||||
}
|
||||
__bang_add(nram_out, nram_out, tmp_sum, c_slice_align);
|
||||
swap(nram_ping, nram_pong);
|
||||
|
||||
__asm__ volatile("sync;");
|
||||
is_first_sample = false;
|
||||
}
|
||||
}
|
||||
__bang_mul_scalar(nram_out, nram_out, zero_sign, c_slice_align);
|
||||
// store the result to dram
|
||||
int output_offset =
|
||||
((roi_n * params.pooled_height + ph) * params.pooled_width + pw) *
|
||||
|
@ -310,7 +329,6 @@ __mlu_func__ void roiAlignRotatedBackward(const T *top_grad_dram,
|
|||
T cos_theta = std::cos(theta);
|
||||
T sin_theta = std::sin(theta);
|
||||
T zero_sign = 1.0f / bin_dim;
|
||||
|
||||
int c_rem, c_slice, pongc_slice, c_offset;
|
||||
c_rem = channel;
|
||||
c_offset = 0;
|
||||
|
@ -369,30 +387,30 @@ __mlu_func__ void roiAlignRotatedBackward(const T *top_grad_dram,
|
|||
T w1, w2, w3, w4;
|
||||
bool empty = false;
|
||||
int x_low, x_high, y_low, y_high;
|
||||
bilinearInterpolate(height, width, x, y, zero_sign, &w1, &w2, &w3,
|
||||
&w4, &x_low, &x_high, &y_low, &y_high, &empty);
|
||||
bilinearInterpolate(height, width, x, y, &w1, &w2, &w3, &w4, &x_low,
|
||||
&x_high, &y_low, &y_high, &empty);
|
||||
if (empty) {
|
||||
continue;
|
||||
} else {
|
||||
__bang_mul_const(nram_output, nram_ping, w1, c_limit);
|
||||
__bang_mul_scalar(nram_output, nram_ping, w1 * zero_sign, c_limit);
|
||||
__bang_atomic_add(
|
||||
(T *)nram_output,
|
||||
bottom_grad_dram + batch_idx * height * width * channel +
|
||||
y_low * width * channel + x_low * channel + c_offset,
|
||||
(T *)nram_output, c_slice);
|
||||
__bang_mul_const(nram_output, nram_ping, w2, c_limit);
|
||||
__bang_mul_scalar(nram_output, nram_ping, w2 * zero_sign, c_limit);
|
||||
__bang_atomic_add(
|
||||
(T *)nram_output,
|
||||
bottom_grad_dram + batch_idx * height * width * channel +
|
||||
y_low * width * channel + x_high * channel + c_offset,
|
||||
(T *)nram_output, c_slice);
|
||||
__bang_mul_const(nram_output, nram_ping, w3, c_limit);
|
||||
__bang_mul_scalar(nram_output, nram_ping, w3 * zero_sign, c_limit);
|
||||
__bang_atomic_add(
|
||||
(T *)nram_output,
|
||||
bottom_grad_dram + batch_idx * height * width * channel +
|
||||
y_high * width * channel + x_low * channel + c_offset,
|
||||
(T *)nram_output, c_slice);
|
||||
__bang_mul_const(nram_output, nram_ping, w4, c_limit);
|
||||
__bang_mul_scalar(nram_output, nram_ping, w4 * zero_sign, c_limit);
|
||||
__bang_atomic_add(
|
||||
(T *)nram_output,
|
||||
bottom_grad_dram + batch_idx * height * width * channel +
|
||||
|
|
|
@ -99,8 +99,8 @@ void ROIAlignRotatedForwardMLUKernelLauncher(Tensor input, Tensor rois,
|
|||
auto input_tensor =
|
||||
torch_mlu::cnnl::ops::cnnl_contiguous(input, memory_format);
|
||||
at::Tensor output_tmp =
|
||||
at::empty({batch, channel, pooled_height, pooled_width}, input.options(),
|
||||
memory_format);
|
||||
at::empty({rois_nums, channel, pooled_height, pooled_width},
|
||||
input.options(), memory_format);
|
||||
|
||||
// get compute queue
|
||||
auto queue = torch_mlu::getCurQueue();
|
||||
|
|
Loading…
Reference in New Issue