mirror of https://github.com/open-mmlab/mmcv.git
[Enhancement] Repalce the implementation of roiaware_pool3d with mlu-ops. (#2699)
* [Feature] Repalce the implementation of roiaware_pool3d with mlu-ops. * [Feature] Repalce the implementation of roiaware_pool3d with mlu-ops.pull/2731/head
parent
a55f4b7f40
commit
bc727f7132
|
@ -1,747 +0,0 @@
|
|||
/*************************************************************************
|
||||
* Copyright (C) 2022 Cambricon.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include "common_mlu_helper.hpp"
|
||||
|
||||
#define ROI_OFFSET 7
|
||||
#define FLOAT_NRAM_BUFFER_NUM 14
|
||||
#define HALF_NRAM_BUFFER_NUM 25
|
||||
#define ALIGN_NUM 64
|
||||
|
||||
__nram__ char data_nram[MAX_NRAM_SIZE];
|
||||
|
||||
template <typename T>
|
||||
__mlu_global__ void MLUUnion1KernelPtsIdxOfVoxels(
|
||||
const int pool_method, const int boxes_num, const int pts_num,
|
||||
const int max_pts_each_voxel, const int out_x, const int out_y,
|
||||
const int out_z, const T *rois, const T *pts, int *pts_idx_of_voxels) {
|
||||
// params (T)rois: (boxes_num, 7)
|
||||
// params (T)pts: (3, pts_num)
|
||||
// params (int)pts_idx_of_voxels: (boxes_num, out_x, out_y, out_z,
|
||||
// max_pts_each_voxel)
|
||||
|
||||
// make sure that memcore is not used
|
||||
if (coreId == 0x80) {
|
||||
return;
|
||||
}
|
||||
int nram_pts_num = 0;
|
||||
if (sizeof(T) == sizeof(float)) {
|
||||
nram_pts_num = PAD_DOWN(
|
||||
(MAX_NRAM_SIZE / sizeof(float) / FLOAT_NRAM_BUFFER_NUM), ALIGN_NUM);
|
||||
} else {
|
||||
nram_pts_num = PAD_DOWN(
|
||||
(MAX_NRAM_SIZE / sizeof(half) / HALF_NRAM_BUFFER_NUM), ALIGN_NUM);
|
||||
}
|
||||
|
||||
char *X = NULL;
|
||||
char *Y = NULL;
|
||||
char *Z = NULL;
|
||||
char *local_X = NULL;
|
||||
char *local_Y = NULL;
|
||||
char *local_Z = NULL;
|
||||
char *nram_pts_in_flag = NULL;
|
||||
float *temp_buffer1 = NULL;
|
||||
float *temp_buffer2 = NULL;
|
||||
float *temp_buffer3 = NULL;
|
||||
float *temp_buffer4 = NULL;
|
||||
float *temp_buffer5 = NULL;
|
||||
float *nram_voxel_offset = NULL;
|
||||
int *nram_pts_idx_seq = NULL;
|
||||
float *fp_local_X = NULL;
|
||||
float *fp_local_Y = NULL;
|
||||
float *fp_local_Z = NULL;
|
||||
float *fp_nram_pts_in_flag = NULL;
|
||||
if (sizeof(T) == sizeof(float)) {
|
||||
X = (char *)((float *)data_nram);
|
||||
Y = (char *)((float *)data_nram + nram_pts_num);
|
||||
Z = (char *)((float *)data_nram + nram_pts_num * 2);
|
||||
local_X = (char *)((float *)data_nram + nram_pts_num * 3);
|
||||
local_Y = (char *)((float *)data_nram + nram_pts_num * 4);
|
||||
local_Z = (char *)((float *)data_nram + nram_pts_num * 5);
|
||||
nram_pts_in_flag = (char *)((float *)data_nram + nram_pts_num * 6);
|
||||
temp_buffer1 = (float *)data_nram + nram_pts_num * 7;
|
||||
temp_buffer2 = (float *)data_nram + nram_pts_num * 8;
|
||||
temp_buffer3 = (float *)data_nram + nram_pts_num * 9;
|
||||
temp_buffer4 = (float *)data_nram + nram_pts_num * 10;
|
||||
temp_buffer5 = (float *)data_nram + nram_pts_num * 11;
|
||||
nram_voxel_offset = (float *)data_nram + nram_pts_num * 12;
|
||||
nram_pts_idx_seq = (int *)((float *)data_nram + nram_pts_num * 13);
|
||||
fp_local_X = (float *)local_X;
|
||||
fp_local_Y = (float *)local_Y;
|
||||
fp_local_Z = (float *)local_Z;
|
||||
fp_nram_pts_in_flag = (float *)nram_pts_in_flag;
|
||||
} else {
|
||||
X = (char *)((half *)data_nram);
|
||||
Y = (char *)((half *)data_nram + nram_pts_num);
|
||||
Z = (char *)((half *)data_nram + nram_pts_num * 2);
|
||||
local_X = (char *)((half *)data_nram + nram_pts_num * 4);
|
||||
local_Y = (char *)((half *)data_nram + nram_pts_num * 6);
|
||||
local_Z = (char *)((half *)data_nram + nram_pts_num * 8);
|
||||
nram_pts_in_flag = (char *)((half *)data_nram + nram_pts_num * 10);
|
||||
temp_buffer1 = (float *)((half *)data_nram + nram_pts_num * 11);
|
||||
temp_buffer2 = (float *)((half *)data_nram + nram_pts_num * 13);
|
||||
temp_buffer3 = (float *)((half *)data_nram + nram_pts_num * 15);
|
||||
temp_buffer4 = (float *)((half *)data_nram + nram_pts_num * 17);
|
||||
temp_buffer5 = (float *)((half *)data_nram + nram_pts_num * 19);
|
||||
nram_voxel_offset = (float *)((half *)data_nram + nram_pts_num * 21);
|
||||
nram_pts_idx_seq = (int *)((half *)data_nram + nram_pts_num * 23);
|
||||
fp_local_X = (float *)((half *)local_X - nram_pts_num);
|
||||
fp_local_Y = (float *)((half *)local_Y - nram_pts_num);
|
||||
fp_local_Z = (float *)((half *)local_Z - nram_pts_num);
|
||||
fp_nram_pts_in_flag = (float *)((half *)nram_pts_in_flag - nram_pts_num);
|
||||
}
|
||||
|
||||
for (int i = 0; i < nram_pts_num; i++) {
|
||||
nram_pts_idx_seq[i] = i;
|
||||
}
|
||||
|
||||
int nram_pts_loop_times = pts_num / nram_pts_num;
|
||||
int rem_nram_num = pts_num % nram_pts_num;
|
||||
|
||||
for (int roi_index = taskId; roi_index < boxes_num; roi_index += taskDim) {
|
||||
const T *cur_roi = rois + roi_index * ROI_OFFSET;
|
||||
T cx = cur_roi[0];
|
||||
T cy = cur_roi[1];
|
||||
T cz = cur_roi[2];
|
||||
T dx = cur_roi[3];
|
||||
T dy = cur_roi[4];
|
||||
T dz = cur_roi[5];
|
||||
T rz = cur_roi[6];
|
||||
|
||||
T dx_2 = dx / 2.0;
|
||||
T dy_2 = dy / 2.0;
|
||||
T dz_2 = dz / 2.0;
|
||||
|
||||
for (int loop_idx = 0; loop_idx <= nram_pts_loop_times; loop_idx++) {
|
||||
int load_pts_num =
|
||||
(loop_idx == nram_pts_loop_times) ? rem_nram_num : nram_pts_num;
|
||||
if (load_pts_num == 0) {
|
||||
break;
|
||||
}
|
||||
int pts_offset_cur_loop = nram_pts_num * loop_idx;
|
||||
int compute_pts_num = (loop_idx == nram_pts_loop_times)
|
||||
? PAD_UP(rem_nram_num, ALIGN_NUM)
|
||||
: nram_pts_num;
|
||||
// load pts
|
||||
__memcpy((void *)X, (T *)pts + pts_offset_cur_loop,
|
||||
load_pts_num * sizeof(T), GDRAM2NRAM);
|
||||
__memcpy((void *)Y, (T *)pts + pts_num + pts_offset_cur_loop,
|
||||
load_pts_num * sizeof(T), GDRAM2NRAM);
|
||||
__memcpy((void *)Z, (T *)pts + pts_num * 2 + pts_offset_cur_loop,
|
||||
load_pts_num * sizeof(T), GDRAM2NRAM);
|
||||
// fabs(local_z)
|
||||
__bang_sub_scalar((T *)local_Z, (T *)Z, (T)cz, compute_pts_num);
|
||||
__bang_sub_scalar((T *)temp_buffer1, (T *)Z, (T)(cz + dz_2),
|
||||
compute_pts_num);
|
||||
__bang_active_abs((T *)temp_buffer1, (T *)temp_buffer1, compute_pts_num);
|
||||
#if __BANG_ARCH__ >= 322
|
||||
__bang_le_scalar((T *)nram_pts_in_flag, (T *)temp_buffer1, (T)(dz_2),
|
||||
compute_pts_num);
|
||||
#else
|
||||
__bang_write_value((void *)temp_buffer2, compute_pts_num, (T)(dz_2));
|
||||
__bang_le((T *)nram_pts_in_flag, (T *)temp_buffer1, (T *)temp_buffer2,
|
||||
compute_pts_num);
|
||||
#endif
|
||||
T cosa = std::cos(-rz);
|
||||
T sina = std::sin(-rz);
|
||||
__bang_sub_scalar((T *)temp_buffer3, (T *)X, (T)cx, compute_pts_num);
|
||||
__bang_sub_scalar((T *)temp_buffer4, (T *)Y, (T)cy, compute_pts_num);
|
||||
__bang_mul_scalar((T *)temp_buffer1, (T *)temp_buffer3, (T)cosa,
|
||||
compute_pts_num);
|
||||
__bang_mul_scalar((T *)temp_buffer2, (T *)temp_buffer4, (T)sina,
|
||||
compute_pts_num);
|
||||
// local_x
|
||||
__bang_sub((T *)local_X, (T *)temp_buffer1, (T *)temp_buffer2,
|
||||
compute_pts_num);
|
||||
// fabs(local_x)
|
||||
__bang_active_abs((T *)temp_buffer1, (T *)local_X, compute_pts_num);
|
||||
// fabs(local_x) < dx/2 ? 1 : 0
|
||||
#if __BANG_ARCH__ >= 322
|
||||
__bang_lt_scalar((T *)temp_buffer1, (T *)temp_buffer1, (T)(dx_2),
|
||||
compute_pts_num);
|
||||
#else
|
||||
__bang_write_value((void *)temp_buffer2, compute_pts_num, (T)(dx_2));
|
||||
__bang_lt((T *)temp_buffer1, (T *)temp_buffer1, (T *)temp_buffer2,
|
||||
compute_pts_num);
|
||||
#endif
|
||||
__bang_and((T *)nram_pts_in_flag, (T *)nram_pts_in_flag,
|
||||
(T *)temp_buffer1,
|
||||
compute_pts_num); // flush res
|
||||
|
||||
__bang_mul_scalar((T *)temp_buffer1, (T *)temp_buffer3, (T)sina,
|
||||
compute_pts_num);
|
||||
__bang_mul_scalar((T *)temp_buffer2, (T *)temp_buffer4, (T)cosa,
|
||||
compute_pts_num);
|
||||
// local_y
|
||||
__bang_add((T *)local_Y, (T *)temp_buffer1, (T *)temp_buffer2,
|
||||
compute_pts_num);
|
||||
// fabs(local_y)
|
||||
__bang_active_abs((T *)temp_buffer1, (T *)local_Y, compute_pts_num);
|
||||
// fabs(local_y) < dy/2 ? 1 : 0
|
||||
#if __BANG_ARCH__ >= 322
|
||||
__bang_lt_scalar((T *)temp_buffer1, (T *)temp_buffer1, (T)(dy_2),
|
||||
compute_pts_num);
|
||||
#else
|
||||
__bang_write_value((void *)temp_buffer2, compute_pts_num, (T)(dy_2));
|
||||
__bang_lt((T *)temp_buffer1, (T *)temp_buffer1, (T *)temp_buffer2,
|
||||
compute_pts_num);
|
||||
#endif
|
||||
__bang_and((T *)nram_pts_in_flag, (T *)nram_pts_in_flag,
|
||||
(T *)temp_buffer1,
|
||||
compute_pts_num); // flush res
|
||||
T x_res = dx / out_x;
|
||||
T y_res = dy / out_y;
|
||||
T z_res = dz / out_z;
|
||||
__bang_add_scalar((T *)local_X, (T *)local_X, (T)(dx_2), compute_pts_num);
|
||||
__bang_add_scalar((T *)local_Y, (T *)local_Y, (T)(dy_2), compute_pts_num);
|
||||
// local_Z do not need to add dz/2.0
|
||||
|
||||
#if (__BANG_ARCH__ >= 322) && (__BANG_ARCH__ != 372)
|
||||
__bang_div((T *)local_X, (T *)local_X, (T)x_res, compute_pts_num);
|
||||
__bang_div((T *)local_Y, (T *)local_Y, (T)y_res, compute_pts_num);
|
||||
__bang_div((T *)local_Z, (T *)local_Z, (T)z_res, compute_pts_num);
|
||||
#else
|
||||
__bang_mul_scalar((T *)local_X, (T *)local_X, (T)(1 / x_res),
|
||||
compute_pts_num);
|
||||
__bang_mul_scalar((T *)local_Y, (T *)local_Y, (T)(1 / y_res),
|
||||
compute_pts_num);
|
||||
__bang_mul_scalar((T *)local_Z, (T *)local_Z, (T)(1 / z_res),
|
||||
compute_pts_num);
|
||||
#endif
|
||||
// float = float2int + int2float, half = half2int + int2float
|
||||
if (sizeof(T) == sizeof(float)) {
|
||||
#if __BANG_ARCH__ >= 322
|
||||
__bang_float2int32_tz((int *)temp_buffer1, (float *)local_X,
|
||||
compute_pts_num, 0);
|
||||
__bang_float2int32_tz((int *)temp_buffer2, (float *)local_Y,
|
||||
compute_pts_num, 0);
|
||||
__bang_float2int32_tz((int *)temp_buffer3, (float *)local_Z,
|
||||
compute_pts_num, 0);
|
||||
__bang_int322float_rn((float *)fp_local_X, (int *)temp_buffer1,
|
||||
compute_pts_num, 0);
|
||||
__bang_int322float_rn((float *)fp_local_Y, (int *)temp_buffer2,
|
||||
compute_pts_num, 0);
|
||||
__bang_int322float_rn((float *)fp_local_Z, (int *)temp_buffer3,
|
||||
compute_pts_num, 0);
|
||||
#else
|
||||
convertFloat2Int((int *)temp_buffer1, (float *)temp_buffer2,
|
||||
(float *)fp_local_X, (float *)temp_buffer3,
|
||||
compute_pts_num);
|
||||
convertFloat2Int((int *)temp_buffer2, (float *)temp_buffer3,
|
||||
(float *)fp_local_Y, (float *)temp_buffer4,
|
||||
compute_pts_num);
|
||||
convertFloat2Int((int *)temp_buffer3, (float *)temp_buffer4,
|
||||
(float *)fp_local_Z, (float *)temp_buffer5,
|
||||
compute_pts_num);
|
||||
convertInt2Float((float *)fp_local_X, (float *)temp_buffer4,
|
||||
(int *)temp_buffer1, (float *)temp_buffer5,
|
||||
compute_pts_num);
|
||||
convertInt2Float((float *)fp_local_Y, (float *)temp_buffer4,
|
||||
(int *)temp_buffer2, (float *)temp_buffer5,
|
||||
compute_pts_num);
|
||||
convertInt2Float((float *)fp_local_Z, (float *)temp_buffer4,
|
||||
(int *)temp_buffer3, (float *)temp_buffer5,
|
||||
compute_pts_num);
|
||||
#endif
|
||||
} else {
|
||||
__bang_half2float((float *)temp_buffer4, (half *)nram_pts_in_flag,
|
||||
compute_pts_num);
|
||||
__bang_move((void *)fp_nram_pts_in_flag, (void *)temp_buffer4,
|
||||
compute_pts_num * sizeof(float));
|
||||
#if __BANG_ARCH__ >= 322
|
||||
__bang_half2int32_tz((int *)temp_buffer1, (half *)local_X,
|
||||
compute_pts_num, 0);
|
||||
__bang_half2int32_tz((int *)temp_buffer2, (half *)local_Y,
|
||||
compute_pts_num, 0);
|
||||
__bang_half2int32_tz((int *)temp_buffer3, (half *)local_Z,
|
||||
compute_pts_num, 0);
|
||||
__bang_int322float_rn((float *)fp_local_X, (int *)temp_buffer1,
|
||||
compute_pts_num, 0);
|
||||
__bang_int322float_rn((float *)fp_local_Y, (int *)temp_buffer2,
|
||||
compute_pts_num, 0);
|
||||
__bang_int322float_rn((float *)fp_local_Z, (int *)temp_buffer3,
|
||||
compute_pts_num, 0);
|
||||
#else
|
||||
__bang_half2int16_tz((int16_t *)temp_buffer1, (half *)local_X,
|
||||
compute_pts_num, 0);
|
||||
__bang_half2int16_tz((int16_t *)temp_buffer2, (half *)local_Y,
|
||||
compute_pts_num, 0);
|
||||
__bang_half2int16_tz((int16_t *)temp_buffer3, (half *)local_Z,
|
||||
compute_pts_num, 0);
|
||||
__bang_int162float((float *)fp_local_X, (int16_t *)temp_buffer1,
|
||||
compute_pts_num, 0);
|
||||
__bang_int162float((float *)fp_local_Y, (int16_t *)temp_buffer2,
|
||||
compute_pts_num, 0);
|
||||
__bang_int162float((float *)fp_local_Z, (int16_t *)temp_buffer3,
|
||||
compute_pts_num, 0);
|
||||
#endif
|
||||
}
|
||||
// process index >= 0
|
||||
__bang_write_value((float *)temp_buffer4, compute_pts_num, (float)0.0f);
|
||||
__bang_maxequal((float *)fp_local_X, (float *)fp_local_X,
|
||||
(float *)temp_buffer4, compute_pts_num);
|
||||
__bang_maxequal((float *)fp_local_Y, (float *)fp_local_Y,
|
||||
(float *)temp_buffer4, compute_pts_num);
|
||||
__bang_maxequal((float *)fp_local_Z, (float *)fp_local_Z,
|
||||
(float *)temp_buffer4, compute_pts_num);
|
||||
// process index <= (out_x - 1)
|
||||
__bang_write_value((float *)temp_buffer5, compute_pts_num,
|
||||
(float)(out_x - 1));
|
||||
__bang_minequal((float *)fp_local_X, (float *)fp_local_X,
|
||||
(float *)temp_buffer5, compute_pts_num);
|
||||
__bang_write_value((float *)temp_buffer5, compute_pts_num,
|
||||
(float)(out_y - 1));
|
||||
__bang_minequal((float *)fp_local_Y, (float *)fp_local_Y,
|
||||
(float *)temp_buffer5, compute_pts_num);
|
||||
__bang_write_value((float *)temp_buffer5, compute_pts_num,
|
||||
(float)(out_z - 1));
|
||||
__bang_minequal((float *)fp_local_Z, (float *)fp_local_Z,
|
||||
(float *)temp_buffer5, compute_pts_num);
|
||||
__bang_mul_scalar((float *)temp_buffer1, (float *)fp_local_X,
|
||||
(float)(out_y * out_z), compute_pts_num);
|
||||
__bang_mul_scalar((float *)temp_buffer2, (float *)fp_local_Y,
|
||||
(float)out_z, compute_pts_num);
|
||||
__bang_mul_scalar((float *)temp_buffer3, (float *)fp_local_Z, (float)1.0,
|
||||
compute_pts_num);
|
||||
__bang_add((float *)nram_voxel_offset, (float *)temp_buffer1,
|
||||
(float *)temp_buffer2, compute_pts_num);
|
||||
__bang_add((float *)nram_voxel_offset, (float *)nram_voxel_offset,
|
||||
(float *)temp_buffer3, compute_pts_num);
|
||||
__bang_mul_scalar((float *)nram_voxel_offset, (float *)nram_voxel_offset,
|
||||
(float)max_pts_each_voxel, compute_pts_num);
|
||||
if (compute_pts_num != load_pts_num) {
|
||||
__memset_nram((float *)fp_nram_pts_in_flag + load_pts_num,
|
||||
compute_pts_num - load_pts_num, (float)0.0);
|
||||
}
|
||||
__bang_collect((float *)temp_buffer4, (float *)nram_pts_idx_seq,
|
||||
(float *)fp_nram_pts_in_flag, compute_pts_num);
|
||||
int pts_num_in_cur_roi =
|
||||
(int)__bang_count((float *)fp_nram_pts_in_flag, compute_pts_num);
|
||||
int *pts_idx_cur_voxels =
|
||||
(int *)pts_idx_of_voxels +
|
||||
roi_index * out_x * out_y * out_z * max_pts_each_voxel;
|
||||
for (int idx = 0; idx < pts_num_in_cur_roi; idx++) {
|
||||
int cur_pts_idx = *((int *)temp_buffer4 + idx);
|
||||
int offset = (int)(*((float *)nram_voxel_offset + cur_pts_idx));
|
||||
int cnt = pts_idx_cur_voxels[offset];
|
||||
if (cnt < max_pts_each_voxel - 1) {
|
||||
pts_idx_cur_voxels[offset + cnt + 1] =
|
||||
cur_pts_idx + loop_idx * nram_pts_num;
|
||||
pts_idx_cur_voxels[offset]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_global__ void MLUUnion1KernelRoiawarePool3dForward(
|
||||
const int pool_method, const int boxes_num, const int pts_num,
|
||||
const int channels, const int max_pts_each_voxel, const int out_x,
|
||||
const int out_y, const int out_z, const T *pts_feature,
|
||||
const int *pts_idx_of_voxels, T *pooled_features, int *argmax) {
|
||||
// params (T)pts_feature: (channels, pts_num)
|
||||
// params (int)pts_idx_of_voxels: (boxes_num, out_x, out_y, out_z,
|
||||
// max_pts_each_voxel) params (int)argmax: (boxes_num, out_x, out_y, out_z,
|
||||
// channels) params (T)pooled_features: (boxes_num, out_x, out_y, out_z,
|
||||
// channels)
|
||||
|
||||
// make sure that memcore is not used
|
||||
if (coreId == 0x80) {
|
||||
return;
|
||||
}
|
||||
int align_num = NFU_ALIGN_SIZE / sizeof(T);
|
||||
int align_max_pts_each_voxel = PAD_UP(max_pts_each_voxel, align_num);
|
||||
int nram_channels_limit =
|
||||
PAD_DOWN((MAX_NRAM_SIZE - 128 -
|
||||
align_max_pts_each_voxel * (sizeof(int) + sizeof(T))) /
|
||||
((align_max_pts_each_voxel + 1) * sizeof(T) + sizeof(int)),
|
||||
align_num);
|
||||
int *nram_pts_idx_cur_voxel = (int *)data_nram;
|
||||
// nram_pts_idx_cur_voxel [align_max_pts_each_voxel]
|
||||
T *nram_max_pts_feature_tmp =
|
||||
(T *)((int *)nram_pts_idx_cur_voxel + align_max_pts_each_voxel);
|
||||
// nram_max_pts_feature_tmp [align_max_pts_each_voxel]
|
||||
T *nram_pts_feature_in_voxel =
|
||||
((T *)nram_max_pts_feature_tmp + align_max_pts_each_voxel);
|
||||
// nram_pts_feature_in_voxel [nram_channels_limit, align_max_pts_each_voxel]
|
||||
T *nram_pooled_features_cur_voxel =
|
||||
((T *)nram_pts_feature_in_voxel +
|
||||
nram_channels_limit * align_max_pts_each_voxel);
|
||||
// nram_pooled_features_cur_voxel [nram_channels_limit]
|
||||
int *nram_argmax_cur_voxel =
|
||||
(int *)((T *)nram_pooled_features_cur_voxel + nram_channels_limit);
|
||||
// nram_argmax_cur_voxel [nram_channels_limit]
|
||||
char *one_pooled_feature =
|
||||
(char *)((int *)nram_argmax_cur_voxel + nram_channels_limit);
|
||||
// one_pooled_feature [128]
|
||||
int channels_loop_times = channels / nram_channels_limit;
|
||||
int rem_channels = channels % nram_channels_limit;
|
||||
for (int voxel_index = taskId;
|
||||
voxel_index < boxes_num * out_x * out_y * out_z;
|
||||
voxel_index += taskDim) {
|
||||
int *pts_idx_cur_voxels =
|
||||
(int *)pts_idx_of_voxels + voxel_index * max_pts_each_voxel;
|
||||
__memcpy((void *)nram_pts_idx_cur_voxel, (void *)pts_idx_cur_voxels,
|
||||
max_pts_each_voxel * sizeof(int), GDRAM2NRAM);
|
||||
int pts_num_cur_voxel = nram_pts_idx_cur_voxel[0];
|
||||
if (pts_num_cur_voxel == 0) {
|
||||
continue;
|
||||
}
|
||||
for (int channels_loop_idx = 0; channels_loop_idx <= channels_loop_times;
|
||||
channels_loop_idx++) {
|
||||
int actual_channels_num = (channels_loop_idx == channels_loop_times)
|
||||
? rem_channels
|
||||
: nram_channels_limit;
|
||||
if (actual_channels_num == 0) {
|
||||
break;
|
||||
}
|
||||
int channels_offset = nram_channels_limit * channels_loop_idx;
|
||||
|
||||
#if ((__BANG_ARCH__ >= 200) && (__BANG_ARCH__ < 300))
|
||||
int compute_channels_num = (channels_loop_idx == channels_loop_times)
|
||||
? PAD_UP(rem_channels, align_num)
|
||||
: nram_channels_limit;
|
||||
if (pool_method == 0) {
|
||||
__bang_write_value((void *)nram_pts_feature_in_voxel,
|
||||
compute_channels_num * align_max_pts_each_voxel,
|
||||
(T)-INFINITY);
|
||||
}
|
||||
#endif
|
||||
|
||||
T *pts_feature_cur_loop = (T *)pts_feature + channels_offset * pts_num;
|
||||
for (int idx = 0; idx < pts_num_cur_voxel; idx++) {
|
||||
__memcpy((T *)nram_pts_feature_in_voxel + idx,
|
||||
(T *)pts_feature_cur_loop + nram_pts_idx_cur_voxel[idx + 1],
|
||||
sizeof(T), GDRAM2NRAM, align_max_pts_each_voxel * sizeof(T),
|
||||
pts_num * sizeof(T), actual_channels_num - 1);
|
||||
}
|
||||
for (int channel_idx = 0; channel_idx < actual_channels_num;
|
||||
channel_idx++) {
|
||||
if (pool_method == 0) {
|
||||
#if __BANG_ARCH__ >= 322
|
||||
__bang_argmax((T *)one_pooled_feature,
|
||||
(T *)nram_pts_feature_in_voxel +
|
||||
channel_idx * align_max_pts_each_voxel,
|
||||
pts_num_cur_voxel);
|
||||
T max_val = ((T *)one_pooled_feature)[0];
|
||||
int max_idx = (int)(*(uint32_t *)((T *)one_pooled_feature + 1));
|
||||
nram_pooled_features_cur_voxel[channel_idx] =
|
||||
(max_val == -INFINITY) ? 0 : max_val;
|
||||
nram_argmax_cur_voxel[channel_idx] =
|
||||
(max_val == -INFINITY) ? -1 : nram_pts_idx_cur_voxel[max_idx + 1];
|
||||
#else
|
||||
// __bang_max need align num on mlu200 series
|
||||
if (sizeof(T) == sizeof(float)) {
|
||||
__bang_max((float *)one_pooled_feature,
|
||||
(float *)nram_pts_feature_in_voxel +
|
||||
channel_idx * align_max_pts_each_voxel,
|
||||
align_max_pts_each_voxel);
|
||||
float max_val = ((float *)one_pooled_feature)[0];
|
||||
__bang_write_value((void *)nram_max_pts_feature_tmp,
|
||||
align_max_pts_each_voxel, (float)max_val);
|
||||
__bang_eq((float *)nram_max_pts_feature_tmp,
|
||||
(float *)nram_pts_feature_in_voxel +
|
||||
channel_idx * align_max_pts_each_voxel,
|
||||
(float *)nram_max_pts_feature_tmp,
|
||||
align_max_pts_each_voxel);
|
||||
int max_idx = (int)__bang_findfirst1(
|
||||
(float *)nram_max_pts_feature_tmp, align_max_pts_each_voxel);
|
||||
nram_pooled_features_cur_voxel[channel_idx] =
|
||||
(max_val == -INFINITY) ? 0 : max_val;
|
||||
nram_argmax_cur_voxel[channel_idx] =
|
||||
(max_val == -INFINITY) ? -1
|
||||
: nram_pts_idx_cur_voxel[max_idx + 1];
|
||||
} else {
|
||||
int max_idx = -1;
|
||||
float max_val = -INFINITY;
|
||||
for (int k = 0; k < pts_num_cur_voxel; k++) {
|
||||
float pts_feature_cur_channel = __half2float_rd(
|
||||
*((half *)nram_pts_feature_in_voxel +
|
||||
channel_idx * align_max_pts_each_voxel + k));
|
||||
if (pts_feature_cur_channel > max_val) {
|
||||
max_val = pts_feature_cur_channel;
|
||||
max_idx = k;
|
||||
}
|
||||
}
|
||||
nram_pooled_features_cur_voxel[channel_idx] =
|
||||
(max_idx == -1) ? 0 : max_val;
|
||||
nram_argmax_cur_voxel[channel_idx] =
|
||||
(max_idx == -1) ? -1 : nram_pts_idx_cur_voxel[max_idx + 1];
|
||||
}
|
||||
#endif
|
||||
} else if (pool_method == 1) {
|
||||
float sum_val_cur_channel = 0;
|
||||
for (int k = 0; k < pts_num_cur_voxel; k++) {
|
||||
sum_val_cur_channel += static_cast<float>(
|
||||
((T *)nram_pts_feature_in_voxel)[channel_idx *
|
||||
align_max_pts_each_voxel +
|
||||
k]);
|
||||
}
|
||||
nram_pooled_features_cur_voxel[channel_idx] =
|
||||
(T)(sum_val_cur_channel / pts_num_cur_voxel);
|
||||
}
|
||||
}
|
||||
// store
|
||||
__memcpy((T *)pooled_features + voxel_index * channels + channels_offset,
|
||||
(void *)nram_pooled_features_cur_voxel,
|
||||
actual_channels_num * sizeof(T), NRAM2GDRAM);
|
||||
if (pool_method == 0) {
|
||||
__memcpy((int *)argmax + voxel_index * channels + channels_offset,
|
||||
(void *)nram_argmax_cur_voxel,
|
||||
actual_channels_num * sizeof(int), NRAM2GDRAM);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void KernelPtsIdxOfVoxels(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
|
||||
cnrtQueue_t queue, const cnrtDataType_t d_type,
|
||||
const int pool_method, const int boxes_num,
|
||||
const int pts_num, const int max_pts_each_voxel,
|
||||
const int out_x, const int out_y, const int out_z,
|
||||
const void *rois, const void *pts,
|
||||
int *pts_idx_of_voxels) {
|
||||
switch (d_type) {
|
||||
case CNRT_FLOAT32: {
|
||||
MLUUnion1KernelPtsIdxOfVoxels<float><<<k_dim, k_type, queue>>>(
|
||||
pool_method, boxes_num, pts_num, max_pts_each_voxel, out_x, out_y,
|
||||
out_z, (float *)rois, (float *)pts, (int *)pts_idx_of_voxels);
|
||||
}; break;
|
||||
case CNRT_FLOAT16: {
|
||||
MLUUnion1KernelPtsIdxOfVoxels<half><<<k_dim, k_type, queue>>>(
|
||||
pool_method, boxes_num, pts_num, max_pts_each_voxel, out_x, out_y,
|
||||
out_z, (half *)rois, (half *)pts, (int *)pts_idx_of_voxels);
|
||||
}; break;
|
||||
default: {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void KernelRoiawarePool3dForward(
|
||||
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
|
||||
const cnrtDataType_t d_type, const int pool_method, const int boxes_num,
|
||||
const int pts_num, const int channels, const int max_pts_each_voxel,
|
||||
const int out_x, const int out_y, const int out_z, const void *pts_feature,
|
||||
const int *pts_idx_of_voxels, void *pooled_features, int *argmax) {
|
||||
switch (d_type) {
|
||||
case CNRT_FLOAT32: {
|
||||
MLUUnion1KernelRoiawarePool3dForward<float><<<k_dim, k_type, queue>>>(
|
||||
pool_method, boxes_num, pts_num, channels, max_pts_each_voxel, out_x,
|
||||
out_y, out_z, (float *)pts_feature, (int *)pts_idx_of_voxels,
|
||||
(float *)pooled_features, (int *)argmax);
|
||||
}; break;
|
||||
case CNRT_FLOAT16: {
|
||||
MLUUnion1KernelRoiawarePool3dForward<half><<<k_dim, k_type, queue>>>(
|
||||
pool_method, boxes_num, pts_num, channels, max_pts_each_voxel, out_x,
|
||||
out_y, out_z, (half *)pts_feature, (int *)pts_idx_of_voxels,
|
||||
(half *)pooled_features, (int *)argmax);
|
||||
}; break;
|
||||
default: {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_global__ void MLUUnion1KernelRoiawareMaxPool3dBackward(
|
||||
const int boxes_num, const int out_x, const int out_y, const int out_z,
|
||||
const int channels, const int *argmax, const T *grad_out, T *grad_in) {
|
||||
// params (int)argmax: (boxes_num, out_x, out_y, out_z, channels)
|
||||
// params (T)grad_out: (boxes_num, out_x, out_y, out_z, channels)
|
||||
// params (T)grad_in: (pts_num, channels)
|
||||
|
||||
// make sure that memcore is not used
|
||||
if (coreId == 0x80) {
|
||||
return;
|
||||
}
|
||||
int nram_channels_limit =
|
||||
(MAX_NRAM_SIZE - sizeof(T) * 1) / (sizeof(T) + sizeof(int));
|
||||
int *nram_argmax_cur_loop = (int *)data_nram;
|
||||
// nram_argmax_cur_loop [nram_channels_limit]
|
||||
T *nram_grad_out_cur_loop =
|
||||
(T *)((int *)nram_argmax_cur_loop + nram_channels_limit);
|
||||
// nram_grad_out_cur_loop [nram_channels_limit]
|
||||
T *nram_grad_in_cur_channel =
|
||||
(T *)nram_grad_out_cur_loop + nram_channels_limit;
|
||||
// nram_grad_in_cur_channel [1]
|
||||
int channels_loop_times = channels / nram_channels_limit;
|
||||
int rem_channels = channels % nram_channels_limit;
|
||||
int voxels_num = boxes_num * out_x * out_y * out_z;
|
||||
|
||||
for (int voxel_index = taskId; voxel_index < voxels_num;
|
||||
voxel_index += taskDim) {
|
||||
const int *argmax_cur_voxel = argmax + voxel_index * channels;
|
||||
const T *grad_out_cur_voxel = grad_out + voxel_index * channels;
|
||||
|
||||
for (int channels_loop_idx = 0; channels_loop_idx <= channels_loop_times;
|
||||
channels_loop_idx++) {
|
||||
int actual_channels_num = (channels_loop_idx == channels_loop_times)
|
||||
? rem_channels
|
||||
: nram_channels_limit;
|
||||
if (actual_channels_num == 0) {
|
||||
break;
|
||||
}
|
||||
const int *argmax_cur_loop =
|
||||
argmax_cur_voxel + nram_channels_limit * channels_loop_idx;
|
||||
const T *grad_out_cur_loop =
|
||||
grad_out_cur_voxel + nram_channels_limit * channels_loop_idx;
|
||||
__memcpy((void *)nram_argmax_cur_loop, (void *)argmax_cur_loop,
|
||||
actual_channels_num * sizeof(int), GDRAM2NRAM);
|
||||
__memcpy((void *)nram_grad_out_cur_loop, (void *)grad_out_cur_loop,
|
||||
actual_channels_num * sizeof(T), GDRAM2NRAM);
|
||||
|
||||
for (int channel_idx = 0; channel_idx < actual_channels_num;
|
||||
channel_idx++) {
|
||||
int *nram_argmax_cur_channel = nram_argmax_cur_loop + channel_idx;
|
||||
T *nram_grad_out_cur_channel = nram_grad_out_cur_loop + channel_idx;
|
||||
if (nram_argmax_cur_channel[0] == -1) {
|
||||
continue;
|
||||
}
|
||||
T *grad_in_cur_channel =
|
||||
grad_in + nram_argmax_cur_channel[0] * channels +
|
||||
nram_channels_limit * channels_loop_idx + channel_idx;
|
||||
__bang_atomic_add((T *)nram_grad_in_cur_channel,
|
||||
(T *)grad_in_cur_channel,
|
||||
(T *)(nram_grad_out_cur_channel), 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_global__ void MLUUnion1KernelRoiawareAvgPool3dBackward(
|
||||
const int boxes_num, const int out_x, const int out_y, const int out_z,
|
||||
const int channels, const int max_pts_each_voxel,
|
||||
const int *pts_idx_of_voxels, const T *grad_out, T *grad_in) {
|
||||
// params (int)pts_idx_of_voxels: (boxes_num, out_x, out_y, out_z,
|
||||
// max_pts_each_voxel) params (T)grad_out: (boxes_num, out_x, out_y, out_z,
|
||||
// channels) params (T)grad_in: (pts_num, channels)
|
||||
|
||||
// make sure that memcore is not used
|
||||
if (coreId == 0x80) {
|
||||
return;
|
||||
}
|
||||
int align_num = NFU_ALIGN_SIZE / sizeof(T);
|
||||
int align_max_pts_each_voxel = PAD_UP(max_pts_each_voxel, align_num);
|
||||
int nram_channels_limit = PAD_DOWN(
|
||||
(MAX_NRAM_SIZE - align_max_pts_each_voxel * sizeof(int)) / 2 / sizeof(T),
|
||||
align_num);
|
||||
int *nram_pts_idx_cur_voxel = (int *)data_nram;
|
||||
// nram_pts_idx_cur_voxel [align_max_pts_each_voxel]
|
||||
T *nram_grad_out_cur_loop =
|
||||
(T *)((int *)nram_pts_idx_cur_voxel + align_max_pts_each_voxel);
|
||||
// nram_grad_out_cur_loop [nram_channels_limit]
|
||||
T *nram_grad_in_cur_loop = (T *)nram_grad_out_cur_loop + nram_channels_limit;
|
||||
// nram_grad_in_cur_loop [nram_channels_limit]
|
||||
int channels_loop_times = channels / nram_channels_limit;
|
||||
int rem_channels = channels % nram_channels_limit;
|
||||
int voxels_num = boxes_num * out_x * out_y * out_z;
|
||||
|
||||
for (int voxel_index = taskId; voxel_index < voxels_num;
|
||||
voxel_index += taskDim) {
|
||||
const T *grad_out_cur_voxel = grad_out + voxel_index * channels;
|
||||
const int *pts_idx_cur_voxel =
|
||||
pts_idx_of_voxels + voxel_index * max_pts_each_voxel;
|
||||
__memcpy((void *)nram_pts_idx_cur_voxel, (void *)pts_idx_cur_voxel,
|
||||
max_pts_each_voxel * sizeof(int), GDRAM2NRAM);
|
||||
int total_pts_of_voxel = nram_pts_idx_cur_voxel[0];
|
||||
if (total_pts_of_voxel <= 0) {
|
||||
continue;
|
||||
}
|
||||
float cur_grad = 1.0 / ((float)total_pts_of_voxel);
|
||||
|
||||
for (int channels_loop_idx = 0; channels_loop_idx <= channels_loop_times;
|
||||
channels_loop_idx++) {
|
||||
int actual_channels_num = (channels_loop_idx == channels_loop_times)
|
||||
? rem_channels
|
||||
: nram_channels_limit;
|
||||
if (actual_channels_num == 0) {
|
||||
break;
|
||||
}
|
||||
const T *grad_out_cur_loop =
|
||||
grad_out_cur_voxel + nram_channels_limit * channels_loop_idx;
|
||||
__memcpy((void *)nram_grad_in_cur_loop, (void *)grad_out_cur_loop,
|
||||
actual_channels_num * sizeof(T), GDRAM2NRAM);
|
||||
|
||||
int align_actual_channels_num = PAD_UP(actual_channels_num, align_num);
|
||||
|
||||
if (sizeof(T) == sizeof(half)) {
|
||||
__bang_half2float((float *)nram_grad_out_cur_loop,
|
||||
(half *)nram_grad_in_cur_loop,
|
||||
align_actual_channels_num);
|
||||
__bang_mul_scalar((float *)nram_grad_out_cur_loop,
|
||||
(float *)nram_grad_out_cur_loop, (float)cur_grad,
|
||||
align_actual_channels_num);
|
||||
convertFloat2half((half *)nram_grad_out_cur_loop,
|
||||
(float *)nram_grad_out_cur_loop,
|
||||
align_actual_channels_num);
|
||||
} else {
|
||||
__bang_mul_scalar((float *)nram_grad_out_cur_loop,
|
||||
(float *)nram_grad_in_cur_loop, (float)cur_grad,
|
||||
align_actual_channels_num);
|
||||
}
|
||||
for (int k = 1; k <= total_pts_of_voxel; k++) {
|
||||
T *grad_in_cur_loop = grad_in + nram_pts_idx_cur_voxel[k] * channels +
|
||||
nram_channels_limit * channels_loop_idx;
|
||||
__bang_atomic_add((T *)nram_grad_in_cur_loop, (T *)grad_in_cur_loop,
|
||||
(T *)nram_grad_out_cur_loop, actual_channels_num);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void KernelRoiawarePool3dBackward(
|
||||
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
|
||||
const cnrtDataType_t d_type, const int pool_method, const int boxes_num,
|
||||
const int out_x, const int out_y, const int out_z, const int channels,
|
||||
const int max_pts_each_voxel, const int *pts_idx_of_voxels,
|
||||
const int *argmax, const void *grad_out, void *grad_in) {
|
||||
if (pool_method == 0) {
|
||||
switch (d_type) {
|
||||
case CNRT_FLOAT32: {
|
||||
MLUUnion1KernelRoiawareMaxPool3dBackward<float>
|
||||
<<<k_dim, k_type, queue>>>(boxes_num, out_x, out_y, out_z, channels,
|
||||
(int *)argmax, (float *)grad_out,
|
||||
(float *)grad_in);
|
||||
}; break;
|
||||
case CNRT_FLOAT16: {
|
||||
MLUUnion1KernelRoiawareMaxPool3dBackward<half>
|
||||
<<<k_dim, k_type, queue>>>(boxes_num, out_x, out_y, out_z, channels,
|
||||
(int *)argmax, (half *)grad_out,
|
||||
(half *)grad_in);
|
||||
}; break;
|
||||
default: {
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
switch (d_type) {
|
||||
case CNRT_FLOAT32: {
|
||||
MLUUnion1KernelRoiawareAvgPool3dBackward<float>
|
||||
<<<k_dim, k_type, queue>>>(
|
||||
boxes_num, out_x, out_y, out_z, channels, max_pts_each_voxel,
|
||||
(int *)pts_idx_of_voxels, (float *)grad_out, (float *)grad_in);
|
||||
}; break;
|
||||
case CNRT_FLOAT16: {
|
||||
MLUUnion1KernelRoiawareAvgPool3dBackward<half>
|
||||
<<<k_dim, k_type, queue>>>(
|
||||
boxes_num, out_x, out_y, out_z, channels, max_pts_each_voxel,
|
||||
(int *)pts_idx_of_voxels, (half *)grad_out, (half *)grad_in);
|
||||
}; break;
|
||||
default: {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -9,49 +9,7 @@
|
|||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#include "pytorch_device_registry.hpp"
|
||||
#include "pytorch_mlu_helper.hpp"
|
||||
|
||||
void KernelPtsIdxOfVoxels(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
|
||||
cnrtQueue_t queue, const cnrtDataType_t d_type,
|
||||
const int pool_method, const int boxes_num,
|
||||
const int pts_num, const int max_pts_each_voxel,
|
||||
const int out_x, const int out_y, const int out_z,
|
||||
const void *rois, const void *pts,
|
||||
int *pts_idx_of_voxels);
|
||||
|
||||
void KernelRoiawarePool3dForward(
|
||||
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
|
||||
const cnrtDataType_t d_type, const int pool_method, const int boxes_num,
|
||||
const int pts_num, const int channels, const int max_pts_each_voxel,
|
||||
const int out_x, const int out_y, const int out_z, const void *pts_feature,
|
||||
const int *pts_idx_of_voxels, void *pooled_features, int *argmax);
|
||||
|
||||
// policy function
|
||||
static void kernelPtsIdxOfVoxelsPolicyFunc(const int boxes_num,
|
||||
cnrtDim3_t *k_dim,
|
||||
cnrtFunctionType_t *k_type) {
|
||||
unsigned int core_num = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
|
||||
unsigned int cluster_num = torch_mlu::getDeviceAttr(cnrtAttrClusterCount);
|
||||
*k_type = CNRT_FUNC_TYPE_UNION1;
|
||||
k_dim->x = core_num;
|
||||
unsigned int use_cluster = (boxes_num + core_num - 1) / core_num;
|
||||
k_dim->y = use_cluster > cluster_num ? cluster_num : use_cluster;
|
||||
k_dim->z = 1;
|
||||
}
|
||||
|
||||
static void kernelRoiawarePool3dForwardPolicyFunc(
|
||||
const int boxes_num, const int out_x, const int out_y, const int out_z,
|
||||
cnrtDim3_t *k_dim, cnrtFunctionType_t *k_type) {
|
||||
unsigned int core_num = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
|
||||
unsigned int cluster_num = torch_mlu::getDeviceAttr(cnrtAttrClusterCount);
|
||||
*k_type = CNRT_FUNC_TYPE_UNION1;
|
||||
k_dim->x = core_num;
|
||||
const int voxels_num = boxes_num * out_x * out_y * out_z;
|
||||
unsigned int use_cluster = (voxels_num + core_num - 1) / core_num;
|
||||
k_dim->y = use_cluster > cluster_num ? cluster_num : use_cluster;
|
||||
k_dim->z = 1;
|
||||
}
|
||||
#include "mlu_common_helper.h"
|
||||
|
||||
void RoiawarePool3dForwardMLUKernelLauncher(
|
||||
const int pool_method, const int boxes_num, const int pts_num,
|
||||
|
@ -59,168 +17,65 @@ void RoiawarePool3dForwardMLUKernelLauncher(
|
|||
const int out_y, const int out_z, const Tensor rois, const Tensor pts,
|
||||
const Tensor pts_feature, Tensor pts_idx_of_voxels, Tensor pooled_features,
|
||||
Tensor argmax) {
|
||||
// check datatype
|
||||
TORCH_CHECK(((pts.scalar_type() == rois.scalar_type()) &&
|
||||
(pts_feature.scalar_type() == rois.scalar_type()) &&
|
||||
(pooled_features.scalar_type() == rois.scalar_type())),
|
||||
"data types of rois, rois, pts_feature and pooled_features "
|
||||
"should be the same, ",
|
||||
"but now rois type is ", rois.scalar_type(), ", pts type is ",
|
||||
pts.scalar_type(), ", pts_feature type is ",
|
||||
pts_feature.scalar_type(), ", pooled_features type is ",
|
||||
pooled_features.scalar_type(), ".");
|
||||
TORCH_CHECK(
|
||||
(rois.scalar_type() == at::kFloat || rois.scalar_type() == at::kHalf),
|
||||
"rois type should be Float or Half, got ", rois.scalar_type(), ".");
|
||||
TORCH_CHECK((pts_idx_of_voxels.scalar_type() == at::kInt),
|
||||
"pts_idx_of_voxels type should be Int, got ",
|
||||
pts_idx_of_voxels.scalar_type(), ".");
|
||||
// check dim
|
||||
TORCH_CHECK(rois.dim() == 2, "rois should be a 2D tensor, got ", rois.dim(),
|
||||
"D.");
|
||||
TORCH_CHECK(pts.dim() == 2, "pts should be a 2D tensor, got ", pts.dim(),
|
||||
"D.");
|
||||
TORCH_CHECK(pts_feature.dim() == 2, "pts_feature should be a 2D tensor, got ",
|
||||
pts_feature.dim(), "D.");
|
||||
TORCH_CHECK(pts_idx_of_voxels.dim() == 5,
|
||||
"pts_idx_of_voxels should be a 5D tensor, got ",
|
||||
pts_idx_of_voxels.dim(), "D.");
|
||||
TORCH_CHECK(pooled_features.dim() == 5,
|
||||
"pooled_features should be a 5D tensor, got ",
|
||||
pooled_features.dim(), "D.");
|
||||
// check shape
|
||||
TORCH_CHECK(((rois.size(0) == boxes_num) && (rois.size(1) == 7)),
|
||||
"the dimensions of rois should be (boxes_num, 7), ", "but got (",
|
||||
rois.size(0), ", ", rois.size(1), ") .");
|
||||
TORCH_CHECK(((pts.size(0) == pts_num) && (pts.size(1) == 3)),
|
||||
"the dimensions of pts should be (pts_num, 3), ", "but got (",
|
||||
pts.size(0), ",", pts.size(1), ").");
|
||||
TORCH_CHECK(
|
||||
((pts_feature.size(0) == pts_num) && (pts_feature.size(1) == channels)),
|
||||
"the dimensions of pts_feature should be (pts_num, channels), ",
|
||||
"but got (", pts_feature.size(0), ",", pts_feature.size(1), ").");
|
||||
TORCH_CHECK(((pts_idx_of_voxels.size(0) == boxes_num) &&
|
||||
(pts_idx_of_voxels.size(1) == out_x) &&
|
||||
(pts_idx_of_voxels.size(2) == out_y) &&
|
||||
(pts_idx_of_voxels.size(3) == out_z) &&
|
||||
(pts_idx_of_voxels.size(4) == max_pts_each_voxel)),
|
||||
"the dimensions of pts_idx_of_voxels should be (boxes_num, "
|
||||
"out_x, out_y, out_z, max_pts_each_voxel), ",
|
||||
"but got (", pts_idx_of_voxels.size(0), ",",
|
||||
pts_idx_of_voxels.size(1), ",", pts_idx_of_voxels.size(2), ",",
|
||||
pts_idx_of_voxels.size(3), ",", pts_idx_of_voxels.size(4), ").");
|
||||
TORCH_CHECK(((pooled_features.size(0) == boxes_num) &&
|
||||
(pooled_features.size(1) == out_x) &&
|
||||
(pooled_features.size(2) == out_y) &&
|
||||
(pooled_features.size(3) == out_z) &&
|
||||
(pooled_features.size(4) == channels)),
|
||||
"the dimensions of pooled_features should be (boxes_num, out_x, "
|
||||
"out_y, out_z, channels), ",
|
||||
"but got (", pooled_features.size(0), ",",
|
||||
pooled_features.size(1), ",", pooled_features.size(2), ",",
|
||||
pooled_features.size(3), ",", pooled_features.size(4), ").");
|
||||
// check other params : pool_mothod
|
||||
TORCH_CHECK(((pool_method == 0) || (pool_method == 1)),
|
||||
"the num of pool_method should be 0(max) or 1(avg), ", "but got ",
|
||||
pool_method, ".");
|
||||
// check large tensor
|
||||
const size_t max_input_size = 2147483648;
|
||||
TORCH_CHECK(rois.numel() < max_input_size,
|
||||
"rois element num should be less than 2^31, got ", rois.numel(),
|
||||
".");
|
||||
TORCH_CHECK(pts.numel() < max_input_size,
|
||||
"pts element num should be less than 2^31, got ", pts.numel(),
|
||||
".");
|
||||
TORCH_CHECK(pts_feature.numel() < max_input_size,
|
||||
"pts_feature element num should be less than 2^31, got ",
|
||||
pts_feature.numel(), ".");
|
||||
TORCH_CHECK(pts_idx_of_voxels.numel() < max_input_size,
|
||||
"pts_idx_of_voxels element num should be less than 2^31, got ",
|
||||
pts_idx_of_voxels.numel(), ".");
|
||||
TORCH_CHECK(pooled_features.numel() < max_input_size,
|
||||
"pooled_features element num should be less than 2^31, got ",
|
||||
pooled_features.numel(), ".");
|
||||
// check zero element
|
||||
TORCH_CHECK(rois.numel() != 0, "rois.numel() should not be zero, got ",
|
||||
rois.numel());
|
||||
TORCH_CHECK(pts.numel() != 0, "pts.numel() should not be zero, got ",
|
||||
pts.numel());
|
||||
TORCH_CHECK(pts_feature.numel() != 0,
|
||||
"pts_feature.numel() should not be zero, got ",
|
||||
pts_feature.numel());
|
||||
TORCH_CHECK(pts_idx_of_voxels.numel() != 0,
|
||||
"pts_idx_of_voxels.numel() should not be zero, got ",
|
||||
pts_idx_of_voxels.numel());
|
||||
TORCH_CHECK(pooled_features.numel() != 0,
|
||||
"pooled_features.numel() should not be zero, got ",
|
||||
pooled_features.numel());
|
||||
if (pool_method == 0) {
|
||||
// check datatype
|
||||
TORCH_CHECK((argmax.scalar_type() == at::kInt),
|
||||
"argmax type should be Int, got ", argmax.scalar_type(), ".");
|
||||
// check dim
|
||||
TORCH_CHECK(argmax.dim() == 5, "argmax should be a 5D tensor, got ",
|
||||
argmax.dim(), "D.");
|
||||
// check shape
|
||||
TORCH_CHECK(((argmax.size(0) == boxes_num) && (argmax.size(1) == out_x) &&
|
||||
(argmax.size(2) == out_y) && (argmax.size(3) == out_z) &&
|
||||
(argmax.size(4) == channels)),
|
||||
"the dimensions of argmax should be (boxes_num, out_x, out_y, "
|
||||
"out_z, channels), ",
|
||||
"but got (", argmax.size(0), ",", argmax.size(1), ",",
|
||||
argmax.size(2), ",", argmax.size(3), ",", argmax.size(4), ").");
|
||||
// check large tensor
|
||||
TORCH_CHECK(argmax.numel() < max_input_size,
|
||||
"argmax element num should be less than 2^31, got ",
|
||||
argmax.numel(), ".");
|
||||
// check zero element
|
||||
TORCH_CHECK(argmax.numel() != 0, "argmax.numel() should not be zero, got ",
|
||||
argmax.numel());
|
||||
// when pool_method is 0, which is max pool, init argmax data value to -1
|
||||
argmax.fill_(static_cast<int>(-1));
|
||||
}
|
||||
// calculate task one dimension
|
||||
cnrtDim3_t k1_dim;
|
||||
cnrtFunctionType_t k1_type;
|
||||
kernelPtsIdxOfVoxelsPolicyFunc(boxes_num, &k1_dim, &k1_type);
|
||||
cnrtDim3_t k2_dim;
|
||||
cnrtFunctionType_t k2_type;
|
||||
kernelRoiawarePool3dForwardPolicyFunc(boxes_num, out_x, out_y, out_z, &k2_dim,
|
||||
&k2_type);
|
||||
// get compute queue
|
||||
auto queue = torch_mlu::getCurQueue();
|
||||
// get ptr of tensors
|
||||
auto rois_impl = torch_mlu::getMluTensorImpl(rois);
|
||||
// get compute handle
|
||||
auto handle = mluOpGetCurrentHandle();
|
||||
|
||||
auto rois_contiguous =
|
||||
torch_mlu::cnnl::ops::cnnl_contiguous(rois, rois.suggest_memory_format());
|
||||
auto pts_contiguous =
|
||||
torch_mlu::cnnl::ops::cnnl_contiguous(pts, pts.suggest_memory_format());
|
||||
auto pts_feature_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
|
||||
pts_feature, pts_feature.suggest_memory_format());
|
||||
auto argmax_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
|
||||
argmax, argmax.suggest_memory_format());
|
||||
auto pts_idx_of_voxels_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
|
||||
pts_idx_of_voxels, pts_idx_of_voxels.suggest_memory_format());
|
||||
auto pooled_features_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
|
||||
pooled_features, pooled_features.suggest_memory_format());
|
||||
|
||||
MluOpTensorDescriptor rois_desc, pts_desc, pts_feature_desc, argmax_desc,
|
||||
pts_idx_of_voxels_desc, pooled_features_desc;
|
||||
rois_desc.set(rois_contiguous);
|
||||
pts_desc.set(pts_contiguous);
|
||||
pts_feature_desc.set(pts_feature_contiguous);
|
||||
argmax_desc.set(argmax_contiguous);
|
||||
pts_idx_of_voxels_desc.set(pts_idx_of_voxels_contiguous);
|
||||
pooled_features_desc.set(pooled_features_contiguous);
|
||||
|
||||
// allocate extra space for workspace
|
||||
size_t workspace_size = 0;
|
||||
mluOpGetRoiawarePool3dForwardWorkspaceSize(
|
||||
handle, rois_desc.desc(), pts_desc.desc(), pts_feature_desc.desc(),
|
||||
&workspace_size);
|
||||
|
||||
auto workspace = at::empty(workspace_size, rois.options().dtype(at::kByte));
|
||||
auto workspace_impl = torch_mlu::getMluTensorImpl(workspace);
|
||||
auto workspace_ptr = workspace_impl->cnnlMalloc();
|
||||
|
||||
auto rois_impl = torch_mlu::getMluTensorImpl(rois_contiguous);
|
||||
auto pts_impl = torch_mlu::getMluTensorImpl(pts_contiguous);
|
||||
auto pts_feature_impl = torch_mlu::getMluTensorImpl(pts_feature_contiguous);
|
||||
auto argmax_impl = torch_mlu::getMluTensorImpl(argmax_contiguous);
|
||||
auto pts_idx_of_voxels_impl =
|
||||
torch_mlu::getMluTensorImpl(pts_idx_of_voxels_contiguous);
|
||||
auto pooled_features_impl =
|
||||
torch_mlu::getMluTensorImpl(pooled_features_contiguous);
|
||||
|
||||
auto rois_ptr = rois_impl->cnnlMalloc();
|
||||
// transpose points [pts_num, 3] -> [3, pts_num]
|
||||
auto pts_ = pts.permute({1, 0}).contiguous();
|
||||
auto pts_impl = torch_mlu::getMluTensorImpl(pts_);
|
||||
auto pts_ptr = pts_impl->cnnlMalloc();
|
||||
// transpose points_features [pts_num, channels] -> [channels, pts_num]
|
||||
auto pts_feature_ = pts_feature.permute({1, 0}).contiguous();
|
||||
auto pts_feature_impl = torch_mlu::getMluTensorImpl(pts_feature_);
|
||||
auto pts_feature_ptr = pts_feature_impl->cnnlMalloc();
|
||||
auto pts_idx_of_voxels_impl = torch_mlu::getMluTensorImpl(pts_idx_of_voxels);
|
||||
auto pts_idx_of_voxels_ptr = pts_idx_of_voxels_impl->cnnlMalloc();
|
||||
auto pooled_features_impl = torch_mlu::getMluTensorImpl(pooled_features);
|
||||
auto pooled_features_ptr = pooled_features_impl->cnnlMalloc();
|
||||
auto argmax_impl = torch_mlu::getMluTensorImpl(argmax);
|
||||
auto argmax_ptr = argmax_impl->cnnlMalloc();
|
||||
// get compute dtype of input
|
||||
cnrtDataType_t data_type = torch_mlu::toCnrtDtype(rois.dtype());
|
||||
// launch kernel PtsIdxOfVoxels
|
||||
CNLOG(INFO) << "Launch Kernel MLUKernel PtsIdxOfVoxels<<<" << k1_dim.x << ", "
|
||||
<< k1_dim.y << ", " << k1_dim.z << ">>>";
|
||||
KernelPtsIdxOfVoxels(k1_dim, k1_type, queue, data_type, pool_method,
|
||||
boxes_num, pts_num, max_pts_each_voxel, out_x, out_y,
|
||||
out_z, rois_ptr, pts_ptr, (int *)pts_idx_of_voxels_ptr);
|
||||
// launch kernel RoiawarePool3dForward
|
||||
CNLOG(INFO) << "Launch Kernel MLUKernel RoiawarePool3dForward<<<" << k2_dim.x
|
||||
<< ", " << k2_dim.y << ", " << k2_dim.z << ">>>";
|
||||
KernelRoiawarePool3dForward(
|
||||
k2_dim, k2_type, queue, data_type, pool_method, boxes_num, pts_num,
|
||||
channels, max_pts_each_voxel, out_x, out_y, out_z, pts_feature_ptr,
|
||||
(int *)pts_idx_of_voxels_ptr, pooled_features_ptr, (int *)argmax_ptr);
|
||||
auto pts_idx_of_voxels_ptr = pts_idx_of_voxels_impl->cnnlMalloc();
|
||||
auto pooled_features_ptr = pooled_features_impl->cnnlMalloc();
|
||||
|
||||
CNLOG(INFO) << "Call mluOpRoiawarePool3dForward().";
|
||||
mluOpRoiawarePool3dForward(
|
||||
handle, pool_method, boxes_num, pts_num, channels, rois_desc.desc(),
|
||||
rois_ptr, pts_desc.desc(), pts_ptr, pts_feature_desc.desc(),
|
||||
pts_feature_ptr, workspace_ptr, workspace_size, max_pts_each_voxel, out_x,
|
||||
out_y, out_z, argmax_desc.desc(), argmax_ptr,
|
||||
pts_idx_of_voxels_desc.desc(), pts_idx_of_voxels_ptr,
|
||||
pooled_features_desc.desc(), pooled_features_ptr);
|
||||
}
|
||||
|
||||
void roiaware_pool3d_forward_mlu(int boxes_num, int pts_num, int channels,
|
||||
|
@ -245,136 +100,46 @@ void roiaware_pool3d_forward_impl(int boxes_num, int pts_num, int channels,
|
|||
REGISTER_DEVICE_IMPL(roiaware_pool3d_forward_impl, MLU,
|
||||
roiaware_pool3d_forward_mlu);
|
||||
|
||||
void KernelRoiawarePool3dBackward(
|
||||
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
|
||||
const cnrtDataType_t d_type, const int pool_method, const int boxes_num,
|
||||
const int out_x, const int out_y, const int out_z, const int channels,
|
||||
const int max_pts_each_voxel, const int *pts_idx_of_voxels,
|
||||
const int *argmax, const void *grad_out, void *grad_in);
|
||||
|
||||
static void kernelRoiawarePool3dBackwardPolicyFunc(
|
||||
const int boxes_num, const int out_x, const int out_y, const int out_z,
|
||||
cnrtDim3_t *k_dim, cnrtFunctionType_t *k_type) {
|
||||
unsigned int core_num = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
|
||||
unsigned int cluster_num = torch_mlu::getDeviceAttr(cnrtAttrClusterCount);
|
||||
*k_type = CNRT_FUNC_TYPE_UNION1;
|
||||
k_dim->x = core_num;
|
||||
const int voxels_num = boxes_num * out_x * out_y * out_z;
|
||||
unsigned int use_cluster = (voxels_num + core_num - 1) / core_num;
|
||||
k_dim->y = use_cluster > cluster_num ? cluster_num : use_cluster;
|
||||
k_dim->z = 1;
|
||||
}
|
||||
|
||||
void RoiawarePool3dBackwardMLUKernelLauncher(
|
||||
int pool_method, int boxes_num, int out_x, int out_y, int out_z,
|
||||
int channels, int max_pts_each_voxel, const Tensor pts_idx_of_voxels,
|
||||
const Tensor argmax, const Tensor grad_out, Tensor grad_in) {
|
||||
// check datatype
|
||||
TORCH_CHECK((pts_idx_of_voxels.scalar_type() == at::kInt),
|
||||
"pts_idx_of_voxels type should be Int, got ",
|
||||
pts_idx_of_voxels.scalar_type(), ".");
|
||||
TORCH_CHECK((argmax.scalar_type() == at::kInt),
|
||||
"argmax type should be Int, got ", argmax.scalar_type(), ".");
|
||||
TORCH_CHECK((grad_out.scalar_type() == at::kFloat ||
|
||||
grad_out.scalar_type() == at::kHalf),
|
||||
"grad_out type should be Float or Half, got ",
|
||||
grad_out.scalar_type(), ".");
|
||||
TORCH_CHECK((grad_out.scalar_type() == grad_in.scalar_type()),
|
||||
"data types of grad_out, grad_in, should be the same, ",
|
||||
"but now grad_out type is ", grad_out.scalar_type(),
|
||||
", grad_in type is ", grad_in.scalar_type(), ".");
|
||||
// check dim
|
||||
TORCH_CHECK(pts_idx_of_voxels.dim() == 5,
|
||||
"pts_idx_of_voxels should be a 5D tensor, got ",
|
||||
pts_idx_of_voxels.dim(), "D.");
|
||||
TORCH_CHECK(argmax.dim() == 5, "argmax should be a 5D tensor, got ",
|
||||
argmax.dim(), "D.");
|
||||
TORCH_CHECK(grad_out.dim() == 5, "grad_out should be a 5D tensor, got ",
|
||||
grad_out.dim(), "D.");
|
||||
TORCH_CHECK(grad_in.dim() == 2, "grad_in should be a 2D tensor, got ",
|
||||
grad_in.dim(), "D.");
|
||||
// check shape
|
||||
TORCH_CHECK(((pts_idx_of_voxels.size(0) == boxes_num) &&
|
||||
(pts_idx_of_voxels.size(1) == out_x) &&
|
||||
(pts_idx_of_voxels.size(2) == out_y) &&
|
||||
(pts_idx_of_voxels.size(3) == out_z) &&
|
||||
(pts_idx_of_voxels.size(4) == max_pts_each_voxel)),
|
||||
"the dimensions of pts_idx_of_voxels should be (boxes_num, "
|
||||
"out_x, out_y, out_z, max_pts_each_voxel), ",
|
||||
"but got (", pts_idx_of_voxels.size(0), ",",
|
||||
pts_idx_of_voxels.size(1), ",", pts_idx_of_voxels.size(2), ",",
|
||||
pts_idx_of_voxels.size(3), ",", pts_idx_of_voxels.size(4), ").");
|
||||
TORCH_CHECK(((argmax.size(0) == boxes_num) && (argmax.size(1) == out_x) &&
|
||||
(argmax.size(2) == out_y) && (argmax.size(3) == out_z) &&
|
||||
(argmax.size(4) == channels)),
|
||||
"the dimensions of argmax should be (boxes_num, out_x, out_y, "
|
||||
"out_z, channels), ",
|
||||
"but got (", argmax.size(0), ",", argmax.size(1), ",",
|
||||
argmax.size(2), ",", argmax.size(3), ",", argmax.size(4), ").");
|
||||
TORCH_CHECK(((grad_out.size(0) == boxes_num) && (grad_out.size(1) == out_x) &&
|
||||
(grad_out.size(2) == out_y) && (grad_out.size(3) == out_z) &&
|
||||
(grad_out.size(4) == channels)),
|
||||
"the dimensions of grad_out should be (boxes_num, out_x, "
|
||||
"out_y, out_z, channels), ",
|
||||
"but got (", grad_out.size(0), ",", grad_out.size(1), ",",
|
||||
grad_out.size(2), ",", grad_out.size(3), ",", grad_out.size(4),
|
||||
").");
|
||||
TORCH_CHECK((grad_in.size(1) == channels),
|
||||
"the 1st dimensions of grad_in should be channels, ", "but got ",
|
||||
grad_in.size(1), ".");
|
||||
// check other params : pool_mothod
|
||||
TORCH_CHECK(((pool_method == 0) || (pool_method == 1)),
|
||||
"the num of pool_method should be 0(max) or 1(avg), ", "but got ",
|
||||
pool_method, ".");
|
||||
// check large tensor
|
||||
const size_t max_input_size = 2147483648;
|
||||
TORCH_CHECK(pts_idx_of_voxels.numel() < max_input_size,
|
||||
"pts_idx_of_voxels element num should be less than 2^31, got ",
|
||||
pts_idx_of_voxels.numel(), ".");
|
||||
TORCH_CHECK(argmax.numel() < max_input_size,
|
||||
"argmax element num should be less than 2^31, got ",
|
||||
argmax.numel(), ".");
|
||||
TORCH_CHECK(grad_out.numel() < max_input_size,
|
||||
"grad_out element num should be less than 2^31, got ",
|
||||
grad_out.numel(), ".");
|
||||
TORCH_CHECK(grad_in.numel() < max_input_size,
|
||||
"grad_in element num should be less than 2^31, got ",
|
||||
grad_in.numel(), ".");
|
||||
// check zero element
|
||||
TORCH_CHECK(pts_idx_of_voxels.numel() != 0,
|
||||
"pts_idx_of_voxels.numel() should not be zero, got ",
|
||||
pts_idx_of_voxels.numel());
|
||||
TORCH_CHECK(argmax.numel() != 0, "argmax.numel() should not be zero, got ",
|
||||
argmax.numel());
|
||||
TORCH_CHECK(grad_out.numel() != 0,
|
||||
"grad_out.numel() should not be zero, got ", grad_out.numel());
|
||||
TORCH_CHECK(grad_in.numel() != 0, "grad_in.numel() should not be zero, got ",
|
||||
grad_in.numel());
|
||||
// calculate task one dimension
|
||||
cnrtDim3_t k_dim;
|
||||
cnrtFunctionType_t k_type;
|
||||
kernelRoiawarePool3dBackwardPolicyFunc(boxes_num, out_x, out_y, out_z, &k_dim,
|
||||
&k_type);
|
||||
// get compute queue
|
||||
auto queue = torch_mlu::getCurQueue();
|
||||
// transpose points_features [pts_num, channels] -> [channels, pts_num]
|
||||
auto pts_idx_of_voxels_impl = torch_mlu::getMluTensorImpl(pts_idx_of_voxels);
|
||||
// get compute handle
|
||||
auto handle = mluOpGetCurrentHandle();
|
||||
auto pts_idx_of_voxels_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
|
||||
pts_idx_of_voxels, pts_idx_of_voxels.suggest_memory_format());
|
||||
auto argmax_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
|
||||
argmax, argmax.suggest_memory_format());
|
||||
auto grad_out_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
|
||||
grad_out, grad_out.suggest_memory_format());
|
||||
auto grad_in_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
|
||||
grad_in, grad_in.suggest_memory_format());
|
||||
|
||||
MluOpTensorDescriptor pts_idx_of_voxels_desc, argmax_desc, grad_out_desc,
|
||||
grad_in_desc;
|
||||
|
||||
pts_idx_of_voxels_desc.set(pts_idx_of_voxels_contiguous);
|
||||
argmax_desc.set(argmax_contiguous);
|
||||
grad_out_desc.set(grad_out_contiguous);
|
||||
grad_in_desc.set(grad_in_contiguous);
|
||||
|
||||
auto pts_idx_of_voxels_impl =
|
||||
torch_mlu::getMluTensorImpl(pts_idx_of_voxels_contiguous);
|
||||
auto argmax_impl = torch_mlu::getMluTensorImpl(argmax_contiguous);
|
||||
auto grad_out_impl = torch_mlu::getMluTensorImpl(grad_out_contiguous);
|
||||
auto grad_in_impl = torch_mlu::getMluTensorImpl(grad_in_contiguous);
|
||||
|
||||
auto pts_idx_of_voxels_ptr = pts_idx_of_voxels_impl->cnnlMalloc();
|
||||
auto argmax_impl = torch_mlu::getMluTensorImpl(argmax);
|
||||
auto argmax_ptr = argmax_impl->cnnlMalloc();
|
||||
auto grad_out_impl = torch_mlu::getMluTensorImpl(grad_out);
|
||||
auto grad_out_ptr = grad_out_impl->cnnlMalloc();
|
||||
auto grad_in_impl = torch_mlu::getMluTensorImpl(grad_in);
|
||||
auto grad_in_ptr = grad_in_impl->cnnlMalloc();
|
||||
// get compute dtype of input
|
||||
cnrtDataType_t data_type = torch_mlu::toCnrtDtype(grad_out.dtype());
|
||||
// launch kernel RoiawarePool3dForward
|
||||
CNLOG(INFO) << "Launch Kernel MLUKernel RoiawarePool3dBackward<<<" << k_dim.x
|
||||
<< ", " << k_dim.y << ", " << k_dim.z << ">>>";
|
||||
KernelRoiawarePool3dBackward(k_dim, k_type, queue, data_type, pool_method,
|
||||
boxes_num, out_x, out_y, out_z, channels,
|
||||
max_pts_each_voxel, (int *)pts_idx_of_voxels_ptr,
|
||||
(int *)argmax_ptr, grad_out_ptr, grad_in_ptr);
|
||||
|
||||
CNLOG(INFO) << "Call mluOpRoiawarePool3dBackward().";
|
||||
mluOpRoiawarePool3dBackward(
|
||||
handle, pool_method, boxes_num, out_x, out_y, out_z, channels,
|
||||
max_pts_each_voxel, pts_idx_of_voxels_desc.desc(), pts_idx_of_voxels_ptr,
|
||||
argmax_desc.desc(), argmax_ptr, grad_out_desc.desc(), grad_out_ptr,
|
||||
grad_in_desc.desc(), grad_in_ptr);
|
||||
}
|
||||
|
||||
void roiaware_pool3d_backward_mlu(int boxes_num, int out_x, int out_y,
|
||||
|
|
Loading…
Reference in New Issue