[Enhancement] Repalce the implementation of roiaware_pool3d with mlu-ops. (#2699)

* [Feature] Repalce the implementation of roiaware_pool3d with mlu-ops.

* [Feature] Repalce the implementation of roiaware_pool3d with mlu-ops.
pull/2731/head
Zhang 2023-04-03 23:35:07 +08:00 committed by GitHub
parent a55f4b7f40
commit bc727f7132
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 88 additions and 1070 deletions

View File

@ -1,747 +0,0 @@
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "common_mlu_helper.hpp"
#define ROI_OFFSET 7
#define FLOAT_NRAM_BUFFER_NUM 14
#define HALF_NRAM_BUFFER_NUM 25
#define ALIGN_NUM 64
__nram__ char data_nram[MAX_NRAM_SIZE];
template <typename T>
__mlu_global__ void MLUUnion1KernelPtsIdxOfVoxels(
const int pool_method, const int boxes_num, const int pts_num,
const int max_pts_each_voxel, const int out_x, const int out_y,
const int out_z, const T *rois, const T *pts, int *pts_idx_of_voxels) {
// params (T)rois: (boxes_num, 7)
// params (T)pts: (3, pts_num)
// params (int)pts_idx_of_voxels: (boxes_num, out_x, out_y, out_z,
// max_pts_each_voxel)
// make sure that memcore is not used
if (coreId == 0x80) {
return;
}
int nram_pts_num = 0;
if (sizeof(T) == sizeof(float)) {
nram_pts_num = PAD_DOWN(
(MAX_NRAM_SIZE / sizeof(float) / FLOAT_NRAM_BUFFER_NUM), ALIGN_NUM);
} else {
nram_pts_num = PAD_DOWN(
(MAX_NRAM_SIZE / sizeof(half) / HALF_NRAM_BUFFER_NUM), ALIGN_NUM);
}
char *X = NULL;
char *Y = NULL;
char *Z = NULL;
char *local_X = NULL;
char *local_Y = NULL;
char *local_Z = NULL;
char *nram_pts_in_flag = NULL;
float *temp_buffer1 = NULL;
float *temp_buffer2 = NULL;
float *temp_buffer3 = NULL;
float *temp_buffer4 = NULL;
float *temp_buffer5 = NULL;
float *nram_voxel_offset = NULL;
int *nram_pts_idx_seq = NULL;
float *fp_local_X = NULL;
float *fp_local_Y = NULL;
float *fp_local_Z = NULL;
float *fp_nram_pts_in_flag = NULL;
if (sizeof(T) == sizeof(float)) {
X = (char *)((float *)data_nram);
Y = (char *)((float *)data_nram + nram_pts_num);
Z = (char *)((float *)data_nram + nram_pts_num * 2);
local_X = (char *)((float *)data_nram + nram_pts_num * 3);
local_Y = (char *)((float *)data_nram + nram_pts_num * 4);
local_Z = (char *)((float *)data_nram + nram_pts_num * 5);
nram_pts_in_flag = (char *)((float *)data_nram + nram_pts_num * 6);
temp_buffer1 = (float *)data_nram + nram_pts_num * 7;
temp_buffer2 = (float *)data_nram + nram_pts_num * 8;
temp_buffer3 = (float *)data_nram + nram_pts_num * 9;
temp_buffer4 = (float *)data_nram + nram_pts_num * 10;
temp_buffer5 = (float *)data_nram + nram_pts_num * 11;
nram_voxel_offset = (float *)data_nram + nram_pts_num * 12;
nram_pts_idx_seq = (int *)((float *)data_nram + nram_pts_num * 13);
fp_local_X = (float *)local_X;
fp_local_Y = (float *)local_Y;
fp_local_Z = (float *)local_Z;
fp_nram_pts_in_flag = (float *)nram_pts_in_flag;
} else {
X = (char *)((half *)data_nram);
Y = (char *)((half *)data_nram + nram_pts_num);
Z = (char *)((half *)data_nram + nram_pts_num * 2);
local_X = (char *)((half *)data_nram + nram_pts_num * 4);
local_Y = (char *)((half *)data_nram + nram_pts_num * 6);
local_Z = (char *)((half *)data_nram + nram_pts_num * 8);
nram_pts_in_flag = (char *)((half *)data_nram + nram_pts_num * 10);
temp_buffer1 = (float *)((half *)data_nram + nram_pts_num * 11);
temp_buffer2 = (float *)((half *)data_nram + nram_pts_num * 13);
temp_buffer3 = (float *)((half *)data_nram + nram_pts_num * 15);
temp_buffer4 = (float *)((half *)data_nram + nram_pts_num * 17);
temp_buffer5 = (float *)((half *)data_nram + nram_pts_num * 19);
nram_voxel_offset = (float *)((half *)data_nram + nram_pts_num * 21);
nram_pts_idx_seq = (int *)((half *)data_nram + nram_pts_num * 23);
fp_local_X = (float *)((half *)local_X - nram_pts_num);
fp_local_Y = (float *)((half *)local_Y - nram_pts_num);
fp_local_Z = (float *)((half *)local_Z - nram_pts_num);
fp_nram_pts_in_flag = (float *)((half *)nram_pts_in_flag - nram_pts_num);
}
for (int i = 0; i < nram_pts_num; i++) {
nram_pts_idx_seq[i] = i;
}
int nram_pts_loop_times = pts_num / nram_pts_num;
int rem_nram_num = pts_num % nram_pts_num;
for (int roi_index = taskId; roi_index < boxes_num; roi_index += taskDim) {
const T *cur_roi = rois + roi_index * ROI_OFFSET;
T cx = cur_roi[0];
T cy = cur_roi[1];
T cz = cur_roi[2];
T dx = cur_roi[3];
T dy = cur_roi[4];
T dz = cur_roi[5];
T rz = cur_roi[6];
T dx_2 = dx / 2.0;
T dy_2 = dy / 2.0;
T dz_2 = dz / 2.0;
for (int loop_idx = 0; loop_idx <= nram_pts_loop_times; loop_idx++) {
int load_pts_num =
(loop_idx == nram_pts_loop_times) ? rem_nram_num : nram_pts_num;
if (load_pts_num == 0) {
break;
}
int pts_offset_cur_loop = nram_pts_num * loop_idx;
int compute_pts_num = (loop_idx == nram_pts_loop_times)
? PAD_UP(rem_nram_num, ALIGN_NUM)
: nram_pts_num;
// load pts
__memcpy((void *)X, (T *)pts + pts_offset_cur_loop,
load_pts_num * sizeof(T), GDRAM2NRAM);
__memcpy((void *)Y, (T *)pts + pts_num + pts_offset_cur_loop,
load_pts_num * sizeof(T), GDRAM2NRAM);
__memcpy((void *)Z, (T *)pts + pts_num * 2 + pts_offset_cur_loop,
load_pts_num * sizeof(T), GDRAM2NRAM);
// fabs(local_z)
__bang_sub_scalar((T *)local_Z, (T *)Z, (T)cz, compute_pts_num);
__bang_sub_scalar((T *)temp_buffer1, (T *)Z, (T)(cz + dz_2),
compute_pts_num);
__bang_active_abs((T *)temp_buffer1, (T *)temp_buffer1, compute_pts_num);
#if __BANG_ARCH__ >= 322
__bang_le_scalar((T *)nram_pts_in_flag, (T *)temp_buffer1, (T)(dz_2),
compute_pts_num);
#else
__bang_write_value((void *)temp_buffer2, compute_pts_num, (T)(dz_2));
__bang_le((T *)nram_pts_in_flag, (T *)temp_buffer1, (T *)temp_buffer2,
compute_pts_num);
#endif
T cosa = std::cos(-rz);
T sina = std::sin(-rz);
__bang_sub_scalar((T *)temp_buffer3, (T *)X, (T)cx, compute_pts_num);
__bang_sub_scalar((T *)temp_buffer4, (T *)Y, (T)cy, compute_pts_num);
__bang_mul_scalar((T *)temp_buffer1, (T *)temp_buffer3, (T)cosa,
compute_pts_num);
__bang_mul_scalar((T *)temp_buffer2, (T *)temp_buffer4, (T)sina,
compute_pts_num);
// local_x
__bang_sub((T *)local_X, (T *)temp_buffer1, (T *)temp_buffer2,
compute_pts_num);
// fabs(local_x)
__bang_active_abs((T *)temp_buffer1, (T *)local_X, compute_pts_num);
// fabs(local_x) < dx/2 ? 1 : 0
#if __BANG_ARCH__ >= 322
__bang_lt_scalar((T *)temp_buffer1, (T *)temp_buffer1, (T)(dx_2),
compute_pts_num);
#else
__bang_write_value((void *)temp_buffer2, compute_pts_num, (T)(dx_2));
__bang_lt((T *)temp_buffer1, (T *)temp_buffer1, (T *)temp_buffer2,
compute_pts_num);
#endif
__bang_and((T *)nram_pts_in_flag, (T *)nram_pts_in_flag,
(T *)temp_buffer1,
compute_pts_num); // flush res
__bang_mul_scalar((T *)temp_buffer1, (T *)temp_buffer3, (T)sina,
compute_pts_num);
__bang_mul_scalar((T *)temp_buffer2, (T *)temp_buffer4, (T)cosa,
compute_pts_num);
// local_y
__bang_add((T *)local_Y, (T *)temp_buffer1, (T *)temp_buffer2,
compute_pts_num);
// fabs(local_y)
__bang_active_abs((T *)temp_buffer1, (T *)local_Y, compute_pts_num);
// fabs(local_y) < dy/2 ? 1 : 0
#if __BANG_ARCH__ >= 322
__bang_lt_scalar((T *)temp_buffer1, (T *)temp_buffer1, (T)(dy_2),
compute_pts_num);
#else
__bang_write_value((void *)temp_buffer2, compute_pts_num, (T)(dy_2));
__bang_lt((T *)temp_buffer1, (T *)temp_buffer1, (T *)temp_buffer2,
compute_pts_num);
#endif
__bang_and((T *)nram_pts_in_flag, (T *)nram_pts_in_flag,
(T *)temp_buffer1,
compute_pts_num); // flush res
T x_res = dx / out_x;
T y_res = dy / out_y;
T z_res = dz / out_z;
__bang_add_scalar((T *)local_X, (T *)local_X, (T)(dx_2), compute_pts_num);
__bang_add_scalar((T *)local_Y, (T *)local_Y, (T)(dy_2), compute_pts_num);
// local_Z do not need to add dz/2.0
#if (__BANG_ARCH__ >= 322) && (__BANG_ARCH__ != 372)
__bang_div((T *)local_X, (T *)local_X, (T)x_res, compute_pts_num);
__bang_div((T *)local_Y, (T *)local_Y, (T)y_res, compute_pts_num);
__bang_div((T *)local_Z, (T *)local_Z, (T)z_res, compute_pts_num);
#else
__bang_mul_scalar((T *)local_X, (T *)local_X, (T)(1 / x_res),
compute_pts_num);
__bang_mul_scalar((T *)local_Y, (T *)local_Y, (T)(1 / y_res),
compute_pts_num);
__bang_mul_scalar((T *)local_Z, (T *)local_Z, (T)(1 / z_res),
compute_pts_num);
#endif
// float = float2int + int2float, half = half2int + int2float
if (sizeof(T) == sizeof(float)) {
#if __BANG_ARCH__ >= 322
__bang_float2int32_tz((int *)temp_buffer1, (float *)local_X,
compute_pts_num, 0);
__bang_float2int32_tz((int *)temp_buffer2, (float *)local_Y,
compute_pts_num, 0);
__bang_float2int32_tz((int *)temp_buffer3, (float *)local_Z,
compute_pts_num, 0);
__bang_int322float_rn((float *)fp_local_X, (int *)temp_buffer1,
compute_pts_num, 0);
__bang_int322float_rn((float *)fp_local_Y, (int *)temp_buffer2,
compute_pts_num, 0);
__bang_int322float_rn((float *)fp_local_Z, (int *)temp_buffer3,
compute_pts_num, 0);
#else
convertFloat2Int((int *)temp_buffer1, (float *)temp_buffer2,
(float *)fp_local_X, (float *)temp_buffer3,
compute_pts_num);
convertFloat2Int((int *)temp_buffer2, (float *)temp_buffer3,
(float *)fp_local_Y, (float *)temp_buffer4,
compute_pts_num);
convertFloat2Int((int *)temp_buffer3, (float *)temp_buffer4,
(float *)fp_local_Z, (float *)temp_buffer5,
compute_pts_num);
convertInt2Float((float *)fp_local_X, (float *)temp_buffer4,
(int *)temp_buffer1, (float *)temp_buffer5,
compute_pts_num);
convertInt2Float((float *)fp_local_Y, (float *)temp_buffer4,
(int *)temp_buffer2, (float *)temp_buffer5,
compute_pts_num);
convertInt2Float((float *)fp_local_Z, (float *)temp_buffer4,
(int *)temp_buffer3, (float *)temp_buffer5,
compute_pts_num);
#endif
} else {
__bang_half2float((float *)temp_buffer4, (half *)nram_pts_in_flag,
compute_pts_num);
__bang_move((void *)fp_nram_pts_in_flag, (void *)temp_buffer4,
compute_pts_num * sizeof(float));
#if __BANG_ARCH__ >= 322
__bang_half2int32_tz((int *)temp_buffer1, (half *)local_X,
compute_pts_num, 0);
__bang_half2int32_tz((int *)temp_buffer2, (half *)local_Y,
compute_pts_num, 0);
__bang_half2int32_tz((int *)temp_buffer3, (half *)local_Z,
compute_pts_num, 0);
__bang_int322float_rn((float *)fp_local_X, (int *)temp_buffer1,
compute_pts_num, 0);
__bang_int322float_rn((float *)fp_local_Y, (int *)temp_buffer2,
compute_pts_num, 0);
__bang_int322float_rn((float *)fp_local_Z, (int *)temp_buffer3,
compute_pts_num, 0);
#else
__bang_half2int16_tz((int16_t *)temp_buffer1, (half *)local_X,
compute_pts_num, 0);
__bang_half2int16_tz((int16_t *)temp_buffer2, (half *)local_Y,
compute_pts_num, 0);
__bang_half2int16_tz((int16_t *)temp_buffer3, (half *)local_Z,
compute_pts_num, 0);
__bang_int162float((float *)fp_local_X, (int16_t *)temp_buffer1,
compute_pts_num, 0);
__bang_int162float((float *)fp_local_Y, (int16_t *)temp_buffer2,
compute_pts_num, 0);
__bang_int162float((float *)fp_local_Z, (int16_t *)temp_buffer3,
compute_pts_num, 0);
#endif
}
// process index >= 0
__bang_write_value((float *)temp_buffer4, compute_pts_num, (float)0.0f);
__bang_maxequal((float *)fp_local_X, (float *)fp_local_X,
(float *)temp_buffer4, compute_pts_num);
__bang_maxequal((float *)fp_local_Y, (float *)fp_local_Y,
(float *)temp_buffer4, compute_pts_num);
__bang_maxequal((float *)fp_local_Z, (float *)fp_local_Z,
(float *)temp_buffer4, compute_pts_num);
// process index <= out_x - 1)
__bang_write_value((float *)temp_buffer5, compute_pts_num,
(float)(out_x - 1));
__bang_minequal((float *)fp_local_X, (float *)fp_local_X,
(float *)temp_buffer5, compute_pts_num);
__bang_write_value((float *)temp_buffer5, compute_pts_num,
(float)(out_y - 1));
__bang_minequal((float *)fp_local_Y, (float *)fp_local_Y,
(float *)temp_buffer5, compute_pts_num);
__bang_write_value((float *)temp_buffer5, compute_pts_num,
(float)(out_z - 1));
__bang_minequal((float *)fp_local_Z, (float *)fp_local_Z,
(float *)temp_buffer5, compute_pts_num);
__bang_mul_scalar((float *)temp_buffer1, (float *)fp_local_X,
(float)(out_y * out_z), compute_pts_num);
__bang_mul_scalar((float *)temp_buffer2, (float *)fp_local_Y,
(float)out_z, compute_pts_num);
__bang_mul_scalar((float *)temp_buffer3, (float *)fp_local_Z, (float)1.0,
compute_pts_num);
__bang_add((float *)nram_voxel_offset, (float *)temp_buffer1,
(float *)temp_buffer2, compute_pts_num);
__bang_add((float *)nram_voxel_offset, (float *)nram_voxel_offset,
(float *)temp_buffer3, compute_pts_num);
__bang_mul_scalar((float *)nram_voxel_offset, (float *)nram_voxel_offset,
(float)max_pts_each_voxel, compute_pts_num);
if (compute_pts_num != load_pts_num) {
__memset_nram((float *)fp_nram_pts_in_flag + load_pts_num,
compute_pts_num - load_pts_num, (float)0.0);
}
__bang_collect((float *)temp_buffer4, (float *)nram_pts_idx_seq,
(float *)fp_nram_pts_in_flag, compute_pts_num);
int pts_num_in_cur_roi =
(int)__bang_count((float *)fp_nram_pts_in_flag, compute_pts_num);
int *pts_idx_cur_voxels =
(int *)pts_idx_of_voxels +
roi_index * out_x * out_y * out_z * max_pts_each_voxel;
for (int idx = 0; idx < pts_num_in_cur_roi; idx++) {
int cur_pts_idx = *((int *)temp_buffer4 + idx);
int offset = (int)(*((float *)nram_voxel_offset + cur_pts_idx));
int cnt = pts_idx_cur_voxels[offset];
if (cnt < max_pts_each_voxel - 1) {
pts_idx_cur_voxels[offset + cnt + 1] =
cur_pts_idx + loop_idx * nram_pts_num;
pts_idx_cur_voxels[offset]++;
}
}
}
}
}
template <typename T>
__mlu_global__ void MLUUnion1KernelRoiawarePool3dForward(
const int pool_method, const int boxes_num, const int pts_num,
const int channels, const int max_pts_each_voxel, const int out_x,
const int out_y, const int out_z, const T *pts_feature,
const int *pts_idx_of_voxels, T *pooled_features, int *argmax) {
// params (T)pts_feature: (channels, pts_num)
// params (int)pts_idx_of_voxels: (boxes_num, out_x, out_y, out_z,
// max_pts_each_voxel) params (int)argmax: (boxes_num, out_x, out_y, out_z,
// channels) params (T)pooled_features: (boxes_num, out_x, out_y, out_z,
// channels)
// make sure that memcore is not used
if (coreId == 0x80) {
return;
}
int align_num = NFU_ALIGN_SIZE / sizeof(T);
int align_max_pts_each_voxel = PAD_UP(max_pts_each_voxel, align_num);
int nram_channels_limit =
PAD_DOWN((MAX_NRAM_SIZE - 128 -
align_max_pts_each_voxel * (sizeof(int) + sizeof(T))) /
((align_max_pts_each_voxel + 1) * sizeof(T) + sizeof(int)),
align_num);
int *nram_pts_idx_cur_voxel = (int *)data_nram;
// nram_pts_idx_cur_voxel [align_max_pts_each_voxel]
T *nram_max_pts_feature_tmp =
(T *)((int *)nram_pts_idx_cur_voxel + align_max_pts_each_voxel);
// nram_max_pts_feature_tmp [align_max_pts_each_voxel]
T *nram_pts_feature_in_voxel =
((T *)nram_max_pts_feature_tmp + align_max_pts_each_voxel);
// nram_pts_feature_in_voxel [nram_channels_limit, align_max_pts_each_voxel]
T *nram_pooled_features_cur_voxel =
((T *)nram_pts_feature_in_voxel +
nram_channels_limit * align_max_pts_each_voxel);
// nram_pooled_features_cur_voxel [nram_channels_limit]
int *nram_argmax_cur_voxel =
(int *)((T *)nram_pooled_features_cur_voxel + nram_channels_limit);
// nram_argmax_cur_voxel [nram_channels_limit]
char *one_pooled_feature =
(char *)((int *)nram_argmax_cur_voxel + nram_channels_limit);
// one_pooled_feature [128]
int channels_loop_times = channels / nram_channels_limit;
int rem_channels = channels % nram_channels_limit;
for (int voxel_index = taskId;
voxel_index < boxes_num * out_x * out_y * out_z;
voxel_index += taskDim) {
int *pts_idx_cur_voxels =
(int *)pts_idx_of_voxels + voxel_index * max_pts_each_voxel;
__memcpy((void *)nram_pts_idx_cur_voxel, (void *)pts_idx_cur_voxels,
max_pts_each_voxel * sizeof(int), GDRAM2NRAM);
int pts_num_cur_voxel = nram_pts_idx_cur_voxel[0];
if (pts_num_cur_voxel == 0) {
continue;
}
for (int channels_loop_idx = 0; channels_loop_idx <= channels_loop_times;
channels_loop_idx++) {
int actual_channels_num = (channels_loop_idx == channels_loop_times)
? rem_channels
: nram_channels_limit;
if (actual_channels_num == 0) {
break;
}
int channels_offset = nram_channels_limit * channels_loop_idx;
#if ((__BANG_ARCH__ >= 200) && (__BANG_ARCH__ < 300))
int compute_channels_num = (channels_loop_idx == channels_loop_times)
? PAD_UP(rem_channels, align_num)
: nram_channels_limit;
if (pool_method == 0) {
__bang_write_value((void *)nram_pts_feature_in_voxel,
compute_channels_num * align_max_pts_each_voxel,
(T)-INFINITY);
}
#endif
T *pts_feature_cur_loop = (T *)pts_feature + channels_offset * pts_num;
for (int idx = 0; idx < pts_num_cur_voxel; idx++) {
__memcpy((T *)nram_pts_feature_in_voxel + idx,
(T *)pts_feature_cur_loop + nram_pts_idx_cur_voxel[idx + 1],
sizeof(T), GDRAM2NRAM, align_max_pts_each_voxel * sizeof(T),
pts_num * sizeof(T), actual_channels_num - 1);
}
for (int channel_idx = 0; channel_idx < actual_channels_num;
channel_idx++) {
if (pool_method == 0) {
#if __BANG_ARCH__ >= 322
__bang_argmax((T *)one_pooled_feature,
(T *)nram_pts_feature_in_voxel +
channel_idx * align_max_pts_each_voxel,
pts_num_cur_voxel);
T max_val = ((T *)one_pooled_feature)[0];
int max_idx = (int)(*(uint32_t *)((T *)one_pooled_feature + 1));
nram_pooled_features_cur_voxel[channel_idx] =
(max_val == -INFINITY) ? 0 : max_val;
nram_argmax_cur_voxel[channel_idx] =
(max_val == -INFINITY) ? -1 : nram_pts_idx_cur_voxel[max_idx + 1];
#else
// __bang_max need align num on mlu200 series
if (sizeof(T) == sizeof(float)) {
__bang_max((float *)one_pooled_feature,
(float *)nram_pts_feature_in_voxel +
channel_idx * align_max_pts_each_voxel,
align_max_pts_each_voxel);
float max_val = ((float *)one_pooled_feature)[0];
__bang_write_value((void *)nram_max_pts_feature_tmp,
align_max_pts_each_voxel, (float)max_val);
__bang_eq((float *)nram_max_pts_feature_tmp,
(float *)nram_pts_feature_in_voxel +
channel_idx * align_max_pts_each_voxel,
(float *)nram_max_pts_feature_tmp,
align_max_pts_each_voxel);
int max_idx = (int)__bang_findfirst1(
(float *)nram_max_pts_feature_tmp, align_max_pts_each_voxel);
nram_pooled_features_cur_voxel[channel_idx] =
(max_val == -INFINITY) ? 0 : max_val;
nram_argmax_cur_voxel[channel_idx] =
(max_val == -INFINITY) ? -1
: nram_pts_idx_cur_voxel[max_idx + 1];
} else {
int max_idx = -1;
float max_val = -INFINITY;
for (int k = 0; k < pts_num_cur_voxel; k++) {
float pts_feature_cur_channel = __half2float_rd(
*((half *)nram_pts_feature_in_voxel +
channel_idx * align_max_pts_each_voxel + k));
if (pts_feature_cur_channel > max_val) {
max_val = pts_feature_cur_channel;
max_idx = k;
}
}
nram_pooled_features_cur_voxel[channel_idx] =
(max_idx == -1) ? 0 : max_val;
nram_argmax_cur_voxel[channel_idx] =
(max_idx == -1) ? -1 : nram_pts_idx_cur_voxel[max_idx + 1];
}
#endif
} else if (pool_method == 1) {
float sum_val_cur_channel = 0;
for (int k = 0; k < pts_num_cur_voxel; k++) {
sum_val_cur_channel += static_cast<float>(
((T *)nram_pts_feature_in_voxel)[channel_idx *
align_max_pts_each_voxel +
k]);
}
nram_pooled_features_cur_voxel[channel_idx] =
(T)(sum_val_cur_channel / pts_num_cur_voxel);
}
}
// store
__memcpy((T *)pooled_features + voxel_index * channels + channels_offset,
(void *)nram_pooled_features_cur_voxel,
actual_channels_num * sizeof(T), NRAM2GDRAM);
if (pool_method == 0) {
__memcpy((int *)argmax + voxel_index * channels + channels_offset,
(void *)nram_argmax_cur_voxel,
actual_channels_num * sizeof(int), NRAM2GDRAM);
}
}
}
}
void KernelPtsIdxOfVoxels(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, const cnrtDataType_t d_type,
const int pool_method, const int boxes_num,
const int pts_num, const int max_pts_each_voxel,
const int out_x, const int out_y, const int out_z,
const void *rois, const void *pts,
int *pts_idx_of_voxels) {
switch (d_type) {
case CNRT_FLOAT32: {
MLUUnion1KernelPtsIdxOfVoxels<float><<<k_dim, k_type, queue>>>(
pool_method, boxes_num, pts_num, max_pts_each_voxel, out_x, out_y,
out_z, (float *)rois, (float *)pts, (int *)pts_idx_of_voxels);
}; break;
case CNRT_FLOAT16: {
MLUUnion1KernelPtsIdxOfVoxels<half><<<k_dim, k_type, queue>>>(
pool_method, boxes_num, pts_num, max_pts_each_voxel, out_x, out_y,
out_z, (half *)rois, (half *)pts, (int *)pts_idx_of_voxels);
}; break;
default: {
break;
}
}
}
void KernelRoiawarePool3dForward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const int pool_method, const int boxes_num,
const int pts_num, const int channels, const int max_pts_each_voxel,
const int out_x, const int out_y, const int out_z, const void *pts_feature,
const int *pts_idx_of_voxels, void *pooled_features, int *argmax) {
switch (d_type) {
case CNRT_FLOAT32: {
MLUUnion1KernelRoiawarePool3dForward<float><<<k_dim, k_type, queue>>>(
pool_method, boxes_num, pts_num, channels, max_pts_each_voxel, out_x,
out_y, out_z, (float *)pts_feature, (int *)pts_idx_of_voxels,
(float *)pooled_features, (int *)argmax);
}; break;
case CNRT_FLOAT16: {
MLUUnion1KernelRoiawarePool3dForward<half><<<k_dim, k_type, queue>>>(
pool_method, boxes_num, pts_num, channels, max_pts_each_voxel, out_x,
out_y, out_z, (half *)pts_feature, (int *)pts_idx_of_voxels,
(half *)pooled_features, (int *)argmax);
}; break;
default: {
break;
}
}
}
template <typename T>
__mlu_global__ void MLUUnion1KernelRoiawareMaxPool3dBackward(
const int boxes_num, const int out_x, const int out_y, const int out_z,
const int channels, const int *argmax, const T *grad_out, T *grad_in) {
// params (int)argmax: (boxes_num, out_x, out_y, out_z, channels)
// params (T)grad_out: (boxes_num, out_x, out_y, out_z, channels)
// params (T)grad_in: (pts_num, channels)
// make sure that memcore is not used
if (coreId == 0x80) {
return;
}
int nram_channels_limit =
(MAX_NRAM_SIZE - sizeof(T) * 1) / (sizeof(T) + sizeof(int));
int *nram_argmax_cur_loop = (int *)data_nram;
// nram_argmax_cur_loop [nram_channels_limit]
T *nram_grad_out_cur_loop =
(T *)((int *)nram_argmax_cur_loop + nram_channels_limit);
// nram_grad_out_cur_loop [nram_channels_limit]
T *nram_grad_in_cur_channel =
(T *)nram_grad_out_cur_loop + nram_channels_limit;
// nram_grad_in_cur_channel [1]
int channels_loop_times = channels / nram_channels_limit;
int rem_channels = channels % nram_channels_limit;
int voxels_num = boxes_num * out_x * out_y * out_z;
for (int voxel_index = taskId; voxel_index < voxels_num;
voxel_index += taskDim) {
const int *argmax_cur_voxel = argmax + voxel_index * channels;
const T *grad_out_cur_voxel = grad_out + voxel_index * channels;
for (int channels_loop_idx = 0; channels_loop_idx <= channels_loop_times;
channels_loop_idx++) {
int actual_channels_num = (channels_loop_idx == channels_loop_times)
? rem_channels
: nram_channels_limit;
if (actual_channels_num == 0) {
break;
}
const int *argmax_cur_loop =
argmax_cur_voxel + nram_channels_limit * channels_loop_idx;
const T *grad_out_cur_loop =
grad_out_cur_voxel + nram_channels_limit * channels_loop_idx;
__memcpy((void *)nram_argmax_cur_loop, (void *)argmax_cur_loop,
actual_channels_num * sizeof(int), GDRAM2NRAM);
__memcpy((void *)nram_grad_out_cur_loop, (void *)grad_out_cur_loop,
actual_channels_num * sizeof(T), GDRAM2NRAM);
for (int channel_idx = 0; channel_idx < actual_channels_num;
channel_idx++) {
int *nram_argmax_cur_channel = nram_argmax_cur_loop + channel_idx;
T *nram_grad_out_cur_channel = nram_grad_out_cur_loop + channel_idx;
if (nram_argmax_cur_channel[0] == -1) {
continue;
}
T *grad_in_cur_channel =
grad_in + nram_argmax_cur_channel[0] * channels +
nram_channels_limit * channels_loop_idx + channel_idx;
__bang_atomic_add((T *)nram_grad_in_cur_channel,
(T *)grad_in_cur_channel,
(T *)(nram_grad_out_cur_channel), 1);
}
}
}
}
template <typename T>
__mlu_global__ void MLUUnion1KernelRoiawareAvgPool3dBackward(
const int boxes_num, const int out_x, const int out_y, const int out_z,
const int channels, const int max_pts_each_voxel,
const int *pts_idx_of_voxels, const T *grad_out, T *grad_in) {
// params (int)pts_idx_of_voxels: (boxes_num, out_x, out_y, out_z,
// max_pts_each_voxel) params (T)grad_out: (boxes_num, out_x, out_y, out_z,
// channels) params (T)grad_in: (pts_num, channels)
// make sure that memcore is not used
if (coreId == 0x80) {
return;
}
int align_num = NFU_ALIGN_SIZE / sizeof(T);
int align_max_pts_each_voxel = PAD_UP(max_pts_each_voxel, align_num);
int nram_channels_limit = PAD_DOWN(
(MAX_NRAM_SIZE - align_max_pts_each_voxel * sizeof(int)) / 2 / sizeof(T),
align_num);
int *nram_pts_idx_cur_voxel = (int *)data_nram;
// nram_pts_idx_cur_voxel [align_max_pts_each_voxel]
T *nram_grad_out_cur_loop =
(T *)((int *)nram_pts_idx_cur_voxel + align_max_pts_each_voxel);
// nram_grad_out_cur_loop [nram_channels_limit]
T *nram_grad_in_cur_loop = (T *)nram_grad_out_cur_loop + nram_channels_limit;
// nram_grad_in_cur_loop [nram_channels_limit]
int channels_loop_times = channels / nram_channels_limit;
int rem_channels = channels % nram_channels_limit;
int voxels_num = boxes_num * out_x * out_y * out_z;
for (int voxel_index = taskId; voxel_index < voxels_num;
voxel_index += taskDim) {
const T *grad_out_cur_voxel = grad_out + voxel_index * channels;
const int *pts_idx_cur_voxel =
pts_idx_of_voxels + voxel_index * max_pts_each_voxel;
__memcpy((void *)nram_pts_idx_cur_voxel, (void *)pts_idx_cur_voxel,
max_pts_each_voxel * sizeof(int), GDRAM2NRAM);
int total_pts_of_voxel = nram_pts_idx_cur_voxel[0];
if (total_pts_of_voxel <= 0) {
continue;
}
float cur_grad = 1.0 / ((float)total_pts_of_voxel);
for (int channels_loop_idx = 0; channels_loop_idx <= channels_loop_times;
channels_loop_idx++) {
int actual_channels_num = (channels_loop_idx == channels_loop_times)
? rem_channels
: nram_channels_limit;
if (actual_channels_num == 0) {
break;
}
const T *grad_out_cur_loop =
grad_out_cur_voxel + nram_channels_limit * channels_loop_idx;
__memcpy((void *)nram_grad_in_cur_loop, (void *)grad_out_cur_loop,
actual_channels_num * sizeof(T), GDRAM2NRAM);
int align_actual_channels_num = PAD_UP(actual_channels_num, align_num);
if (sizeof(T) == sizeof(half)) {
__bang_half2float((float *)nram_grad_out_cur_loop,
(half *)nram_grad_in_cur_loop,
align_actual_channels_num);
__bang_mul_scalar((float *)nram_grad_out_cur_loop,
(float *)nram_grad_out_cur_loop, (float)cur_grad,
align_actual_channels_num);
convertFloat2half((half *)nram_grad_out_cur_loop,
(float *)nram_grad_out_cur_loop,
align_actual_channels_num);
} else {
__bang_mul_scalar((float *)nram_grad_out_cur_loop,
(float *)nram_grad_in_cur_loop, (float)cur_grad,
align_actual_channels_num);
}
for (int k = 1; k <= total_pts_of_voxel; k++) {
T *grad_in_cur_loop = grad_in + nram_pts_idx_cur_voxel[k] * channels +
nram_channels_limit * channels_loop_idx;
__bang_atomic_add((T *)nram_grad_in_cur_loop, (T *)grad_in_cur_loop,
(T *)nram_grad_out_cur_loop, actual_channels_num);
}
}
}
}
void KernelRoiawarePool3dBackward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const int pool_method, const int boxes_num,
const int out_x, const int out_y, const int out_z, const int channels,
const int max_pts_each_voxel, const int *pts_idx_of_voxels,
const int *argmax, const void *grad_out, void *grad_in) {
if (pool_method == 0) {
switch (d_type) {
case CNRT_FLOAT32: {
MLUUnion1KernelRoiawareMaxPool3dBackward<float>
<<<k_dim, k_type, queue>>>(boxes_num, out_x, out_y, out_z, channels,
(int *)argmax, (float *)grad_out,
(float *)grad_in);
}; break;
case CNRT_FLOAT16: {
MLUUnion1KernelRoiawareMaxPool3dBackward<half>
<<<k_dim, k_type, queue>>>(boxes_num, out_x, out_y, out_z, channels,
(int *)argmax, (half *)grad_out,
(half *)grad_in);
}; break;
default: {
break;
}
}
} else {
switch (d_type) {
case CNRT_FLOAT32: {
MLUUnion1KernelRoiawareAvgPool3dBackward<float>
<<<k_dim, k_type, queue>>>(
boxes_num, out_x, out_y, out_z, channels, max_pts_each_voxel,
(int *)pts_idx_of_voxels, (float *)grad_out, (float *)grad_in);
}; break;
case CNRT_FLOAT16: {
MLUUnion1KernelRoiawareAvgPool3dBackward<half>
<<<k_dim, k_type, queue>>>(
boxes_num, out_x, out_y, out_z, channels, max_pts_each_voxel,
(int *)pts_idx_of_voxels, (half *)grad_out, (half *)grad_in);
}; break;
default: {
break;
}
}
}
}

