mirror of https://github.com/open-mmlab/mmcv.git
[Refactor] Replace DIVUP with GET_BLOCKS (#1586)
* [Improve] migrating DIVUP to GET_BLOCKS * [Fix] use GET_BLOCKS only for block alloc and del useless statements * [Fix] add kernel loop for nms and del useless statementspull/1515/merge
parent
cf754db983
commit
b586cc2f6a
|
@ -22,8 +22,7 @@ __global__ void assign_score_withk_forward_cuda_kernel(
|
|||
const int O, const int aggregate, const T* points, const T* centers,
|
||||
const T* scores, const int64_t* knn_idx, T* output) {
|
||||
// ----- parallel loop for B, N1, K and O ---------
|
||||
long i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i >= B * N1 * K * O) return;
|
||||
CUDA_1D_KERNEL_LOOP(i, B * O * N1 * K) {
|
||||
// ------- loop for M ----------
|
||||
const int b = (int)(i / (O * N1 * K));
|
||||
const int o = (int)(i % (O * N1 * K) / (N1 * K));
|
||||
|
@ -51,6 +50,7 @@ __global__ void assign_score_withk_forward_cuda_kernel(
|
|||
}
|
||||
output[out_idx] = val;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void assign_score_withk_points_backward_cuda_kernel(
|
||||
|
@ -58,8 +58,7 @@ __global__ void assign_score_withk_points_backward_cuda_kernel(
|
|||
const int O, const int aggregate, const T* grad_out, const T* scores,
|
||||
const int64_t* knn_idx, T* grad_points, T* grad_centers) {
|
||||
// ----- parallel loop for B, M, O ---------
|
||||
long i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i >= B * M * O) return;
|
||||
CUDA_1D_KERNEL_LOOP(i, B * M * O) {
|
||||
int b = (int)(i / (M * O));
|
||||
int m = (int)(i % (M * O) / O);
|
||||
int o = (int)(i % O);
|
||||
|
@ -69,8 +68,8 @@ __global__ void assign_score_withk_points_backward_cuda_kernel(
|
|||
for (int k = 0; k < K; k++) {
|
||||
int kn = knn_idx[b * N * K + n * K + k];
|
||||
int cn = knn_idx[b * N * K + n * K + 0];
|
||||
if (kn >= N0 ||
|
||||
kn < 0) { // if index overflows, it is out of the neighborhood range
|
||||
if (kn >= N0 || kn < 0) { // if index overflows, it is out of the
|
||||
// neighborhood range
|
||||
continue;
|
||||
}
|
||||
atomicAdd(grad_points + b * N0 * M * O + kn * M * O + m * O + o,
|
||||
|
@ -82,6 +81,7 @@ __global__ void assign_score_withk_points_backward_cuda_kernel(
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void assign_score_withk_scores_backward_cuda_kernel(
|
||||
|
@ -89,8 +89,7 @@ __global__ void assign_score_withk_scores_backward_cuda_kernel(
|
|||
const int O, const int aggregate, const T* grad_out, const T* points,
|
||||
const T* centers, const int64_t* knn_idx, T* grad_scores) {
|
||||
// ----- parallel loop for B, N, K, M ---------
|
||||
long i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i >= B * N * K * M) return;
|
||||
CUDA_1D_KERNEL_LOOP(i, B * N * K * M) {
|
||||
const int b = (int)(i / (N * M * K));
|
||||
const int n = (int)(i % (N * M * K) / M / K);
|
||||
const int k = (int)(i % (M * K) / M);
|
||||
|
@ -112,5 +111,6 @@ __global__ void assign_score_withk_scores_backward_cuda_kernel(
|
|||
}
|
||||
grad_scores[out_idx] = val;
|
||||
}
|
||||
}
|
||||
|
||||
#endif // ASSIGN_SCORE_WITHK_CUDA_KERNEL_CUH
|
||||
|
|
|
@ -21,8 +21,8 @@ __global__ void ball_query_forward_cuda_kernel(int b, int n, int m,
|
|||
// output:
|
||||
// idx: (B, M, nsample)
|
||||
int bs_idx = blockIdx.y;
|
||||
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (bs_idx >= b || pt_idx >= m) return;
|
||||
CUDA_1D_KERNEL_LOOP(pt_idx, m) {
|
||||
if (bs_idx >= b) return;
|
||||
|
||||
new_xyz += bs_idx * m * 3 + pt_idx * 3;
|
||||
xyz += bs_idx * n * 3;
|
||||
|
@ -53,5 +53,6 @@ __global__ void ball_query_forward_cuda_kernel(int b, int n, int m,
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif // BALL_QUERY_CUDA_KERNEL_CUH
|
||||
|
|
|
@ -7,12 +7,20 @@
|
|||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
|
||||
i += blockDim.x * gridDim.x)
|
||||
|
||||
#define CUDA_2D_KERNEL_LOOP(i, n, j, m) \
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
|
||||
i += blockDim.x * gridDim.x) \
|
||||
for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); \
|
||||
j += blockDim.y * gridDim.y)
|
||||
|
||||
#define CUDA_2D_KERNEL_BLOCK_LOOP(i, n, j, m) \
|
||||
for (size_t i = blockIdx.x; i < (n); i += gridDim.x) \
|
||||
for (size_t j = blockIdx.y; j < (m); j += gridDim.y)
|
||||
|
||||
#define THREADS_PER_BLOCK 512
|
||||
|
||||
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
|
||||
|
||||
inline int GET_BLOCKS(const int N) {
|
||||
int optimal_block_num = (N + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
|
||||
inline int GET_BLOCKS(const int N, const int num_threads = THREADS_PER_BLOCK) {
|
||||
int optimal_block_num = (N + num_threads - 1) / num_threads;
|
||||
int max_block_num = 4096;
|
||||
return min(optimal_block_num, max_block_num);
|
||||
}
|
||||
|
|
|
@ -22,14 +22,15 @@ __global__ void gather_points_forward_cuda_kernel(int b, int c, int n, int m,
|
|||
|
||||
int bs_idx = blockIdx.z;
|
||||
int c_idx = blockIdx.y;
|
||||
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (bs_idx >= b || c_idx >= c || pt_idx >= m) return;
|
||||
CUDA_1D_KERNEL_LOOP(pt_idx, m) {
|
||||
if (bs_idx >= b || c_idx >= c) return;
|
||||
|
||||
out += bs_idx * c * m + c_idx * m + pt_idx;
|
||||
idx += bs_idx * m + pt_idx;
|
||||
points += bs_idx * c * n + c_idx * n;
|
||||
out[0] = points[idx[0]];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void gather_points_backward_cuda_kernel(int b, int c, int n, int m,
|
||||
|
@ -43,8 +44,8 @@ __global__ void gather_points_backward_cuda_kernel(int b, int c, int n, int m,
|
|||
|
||||
int bs_idx = blockIdx.z;
|
||||
int c_idx = blockIdx.y;
|
||||
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (bs_idx >= b || c_idx >= c || pt_idx >= m) return;
|
||||
CUDA_1D_KERNEL_LOOP(pt_idx, m) {
|
||||
if (bs_idx >= b || c_idx >= c) return;
|
||||
|
||||
grad_out += bs_idx * c * m + c_idx * m + pt_idx;
|
||||
idx += bs_idx * m + pt_idx;
|
||||
|
@ -52,5 +53,6 @@ __global__ void gather_points_backward_cuda_kernel(int b, int c, int n, int m,
|
|||
|
||||
atomicAdd(grad_points + idx[0], grad_out[0]);
|
||||
}
|
||||
}
|
||||
|
||||
#endif // GATHER_POINTS_CUDA_KERNEL_CUH
|
||||
|
|
|
@ -22,10 +22,10 @@ __global__ void group_points_forward_cuda_kernel(int b, int c, int n,
|
|||
// out: (B, C, npoints, nsample)
|
||||
int bs_idx = blockIdx.z;
|
||||
int c_idx = blockIdx.y;
|
||||
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int pt_idx = index / nsample;
|
||||
if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return;
|
||||
CUDA_1D_KERNEL_LOOP(index, npoints * nsample) {
|
||||
if (bs_idx >= b || c_idx >= c) return;
|
||||
|
||||
int pt_idx = index / nsample;
|
||||
int sample_idx = index % nsample;
|
||||
|
||||
idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx;
|
||||
|
@ -35,6 +35,7 @@ __global__ void group_points_forward_cuda_kernel(int b, int c, int n,
|
|||
|
||||
out[out_idx] = points[in_idx];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void group_points_backward_cuda_kernel(int b, int c, int n,
|
||||
|
@ -48,9 +49,9 @@ __global__ void group_points_backward_cuda_kernel(int b, int c, int n,
|
|||
// grad_points: (B, C, N)
|
||||
int bs_idx = blockIdx.z;
|
||||
int c_idx = blockIdx.y;
|
||||
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
CUDA_1D_KERNEL_LOOP(index, npoints * nsample) {
|
||||
int pt_idx = index / nsample;
|
||||
if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return;
|
||||
if (bs_idx >= b || c_idx >= c) return;
|
||||
|
||||
int sample_idx = index % nsample;
|
||||
grad_out += bs_idx * c * npoints * nsample + c_idx * npoints * nsample +
|
||||
|
@ -59,5 +60,6 @@ __global__ void group_points_backward_cuda_kernel(int b, int c, int n,
|
|||
|
||||
atomicAdd(grad_points + bs_idx * c * n + c_idx * n + idx[0], grad_out[0]);
|
||||
}
|
||||
}
|
||||
|
||||
#endif // GROUP_POINTS_CUDA_KERNEL_CUH
|
||||
|
|
|
@ -220,9 +220,7 @@ __device__ inline float iou_bev(const float *box_a, const float *box_b) {
|
|||
__global__ void iou3d_boxes_overlap_bev_forward_cuda_kernel(
|
||||
const int num_a, const float *boxes_a, const int num_b,
|
||||
const float *boxes_b, float *ans_overlap) {
|
||||
const int a_idx = blockIdx.y * THREADS_PER_BLOCK + threadIdx.y;
|
||||
const int b_idx = blockIdx.x * THREADS_PER_BLOCK + threadIdx.x;
|
||||
|
||||
CUDA_2D_KERNEL_LOOP(b_idx, num_b, a_idx, num_a) {
|
||||
if (a_idx >= num_a || b_idx >= num_b) {
|
||||
return;
|
||||
}
|
||||
|
@ -231,15 +229,14 @@ __global__ void iou3d_boxes_overlap_bev_forward_cuda_kernel(
|
|||
float s_overlap = box_overlap(cur_box_a, cur_box_b);
|
||||
ans_overlap[a_idx * num_b + b_idx] = s_overlap;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void iou3d_boxes_iou_bev_forward_cuda_kernel(const int num_a,
|
||||
const float *boxes_a,
|
||||
const int num_b,
|
||||
const float *boxes_b,
|
||||
float *ans_iou) {
|
||||
const int a_idx = blockIdx.y * THREADS_PER_BLOCK + threadIdx.y;
|
||||
const int b_idx = blockIdx.x * THREADS_PER_BLOCK + threadIdx.x;
|
||||
|
||||
CUDA_2D_KERNEL_LOOP(b_idx, num_b, a_idx, num_a) {
|
||||
if (a_idx >= num_a || b_idx >= num_b) {
|
||||
return;
|
||||
}
|
||||
|
@ -249,6 +246,7 @@ __global__ void iou3d_boxes_iou_bev_forward_cuda_kernel(const int num_a,
|
|||
float cur_iou_bev = iou_bev(cur_box_a, cur_box_b);
|
||||
ans_iou[a_idx * num_b + b_idx] = cur_iou_bev;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void nms_forward_cuda_kernel(const int boxes_num,
|
||||
const float nms_overlap_thresh,
|
||||
|
@ -256,10 +254,9 @@ __global__ void nms_forward_cuda_kernel(const int boxes_num,
|
|||
unsigned long long *mask) {
|
||||
// params: boxes (N, 5) [x1, y1, x2, y2, ry]
|
||||
// params: mask (N, N/THREADS_PER_BLOCK_NMS)
|
||||
|
||||
const int row_start = blockIdx.y;
|
||||
const int col_start = blockIdx.x;
|
||||
|
||||
const int blocks =
|
||||
(boxes_num + THREADS_PER_BLOCK_NMS - 1) / THREADS_PER_BLOCK_NMS;
|
||||
CUDA_2D_KERNEL_BLOCK_LOOP(col_start, blocks, row_start, blocks) {
|
||||
// if (row_start > col_start) return;
|
||||
|
||||
const int row_size = fminf(boxes_num - row_start * THREADS_PER_BLOCK_NMS,
|
||||
|
@ -298,10 +295,12 @@ __global__ void nms_forward_cuda_kernel(const int boxes_num,
|
|||
t |= 1ULL << i;
|
||||
}
|
||||
}
|
||||
const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);
|
||||
const int col_blocks =
|
||||
(boxes_num + THREADS_PER_BLOCK_NMS - 1) / THREADS_PER_BLOCK_NMS;
|
||||
mask[cur_box_idx * col_blocks + col_start] = t;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline float iou_normal(float const *const a, float const *const b) {
|
||||
float left = fmaxf(a[0], b[0]), right = fminf(a[2], b[2]);
|
||||
|
@ -320,9 +319,9 @@ __global__ void nms_normal_forward_cuda_kernel(const int boxes_num,
|
|||
// params: boxes (N, 5) [x1, y1, x2, y2, ry]
|
||||
// params: mask (N, N/THREADS_PER_BLOCK_NMS)
|
||||
|
||||
const int row_start = blockIdx.y;
|
||||
const int col_start = blockIdx.x;
|
||||
|
||||
const int blocks =
|
||||
(boxes_num + THREADS_PER_BLOCK_NMS - 1) / THREADS_PER_BLOCK_NMS;
|
||||
CUDA_2D_KERNEL_BLOCK_LOOP(col_start, blocks, row_start, blocks) {
|
||||
// if (row_start > col_start) return;
|
||||
|
||||
const int row_size = fminf(boxes_num - row_start * THREADS_PER_BLOCK_NMS,
|
||||
|
@ -361,9 +360,11 @@ __global__ void nms_normal_forward_cuda_kernel(const int boxes_num,
|
|||
t |= 1ULL << i;
|
||||
}
|
||||
}
|
||||
const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);
|
||||
const int col_blocks =
|
||||
(boxes_num + THREADS_PER_BLOCK_NMS - 1) / THREADS_PER_BLOCK_NMS;
|
||||
mask[cur_box_idx * col_blocks + col_start] = t;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif // IOU3D_CUDA_KERNEL_CUH
|
||||
|
|
|
@ -51,8 +51,8 @@ __global__ void knn_forward_cuda_kernel(int b, int n, int m, int nsample,
|
|||
const T *xyz, const T *new_xyz,
|
||||
int *__restrict__ idx, T *dist2) {
|
||||
int bs_idx = blockIdx.y;
|
||||
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (bs_idx >= b || pt_idx >= m) return;
|
||||
CUDA_1D_KERNEL_LOOP(pt_idx, m) {
|
||||
if (bs_idx >= b) return;
|
||||
|
||||
new_xyz += bs_idx * m * 3 + pt_idx * 3;
|
||||
xyz += bs_idx * n * 3;
|
||||
|
@ -87,5 +87,6 @@ __global__ void knn_forward_cuda_kernel(int b, int n, int m, int nsample,
|
|||
dist2[i] = best_dist[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif // KNN_CUDA_KERNEL_CUH
|
||||
|
|
|
@ -15,9 +15,6 @@
|
|||
#include "pytorch_cuda_helper.hpp"
|
||||
|
||||
const int CUDA_NUM_THREADS = 1024;
|
||||
inline int GET_BLOCKS(const int N, const int num_threads) {
|
||||
return (N + num_threads - 1) / num_threads;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ scalar_t ms_deform_attn_im2col_bilinear(
|
||||
|
|
|
@ -30,8 +30,8 @@ __device__ inline bool devIoU(float const *const a, float const *const b,
|
|||
__global__ void nms_cuda(const int n_boxes, const float iou_threshold,
|
||||
const int offset, const float *dev_boxes,
|
||||
unsigned long long *dev_mask) {
|
||||
const int row_start = blockIdx.y;
|
||||
const int col_start = blockIdx.x;
|
||||
int blocks = (n_boxes + threadsPerBlock - 1) / threadsPerBlock;
|
||||
CUDA_2D_KERNEL_BLOCK_LOOP(col_start, blocks, row_start, blocks) {
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
if (row_start > col_start) return;
|
||||
|
@ -71,4 +71,5 @@ __global__ void nms_cuda(const int n_boxes, const float iou_threshold,
|
|||
dev_mask[cur_box_idx * gridDim.y + col_start] = t;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // NMS_CUDA_KERNEL_CUH
|
||||
|
|
|
@ -45,8 +45,8 @@ __global__ void points_in_boxes_part_forward_cuda_kernel(
|
|||
// (B, npoints), default -1
|
||||
|
||||
int bs_idx = blockIdx.y;
|
||||
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (bs_idx >= batch_size || pt_idx >= pts_num) return;
|
||||
CUDA_1D_KERNEL_LOOP(pt_idx, pts_num) {
|
||||
if (bs_idx >= batch_size) return;
|
||||
|
||||
boxes += bs_idx * boxes_num * 7;
|
||||
pts += bs_idx * pts_num * 3 + pt_idx * 3;
|
||||
|
@ -62,6 +62,7 @@ __global__ void points_in_boxes_part_forward_cuda_kernel(
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void points_in_boxes_all_forward_cuda_kernel(
|
||||
|
@ -73,8 +74,8 @@ __global__ void points_in_boxes_all_forward_cuda_kernel(
|
|||
// (B, npoints), default -1
|
||||
|
||||
int bs_idx = blockIdx.y;
|
||||
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (bs_idx >= batch_size || pt_idx >= pts_num) return;
|
||||
CUDA_1D_KERNEL_LOOP(pt_idx, pts_num) {
|
||||
if (bs_idx >= batch_size) return;
|
||||
|
||||
boxes += bs_idx * boxes_num * 7;
|
||||
pts += bs_idx * pts_num * 3 + pt_idx * 3;
|
||||
|
@ -89,5 +90,6 @@ __global__ void points_in_boxes_all_forward_cuda_kernel(
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif // POINT_IN_BOXES_CUDA_KERNEL_CUH
|
||||
|
|
|
@ -44,9 +44,9 @@ __global__ void generate_pts_mask_for_box3d(int boxes_num, int pts_num,
|
|||
// coordinate params pts: (npoints, 3) [x, y, z] params pts_mask: (N,
|
||||
// npoints): -1 means point does not in this box, otherwise: encode (x_idxs,
|
||||
// y_idxs, z_idxs) by binary bit
|
||||
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int box_idx = blockIdx.y;
|
||||
if (pt_idx >= pts_num || box_idx >= boxes_num) return;
|
||||
CUDA_1D_KERNEL_LOOP(pt_idx, pts_num) {
|
||||
if (box_idx >= boxes_num) return;
|
||||
|
||||
pts += pt_idx * 3;
|
||||
rois += box_idx * 7;
|
||||
|
@ -77,6 +77,7 @@ __global__ void generate_pts_mask_for_box3d(int boxes_num, int pts_num,
|
|||
pts_mask[0] = idx_encoding;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void collect_inside_pts_for_box3d(int boxes_num, int pts_num,
|
||||
|
@ -86,10 +87,7 @@ __global__ void collect_inside_pts_for_box3d(int boxes_num, int pts_num,
|
|||
T *pts_idx_of_voxels) {
|
||||
// params pts_mask: (N, npoints) 0 or 1
|
||||
// params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel)
|
||||
|
||||
int box_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (box_idx >= boxes_num) return;
|
||||
|
||||
CUDA_1D_KERNEL_LOOP(box_idx, boxes_num) {
|
||||
int max_num_pts = max_pts_each_voxel - 1; // index 0 is the counter
|
||||
pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel;
|
||||
|
||||
|
@ -110,6 +108,7 @@ __global__ void collect_inside_pts_for_box3d(int boxes_num, int pts_num,
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void roiaware_maxpool3d(int boxes_num, int pts_num, int channels,
|
||||
|
@ -124,14 +123,11 @@ __global__ void roiaware_maxpool3d(int boxes_num, int pts_num, int channels,
|
|||
|
||||
int box_idx = blockIdx.z;
|
||||
int channel_idx = blockIdx.y;
|
||||
int voxel_idx_flat = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
CUDA_1D_KERNEL_LOOP(voxel_idx_flat, out_x * out_y * out_z) {
|
||||
int x_idx = voxel_idx_flat / (out_y * out_z);
|
||||
int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z;
|
||||
int z_idx = voxel_idx_flat % out_z;
|
||||
if (box_idx >= boxes_num || channel_idx >= channels || x_idx >= out_x ||
|
||||
y_idx >= out_y || z_idx >= out_z)
|
||||
return;
|
||||
if (box_idx >= boxes_num || channel_idx >= channels) return;
|
||||
|
||||
int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx;
|
||||
pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel +
|
||||
|
@ -147,7 +143,8 @@ __global__ void roiaware_maxpool3d(int boxes_num, int pts_num, int channels,
|
|||
int total_pts = pts_idx_of_voxels[0];
|
||||
|
||||
for (int k = 1; k <= total_pts; k++) {
|
||||
if (pts_feature[pts_idx_of_voxels[k] * channels + channel_idx] > max_val) {
|
||||
if (pts_feature[pts_idx_of_voxels[k] * channels + channel_idx] >
|
||||
max_val) {
|
||||
max_val = pts_feature[pts_idx_of_voxels[k] * channels + channel_idx];
|
||||
argmax_idx = pts_idx_of_voxels[k];
|
||||
}
|
||||
|
@ -158,6 +155,7 @@ __global__ void roiaware_maxpool3d(int boxes_num, int pts_num, int channels,
|
|||
}
|
||||
argmax[0] = argmax_idx;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void roiaware_avgpool3d(int boxes_num, int pts_num, int channels,
|
||||
|
@ -172,14 +170,11 @@ __global__ void roiaware_avgpool3d(int boxes_num, int pts_num, int channels,
|
|||
|
||||
int box_idx = blockIdx.z;
|
||||
int channel_idx = blockIdx.y;
|
||||
int voxel_idx_flat = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
CUDA_1D_KERNEL_LOOP(voxel_idx_flat, out_x * out_y * out_z) {
|
||||
int x_idx = voxel_idx_flat / (out_y * out_z);
|
||||
int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z;
|
||||
int z_idx = voxel_idx_flat % out_z;
|
||||
if (box_idx >= boxes_num || channel_idx >= channels || x_idx >= out_x ||
|
||||
y_idx >= out_y || z_idx >= out_z)
|
||||
return;
|
||||
if (box_idx >= boxes_num || channel_idx >= channels) return;
|
||||
|
||||
int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx;
|
||||
pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel +
|
||||
|
@ -198,6 +193,7 @@ __global__ void roiaware_avgpool3d(int boxes_num, int pts_num, int channels,
|
|||
pooled_features[0] = sum_val / total_pts;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void roiaware_maxpool3d_backward(int boxes_num, int channels,
|
||||
|
@ -210,14 +206,11 @@ __global__ void roiaware_maxpool3d_backward(int boxes_num, int channels,
|
|||
|
||||
int box_idx = blockIdx.z;
|
||||
int channel_idx = blockIdx.y;
|
||||
int voxel_idx_flat = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
CUDA_1D_KERNEL_LOOP(voxel_idx_flat, out_x * out_y * out_z) {
|
||||
int x_idx = voxel_idx_flat / (out_y * out_z);
|
||||
int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z;
|
||||
int z_idx = voxel_idx_flat % out_z;
|
||||
if (box_idx >= boxes_num || channel_idx >= channels || x_idx >= out_x ||
|
||||
y_idx >= out_y || z_idx >= out_z)
|
||||
return;
|
||||
if (box_idx >= boxes_num || channel_idx >= channels) return;
|
||||
|
||||
int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx;
|
||||
argmax += box_idx * out_x * out_y * out_z * channels +
|
||||
|
@ -229,6 +222,7 @@ __global__ void roiaware_maxpool3d_backward(int boxes_num, int channels,
|
|||
|
||||
atomicAdd(grad_in + argmax[0] * channels + channel_idx, grad_out[0] * 1);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void roiaware_avgpool3d_backward(int boxes_num, int channels,
|
||||
|
@ -242,14 +236,11 @@ __global__ void roiaware_avgpool3d_backward(int boxes_num, int channels,
|
|||
|
||||
int box_idx = blockIdx.z;
|
||||
int channel_idx = blockIdx.y;
|
||||
int voxel_idx_flat = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
CUDA_1D_KERNEL_LOOP(voxel_idx_flat, out_x * out_y * out_z) {
|
||||
int x_idx = voxel_idx_flat / (out_y * out_z);
|
||||
int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z;
|
||||
int z_idx = voxel_idx_flat % out_z;
|
||||
if (box_idx >= boxes_num || channel_idx >= channels || x_idx >= out_x ||
|
||||
y_idx >= out_y || z_idx >= out_z)
|
||||
return;
|
||||
if (box_idx >= boxes_num || channel_idx >= channels) return;
|
||||
|
||||
int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx;
|
||||
pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel +
|
||||
|
@ -264,5 +255,6 @@ __global__ void roiaware_avgpool3d_backward(int boxes_num, int channels,
|
|||
grad_out[0] * cur_grad);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif // ROIAWARE_POOL3D_CUDA_KERNEL_CUH
|
||||
|
|
|
@ -42,14 +42,13 @@ __global__ void assign_pts_to_box3d(int batch_size, int pts_num, int boxes_num,
|
|||
// params boxes3d: (B, M, 7)
|
||||
// params pts_assign: (B, N, M): idx of the corresponding box3d, -1 means
|
||||
// background points
|
||||
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int box_idx = blockIdx.y;
|
||||
int bs_idx = blockIdx.z;
|
||||
CUDA_1D_KERNEL_LOOP(pt_idx, pts_num) {
|
||||
if (box_idx >= boxes_num || bs_idx >= batch_size) return;
|
||||
|
||||
if (pt_idx >= pts_num || box_idx >= boxes_num || bs_idx >= batch_size) {
|
||||
return;
|
||||
}
|
||||
int assign_idx = bs_idx * pts_num * boxes_num + pt_idx * boxes_num + box_idx;
|
||||
int assign_idx =
|
||||
bs_idx * pts_num * boxes_num + pt_idx * boxes_num + box_idx;
|
||||
pts_assign[assign_idx] = 0;
|
||||
|
||||
int box_offset = bs_idx * boxes_num * 7 + box_idx * 7;
|
||||
|
@ -60,6 +59,7 @@ __global__ void assign_pts_to_box3d(int batch_size, int pts_num, int boxes_num,
|
|||
local_x, local_y);
|
||||
pts_assign[assign_idx] = cur_in_flag;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void get_pooled_idx(int batch_size, int pts_num, int boxes_num,
|
||||
int sampled_pts_num, const int *pts_assign,
|
||||
|
@ -69,17 +69,13 @@ __global__ void get_pooled_idx(int batch_size, int pts_num, int boxes_num,
|
|||
// params pts_assign: (B, N)
|
||||
// params pts_idx: (B, M, 512)
|
||||
// params pooled_empty_flag: (B, M)
|
||||
|
||||
int boxes_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (boxes_idx >= boxes_num) {
|
||||
return;
|
||||
}
|
||||
|
||||
CUDA_1D_KERNEL_LOOP(boxes_idx, boxes_num) {
|
||||
int bs_idx = blockIdx.y;
|
||||
|
||||
int cnt = 0;
|
||||
for (int k = 0; k < pts_num; k++) {
|
||||
if (pts_assign[bs_idx * pts_num * boxes_num + k * boxes_num + boxes_idx]) {
|
||||
if (pts_assign[bs_idx * pts_num * boxes_num + k * boxes_num +
|
||||
boxes_idx]) {
|
||||
if (cnt < sampled_pts_num) {
|
||||
pts_idx[bs_idx * boxes_num * sampled_pts_num +
|
||||
boxes_idx * sampled_pts_num + cnt] = k;
|
||||
|
@ -101,6 +97,7 @@ __global__ void get_pooled_idx(int batch_size, int pts_num, int boxes_num,
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void roipoint_pool3d_forward(
|
||||
|
@ -112,19 +109,11 @@ __global__ void roipoint_pool3d_forward(
|
|||
// params pts_feature: (B, N, C)
|
||||
// params pooled_features: (B, M, 512, 3+C)
|
||||
// params pooled_empty_flag: (B, M)
|
||||
|
||||
int sample_pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int box_idx = blockIdx.y;
|
||||
int bs_idx = blockIdx.z;
|
||||
|
||||
if (sample_pt_idx >= sampled_pts_num || box_idx >= boxes_num ||
|
||||
bs_idx >= batch_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (pooled_empty_flag[bs_idx * boxes_num + box_idx]) {
|
||||
return;
|
||||
}
|
||||
CUDA_1D_KERNEL_LOOP(sample_pt_idx, sampled_pts_num) {
|
||||
if (box_idx >= boxes_num || bs_idx >= batch_size) return;
|
||||
if (pooled_empty_flag[bs_idx * boxes_num + box_idx]) return;
|
||||
|
||||
int temp_idx = bs_idx * boxes_num * sampled_pts_num +
|
||||
box_idx * sampled_pts_num + sample_pt_idx;
|
||||
|
@ -140,5 +129,6 @@ __global__ void roipoint_pool3d_forward(
|
|||
memcpy(pooled_features + dst_feature_offset + 3,
|
||||
pts_feature + src_feature_offset, feature_in_len * sizeof(T));
|
||||
}
|
||||
}
|
||||
|
||||
#endif // ROIPOINT_POOL3D_CUDA_KERNEL_CUH
|
||||
|
|
|
@ -20,9 +20,8 @@ __global__ void three_interpolate_forward_cuda_kernel(
|
|||
|
||||
int bs_idx = blockIdx.z;
|
||||
int c_idx = blockIdx.y;
|
||||
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (bs_idx >= b || c_idx >= c || pt_idx >= n) return;
|
||||
CUDA_1D_KERNEL_LOOP(pt_idx, n) {
|
||||
if (bs_idx >= b || c_idx >= c) return;
|
||||
|
||||
weight += bs_idx * n * 3 + pt_idx * 3;
|
||||
points += bs_idx * c * m + c_idx * m;
|
||||
|
@ -32,6 +31,7 @@ __global__ void three_interpolate_forward_cuda_kernel(
|
|||
out[pt_idx] = weight[0] * points[idx[0]] + weight[1] * points[idx[1]] +
|
||||
weight[2] * points[idx[2]];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void three_interpolate_backward_cuda_kernel(
|
||||
|
@ -44,9 +44,8 @@ __global__ void three_interpolate_backward_cuda_kernel(
|
|||
|
||||
int bs_idx = blockIdx.z;
|
||||
int c_idx = blockIdx.y;
|
||||
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (bs_idx >= b || c_idx >= c || pt_idx >= n) return;
|
||||
CUDA_1D_KERNEL_LOOP(pt_idx, n) {
|
||||
if (bs_idx >= b || c_idx >= c) return;
|
||||
|
||||
grad_out += bs_idx * c * n + c_idx * n + pt_idx;
|
||||
weight += bs_idx * n * 3 + pt_idx * 3;
|
||||
|
@ -57,5 +56,6 @@ __global__ void three_interpolate_backward_cuda_kernel(
|
|||
atomicAdd(grad_points + idx[1], grad_out[0] * weight[1]);
|
||||
atomicAdd(grad_points + idx[2], grad_out[0] * weight[2]);
|
||||
}
|
||||
}
|
||||
|
||||
#endif // THREE_INTERPOLATE_CUDA_KERNEL_CUH
|
||||
|
|
|
@ -19,8 +19,8 @@ __global__ void three_nn_forward_cuda_kernel(int b, int n, int m,
|
|||
// idx: (B, N, 3)
|
||||
|
||||
int bs_idx = blockIdx.y;
|
||||
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (bs_idx >= b || pt_idx >= n) return;
|
||||
CUDA_1D_KERNEL_LOOP(pt_idx, n) {
|
||||
if (bs_idx >= b) return;
|
||||
|
||||
unknown += bs_idx * n * 3 + pt_idx * 3;
|
||||
known += bs_idx * m * 3;
|
||||
|
@ -62,5 +62,6 @@ __global__ void three_nn_forward_cuda_kernel(int b, int n, int m,
|
|||
idx[1] = besti2;
|
||||
idx[2] = besti3;
|
||||
}
|
||||
}
|
||||
|
||||
#endif // THREE_NN_CUDA_KERNEL_CUH
|
||||
|
|
|
@ -101,7 +101,7 @@ __global__ void point_to_voxelidx_kernel(const T_int* coor,
|
|||
CUDA_1D_KERNEL_LOOP(index, num_points) {
|
||||
auto coor_offset = coor + index * NDim;
|
||||
// skip invalid points
|
||||
if ((index >= num_points) || (coor_offset[0] == -1)) return;
|
||||
if (coor_offset[0] == -1) return;
|
||||
|
||||
int num = 0;
|
||||
int coor_x = coor_offset[0];
|
||||
|
|
|
@ -6,8 +6,6 @@
|
|||
|
||||
using namespace at;
|
||||
|
||||
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
|
||||
|
||||
#define CHECK_CUDA(x) \
|
||||
TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CPU(x) \
|
||||
|
|
|
@ -73,8 +73,8 @@ void iou3d_nms_forward(Tensor boxes, Tensor keep, Tensor keep_num,
|
|||
int64_t *keep_data = keep.data_ptr<int64_t>();
|
||||
int64_t *keep_num_data = keep_num.data_ptr<int64_t>();
|
||||
|
||||
const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);
|
||||
|
||||
const int col_blocks =
|
||||
(boxes_num + THREADS_PER_BLOCK_NMS - 1) / THREADS_PER_BLOCK_NMS;
|
||||
Tensor mask =
|
||||
at::empty({boxes_num, col_blocks}, boxes.options().dtype(at::kLong));
|
||||
unsigned long long *mask_data =
|
||||
|
|
|
@ -13,7 +13,7 @@ void AssignScoreWithKForwardCUDAKernelLauncher(
|
|||
at::cuda::CUDAGuard device_guard(points.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
dim3 blocks(DIVUP(B * O * N1 * K, THREADS_PER_BLOCK));
|
||||
dim3 blocks(GET_BLOCKS(B * O * N1 * K, THREADS_PER_BLOCK));
|
||||
dim3 threads(THREADS_PER_BLOCK);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
|
@ -36,9 +36,9 @@ void AssignScoreWithKBackwardCUDAKernelLauncher(
|
|||
at::cuda::CUDAGuard device_guard(grad_out.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
dim3 blocks1(DIVUP(B * M * O, THREADS_PER_BLOCK));
|
||||
dim3 blocks1(GET_BLOCKS(B * M * O, THREADS_PER_BLOCK));
|
||||
dim3 threads1(THREADS_PER_BLOCK);
|
||||
dim3 blocks2(DIVUP(B * N1 * K * M, THREADS_PER_BLOCK));
|
||||
dim3 blocks2(GET_BLOCKS(B * N1 * K * M, THREADS_PER_BLOCK));
|
||||
dim3 threads2(THREADS_PER_BLOCK);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
|
|
|
@ -22,7 +22,7 @@ void BallQueryForwardCUDAKernelLauncher(int b, int n, int m, float min_radius,
|
|||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
// blockIdx.x(col), blockIdx.y(row)
|
||||
dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b);
|
||||
dim3 blocks(GET_BLOCKS(m, THREADS_PER_BLOCK), b);
|
||||
dim3 threads(THREADS_PER_BLOCK);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
|
|
|
@ -16,7 +16,7 @@ void GatherPointsForwardCUDAKernelLauncher(int b, int c, int n, int npoints,
|
|||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
// blockIdx.x(col), blockIdx.y(row)
|
||||
dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b);
|
||||
dim3 blocks(GET_BLOCKS(npoints, THREADS_PER_BLOCK), c, b);
|
||||
dim3 threads(THREADS_PER_BLOCK);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
|
@ -43,7 +43,7 @@ void GatherPointsBackwardCUDAKernelLauncher(int b, int c, int n, int npoints,
|
|||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
// blockIdx.x(col), blockIdx.y(row)
|
||||
dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b);
|
||||
dim3 blocks(GET_BLOCKS(npoints, THREADS_PER_BLOCK), c, b);
|
||||
dim3 threads(THREADS_PER_BLOCK);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
|
|
|
@ -19,7 +19,7 @@ void GroupPointsForwardCUDAKernelLauncher(int b, int c, int n, int npoints,
|
|||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
// blockIdx.x(col), blockIdx.y(row)
|
||||
dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b);
|
||||
dim3 blocks(GET_BLOCKS(npoints * nsample, THREADS_PER_BLOCK), c, b);
|
||||
dim3 threads(THREADS_PER_BLOCK);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
|
@ -46,7 +46,7 @@ void GroupPointsBackwardCUDAKernelLauncher(int b, int c, int n, int npoints,
|
|||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
// blockIdx.x(col), blockIdx.y(row)
|
||||
dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b);
|
||||
dim3 blocks(GET_BLOCKS(npoints * nsample, THREADS_PER_BLOCK), c, b);
|
||||
dim3 threads(THREADS_PER_BLOCK);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
|
|
|
@ -21,8 +21,8 @@ void IoU3DBoxesOverlapBevForwardCUDAKernelLauncher(const int num_a,
|
|||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
// blockIdx.x(col), blockIdx.y(row)
|
||||
dim3 blocks(DIVUP(num_b, THREADS_PER_BLOCK_IOU3D),
|
||||
DIVUP(num_a, THREADS_PER_BLOCK_IOU3D));
|
||||
dim3 blocks(GET_BLOCKS(num_b, THREADS_PER_BLOCK_IOU3D),
|
||||
GET_BLOCKS(num_a, THREADS_PER_BLOCK_IOU3D));
|
||||
dim3 threads(THREADS_PER_BLOCK_IOU3D, THREADS_PER_BLOCK_IOU3D);
|
||||
|
||||
iou3d_boxes_overlap_bev_forward_cuda_kernel<<<blocks, threads, 0, stream>>>(
|
||||
|
@ -41,8 +41,8 @@ void IoU3DBoxesIoUBevForwardCUDAKernelLauncher(const int num_a,
|
|||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
// blockIdx.x(col), blockIdx.y(row)
|
||||
dim3 blocks(DIVUP(num_b, THREADS_PER_BLOCK_IOU3D),
|
||||
DIVUP(num_a, THREADS_PER_BLOCK_IOU3D));
|
||||
dim3 blocks(GET_BLOCKS(num_b, THREADS_PER_BLOCK_IOU3D),
|
||||
GET_BLOCKS(num_a, THREADS_PER_BLOCK_IOU3D));
|
||||
dim3 threads(THREADS_PER_BLOCK_IOU3D, THREADS_PER_BLOCK_IOU3D);
|
||||
|
||||
iou3d_boxes_iou_bev_forward_cuda_kernel<<<blocks, threads, 0, stream>>>(
|
||||
|
@ -58,8 +58,8 @@ void IoU3DNMSForwardCUDAKernelLauncher(const Tensor boxes,
|
|||
at::cuda::CUDAGuard device_guard(boxes.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
dim3 blocks(DIVUP(boxes_num, THREADS_PER_BLOCK_NMS),
|
||||
DIVUP(boxes_num, THREADS_PER_BLOCK_NMS));
|
||||
dim3 blocks(GET_BLOCKS(boxes_num, THREADS_PER_BLOCK_NMS),
|
||||
GET_BLOCKS(boxes_num, THREADS_PER_BLOCK_NMS));
|
||||
dim3 threads(THREADS_PER_BLOCK_NMS);
|
||||
|
||||
nms_forward_cuda_kernel<<<blocks, threads, 0, stream>>>(
|
||||
|
@ -75,8 +75,8 @@ void IoU3DNMSNormalForwardCUDAKernelLauncher(const Tensor boxes,
|
|||
at::cuda::CUDAGuard device_guard(boxes.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
dim3 blocks(DIVUP(boxes_num, THREADS_PER_BLOCK_NMS),
|
||||
DIVUP(boxes_num, THREADS_PER_BLOCK_NMS));
|
||||
dim3 blocks(GET_BLOCKS(boxes_num, THREADS_PER_BLOCK_NMS),
|
||||
GET_BLOCKS(boxes_num, THREADS_PER_BLOCK_NMS));
|
||||
dim3 threads(THREADS_PER_BLOCK_NMS);
|
||||
|
||||
nms_normal_forward_cuda_kernel<<<blocks, threads, 0, stream>>>(
|
||||
|
|
|
@ -19,7 +19,7 @@ void KNNForwardCUDAKernelLauncher(int b, int n, int m, int nsample,
|
|||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
// blockIdx.x(col), blockIdx.y(row)
|
||||
dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b);
|
||||
dim3 blocks(GET_BLOCKS(m, THREADS_PER_BLOCK), b);
|
||||
dim3 threads(THREADS_PER_BLOCK);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
|
|
|
@ -13,10 +13,11 @@ Tensor NMSCUDAKernelLauncher(Tensor boxes, Tensor scores, float iou_threshold,
|
|||
auto boxes_sorted = boxes.index_select(0, order_t);
|
||||
|
||||
int boxes_num = boxes.size(0);
|
||||
const int col_blocks = DIVUP(boxes_num, threadsPerBlock);
|
||||
const int col_blocks = (boxes_num + threadsPerBlock - 1) / threadsPerBlock;
|
||||
const int col_blocks_alloc = GET_BLOCKS(boxes_num, threadsPerBlock);
|
||||
Tensor mask =
|
||||
at::empty({boxes_num, col_blocks}, boxes.options().dtype(at::kLong));
|
||||
dim3 blocks(col_blocks, col_blocks);
|
||||
dim3 blocks(col_blocks_alloc, col_blocks_alloc);
|
||||
dim3 threads(threadsPerBlock);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
nms_cuda<<<blocks, threads, 0, stream>>>(
|
||||
|
|
|
@ -21,7 +21,7 @@ void PointsInBoxesPartForwardCUDAKernelLauncher(int batch_size, int boxes_num,
|
|||
at::cuda::CUDAGuard device_guard(boxes.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
dim3 blocks(DIVUP(pts_num, THREADS_PER_BLOCK), batch_size);
|
||||
dim3 blocks(GET_BLOCKS(pts_num, THREADS_PER_BLOCK), batch_size);
|
||||
dim3 threads(THREADS_PER_BLOCK);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
|
@ -47,7 +47,7 @@ void PointsInBoxesAllForwardCUDAKernelLauncher(int batch_size, int boxes_num,
|
|||
at::cuda::CUDAGuard device_guard(boxes.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
dim3 blocks(DIVUP(pts_num, THREADS_PER_BLOCK), batch_size);
|
||||
dim3 blocks(GET_BLOCKS(pts_num, THREADS_PER_BLOCK), batch_size);
|
||||
dim3 threads(THREADS_PER_BLOCK);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
|
|
|
@ -26,7 +26,7 @@ void RoiawarePool3dForwardCUDAKernelLauncher(
|
|||
Tensor pts_mask =
|
||||
-at::ones({boxes_num, pts_num}, pts_feature.options().dtype(at::kInt));
|
||||
|
||||
dim3 blocks_mask(DIVUP(pts_num, THREADS_PER_BLOCK), boxes_num);
|
||||
dim3 blocks_mask(GET_BLOCKS(pts_num, THREADS_PER_BLOCK), boxes_num);
|
||||
dim3 threads(THREADS_PER_BLOCK);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
|
@ -42,7 +42,7 @@ void RoiawarePool3dForwardCUDAKernelLauncher(
|
|||
|
||||
// TODO: Merge the collect and pool functions, SS
|
||||
|
||||
dim3 blocks_collect(DIVUP(boxes_num, THREADS_PER_BLOCK));
|
||||
dim3 blocks_collect(GET_BLOCKS(boxes_num, THREADS_PER_BLOCK));
|
||||
|
||||
AT_DISPATCH_INTEGRAL_TYPES(
|
||||
pts_idx_of_voxels.scalar_type(), "collect_inside_pts_for_box3d", [&] {
|
||||
|
@ -55,8 +55,8 @@ void RoiawarePool3dForwardCUDAKernelLauncher(
|
|||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
dim3 blocks_pool(DIVUP(out_x * out_y * out_z, THREADS_PER_BLOCK), channels,
|
||||
boxes_num);
|
||||
dim3 blocks_pool(GET_BLOCKS(out_x * out_y * out_z, THREADS_PER_BLOCK),
|
||||
channels, boxes_num);
|
||||
if (pool_method == 0) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
pts_feature.scalar_type(), "roiaware_maxpool3d", [&] {
|
||||
|
@ -93,7 +93,7 @@ void RoiawarePool3dBackwardCUDAKernelLauncher(
|
|||
at::cuda::CUDAGuard device_guard(grad_out.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
dim3 blocks(DIVUP(out_x * out_y * out_z, THREADS_PER_BLOCK), channels,
|
||||
dim3 blocks(GET_BLOCKS(out_x * out_y * out_z, THREADS_PER_BLOCK), channels,
|
||||
boxes_num);
|
||||
dim3 threads(THREADS_PER_BLOCK);
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ void RoIPointPool3dForwardCUDAKernelLauncher(
|
|||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
// blockIdx.x(col), blockIdx.y(row)
|
||||
dim3 blocks(DIVUP(pts_num, THREADS_PER_BLOCK), boxes_num, batch_size);
|
||||
dim3 blocks(GET_BLOCKS(pts_num, THREADS_PER_BLOCK), boxes_num, batch_size);
|
||||
dim3 threads(THREADS_PER_BLOCK);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
|
@ -38,14 +38,14 @@ void RoIPointPool3dForwardCUDAKernelLauncher(
|
|||
boxes3d.options().dtype(at::kInt));
|
||||
|
||||
// blockIdx.x(col), blockIdx.y(row)
|
||||
dim3 blocks2(DIVUP(boxes_num, THREADS_PER_BLOCK), batch_size);
|
||||
dim3 blocks2(GET_BLOCKS(boxes_num, THREADS_PER_BLOCK), batch_size);
|
||||
|
||||
get_pooled_idx<<<blocks2, threads, 0, stream>>>(
|
||||
batch_size, pts_num, boxes_num, sampled_pts_num,
|
||||
pts_assign.data_ptr<int>(), pts_idx.data_ptr<int>(),
|
||||
pooled_empty_flag.data_ptr<int>());
|
||||
|
||||
dim3 blocks_pool(DIVUP(sampled_pts_num, THREADS_PER_BLOCK), boxes_num,
|
||||
dim3 blocks_pool(GET_BLOCKS(sampled_pts_num, THREADS_PER_BLOCK), boxes_num,
|
||||
batch_size);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
|
|
|
@ -23,7 +23,7 @@ void ThreeInterpolateForwardCUDAKernelLauncher(int b, int c, int m, int n,
|
|||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
// blockIdx.x(col), blockIdx.y(row)
|
||||
dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b);
|
||||
dim3 blocks(GET_BLOCKS(n, THREADS_PER_BLOCK), c, b);
|
||||
dim3 threads(THREADS_PER_BLOCK);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
|
@ -51,7 +51,7 @@ void ThreeInterpolateBackwardCUDAKernelLauncher(int b, int c, int n, int m,
|
|||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
// blockIdx.x(col), blockIdx.y(row)
|
||||
dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b);
|
||||
dim3 blocks(GET_BLOCKS(n, THREADS_PER_BLOCK), c, b);
|
||||
dim3 threads(THREADS_PER_BLOCK);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
|
|
|
@ -21,7 +21,7 @@ void ThreeNNForwardCUDAKernelLauncher(int b, int n, int m, const Tensor unknown,
|
|||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
// blockIdx.x(col), blockIdx.y(row)
|
||||
dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), b);
|
||||
dim3 blocks(GET_BLOCKS(n, THREADS_PER_BLOCK), b);
|
||||
dim3 threads(THREADS_PER_BLOCK);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
|
|
|
@ -73,7 +73,8 @@ void iou3d_nms_forward(Tensor boxes, Tensor keep, Tensor keep_num,
|
|||
int64_t *keep_data = keep.data_ptr<int64_t>();
|
||||
int64_t *keep_num_data = keep_num.data_ptr<int64_t>();
|
||||
|
||||
const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);
|
||||
const int col_blocks =
|
||||
(boxes_num + THREADS_PER_BLOCK_NMS - 1) / THREADS_PER_BLOCK_NMS;
|
||||
|
||||
Tensor mask =
|
||||
at::empty({boxes_num, col_blocks}, boxes.options().dtype(at::kLong));
|
||||
|
@ -117,7 +118,8 @@ void iou3d_nms_normal_forward(Tensor boxes, Tensor keep, Tensor keep_num,
|
|||
int64_t *keep_data = keep.data_ptr<int64_t>();
|
||||
int64_t *keep_num_data = keep_num.data_ptr<int64_t>();
|
||||
|
||||
const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);
|
||||
const int col_blocks =
|
||||
(boxes_num + THREADS_PER_BLOCK_NMS - 1) / THREADS_PER_BLOCK_NMS;
|
||||
|
||||
Tensor mask =
|
||||
at::empty({boxes_num, col_blocks}, boxes.options().dtype(at::kLong));
|
||||
|
|
|
@ -85,7 +85,7 @@ void CornerPoolForwardLauncher(const scalar_t *input, scalar_t *output,
|
|||
case 0:
|
||||
case 1:
|
||||
nthreads = batch_size * channels * width;
|
||||
col_block = DIVUP(nthreads, THREADS_PER_BLOCK);
|
||||
col_block = GET_BLOCKS(nthreads, THREADS_PER_BLOCK);
|
||||
top_bottom_pool_kernel<scalar_t>
|
||||
<<<col_block, THREADS_PER_BLOCK, 0, stream>>>(
|
||||
input, output, batch_size, channels, height, width, pool_type);
|
||||
|
@ -93,7 +93,7 @@ void CornerPoolForwardLauncher(const scalar_t *input, scalar_t *output,
|
|||
case 2:
|
||||
case 3:
|
||||
nthreads = batch_size * channels * height;
|
||||
col_block = DIVUP(nthreads, THREADS_PER_BLOCK);
|
||||
col_block = GET_BLOCKS(nthreads, THREADS_PER_BLOCK);
|
||||
left_right_pool_kernel<scalar_t>
|
||||
<<<col_block, THREADS_PER_BLOCK, 0, stream>>>(
|
||||
input, output, batch_size, channels, height, width, pool_type);
|
||||
|
|
|
@ -67,7 +67,7 @@ void CumMaxMinForwardLauncher(const scalar_t *input, scalar_t *output_value,
|
|||
const int data_size =
|
||||
tensor_desc.stride[0] * tensor_desc.shape[0] / tensor_desc.shape[cum_dim];
|
||||
|
||||
const int col_block = DIVUP(data_size, THREADS_PER_BLOCK);
|
||||
const int col_block = GET_BLOCKS(data_size, THREADS_PER_BLOCK);
|
||||
|
||||
cummaxmin_kernel<scalar_t><<<col_block, THREADS_PER_BLOCK, 0, stream>>>(
|
||||
input, output_value, output_index, tensor_desc, cum_dim, cum_type);
|
||||
|
|
|
@ -114,7 +114,8 @@ size_t get_onnxnms_workspace_size(size_t num_batches, size_t spatial_dimension,
|
|||
mmcv::getAlignedSize(spatial_dimension * boxes_word_size);
|
||||
size_t boxes_workspace =
|
||||
mmcv::getAlignedSize(spatial_dimension * 4 * boxes_word_size);
|
||||
const int col_blocks = DIVUP(spatial_dimension, threadsPerBlock);
|
||||
const int col_blocks =
|
||||
(spatial_dimension + threadsPerBlock - 1) / threadsPerBlock;
|
||||
size_t mask_workspace = mmcv::getAlignedSize(spatial_dimension * col_blocks *
|
||||
sizeof(unsigned long long));
|
||||
size_t index_template_workspace =
|
||||
|
@ -163,7 +164,8 @@ void TRTNMSCUDAKernelLauncher_float(const float* boxes, const float* scores,
|
|||
int spatial_dimension, int num_classes,
|
||||
size_t output_length, void* workspace,
|
||||
cudaStream_t stream) {
|
||||
const int col_blocks = DIVUP(spatial_dimension, threadsPerBlock);
|
||||
const int col_blocks =
|
||||
(spatial_dimension + threadsPerBlock - 1) / threadsPerBlock;
|
||||
float* boxes_sorted = (float*)workspace;
|
||||
workspace = static_cast<char*>(workspace) +
|
||||
mmcv::getAlignedSize(spatial_dimension * 4 * sizeof(float));
|
||||
|
|
|
@ -67,7 +67,7 @@ void TRTONNXScatterNDKernelLauncher(const T* data, const int* indices,
|
|||
num_update_indice *= indice_desc.shape[i];
|
||||
}
|
||||
// scatter
|
||||
const int col_block = DIVUP(num_update_indice, threadsPerBlock);
|
||||
const int col_block = GET_BLOCKS(num_update_indice, threadsPerBlock);
|
||||
onnx_scatternd_kernel<<<col_block, threadsPerBlock, 0, stream>>>(
|
||||
num_update_indice, indices, update, output, tensor_desc, indice_desc);
|
||||
}
|
||||
|
|
|
@ -3,8 +3,6 @@
|
|||
#define TRT_CUDA_HELPER_HPP
|
||||
#include <cublas_v2.h>
|
||||
|
||||
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
|
||||
|
||||
#define cudaCheckError() \
|
||||
{ \
|
||||
cudaError_t e = cudaGetLastError(); \
|
||||
|
|
Loading…
Reference in New Issue