[Fix] Fix roi_align_rotated op of MLU (#2210)

* [Fix] roi_align_rotated codes

* [Fix]: fix code style

* [Fix]: fix code style

* [Fix]: fix code style
pull/2223/head
tudejiang79 2022-08-19 21:19:40 +08:00 committed by GitHub
parent 730a53a062
commit 9b11e560f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 88 additions and 70 deletions

View File

@ -25,10 +25,10 @@ __mlu_func__ void swap(T &a, T &b) {
template <typename T>
__mlu_func__ void bilinearInterpolate(const int input_height,
const int input_width, T x, T y,
const T zero_sign, T *w1, T *w2, T *w3,
T *w4, int *x_low, int *x_high,
int *y_low, int *y_high, bool *empty) {
const int input_width, T x, T y, T *w1,
T *w2, T *w3, T *w4, int *x_low,
int *x_high, int *y_low, int *y_high,
bool *empty) {
// deal with case that the point is out of feature map boundary
if (y < -1.0 || y > input_height || x < -1.0 || x > input_width) {
*empty = true;
@ -58,10 +58,11 @@ __mlu_func__ void bilinearInterpolate(const int input_height,
T lx = x - *x_low;
T hy = 1.0 - ly;
T hx = 1.0 - lx;
*w1 = hy * hx * zero_sign;
*w2 = hy * lx * zero_sign;
*w3 = ly * hx * zero_sign;
*w4 = ly * lx * zero_sign;
*w1 = hy * hx;
*w2 = hy * lx;
*w3 = ly * hx;
*w4 = ly * lx;
return;
}
template <typename T>
@ -141,7 +142,7 @@ __mlu_func__ void roiAlignRotatedForward(const T *input_dram,
int dst_offset = 0;
int c_rem, c_slice, c_slice_align, pongc_slice, pongc_slice_align;
for (int c_offset = 0; c_offset < channel; c_offset += channel_align) {
__nramset(nram_out, channel_align, (T)0);
__bang_write_value(nram_out, channel_align, (T)0);
c_rem = channel - c_offset;
c_slice = channel_align > c_rem ? c_rem : channel_align;
c_slice_align = CEIL_ALIGN(c_slice, align_base_128);
@ -159,9 +160,8 @@ __mlu_func__ void roiAlignRotatedForward(const T *input_dram,
T w1, w2, w3, w4;
bool empty = false;
int x_low, x_high, y_low, y_high;
bilinearInterpolate(height, width, x, y, zero_sign, &w1, &w2, &w3,
&w4, &x_low, &x_high, &y_low, &y_high, &empty);
int sample_wdim = x_high - x_low + 1;
bilinearInterpolate(height, width, x, y, &w1, &w2, &w3, &w4, &x_low,
&x_high, &y_low, &y_high, &empty);
/*******************************************************
| ping | pong |
|------|-----|-----|-----|-----|-----|-----|-----|-----|
@ -170,22 +170,32 @@ __mlu_func__ void roiAlignRotatedForward(const T *input_dram,
********************************************************/
if (is_first_sample && !empty) {
// load input data from dram to nram
__nramset(nram_ping, SAMPLING_NUM * c_slice_align, (T)0);
for (int h = y_low; h <= y_high; ++h) {
src_offset =
(batch_idx * height * width + h * width + x_low) * channel +
c_offset;
dst_offset = (h - y_low) * SAMPLING_NUM * c_slice_align / 2;
if (c_slice_align == channel) {
__memcpy(nram_ping + dst_offset, input_dram + src_offset,
sample_wdim * channel * sizeof(T), GDRAM2NRAM);
} else {
__memcpy(nram_ping + dst_offset, input_dram + src_offset,
c_slice * sizeof(T), GDRAM2NRAM,
c_slice_align * sizeof(T), channel * sizeof(T),
sample_wdim - 1);
}
}
__bang_write_value(nram_ping, SAMPLING_NUM * c_slice_align, (T)0);
src_offset =
(batch_idx * height * width + y_low * width + x_low) * channel +
c_offset;
dst_offset = 0;
__memcpy(nram_ping + dst_offset, input_dram + src_offset,
c_slice * sizeof(T), GDRAM2NRAM);
src_offset = (batch_idx * height * width + y_low * width + x_high) *
channel +
c_offset;
dst_offset = c_slice_align;
__memcpy(nram_ping + dst_offset, input_dram + src_offset,
c_slice * sizeof(T), GDRAM2NRAM);
src_offset = (batch_idx * height * width + y_high * width + x_low) *
channel +
c_offset;
dst_offset = c_slice_align * 2;
__memcpy(nram_ping + dst_offset, input_dram + src_offset,
c_slice * sizeof(T), GDRAM2NRAM);
src_offset =
(batch_idx * height * width + y_high * width + x_high) *
channel +
c_offset;
dst_offset = c_slice_align * 3;
__memcpy(nram_ping + dst_offset, input_dram + src_offset,
c_slice * sizeof(T), GDRAM2NRAM);
}
// load next input data to nram
if (sample_i + 1 < bin_dim) {
@ -200,56 +210,65 @@ __mlu_func__ void roiAlignRotatedForward(const T *input_dram,
T p_w1, p_w2, p_w3, p_w4;
bool p_empty = false;
int p_x_low, p_x_high, p_y_low, p_y_high;
bilinearInterpolate(height, width, p_x, p_y, zero_sign, &p_w1,
&p_w2, &p_w3, &p_w4, &p_x_low, &p_x_high,
&p_y_low, &p_y_high, &p_empty);
int p_sample_wdim = p_x_high - p_x_low + 1;
bilinearInterpolate(height, width, p_x, p_y, &p_w1, &p_w2, &p_w3,
&p_w4, &p_x_low, &p_x_high, &p_y_low, &p_y_high,
&p_empty);
pongc_slice = c_slice;
pongc_slice_align = c_slice_align;
if (!p_empty) {
__nramset(nram_pong, SAMPLING_NUM * pongc_slice_align, (T)0);
for (int h = p_y_low; h <= p_y_high; ++h) {
src_offset =
(batch_idx * height * width + h * width + p_x_low) *
channel +
c_offset;
dst_offset =
(h - p_y_low) * SAMPLING_NUM * pongc_slice_align / 2;
if (pongc_slice_align == channel) {
__memcpy_async(
nram_pong + dst_offset, input_dram + src_offset,
p_sample_wdim * channel * sizeof(T), GDRAM2NRAM);
} else {
__memcpy_async(nram_pong + dst_offset,
input_dram + src_offset,
pongc_slice * sizeof(T), GDRAM2NRAM,
pongc_slice_align * sizeof(T),
channel * sizeof(T), p_sample_wdim - 1);
}
}
__bang_write_value(nram_pong, SAMPLING_NUM * pongc_slice_align,
(T)0);
src_offset =
(batch_idx * height * width + p_y_low * width + p_x_low) *
channel +
c_offset;
dst_offset = 0;
__memcpy(nram_pong + dst_offset, input_dram + src_offset,
c_slice * sizeof(T), GDRAM2NRAM);
src_offset =
(batch_idx * height * width + p_y_low * width + p_x_high) *
channel +
c_offset;
dst_offset = pongc_slice_align;
__memcpy(nram_pong + dst_offset, input_dram + src_offset,
c_slice * sizeof(T), GDRAM2NRAM);
src_offset =
(batch_idx * height * width + p_y_high * width + p_x_low) *
channel +
c_offset;
dst_offset = pongc_slice_align * 2;
__memcpy(nram_pong + dst_offset, input_dram + src_offset,
c_slice * sizeof(T), GDRAM2NRAM);
src_offset =
(batch_idx * height * width + p_y_high * width + p_x_high) *
channel +
c_offset;
dst_offset = pongc_slice_align * 3;
__memcpy(nram_pong + dst_offset, input_dram + src_offset,
c_slice * sizeof(T), GDRAM2NRAM);
}
}
T *tmp_sum = nram_ping + 3 * c_slice_align;
if (empty) {
__nramset(tmp_sum, c_slice_align, T(0));
__bang_write_value(tmp_sum, c_slice_align, T(0));
} else {
__bang_mul_const(nram_ping, nram_ping, w1, c_slice_align);
__bang_mul_const(nram_ping + c_slice_align,
nram_ping + c_slice_align, w2, c_slice_align);
__bang_mul_const(nram_ping + 2 * c_slice_align,
nram_ping + 2 * c_slice_align, w3, c_slice_align);
__bang_mul_const(nram_ping + 3 * c_slice_align,
nram_ping + 3 * c_slice_align, w4, c_slice_align);
__bang_mul_scalar(nram_ping, nram_ping, w1, c_slice_align);
__bang_mul_scalar(nram_ping + c_slice_align,
nram_ping + c_slice_align, w2, c_slice_align);
__bang_mul_scalar(nram_ping + 2 * c_slice_align,
nram_ping + 2 * c_slice_align, w3, c_slice_align);
__bang_mul_scalar(nram_ping + 3 * c_slice_align,
nram_ping + 3 * c_slice_align, w4, c_slice_align);
__bang_sumpool(tmp_sum, nram_ping, c_slice_align, 1, SAMPLING_NUM,
1, SAMPLING_NUM, 1, 1);
}
__bang_add(nram_out, nram_out, tmp_sum, c_slice_align);
swap(nram_ping, nram_pong);
__asm__ volatile("sync;");
is_first_sample = false;
}
}
__bang_mul_scalar(nram_out, nram_out, zero_sign, c_slice_align);
// store the result to dram
int output_offset =
((roi_n * params.pooled_height + ph) * params.pooled_width + pw) *
@ -310,7 +329,6 @@ __mlu_func__ void roiAlignRotatedBackward(const T *top_grad_dram,
T cos_theta = std::cos(theta);
T sin_theta = std::sin(theta);
T zero_sign = 1.0f / bin_dim;
int c_rem, c_slice, pongc_slice, c_offset;
c_rem = channel;
c_offset = 0;
@ -369,30 +387,30 @@ __mlu_func__ void roiAlignRotatedBackward(const T *top_grad_dram,
T w1, w2, w3, w4;
bool empty = false;
int x_low, x_high, y_low, y_high;
bilinearInterpolate(height, width, x, y, zero_sign, &w1, &w2, &w3,
&w4, &x_low, &x_high, &y_low, &y_high, &empty);
bilinearInterpolate(height, width, x, y, &w1, &w2, &w3, &w4, &x_low,
&x_high, &y_low, &y_high, &empty);
if (empty) {
continue;
} else {
__bang_mul_const(nram_output, nram_ping, w1, c_limit);
__bang_mul_scalar(nram_output, nram_ping, w1 * zero_sign, c_limit);
__bang_atomic_add(
(T *)nram_output,
bottom_grad_dram + batch_idx * height * width * channel +
y_low * width * channel + x_low * channel + c_offset,
(T *)nram_output, c_slice);
__bang_mul_const(nram_output, nram_ping, w2, c_limit);
__bang_mul_scalar(nram_output, nram_ping, w2 * zero_sign, c_limit);
__bang_atomic_add(
(T *)nram_output,
bottom_grad_dram + batch_idx * height * width * channel +
y_low * width * channel + x_high * channel + c_offset,
(T *)nram_output, c_slice);
__bang_mul_const(nram_output, nram_ping, w3, c_limit);
__bang_mul_scalar(nram_output, nram_ping, w3 * zero_sign, c_limit);
__bang_atomic_add(
(T *)nram_output,
bottom_grad_dram + batch_idx * height * width * channel +
y_high * width * channel + x_low * channel + c_offset,
(T *)nram_output, c_slice);
__bang_mul_const(nram_output, nram_ping, w4, c_limit);
__bang_mul_scalar(nram_output, nram_ping, w4 * zero_sign, c_limit);
__bang_atomic_add(
(T *)nram_output,
bottom_grad_dram + batch_idx * height * width * channel +

View File

@ -99,8 +99,8 @@ void ROIAlignRotatedForwardMLUKernelLauncher(Tensor input, Tensor rois,
auto input_tensor =
torch_mlu::cnnl::ops::cnnl_contiguous(input, memory_format);
at::Tensor output_tmp =
at::empty({batch, channel, pooled_height, pooled_width}, input.options(),
memory_format);
at::empty({rois_nums, channel, pooled_height, pooled_width},
input.options(), memory_format);
// get compute queue
auto queue = torch_mlu::getCurQueue();