View File

@ -9,49 +9,7 @@
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
void KernelPtsIdxOfVoxels(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, const cnrtDataType_t d_type,
const int pool_method, const int boxes_num,
const int pts_num, const int max_pts_each_voxel,
const int out_x, const int out_y, const int out_z,
const void *rois, const void *pts,
int *pts_idx_of_voxels);
void KernelRoiawarePool3dForward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const int pool_method, const int boxes_num,
const int pts_num, const int channels, const int max_pts_each_voxel,
const int out_x, const int out_y, const int out_z, const void *pts_feature,
const int *pts_idx_of_voxels, void *pooled_features, int *argmax);
// policy function
static void kernelPtsIdxOfVoxelsPolicyFunc(const int boxes_num,
cnrtDim3_t *k_dim,
cnrtFunctionType_t *k_type) {
unsigned int core_num = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
unsigned int cluster_num = torch_mlu::getDeviceAttr(cnrtAttrClusterCount);
*k_type = CNRT_FUNC_TYPE_UNION1;
k_dim->x = core_num;
unsigned int use_cluster = (boxes_num + core_num - 1) / core_num;
k_dim->y = use_cluster > cluster_num ? cluster_num : use_cluster;
k_dim->z = 1;
}
static void kernelRoiawarePool3dForwardPolicyFunc(
const int boxes_num, const int out_x, const int out_y, const int out_z,
cnrtDim3_t *k_dim, cnrtFunctionType_t *k_type) {
unsigned int core_num = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
unsigned int cluster_num = torch_mlu::getDeviceAttr(cnrtAttrClusterCount);
*k_type = CNRT_FUNC_TYPE_UNION1;
k_dim->x = core_num;
const int voxels_num = boxes_num * out_x * out_y * out_z;
unsigned int use_cluster = (voxels_num + core_num - 1) / core_num;
k_dim->y = use_cluster > cluster_num ? cluster_num : use_cluster;
k_dim->z = 1;
}
#include "mlu_common_helper.h"
void RoiawarePool3dForwardMLUKernelLauncher(
const int pool_method, const int boxes_num, const int pts_num,
@ -59,168 +17,65 @@ void RoiawarePool3dForwardMLUKernelLauncher(
const int out_y, const int out_z, const Tensor rois, const Tensor pts,
const Tensor pts_feature, Tensor pts_idx_of_voxels, Tensor pooled_features,
Tensor argmax) {
// check datatype
TORCH_CHECK(((pts.scalar_type() == rois.scalar_type()) &&
(pts_feature.scalar_type() == rois.scalar_type()) &&
(pooled_features.scalar_type() == rois.scalar_type())),
"data types of rois, rois, pts_feature and pooled_features "
"should be the same, ",
"but now rois type is ", rois.scalar_type(), ", pts type is ",
pts.scalar_type(), ", pts_feature type is ",
pts_feature.scalar_type(), ", pooled_features type is ",
pooled_features.scalar_type(), ".");
TORCH_CHECK(
(rois.scalar_type() == at::kFloat || rois.scalar_type() == at::kHalf),
"rois type should be Float or Half, got ", rois.scalar_type(), ".");
TORCH_CHECK((pts_idx_of_voxels.scalar_type() == at::kInt),
"pts_idx_of_voxels type should be Int, got ",
pts_idx_of_voxels.scalar_type(), ".");
// check dim
TORCH_CHECK(rois.dim() == 2, "rois should be a 2D tensor, got ", rois.dim(),
"D.");
TORCH_CHECK(pts.dim() == 2, "pts should be a 2D tensor, got ", pts.dim(),
"D.");
TORCH_CHECK(pts_feature.dim() == 2, "pts_feature should be a 2D tensor, got ",
pts_feature.dim(), "D.");
TORCH_CHECK(pts_idx_of_voxels.dim() == 5,
"pts_idx_of_voxels should be a 5D tensor, got ",
pts_idx_of_voxels.dim(), "D.");
TORCH_CHECK(pooled_features.dim() == 5,
"pooled_features should be a 5D tensor, got ",
pooled_features.dim(), "D.");
// check shape
TORCH_CHECK(((rois.size(0) == boxes_num) && (rois.size(1) == 7)),
"the dimensions of rois should be (boxes_num, 7), ", "but got (",
rois.size(0), ", ", rois.size(1), ") .");
TORCH_CHECK(((pts.size(0) == pts_num) && (pts.size(1) == 3)),
"the dimensions of pts should be (pts_num, 3), ", "but got (",
pts.size(0), ",", pts.size(1), ").");
TORCH_CHECK(
((pts_feature.size(0) == pts_num) && (pts_feature.size(1) == channels)),
"the dimensions of pts_feature should be (pts_num, channels), ",
"but got (", pts_feature.size(0), ",", pts_feature.size(1), ").");
TORCH_CHECK(((pts_idx_of_voxels.size(0) == boxes_num) &&
(pts_idx_of_voxels.size(1) == out_x) &&
(pts_idx_of_voxels.size(2) == out_y) &&
(pts_idx_of_voxels.size(3) == out_z) &&
(pts_idx_of_voxels.size(4) == max_pts_each_voxel)),
"the dimensions of pts_idx_of_voxels should be (boxes_num, "
"out_x, out_y, out_z, max_pts_each_voxel), ",
"but got (", pts_idx_of_voxels.size(0), ",",
pts_idx_of_voxels.size(1), ",", pts_idx_of_voxels.size(2), ",",
pts_idx_of_voxels.size(3), ",", pts_idx_of_voxels.size(4), ").");
TORCH_CHECK(((pooled_features.size(0) == boxes_num) &&
(pooled_features.size(1) == out_x) &&
(pooled_features.size(2) == out_y) &&
(pooled_features.size(3) == out_z) &&
(pooled_features.size(4) == channels)),
"the dimensions of pooled_features should be (boxes_num, out_x, "
"out_y, out_z, channels), ",
"but got (", pooled_features.size(0), ",",
pooled_features.size(1), ",", pooled_features.size(2), ",",
pooled_features.size(3), ",", pooled_features.size(4), ").");
// check other params : pool_mothod
TORCH_CHECK(((pool_method == 0) || (pool_method == 1)),
"the num of pool_method should be 0(max) or 1(avg), ", "but got ",
pool_method, ".");
// check large tensor
const size_t max_input_size = 2147483648;
TORCH_CHECK(rois.numel() < max_input_size,
"rois element num should be less than 2^31, got ", rois.numel(),
".");
TORCH_CHECK(pts.numel() < max_input_size,
"pts element num should be less than 2^31, got ", pts.numel(),
".");
TORCH_CHECK(pts_feature.numel() < max_input_size,
"pts_feature element num should be less than 2^31, got ",
pts_feature.numel(), ".");
TORCH_CHECK(pts_idx_of_voxels.numel() < max_input_size,
"pts_idx_of_voxels element num should be less than 2^31, got ",
pts_idx_of_voxels.numel(), ".");
TORCH_CHECK(pooled_features.numel() < max_input_size,
"pooled_features element num should be less than 2^31, got ",
pooled_features.numel(), ".");
// check zero element
TORCH_CHECK(rois.numel() != 0, "rois.numel() should not be zero, got ",
rois.numel());
TORCH_CHECK(pts.numel() != 0, "pts.numel() should not be zero, got ",
pts.numel());
TORCH_CHECK(pts_feature.numel() != 0,
"pts_feature.numel() should not be zero, got ",
pts_feature.numel());
TORCH_CHECK(pts_idx_of_voxels.numel() != 0,
"pts_idx_of_voxels.numel() should not be zero, got ",
pts_idx_of_voxels.numel());
TORCH_CHECK(pooled_features.numel() != 0,
"pooled_features.numel() should not be zero, got ",
pooled_features.numel());
if (pool_method == 0) {
// check datatype
TORCH_CHECK((argmax.scalar_type() == at::kInt),
"argmax type should be Int, got ", argmax.scalar_type(), ".");
// check dim
TORCH_CHECK(argmax.dim() == 5, "argmax should be a 5D tensor, got ",
argmax.dim(), "D.");
// check shape
TORCH_CHECK(((argmax.size(0) == boxes_num) && (argmax.size(1) == out_x) &&
(argmax.size(2) == out_y) && (argmax.size(3) == out_z) &&
(argmax.size(4) == channels)),
"the dimensions of argmax should be (boxes_num, out_x, out_y, "
"out_z, channels), ",
"but got (", argmax.size(0), ",", argmax.size(1), ",",
argmax.size(2), ",", argmax.size(3), ",", argmax.size(4), ").");
// check large tensor
TORCH_CHECK(argmax.numel() < max_input_size,
"argmax element num should be less than 2^31, got ",
argmax.numel(), ".");
// check zero element
TORCH_CHECK(argmax.numel() != 0, "argmax.numel() should not be zero, got ",
argmax.numel());
// when pool_method is 0, which is max pool, init argmax data value to -1
argmax.fill_(static_cast<int>(-1));
}
// calculate task one dimension
cnrtDim3_t k1_dim;
cnrtFunctionType_t k1_type;
kernelPtsIdxOfVoxelsPolicyFunc(boxes_num, &k1_dim, &k1_type);
cnrtDim3_t k2_dim;
cnrtFunctionType_t k2_type;
kernelRoiawarePool3dForwardPolicyFunc(boxes_num, out_x, out_y, out_z, &k2_dim,
&k2_type);
// get compute queue
auto queue = torch_mlu::getCurQueue();
// get ptr of tensors
auto rois_impl = torch_mlu::getMluTensorImpl(rois);
// get compute handle
auto handle = mluOpGetCurrentHandle();
auto rois_contiguous =
torch_mlu::cnnl::ops::cnnl_contiguous(rois, rois.suggest_memory_format());
auto pts_contiguous =
torch_mlu::cnnl::ops::cnnl_contiguous(pts, pts.suggest_memory_format());
auto pts_feature_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
pts_feature, pts_feature.suggest_memory_format());
auto argmax_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
argmax, argmax.suggest_memory_format());
auto pts_idx_of_voxels_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
pts_idx_of_voxels, pts_idx_of_voxels.suggest_memory_format());
auto pooled_features_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
pooled_features, pooled_features.suggest_memory_format());
MluOpTensorDescriptor rois_desc, pts_desc, pts_feature_desc, argmax_desc,
pts_idx_of_voxels_desc, pooled_features_desc;
rois_desc.set(rois_contiguous);
pts_desc.set(pts_contiguous);
pts_feature_desc.set(pts_feature_contiguous);
argmax_desc.set(argmax_contiguous);
pts_idx_of_voxels_desc.set(pts_idx_of_voxels_contiguous);
pooled_features_desc.set(pooled_features_contiguous);
// allocate extra space for workspace
size_t workspace_size = 0;
mluOpGetRoiawarePool3dForwardWorkspaceSize(
handle, rois_desc.desc(), pts_desc.desc(), pts_feature_desc.desc(),
&workspace_size);
auto workspace = at::empty(workspace_size, rois.options().dtype(at::kByte));
auto workspace_impl = torch_mlu::getMluTensorImpl(workspace);
auto workspace_ptr = workspace_impl->cnnlMalloc();
auto rois_impl = torch_mlu::getMluTensorImpl(rois_contiguous);
auto pts_impl = torch_mlu::getMluTensorImpl(pts_contiguous);
auto pts_feature_impl = torch_mlu::getMluTensorImpl(pts_feature_contiguous);
auto argmax_impl = torch_mlu::getMluTensorImpl(argmax_contiguous);
auto pts_idx_of_voxels_impl =
torch_mlu::getMluTensorImpl(pts_idx_of_voxels_contiguous);
auto pooled_features_impl =
torch_mlu::getMluTensorImpl(pooled_features_contiguous);
auto rois_ptr = rois_impl->cnnlMalloc();
// transpose points [pts_num, 3] -> [3, pts_num]
auto pts_ = pts.permute({1, 0}).contiguous();
auto pts_impl = torch_mlu::getMluTensorImpl(pts_);
auto pts_ptr = pts_impl->cnnlMalloc();
// transpose points_features [pts_num, channels] -> [channels, pts_num]
auto pts_feature_ = pts_feature.permute({1, 0}).contiguous();
auto pts_feature_impl = torch_mlu::getMluTensorImpl(pts_feature_);
auto pts_feature_ptr = pts_feature_impl->cnnlMalloc();
auto pts_idx_of_voxels_impl = torch_mlu::getMluTensorImpl(pts_idx_of_voxels);
auto pts_idx_of_voxels_ptr = pts_idx_of_voxels_impl->cnnlMalloc();
auto pooled_features_impl = torch_mlu::getMluTensorImpl(pooled_features);
auto pooled_features_ptr = pooled_features_impl->cnnlMalloc();
auto argmax_impl = torch_mlu::getMluTensorImpl(argmax);
auto argmax_ptr = argmax_impl->cnnlMalloc();
// get compute dtype of input
cnrtDataType_t data_type = torch_mlu::toCnrtDtype(rois.dtype());
// launch kernel PtsIdxOfVoxels
CNLOG(INFO) << "Launch Kernel MLUKernel PtsIdxOfVoxels<<<" << k1_dim.x << ", "
<< k1_dim.y << ", " << k1_dim.z << ">>>";
KernelPtsIdxOfVoxels(k1_dim, k1_type, queue, data_type, pool_method,
boxes_num, pts_num, max_pts_each_voxel, out_x, out_y,
out_z, rois_ptr, pts_ptr, (int *)pts_idx_of_voxels_ptr);
// launch kernel RoiawarePool3dForward
CNLOG(INFO) << "Launch Kernel MLUKernel RoiawarePool3dForward<<<" << k2_dim.x
<< ", " << k2_dim.y << ", " << k2_dim.z << ">>>";
KernelRoiawarePool3dForward(
k2_dim, k2_type, queue, data_type, pool_method, boxes_num, pts_num,
channels, max_pts_each_voxel, out_x, out_y, out_z, pts_feature_ptr,
(int *)pts_idx_of_voxels_ptr, pooled_features_ptr, (int *)argmax_ptr);
auto pts_idx_of_voxels_ptr = pts_idx_of_voxels_impl->cnnlMalloc();
auto pooled_features_ptr = pooled_features_impl->cnnlMalloc();
CNLOG(INFO) << "Call mluOpRoiawarePool3dForward().";
mluOpRoiawarePool3dForward(
handle, pool_method, boxes_num, pts_num, channels, rois_desc.desc(),
rois_ptr, pts_desc.desc(), pts_ptr, pts_feature_desc.desc(),
pts_feature_ptr, workspace_ptr, workspace_size, max_pts_each_voxel, out_x,
out_y, out_z, argmax_desc.desc(), argmax_ptr,
pts_idx_of_voxels_desc.desc(), pts_idx_of_voxels_ptr,
pooled_features_desc.desc(), pooled_features_ptr);
}
void roiaware_pool3d_forward_mlu(int boxes_num, int pts_num, int channels,
@ -245,136 +100,46 @@ void roiaware_pool3d_forward_impl(int boxes_num, int pts_num, int channels,
REGISTER_DEVICE_IMPL(roiaware_pool3d_forward_impl, MLU,
roiaware_pool3d_forward_mlu);
void KernelRoiawarePool3dBackward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const int pool_method, const int boxes_num,
const int out_x, const int out_y, const int out_z, const int channels,
const int max_pts_each_voxel, const int *pts_idx_of_voxels,
const int *argmax, const void *grad_out, void *grad_in);
static void kernelRoiawarePool3dBackwardPolicyFunc(
const int boxes_num, const int out_x, const int out_y, const int out_z,
cnrtDim3_t *k_dim, cnrtFunctionType_t *k_type) {
unsigned int core_num = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
unsigned int cluster_num = torch_mlu::getDeviceAttr(cnrtAttrClusterCount);
*k_type = CNRT_FUNC_TYPE_UNION1;
k_dim->x = core_num;
const int voxels_num = boxes_num * out_x * out_y * out_z;
unsigned int use_cluster = (voxels_num + core_num - 1) / core_num;
k_dim->y = use_cluster > cluster_num ? cluster_num : use_cluster;
k_dim->z = 1;
}
void RoiawarePool3dBackwardMLUKernelLauncher(
int pool_method, int boxes_num, int out_x, int out_y, int out_z,
int channels, int max_pts_each_voxel, const Tensor pts_idx_of_voxels,
const Tensor argmax, const Tensor grad_out, Tensor grad_in) {
// check datatype
TORCH_CHECK((pts_idx_of_voxels.scalar_type() == at::kInt),
"pts_idx_of_voxels type should be Int, got ",
pts_idx_of_voxels.scalar_type(), ".");
TORCH_CHECK((argmax.scalar_type() == at::kInt),
"argmax type should be Int, got ", argmax.scalar_type(), ".");
TORCH_CHECK((grad_out.scalar_type() == at::kFloat ||
grad_out.scalar_type() == at::kHalf),
"grad_out type should be Float or Half, got ",
grad_out.scalar_type(), ".");
TORCH_CHECK((grad_out.scalar_type() == grad_in.scalar_type()),
"data types of grad_out, grad_in, should be the same, ",
"but now grad_out type is ", grad_out.scalar_type(),
", grad_in type is ", grad_in.scalar_type(), ".");
// check dim
TORCH_CHECK(pts_idx_of_voxels.dim() == 5,
"pts_idx_of_voxels should be a 5D tensor, got ",
pts_idx_of_voxels.dim(), "D.");
TORCH_CHECK(argmax.dim() == 5, "argmax should be a 5D tensor, got ",
argmax.dim(), "D.");
TORCH_CHECK(grad_out.dim() == 5, "grad_out should be a 5D tensor, got ",
grad_out.dim(), "D.");
TORCH_CHECK(grad_in.dim() == 2, "grad_in should be a 2D tensor, got ",
grad_in.dim(), "D.");
// check shape
TORCH_CHECK(((pts_idx_of_voxels.size(0) == boxes_num) &&
(pts_idx_of_voxels.size(1) == out_x) &&
(pts_idx_of_voxels.size(2) == out_y) &&
(pts_idx_of_voxels.size(3) == out_z) &&
(pts_idx_of_voxels.size(4) == max_pts_each_voxel)),
"the dimensions of pts_idx_of_voxels should be (boxes_num, "
"out_x, out_y, out_z, max_pts_each_voxel), ",
"but got (", pts_idx_of_voxels.size(0), ",",
pts_idx_of_voxels.size(1), ",", pts_idx_of_voxels.size(2), ",",
pts_idx_of_voxels.size(3), ",", pts_idx_of_voxels.size(4), ").");
TORCH_CHECK(((argmax.size(0) == boxes_num) && (argmax.size(1) == out_x) &&
(argmax.size(2) == out_y) && (argmax.size(3) == out_z) &&
(argmax.size(4) == channels)),
"the dimensions of argmax should be (boxes_num, out_x, out_y, "
"out_z, channels), ",
"but got (", argmax.size(0), ",", argmax.size(1), ",",
argmax.size(2), ",", argmax.size(3), ",", argmax.size(4), ").");
TORCH_CHECK(((grad_out.size(0) == boxes_num) && (grad_out.size(1) == out_x) &&
(grad_out.size(2) == out_y) && (grad_out.size(3) == out_z) &&
(grad_out.size(4) == channels)),
"the dimensions of grad_out should be (boxes_num, out_x, "
"out_y, out_z, channels), ",
"but got (", grad_out.size(0), ",", grad_out.size(1), ",",
grad_out.size(2), ",", grad_out.size(3), ",", grad_out.size(4),
").");
TORCH_CHECK((grad_in.size(1) == channels),
"the 1st dimensions of grad_in should be channels, ", "but got ",
grad_in.size(1), ".");
// check other params : pool_mothod
TORCH_CHECK(((pool_method == 0) || (pool_method == 1)),
"the num of pool_method should be 0(max) or 1(avg), ", "but got ",
pool_method, ".");
// check large tensor
const size_t max_input_size = 2147483648;
TORCH_CHECK(pts_idx_of_voxels.numel() < max_input_size,
"pts_idx_of_voxels element num should be less than 2^31, got ",
pts_idx_of_voxels.numel(), ".");
TORCH_CHECK(argmax.numel() < max_input_size,
"argmax element num should be less than 2^31, got ",
argmax.numel(), ".");
TORCH_CHECK(grad_out.numel() < max_input_size,
"grad_out element num should be less than 2^31, got ",
grad_out.numel(), ".");
TORCH_CHECK(grad_in.numel() < max_input_size,
"grad_in element num should be less than 2^31, got ",
grad_in.numel(), ".");
// check zero element
TORCH_CHECK(pts_idx_of_voxels.numel() != 0,
"pts_idx_of_voxels.numel() should not be zero, got ",
pts_idx_of_voxels.numel());
TORCH_CHECK(argmax.numel() != 0, "argmax.numel() should not be zero, got ",
argmax.numel());
TORCH_CHECK(grad_out.numel() != 0,
"grad_out.numel() should not be zero, got ", grad_out.numel());
TORCH_CHECK(grad_in.numel() != 0, "grad_in.numel() should not be zero, got ",
grad_in.numel());
// calculate task one dimension
cnrtDim3_t k_dim;
cnrtFunctionType_t k_type;
kernelRoiawarePool3dBackwardPolicyFunc(boxes_num, out_x, out_y, out_z, &k_dim,
&k_type);
// get compute queue
auto queue = torch_mlu::getCurQueue();
// transpose points_features [pts_num, channels] -> [channels, pts_num]
auto pts_idx_of_voxels_impl = torch_mlu::getMluTensorImpl(pts_idx_of_voxels);
// get compute handle
auto handle = mluOpGetCurrentHandle();
auto pts_idx_of_voxels_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
pts_idx_of_voxels, pts_idx_of_voxels.suggest_memory_format());
auto argmax_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
argmax, argmax.suggest_memory_format());
auto grad_out_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
grad_out, grad_out.suggest_memory_format());
auto grad_in_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
grad_in, grad_in.suggest_memory_format());
MluOpTensorDescriptor pts_idx_of_voxels_desc, argmax_desc, grad_out_desc,
grad_in_desc;
pts_idx_of_voxels_desc.set(pts_idx_of_voxels_contiguous);
argmax_desc.set(argmax_contiguous);
grad_out_desc.set(grad_out_contiguous);
grad_in_desc.set(grad_in_contiguous);
auto pts_idx_of_voxels_impl =
torch_mlu::getMluTensorImpl(pts_idx_of_voxels_contiguous);
auto argmax_impl = torch_mlu::getMluTensorImpl(argmax_contiguous);
auto grad_out_impl = torch_mlu::getMluTensorImpl(grad_out_contiguous);
auto grad_in_impl = torch_mlu::getMluTensorImpl(grad_in_contiguous);
auto pts_idx_of_voxels_ptr = pts_idx_of_voxels_impl->cnnlMalloc();
auto argmax_impl = torch_mlu::getMluTensorImpl(argmax);
auto argmax_ptr = argmax_impl->cnnlMalloc();
auto grad_out_impl = torch_mlu::getMluTensorImpl(grad_out);
auto grad_out_ptr = grad_out_impl->cnnlMalloc();
auto grad_in_impl = torch_mlu::getMluTensorImpl(grad_in);
auto grad_in_ptr = grad_in_impl->cnnlMalloc();
// get compute dtype of input
cnrtDataType_t data_type = torch_mlu::toCnrtDtype(grad_out.dtype());
// launch kernel RoiawarePool3dForward
CNLOG(INFO) << "Launch Kernel MLUKernel RoiawarePool3dBackward<<<" << k_dim.x
<< ", " << k_dim.y << ", " << k_dim.z << ">>>";
KernelRoiawarePool3dBackward(k_dim, k_type, queue, data_type, pool_method,
boxes_num, out_x, out_y, out_z, channels,
max_pts_each_voxel, (int *)pts_idx_of_voxels_ptr,
(int *)argmax_ptr, grad_out_ptr, grad_in_ptr);
CNLOG(INFO) << "Call mluOpRoiawarePool3dBackward().";
mluOpRoiawarePool3dBackward(
handle, pool_method, boxes_num, out_x, out_y, out_z, channels,
max_pts_each_voxel, pts_idx_of_voxels_desc.desc(), pts_idx_of_voxels_ptr,
argmax_desc.desc(), argmax_ptr, grad_out_desc.desc(), grad_out_ptr,
grad_in_desc.desc(), grad_in_ptr);
}
void roiaware_pool3d_backward_mlu(int boxes_num, int out_x, int out_y,