[Feature] Support Voxelization with cambricon MLU device (#2500)

* [Feature] Support hard_voxelize with cambricon MLU backend

* [Feature](bangc-ops): add voxelization op

* [Feature](bangc-ops): add voxelization op

* [Feature](bangc-ops): add voxelization op

* [Feature](bangc-ops): add voxelization op

* [Feature](bangc-ops): add voxelization op

* [Feature](bangc-ops): add voxelization op

* [Feature](bangc-ops): add voxelization op

* [Feature](bangc-ops): add voxelization op
pull/2544/head^2
ZShaopeng 2023-01-10 19:43:45 +08:00 committed by GitHub
parent 64e739e002
commit 48ea88ab9f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 840 additions and 5 deletions

View File

@ -58,5 +58,5 @@ We implement common ops used in detection, segmentation, etc.
| ThreeNN | | √ | √ | | |
| TINShift | | √ | √ | | |
| UpFirDn2d | | √ | | | |
| Voxelization | √ | √ | | | |
| Voxelization | √ | √ | | | |
| PrRoIPool | | √ | | | |

View File

@ -58,5 +58,5 @@ MMCV 提供了检测、分割等任务中常用的算子
| ThreeNN | | √ | √ | | |
| TINShift | | √ | √ | | |
| UpFirDn2d | | √ | | | |
| Voxelization | √ | √ | | | |
| Voxelization | √ | √ | | | |
| PrRoIPool | | √ | | | |

View File

@ -0,0 +1,532 @@
/*************************************************************************
* Copyright (C) 2022 by Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "common_mlu_helper.hpp"
__nram__ char nram_buffer[MAX_NRAM_SIZE];
#if __BANG_ARCH__ >= 322
__mlu_func__ void computeDynamicVoxelize(
char *points_x, char *points_y, char *points_z, char *auxiliary_a,
char *auxiliary_b, char *auxiliary_c, const float coors_x_min,
const float coors_y_min, const float coors_z_min, const float voxel_x,
const float voxel_y, const float voxel_z, const int32_t grid_x,
const int32_t grid_y, const int32_t grid_z, const int32_t deal_num) {
// x - coors_x_min
__bang_sub_scalar((float *)points_x, (float *)points_x, coors_x_min,
deal_num);
// y - coors_y_min
__bang_sub_scalar((float *)points_y, (float *)points_y, coors_y_min,
deal_num);
// z - coors_z_min
__bang_sub_scalar((float *)points_z, (float *)points_z, coors_z_min,
deal_num);
// (x - coors_x_min) / voxel_x
__bang_mul_scalar((float *)points_x, (float *)points_x, 1.0 / voxel_x,
deal_num);
// (y - coors_y_min) / voxel_y
__bang_mul_scalar((float *)points_y, (float *)points_y, 1.0 / voxel_y,
deal_num);
// (z - coors_z_min) / voxel_z
__bang_mul_scalar((float *)points_z, (float *)points_z, 1.0 / voxel_z,
deal_num);
// c_x = floor((x - coors_x_min) / voxel_x)
__bang_floor((float *)auxiliary_a, (float *)points_x, deal_num);
__bang_float2int32((int32_t *)points_x, (float *)auxiliary_a, deal_num, 0);
// c_y = floor((y - coors_y_min) / voxel_y)
__bang_floor((float *)auxiliary_a, (float *)points_y, deal_num);
__bang_float2int32((int32_t *)points_y, (float *)auxiliary_a, deal_num, 0);
// c_z = floor((z - coors_z_min) / voxel_z)
__bang_floor((float *)auxiliary_a, (float *)points_z, deal_num);
__bang_float2int32((int32_t *)points_z, (float *)auxiliary_a, deal_num, 0);
// c_x >= 0
__bang_ge_scalar((int32_t *)auxiliary_b, (int32_t *)points_x, (int32_t)0,
deal_num);
// c_x < grid_x
__bang_lt_scalar((int32_t *)auxiliary_c, (int32_t *)points_x, grid_x,
deal_num);
// 0 <= c_x < grid_x
__bang_mul((int32_t *)auxiliary_a, (int32_t *)auxiliary_b,
(int32_t *)auxiliary_c, deal_num);
// c_y >= 0
__bang_ge_scalar((int32_t *)auxiliary_b, (int32_t *)points_y, (int32_t)0,
deal_num);
// c_y < grid_y
__bang_lt_scalar((int32_t *)auxiliary_c, (int32_t *)points_y, grid_y,
deal_num);
// 0 <= c_y < grid_y
__bang_mul((int32_t *)auxiliary_b, (int32_t *)auxiliary_b,
(int32_t *)auxiliary_c, deal_num);
// c_x >= 0 && c_x < grid_x && c_y >= 0 && c_y < grid_y
__bang_mul((int32_t *)auxiliary_a, (int32_t *)auxiliary_a,
(int32_t *)auxiliary_b, deal_num);
// c_z >= 0
__bang_ge_scalar((int32_t *)auxiliary_b, (int32_t *)points_z, (int32_t)0,
deal_num);
// c_z < grid_z
__bang_lt_scalar((int32_t *)auxiliary_c, (int32_t *)points_z, grid_z,
deal_num);
// 0 <= c_z < grid_z
__bang_mul((int32_t *)auxiliary_b, (int32_t *)auxiliary_b,
(int32_t *)auxiliary_c, deal_num);
// 0 <= c_x < grid_x && 0 <= c_y < grid_y && 0 <= c_z < grid_z
__bang_mul((int32_t *)auxiliary_a, (int32_t *)auxiliary_a,
(int32_t *)auxiliary_b, deal_num);
__bang_not((int32_t *)auxiliary_c, (int32_t *)auxiliary_a, deal_num);
__bang_mul((int32_t *)points_x, (int32_t *)points_x, (int32_t *)auxiliary_a,
deal_num);
__bang_mul_scalar((int32_t *)auxiliary_b, (int32_t *)auxiliary_c,
(int32_t)(-1), deal_num);
__bang_add((int32_t *)points_x, (int32_t *)points_x, (int32_t *)auxiliary_b,
deal_num);
__bang_mul((int32_t *)points_y, (int32_t *)points_y, (int32_t *)auxiliary_a,
deal_num);
__bang_add((int32_t *)points_y, (int32_t *)points_y, (int32_t *)auxiliary_b,
deal_num);
__bang_mul((int32_t *)points_z, (int32_t *)points_z, (int32_t *)auxiliary_a,
deal_num);
__bang_add((int32_t *)points_z, (int32_t *)points_z, (int32_t *)auxiliary_b,
deal_num);
}
__mlu_func__ void computePoint2Voxel(char *coors_x, char *coors_y,
char *coors_z, const int32_t c_x,
const int32_t c_y, const int32_t c_z,
const int32_t max_points, int32_t *num,
int32_t *first_point,
const int32_t deal_idx,
const int32_t deal_num) {
__bang_eq_scalar((int32_t *)coors_x, (int32_t *)coors_x, c_x, deal_num);
__bang_eq_scalar((int32_t *)coors_y, (int32_t *)coors_y, c_y, deal_num);
__bang_eq_scalar((int32_t *)coors_z, (int32_t *)coors_z, c_z, deal_num);
__bang_mul((int32_t *)coors_x, (int32_t *)coors_x, (int32_t *)coors_y,
deal_num);
__bang_mul((int32_t *)coors_x, (int32_t *)coors_x, (int32_t *)coors_z,
deal_num);
if (*num == 0) {
*num = (int32_t)__bang_count((float *)coors_x, deal_num);
if (*num > 0) {
*first_point =
(int32_t)__bang_findfirst1((float *)coors_x, deal_num) + deal_idx;
}
} else {
*num += (int32_t)__bang_count((float *)coors_x, deal_num);
}
}
#endif
__mlu_global__ void MLUUnion1KernelDynamicVoxelize(
const float *points, int32_t *coors, const float voxel_x,
const float voxel_y, const float voxel_z, const float coors_x_min,
const float coors_y_min, const float coors_z_min, const float coors_x_max,
const float coors_y_max, const float coors_z_max, const int32_t grid_x,
const int32_t grid_y, const int32_t grid_z, const int32_t num_points,
const int32_t num_features) {
#if __BANG_ARCH__ >= 322
if (coreId == 0x80) {
return;
}
const int32_t points_rem = num_points % taskDim;
const int32_t points_per_core =
taskId < points_rem ? num_points / taskDim + 1 : num_points / taskDim;
const int32_t points_start = taskId < points_rem
? taskId * points_per_core
: taskId * points_per_core + points_rem;
const int32_t split_num = 9;
const int32_t deal_num =
PAD_DOWN(MAX_NRAM_SIZE / split_num / sizeof(float), NFU_ALIGN_SIZE);
const int32_t repeat = points_per_core / deal_num;
const int32_t rem = points_per_core % deal_num;
const int32_t ping_pong_gap = 3 * deal_num * sizeof(float);
char *points_x = nram_buffer;
char *points_y = points_x + deal_num * sizeof(float);
char *points_z = points_y + deal_num * sizeof(float);
char *auxiliary_a = points_x + 2 * ping_pong_gap;
char *auxiliary_b = auxiliary_a + deal_num * sizeof(float);
char *auxiliary_c = auxiliary_b + deal_num * sizeof(float);
int32_t *coors_z_start = coors + points_start;
int32_t *coors_y_start = coors + num_points + points_start;
int32_t *coors_x_start = coors + num_points * 2 + points_start;
if (repeat > 0) {
__memcpy_async(points_x, points + points_start * num_features,
sizeof(float), GDRAM2NRAM, sizeof(float),
num_features * sizeof(float), deal_num - 1);
__memcpy_async(points_y, points + points_start * num_features + 1,
sizeof(float), GDRAM2NRAM, sizeof(float),
num_features * sizeof(float), deal_num - 1);
__memcpy_async(points_z, points + points_start * num_features + 2,
sizeof(float), GDRAM2NRAM, sizeof(float),
num_features * sizeof(float), deal_num - 1);
__asm__ volatile("sync;");
}
if (repeat > 1) {
__memcpy_async(points_x + ping_pong_gap,
points + (points_start + deal_num) * num_features,
sizeof(float), GDRAM2NRAM, sizeof(float),
num_features * sizeof(float), deal_num - 1);
__memcpy_async(points_y + ping_pong_gap,
points + (points_start + deal_num) * num_features + 1,
sizeof(float), GDRAM2NRAM, sizeof(float),
num_features * sizeof(float), deal_num - 1);
__memcpy_async(points_z + ping_pong_gap,
points + (points_start + deal_num) * num_features + 2,
sizeof(float), GDRAM2NRAM, sizeof(float),
num_features * sizeof(float), deal_num - 1);
computeDynamicVoxelize(points_x, points_y, points_z, auxiliary_a,
auxiliary_b, auxiliary_c, coors_x_min, coors_y_min,
coors_z_min, voxel_x, voxel_y, voxel_z, grid_x,
grid_y, grid_z, deal_num);
__asm__ volatile("sync;");
}
for (int32_t i = 0; i < repeat - 2; ++i) {
__memcpy_async(coors_x_start + i * deal_num,
points_x + (i % 2) * ping_pong_gap,
deal_num * sizeof(int32_t), NRAM2GDRAM);
__memcpy_async(coors_y_start + i * deal_num,
points_y + (i % 2) * ping_pong_gap,
deal_num * sizeof(int32_t), NRAM2GDRAM);
__memcpy_async(coors_z_start + i * deal_num,
points_z + (i % 2) * ping_pong_gap,
deal_num * sizeof(int32_t), NRAM2GDRAM);
__memcpy_async(points_x + (i % 2) * ping_pong_gap,
points + (points_start + (i + 2) * deal_num) * num_features,
sizeof(float), GDRAM2NRAM, sizeof(float),
num_features * sizeof(float), deal_num - 1);
__memcpy_async(
points_y + (i % 2) * ping_pong_gap,
points + (points_start + (i + 2) * deal_num) * num_features + 1,
sizeof(float), GDRAM2NRAM, sizeof(float), num_features * sizeof(float),
deal_num - 1);
__memcpy_async(
points_z + (i % 2) * ping_pong_gap,
points + (points_start + (i + 2) * deal_num) * num_features + 2,
sizeof(float), GDRAM2NRAM, sizeof(float), num_features * sizeof(float),
deal_num - 1);
computeDynamicVoxelize(points_x + ((i + 1) % 2) * ping_pong_gap,
points_y + ((i + 1) % 2) * ping_pong_gap,
points_z + ((i + 1) % 2) * ping_pong_gap,
auxiliary_a, auxiliary_b, auxiliary_c, coors_x_min,
coors_y_min, coors_z_min, voxel_x, voxel_y, voxel_z,
grid_x, grid_y, grid_z, deal_num);
__asm__ volatile("sync;");
}
if (repeat >= 2) {
__memcpy_async(coors_x_start + (repeat - 2) * deal_num,
points_x + (repeat % 2) * ping_pong_gap,
deal_num * sizeof(int32_t), NRAM2GDRAM);
__memcpy_async(coors_y_start + (repeat - 2) * deal_num,
points_y + (repeat % 2) * ping_pong_gap,
deal_num * sizeof(int32_t), NRAM2GDRAM);
__memcpy_async(coors_z_start + (repeat - 2) * deal_num,
points_z + (repeat % 2) * ping_pong_gap,
deal_num * sizeof(int32_t), NRAM2GDRAM);
}
if (rem > 0) {
__memcpy_async(points_x + (repeat % 2) * ping_pong_gap,
points + (points_start + repeat * deal_num) * num_features,
sizeof(float), GDRAM2NRAM, sizeof(float),
num_features * sizeof(float), rem - 1);
__memcpy_async(
points_y + (repeat % 2) * ping_pong_gap,
points + (points_start + repeat * deal_num) * num_features + 1,
sizeof(float), GDRAM2NRAM, sizeof(float), num_features * sizeof(float),
rem - 1);
__memcpy_async(
points_z + (repeat % 2) * ping_pong_gap,
points + (points_start + repeat * deal_num) * num_features + 2,
sizeof(float), GDRAM2NRAM, sizeof(float), num_features * sizeof(float),
rem - 1);
}
if (repeat > 0) {
computeDynamicVoxelize(points_x + ((repeat - 1) % 2) * ping_pong_gap,
points_y + ((repeat - 1) % 2) * ping_pong_gap,
points_z + ((repeat - 1) % 2) * ping_pong_gap,
auxiliary_a, auxiliary_b, auxiliary_c, coors_x_min,
coors_y_min, coors_z_min, voxel_x, voxel_y, voxel_z,
grid_x, grid_y, grid_z, deal_num);
}
__asm__ volatile("sync;");
if (repeat > 0) {
__memcpy_async(coors_x_start + (repeat - 1) * deal_num,
points_x + ((repeat - 1) % 2) * ping_pong_gap,
deal_num * sizeof(int32_t), NRAM2GDRAM);
__memcpy_async(coors_y_start + (repeat - 1) * deal_num,
points_y + ((repeat - 1) % 2) * ping_pong_gap,
deal_num * sizeof(int32_t), NRAM2GDRAM);
__memcpy_async(coors_z_start + (repeat - 1) * deal_num,
points_z + ((repeat - 1) % 2) * ping_pong_gap,
deal_num * sizeof(int32_t), NRAM2GDRAM);
}
if (rem > 0) {
computeDynamicVoxelize(points_x + (repeat % 2) * ping_pong_gap,
points_y + (repeat % 2) * ping_pong_gap,
points_z + (repeat % 2) * ping_pong_gap, auxiliary_a,
auxiliary_b, auxiliary_c, coors_x_min, coors_y_min,
coors_z_min, voxel_x, voxel_y, voxel_z, grid_x,
grid_y, grid_z, rem);
__asm__ volatile("sync;");
__memcpy_async(coors_x_start + repeat * deal_num,
points_x + (repeat % 2) * ping_pong_gap,
rem * sizeof(int32_t), NRAM2GDRAM);
__memcpy_async(coors_y_start + repeat * deal_num,
points_y + (repeat % 2) * ping_pong_gap,
rem * sizeof(int32_t), NRAM2GDRAM);
__memcpy_async(coors_z_start + repeat * deal_num,
points_z + (repeat % 2) * ping_pong_gap,
rem * sizeof(int32_t), NRAM2GDRAM);
}
#endif
}
__mlu_global__ void MLUUnion1KernelPoint2Voxel(int32_t *coors,
int32_t *point_to_pointidx,
int32_t *point_to_voxelidx,
const int32_t num_points,
const int32_t max_points) {
#if __BANG_ARCH__ >= 322
if (coreId == 0x80) {
return;
}
const int32_t split_num = 6;
const int32_t deal_num =
PAD_DOWN(MAX_NRAM_SIZE / split_num / sizeof(int32_t), NFU_ALIGN_SIZE);
const int32_t ping_pong_gap = 3 * deal_num * sizeof(int32_t);
char *coors_x = nram_buffer;
char *coors_y = coors_x + deal_num * sizeof(int32_t);
char *coors_z = coors_y + deal_num * sizeof(int32_t);
int32_t *coors_z_start = coors;
int32_t *coors_y_start = coors + num_points;
int32_t *coors_x_start = coors + num_points * 2;
for (int32_t point_idx = taskId; point_idx < num_points;
point_idx += taskDim) {
if (coors_x_start[point_idx] == -1) {
point_to_pointidx[point_idx] = -1;
point_to_voxelidx[point_idx] = -1;
continue;
}
int32_t c_x = coors_x_start[point_idx];
int32_t c_y = coors_y_start[point_idx];
int32_t c_z = coors_z_start[point_idx];
int32_t deal_total_num = point_idx;
int32_t repeat = deal_total_num / deal_num;
int32_t rem = deal_total_num % deal_num;
int32_t num = 0;
int32_t first_point = -1;
if (repeat > 0) {
__memcpy_async(coors_x, coors_x_start, deal_num * sizeof(int32_t),
GDRAM2NRAM);
__memcpy_async(coors_y, coors_y_start, deal_num * sizeof(int32_t),
GDRAM2NRAM);
__memcpy_async(coors_z, coors_z_start, deal_num * sizeof(int32_t),
GDRAM2NRAM);
__asm__ volatile("sync;");
}
for (int32_t i = 0; i < repeat - 1; ++i) {
__memcpy_async(coors_x + ((i + 1) % 2) * ping_pong_gap,
coors_x_start + (i + 1) * deal_num,
deal_num * sizeof(int32_t), GDRAM2NRAM);
__memcpy_async(coors_y + ((i + 1) % 2) * ping_pong_gap,
coors_y_start + (i + 1) * deal_num,
deal_num * sizeof(int32_t), GDRAM2NRAM);
__memcpy_async(coors_z + ((i + 1) % 2) * ping_pong_gap,
coors_z_start + (i + 1) * deal_num,
deal_num * sizeof(int32_t), GDRAM2NRAM);
computePoint2Voxel(
coors_x + (i % 2) * ping_pong_gap, coors_y + (i % 2) * ping_pong_gap,
coors_z + (i % 2) * ping_pong_gap, c_x, c_y, c_z, max_points, &num,
&first_point, i * deal_num, deal_num);
__asm__ volatile("sync;");
}
if (rem > 0) {
__memcpy_async(coors_x + (repeat % 2) * ping_pong_gap,
coors_x_start + repeat * deal_num, rem * sizeof(int32_t),
GDRAM2NRAM);
__memcpy_async(coors_y + (repeat % 2) * ping_pong_gap,
coors_y_start + repeat * deal_num, rem * sizeof(int32_t),
GDRAM2NRAM);
__memcpy_async(coors_z + (repeat % 2) * ping_pong_gap,
coors_z_start + repeat * deal_num, rem * sizeof(int32_t),
GDRAM2NRAM);
}
if (repeat > 0) {
computePoint2Voxel(coors_x + ((repeat - 1) % 2) * ping_pong_gap,
coors_y + ((repeat - 1) % 2) * ping_pong_gap,
coors_z + ((repeat - 1) % 2) * ping_pong_gap, c_x, c_y,
c_z, max_points, &num, &first_point,
(repeat - 1) * deal_num, deal_num);
}
__asm__ volatile("sync;");
if (rem > 0) {
computePoint2Voxel(coors_x + (repeat % 2) * ping_pong_gap,
coors_y + (repeat % 2) * ping_pong_gap,
coors_z + (repeat % 2) * ping_pong_gap, c_x, c_y, c_z,
max_points, &num, &first_point, repeat * deal_num,
rem);
__asm__ volatile("sync;");
}
if (num == 0) {
point_to_pointidx[point_idx] = point_idx;
} else if (num > 0) {
point_to_pointidx[point_idx] = first_point;
}
if (num < max_points) {
point_to_voxelidx[point_idx] = num;
} else {
point_to_voxelidx[point_idx] = -1;
}
}
#endif
}
__mlu_global__ void MLUUnion1KernelCalcPointsPerVoxel(
int32_t *point_to_pointidx, int32_t *point_to_voxelidx,
int32_t *coor_to_voxelidx, int32_t *num_points_per_voxel,
int32_t *voxel_num, const int32_t max_voxels, const int32_t num_points) {
#if __BANG_ARCH__ >= 322
if (coreId == 0) {
int32_t voxel_num_temp = 0;
for (int32_t point_idx = 0; point_idx < num_points; ++point_idx) {
int32_t point_pos_in_voxel = point_to_voxelidx[point_idx];
coor_to_voxelidx[point_idx] = -1;
if (point_pos_in_voxel == -1) {
continue;
} else if (point_pos_in_voxel == 0) {
int32_t voxel_idx = voxel_num_temp;
if (voxel_num_temp >= max_voxels) {
continue;
}
voxel_num_temp += 1;
coor_to_voxelidx[point_idx] = voxel_idx;
num_points_per_voxel[voxel_idx] = 1;
} else {
int32_t point_idx_temp = point_to_pointidx[point_idx];
int32_t voxel_idx = coor_to_voxelidx[point_idx_temp];
if (voxel_idx != -1) {
coor_to_voxelidx[point_idx] = voxel_idx;
num_points_per_voxel[voxel_idx] += 1;
}
}
}
*voxel_num = voxel_num_temp;
}
#endif
}
__mlu_global__ void MLUUnion1KernelAssignVoxelsCoors(
const float *points, int32_t *temp_coors, int32_t *point_to_voxelidx,
int32_t *coor_to_voxelidx, float *voxels, int32_t *coors,
const int32_t max_points, const int32_t num_points,
const int32_t num_features) {
#if __BANG_ARCH__ >= 322
if (coreId == 0x80) {
return;
}
int32_t points_per_core = num_points / taskDim;
int32_t points_rem = num_points % taskDim;
int32_t points_start = taskId < points_rem
? taskId * (points_per_core + 1)
: taskId * points_per_core + points_rem;
int32_t points_end = taskId < points_rem ? points_start + points_per_core + 1
: points_start + points_per_core;
for (int32_t point_idx = points_start; point_idx < points_end; ++point_idx) {
int32_t num = point_to_voxelidx[point_idx];
int32_t voxel_idx = coor_to_voxelidx[point_idx];
if (num > -1 && voxel_idx > -1) {
float *voxels_offset =
voxels + voxel_idx * max_points * num_features + num * num_features;
const float *points_offset = points + point_idx * num_features;
__memcpy_async(voxels_offset, points_offset, num_features * sizeof(float),
GDRAM2GDRAM);
if (num == 0) {
int32_t *coors_offset = coors + voxel_idx * 3;
__memcpy_async(coors_offset, temp_coors + point_idx, sizeof(int32_t),
GDRAM2GDRAM, sizeof(int32_t),
num_points * sizeof(int32_t), 2);
}
}
}
__asm__ volatile("sync;");
#endif
}
void KernelDynamicVoxelize(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, const void *points, void *coors,
const float voxel_x, const float voxel_y,
const float voxel_z, const float coors_x_min,
const float coors_y_min, const float coors_z_min,
const float coors_x_max, const float coors_y_max,
const float coors_z_max, const int32_t grid_x,
const int32_t grid_y, const int32_t grid_z,
const int32_t num_points,
const int32_t num_features) {
MLUUnion1KernelDynamicVoxelize<<<k_dim, k_type, queue>>>(
(float *)points, (int32_t *)coors, voxel_x, voxel_y, voxel_z, coors_x_min,
coors_y_min, coors_z_min, coors_x_max, coors_y_max, coors_z_max, grid_x,
grid_y, grid_z, num_points, num_features);
}
void KernelPoint2Voxel(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, void *coors, void *point_to_pointidx,
void *point_to_voxelidx, const int32_t num_points,
const int32_t max_points) {
MLUUnion1KernelPoint2Voxel<<<k_dim, k_type, queue>>>(
(int32_t *)coors, (int32_t *)point_to_pointidx,
(int32_t *)point_to_voxelidx, num_points, max_points);
}
void KernelCalcPointsPerVoxel(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, void *point_to_pointidx,
void *point_to_voxelidx, void *coor_to_voxelidx,
void *num_points_per_voxel, void *voxel_num,
const int32_t max_voxels,
const int32_t num_points) {
MLUUnion1KernelCalcPointsPerVoxel<<<k_dim, k_type, queue>>>(
(int32_t *)point_to_pointidx, (int32_t *)point_to_voxelidx,
(int32_t *)coor_to_voxelidx, (int32_t *)num_points_per_voxel,
(int32_t *)voxel_num, max_voxels, num_points);
}
void KernelAssignVoxelsCoors(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, const void *points,
void *temp_coors, void *point_to_voxelidx,
void *coor_to_voxelidx, void *voxels, void *coors,
const int32_t max_points, const int32_t num_points,
const int32_t num_features) {
MLUUnion1KernelAssignVoxelsCoors<<<k_dim, k_type, queue>>>(
(float *)points, (int32_t *)temp_coors, (int32_t *)point_to_voxelidx,
(int32_t *)coor_to_voxelidx, (float *)voxels, (int32_t *)coors,
max_points, num_points, num_features);
}

View File

@ -0,0 +1,268 @@
/*************************************************************************
* Copyright (C) 2022 by Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
#define MIN(a, b) (((a) < (b)) ? (a) : (b))
void KernelDynamicVoxelize(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const void *points, void *coors, const float voxel_x, const float voxel_y,
const float voxel_z, const float coors_x_min, const float coors_y_min,
const float coors_z_min, const float coors_x_max, const float coors_y_max,
const float coors_z_max, const int32_t grid_x, const int32_t grid_y,
const int32_t grid_z, const int32_t num_points, const int32_t num_features);
void KernelPoint2Voxel(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, void *coors, void *point_to_pointidx,
void *point_to_voxelidx, const int32_t num_points,
const int32_t max_points);
void KernelCalcPointsPerVoxel(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, void *point_to_pointidx,
void *point_to_voxelidx, void *coor_to_voxelidx,
void *num_points_per_voxel, void *voxel_num,
const int32_t max_voxels,
const int32_t num_points);
void KernelAssignVoxelsCoors(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, const void *points,
void *temp_coors, void *point_to_voxelidx,
void *coor_to_voxelidx, void *voxels, void *coors,
const int32_t max_points, const int32_t num_points,
const int32_t num_features);
// policy function
static void policyFuncDefault(cnrtDim3_t *k_dim, cnrtFunctionType_t *k_type,
const int num_points) {
k_dim->x = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
k_dim->y = MIN((num_points + k_dim->x - 1) / k_dim->x,
torch_mlu::getDeviceAttr(cnrtAttrClusterCount));
k_dim->z = 1;
*k_type = CNRT_FUNC_TYPE_UNION1;
}
// policy function
static void policyFuncCalcPointsPerVoxel(cnrtDim3_t *k_dim,
cnrtFunctionType_t *k_type,
const int num_points) {
k_dim->x = 1;
k_dim->y = 1;
k_dim->z = 1;
*k_type = CNRT_FUNC_TYPE_BLOCK;
}
int HardVoxelizeForwardMLUKernelLauncher(
const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors,
at::Tensor &num_points_per_voxel, const std::vector<float> voxel_size,
const std::vector<float> coors_range, const int max_points,
const int max_voxels, const int NDim = 3) {
// check datatype
TORCH_CHECK(points.scalar_type() == at::kFloat,
"points type should be Float, got ", points.scalar_type(), ".");
TORCH_CHECK(voxels.scalar_type() == at::kFloat,
"voxels type should be Float, got ", voxels.scalar_type(), ".");
TORCH_CHECK(coors.scalar_type() == at::kInt,
"coors type should be Float, got ", coors.scalar_type(), ".");
TORCH_CHECK(num_points_per_voxel.scalar_type() == at::kInt,
"num_points_per_voxel type should be Float, got ",
num_points_per_voxel.scalar_type(), ".");
// check shape
TORCH_CHECK(points.dim() == 2, "points should be a 2d tensor, got ",
points.dim(), "D.");
TORCH_CHECK(voxels.dim() == 3, "voxels should be a 3d tensor, got ",
voxels.dim(), "D.");
TORCH_CHECK(coors.dim() == 2, "coors should be a 2d tensor, got ",
coors.dim(), "D.");
TORCH_CHECK(num_points_per_voxel.dim() == 1,
"num_points_per_voxel should be a 1d tensor, got ",
num_points_per_voxel.dim(), "D.");
const int num_points = points.size(0);
const int num_features = points.size(1);
TORCH_CHECK(points.size(0) == num_points,
"the 1st dimensions of points should be num_points, got ",
points.size(0), ".");
TORCH_CHECK(points.size(1) == num_features,
"the 2nd dimensions of points should be num_features, got ",
points.size(1), ".");
TORCH_CHECK(voxels.size(0) == max_voxels,
"the 1st dimensions of voxels should be max_voxels, got ",
voxels.size(0), ".");
TORCH_CHECK(voxels.size(1) == max_points,
"the 2nd dimensions of voxels should be max_points, got ",
voxels.size(1), ".");
TORCH_CHECK(voxels.size(2) == num_features,
"the 3rd dimensions of voxels should be num_features, got ",
voxels.size(2), ".");
TORCH_CHECK(coors.size(0) == max_voxels,
"the 1st dimensions of coors should be max_voxels, got ",
coors.size(0), ".");
TORCH_CHECK(coors.size(1) == 3,
"the 2nd dimensions of coors should be 3, got ", coors.size(1),
".");
TORCH_CHECK(num_points_per_voxel.size(0) == max_voxels,
"the 1st dimensions of num_points_per_voxel should be 3, got ",
num_points_per_voxel.size(0), ".");
// large tensor check
const size_t max_input_size = 2147483648;
TORCH_CHECK(points.numel() < max_input_size,
"points element num should be less than 2^31, got ",
points.numel(), ".");
TORCH_CHECK(voxels.numel() < max_input_size,
"voxels element num should be less than 2^31, got ",
voxels.numel(), ".");
TORCH_CHECK(coors.numel() < max_input_size,
"coors element num should be less than 2^31, got ", coors.numel(),
".");
// check zero element
if (max_points == 0 || max_voxels == 0) {
return 0;
}
// get compute queue
auto queue = torch_mlu::getCurQueue();
// get ptr of tensors
auto points_ = points.contiguous();
auto points_impl = torch_mlu::getMluTensorImpl(points_);
auto points_ptr = points_impl->cnnlMalloc();
auto voxels_ = voxels.contiguous();
auto voxels_impl = torch_mlu::getMluTensorImpl(voxels_);
auto voxels_ptr = voxels_impl->cnnlMalloc();
auto coors_ = coors.contiguous();
auto coors_impl = torch_mlu::getMluTensorImpl(coors_);
auto coors_ptr = coors_impl->cnnlMalloc();
auto num_points_per_voxel_ = num_points_per_voxel.contiguous();
auto num_points_per_voxel_impl =
torch_mlu::getMluTensorImpl(num_points_per_voxel_);
auto num_points_per_voxel_ptr = num_points_per_voxel_impl->cnnlMalloc();
// calculate task dimension
cnrtDim3_t k_dim;
cnrtFunctionType_t k_type;
policyFuncDefault(&k_dim, &k_type, num_points);
// 1. link point to corresponding voxel coors
const float voxel_x = voxel_size[0];
const float voxel_y = voxel_size[1];
const float voxel_z = voxel_size[2];
const float coors_x_min = coors_range[0];
const float coors_y_min = coors_range[1];
const float coors_z_min = coors_range[2];
const float coors_x_max = coors_range[3];
const float coors_y_max = coors_range[4];
const float coors_z_max = coors_range[5];
const int grid_x = round((coors_x_max - coors_x_min) / voxel_x);
const int grid_y = round((coors_y_max - coors_y_min) / voxel_y);
const int grid_z = round((coors_z_max - coors_z_min) / voxel_z);
auto temp_coors =
at::zeros({NDim, num_points}, points.options().dtype(at::kInt))
.contiguous();
auto temp_coors_impl = torch_mlu::getMluTensorImpl(temp_coors);
auto temp_coors_ptr = temp_coors_impl->cnnlMalloc();
KernelDynamicVoxelize(k_dim, k_type, queue, points_ptr, temp_coors_ptr,
voxel_x, voxel_y, voxel_z, coors_x_min, coors_y_min,
coors_z_min, coors_x_max, coors_y_max, coors_z_max,
grid_x, grid_y, grid_z, num_points, num_features);
// 2. map point to the idx of the corresponding voxel, find duplicate coor
auto point_to_pointidx = at::zeros(
{
num_points,
},
points.options().dtype(at::kInt))
.contiguous();
auto point_to_pointidx_impl = torch_mlu::getMluTensorImpl(point_to_pointidx);
auto point_to_pointidx_ptr = point_to_pointidx_impl->cnnlMalloc();
auto point_to_voxelidx = at::zeros(
{
num_points,
},
points.options().dtype(at::kInt))
.contiguous();
auto point_to_voxelidx_impl = torch_mlu::getMluTensorImpl(point_to_voxelidx);
auto point_to_voxelidx_ptr = point_to_voxelidx_impl->cnnlMalloc();
KernelPoint2Voxel(k_dim, k_type, queue, temp_coors_ptr, point_to_pointidx_ptr,
point_to_voxelidx_ptr, num_points, max_points);
// calculate task dimension
cnrtDim3_t k_dim_calc_points_per_voxel;
cnrtFunctionType_t k_type_calc_points_per_voxel;
policyFuncCalcPointsPerVoxel(&k_dim_calc_points_per_voxel,
&k_type_calc_points_per_voxel, num_points);
// 3. determine voxel num and voxel's coor index
auto coor_to_voxelidx = at::zeros(
{
num_points,
},
points.options().dtype(at::kInt))
.contiguous();
auto coor_to_voxelidx_impl = torch_mlu::getMluTensorImpl(coor_to_voxelidx);
auto coor_to_voxelidx_ptr = coor_to_voxelidx_impl->cnnlMalloc();
auto voxel_num = at::zeros(
{
1,
},
points.options().dtype(at::kInt))
.contiguous();
auto voxel_num_impl = torch_mlu::getMluTensorImpl(voxel_num);
auto voxel_num_ptr = voxel_num_impl->cnnlMalloc();
KernelCalcPointsPerVoxel(
k_dim_calc_points_per_voxel, k_type_calc_points_per_voxel, queue,
point_to_pointidx_ptr, point_to_voxelidx_ptr, coor_to_voxelidx_ptr,
num_points_per_voxel_ptr, voxel_num_ptr, max_voxels, num_points);
// 4. copy point features and coors of each voxels to voxels
KernelAssignVoxelsCoors(k_dim, k_type, queue, points_ptr, temp_coors_ptr,
point_to_voxelidx_ptr, coor_to_voxelidx_ptr,
voxels_ptr, coors_ptr, max_points, num_points,
num_features);
auto voxel_num_cpu = voxel_num.to(at::kCPU);
int voxel_num_int = voxel_num_cpu.data_ptr<int>()[0];
return voxel_num_int;
}
int hard_voxelize_forward_mlu(const at::Tensor &points, at::Tensor &voxels,
at::Tensor &coors,
at::Tensor &num_points_per_voxel,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const int max_points, const int max_voxels,
const int NDim) {
return HardVoxelizeForwardMLUKernelLauncher(
points, voxels, coors, num_points_per_voxel, voxel_size, coors_range,
max_points, max_voxels, NDim);
};
int hard_voxelize_forward_impl(const at::Tensor &points, at::Tensor &voxels,
at::Tensor &coors,
at::Tensor &num_points_per_voxel,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const int max_points, const int max_voxels,
const int NDim);
REGISTER_DEVICE_IMPL(hard_voxelize_forward_impl, MLU,
hard_voxelize_forward_mlu);

View File

@ -4,6 +4,7 @@ import pytest
import torch
from mmcv.ops import Voxelization
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
def _get_voxel_points_indices(points, coors, voxel):
@ -16,7 +17,7 @@ def _get_voxel_points_indices(points, coors, voxel):
pytest.param(
'cuda:0',
marks=pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support'))
not IS_CUDA_AVAILABLE, reason='requires CUDA support'))
])
def test_voxelization(device_type):
voxel_size = [0.5, 0.5, 0.5]
@ -62,8 +63,7 @@ def test_voxelization(device_type):
assert num_points_current_voxel == expected_num_points_per_voxel[i]
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
@pytest.mark.skipif(not IS_CUDA_AVAILABLE, reason='requires CUDA support')
def test_voxelization_nondeterministic():
voxel_size = [0.5, 0.5, 0.5]
point_cloud_range = [0, -40, -3, 70.4, 40, 1]
@ -137,3 +137,38 @@ def test_voxelization_nondeterministic():
coors_all_set = {tuple(c) for c in coors_all}
assert len(coors_set) == len(coors) == len(coors_all_set)
@pytest.mark.parametrize('device_type', [
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
])
def test_voxelization_mlu(device_type):
voxel_size = [0.5, 0.5, 0.5]
point_cloud_range = [0, -40, -3, 70.4, 40, 1]
voxel_dict = np.load(
'tests/data/for_3d_ops/test_voxel.npy', allow_pickle=True).item()
expected_coors = voxel_dict['coors']
expected_voxels = voxel_dict['voxels']
expected_num_points_per_voxel = voxel_dict['num_points_per_voxel']
points = voxel_dict['points']
points = torch.tensor(points)
max_num_points = 1000
hard_voxelization = Voxelization(voxel_size, point_cloud_range,
max_num_points)
device = torch.device(device_type)
# test hard_voxelization on mlu
points = points.contiguous().to(device)
coors, voxels, num_points_per_voxel = hard_voxelization.forward(points)
coors = coors.cpu().detach().numpy()
voxels = voxels.cpu().detach().numpy()
num_points_per_voxel = num_points_per_voxel.cpu().detach().numpy()
assert np.all(coors == expected_coors)
assert np.all(voxels == expected_voxels)
assert np.all(num_points_per_voxel == expected_num_points_per_voxel)