[Feature] Support RoiPool with Cambricon MLU backend (#2073)

* [Feature] Support RoiPool with cambricon MLU backend

* [Docs] Update ops.md
pull/2084/head
zhouchenyang 2022-06-29 11:24:00 +08:00 committed by GitHub
parent d71d067da1
commit 2708fac6c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 1208 additions and 14 deletions

View File

@ -41,7 +41,7 @@ We implement common ops used in detection, segmentation, etc.
| PSAMask | √ | √ | √ |
| RotatedFeatureAlign | √ | √ | |
| RoIPointPool3d | | √ | |
| RoIPool | | √ | |
| RoIPool | | √ | |
| RoIAlignRotated | √ | √ | √ |
| RiRoIAlignRotated | | √ | |
| RoIAlign | √ | √ | √ |

View File

@ -41,7 +41,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| PSAMask | √ | √ | √ |
| RotatedFeatureAlign | √ | √ | |
| RoIPointPool3d | | √ | |
| RoIPool | | √ | |
| RoIPool | | √ | |
| RoIAlignRotated | √ | √ | √ |
| RiRoIAlignRotated | | √ | |
| RoIAlign | √ | √ | √ |

View File

@ -9,8 +9,8 @@
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef UTILS_H_
#define UTILS_H_
#ifndef COMMON_MLU_HELPER_HPP_
#define COMMON_MLU_HELPER_HPP_
#define NFU_ALIGN_SIZE 128 // Byte
#define REM_FOR_STACK (128 * 1024) // 128KB reserved for cncc
@ -35,4 +35,156 @@
#define CEIL_ALIGN(x, y) (((x) + (y)-1) / (y) * (y))
#endif // UTILS_H_
/*!
* @brief Converts int32 to float32 data type.
*
* @param[out] dst
* Pointer to NRAM that stores int32 type data.
* @param[in,out] dst_addition
* Pointer to NRAM as the workspace of dst, which has the same size as dst.
* It allows empty pointer on MLU300 series.
* @param[in] src
* Pointer to NRAM that stores float32 type data.
* @param[in,out] src_addition
* Pointer to NRAM as the workspace of src, which has a size of 128 Bytes.
* It allows empty pointer on MLU300 series.
* @param[in] src_count
* The count of elements in src.
*/
__mlu_func__ void convertInt2Float(float *dst, float *dst_addition, int *src,
float *src_addition, const int src_count) {
#if __BANG_ARCH__ >= 300
__bang_int2float((float *)dst, (int32_t *)src, src_count, 0);
#else
// get sign bit
const float move_23bit = 8388608.0;
// 0x80000000 = 1,000000000,0000000000000000000000000000
__nramset((unsigned *)src_addition, NFU_ALIGN_SIZE / sizeof(float),
0x80000000);
__bang_cycle_band((char *)dst_addition, (char *)src, (char *)src_addition,
src_count * sizeof(float), NFU_ALIGN_SIZE);
// get 1 or 0 from sign bit
// judg is Odd
__nramset((unsigned *)src_addition, NFU_ALIGN_SIZE / sizeof(float),
0x00000001);
__bang_cycle_bor((char *)dst_addition, (char *)dst_addition,
(char *)src_addition, src_count * sizeof(float),
NFU_ALIGN_SIZE);
__nramset((unsigned *)src_addition, NFU_ALIGN_SIZE / sizeof(float),
0x80000001);
__bang_cycle_eq(dst_addition, dst_addition, src_addition, src_count,
NFU_ALIGN_SIZE / sizeof(float));
// minus xor, positive num invariant
__nramset((unsigned *)src_addition, NFU_ALIGN_SIZE / sizeof(float),
0xffffffff);
__bang_cycle_mul(dst, dst_addition, src_addition, src_count,
NFU_ALIGN_SIZE / sizeof(float));
__bang_bxor((char *)dst, (char *)src, (char *)dst, src_count * sizeof(float));
// convert int32 to float32
__nramset((unsigned *)src_addition, NFU_ALIGN_SIZE / sizeof(float), 0x7fffff);
__bang_cycle_band((char *)dst, (char *)dst, (char *)src_addition,
src_count * sizeof(float), NFU_ALIGN_SIZE);
__nramset((unsigned *)src_addition, NFU_ALIGN_SIZE / sizeof(float),
0x4b000000);
__bang_cycle_bor((char *)dst, (char *)dst, (char *)src_addition,
src_count * sizeof(float), NFU_ALIGN_SIZE);
__bang_sub_const(dst, dst, move_23bit, src_count);
// add one
__bang_add(dst, dst, dst_addition, src_count);
// set sign for float32
__nramset((unsigned *)src_addition, NFU_ALIGN_SIZE / sizeof(float),
0xffffffff);
__bang_cycle_mul(dst_addition, dst_addition, src_addition, src_count,
NFU_ALIGN_SIZE / sizeof(float));
__nramset((unsigned *)src_addition, NFU_ALIGN_SIZE / sizeof(float),
0x00000001);
__bang_cycle_add(dst_addition, dst_addition, src_addition, src_count,
NFU_ALIGN_SIZE / sizeof(float));
__nramset((unsigned *)src_addition, NFU_ALIGN_SIZE / sizeof(float),
0x80000000);
__bang_cycle_band((char *)dst_addition, (char *)dst_addition,
(char *)src_addition, src_count * 4, 128);
__bang_bor((char *)dst, (char *)dst, (char *)dst_addition, src_count * 4);
#endif // __BANG_ARCH__ >= 300
}
/*!
* @brief Converts float32 to int32 data type with to_zero round mode.
*
* @param[out] dst
* Pointer to NRAM that stores float32 type data.
* @param[in,out] dst_addition
* Pointer to NRAM as the workspace of dst, which has the same size as dst.
* It allows empty pointer on MLU300 series.
* @param[in] src
* Pointer to NRAM that stores int32 type data.
* @param[in,out] src_addition
* Pointer to NRAM as the workspace of src, which has a size of 128 Bytes.
* It allows empty pointer on MLU300 series.
* @param[in] src_count
* The count of elements in src.
*/
__mlu_func__ void convertFloat2Int(int *dst, float *dst_addition, float *src,
float *src_addition, const int src_count) {
#if __BANG_ARCH__ >= 300
__bang_float2int_tz((int32_t *)dst, (float *)src, src_count, 0);
#else
// sign ===> src_addition
// dst=-1.0 : when src[i] is a negative number
// dst=+1.0 : when src[i] is a positive number
const int floatDchar = sizeof(float) / sizeof(char);
__bang_active_sign((float *)dst, src, src_count);
// dst_addition = abs(src)
__bang_mul(dst_addition, src, (float *)dst, src_count);
// if dst_addition < 1.0 , then src_addition + 1, to fix add error.
__nramset((float *)src_addition, NFU_ALIGN_SIZE / sizeof(float), 1.0f);
__bang_cycle_lt(dst_addition, dst_addition, (float *)src_addition, src_count,
NFU_ALIGN_SIZE / sizeof(float));
__bang_add_tz((float *)dst, (float *)dst, (float *)dst_addition, src_count);
__nramset((unsigned *)src_addition, NFU_ALIGN_SIZE / sizeof(float),
0xbf800000);
// set negative flag -1.0 = 0xbf80000
__bang_cycle_eq(
(float *)dst, (float *)dst, (float *)src_addition, src_count,
NFU_ALIGN_SIZE / sizeof(float)); // to mark all src in [x<-1.0]
__bang_active_abs(dst_addition, src, src_count);
__nramset((float *)src_addition, NFU_ALIGN_SIZE / sizeof(float), 8388608.0f);
// mask shift move 23
__bang_cycle_add_tz(
dst_addition, dst_addition, src_addition, src_count,
NFU_ALIGN_SIZE / sizeof(float)); // right shift move 23bit
// two`s complement for negatibe
// dst=1.0 , when src <-1.0
// dst=0.0 , when src >=-1.0
__bang_sub(dst_addition, dst_addition, (float *)dst, src_count);
// to fix max value
// 0 1001 0110 111 1111 1111 1111 1111 1111 <=> 0xcb7fffff <=> 16777215.0,
// means max value.
__bang_mul_const((float *)dst, (float *)dst, 16777215.0, src_count);
__bang_bxor((char *)dst_addition, (char *)dst_addition, (char *)dst,
src_count * floatDchar);
// get low 23bit
__nramset((unsigned *)src_addition, NFU_ALIGN_SIZE / sizeof(float),
(unsigned)0x007fffff);
// mask low 23bit is 1
__bang_cycle_band((char *)dst_addition, (char *)dst_addition,
(char *)src_addition, src_count * floatDchar,
NFU_ALIGN_SIZE / sizeof(char));
// set 9 high bit ===> dst
// -2.0 <=> 0xc0000000 <=> 1100 0000 0000 0000 0000 0000 0000 0000
// 1.0 <=> 0x3f800000 <=> 0011 1111 1000 0000 0000 0000 0000 0000
__nramset(src_addition, NFU_ALIGN_SIZE / sizeof(float), 0x3f800000);
__bang_cycle_and((float *)dst, (float *)dst, src_addition, src_count,
NFU_ALIGN_SIZE / sizeof(float));
// src or dst_addition
__bang_bor((char *)dst_addition, (char *)dst, (char *)dst_addition,
src_count * floatDchar);
__bang_mul_const((float *)dst, (float *)dst, -2.0, src_count);
__bang_bor((char *)dst, (char *)dst, (char *)dst_addition,
src_count * floatDchar);
#endif // __BANG_ARCH__ >= 300
}
#endif // COMMON_MLU_HELPER_HPP_

View File

@ -0,0 +1,749 @@
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "common_mlu_helper.hpp"
#define ALIGN_SIZE 64
#define PIPELINE_COMMON_NUM 2
#define PIPELINE_PINGPONG_NUM 10
__nram__ char nram_buffer[MAX_NRAM_SIZE];
namespace forward {
template <typename T>
__mlu_func__ void getRoiBinInfo(T *input_v, T *rois_v, int bin_i, int height,
int width, int channels, int p_height,
int p_width, T spatial_scale, int *bin_x1,
int *bin_y1, int *bin_x2, int *bin_y2,
int *bin_wdim, int *bin_hdim, int *bin_dims,
T **input_base, bool *is_empty) {
int pw = bin_i % p_width;
int ph = (bin_i / p_width) % p_height;
int roi_n = bin_i / p_width / p_height;
/*roi*/
const T *roi_info = rois_v + roi_n * 5; // {{batch, x1, y1, x2, y2},,,}
int batch_index = (int)roi_info[0];
int roi_x1 = round(roi_info[1] * spatial_scale);
int roi_y1 = round(roi_info[2] * spatial_scale);
int roi_x2 = round(roi_info[3] * spatial_scale);
int roi_y2 = round(roi_info[4] * spatial_scale);
int roi_w = roi_x2 - roi_x1 + 1 > 1 ? roi_x2 - roi_x1 + 1 : 1;
int roi_h = roi_y2 - roi_y1 + 1 > 1 ? roi_y2 - roi_y1 + 1 : 1;
/*bin*/
T bin_w = (T)roi_w / (T)p_width;
T bin_h = (T)roi_h / (T)p_height;
*bin_x1 = (int)floor((T)pw * bin_w) + roi_x1;
*bin_x1 = *bin_x1 > 0 ? *bin_x1 : 0;
*bin_x1 = *bin_x1 < width ? *bin_x1 : width;
*bin_y1 = (int)floor((T)ph * bin_h) + roi_y1;
*bin_y1 = *bin_y1 > 0 ? *bin_y1 : 0;
*bin_y1 = *bin_y1 < height ? *bin_y1 : height;
*bin_x2 = (int)ceil((T)(pw + 1) * bin_w) + roi_x1;
*bin_x2 = *bin_x2 > 0 ? *bin_x2 : 0;
*bin_x2 = *bin_x2 < width ? *bin_x2 : width;
*bin_y2 = (int)ceil((T)(ph + 1) * bin_h) + roi_y1;
*bin_y2 = *bin_y2 > 0 ? *bin_y2 : 0;
*bin_y2 = *bin_y2 < height ? *bin_y2 : height;
*input_base = input_v + batch_index * height * width * channels;
*bin_wdim = *bin_x2 - *bin_x1;
*bin_hdim = *bin_y2 - *bin_y1;
*bin_dims = (*bin_hdim) * (*bin_wdim);
*is_empty = (*bin_y2 <= *bin_y1) || (*bin_x2 <= *bin_x1);
}
template <typename T>
__mlu_func__ void MLUUnion1Roipool(T *input_v, T *rois_v, int batch,
int channels, int height, int width,
int p_height, int p_width, int rois_num,
T spatial_scale, T *output_v, int *argmax) {
/*
* NRAM partition
* |---------------------------------------------------|
* | ping |
* |---------------------------------------------------|
* | pong |
* |---------------------------------------------------|
* | out |
* |---------------------------------------------------|
* | argmax |
* |---------------------------------------------------|
* | a |
* |---------------------------------------------------|
* | b |
* |---------------------------------------------------|
*/
uint32_t is_half = sizeof(T) == sizeof(half) ? true : false;
uint32_t t_size = sizeof(T);
uint32_t float_div = NFU_ALIGN_SIZE / sizeof(float);
uint32_t half_div = NFU_ALIGN_SIZE / sizeof(half);
uint32_t channels_align = PAD_UP(channels, float_div);
uint32_t nram_limit = PAD_DOWN(
(MAX_NRAM_SIZE / sizeof(float) - 4 * channels_align) / 2, half_div);
// nram PING/PONG, output, argamx, a, b
float *nram_ping = (float *)nram_buffer;
float *nram_pong = (float *)nram_buffer + nram_limit;
float *nram_out = (float *)nram_buffer + 2 * nram_limit;
float *nram_argmax = nram_out + channels_align;
float *nram_a = nram_out + 2 * channels_align;
float *nram_b = nram_out + 3 * channels_align;
uint32_t c_bins_num = rois_num * p_height * p_width;
uint32_t task_bins = c_bins_num / taskDim;
uint32_t rem_bins = c_bins_num % taskDim;
if (taskId < rem_bins) {
task_bins += 1;
}
int bin_first =
(c_bins_num / taskDim) * taskId + (taskId > rem_bins ? rem_bins : taskId);
int bins_loop = bin_first + task_bins;
T *input_base = NULL;
T *output_base = output_v + bin_first * channels;
int *argmax_base = NULL != argmax ? argmax + bin_first * channels : NULL;
int bin_x1, bin_y1, bin_x2, bin_y2, bin_wdim, bin_hdim, bin_dims;
int pbin_x1, pbin_y1, pbin_x2, pbin_y2, pbin_wdim, pbin_hdim, pbin_dims;
bool is_empty = false;
bool pong_is_empty = false;
bool is_first_bin = true;
uint32_t src_offset = 0;
uint32_t dst_offset = 0;
uint32_t nram_offset = 0;
uint32_t half_offset =
is_half ? (nram_limit / 2 / half_div * half_div) * 2 : 0;
float *nram_tmp = NULL;
uint32_t c_slice = 0;
uint32_t c_slice_align = 0;
uint32_t pongc_slice = 0;
uint32_t pongc_slice_align = 0;
for (int bin_i = bin_first; bin_i < bins_loop; bin_i++) {
getRoiBinInfo((T *)input_v, (T *)rois_v, bin_i, height, width, channels,
p_height, p_width, (T)spatial_scale, &bin_x1, &bin_y1,
&bin_x2, &bin_y2, &bin_wdim, &bin_hdim, &bin_dims,
&input_base, &is_empty);
uint32_t c_rem = channels;
c_slice = nram_limit / bin_dims / float_div * float_div;
if (is_first_bin && !is_empty) {
c_slice = c_slice > c_rem ? c_rem : c_slice;
c_slice_align = PAD_UP(c_slice, float_div);
for (int h = bin_y1; h < bin_y2; h++) {
src_offset = (h * width + bin_x1) * channels;
nram_offset = (h - bin_y1) * bin_wdim * c_slice_align + half_offset;
if (c_slice_align == channels) {
__memcpy((T *)nram_ping + nram_offset, (T *)input_base + src_offset,
bin_wdim * c_slice * t_size, GDRAM2NRAM);
} else {
__memcpy((T *)nram_ping + nram_offset, (T *)input_base + src_offset,
c_slice * t_size, GDRAM2NRAM, c_slice_align * t_size,
channels * t_size, bin_wdim - 1);
}
}
}
uint32_t c_offset = 0;
while (c_rem > 0) {
c_slice = c_slice > c_rem ? c_rem : c_slice;
c_slice_align = PAD_UP(c_slice, float_div);
/*__memcpy_async*/
if (c_rem - c_slice > 0 && !is_empty) {
pongc_slice = c_rem - c_slice > c_slice ? c_slice : c_rem - c_slice;
pongc_slice_align = PAD_UP(pongc_slice, float_div);
for (int h = bin_y1; h < bin_y2; h++) {
src_offset = (h * width + bin_x1) * channels + c_offset;
nram_offset =
(h - bin_y1) * bin_wdim * pongc_slice_align + half_offset;
__memcpy_async((T *)nram_pong + nram_offset,
(T *)input_base + src_offset + c_slice,
pongc_slice * t_size, GDRAM2NRAM,
pongc_slice_align * t_size, channels * t_size,
bin_wdim - 1);
}
} else if (bin_i + 1 < bins_loop) {
getRoiBinInfo((T *)input_v, (T *)rois_v, bin_i + 1, height, width,
channels, p_height, p_width, (T)spatial_scale, &pbin_x1,
&pbin_y1, &pbin_x2, &pbin_y2, &pbin_wdim, &pbin_hdim,
&pbin_dims, &input_base, &pong_is_empty);
pongc_slice = PAD_DOWN(nram_limit / pbin_dims, float_div);
pongc_slice = pongc_slice > channels ? channels : pongc_slice;
pongc_slice_align = PAD_UP(pongc_slice, float_div);
if (!pong_is_empty) {
for (int h = pbin_y1; h < pbin_y2; h++) {
src_offset = (h * width + pbin_x1) * channels;
nram_offset =
(h - pbin_y1) * pbin_wdim * pongc_slice_align + half_offset;
if (pongc_slice_align == channels) {
__memcpy_async((T *)nram_pong + nram_offset,
(T *)input_base + src_offset,
pbin_wdim * pongc_slice * t_size, GDRAM2NRAM);
} else {
__memcpy_async((T *)nram_pong + nram_offset,
(T *)input_base + src_offset, pongc_slice * t_size,
GDRAM2NRAM, pongc_slice_align * t_size,
channels * t_size, pbin_wdim - 1);
}
}
}
}
if (is_empty) {
__nramset((T *)nram_out, c_slice_align, (T)0);
__memcpy((T *)output_base + dst_offset + c_offset, (T *)nram_out,
c_slice * t_size, NRAM2GDRAM);
if (NULL != argmax) {
__nramset((int32_t *)nram_out, c_slice_align, (int32_t)(-1));
__memcpy((int32_t *)argmax_base + dst_offset + c_offset,
(int32_t *)nram_out, c_slice * sizeof(int32_t), NRAM2GDRAM);
}
} else {
if (is_half) {
uint32_t bin_align64 = PAD_UP(bin_dims * c_slice_align, half_div);
__bang_half2float((float *)nram_ping, (half *)nram_ping + half_offset,
bin_align64);
}
__bang_maxpool((float *)nram_out, (float *)nram_ping, c_slice_align,
bin_hdim, bin_wdim, bin_hdim, bin_wdim, 1, 1);
if (is_half) {
uint32_t c_align64 = PAD_UP(c_slice_align, half_div);
__bang_float2half_rd((half *)nram_out, (float *)nram_out, c_align64);
}
__memcpy((T *)output_base + dst_offset + c_offset, (T *)nram_out,
c_slice * t_size, NRAM2GDRAM);
if (NULL != argmax) {
/*compute max_index*/
__bang_maxpool_index((uint32_t *)nram_out, (float *)nram_ping,
c_slice_align, bin_hdim, bin_wdim, bin_hdim,
bin_wdim, 1, 1);
convertInt2Float((float *)nram_argmax, (float *)nram_a,
(int32_t *)nram_out, (float *)nram_b, c_slice_align);
/*compute input_h*/
for (int i = 0; i < c_slice; i++) {
nram_out[i] = (float)(((uint32_t *)nram_out)[i] / bin_wdim);
}
__bang_add_const((float *)nram_a, (float *)nram_out, (float)bin_y1,
c_slice_align);
__bang_mul_const((float *)nram_ping, (float *)nram_a, (float)width,
c_slice_align);
/*compute input_w*/
__bang_mul_const((float *)nram_a, (float *)nram_out, (float)bin_wdim,
c_slice_align);
__bang_sub((float *)nram_a, (float *)nram_argmax, (float *)nram_a,
c_slice_align);
__bang_add_const((float *)nram_a, (float *)nram_a, (float)bin_x1,
c_slice_align);
__bang_add((float *)nram_out, (float *)nram_ping, (float *)nram_a,
c_slice_align);
convertFloat2Int((int32_t *)nram_argmax, (float *)nram_a,
(float *)nram_out, (float *)nram_b, c_slice_align);
__memcpy((int32_t *)argmax_base + dst_offset + c_offset,
(int32_t *)nram_argmax, c_slice * sizeof(int32_t),
NRAM2GDRAM);
}
}
nram_tmp = nram_ping;
nram_ping = nram_pong;
nram_pong = nram_tmp;
c_offset += c_slice;
c_rem -= c_slice;
__asm__ volatile("sync;");
}
dst_offset += channels;
is_first_bin = false;
}
}
__mlu_global__ void MLUKernelRoiPool(cnrtDataType_t data_type,
const void *input_data,
const void *input_rois, int batch,
int channels, int height, int width,
int pooled_height, int pooled_width,
int rois_num, float spatial_scale,
void *output_data, int *argmax) {
switch (data_type) {
case CNRT_FLOAT16: {
MLUUnion1Roipool((half *)input_data, (half *)input_rois, batch, channels,
height, width, pooled_height, pooled_width, rois_num,
(half)spatial_scale, (half *)output_data, argmax);
}; break;
case CNRT_FLOAT32: {
MLUUnion1Roipool((float *)input_data, (float *)input_rois, batch,
channels, height, width, pooled_height, pooled_width,
rois_num, (float)spatial_scale, (float *)output_data,
argmax);
}; break;
default: {
break;
}
}
}
} // namespace forward
namespace backward {
// Convert index of argmax from global grads_image to local bin in RoI. Vector
// operations do not support int type, so conversion from int to float is
// performed here.
__mlu_func__ void convertIndex(
int32_t *nram_argmax, int32_t *nram_argmax_fp, int32_t *nram_argmax_fp_bk1,
int32_t *nram_argmax_fp_bk2, int32_t *nram_argmax_int,
int32_t *nram_argmax_int_h, int32_t *nram_argmax_int_w,
int32_t *nram_argmax_fp_h, int32_t *nram_argmax_fp_w,
float *nram_atomic_add, float *nram_grads_image, int width, int height,
int wstart, int hstart, int w_compute, int h_compute, int align_c,
int channels, int loop_flag, int loop_id, int true_limit) {
convertInt2Float((float *)nram_argmax_fp, (float *)nram_argmax_fp_bk1,
(int *)nram_argmax, (float *)nram_argmax_fp_bk2, align_c);
// This step uses scalar division, because the above vector division causes
// rounding accuracy problem.
for (int i = 0; i < channels; ++i) {
*((float *)nram_argmax_fp + i) = *((float *)nram_argmax_fp + i) / width;
}
// Use 'float2int_tz' to perform '*((int32_t*)nram_argmax + i) / width'
// operation.
convertFloat2Int((int *)nram_argmax_int_h, (float *)nram_argmax_fp_bk1,
(float *)nram_argmax_fp, (float *)nram_argmax_fp_bk2,
align_c);
convertInt2Float((float *)nram_argmax_fp, (float *)nram_argmax_fp_bk1,
(int *)nram_argmax_int_h, (float *)nram_argmax_fp_bk2,
align_c);
// Perform 'temp_result - hstart' operation
__bang_sub_const((float *)nram_argmax_fp_h, (float *)nram_argmax_fp, hstart,
align_c);
// Perform 'temp_result1 - temp_result2 * width' operation
__bang_mul_const((float *)nram_argmax_fp_w, (float *)nram_argmax_fp, width,
align_c);
convertInt2Float((float *)nram_argmax_fp, (float *)nram_argmax_fp_bk1,
(int *)nram_argmax, (float *)nram_argmax_fp_bk2, align_c);
__bang_sub((float *)nram_argmax_fp_w, (float *)nram_argmax_fp,
(float *)nram_argmax_fp_w, align_c);
// Perform 'temp_result - wstart' operation
__bang_sub_const((float *)nram_argmax_fp_w, (float *)nram_argmax_fp_w, wstart,
align_c);
// Perform 'temp_result = h * w_compute + w' operation
__bang_mul_const((float *)nram_argmax_fp_h, (float *)nram_argmax_fp_h,
w_compute, align_c);
__bang_add((float *)nram_argmax_fp_h, (float *)nram_argmax_fp_h,
(float *)nram_argmax_fp_w, align_c);
if (loop_flag == 1) {
__bang_sub_const((float *)nram_argmax_fp_h, (float *)nram_argmax_fp_h,
(loop_id * true_limit), align_c);
}
convertFloat2Int((int *)nram_argmax_int, (float *)nram_argmax_fp_bk1,
(float *)nram_argmax_fp_h, (float *)nram_argmax_fp_bk2,
align_c);
}
template <typename T>
__mlu_func__ void MLUUnion1Roipool(const T *rois, const T *grads,
const int32_t *argmax, T *grads_image,
int channels, int height, int width,
int pooled_height, int pooled_width,
int rois_num, const T spatial_scale,
int high_precision) {
// Calculate the number of rois processed by each core
int bin_num = rois_num * pooled_height * pooled_width;
int loop =
(bin_num % taskDim) ? (bin_num / taskDim + 1) : (bin_num / taskDim);
int tid = taskId * loop;
if (bin_num % taskDim != 0) {
if (tid >= bin_num) {
return;
} else {
// last part is (bin_num - tid).
loop = bin_num - tid < loop ? bin_num - tid : loop;
}
}
int align_c = PAD_UP(channels, ALIGN_SIZE);
// Common part has 2: grads, argmax; ping-pong each is PIPELINE_PINGPONG_NUM.
int data_size =
PAD_DOWN(((MAX_NRAM_SIZE / sizeof(float) - PIPELINE_COMMON_NUM * align_c -
(PIPELINE_PINGPONG_NUM - 1) * align_c * 2) /
2),
ALIGN_SIZE);
int hw_limit = data_size / align_c;
float *nram_grads = (float *)nram_buffer;
for (int idx = tid; idx < tid + loop; ++idx) {
// (n, ph, pw) is a C in the pooled output
int pw = idx % pooled_width;
int ph = (idx / pooled_width) % pooled_height;
int n = idx / pooled_width / pooled_height;
const T *offset_rois = (const T *)(rois + n * 5);
int roi_batch_ind = int(offset_rois[0]);
// Calculate the roi region on feature maps
int roi_start_w = round(offset_rois[1] * spatial_scale);
int roi_start_h = round(offset_rois[2] * spatial_scale);
int roi_end_w = round(offset_rois[3] * spatial_scale);
int roi_end_h = round(offset_rois[4] * spatial_scale);
// Force malformed rois to 1x1
int roi_width =
roi_end_w - roi_start_w + 1 > 1 ? roi_end_w - roi_start_w + 1 : 1;
int roi_height =
roi_end_h - roi_start_h + 1 > 1 ? roi_end_h - roi_start_h + 1 : 1;
T bin_size_h = (T)roi_height / (T)pooled_height;
T bin_size_w = (T)roi_width / (T)pooled_width;
// The corresponding bin region
int hstart = int(floor((T)ph * bin_size_h));
int wstart = int(floor((T)pw * bin_size_w));
int hend = int(ceil((T)(ph + 1) * bin_size_h));
int wend = int(ceil((T)(pw + 1) * bin_size_w));
// Add roi offsets and clip to input boundaries, min(max(A, B), C);
hstart = hstart + roi_start_h > 0 ? hstart + roi_start_h : 0;
hstart = hstart < height ? hstart : height;
hend = hend + roi_start_h > 0 ? hend + roi_start_h : 0;
hend = hend < height ? hend : height;
wstart = wstart + roi_start_w > 0 ? wstart + roi_start_w : 0;
wstart = wstart < width ? wstart : width;
wend = wend + roi_start_w > 0 ? wend + roi_start_w : 0;
wend = wend < width ? wend : width;
bool is_empty = (hend <= hstart) || (wend <= wstart);
if (!is_empty) {
int h_compute = hend - hstart;
int w_compute = wend - wstart;
int true_limit =
hw_limit < h_compute * w_compute ? hw_limit : h_compute * w_compute;
int loop_int = (h_compute * w_compute) / true_limit;
int rem = (h_compute * w_compute) % true_limit;
int32_t *nram_argmax = (int32_t *)nram_grads + align_c;
int32_t *nram_argmax_fp = (int32_t *)nram_argmax + align_c;
int32_t *nram_argmax_fp_bk1 = (int32_t *)nram_argmax_fp + align_c;
int32_t *nram_argmax_fp_bk2 = (int32_t *)nram_argmax_fp_bk1 + align_c;
int32_t *nram_argmax_int = (int32_t *)nram_argmax_fp_bk2 + align_c;
int32_t *nram_argmax_int_h = (int32_t *)nram_argmax_int + align_c;
int32_t *nram_argmax_int_w = (int32_t *)nram_argmax_int_h + align_c;
int32_t *nram_argmax_fp_h = (int32_t *)nram_argmax_int_w + align_c;
int32_t *nram_argmax_fp_w = (int32_t *)nram_argmax_fp_h + align_c;
float *nram_atomic_add = (float *)nram_argmax_fp_w + align_c;
float *nram_grads_image = (float *)nram_atomic_add + align_c;
if (true_limit == h_compute * w_compute) {
/*
* NRAM partition
* |---------------------------------------------------|
* | grads |
* |---------------------------------------------------|
* | argmax |
* |---------------------------------------------------|
* | argmax_temp |
* |---------------------------------------------------|
* | atomic_add |
* |---------------------------------------------------|
* | grads_image |
* |---------------------------------------------------|
*/
// Load the data from GDRAM to NRAM.
__memcpy((T *)nram_grads + align_c * high_precision,
(const T *)grads + (n * pooled_height * pooled_width +
ph * pooled_width + pw) *
channels,
channels * sizeof(T), GDRAM2NRAM);
if (high_precision) {
__bang_half2float((float *)nram_grads,
(half *)nram_grads + align_c * high_precision,
align_c);
}
__memcpy((int32_t *)nram_argmax,
(const int32_t *)argmax + (n * pooled_height * pooled_width +
ph * pooled_width + pw) *
channels,
channels * sizeof(int32_t), GDRAM2NRAM);
// Perform pooling operation on NRAM.
convertIndex(nram_argmax, nram_argmax_fp, nram_argmax_fp_bk1,
nram_argmax_fp_bk2, nram_argmax_int, nram_argmax_int_h,
nram_argmax_int_w, nram_argmax_fp_h, nram_argmax_fp_w,
nram_atomic_add, nram_grads_image, width, height, wstart,
hstart, w_compute, h_compute, align_c, channels, 0, 0, 0);
__bang_maxpool_bp((float *)nram_grads_image, (float *)nram_grads,
(int32_t *)nram_argmax_int, align_c, h_compute,
w_compute, h_compute, w_compute, h_compute,
w_compute);
if (high_precision) {
__bang_float2half_rd((half *)nram_grads_image,
(float *)nram_grads_image,
h_compute * w_compute * align_c);
}
// Store the result on NRAM back to GDRAM.
for (int hc = 0; hc < h_compute; ++hc) {
for (int wc = 0; wc < w_compute; ++wc) {
T *dst = (T *)nram_atomic_add;
int grad_image_offset = (roi_batch_ind * height * width +
(hc + hstart) * width + wc + wstart) *
channels;
T *src1 = (T *)grads_image + grad_image_offset;
int nram_grads_image_offset = (hc * w_compute + wc) * align_c;
T *src2 = (T *)nram_grads_image + nram_grads_image_offset;
__bang_atomic_add(dst, src1, src2, channels);
}
}
} else if (true_limit > 0) {
/*
* NRAM partition
* |---------------------------------------------------|
* | grads |
* |---------------------------------------------------|
* | argmax |
* |--------------------ping_pong----------------------|
* | argmax_temp | argmax_temp |
* |------------------------|--------------------------|
* | atomic_add | atomic_add |
* |------------------------|--------------------------|
* | grads_image | grads_image |
* |---------------------------------------------------|
*/
// Load the data from GDRAM to NRAM.
__memcpy((T *)nram_grads + align_c * high_precision,
(const T *)grads + (n * pooled_height * pooled_width +
ph * pooled_width + pw) *
channels,
channels * sizeof(T), GDRAM2NRAM);
if (high_precision) {
__bang_half2float((float *)nram_grads,
(half *)nram_grads + align_c * high_precision,
align_c);
}
__memcpy((int32_t *)nram_argmax,
(const int32_t *)argmax + (n * pooled_height * pooled_width +
ph * pooled_width + pw) *
channels,
channels * sizeof(int32_t), GDRAM2NRAM);
int ping_pong = 0;
int ping_pong_offset =
(MAX_NRAM_SIZE / sizeof(float) - align_c * PIPELINE_COMMON_NUM) / 2;
for (int loop_id = 0; loop_id <= loop_int; ++loop_id) {
int size = (loop_id == loop_int) ? rem : true_limit;
if (size == 0) {
break;
}
// Perform pooling operation on NRAM.
nram_argmax_fp =
(int32_t *)nram_argmax + align_c + ping_pong * ping_pong_offset;
nram_argmax_fp_bk1 = (int32_t *)nram_argmax_fp + align_c;
nram_argmax_fp_bk2 = (int32_t *)nram_argmax_fp_bk1 + align_c;
nram_argmax_int = (int32_t *)nram_argmax_fp_bk2 + align_c;
nram_argmax_int_h = (int32_t *)nram_argmax_int + align_c;
nram_argmax_int_w = (int32_t *)nram_argmax_int_h + align_c;
nram_argmax_fp_h = (int32_t *)nram_argmax_int_w + align_c;
nram_argmax_fp_w = (int32_t *)nram_argmax_fp_h + align_c;
nram_atomic_add = (float *)nram_argmax_fp_w + align_c;
nram_grads_image = (float *)nram_atomic_add + align_c;
int loop_id_1 = loop_id;
int size_1 = ((loop_id_1) == loop_int) ? rem : true_limit;
if (size_1 == 0) {
break;
}
convertIndex(nram_argmax, nram_argmax_fp, nram_argmax_fp_bk1,
nram_argmax_fp_bk2, nram_argmax_int, nram_argmax_int_h,
nram_argmax_int_w, nram_argmax_fp_h, nram_argmax_fp_w,
nram_atomic_add, nram_grads_image, width, height, wstart,
hstart, w_compute, h_compute, align_c, channels, 1,
loop_id_1, true_limit);
__bang_maxpool_bp((float *)nram_grads_image, (float *)nram_grads,
(int32_t *)nram_argmax_int, align_c, size_1, 1,
size_1, 1, size_1, 1);
if (high_precision) {
__bang_float2half_rd((half *)nram_grads_image,
(float *)nram_grads_image, size_1 * align_c);
}
// Store the result on NRAM back to GDRAM.
for (int index_size = 0; index_size < size; ++index_size) {
int h = (loop_id * true_limit + index_size) / w_compute;
int w = (loop_id * true_limit + index_size) % w_compute;
T *dst = (T *)nram_atomic_add;
T *grads_image_n =
(T *)grads_image + roi_batch_ind * height * width * channels;
T *src1 = (T *)grads_image_n +
((h + hstart) * width + (w + wstart)) * channels;
T *src2 = (T *)nram_grads_image + index_size * align_c;
__bang_atomic_add(dst, src1, src2, channels);
}
ping_pong = 1 - ping_pong;
}
} else {
/*
* NRAM partition
* |---------------------------------------------------|
* | grads |
* |---------------------------------------------------|
* | argmax |
* |--------------------ping_pong----------------------|
* | argmax_temp | argmax_temp |
* |------------------------|--------------------------|
* | atomic_add | atomic_add |
* |------------------------|--------------------------|
* | grads_image | grads_image |
* |---------------------------------------------------|
*/
int c_limit =
PAD_DOWN(MAX_NRAM_SIZE / sizeof(float) /
(PIPELINE_COMMON_NUM + PIPELINE_PINGPONG_NUM * 2),
ALIGN_SIZE);
int loop_int = channels / c_limit;
int rem = channels % c_limit;
int ping_pong = 0;
int ping_pong_offset =
(MAX_NRAM_SIZE / sizeof(float) - c_limit * PIPELINE_COMMON_NUM) / 2;
for (int loop_id = 0; loop_id <= loop_int; ++loop_id) {
int size = (loop_id == loop_int) ? rem : c_limit;
if (size == 0) {
break;
}
nram_argmax_fp =
(int32_t *)nram_argmax + c_limit + ping_pong * ping_pong_offset;
nram_argmax_fp_bk1 = (int32_t *)nram_argmax_fp + c_limit;
nram_argmax_fp_bk2 = (int32_t *)nram_argmax_fp_bk1 + c_limit;
nram_argmax_int = (int32_t *)nram_argmax_fp_bk2 + c_limit;
nram_argmax_int_h = (int32_t *)nram_argmax_int + c_limit;
nram_argmax_int_w = (int32_t *)nram_argmax_int_h + c_limit;
nram_argmax_fp_h = (int32_t *)nram_argmax_int_w + c_limit;
nram_argmax_fp_w = (int32_t *)nram_argmax_fp_h + c_limit;
nram_atomic_add = (float *)nram_argmax_fp_w + c_limit;
nram_grads_image = (float *)nram_atomic_add + c_limit;
// This pipeline loads the data from GDRAM to NRAM.
__memcpy((T *)nram_grads + c_limit * high_precision,
(const T *)grads +
n * pooled_height * pooled_width * channels +
ph * pooled_width * channels + pw * channels +
loop_id * c_limit,
size * sizeof(T), GDRAM2NRAM);
if (high_precision) {
__bang_half2float((float *)nram_grads,
(half *)nram_grads + c_limit * high_precision,
c_limit);
}
__memcpy((int32_t *)nram_argmax,
(const int32_t *)argmax +
n * pooled_height * pooled_width * channels +
ph * pooled_width * channels + pw * channels +
loop_id * c_limit,
size * sizeof(int32_t), GDRAM2NRAM);
for (int hc = 0; hc < h_compute; ++hc) {
for (int wc = 0; wc < w_compute; ++wc) {
// This pipeline performs pooling operation on NRAM.
convertIndex(
nram_argmax, nram_argmax_fp, nram_argmax_fp_bk1,
nram_argmax_fp_bk2, nram_argmax_int, nram_argmax_int_h,
nram_argmax_int_w, nram_argmax_fp_h, nram_argmax_fp_w,
nram_atomic_add, nram_grads_image, width, height, wstart + wc,
hstart + hc, h_compute, w_compute, c_limit, size, 0, 0, 0);
__bang_maxpool_bp((float *)nram_grads_image, (float *)nram_grads,
(int32_t *)nram_argmax_int, c_limit, 1, 1, 1, 1,
1, 1);
if (high_precision) {
__bang_float2half_rd((half *)nram_grads_image,
(float *)nram_grads_image, c_limit);
}
// This pipeline stores the result on NRAM back to GDRAM.
T *dst = (T *)nram_atomic_add;
T *grads_image_n =
(T *)grads_image + roi_batch_ind * height * width * channels;
T *src1 = (T *)grads_image_n +
((hc + hstart) * width + (wc + wstart)) * channels +
loop_id * c_limit;
T *src2 = (T *)nram_grads_image;
__bang_atomic_add(dst, src1, src2, size);
}
}
ping_pong = 1 - ping_pong;
}
}
}
}
}
__mlu_global__ void MLUKernelRoiPoolBackward(
const void *grads, const void *rois, const int *argmax, void *grads_image,
int rois_num, int pooled_height, int pooled_width, int channels, int no,
int height, int width, const float spatial_scale,
const cnrtDataType_t k_dtype) {
// make sure that memcore is not used
if (coreId == 0x80) {
return;
}
switch (k_dtype) {
case CNRT_FLOAT16: {
// Using the float type '__bang_max_pool_bp' instruction to increase the
// bit width.
const int high_precision = 1;
MLUUnion1Roipool((const half *)rois, (const half *)grads,
(const int32_t *)argmax, (half *)grads_image, channels,
height, width, pooled_height, pooled_width, rois_num,
(const half)spatial_scale, high_precision);
}; break;
case CNRT_FLOAT32: {
const int high_precision = 0;
MLUUnion1Roipool((const float *)rois, (const float *)grads,
(const int32_t *)argmax, (float *)grads_image, channels,
height, width, pooled_height, pooled_width, rois_num,
(const float)spatial_scale, high_precision);
}; break;
default: {
break;
}
}
}
} // namespace backward
void KernelRoiPoolForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, cnrtDataType_t data_type,
const void *input_data, const void *input_rois,
const int batch, const int channels, const int height,
const int width, const int pooled_height,
const int pooled_width, const int rois_num,
const float spatial_scale, void *output_data,
int *argmax) {
forward::MLUKernelRoiPool<<<k_dim, k_type, queue>>>(
data_type, input_data, input_rois, batch, channels, height, width,
pooled_height, pooled_width, rois_num, spatial_scale, output_data,
argmax);
}
void KernelRoiPoolBackward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, cnrtDataType_t k_dtype,
const void *grad_output_ptr, const void *rois_ptr,
const int *argmax_ptr, void *grad_input_ptr,
const int box_num, const int pooled_height,
const int pooled_width, const int channels,
const int batch, const int height, const int width,
const float spatial_scale) {
backward::MLUKernelRoiPoolBackward<<<k_dim, k_type, queue>>>(
grad_output_ptr, rois_ptr, argmax_ptr, grad_input_ptr, box_num,
pooled_height, pooled_width, channels, batch, height, width,
spatial_scale, k_dtype);
}

