mirror of https://github.com/open-mmlab/mmcv.git
[Enhancement] Replace the implementation of deform_roi_pool with mlu-ops (#2598)
* [Feature] Replace the implementation of deform_roi_pool with mlu-ops * [Feature] Modify code --------- Co-authored-by: budefei <budefei@cambricon.com>pull/2697/head
parent
0d1b224fb1
commit
9fcf48c4c8
|
@ -1,712 +0,0 @@
|
|||
/*************************************************************************
|
||||
* Copyright (C) 2022 Cambricon.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#include <iostream>
|
||||
|
||||
#include "common_mlu_helper.hpp"
|
||||
|
||||
#define ROI_OFFSET 5
|
||||
#define FOURSPLIT 4
|
||||
#define FIVESPLIT 5
|
||||
#define NINESPLIT 9
|
||||
#define THIRTEENSPLIT 13
|
||||
|
||||
__nram__ char nram_buffer[MAX_NRAM_SIZE];
|
||||
|
||||
template <typename T>
|
||||
static __mlu_func__ void bilinearInterpolate(const int input_width, T y, T x,
|
||||
T *w1, T *w2, T *w3, T *w4,
|
||||
int *x_low, int *x_high,
|
||||
const int y_low, bool *is_empty) {
|
||||
if (x < -1.0 || x > input_width) {
|
||||
*is_empty = true;
|
||||
return;
|
||||
}
|
||||
|
||||
if (x <= 0) x = 0;
|
||||
|
||||
*x_low = int(x);
|
||||
|
||||
if (*x_low >= input_width - 1) {
|
||||
*x_high = *x_low = input_width - 1;
|
||||
x = T(*x_low);
|
||||
} else {
|
||||
*x_high = *x_low + 1;
|
||||
}
|
||||
|
||||
T ly = y - y_low;
|
||||
T lx = x - *x_low;
|
||||
T hy = 1.0 - ly;
|
||||
T hx = 1.0 - lx;
|
||||
*w1 = hy * hx;
|
||||
*w2 = hy * lx;
|
||||
*w3 = ly * hx;
|
||||
*w4 = ly * lx;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void MLUUnion1DeformRoIPoolForward(
|
||||
const T *input, const T *rois, const T *offset, T *output,
|
||||
const int channels, const int height, const int width, const int num_rois,
|
||||
const int pooled_height, const int pooled_width, const T spatial_scale,
|
||||
const int sampling_ratio, const T gamma) {
|
||||
for (int bin_index = taskId;
|
||||
bin_index < num_rois * pooled_width * pooled_height;
|
||||
bin_index += taskDim) {
|
||||
int out_batch = bin_index / pooled_width / pooled_height;
|
||||
int out_height = bin_index / pooled_width % pooled_height;
|
||||
int out_width = bin_index % pooled_width;
|
||||
const T *cur_roi = rois + out_batch * ROI_OFFSET;
|
||||
T *nram_rois = (T *)nram_buffer;
|
||||
__memcpy((void *)nram_rois, (void *)cur_roi, ROI_OFFSET * sizeof(T),
|
||||
GDRAM2NRAM);
|
||||
const int roi_batch = nram_rois[0];
|
||||
T roi_x_min = nram_rois[1] * spatial_scale - 0.5;
|
||||
T roi_y_min = nram_rois[2] * spatial_scale - 0.5;
|
||||
const T roi_x_max = nram_rois[3] * spatial_scale - 0.5;
|
||||
const T roi_y_max = nram_rois[4] * spatial_scale - 0.5;
|
||||
const T roi_width = roi_x_max - roi_x_min;
|
||||
const T roi_height = roi_y_max - roi_y_min;
|
||||
const T bin_width = roi_width / static_cast<T>(pooled_width);
|
||||
const T bin_height = roi_height / static_cast<T>(pooled_height);
|
||||
const T *offset_input = input + roi_batch * height * width * channels;
|
||||
int roi_bin_grid_height =
|
||||
(sampling_ratio > 0)
|
||||
? sampling_ratio
|
||||
: static_cast<int>(ceilf(roi_height / pooled_height));
|
||||
int roi_bin_grid_width =
|
||||
(sampling_ratio > 0)
|
||||
? sampling_ratio
|
||||
: static_cast<int>(ceilf(roi_width / pooled_width));
|
||||
if (offset != NULL) {
|
||||
const T *offset_cur = offset +
|
||||
out_batch * pooled_width * pooled_height * 2 +
|
||||
out_height * pooled_width + out_width;
|
||||
roi_x_min += gamma * roi_width * offset_cur[0];
|
||||
roi_y_min +=
|
||||
gamma * roi_height * offset_cur[pooled_width * pooled_height];
|
||||
}
|
||||
int type_align = NFU_ALIGN_SIZE / sizeof(T);
|
||||
int channels_max_num_nram = MAX_NRAM_SIZE / sizeof(T);
|
||||
int channels_nram_split =
|
||||
channels_max_num_nram / NINESPLIT / type_align * type_align;
|
||||
int channel_rem = channels % channels_nram_split;
|
||||
int channel_loops =
|
||||
channels / channels_nram_split + (channel_rem != 0 ? 1 : 0);
|
||||
for (int channel_loop_index = 0; channel_loop_index < channel_loops;
|
||||
++channel_loop_index) {
|
||||
int channels_num =
|
||||
channels_nram_split >= channels ? channels : channels_nram_split;
|
||||
const int channel_offset = channel_loop_index * channels_num;
|
||||
if (channel_loop_index + 1 == channel_loops && channel_rem != 0) {
|
||||
channels_num = channel_rem;
|
||||
}
|
||||
int channels_align = CEIL_ALIGN(channels_num, type_align);
|
||||
int nram_limit = (MAX_NRAM_SIZE / sizeof(T) - channels_align) >> 1;
|
||||
int c_slice = nram_limit / FOURSPLIT / type_align * type_align;
|
||||
int c_slice_align = 0;
|
||||
|
||||
/* NRAM partition
|
||||
*
|
||||
* | | ping | pong |
|
||||
* |----------|-------------------|-------------------|
|
||||
* | nram_out | p1 | p2 | p3 | p4 | p1 | p2 | p3 | p4 |
|
||||
*
|
||||
*/
|
||||
|
||||
T *nram_out = (T *)nram_buffer;
|
||||
T *nram_ping = nram_out + channels_align;
|
||||
T *nram_pong = nram_ping + nram_limit;
|
||||
__bang_write_value((T *)nram_out, channels_align, (T)0);
|
||||
__bang_write_value((T *)nram_ping, FOURSPLIT * c_slice, (T)0);
|
||||
__bang_write_value((T *)nram_pong, FOURSPLIT * c_slice, (T)0);
|
||||
const T num_bins =
|
||||
static_cast<T>(max(roi_bin_grid_height * roi_bin_grid_width, 1));
|
||||
const T value_div = 1.0f / num_bins;
|
||||
bool is_ping_empty = true;
|
||||
for (int iy = 0; iy < roi_bin_grid_height; ++iy) {
|
||||
T y = roi_y_min + out_height * bin_height +
|
||||
static_cast<T>(iy + .5f) * bin_height /
|
||||
static_cast<T>(roi_bin_grid_height);
|
||||
if (y < -1.0 || y > height) {
|
||||
is_ping_empty = true;
|
||||
continue;
|
||||
}
|
||||
if (y <= 0) {
|
||||
y = 0;
|
||||
}
|
||||
int y_low = 0, y_high = 0;
|
||||
y_low = int(y);
|
||||
if (y_low >= height - 1) {
|
||||
y_high = y_low = height - 1;
|
||||
y = T(y_low);
|
||||
} else {
|
||||
y_high = y_low + 1;
|
||||
}
|
||||
for (int ix = 0; ix < roi_bin_grid_width; ++ix) {
|
||||
T x = roi_x_min + out_width * bin_width +
|
||||
static_cast<T>(ix + .5f) * bin_width /
|
||||
static_cast<T>(roi_bin_grid_width);
|
||||
const int sample_index = iy * roi_bin_grid_width + ix;
|
||||
int c_rem = channels_num;
|
||||
c_slice = nram_limit / FOURSPLIT / type_align * type_align;
|
||||
c_slice_align = 0;
|
||||
bool is_empty = false;
|
||||
T w1, w2, w3, w4;
|
||||
int x_low = 0, x_high = 0;
|
||||
bilinearInterpolate(width, y, x, &w1, &w2, &w3, &w4, &x_low, &x_high,
|
||||
y_low, &is_empty);
|
||||
if (is_empty) {
|
||||
is_ping_empty = true;
|
||||
continue;
|
||||
}
|
||||
if (is_ping_empty) {
|
||||
c_slice = c_slice > c_rem ? c_rem : c_slice;
|
||||
c_slice_align = CEIL_ALIGN(c_slice, type_align);
|
||||
__bang_write_value(nram_ping, FOURSPLIT * c_slice_align, (T)0);
|
||||
__asm__ volatile("sync;");
|
||||
__memcpy(nram_ping,
|
||||
offset_input + y_low * width * channels +
|
||||
x_low * channels + channel_offset,
|
||||
c_slice * sizeof(T), GDRAM2NRAM);
|
||||
__memcpy(nram_ping + c_slice_align,
|
||||
offset_input + y_low * width * channels +
|
||||
x_high * channels + channel_offset,
|
||||
c_slice * sizeof(T), GDRAM2NRAM);
|
||||
__memcpy(nram_ping + 2 * c_slice_align,
|
||||
offset_input + y_high * width * channels +
|
||||
x_low * channels + channel_offset,
|
||||
c_slice * sizeof(T), GDRAM2NRAM);
|
||||
__memcpy(nram_ping + 3 * c_slice_align,
|
||||
offset_input + y_high * width * channels +
|
||||
x_high * channels + channel_offset,
|
||||
c_slice * sizeof(T), GDRAM2NRAM);
|
||||
is_ping_empty = false;
|
||||
}
|
||||
int c_offset = 0;
|
||||
int pongc_slice = 0;
|
||||
int pongc_slice_align = 0;
|
||||
while (c_rem > 0) {
|
||||
c_slice = c_slice > c_rem ? c_rem : c_slice;
|
||||
c_slice_align = CEIL_ALIGN(c_slice, type_align);
|
||||
if (sample_index + 1 < roi_bin_grid_height * roi_bin_grid_width) {
|
||||
int iy_tmp = (sample_index + 1) / roi_bin_grid_width;
|
||||
int ix_tmp = (sample_index + 1) % roi_bin_grid_width;
|
||||
y = roi_y_min + out_height * bin_height +
|
||||
static_cast<T>(iy_tmp + .5f) * bin_height /
|
||||
static_cast<T>(roi_bin_grid_height);
|
||||
x = roi_x_min + out_width * bin_width +
|
||||
static_cast<T>(ix_tmp + .5f) * bin_width /
|
||||
static_cast<T>(roi_bin_grid_width);
|
||||
if (y < -1.0 || y > height) {
|
||||
is_empty = true;
|
||||
} else {
|
||||
T w1_tmp, w2_tmp, w3_tmp, w4_tmp;
|
||||
if (y <= 0) {
|
||||
y = 0;
|
||||
}
|
||||
y_low = int(y);
|
||||
if (y_low >= height - 1) {
|
||||
y_high = y_low = height - 1;
|
||||
y = T(y_low);
|
||||
} else {
|
||||
y_high = y_low + 1;
|
||||
}
|
||||
bilinearInterpolate(width, y, x, &w1_tmp, &w2_tmp, &w3_tmp,
|
||||
&w4_tmp, &x_low, &x_high, y_low, &is_empty);
|
||||
}
|
||||
pongc_slice = nram_limit / FOURSPLIT / type_align * type_align;
|
||||
pongc_slice =
|
||||
pongc_slice > channels_num ? channels_num : pongc_slice;
|
||||
pongc_slice_align = CEIL_ALIGN(pongc_slice, type_align);
|
||||
__bang_write_value(nram_pong, FOURSPLIT * pongc_slice_align,
|
||||
(T)0);
|
||||
__asm__ volatile("sync;");
|
||||
if (!is_empty) {
|
||||
__memcpy_async(nram_pong,
|
||||
offset_input + y_low * width * channels +
|
||||
x_low * channels + channel_offset,
|
||||
pongc_slice * sizeof(T), GDRAM2NRAM);
|
||||
__memcpy_async(nram_pong + pongc_slice_align,
|
||||
offset_input + y_low * width * channels +
|
||||
x_high * channels + channel_offset,
|
||||
pongc_slice * sizeof(T), GDRAM2NRAM);
|
||||
__memcpy_async(nram_pong + 2 * pongc_slice_align,
|
||||
offset_input + y_high * width * channels +
|
||||
x_low * channels + channel_offset,
|
||||
pongc_slice * sizeof(T), GDRAM2NRAM);
|
||||
__memcpy_async(nram_pong + 3 * pongc_slice_align,
|
||||
offset_input + y_high * width * channels +
|
||||
x_high * channels + channel_offset,
|
||||
pongc_slice * sizeof(T), GDRAM2NRAM);
|
||||
}
|
||||
}
|
||||
__bang_mul_scalar(nram_ping, nram_ping, w1, c_slice_align);
|
||||
__bang_mul_scalar(nram_ping + c_slice_align,
|
||||
nram_ping + c_slice_align, w2, c_slice_align);
|
||||
__bang_add(nram_ping, nram_ping, nram_ping + c_slice_align,
|
||||
c_slice_align);
|
||||
__bang_mul_scalar(nram_ping + 2 * c_slice_align,
|
||||
nram_ping + 2 * c_slice_align, w3, c_slice_align);
|
||||
__bang_add(nram_ping, nram_ping, nram_ping + 2 * c_slice_align,
|
||||
c_slice_align);
|
||||
__bang_mul_scalar(nram_ping + 3 * c_slice_align,
|
||||
nram_ping + 3 * c_slice_align, w4, c_slice_align);
|
||||
__bang_add(nram_ping, nram_ping, nram_ping + 3 * c_slice_align,
|
||||
c_slice_align);
|
||||
__bang_add(nram_out + c_offset, nram_out + c_offset, nram_ping,
|
||||
c_slice_align);
|
||||
T *nram_tmp = nram_ping;
|
||||
nram_ping = nram_pong;
|
||||
nram_pong = nram_tmp;
|
||||
c_rem -= c_slice;
|
||||
c_offset += c_slice;
|
||||
__asm__ volatile("sync;");
|
||||
}
|
||||
}
|
||||
}
|
||||
__bang_mul_scalar(nram_out, nram_out, value_div, channels_align);
|
||||
__memcpy(output + channels * bin_index + channel_offset, nram_out,
|
||||
channels_num * sizeof(T), NRAM2GDRAM);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__mlu_global__ void MLUKernelDeformRoIPoolForward(
|
||||
cnrtDataType_t data_type, const void *input, const void *rois,
|
||||
const void *offset, void *output, const int channels, const int height,
|
||||
const int width, const int num_rois, const int pooled_height,
|
||||
const int pooled_width, const float spatial_scale, const int sampling_ratio,
|
||||
const float gamma) {
|
||||
switch (data_type) {
|
||||
case CNRT_FLOAT16: {
|
||||
MLUUnion1DeformRoIPoolForward((half *)input, (half *)rois, (half *)offset,
|
||||
(half *)output, channels, height, width,
|
||||
num_rois, pooled_height, pooled_width,
|
||||
static_cast<half>(spatial_scale),
|
||||
sampling_ratio, static_cast<half>(gamma));
|
||||
}; break;
|
||||
case CNRT_FLOAT32: {
|
||||
MLUUnion1DeformRoIPoolForward(
|
||||
(float *)input, (float *)rois, (float *)offset, (float *)output,
|
||||
channels, height, width, num_rois, pooled_height, pooled_width,
|
||||
static_cast<float>(spatial_scale), sampling_ratio,
|
||||
static_cast<float>(gamma));
|
||||
}; break;
|
||||
default: {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void KernelDeformRoIPoolForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
|
||||
cnrtQueue_t queue, cnrtDataType_t data_type,
|
||||
const void *input, const void *rois,
|
||||
const void *offset, void *output,
|
||||
const int channels, const int height,
|
||||
const int width, const int num_rois,
|
||||
const int pooled_height, const int pooled_width,
|
||||
const float spatial_scale,
|
||||
const int sampling_ratio, const float gamma) {
|
||||
MLUKernelDeformRoIPoolForward<<<k_dim, k_type, queue>>>(
|
||||
data_type, input, rois, offset, output, channels, height, width, num_rois,
|
||||
pooled_height, pooled_width, spatial_scale, sampling_ratio, gamma);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void MLUUnion1DeformRoIPoolBackward(
|
||||
const T *grad_output, const T *input, const T *rois, const T *offset,
|
||||
T *grad_input, T *grad_offset, const int channels, const int height,
|
||||
const int width, const int num_rois, const int pooled_height,
|
||||
const int pooled_width, const T spatial_scale, const int sampling_ratio,
|
||||
const T gamma) {
|
||||
for (int bin_index = taskId;
|
||||
bin_index < num_rois * pooled_width * pooled_height;
|
||||
bin_index += taskDim) {
|
||||
int out_batch = bin_index / pooled_width / pooled_height;
|
||||
int out_height = bin_index / pooled_width % pooled_height;
|
||||
int out_width = bin_index % pooled_width;
|
||||
const T *cur_roi = rois + out_batch * ROI_OFFSET;
|
||||
T *nram_rois = (T *)nram_buffer;
|
||||
__memcpy((void *)nram_rois, (void *)cur_roi, ROI_OFFSET * sizeof(T),
|
||||
GDRAM2NRAM);
|
||||
const int roi_batch = nram_rois[0];
|
||||
T roi_x_min = nram_rois[1] * spatial_scale - 0.5;
|
||||
T roi_y_min = nram_rois[2] * spatial_scale - 0.5;
|
||||
const T roi_x_max = nram_rois[3] * spatial_scale - 0.5;
|
||||
const T roi_y_max = nram_rois[4] * spatial_scale - 0.5;
|
||||
const T roi_width = roi_x_max - roi_x_min;
|
||||
const T roi_height = roi_y_max - roi_y_min;
|
||||
const T bin_width = roi_width / static_cast<T>(pooled_width);
|
||||
const T bin_height = roi_height / static_cast<T>(pooled_height);
|
||||
const T *offset_input = input + roi_batch * height * width * channels;
|
||||
T *offset_grad_input = grad_input + roi_batch * height * width * channels;
|
||||
int roi_bin_grid_height =
|
||||
(sampling_ratio > 0)
|
||||
? sampling_ratio
|
||||
: static_cast<int>(ceilf(roi_height / pooled_height));
|
||||
int roi_bin_grid_width =
|
||||
(sampling_ratio > 0)
|
||||
? sampling_ratio
|
||||
: static_cast<int>(ceilf(roi_width / pooled_width));
|
||||
if (offset != NULL) {
|
||||
const T *offset_cur = offset +
|
||||
out_batch * pooled_width * pooled_height * 2 +
|
||||
out_height * pooled_width + out_width;
|
||||
roi_x_min += gamma * roi_width * offset_cur[0];
|
||||
roi_y_min +=
|
||||
gamma * roi_height * offset_cur[pooled_width * pooled_height];
|
||||
}
|
||||
|
||||
/* NRAM partition
|
||||
*
|
||||
* If offset != NULL, NRAM partition belows.
|
||||
* | |
|
||||
* ping | pong |
|
||||
* |---------------------------------------------------------------------|-----------|-----------|
|
||||
* |nram_tmp1|nram_tmp2|nram_tmp3|nram_tmp4|nram_grad_output|nram_sum_tmp|p1|p2|p3|p4|p1|p2|p3|p4|
|
||||
*
|
||||
* If offset == NULL, ping and pang will not be needed.
|
||||
* | |
|
||||
* |----------------------------------------------------------------------------------|
|
||||
* | nram_tmp1 | nram_tmp2 | nram_tmp3 | nram_tmp4 | nram_grad_output |
|
||||
*
|
||||
*/
|
||||
|
||||
int type_align = NFU_ALIGN_SIZE / sizeof(T);
|
||||
int channels_max_num_nram = MAX_NRAM_SIZE / sizeof(T);
|
||||
int channels_nram_split =
|
||||
channels_max_num_nram / FIVESPLIT / type_align * type_align;
|
||||
int channel_rem = channels % channels_nram_split;
|
||||
int channel_loops =
|
||||
channels / channels_nram_split + (channel_rem != 0 ? 1 : 0);
|
||||
if (offset != NULL) {
|
||||
channels_nram_split =
|
||||
channels_max_num_nram / THIRTEENSPLIT / type_align * type_align;
|
||||
channel_rem = channels % channels_nram_split;
|
||||
channel_loops =
|
||||
channels / channels_nram_split + (channel_rem != 0 ? 1 : 0);
|
||||
}
|
||||
|
||||
for (int channel_loop_index = 0; channel_loop_index < channel_loops;
|
||||
++channel_loop_index) {
|
||||
int channels_num =
|
||||
channels_nram_split >= channels ? channels : channels_nram_split;
|
||||
const int channel_offset = channel_loop_index * channels_num;
|
||||
if (channel_loop_index + 1 == channel_loops && channel_rem != 0) {
|
||||
channels_num = channel_rem;
|
||||
}
|
||||
int channels_align = CEIL_ALIGN(channels_num, type_align);
|
||||
const int32_t nram_sum_tmp_channel = NFU_ALIGN_SIZE / sizeof(T);
|
||||
int nram_limit = (MAX_NRAM_SIZE / sizeof(T) - 5 * channels_align -
|
||||
nram_sum_tmp_channel) >>
|
||||
1;
|
||||
int c_slice = 0;
|
||||
int c_slice_align = 0;
|
||||
T *nram_tmp1 = (T *)nram_buffer;
|
||||
T *nram_tmp2 = (T *)nram_buffer + channels_align;
|
||||
T *nram_tmp3 = (T *)nram_buffer + 2 * channels_align;
|
||||
T *nram_tmp4 = (T *)nram_buffer + 3 * channels_align;
|
||||
T *nram_grad_output = nram_tmp4 + channels_align;
|
||||
T *nram_sum_tmp = NULL;
|
||||
T *nram_ping_input = NULL;
|
||||
T *nram_pong_input = NULL;
|
||||
__bang_write_value((T *)nram_grad_output, channels_align, (T)0);
|
||||
__asm__ volatile("sync;");
|
||||
|
||||
if (offset != NULL) {
|
||||
c_slice = nram_limit / FOURSPLIT / type_align * type_align;
|
||||
nram_sum_tmp = nram_grad_output + channels_align;
|
||||
nram_ping_input = nram_sum_tmp + nram_sum_tmp_channel;
|
||||
nram_pong_input = nram_ping_input + FOURSPLIT * c_slice;
|
||||
__bang_write_value((T *)nram_sum_tmp, nram_sum_tmp_channel, (T)0);
|
||||
__bang_write_value((T *)nram_ping_input, FOURSPLIT * c_slice, (T)0);
|
||||
__bang_write_value((T *)nram_pong_input, FOURSPLIT * c_slice, (T)0);
|
||||
__asm__ volatile("sync;");
|
||||
}
|
||||
const T num_bins =
|
||||
static_cast<T>(max(roi_bin_grid_height * roi_bin_grid_width, 1));
|
||||
const T value_div = 1.0f / num_bins;
|
||||
bool is_ping_empty = true;
|
||||
__memcpy(nram_grad_output,
|
||||
grad_output + channels * bin_index + channel_offset,
|
||||
channels_num * sizeof(T), GDRAM2NRAM);
|
||||
__bang_mul_scalar(nram_grad_output, nram_grad_output, value_div,
|
||||
channels_align);
|
||||
for (int iy = 0; iy < roi_bin_grid_height; ++iy) {
|
||||
T y = roi_y_min + out_height * bin_height +
|
||||
static_cast<T>(iy + .5f) * bin_height /
|
||||
static_cast<T>(roi_bin_grid_height);
|
||||
T y_tmp = y;
|
||||
if (y_tmp < -1.0 || y_tmp > height) {
|
||||
is_ping_empty = true;
|
||||
continue;
|
||||
}
|
||||
if (y_tmp <= 0) {
|
||||
y_tmp = 0;
|
||||
}
|
||||
int y_low = 0, y_high = 0;
|
||||
y_low = int(y_tmp);
|
||||
if (y_low >= height - 1) {
|
||||
y_high = y_low = height - 1;
|
||||
y_tmp = T(y_low);
|
||||
} else {
|
||||
y_high = y_low + 1;
|
||||
}
|
||||
for (int ix = 0; ix < roi_bin_grid_width; ++ix) {
|
||||
T x = roi_x_min + out_width * bin_width +
|
||||
static_cast<T>(ix + .5f) * bin_width /
|
||||
static_cast<T>(roi_bin_grid_width);
|
||||
const int sample_index = iy * roi_bin_grid_width + ix;
|
||||
int c_rem = channels_num;
|
||||
bool is_empty = false;
|
||||
T w1, w2, w3, w4;
|
||||
int x_low = 0, x_high = 0;
|
||||
bilinearInterpolate(width, y_tmp, x, &w1, &w2, &w3, &w4, &x_low,
|
||||
&x_high, y_low, &is_empty);
|
||||
if (is_empty) {
|
||||
is_ping_empty = true;
|
||||
continue;
|
||||
}
|
||||
__bang_mul_scalar((T *)nram_tmp1, (T *)nram_grad_output, w1,
|
||||
channels_align);
|
||||
__bang_mul_scalar((T *)nram_tmp2, (T *)nram_grad_output, w2,
|
||||
channels_align);
|
||||
__bang_mul_scalar((T *)nram_tmp3, (T *)nram_grad_output, w3,
|
||||
channels_align);
|
||||
__bang_mul_scalar((T *)nram_tmp4, (T *)nram_grad_output, w4,
|
||||
channels_align);
|
||||
__asm__ volatile("sync;");
|
||||
__bang_atomic_add(
|
||||
(T *)nram_tmp1,
|
||||
(T *)(offset_grad_input + (y_low * width + x_low) * channels +
|
||||
channel_offset),
|
||||
(T *)nram_tmp1, channels_num);
|
||||
__bang_atomic_add(
|
||||
(T *)nram_tmp2,
|
||||
(T *)(offset_grad_input + (y_low * width + x_high) * channels +
|
||||
channel_offset),
|
||||
(T *)nram_tmp2, channels_num);
|
||||
__bang_atomic_add(
|
||||
(T *)nram_tmp3,
|
||||
(T *)(offset_grad_input + (y_high * width + x_low) * channels +
|
||||
channel_offset),
|
||||
(T *)nram_tmp3, channels_num);
|
||||
__bang_atomic_add(
|
||||
(T *)nram_tmp4,
|
||||
(T *)(offset_grad_input + (y_high * width + x_high) * channels +
|
||||
channel_offset),
|
||||
(T *)nram_tmp4, channels_num);
|
||||
if (offset != NULL) {
|
||||
c_slice = nram_limit / FOURSPLIT / type_align * type_align;
|
||||
c_slice_align = 0;
|
||||
if (is_ping_empty) {
|
||||
c_slice = c_slice > c_rem ? c_rem : c_slice;
|
||||
c_slice_align = CEIL_ALIGN(c_slice, type_align);
|
||||
__bang_write_value(nram_ping_input, FOURSPLIT * c_slice_align,
|
||||
(T)0);
|
||||
__asm__ volatile("sync;");
|
||||
const T *src_offset1 = offset_input + y_low * width * channels +
|
||||
x_low * channels + channel_offset;
|
||||
const T *src_offset2 = offset_input + y_low * width * channels +
|
||||
x_high * channels + channel_offset;
|
||||
const T *src_offset3 = offset_input + y_high * width * channels +
|
||||
x_low * channels + channel_offset;
|
||||
const T *src_offset4 = offset_input + y_high * width * channels +
|
||||
x_high * channels + channel_offset;
|
||||
__memcpy(nram_ping_input, src_offset1, c_slice * sizeof(T),
|
||||
GDRAM2NRAM);
|
||||
__memcpy(nram_ping_input + c_slice_align, src_offset2,
|
||||
c_slice * sizeof(T), GDRAM2NRAM);
|
||||
__memcpy(nram_ping_input + 2 * c_slice_align, src_offset3,
|
||||
c_slice * sizeof(T), GDRAM2NRAM);
|
||||
__memcpy(nram_ping_input + 3 * c_slice_align, src_offset4,
|
||||
c_slice * sizeof(T), GDRAM2NRAM);
|
||||
is_ping_empty = false;
|
||||
}
|
||||
int c_offset = 0;
|
||||
int pongc_slice = 0;
|
||||
int pongc_slice_align = 0;
|
||||
while (c_rem > 0) {
|
||||
c_slice = c_slice > c_rem ? c_rem : c_slice;
|
||||
c_slice_align = CEIL_ALIGN(c_slice, type_align);
|
||||
if (sample_index + 1 < roi_bin_grid_height * roi_bin_grid_width) {
|
||||
int iy_tmp = (sample_index + 1) / roi_bin_grid_width;
|
||||
int ix_tmp = (sample_index + 1) % roi_bin_grid_width;
|
||||
T y_tmp = roi_y_min + out_height * bin_height +
|
||||
static_cast<T>(iy_tmp + .5f) * bin_height /
|
||||
static_cast<T>(roi_bin_grid_height);
|
||||
T x_tmp = roi_x_min + out_width * bin_width +
|
||||
static_cast<T>(ix_tmp + .5f) * bin_width /
|
||||
static_cast<T>(roi_bin_grid_width);
|
||||
int x_low_tmp = 0, x_high_tmp = 0, y_low_tmp = 0,
|
||||
y_high_tmp = 0;
|
||||
if (y_tmp < -1.0 || y_tmp > height) {
|
||||
is_empty = true;
|
||||
} else {
|
||||
T w1_tmp, w2_tmp, w3_tmp, w4_tmp;
|
||||
if (y_tmp <= 0) {
|
||||
y_tmp = 0;
|
||||
}
|
||||
y_low_tmp = int(y_tmp);
|
||||
if (y_low_tmp >= height - 1) {
|
||||
y_high_tmp = y_low_tmp = height - 1;
|
||||
y_tmp = T(y_low_tmp);
|
||||
} else {
|
||||
y_high_tmp = y_low_tmp + 1;
|
||||
}
|
||||
bilinearInterpolate(width, y_tmp, x_tmp, &w1_tmp, &w2_tmp,
|
||||
&w3_tmp, &w4_tmp, &x_low_tmp, &x_high_tmp,
|
||||
y_low_tmp, &is_empty);
|
||||
}
|
||||
pongc_slice = nram_limit / FOURSPLIT / type_align * type_align;
|
||||
pongc_slice =
|
||||
pongc_slice > channels_num ? channels_num : pongc_slice;
|
||||
pongc_slice_align = CEIL_ALIGN(pongc_slice, type_align);
|
||||
__bang_write_value(nram_pong_input,
|
||||
FOURSPLIT * pongc_slice_align, (T)0);
|
||||
__asm__ volatile("sync;");
|
||||
if (!is_empty) {
|
||||
const T *src_offset1 = offset_input +
|
||||
y_low_tmp * width * channels +
|
||||
x_low_tmp * channels + channel_offset;
|
||||
const T *src_offset2 = offset_input +
|
||||
y_low_tmp * width * channels +
|
||||
x_high_tmp * channels + channel_offset;
|
||||
const T *src_offset3 = offset_input +
|
||||
y_high_tmp * width * channels +
|
||||
x_low_tmp * channels + channel_offset;
|
||||
const T *src_offset4 = offset_input +
|
||||
y_high_tmp * width * channels +
|
||||
x_high_tmp * channels + channel_offset;
|
||||
__memcpy_async(nram_pong_input, src_offset1,
|
||||
pongc_slice * sizeof(T), GDRAM2NRAM);
|
||||
__memcpy_async(nram_pong_input + pongc_slice_align,
|
||||
src_offset2, pongc_slice * sizeof(T),
|
||||
GDRAM2NRAM);
|
||||
__memcpy_async(nram_pong_input + 2 * pongc_slice_align,
|
||||
src_offset3, pongc_slice * sizeof(T),
|
||||
GDRAM2NRAM);
|
||||
__memcpy_async(nram_pong_input + 3 * pongc_slice_align,
|
||||
src_offset4, pongc_slice * sizeof(T),
|
||||
GDRAM2NRAM);
|
||||
}
|
||||
}
|
||||
|
||||
__bang_mul_scalar(nram_tmp1, nram_ping_input + 3 * c_slice_align,
|
||||
y - y_low, c_slice_align);
|
||||
__bang_mul_scalar(nram_tmp2, nram_ping_input + c_slice_align,
|
||||
y_high - y, c_slice_align);
|
||||
__bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align);
|
||||
__bang_mul_scalar(nram_tmp2, nram_ping_input + 2 * c_slice_align,
|
||||
y_low - y, c_slice_align);
|
||||
__bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align);
|
||||
__bang_mul_scalar(nram_tmp2, nram_ping_input, y - y_high,
|
||||
c_slice_align);
|
||||
__bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align);
|
||||
__bang_mul_scalar(nram_tmp1, nram_tmp1, gamma * roi_width,
|
||||
c_slice_align);
|
||||
__bang_mul(nram_tmp1, nram_grad_output, nram_tmp1, c_slice_align);
|
||||
const int32_t kernel_width =
|
||||
c_slice_align / nram_sum_tmp_channel +
|
||||
(int32_t)(c_slice_align % nram_sum_tmp_channel > 0);
|
||||
__bang_sumpool(nram_sum_tmp, nram_tmp1, nram_sum_tmp_channel, 1,
|
||||
kernel_width, 1, kernel_width, kernel_width, 1);
|
||||
__bang_reduce_sum(nram_sum_tmp, nram_sum_tmp,
|
||||
nram_sum_tmp_channel);
|
||||
__bang_atomic_add(
|
||||
(T *)nram_sum_tmp,
|
||||
(T *)(grad_offset +
|
||||
out_batch * pooled_width * pooled_height * 2 +
|
||||
out_height * pooled_width + out_width),
|
||||
(T *)nram_sum_tmp, 1);
|
||||
__bang_write_value((T *)nram_sum_tmp, nram_sum_tmp_channel, (T)0);
|
||||
__bang_mul_scalar(nram_tmp1, nram_ping_input + 3 * c_slice_align,
|
||||
x - x_low, c_slice_align);
|
||||
__bang_mul_scalar(nram_tmp2, nram_ping_input + 2 * c_slice_align,
|
||||
x_high - x, c_slice_align);
|
||||
__bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align);
|
||||
__bang_mul_scalar(nram_tmp2, nram_ping_input + c_slice_align,
|
||||
x_low - x, c_slice_align);
|
||||
__bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align);
|
||||
__bang_mul_scalar(nram_tmp2, nram_ping_input, x - x_high,
|
||||
c_slice_align);
|
||||
__bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align);
|
||||
__bang_mul_scalar(nram_tmp1, nram_tmp1, gamma * roi_height,
|
||||
c_slice_align);
|
||||
__bang_mul(nram_tmp1, nram_grad_output, nram_tmp1, c_slice_align);
|
||||
__bang_sumpool(nram_sum_tmp, nram_tmp1, nram_sum_tmp_channel, 1,
|
||||
kernel_width, 1, kernel_width, kernel_width, 1);
|
||||
__bang_reduce_sum(nram_sum_tmp, nram_sum_tmp,
|
||||
NFU_ALIGN_SIZE / sizeof(T));
|
||||
__bang_atomic_add(
|
||||
(T *)nram_sum_tmp,
|
||||
(T *)(grad_offset +
|
||||
out_batch * pooled_width * pooled_height * 2 +
|
||||
pooled_width * pooled_height +
|
||||
out_height * pooled_width + out_width),
|
||||
(T *)nram_sum_tmp, 1);
|
||||
|
||||
T *nram_tmp = nram_ping_input;
|
||||
nram_ping_input = nram_pong_input;
|
||||
nram_pong_input = nram_tmp;
|
||||
c_rem -= c_slice;
|
||||
c_offset += c_slice;
|
||||
__asm__ volatile("sync;");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__mlu_global__ void MLUKernelDeformRoIPoolBackward(
|
||||
cnrtDataType_t data_type, const void *grad_output, const void *input,
|
||||
const void *rois, const void *offset, void *grad_input, void *grad_offset,
|
||||
const int channels, const int height, const int width, const int num_rois,
|
||||
const int pooled_height, const int pooled_width, const float spatial_scale,
|
||||
const int sampling_ratio, const float gamma) {
|
||||
switch (data_type) {
|
||||
case CNRT_FLOAT16: {
|
||||
MLUUnion1DeformRoIPoolBackward(
|
||||
(half *)grad_output, (half *)input, (half *)rois, (half *)offset,
|
||||
(half *)grad_input, (half *)grad_offset, channels, height, width,
|
||||
num_rois, pooled_height, pooled_width,
|
||||
static_cast<half>(spatial_scale), sampling_ratio,
|
||||
static_cast<half>(gamma));
|
||||
}; break;
|
||||
case CNRT_FLOAT32: {
|
||||
MLUUnion1DeformRoIPoolBackward(
|
||||
(float *)grad_output, (float *)input, (float *)rois, (float *)offset,
|
||||
(float *)grad_input, (float *)grad_offset, channels, height, width,
|
||||
num_rois, pooled_height, pooled_width,
|
||||
static_cast<float>(spatial_scale), sampling_ratio,
|
||||
static_cast<float>(gamma));
|
||||
}; break;
|
||||
default: {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void KernelDeformRoIPoolBackward(
|
||||
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
|
||||
cnrtDataType_t data_type, const void *grad_output, const void *input,
|
||||
const void *rois, const void *offset, void *grad_input, void *grad_offset,
|
||||
const int channels, const int height, const int width, const int num_rois,
|
||||
const int pooled_height, const int pooled_width, const float spatial_scale,
|
||||
const int sampling_ratio, const float gamma) {
|
||||
MLUKernelDeformRoIPoolBackward<<<k_dim, k_type, queue>>>(
|
||||
data_type, grad_output, input, rois, offset, grad_input, grad_offset,
|
||||
channels, height, width, num_rois, pooled_height, pooled_width,
|
||||
spatial_scale, sampling_ratio, gamma);
|
||||
}
|
|
@ -9,254 +9,59 @@
|
|||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#include "pytorch_device_registry.hpp"
|
||||
#include "pytorch_mlu_helper.hpp"
|
||||
|
||||
void KernelDeformRoIPoolForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
|
||||
cnrtQueue_t queue, cnrtDataType_t data_type,
|
||||
const void *input, const void *rois,
|
||||
const void *offset, void *output,
|
||||
const int channels, const int height,
|
||||
const int width, const int num_rois,
|
||||
const int pooled_height, const int pooled_width,
|
||||
const float spatial_scale,
|
||||
const int sampling_ratio, const float gamma);
|
||||
|
||||
void KernelDeformRoIPoolBackward(
|
||||
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
|
||||
cnrtDataType_t data_type, const void *grad_output, const void *input,
|
||||
const void *rois, const void *offset, void *grad_input, void *grad_offset,
|
||||
const int channels, const int height, const int width, const int num_rois,
|
||||
const int pooled_height, const int pooled_width, const float spatial_scale,
|
||||
const int sampling_ratio, const float gamma);
|
||||
|
||||
// policy function for forward and backward
|
||||
static void policyFunc(const int bin_num, cnrtDim3_t *k_dim,
|
||||
cnrtFunctionType_t *k_type) {
|
||||
const size_t cluster_limit = torch_mlu::getDeviceAttr(cnrtAttrClusterCount);
|
||||
;
|
||||
const size_t core_limit = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
|
||||
const size_t bin_num_align = CEIL_ALIGN(bin_num, core_limit);
|
||||
k_dim->x = core_limit;
|
||||
k_dim->y = (bin_num_align / core_limit) > cluster_limit
|
||||
? cluster_limit
|
||||
: (bin_num_align / core_limit);
|
||||
k_dim->z = 1;
|
||||
*k_type = CNRT_FUNC_TYPE_UNION1;
|
||||
}
|
||||
#include "mlu_common_helper.h"
|
||||
|
||||
void DeformRoIPoolForwardMLUKernelLauncher(Tensor input, Tensor rois,
|
||||
Tensor offset, Tensor output,
|
||||
int pooled_height, int pooled_width,
|
||||
float spatial_scale,
|
||||
int sampling_ratio, float gamma) {
|
||||
// Check dtype.
|
||||
TORCH_CHECK(
|
||||
input.scalar_type() == at::kFloat || input.scalar_type() == at::kHalf,
|
||||
"input type should be Float or Half, got ", input.scalar_type());
|
||||
TORCH_CHECK(input.scalar_type() == rois.scalar_type(),
|
||||
"rois should have the same type as input");
|
||||
|
||||
// Check shape.
|
||||
TORCH_CHECK(input.dim() == 4, "input should be 4d tensor, got ", input.dim(),
|
||||
"D.");
|
||||
TORCH_CHECK(rois.dim() == 2, "rois should be 2d tensor, got ", rois.dim(),
|
||||
"D.");
|
||||
if (offset.defined() && offset.numel() > 0) {
|
||||
TORCH_CHECK(input.scalar_type() == offset.scalar_type(),
|
||||
"offset should have the same type as input");
|
||||
TORCH_CHECK(offset.dim() == 4, "offset should be 4d tensor, got ",
|
||||
offset.dim(), "D.");
|
||||
TORCH_CHECK(
|
||||
(offset.size(0) == rois.size(0)), "offset.size(0) = ", offset.size(0),
|
||||
"while rois.size(0)) = ", rois.size(0), ". They should be the same.");
|
||||
TORCH_CHECK((offset.size(1) == 2), "offset.size(1) should be 2, ",
|
||||
"but now offset.size(1) = ", offset.size(1), ".");
|
||||
TORCH_CHECK((offset.size(2) == output.size(2)),
|
||||
"offset.size(2) = ", offset.size(2),
|
||||
"while output.size(2)) = ", output.size(2),
|
||||
". They should be the same.");
|
||||
TORCH_CHECK((offset.size(3) == output.size(3)),
|
||||
"offset.size(3) = ", offset.size(3),
|
||||
"while output.size(3)) = ", output.size(3),
|
||||
". They should be the same.");
|
||||
}
|
||||
|
||||
TORCH_CHECK(spatial_scale > 0 && spatial_scale <= 1,
|
||||
"spatial_scale should be within (0, 1], got ", spatial_scale,
|
||||
".");
|
||||
|
||||
// compute kernel params
|
||||
auto height = input.size(2);
|
||||
auto width = input.size(3);
|
||||
auto channels = input.size(1);
|
||||
auto num_rois = output.size(0);
|
||||
|
||||
if (output.numel() == 0) {
|
||||
output = at::zeros({num_rois, channels, pooled_height, pooled_width},
|
||||
input.options());
|
||||
return;
|
||||
}
|
||||
|
||||
// zero element check
|
||||
TORCH_CHECK(input.size(0) != 0, "input.size(0) should not be zero, got ",
|
||||
input.size(0));
|
||||
TORCH_CHECK(rois.numel() != 0, "rois.numel() should not be zero, got ",
|
||||
rois.numel());
|
||||
if (input.numel() == 0 || output.numel() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// large tensor check
|
||||
const size_t max_input_num = 2147483648; // 2^31, 2G num
|
||||
TORCH_CHECK(input.numel() < max_input_num,
|
||||
"input.numel() should be less than 2147483648, got ",
|
||||
input.numel());
|
||||
TORCH_CHECK(rois.numel() < max_input_num,
|
||||
"rois.numel() should be less than 2147483648, got ",
|
||||
rois.numel());
|
||||
TORCH_CHECK(output.numel() < max_input_num,
|
||||
"output.numel() should be less than 2147483648, got ",
|
||||
output.numel());
|
||||
TORCH_CHECK(!offset.defined() || offset.numel() < max_input_num,
|
||||
"offset.numel() should be less than 2147483648, got ",
|
||||
offset.numel());
|
||||
|
||||
auto memory_format =
|
||||
torch_mlu::cnnl::ops::get_channels_last_memory_format(input.dim());
|
||||
auto input_ = torch_mlu::cnnl::ops::cnnl_contiguous(input, memory_format);
|
||||
auto rois_contiguous =
|
||||
torch_mlu::cnnl::ops::cnnl_contiguous(rois, rois.suggest_memory_format());
|
||||
auto output_contiguous =
|
||||
torch_mlu::cnnl::ops::cnnl_contiguous(output, memory_format);
|
||||
|
||||
at::Tensor output_ =
|
||||
at::empty({num_rois, channels, pooled_height, pooled_width},
|
||||
input.options(), memory_format);
|
||||
MluOpTensorDescriptor input_desc, rois_desc, offset_desc, output_desc;
|
||||
input_desc.set_with_layout(input_, MLUOP_LAYOUT_NHWC);
|
||||
rois_desc.set(rois_contiguous);
|
||||
output_desc.set_with_layout(output_contiguous, MLUOP_LAYOUT_NHWC);
|
||||
|
||||
// calculate task dimension
|
||||
cnrtDim3_t k_dim;
|
||||
cnrtFunctionType_t k_type;
|
||||
policyFunc(num_rois * pooled_height * pooled_width, &k_dim, &k_type);
|
||||
|
||||
// get compute queue
|
||||
auto queue = torch_mlu::getCurQueue();
|
||||
mluOpTensorDescriptor_t offset_real_desc = NULL;
|
||||
void *offset_ptr = NULL;
|
||||
if (offset.defined() && offset.numel() > 0) {
|
||||
auto offset_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
|
||||
offset, offset.suggest_memory_format());
|
||||
offset_desc.set(offset_contiguous);
|
||||
offset_real_desc = offset_desc.desc();
|
||||
auto offset_impl = torch_mlu::getMluTensorImpl(offset_contiguous);
|
||||
offset_ptr = offset_impl->cnnlMalloc();
|
||||
}
|
||||
|
||||
// get ptr of tensors
|
||||
auto input_impl = torch_mlu::getMluTensorImpl(input_);
|
||||
auto input_ptr = input_impl->cnnlMalloc();
|
||||
auto rois_impl = torch_mlu::getMluTensorImpl(rois);
|
||||
auto rois_impl = torch_mlu::getMluTensorImpl(rois_contiguous);
|
||||
auto rois_ptr = rois_impl->cnnlMalloc();
|
||||
auto offset_impl = torch_mlu::getMluTensorImpl(offset);
|
||||
auto offset_ptr = offset_impl->cnnlMalloc();
|
||||
auto output_impl = torch_mlu::getMluTensorImpl(output_);
|
||||
auto output_impl = torch_mlu::getMluTensorImpl(output_contiguous);
|
||||
auto output_ptr = output_impl->cnnlMalloc();
|
||||
|
||||
// get comput dtype of input
|
||||
cnrtDataType_t data_type = torch_mlu::toCnrtDtype(input_.dtype());
|
||||
// get compute handle
|
||||
auto handle = mluOpGetCurrentHandle();
|
||||
mluOpDeformRoiPoolForward(
|
||||
handle, input_desc.desc(), input_ptr, rois_desc.desc(), rois_ptr,
|
||||
offset_real_desc, offset_ptr, pooled_height, pooled_width, spatial_scale,
|
||||
sampling_ratio, gamma, output_desc.desc(), output_ptr);
|
||||
|
||||
// launch kernel
|
||||
CNLOG(INFO) << "Launch Kernel MLUKernelDeformRoIPoolForward<<<" << k_dim.x
|
||||
<< ", " << k_dim.y << ", " << k_dim.z << ">>>";
|
||||
|
||||
KernelDeformRoIPoolForward(k_dim, k_type, queue, data_type, input_ptr,
|
||||
rois_ptr, offset_ptr, output_ptr, channels, height,
|
||||
width, num_rois, pooled_height, pooled_width,
|
||||
spatial_scale, sampling_ratio, gamma);
|
||||
|
||||
output.copy_(output_);
|
||||
output.copy_(output_contiguous);
|
||||
}
|
||||
|
||||
void DeformRoIPoolBackwardMLUKernelLauncher(
|
||||
Tensor grad_output, Tensor input, Tensor rois, Tensor offset,
|
||||
Tensor grad_input, Tensor grad_offset, int pooled_height, int pooled_width,
|
||||
float spatial_scale, int sampling_ratio, float gamma) {
|
||||
// Check dtype.
|
||||
TORCH_CHECK(
|
||||
input.scalar_type() == at::kFloat || input.scalar_type() == at::kHalf,
|
||||
"input type should be Float or Half, got ", input.scalar_type());
|
||||
TORCH_CHECK(input.scalar_type() == grad_output.scalar_type(),
|
||||
"grad_output should have the same type as input");
|
||||
TORCH_CHECK(input.scalar_type() == rois.scalar_type(),
|
||||
"rois should have the same type as input");
|
||||
TORCH_CHECK(input.scalar_type() == grad_input.scalar_type(),
|
||||
"grad_input should have the same type as input");
|
||||
|
||||
// Check shape.
|
||||
TORCH_CHECK(grad_output.dim() == 4, "grad_output should be 4d tensor, got ",
|
||||
grad_output.dim(), "D.");
|
||||
TORCH_CHECK(input.dim() == 4, "input should be 4d tensor, got ", input.dim(),
|
||||
"D.");
|
||||
TORCH_CHECK(rois.dim() == 2, "rois should be 2d tensor, got ", rois.dim(),
|
||||
"D.");
|
||||
if (offset.defined() && offset.numel() > 0) {
|
||||
TORCH_CHECK(input.scalar_type() == offset.scalar_type(),
|
||||
"offset should have the same type as input");
|
||||
TORCH_CHECK(offset.dim() == 4, "offset should be 4d tensor, got ",
|
||||
offset.dim(), "D.");
|
||||
TORCH_CHECK(
|
||||
(offset.size(0) == rois.size(0)), "offset.size(0) = ", offset.size(0),
|
||||
"while rois.size(0)) = ", rois.size(0), ". They should be the same.");
|
||||
TORCH_CHECK((offset.size(1) == 2), "offset.size(1) should be 2, ",
|
||||
"but now offset.size(1) = ", offset.size(1), ".");
|
||||
TORCH_CHECK((offset.size(2) == grad_output.size(2)),
|
||||
"offset.size(2) = ", offset.size(2),
|
||||
"while grad_output.size(2)) = ", grad_output.size(2),
|
||||
". They should be the same.");
|
||||
TORCH_CHECK((offset.size(3) == grad_output.size(3)),
|
||||
"offset.size(3) = ", offset.size(3),
|
||||
"while grad_output.size(3)) = ", grad_output.size(3),
|
||||
". They should be the same.");
|
||||
}
|
||||
|
||||
TORCH_CHECK(spatial_scale > 0 && spatial_scale <= 1,
|
||||
"spatial_scale should be within (0, 1], got ", spatial_scale);
|
||||
|
||||
// Check relationship between tensor.
|
||||
TORCH_CHECK((grad_output.size(0) == rois.size(0)),
|
||||
"grad_output.size(0) = ", grad_output.size(0),
|
||||
"while rois.size(0)) = ", rois.size(0),
|
||||
". They should be the same.");
|
||||
TORCH_CHECK((grad_output.size(1) == input.size(1)),
|
||||
"grad_output.size(1) = ", grad_output.size(1),
|
||||
"while input.size(1)) = ", input.size(1),
|
||||
". They should be the same.");
|
||||
TORCH_CHECK((grad_output.size(2) == pooled_height),
|
||||
"grad_output.size(2) = ", grad_output.size(2),
|
||||
"while pooled_height = ", pooled_height,
|
||||
". They should be the same.");
|
||||
TORCH_CHECK((grad_output.size(3) == pooled_width),
|
||||
"grad_output.size(3) = ", grad_output.size(3),
|
||||
"while pooled_width = ", pooled_width,
|
||||
". They should be the same.");
|
||||
|
||||
// compute kernel params
|
||||
auto batch = input.size(0);
|
||||
auto channels = input.size(1);
|
||||
auto height = input.size(2);
|
||||
auto width = input.size(3);
|
||||
auto num_rois = grad_output.size(0);
|
||||
|
||||
// zero element check
|
||||
TORCH_CHECK(input.size(0) != 0, "input.size(0) should not be zero, got ",
|
||||
input.size(0));
|
||||
TORCH_CHECK(rois.numel() != 0, "rois.numel() should not be zero, got ",
|
||||
rois.numel());
|
||||
if (input.numel() == 0 || grad_output.numel() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// large tensor check
|
||||
const size_t max_input_num = 2147483648; // 2^31, 2G num
|
||||
TORCH_CHECK(input.numel() < max_input_num,
|
||||
"input.numel() should be less than 2147483648, got ",
|
||||
input.numel());
|
||||
TORCH_CHECK(rois.numel() < max_input_num,
|
||||
"rois.numel() should be less than 2147483648, got ",
|
||||
rois.numel());
|
||||
TORCH_CHECK(grad_output.numel() < max_input_num,
|
||||
"grad_output.numel() should be less than 2147483648, got ",
|
||||
grad_output.numel());
|
||||
TORCH_CHECK(!offset.defined() || offset.numel() < max_input_num,
|
||||
"offset.numel() should be less than 2147483648, got ",
|
||||
offset.numel());
|
||||
|
||||
auto memory_format =
|
||||
torch_mlu::cnnl::ops::get_channels_last_memory_format(grad_output.dim());
|
||||
auto grad_output_ =
|
||||
|
@ -264,45 +69,56 @@ void DeformRoIPoolBackwardMLUKernelLauncher(
|
|||
memory_format =
|
||||
torch_mlu::cnnl::ops::get_channels_last_memory_format(input.dim());
|
||||
auto input_ = torch_mlu::cnnl::ops::cnnl_contiguous(input, memory_format);
|
||||
at::Tensor grad_input_ = at::empty({batch, channels, height, width},
|
||||
input.options(), memory_format)
|
||||
.zero_();
|
||||
|
||||
// calculate task dimension
|
||||
cnrtDim3_t k_dim;
|
||||
cnrtFunctionType_t k_type;
|
||||
policyFunc(num_rois * pooled_height * pooled_width, &k_dim, &k_type);
|
||||
|
||||
// get compute queue
|
||||
auto queue = torch_mlu::getCurQueue();
|
||||
auto rois_contiguous =
|
||||
torch_mlu::cnnl::ops::cnnl_contiguous(rois, rois.suggest_memory_format());
|
||||
auto grad_input_ =
|
||||
torch_mlu::cnnl::ops::cnnl_contiguous(grad_input, memory_format);
|
||||
|
||||
// get ptr of tensors
|
||||
auto grad_output_impl = torch_mlu::getMluTensorImpl(grad_output_);
|
||||
auto grad_output_ptr = grad_output_impl->cnnlMalloc();
|
||||
auto input_impl = torch_mlu::getMluTensorImpl(input_);
|
||||
auto input_ptr = input_impl->cnnlMalloc();
|
||||
auto rois_impl = torch_mlu::getMluTensorImpl(rois);
|
||||
auto rois_impl = torch_mlu::getMluTensorImpl(rois_contiguous);
|
||||
auto rois_ptr = rois_impl->cnnlMalloc();
|
||||
auto offset_impl = torch_mlu::getMluTensorImpl(offset);
|
||||
auto offset_ptr = offset_impl->cnnlMalloc();
|
||||
auto grad_input_impl = torch_mlu::getMluTensorImpl(grad_input_);
|
||||
auto grad_input_ptr = grad_input_impl->cnnlMalloc();
|
||||
auto grad_offset_impl = torch_mlu::getMluTensorImpl(grad_offset);
|
||||
auto grad_offset_ptr = grad_offset_impl->cnnlMalloc();
|
||||
|
||||
// get comput dtype of input
|
||||
cnrtDataType_t data_type = torch_mlu::toCnrtDtype(input.dtype());
|
||||
|
||||
// launch kernel
|
||||
CNLOG(INFO) << "Launch Kernel KernelDeformRoIPoolBackward<<<" << k_dim.x
|
||||
<< ", " << k_dim.y << ", " << k_dim.z << ">>>";
|
||||
|
||||
KernelDeformRoIPoolBackward(k_dim, k_type, queue, data_type, grad_output_ptr,
|
||||
input_ptr, rois_ptr, offset_ptr, grad_input_ptr,
|
||||
grad_offset_ptr, channels, height, width,
|
||||
num_rois, pooled_height, pooled_width,
|
||||
spatial_scale, sampling_ratio, gamma);
|
||||
MluOpTensorDescriptor grad_output_desc, input_desc, rois_desc, offset_desc,
|
||||
grad_input_desc, grad_offset_desc;
|
||||
grad_output_desc.set_with_layout(grad_output_, MLUOP_LAYOUT_NHWC);
|
||||
input_desc.set_with_layout(input_, MLUOP_LAYOUT_NHWC);
|
||||
rois_desc.set(rois_contiguous);
|
||||
grad_input_desc.set_with_layout(grad_input_, MLUOP_LAYOUT_NHWC);
|
||||
mluOpTensorDescriptor_t offset_real_desc = NULL;
|
||||
void *offset_ptr = NULL;
|
||||
if (offset.defined() && offset.numel() > 0) {
|
||||
auto offset_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
|
||||
offset, offset.suggest_memory_format());
|
||||
offset_desc.set(offset_contiguous);
|
||||
offset_real_desc = offset_desc.desc();
|
||||
auto offset_impl = torch_mlu::getMluTensorImpl(offset_contiguous);
|
||||
offset_ptr = offset_impl->cnnlMalloc();
|
||||
}
|
||||
mluOpTensorDescriptor_t grad_offset_real_desc = NULL;
|
||||
void *grad_offset_ptr = NULL;
|
||||
if (grad_offset.defined() && grad_offset.numel() > 0) {
|
||||
auto grad_offset_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
|
||||
grad_offset, grad_offset.suggest_memory_format());
|
||||
grad_offset_desc.set(grad_offset_contiguous);
|
||||
grad_offset_real_desc = grad_offset_desc.desc();
|
||||
auto grad_offset_impl = torch_mlu::getMluTensorImpl(grad_offset_contiguous);
|
||||
grad_offset_ptr = grad_offset_impl->cnnlMalloc();
|
||||
}
|
||||
|
||||
// get compute handle
|
||||
auto handle = mluOpGetCurrentHandle();
|
||||
mluOpDeformRoiPoolBackward(
|
||||
handle, grad_output_desc.desc(), grad_output_ptr, input_desc.desc(),
|
||||
input_ptr, rois_desc.desc(), rois_ptr, offset_real_desc, offset_ptr,
|
||||
pooled_height, pooled_width, spatial_scale, sampling_ratio, gamma,
|
||||
grad_input_desc.desc(), grad_input_ptr, grad_offset_real_desc,
|
||||
grad_offset_ptr);
|
||||
grad_input.copy_(grad_input_);
|
||||
}
|
||||
|
||||
|
|
|
@ -72,6 +72,39 @@ void MluOpTensorDescriptor::set(Tensor t) {
|
|||
set_desc(t, layout, data_type, dim_array);
|
||||
}
|
||||
|
||||
void MluOpTensorDescriptor::set_with_layout(Tensor t,
|
||||
mluOpTensorLayout_t layout) {
|
||||
mluOpDataType_t data_type = getMluOpDataType(t.dtype());
|
||||
int t_dim = t.dim();
|
||||
std::vector<int> shape_info = checkUpperBoundAndCastTo<int>(t.sizes().vec());
|
||||
std::vector<int> stride_info =
|
||||
checkUpperBoundAndCastTo<int>(t.strides().vec());
|
||||
if (layout == MLUOP_LAYOUT_NHWC || layout == MLUOP_LAYOUT_NDHWC ||
|
||||
layout == MLUOP_LAYOUT_NLC) {
|
||||
convertShapeAndStride(shape_info, stride_info);
|
||||
} else if (layout == MLUOP_LAYOUT_HWCN) {
|
||||
auto convertDepthWiseConvShapeStride = [](const std::vector<int64_t>& vec,
|
||||
std::vector<int>& target_vec,
|
||||
std::vector<int>& stride_vec) {
|
||||
// NCHW --> HWCN
|
||||
target_vec[0] = static_cast<int>(vec[2]);
|
||||
target_vec[1] = static_cast<int>(vec[3]);
|
||||
target_vec[2] = static_cast<int>(vec[1]);
|
||||
target_vec[3] = static_cast<int>(vec[0]);
|
||||
// Calculate Stride just like contiguous of HWCN.
|
||||
stride_vec[3] = 1;
|
||||
stride_vec[2] = target_vec[3] * stride_vec[3];
|
||||
stride_vec[1] = target_vec[2] * stride_vec[2];
|
||||
stride_vec[0] = target_vec[1] * stride_vec[1];
|
||||
};
|
||||
convertDepthWiseConvShapeStride(t.sizes().vec(), shape_info, stride_info);
|
||||
}
|
||||
TORCH_CHECK(mluOpSetTensorDescriptorEx(
|
||||
desc_, layout, data_type, t_dim, shape_info.data(),
|
||||
stride_info.data()) == MLUOP_STATUS_SUCCESS,
|
||||
"mluOpSetTensorDescriptorEx execution failed.");
|
||||
}
|
||||
|
||||
void MluOpTensorDescriptor::set_desc(const at::Tensor& t,
|
||||
mluOpTensorLayout_t layout,
|
||||
mluOpDataType_t dtype,
|
||||
|
|
|
@ -30,6 +30,7 @@ class MluOpTensorDescriptor {
|
|||
~MluOpTensorDescriptor() { mluOpDestroyTensorDescriptor(desc_); }
|
||||
|
||||
void set(at::Tensor);
|
||||
void set_with_layout(at::Tensor, mluOpTensorLayout_t layout);
|
||||
mluOpTensorDescriptor_t desc() { return desc_; }
|
||||
|
||||
private:
|
||||
|
@ -52,3 +53,47 @@ class MluOpHandle {
|
|||
void setQueue(cnrtQueue_t queue) { mluOpSetQueue(handle, queue); }
|
||||
mluOpHandle_t handle;
|
||||
};
|
||||
|
||||
// modify tensor size and stride order based on
|
||||
// channels_first to channels_last or channels_last_3d.
|
||||
// which this is not same with pytorch original layout,
|
||||
// this real layout is based on data storage real order.
|
||||
// example: modify channels_last tensor dim to nhwc tensor desc.
|
||||
// N C H W --> N H W C
|
||||
// C*H*W 1 W C --> C*H*W W C 1
|
||||
template <typename T>
|
||||
void convertShapeAndStride(std::vector<T>& shape_info,
|
||||
std::vector<T>& stride_info) {
|
||||
TORCH_MLU_CHECK(shape_info.size() == stride_info.size(),
|
||||
"shape size need equal to stride size.");
|
||||
const int dim = shape_info.size();
|
||||
std::vector<T> temp_shape_info(dim);
|
||||
std::vector<T> temp_stride_info(dim);
|
||||
temp_shape_info[0] = shape_info[0];
|
||||
temp_stride_info[0] = stride_info[0];
|
||||
for (size_t i = 0; i < dim - 1; ++i) {
|
||||
const int index = (i + 1) % (dim - 1) + 1;
|
||||
temp_shape_info[i + 1] = shape_info[index];
|
||||
temp_stride_info[i + 1] = stride_info[index];
|
||||
}
|
||||
shape_info.assign(temp_shape_info.begin(), temp_shape_info.end());
|
||||
stride_info.assign(temp_stride_info.begin(), temp_stride_info.end());
|
||||
}
|
||||
|
||||
// torch tensor provides int64_t type of shape and stride,
|
||||
// but mluops descriptor requires type int32.
|
||||
// use this function to ensure safe CAST, or report an error.
|
||||
template <typename DST_T, typename SRC_T>
|
||||
std::vector<DST_T> checkUpperBoundAndCastTo(const std::vector<SRC_T>& input) {
|
||||
std::vector<DST_T> output;
|
||||
output.reserve(input.size());
|
||||
for (const auto& val : input) {
|
||||
if (val > std::numeric_limits<DST_T>::max()) {
|
||||
TORCH_MLU_CHECK(false, "Requires dim size not greater than ",
|
||||
std::numeric_limits<DST_T>::max(), ". But got ", val,
|
||||
".");
|
||||
}
|
||||
output.push_back(static_cast<DST_T>(val));
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue