From 126077df5ebd73f7d8f7233c9b4c1606d1489e9c Mon Sep 17 00:00:00 2001 From: mahxn0 <1262384588@qq.com> Date: Fri, 10 Mar 2023 11:21:43 +0800 Subject: [PATCH] [Feature] Make voxelization operator support the mlu290 platform (#2652) * [Feature]: Make voxelization operator support the mlu290 platform. * [Feature]: Make voxelization operator support the mlu290 platform. * [Feature]: Make voxelization operator support the mlu290 platform. --------- Co-authored-by: maxiangjun --- .../common/mlu/voxelization_mlu_kernel.mlu | 213 ++++++++++++++---- mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h | 2 +- 2 files changed, 166 insertions(+), 49 deletions(-) diff --git a/mmcv/ops/csrc/common/mlu/voxelization_mlu_kernel.mlu b/mmcv/ops/csrc/common/mlu/voxelization_mlu_kernel.mlu index d7c57da4f..dd3846250 100644 --- a/mmcv/ops/csrc/common/mlu/voxelization_mlu_kernel.mlu +++ b/mmcv/ops/csrc/common/mlu/voxelization_mlu_kernel.mlu @@ -14,7 +14,17 @@ __nram__ char nram_buffer[MAX_NRAM_SIZE]; -#if __BANG_ARCH__ >= 322 +__mlu_func__ void floor(float *dst_ram, float *src_ram, int size) { +#if (__BANG_ARCH__ >= 322) + __bang_floor(dst_ram, src_ram, size); // This bang interface is for nan/inf + // temp. +#else + int16 *mid = (int16 *)(dst_ram + size / 2); + __bang_float2int16_dn(mid, (float *)src_ram, size, 0); + __bang_int162float((float *)dst_ram, mid, size, 0); +#endif +} + __mlu_func__ void computeDynamicVoxelize( char *points_x, char *points_y, char *points_z, char *auxiliary_a, char *auxiliary_b, char *auxiliary_c, const float coors_x_min, @@ -24,21 +34,25 @@ __mlu_func__ void computeDynamicVoxelize( // x - coors_x_min __bang_sub_scalar((float *)points_x, (float *)points_x, coors_x_min, deal_num); - // y - coors_y_min - __bang_sub_scalar((float *)points_y, (float *)points_y, coors_y_min, - deal_num); - // z - coors_z_min - __bang_sub_scalar((float *)points_z, (float *)points_z, coors_z_min, - deal_num); // (x - coors_x_min) / voxel_x __bang_mul_scalar((float *)points_x, (float *)points_x, 1.0 / voxel_x, deal_num); + + // y - coors_y_min + __bang_sub_scalar((float *)points_y, (float *)points_y, coors_y_min, + deal_num); // (y - coors_y_min) / voxel_y __bang_mul_scalar((float *)points_y, (float *)points_y, 1.0 / voxel_y, deal_num); + + // z - coors_z_min + __bang_sub_scalar((float *)points_z, (float *)points_z, coors_z_min, + deal_num); // (z - coors_z_min) / voxel_z __bang_mul_scalar((float *)points_z, (float *)points_z, 1.0 / voxel_z, deal_num); + +#if __BANG_ARCH__ >= 322 // c_x = floor((x - coors_x_min) / voxel_x) __bang_floor((float *)auxiliary_a, (float *)points_x, deal_num); __bang_float2int32((int32_t *)points_x, (float *)auxiliary_a, deal_num, 0); @@ -55,7 +69,7 @@ __mlu_func__ void computeDynamicVoxelize( __bang_lt_scalar((int32_t *)auxiliary_c, (int32_t *)points_x, grid_x, deal_num); // 0 <= c_x < grid_x - __bang_mul((int32_t *)auxiliary_a, (int32_t *)auxiliary_b, + __bang_and((int32_t *)auxiliary_a, (int32_t *)auxiliary_b, (int32_t *)auxiliary_c, deal_num); // c_y >= 0 __bang_ge_scalar((int32_t *)auxiliary_b, (int32_t *)points_y, (int32_t)0, @@ -64,10 +78,10 @@ __mlu_func__ void computeDynamicVoxelize( __bang_lt_scalar((int32_t *)auxiliary_c, (int32_t *)points_y, grid_y, deal_num); // 0 <= c_y < grid_y - __bang_mul((int32_t *)auxiliary_b, (int32_t *)auxiliary_b, + __bang_and((int32_t *)auxiliary_b, (int32_t *)auxiliary_b, (int32_t *)auxiliary_c, deal_num); // c_x >= 0 && c_x < grid_x && c_y >= 0 && c_y < grid_y - __bang_mul((int32_t *)auxiliary_a, (int32_t *)auxiliary_a, + __bang_and((int32_t *)auxiliary_a, (int32_t *)auxiliary_a, (int32_t *)auxiliary_b, deal_num); // c_z >= 0 __bang_ge_scalar((int32_t *)auxiliary_b, (int32_t *)points_z, (int32_t)0, @@ -76,10 +90,10 @@ __mlu_func__ void computeDynamicVoxelize( __bang_lt_scalar((int32_t *)auxiliary_c, (int32_t *)points_z, grid_z, deal_num); // 0 <= c_z < grid_z - __bang_mul((int32_t *)auxiliary_b, (int32_t *)auxiliary_b, + __bang_and((int32_t *)auxiliary_b, (int32_t *)auxiliary_b, (int32_t *)auxiliary_c, deal_num); // 0 <= c_x < grid_x && 0 <= c_y < grid_y && 0 <= c_z < grid_z - __bang_mul((int32_t *)auxiliary_a, (int32_t *)auxiliary_a, + __bang_and((int32_t *)auxiliary_a, (int32_t *)auxiliary_a, (int32_t *)auxiliary_b, deal_num); __bang_not((int32_t *)auxiliary_c, (int32_t *)auxiliary_a, deal_num); @@ -97,15 +111,78 @@ __mlu_func__ void computeDynamicVoxelize( deal_num); __bang_add((int32_t *)points_z, (int32_t *)points_z, (int32_t *)auxiliary_b, deal_num); +#else + // c_x >= 0 + __bang_ge_scalar((float *)auxiliary_b, (float *)points_x, (float)0, deal_num); + + // c_x < grid_x + __bang_write_value((float *)auxiliary_a, deal_num, (float)grid_x); + __bang_lt((float *)auxiliary_c, (float *)points_x, (float *)auxiliary_a, + deal_num); + // 0 <= c_x < grid_x + __bang_and((float *)auxiliary_a, (float *)auxiliary_b, (float *)auxiliary_c, + deal_num); + // c_y >= 0 + __bang_ge_scalar((float *)auxiliary_b, (float *)points_y, (float)0, deal_num); + // c_y < grid_y + __bang_write_value((float *)auxiliary_c, deal_num, (float)grid_y); + __bang_lt((float *)auxiliary_c, (float *)points_y, (float *)auxiliary_c, + deal_num); + // 0 <= c_y < grid_y + __bang_and((float *)auxiliary_b, (float *)auxiliary_b, (float *)auxiliary_c, + deal_num); + // c_x >= 0 && c_x < grid_x && c_y >= 0 && c_y < grid_y + __bang_and((float *)auxiliary_a, (float *)auxiliary_a, (float *)auxiliary_b, + deal_num); + // c_z >= 0 + __bang_ge_scalar((float *)auxiliary_b, (float *)points_z, (float)0, deal_num); + // c_z < grid_z + __bang_write_value((float *)auxiliary_c, deal_num, (float)grid_z); + __bang_lt((float *)auxiliary_c, (float *)points_z, (float *)auxiliary_c, + deal_num); + // 0 <= c_z < grid_z + __bang_and((float *)auxiliary_b, (float *)auxiliary_b, (float *)auxiliary_c, + deal_num); + // 0 <= c_x < grid_x && 0 <= c_y < grid_y && 0 <= c_z < grid_z + __bang_and((float *)auxiliary_a, (float *)auxiliary_a, (float *)auxiliary_b, + deal_num); + __bang_not((float *)auxiliary_c, (float *)auxiliary_a, deal_num); + + __bang_mul((float *)points_x, (float *)points_x, (float *)auxiliary_a, + deal_num); + __bang_mul_scalar((float *)auxiliary_b, (float *)auxiliary_c, (float)(-1.0), + deal_num); + __bang_add((float *)points_x, (float *)points_x, (float *)auxiliary_b, + deal_num); + __bang_mul((float *)points_y, (float *)points_y, (float *)auxiliary_a, + deal_num); + __bang_add((float *)points_y, (float *)points_y, (float *)auxiliary_b, + deal_num); + __bang_mul((float *)points_z, (float *)points_z, (float *)auxiliary_a, + deal_num); + __bang_add((float *)points_z, (float *)points_z, (float *)auxiliary_b, + deal_num); + + floor((float *)auxiliary_a, (float *)points_x, deal_num); + convertFloat2Int((int32_t *)points_x, (float *)auxiliary_b, + (float *)auxiliary_a, (float *)auxiliary_c, deal_num); + + floor((float *)auxiliary_a, (float *)points_y, deal_num); + convertFloat2Int((int32_t *)points_y, (float *)auxiliary_b, + (float *)auxiliary_a, (float *)auxiliary_c, deal_num); + + floor((float *)auxiliary_a, (float *)points_z, deal_num); + convertFloat2Int((int32_t *)points_z, (float *)auxiliary_b, + (float *)auxiliary_a, (float *)auxiliary_c, deal_num); +#endif } -__mlu_func__ void computePoint2Voxel(char *coors_x, char *coors_y, - char *coors_z, const int32_t c_x, - const int32_t c_y, const int32_t c_z, - const int32_t max_points, int32_t *num, - int32_t *first_point, - const int32_t deal_idx, - const int32_t deal_num) { +__mlu_func__ void computePoint2Voxel( + char *coors_x, char *coors_y, char *coors_z, char *src_addition, + char *dst_addition, char *dst, const int32_t c_x, const int32_t c_y, + const int32_t c_z, const int32_t max_points, int32_t *num, + int32_t *first_point, const int32_t deal_idx, const int32_t deal_num) { +#if __BANG_ARCH__ >= 322 __bang_eq_scalar((int32_t *)coors_x, (int32_t *)coors_x, c_x, deal_num); __bang_eq_scalar((int32_t *)coors_y, (int32_t *)coors_y, c_y, deal_num); __bang_eq_scalar((int32_t *)coors_z, (int32_t *)coors_z, c_z, deal_num); @@ -122,8 +199,35 @@ __mlu_func__ void computePoint2Voxel(char *coors_x, char *coors_y, } else { *num += (int32_t)__bang_count((float *)coors_x, deal_num); } -} +#else + convertInt2Float((float *)dst, (float *)dst_addition, (int32_t *)coors_x, + (float *)src_addition, deal_num); + __bang_write_value((float *)src_addition, deal_num, (float)c_x); + __bang_eq((float *)coors_x, (float *)dst, (float *)src_addition, deal_num); + + convertInt2Float((float *)dst, (float *)dst_addition, (int32_t *)coors_y, + (float *)src_addition, deal_num); + __bang_write_value((float *)src_addition, deal_num, (float)c_y); + __bang_eq((float *)coors_y, (float *)dst, (float *)src_addition, deal_num); + + convertInt2Float((float *)dst, (float *)dst_addition, (int32_t *)coors_z, + (float *)src_addition, deal_num); + __bang_write_value((float *)src_addition, deal_num, (float)c_z); + __bang_eq((float *)coors_z, (float *)dst, (float *)src_addition, deal_num); + + __bang_mul((float *)coors_x, (float *)coors_x, (float *)coors_y, deal_num); + __bang_mul((float *)coors_x, (float *)coors_x, (float *)coors_z, deal_num); + if (*num == 0) { + *num = (int32_t)__bang_count((float *)coors_x, deal_num); + if (*num > 0) { + *first_point = + (int32_t)__bang_findfirst1((float *)coors_x, deal_num) + deal_idx; + } + } else { + *num += (int32_t)__bang_count((float *)coors_x, deal_num); + } #endif +} __mlu_global__ void MLUUnion1KernelDynamicVoxelize( const float *points, int32_t *coors, const float voxel_x, @@ -132,7 +236,6 @@ __mlu_global__ void MLUUnion1KernelDynamicVoxelize( const float coors_y_max, const float coors_z_max, const int32_t grid_x, const int32_t grid_y, const int32_t grid_z, const int32_t num_points, const int32_t num_features) { -#if __BANG_ARCH__ >= 322 if (coreId == 0x80) { return; } @@ -143,14 +246,13 @@ __mlu_global__ void MLUUnion1KernelDynamicVoxelize( const int32_t points_start = taskId < points_rem ? taskId * points_per_core : taskId * points_per_core + points_rem; - const int32_t split_num = 9; const int32_t deal_num = PAD_DOWN(MAX_NRAM_SIZE / split_num / sizeof(float), NFU_ALIGN_SIZE); const int32_t repeat = points_per_core / deal_num; const int32_t rem = points_per_core % deal_num; + const int32_t rem_align = CEIL_ALIGN(rem, NFU_ALIGN_SIZE); const int32_t ping_pong_gap = 3 * deal_num * sizeof(float); - char *points_x = nram_buffer; char *points_y = points_x + deal_num * sizeof(float); char *points_z = points_y + deal_num * sizeof(float); @@ -281,7 +383,7 @@ __mlu_global__ void MLUUnion1KernelDynamicVoxelize( points_z + (repeat % 2) * ping_pong_gap, auxiliary_a, auxiliary_b, auxiliary_c, coors_x_min, coors_y_min, coors_z_min, voxel_x, voxel_y, voxel_z, grid_x, - grid_y, grid_z, rem); + grid_y, grid_z, rem_align); __asm__ volatile("sync;"); __memcpy_async(coors_x_start + repeat * deal_num, points_x + (repeat % 2) * ping_pong_gap, @@ -293,7 +395,6 @@ __mlu_global__ void MLUUnion1KernelDynamicVoxelize( points_z + (repeat % 2) * ping_pong_gap, rem * sizeof(int32_t), NRAM2GDRAM); } -#endif } __mlu_global__ void MLUUnion1KernelPoint2Voxel(int32_t *coors, @@ -302,18 +403,26 @@ __mlu_global__ void MLUUnion1KernelPoint2Voxel(int32_t *coors, const int32_t num_points, const int32_t max_points) { #if __BANG_ARCH__ >= 322 - if (coreId == 0x80) { - return; - } - const int32_t split_num = 6; +#else + const int32_t split_num = 9; // one temp space for computePoint2Voxel in + // mlu2xx +#endif const int32_t deal_num = PAD_DOWN(MAX_NRAM_SIZE / split_num / sizeof(int32_t), NFU_ALIGN_SIZE); - const int32_t ping_pong_gap = 3 * deal_num * sizeof(int32_t); - char *coors_x = nram_buffer; char *coors_y = coors_x + deal_num * sizeof(int32_t); char *coors_z = coors_y + deal_num * sizeof(int32_t); + const int32_t ping_pong_gap = 3 * deal_num * sizeof(int32_t); +#if __BANG_ARCH__ >= 322 + char *src_addition = nullptr; + char *dst_addition = nullptr; + char *dst = nullptr; +#else + char *src_addition = coors_x + 2 * ping_pong_gap; + char *dst_addition = src_addition + deal_num * sizeof(int32_t); + char *dst = dst_addition + deal_num * sizeof(int32_t); +#endif int32_t *coors_z_start = coors; int32_t *coors_y_start = coors + num_points; @@ -332,11 +441,15 @@ __mlu_global__ void MLUUnion1KernelPoint2Voxel(int32_t *coors, int32_t c_z = coors_z_start[point_idx]; int32_t deal_total_num = point_idx; - int32_t repeat = deal_total_num / deal_num; - int32_t rem = deal_total_num % deal_num; + const int32_t repeat = deal_total_num / deal_num; + const int32_t rem = deal_total_num % deal_num; + int32_t rem_align = CEIL_ALIGN(rem, NFU_ALIGN_SIZE); + +#if __BANG_ARCH__ >= 322 + rem_align = rem; +#endif int32_t num = 0; int32_t first_point = -1; - if (repeat > 0) { __memcpy_async(coors_x, coors_x_start, deal_num * sizeof(int32_t), GDRAM2NRAM); @@ -357,14 +470,22 @@ __mlu_global__ void MLUUnion1KernelPoint2Voxel(int32_t *coors, __memcpy_async(coors_z + ((i + 1) % 2) * ping_pong_gap, coors_z_start + (i + 1) * deal_num, deal_num * sizeof(int32_t), GDRAM2NRAM); - computePoint2Voxel( - coors_x + (i % 2) * ping_pong_gap, coors_y + (i % 2) * ping_pong_gap, - coors_z + (i % 2) * ping_pong_gap, c_x, c_y, c_z, max_points, &num, - &first_point, i * deal_num, deal_num); + computePoint2Voxel(coors_x + (i % 2) * ping_pong_gap, + coors_y + (i % 2) * ping_pong_gap, + coors_z + (i % 2) * ping_pong_gap, src_addition, + dst_addition, dst, c_x, c_y, c_z, max_points, &num, + &first_point, i * deal_num, deal_num); __asm__ volatile("sync;"); } if (rem > 0) { + __bang_write_value((int32_t *)(coors_x + (repeat % 2) * ping_pong_gap), + rem_align, -1); + __bang_write_value((int32_t *)(coors_y + (repeat % 2) * ping_pong_gap), + rem_align, -1); + __bang_write_value((int32_t *)(coors_z + (repeat % 2) * ping_pong_gap), + rem_align, -1); + __memcpy_async(coors_x + (repeat % 2) * ping_pong_gap, coors_x_start + repeat * deal_num, rem * sizeof(int32_t), GDRAM2NRAM); @@ -378,8 +499,9 @@ __mlu_global__ void MLUUnion1KernelPoint2Voxel(int32_t *coors, if (repeat > 0) { computePoint2Voxel(coors_x + ((repeat - 1) % 2) * ping_pong_gap, coors_y + ((repeat - 1) % 2) * ping_pong_gap, - coors_z + ((repeat - 1) % 2) * ping_pong_gap, c_x, c_y, - c_z, max_points, &num, &first_point, + coors_z + ((repeat - 1) % 2) * ping_pong_gap, + src_addition, dst_addition, dst, c_x, c_y, c_z, + max_points, &num, &first_point, (repeat - 1) * deal_num, deal_num); } __asm__ volatile("sync;"); @@ -387,9 +509,9 @@ __mlu_global__ void MLUUnion1KernelPoint2Voxel(int32_t *coors, if (rem > 0) { computePoint2Voxel(coors_x + (repeat % 2) * ping_pong_gap, coors_y + (repeat % 2) * ping_pong_gap, - coors_z + (repeat % 2) * ping_pong_gap, c_x, c_y, c_z, - max_points, &num, &first_point, repeat * deal_num, - rem); + coors_z + (repeat % 2) * ping_pong_gap, src_addition, + dst_addition, dst, c_x, c_y, c_z, max_points, &num, + &first_point, repeat * deal_num, rem_align); __asm__ volatile("sync;"); } @@ -405,14 +527,12 @@ __mlu_global__ void MLUUnion1KernelPoint2Voxel(int32_t *coors, point_to_voxelidx[point_idx] = -1; } } -#endif } __mlu_global__ void MLUUnion1KernelCalcPointsPerVoxel( int32_t *point_to_pointidx, int32_t *point_to_voxelidx, int32_t *coor_to_voxelidx, int32_t *num_points_per_voxel, int32_t *voxel_num, const int32_t max_voxels, const int32_t num_points) { -#if __BANG_ARCH__ >= 322 if (coreId == 0) { int32_t voxel_num_temp = 0; for (int32_t point_idx = 0; point_idx < num_points; ++point_idx) { @@ -439,7 +559,6 @@ __mlu_global__ void MLUUnion1KernelCalcPointsPerVoxel( } *voxel_num = voxel_num_temp; } -#endif } __mlu_global__ void MLUUnion1KernelAssignVoxelsCoors( @@ -447,7 +566,6 @@ __mlu_global__ void MLUUnion1KernelAssignVoxelsCoors( int32_t *coor_to_voxelidx, float *voxels, int32_t *coors, const int32_t max_points, const int32_t num_points, const int32_t num_features) { -#if __BANG_ARCH__ >= 322 if (coreId == 0x80) { return; } @@ -479,7 +597,6 @@ __mlu_global__ void MLUUnion1KernelAssignVoxelsCoors( } } __asm__ volatile("sync;"); -#endif } void KernelDynamicVoxelize(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, diff --git a/mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h b/mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h index 321aa0df9..0d6a9aff4 100644 --- a/mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h +++ b/mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h @@ -19,7 +19,7 @@ #define MLUOP_MAJOR 0 #define MLUOP_MINOR 4 -#define MLUOP_PATCHLEVEL 1 +#define MLUOP_PATCHLEVEL 2 mluOpDataType_t getMluOpDataType(const caffe2::TypeMeta& data_type); mluOpTensorLayout_t getMluOpSuggestLayout(const at::Tensor& input);