View File

@ -0,0 +1,275 @@
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
void KernelRoiPoolForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, cnrtDataType_t data_type,
const void *input_data, const void *input_rois,
const int batch, const int channels, const int height,
const int width, const int pooled_height,
const int pooled_width, const int rois_num,
const float spatial_scale, void *output_data,
int *argmax);
void KernelRoiPoolBackward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, cnrtDataType_t k_dtype,
const void *grad_output_ptr, const void *rois_ptr,
const int *argmax_ptr, void *grad_input_ptr,
const int box_num, const int pooled_height,
const int pooled_width, const int channels,
const int batch, const int height, const int width,
const float spatial_scale);
// policy function for forward
static void policyFuncForward(const int bin_num, cnrtDim3_t *k_dim,
cnrtFunctionType_t *k_type) {
auto core_num = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
auto cluster_num = torch_mlu::getDeviceAttr(cnrtAttrClusterCount);
*k_type = CNRT_FUNC_TYPE_UNION1;
k_dim->x = core_num;
unsigned int use_cluster = bin_num / core_num + (bin_num % core_num > 0);
k_dim->y = use_cluster > cluster_num ? cluster_num : use_cluster;
k_dim->z = 1;
}
void ROIPoolForwardMLUKernelLauncher(Tensor input, Tensor rois, Tensor output,
Tensor argmax, int pooled_height,
int pooled_width, float spatial_scale) {
// Check dtype.
TORCH_CHECK(
input.scalar_type() == at::kFloat || input.scalar_type() == at::kHalf,
"input type should be Float or Half, got ", input.scalar_type());
TORCH_CHECK(input.scalar_type() == rois.scalar_type(),
"rois should have the same type as input");
// Check dtype relationship.
TORCH_CHECK(
argmax.scalar_type() == at::kLong || argmax.scalar_type() == at::kInt,
"argmax type should be Int or Long, got ", argmax.scalar_type());
// Check shape.
TORCH_CHECK(input.dim() == 4, "input should be 4d tensor, got ", input.dim(),
"D");
TORCH_CHECK(rois.dim() == 2, "rois should be 2d tensor, got ", rois.dim(),
"D");
TORCH_CHECK(argmax.dim() == 4, "argmax should be 4d tensor, got ",
argmax.dim(), "D");
TORCH_CHECK(spatial_scale > 0 && spatial_scale <= 1,
"spatial_scale should be within (0, 1], got ", spatial_scale);
// compute kernel params
auto batch = input.size(0);
auto height = input.size(2);
auto width = input.size(3);
auto channels = input.size(1);
auto rois_num = output.size(0);
if (output.numel() == 0) {
output = at::zeros({rois_num, channels, pooled_height, pooled_width},
input.options());
return;
}
if (argmax.numel() == 0) {
argmax = at::zeros({rois_num, channels, pooled_height, pooled_width},
argmax.options());
return;
}
// zero element check
if (input.numel() == 0 || rois.numel() == 0 || output.numel() == 0 ||
argmax.numel() == 0) {
return;
}
auto memory_format =
torch_mlu::cnnl::ops::get_channels_last_memory_format(input.dim());
auto input_ = torch_mlu::cnnl::ops::cnnl_contiguous(input, memory_format);
at::Tensor output_ =
at::empty({rois_num, channels, pooled_height, pooled_width},
input.options(), memory_format);
at::Tensor argmax_ =
at::empty({rois_num, channels, pooled_height, pooled_width},
argmax.options(), memory_format);
// calculate task dimension
cnrtDim3_t k_dim;
cnrtFunctionType_t k_type;
policyFuncForward(rois_num * pooled_height * pooled_width, &k_dim, &k_type);
// get compute queue
auto queue = torch_mlu::getCurQueue();
// get ptr of tensors
auto input_impl = torch_mlu::getMluTensorImpl(input_);
auto input_ptr = input_impl->cnnlMalloc();
auto rois_impl = torch_mlu::getMluTensorImpl(rois);
auto rois_ptr = rois_impl->cnnlMalloc();
auto output_impl = torch_mlu::getMluTensorImpl(output_);
auto output_ptr = output_impl->cnnlMalloc();
auto argmax_impl = torch_mlu::getMluTensorImpl(argmax_);
auto argmax_ptr = argmax_impl->cnnlMalloc();
// get comput dtype of input
cnrtDataType_t data_type = torch_mlu::toCnrtDtype(input_.dtype());
// launch kernel
CNLOG(INFO) << "Launch Kernel MLUKernelRoiPoolForward<<<" << k_dim.x << ", "
<< k_dim.y << ", " << k_dim.z << ">>>";
KernelRoiPoolForward(k_dim, k_type, queue, data_type, input_ptr, rois_ptr,
batch, channels, height, width, pooled_height,
pooled_width, rois_num, spatial_scale, output_ptr,
(int *)argmax_ptr);
output.copy_(output_);
argmax.copy_(argmax_);
}
// policy function for backward
static void policyFuncBackward(cnrtDim3_t *k_dim, cnrtFunctionType_t *k_type) {
*k_type = CNRT_FUNC_TYPE_UNION1;
k_dim->x = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
k_dim->y = torch_mlu::getDeviceAttr(cnrtAttrClusterCount);
k_dim->z = 1;
}
void ROIPoolBackwardMLUKernelLauncher(Tensor grad_output, Tensor rois,
Tensor argmax, Tensor grad_input,
int pooled_height, int pooled_width,
float spatial_scale) {
// Check dtype.
TORCH_CHECK(
argmax.scalar_type() == at::kLong || argmax.scalar_type() == at::kInt,
"argmax type should be Int or Long, got ", argmax.scalar_type());
TORCH_CHECK((grad_output.scalar_type() == at::kFloat ||
grad_output.scalar_type() == at::kHalf),
"grad_output type should be FLoat or Half, got ",
grad_output.scalar_type());
// Check dtype relationship.
TORCH_CHECK((rois.scalar_type() == grad_output.scalar_type()),
"rois should have the same type as grad_output");
// Check shape.
TORCH_CHECK(grad_output.dim() == 4, "grad_output should be 4d tensor, got ",
grad_output.dim(), "D");
TORCH_CHECK(rois.dim() == 2, "rois should be 2d tensor, got ", rois.dim(),
"D");
TORCH_CHECK(argmax.dim() == 4, "argmax should be 4d tensor, got ",
argmax.dim(), "D");
TORCH_CHECK(spatial_scale > 0 && spatial_scale <= 1,
"spatial_scale should be within (0, 1], got ", spatial_scale);
// Check relationship between tensor.
// Check the relationship of n.
TORCH_CHECK(grad_output.size(0) == rois.size(0),
"grad_output.size(0) = ", grad_output.size(0),
", while rois.size(0) = ", rois.size(0),
". They should be the same.");
// Check the relationship of channels.
TORCH_CHECK(grad_output.size(1) == argmax.size(1),
"grad_output.size(1) = ", grad_output.size(1),
", while argmax.size(1) = ", argmax.size(1),
". They should be the same.");
// Check the relationship of height and width.
TORCH_CHECK(grad_output.size(2) == argmax.size(2),
"argmax.size(2) = ", argmax.size(2),
", while grad_output.size(2) = ", grad_output.size(2),
". They should be the same.");
TORCH_CHECK(grad_output.size(3) == argmax.size(3),
"argmax.size(3) = ", argmax.size(3),
", while grad_output.size(3) = ", grad_output.size(3),
". They should be the same.");
// Check zero element.
if (grad_output.numel() == 0 || rois.numel() == 0 || argmax.numel() == 0 ||
grad_input.numel() == 0) {
// return if zero-element
return;
}
auto memory_format =
torch_mlu::cnnl::ops::get_channels_last_memory_format(grad_output.dim());
auto grad_output_ =
torch_mlu::cnnl::ops::cnnl_contiguous(grad_output, memory_format);
auto argmax_ = torch_mlu::cnnl::ops::cnnl_contiguous(argmax, memory_format);
int boxes_num = grad_output.size(0);
int no = grad_input.size(0);
int channels = grad_input.size(1);
int height = grad_input.size(2);
int width = grad_input.size(3);
auto grad_input_ = at::empty({no, channels, height, width},
grad_input.options(), memory_format)
.zero_();
// get tensor impl
auto grad_output_impl = torch_mlu::getMluTensorImpl(grad_output_);
auto rois_impl = torch_mlu::getMluTensorImpl(rois);
auto argmax_impl = torch_mlu::getMluTensorImpl(argmax_);
auto grad_input_impl = torch_mlu::getMluTensorImpl(grad_input_);
// get compute queue
auto queue = torch_mlu::getCurQueue();
// get mlu ptr
auto grad_output_ptr = grad_output_impl->cnnlMalloc();
auto rois_ptr = rois_impl->cnnlMalloc();
auto argmax_ptr = argmax_impl->cnnlMalloc();
auto grad_input_ptr = grad_input_impl->cnnlMalloc();
// calculate task dimension
cnrtDataType_t k_dtype = torch_mlu::toCnrtDtype(grad_input.dtype());
cnrtDim3_t k_dim;
cnrtFunctionType_t k_type;
policyFuncBackward(&k_dim, &k_type);
CNLOG(INFO) << "Launch Kernel MLUKernelRoiPoolBackward<<<" << k_dim.x << ", "
<< k_dim.y << ", " << k_dim.z << ">>>";
KernelRoiPoolBackward(k_dim, k_type, queue, k_dtype, grad_output_ptr,
rois_ptr, (int *)argmax_ptr, grad_input_ptr, boxes_num,
pooled_height, pooled_width, channels, no, height,
width, spatial_scale);
grad_input.copy_(grad_input_);
}
void roi_pool_forward_mlu(Tensor input, Tensor rois, Tensor output,
Tensor argmax, int pooled_height, int pooled_width,
float spatial_scale) {
ROIPoolForwardMLUKernelLauncher(input, rois, output, argmax, pooled_height,
pooled_width, spatial_scale);
}
void roi_pool_backward_mlu(Tensor grad_output, Tensor rois, Tensor argmax,
Tensor grad_input, int pooled_height,
int pooled_width, float spatial_scale) {
ROIPoolBackwardMLUKernelLauncher(grad_output, rois, argmax, grad_input,
pooled_height, pooled_width, spatial_scale);
}
void roi_pool_forward_impl(Tensor input, Tensor rois, Tensor output,
Tensor argmax, int pooled_height, int pooled_width,
float spatial_scale);
void roi_pool_backward_impl(Tensor grad_output, Tensor rois, Tensor argmax,
Tensor grad_input, int pooled_height,
int pooled_width, float spatial_scale);
REGISTER_DEVICE_IMPL(roi_pool_forward_impl, MLU, roi_pool_forward_mlu);
REGISTER_DEVICE_IMPL(roi_pool_backward_impl, MLU, roi_pool_backward_mlu);

View File

@ -2,8 +2,11 @@
import os
import numpy as np
import pytest
import torch
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
_USING_PARROTS = True
try:
from parrots.autograd import gradcheck
@ -54,9 +57,7 @@ class TestRoiPool:
else:
gradcheck(froipool, (x, rois), eps=1e-2, atol=1e-2)
def _test_roipool_allclose(self, dtype=torch.float):
if not torch.cuda.is_available():
return
def _test_roipool_allclose(self, device, dtype=torch.float):
from mmcv.ops import roi_pool
pool_h = 2
pool_w = 2
@ -69,15 +70,32 @@ class TestRoiPool:
np_grad = np.array(output[1])
x = torch.tensor(
np_input, dtype=dtype, device='cuda', requires_grad=True)
rois = torch.tensor(np_rois, dtype=dtype, device='cuda')
np_input, dtype=dtype, device=device, requires_grad=True)
rois = torch.tensor(np_rois, dtype=dtype, device=device)
output = roi_pool(x, rois, (pool_h, pool_w), spatial_scale)
output.backward(torch.ones_like(output))
assert np.allclose(output.data.cpu().numpy(), np_output, 1e-3)
assert np.allclose(x.grad.data.cpu().numpy(), np_grad, 1e-3)
def test_roipool_allclose(self):
self._test_roipool_allclose(torch.double)
self._test_roipool_allclose(torch.float)
self._test_roipool_allclose(torch.half)
@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
])
@pytest.mark.parametrize('dtype', [
torch.float,
pytest.param(
torch.double,
marks=pytest.mark.skipif(
IS_MLU_AVAILABLE,
reason='MLU does not support for 64-bit floating point')),
torch.half
])
def test_roipool_allclose(self, device, dtype):
self._test_roipool_allclose(device, dtype)