[Refactor] Replace DIVUP with GET_BLOCKS (#1586)

* [Improve] migrating DIVUP to GET_BLOCKS

* [Fix] use GET_BLOCKS only for block alloc and del useless statements

* [Fix] add kernel loop for nms and del useless statements
pull/1515/merge
Jiazhen Wang 2022-01-08 11:35:16 +08:00 committed by GitHub
parent cf754db983
commit b586cc2f6a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 588 additions and 589 deletions

View File

@ -22,8 +22,7 @@ __global__ void assign_score_withk_forward_cuda_kernel(
const int O, const int aggregate, const T* points, const T* centers,
const T* scores, const int64_t* knn_idx, T* output) {
// ----- parallel loop for B, N1, K and O ---------
long i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= B * N1 * K * O) return;
CUDA_1D_KERNEL_LOOP(i, B * O * N1 * K) {
// ------- loop for M ----------
const int b = (int)(i / (O * N1 * K));
const int o = (int)(i % (O * N1 * K) / (N1 * K));
@ -51,6 +50,7 @@ __global__ void assign_score_withk_forward_cuda_kernel(
}
output[out_idx] = val;
}
}
template <typename T>
__global__ void assign_score_withk_points_backward_cuda_kernel(
@ -58,8 +58,7 @@ __global__ void assign_score_withk_points_backward_cuda_kernel(
const int O, const int aggregate, const T* grad_out, const T* scores,
const int64_t* knn_idx, T* grad_points, T* grad_centers) {
// ----- parallel loop for B, M, O ---------
long i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= B * M * O) return;
CUDA_1D_KERNEL_LOOP(i, B * M * O) {
int b = (int)(i / (M * O));
int m = (int)(i % (M * O) / O);
int o = (int)(i % O);
@ -69,8 +68,8 @@ __global__ void assign_score_withk_points_backward_cuda_kernel(
for (int k = 0; k < K; k++) {
int kn = knn_idx[b * N * K + n * K + k];
int cn = knn_idx[b * N * K + n * K + 0];
if (kn >= N0 ||
kn < 0) { // if index overflows, it is out of the neighborhood range
if (kn >= N0 || kn < 0) { // if index overflows, it is out of the
// neighborhood range
continue;
}
atomicAdd(grad_points + b * N0 * M * O + kn * M * O + m * O + o,
@ -82,6 +81,7 @@ __global__ void assign_score_withk_points_backward_cuda_kernel(
}
}
}
}
template <typename T>
__global__ void assign_score_withk_scores_backward_cuda_kernel(
@ -89,8 +89,7 @@ __global__ void assign_score_withk_scores_backward_cuda_kernel(
const int O, const int aggregate, const T* grad_out, const T* points,
const T* centers, const int64_t* knn_idx, T* grad_scores) {
// ----- parallel loop for B, N, K, M ---------
long i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= B * N * K * M) return;
CUDA_1D_KERNEL_LOOP(i, B * N * K * M) {
const int b = (int)(i / (N * M * K));
const int n = (int)(i % (N * M * K) / M / K);
const int k = (int)(i % (M * K) / M);
@ -112,5 +111,6 @@ __global__ void assign_score_withk_scores_backward_cuda_kernel(
}
grad_scores[out_idx] = val;
}
}
#endif // ASSIGN_SCORE_WITHK_CUDA_KERNEL_CUH

View File

@ -21,8 +21,8 @@ __global__ void ball_query_forward_cuda_kernel(int b, int n, int m,
// output:
// idx: (B, M, nsample)
int bs_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (bs_idx >= b || pt_idx >= m) return;
CUDA_1D_KERNEL_LOOP(pt_idx, m) {
if (bs_idx >= b) return;
new_xyz += bs_idx * m * 3 + pt_idx * 3;
xyz += bs_idx * n * 3;
@ -53,5 +53,6 @@ __global__ void ball_query_forward_cuda_kernel(int b, int n, int m,
}
}
}
}
#endif // BALL_QUERY_CUDA_KERNEL_CUH

View File

@ -7,12 +7,20 @@
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
#define CUDA_2D_KERNEL_LOOP(i, n, j, m) \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x) \
for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); \
j += blockDim.y * gridDim.y)
#define CUDA_2D_KERNEL_BLOCK_LOOP(i, n, j, m) \
for (size_t i = blockIdx.x; i < (n); i += gridDim.x) \
for (size_t j = blockIdx.y; j < (m); j += gridDim.y)
#define THREADS_PER_BLOCK 512
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
inline int GET_BLOCKS(const int N) {
int optimal_block_num = (N + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
inline int GET_BLOCKS(const int N, const int num_threads = THREADS_PER_BLOCK) {
int optimal_block_num = (N + num_threads - 1) / num_threads;
int max_block_num = 4096;
return min(optimal_block_num, max_block_num);
}

View File

@ -22,14 +22,15 @@ __global__ void gather_points_forward_cuda_kernel(int b, int c, int n, int m,
int bs_idx = blockIdx.z;
int c_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (bs_idx >= b || c_idx >= c || pt_idx >= m) return;
CUDA_1D_KERNEL_LOOP(pt_idx, m) {
if (bs_idx >= b || c_idx >= c) return;
out += bs_idx * c * m + c_idx * m + pt_idx;
idx += bs_idx * m + pt_idx;
points += bs_idx * c * n + c_idx * n;
out[0] = points[idx[0]];
}
}
template <typename T>
__global__ void gather_points_backward_cuda_kernel(int b, int c, int n, int m,
@ -43,8 +44,8 @@ __global__ void gather_points_backward_cuda_kernel(int b, int c, int n, int m,
int bs_idx = blockIdx.z;
int c_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (bs_idx >= b || c_idx >= c || pt_idx >= m) return;
CUDA_1D_KERNEL_LOOP(pt_idx, m) {
if (bs_idx >= b || c_idx >= c) return;
grad_out += bs_idx * c * m + c_idx * m + pt_idx;
idx += bs_idx * m + pt_idx;
@ -52,5 +53,6 @@ __global__ void gather_points_backward_cuda_kernel(int b, int c, int n, int m,
atomicAdd(grad_points + idx[0], grad_out[0]);
}
}
#endif // GATHER_POINTS_CUDA_KERNEL_CUH

View File

@ -22,10 +22,10 @@ __global__ void group_points_forward_cuda_kernel(int b, int c, int n,
// out: (B, C, npoints, nsample)
int bs_idx = blockIdx.z;
int c_idx = blockIdx.y;
int index = blockIdx.x * blockDim.x + threadIdx.x;
int pt_idx = index / nsample;
if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return;
CUDA_1D_KERNEL_LOOP(index, npoints * nsample) {
if (bs_idx >= b || c_idx >= c) return;
int pt_idx = index / nsample;
int sample_idx = index % nsample;
idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx;
@ -35,6 +35,7 @@ __global__ void group_points_forward_cuda_kernel(int b, int c, int n,
out[out_idx] = points[in_idx];
}
}
template <typename T>
__global__ void group_points_backward_cuda_kernel(int b, int c, int n,
@ -48,9 +49,9 @@ __global__ void group_points_backward_cuda_kernel(int b, int c, int n,
// grad_points: (B, C, N)
int bs_idx = blockIdx.z;
int c_idx = blockIdx.y;
int index = blockIdx.x * blockDim.x + threadIdx.x;
CUDA_1D_KERNEL_LOOP(index, npoints * nsample) {
int pt_idx = index / nsample;
if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return;
if (bs_idx >= b || c_idx >= c) return;
int sample_idx = index % nsample;
grad_out += bs_idx * c * npoints * nsample + c_idx * npoints * nsample +
@ -59,5 +60,6 @@ __global__ void group_points_backward_cuda_kernel(int b, int c, int n,
atomicAdd(grad_points + bs_idx * c * n + c_idx * n + idx[0], grad_out[0]);
}
}
#endif // GROUP_POINTS_CUDA_KERNEL_CUH

View File

@ -220,9 +220,7 @@ __device__ inline float iou_bev(const float *box_a, const float *box_b) {
__global__ void iou3d_boxes_overlap_bev_forward_cuda_kernel(
const int num_a, const float *boxes_a, const int num_b,
const float *boxes_b, float *ans_overlap) {
const int a_idx = blockIdx.y * THREADS_PER_BLOCK + threadIdx.y;
const int b_idx = blockIdx.x * THREADS_PER_BLOCK + threadIdx.x;
CUDA_2D_KERNEL_LOOP(b_idx, num_b, a_idx, num_a) {
if (a_idx >= num_a || b_idx >= num_b) {
return;
}
@ -231,15 +229,14 @@ __global__ void iou3d_boxes_overlap_bev_forward_cuda_kernel(
float s_overlap = box_overlap(cur_box_a, cur_box_b);
ans_overlap[a_idx * num_b + b_idx] = s_overlap;
}
}
__global__ void iou3d_boxes_iou_bev_forward_cuda_kernel(const int num_a,
const float *boxes_a,
const int num_b,
const float *boxes_b,
float *ans_iou) {
const int a_idx = blockIdx.y * THREADS_PER_BLOCK + threadIdx.y;
const int b_idx = blockIdx.x * THREADS_PER_BLOCK + threadIdx.x;
CUDA_2D_KERNEL_LOOP(b_idx, num_b, a_idx, num_a) {
if (a_idx >= num_a || b_idx >= num_b) {
return;
}
@ -249,6 +246,7 @@ __global__ void iou3d_boxes_iou_bev_forward_cuda_kernel(const int num_a,
float cur_iou_bev = iou_bev(cur_box_a, cur_box_b);
ans_iou[a_idx * num_b + b_idx] = cur_iou_bev;
}
}
__global__ void nms_forward_cuda_kernel(const int boxes_num,
const float nms_overlap_thresh,
@ -256,10 +254,9 @@ __global__ void nms_forward_cuda_kernel(const int boxes_num,
unsigned long long *mask) {
// params: boxes (N, 5) [x1, y1, x2, y2, ry]
// params: mask (N, N/THREADS_PER_BLOCK_NMS)
const int row_start = blockIdx.y;
const int col_start = blockIdx.x;
const int blocks =
(boxes_num + THREADS_PER_BLOCK_NMS - 1) / THREADS_PER_BLOCK_NMS;
CUDA_2D_KERNEL_BLOCK_LOOP(col_start, blocks, row_start, blocks) {
// if (row_start > col_start) return;
const int row_size = fminf(boxes_num - row_start * THREADS_PER_BLOCK_NMS,
@ -298,10 +295,12 @@ __global__ void nms_forward_cuda_kernel(const int boxes_num,
t |= 1ULL << i;
}
}
const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);
const int col_blocks =
(boxes_num + THREADS_PER_BLOCK_NMS - 1) / THREADS_PER_BLOCK_NMS;
mask[cur_box_idx * col_blocks + col_start] = t;
}
}
}
__device__ inline float iou_normal(float const *const a, float const *const b) {
float left = fmaxf(a[0], b[0]), right = fminf(a[2], b[2]);
@ -320,9 +319,9 @@ __global__ void nms_normal_forward_cuda_kernel(const int boxes_num,
// params: boxes (N, 5) [x1, y1, x2, y2, ry]
// params: mask (N, N/THREADS_PER_BLOCK_NMS)
const int row_start = blockIdx.y;
const int col_start = blockIdx.x;
const int blocks =
(boxes_num + THREADS_PER_BLOCK_NMS - 1) / THREADS_PER_BLOCK_NMS;
CUDA_2D_KERNEL_BLOCK_LOOP(col_start, blocks, row_start, blocks) {
// if (row_start > col_start) return;
const int row_size = fminf(boxes_num - row_start * THREADS_PER_BLOCK_NMS,
@ -361,9 +360,11 @@ __global__ void nms_normal_forward_cuda_kernel(const int boxes_num,
t |= 1ULL << i;
}
}
const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);
const int col_blocks =
(boxes_num + THREADS_PER_BLOCK_NMS - 1) / THREADS_PER_BLOCK_NMS;
mask[cur_box_idx * col_blocks + col_start] = t;
}
}
}
#endif // IOU3D_CUDA_KERNEL_CUH

View File

@ -51,8 +51,8 @@ __global__ void knn_forward_cuda_kernel(int b, int n, int m, int nsample,
const T *xyz, const T *new_xyz,
int *__restrict__ idx, T *dist2) {
int bs_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (bs_idx >= b || pt_idx >= m) return;
CUDA_1D_KERNEL_LOOP(pt_idx, m) {
if (bs_idx >= b) return;
new_xyz += bs_idx * m * 3 + pt_idx * 3;
xyz += bs_idx * n * 3;
@ -87,5 +87,6 @@ __global__ void knn_forward_cuda_kernel(int b, int n, int m, int nsample,
dist2[i] = best_dist[i];
}
}
}
#endif // KNN_CUDA_KERNEL_CUH

View File

@ -15,9 +15,6 @@
#include "pytorch_cuda_helper.hpp"
const int CUDA_NUM_THREADS = 1024;
inline int GET_BLOCKS(const int N, const int num_threads) {
return (N + num_threads - 1) / num_threads;
}
template <typename scalar_t>
__device__ scalar_t ms_deform_attn_im2col_bilinear(

View File

@ -30,8 +30,8 @@ __device__ inline bool devIoU(float const *const a, float const *const b,
__global__ void nms_cuda(const int n_boxes, const float iou_threshold,
const int offset, const float *dev_boxes,
unsigned long long *dev_mask) {
const int row_start = blockIdx.y;
const int col_start = blockIdx.x;
int blocks = (n_boxes + threadsPerBlock - 1) / threadsPerBlock;
CUDA_2D_KERNEL_BLOCK_LOOP(col_start, blocks, row_start, blocks) {
const int tid = threadIdx.x;
if (row_start > col_start) return;
@ -71,4 +71,5 @@ __global__ void nms_cuda(const int n_boxes, const float iou_threshold,
dev_mask[cur_box_idx * gridDim.y + col_start] = t;
}
}
}
#endif // NMS_CUDA_KERNEL_CUH

View File

@ -45,8 +45,8 @@ __global__ void points_in_boxes_part_forward_cuda_kernel(
// (B, npoints), default -1
int bs_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (bs_idx >= batch_size || pt_idx >= pts_num) return;
CUDA_1D_KERNEL_LOOP(pt_idx, pts_num) {
if (bs_idx >= batch_size) return;
boxes += bs_idx * boxes_num * 7;
pts += bs_idx * pts_num * 3 + pt_idx * 3;
@ -62,6 +62,7 @@ __global__ void points_in_boxes_part_forward_cuda_kernel(
}
}
}
}
template <typename T>
__global__ void points_in_boxes_all_forward_cuda_kernel(
@ -73,8 +74,8 @@ __global__ void points_in_boxes_all_forward_cuda_kernel(
// (B, npoints), default -1
int bs_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (bs_idx >= batch_size || pt_idx >= pts_num) return;
CUDA_1D_KERNEL_LOOP(pt_idx, pts_num) {
if (bs_idx >= batch_size) return;
boxes += bs_idx * boxes_num * 7;
pts += bs_idx * pts_num * 3 + pt_idx * 3;
@ -89,5 +90,6 @@ __global__ void points_in_boxes_all_forward_cuda_kernel(
}
}
}
}
#endif // POINT_IN_BOXES_CUDA_KERNEL_CUH

View File

@ -44,9 +44,9 @@ __global__ void generate_pts_mask_for_box3d(int boxes_num, int pts_num,
// coordinate params pts: (npoints, 3) [x, y, z] params pts_mask: (N,
// npoints): -1 means point does not in this box, otherwise: encode (x_idxs,
// y_idxs, z_idxs) by binary bit
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
int box_idx = blockIdx.y;
if (pt_idx >= pts_num || box_idx >= boxes_num) return;
CUDA_1D_KERNEL_LOOP(pt_idx, pts_num) {
if (box_idx >= boxes_num) return;
pts += pt_idx * 3;
rois += box_idx * 7;
@ -77,6 +77,7 @@ __global__ void generate_pts_mask_for_box3d(int boxes_num, int pts_num,
pts_mask[0] = idx_encoding;
}
}
}
template <typename T>
__global__ void collect_inside_pts_for_box3d(int boxes_num, int pts_num,
@ -86,10 +87,7 @@ __global__ void collect_inside_pts_for_box3d(int boxes_num, int pts_num,
T *pts_idx_of_voxels) {
// params pts_mask: (N, npoints) 0 or 1
// params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel)
int box_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (box_idx >= boxes_num) return;
CUDA_1D_KERNEL_LOOP(box_idx, boxes_num) {
int max_num_pts = max_pts_each_voxel - 1; // index 0 is the counter
pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel;
@ -110,6 +108,7 @@ __global__ void collect_inside_pts_for_box3d(int boxes_num, int pts_num,
}
}
}
}
template <typename T>
__global__ void roiaware_maxpool3d(int boxes_num, int pts_num, int channels,
@ -124,14 +123,11 @@ __global__ void roiaware_maxpool3d(int boxes_num, int pts_num, int channels,
int box_idx = blockIdx.z;
int channel_idx = blockIdx.y;
int voxel_idx_flat = blockIdx.x * blockDim.x + threadIdx.x;
CUDA_1D_KERNEL_LOOP(voxel_idx_flat, out_x * out_y * out_z) {
int x_idx = voxel_idx_flat / (out_y * out_z);
int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z;
int z_idx = voxel_idx_flat % out_z;
if (box_idx >= boxes_num || channel_idx >= channels || x_idx >= out_x ||
y_idx >= out_y || z_idx >= out_z)
return;
if (box_idx >= boxes_num || channel_idx >= channels) return;
int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx;
pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel +
@ -147,7 +143,8 @@ __global__ void roiaware_maxpool3d(int boxes_num, int pts_num, int channels,
int total_pts = pts_idx_of_voxels[0];
for (int k = 1; k <= total_pts; k++) {
if (pts_feature[pts_idx_of_voxels[k] * channels + channel_idx] > max_val) {
if (pts_feature[pts_idx_of_voxels[k] * channels + channel_idx] >
max_val) {
max_val = pts_feature[pts_idx_of_voxels[k] * channels + channel_idx];
argmax_idx = pts_idx_of_voxels[k];
}
@ -158,6 +155,7 @@ __global__ void roiaware_maxpool3d(int boxes_num, int pts_num, int channels,
}
argmax[0] = argmax_idx;
}
}
template <typename T>
__global__ void roiaware_avgpool3d(int boxes_num, int pts_num, int channels,
@ -172,14 +170,11 @@ __global__ void roiaware_avgpool3d(int boxes_num, int pts_num, int channels,
int box_idx = blockIdx.z;
int channel_idx = blockIdx.y;
int voxel_idx_flat = blockIdx.x * blockDim.x + threadIdx.x;
CUDA_1D_KERNEL_LOOP(voxel_idx_flat, out_x * out_y * out_z) {
int x_idx = voxel_idx_flat / (out_y * out_z);
int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z;
int z_idx = voxel_idx_flat % out_z;
if (box_idx >= boxes_num || channel_idx >= channels || x_idx >= out_x ||
y_idx >= out_y || z_idx >= out_z)
return;
if (box_idx >= boxes_num || channel_idx >= channels) return;
int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx;
pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel +
@ -198,6 +193,7 @@ __global__ void roiaware_avgpool3d(int boxes_num, int pts_num, int channels,
pooled_features[0] = sum_val / total_pts;
}
}
}
template <typename T>
__global__ void roiaware_maxpool3d_backward(int boxes_num, int channels,
@ -210,14 +206,11 @@ __global__ void roiaware_maxpool3d_backward(int boxes_num, int channels,
int box_idx = blockIdx.z;
int channel_idx = blockIdx.y;
int voxel_idx_flat = blockIdx.x * blockDim.x + threadIdx.x;
CUDA_1D_KERNEL_LOOP(voxel_idx_flat, out_x * out_y * out_z) {
int x_idx = voxel_idx_flat / (out_y * out_z);
int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z;
int z_idx = voxel_idx_flat % out_z;
if (box_idx >= boxes_num || channel_idx >= channels || x_idx >= out_x ||
y_idx >= out_y || z_idx >= out_z)
return;
if (box_idx >= boxes_num || channel_idx >= channels) return;
int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx;
argmax += box_idx * out_x * out_y * out_z * channels +
@ -229,6 +222,7 @@ __global__ void roiaware_maxpool3d_backward(int boxes_num, int channels,
atomicAdd(grad_in + argmax[0] * channels + channel_idx, grad_out[0] * 1);
}
}
template <typename T>
__global__ void roiaware_avgpool3d_backward(int boxes_num, int channels,
@ -242,14 +236,11 @@ __global__ void roiaware_avgpool3d_backward(int boxes_num, int channels,
int box_idx = blockIdx.z;
int channel_idx = blockIdx.y;
int voxel_idx_flat = blockIdx.x * blockDim.x + threadIdx.x;
CUDA_1D_KERNEL_LOOP(voxel_idx_flat, out_x * out_y * out_z) {
int x_idx = voxel_idx_flat / (out_y * out_z);
int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z;
int z_idx = voxel_idx_flat % out_z;
if (box_idx >= boxes_num || channel_idx >= channels || x_idx >= out_x ||
y_idx >= out_y || z_idx >= out_z)
return;
if (box_idx >= boxes_num || channel_idx >= channels) return;
int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx;
pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel +
@ -264,5 +255,6 @@ __global__ void roiaware_avgpool3d_backward(int boxes_num, int channels,
grad_out[0] * cur_grad);
}
}
}
#endif // ROIAWARE_POOL3D_CUDA_KERNEL_CUH

View File

@ -42,14 +42,13 @@ __global__ void assign_pts_to_box3d(int batch_size, int pts_num, int boxes_num,
// params boxes3d: (B, M, 7)
// params pts_assign: (B, N, M): idx of the corresponding box3d, -1 means
// background points
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
int box_idx = blockIdx.y;
int bs_idx = blockIdx.z;
CUDA_1D_KERNEL_LOOP(pt_idx, pts_num) {
if (box_idx >= boxes_num || bs_idx >= batch_size) return;
if (pt_idx >= pts_num || box_idx >= boxes_num || bs_idx >= batch_size) {
return;
}
int assign_idx = bs_idx * pts_num * boxes_num + pt_idx * boxes_num + box_idx;
int assign_idx =
bs_idx * pts_num * boxes_num + pt_idx * boxes_num + box_idx;
pts_assign[assign_idx] = 0;
int box_offset = bs_idx * boxes_num * 7 + box_idx * 7;
@ -60,6 +59,7 @@ __global__ void assign_pts_to_box3d(int batch_size, int pts_num, int boxes_num,
local_x, local_y);
pts_assign[assign_idx] = cur_in_flag;
}
}
__global__ void get_pooled_idx(int batch_size, int pts_num, int boxes_num,
int sampled_pts_num, const int *pts_assign,
@ -69,17 +69,13 @@ __global__ void get_pooled_idx(int batch_size, int pts_num, int boxes_num,
// params pts_assign: (B, N)
// params pts_idx: (B, M, 512)
// params pooled_empty_flag: (B, M)
int boxes_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (boxes_idx >= boxes_num) {
return;
}
CUDA_1D_KERNEL_LOOP(boxes_idx, boxes_num) {
int bs_idx = blockIdx.y;
int cnt = 0;
for (int k = 0; k < pts_num; k++) {
if (pts_assign[bs_idx * pts_num * boxes_num + k * boxes_num + boxes_idx]) {
if (pts_assign[bs_idx * pts_num * boxes_num + k * boxes_num +
boxes_idx]) {
if (cnt < sampled_pts_num) {
pts_idx[bs_idx * boxes_num * sampled_pts_num +
boxes_idx * sampled_pts_num + cnt] = k;
@ -101,6 +97,7 @@ __global__ void get_pooled_idx(int batch_size, int pts_num, int boxes_num,
}
}
}
}
template <typename T>
__global__ void roipoint_pool3d_forward(
@ -112,19 +109,11 @@ __global__ void roipoint_pool3d_forward(
// params pts_feature: (B, N, C)
// params pooled_features: (B, M, 512, 3+C)
// params pooled_empty_flag: (B, M)
int sample_pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
int box_idx = blockIdx.y;
int bs_idx = blockIdx.z;
if (sample_pt_idx >= sampled_pts_num || box_idx >= boxes_num ||
bs_idx >= batch_size) {
return;
}
if (pooled_empty_flag[bs_idx * boxes_num + box_idx]) {
return;
}
CUDA_1D_KERNEL_LOOP(sample_pt_idx, sampled_pts_num) {
if (box_idx >= boxes_num || bs_idx >= batch_size) return;
if (pooled_empty_flag[bs_idx * boxes_num + box_idx]) return;
int temp_idx = bs_idx * boxes_num * sampled_pts_num +
box_idx * sampled_pts_num + sample_pt_idx;
@ -140,5 +129,6 @@ __global__ void roipoint_pool3d_forward(
memcpy(pooled_features + dst_feature_offset + 3,
pts_feature + src_feature_offset, feature_in_len * sizeof(T));
}
}
#endif // ROIPOINT_POOL3D_CUDA_KERNEL_CUH

View File

@ -20,9 +20,8 @@ __global__ void three_interpolate_forward_cuda_kernel(
int bs_idx = blockIdx.z;
int c_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (bs_idx >= b || c_idx >= c || pt_idx >= n) return;
CUDA_1D_KERNEL_LOOP(pt_idx, n) {
if (bs_idx >= b || c_idx >= c) return;
weight += bs_idx * n * 3 + pt_idx * 3;
points += bs_idx * c * m + c_idx * m;
@ -32,6 +31,7 @@ __global__ void three_interpolate_forward_cuda_kernel(
out[pt_idx] = weight[0] * points[idx[0]] + weight[1] * points[idx[1]] +
weight[2] * points[idx[2]];
}
}
template <typename T>
__global__ void three_interpolate_backward_cuda_kernel(
@ -44,9 +44,8 @@ __global__ void three_interpolate_backward_cuda_kernel(
int bs_idx = blockIdx.z;
int c_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (bs_idx >= b || c_idx >= c || pt_idx >= n) return;
CUDA_1D_KERNEL_LOOP(pt_idx, n) {
if (bs_idx >= b || c_idx >= c) return;
grad_out += bs_idx * c * n + c_idx * n + pt_idx;
weight += bs_idx * n * 3 + pt_idx * 3;
@ -57,5 +56,6 @@ __global__ void three_interpolate_backward_cuda_kernel(
atomicAdd(grad_points + idx[1], grad_out[0] * weight[1]);
atomicAdd(grad_points + idx[2], grad_out[0] * weight[2]);
}
}
#endif // THREE_INTERPOLATE_CUDA_KERNEL_CUH

View File

@ -19,8 +19,8 @@ __global__ void three_nn_forward_cuda_kernel(int b, int n, int m,
// idx: (B, N, 3)
int bs_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (bs_idx >= b || pt_idx >= n) return;
CUDA_1D_KERNEL_LOOP(pt_idx, n) {
if (bs_idx >= b) return;
unknown += bs_idx * n * 3 + pt_idx * 3;
known += bs_idx * m * 3;
@ -62,5 +62,6 @@ __global__ void three_nn_forward_cuda_kernel(int b, int n, int m,
idx[1] = besti2;
idx[2] = besti3;
}
}
#endif // THREE_NN_CUDA_KERNEL_CUH

View File

@ -101,7 +101,7 @@ __global__ void point_to_voxelidx_kernel(const T_int* coor,
CUDA_1D_KERNEL_LOOP(index, num_points) {
auto coor_offset = coor + index * NDim;
// skip invalid points
if ((index >= num_points) || (coor_offset[0] == -1)) return;
if (coor_offset[0] == -1) return;
int num = 0;
int coor_x = coor_offset[0];

View File

@ -6,8 +6,6 @@
using namespace at;
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
#define CHECK_CUDA(x) \
TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CPU(x) \

View File

@ -73,8 +73,8 @@ void iou3d_nms_forward(Tensor boxes, Tensor keep, Tensor keep_num,
int64_t *keep_data = keep.data_ptr<int64_t>();
int64_t *keep_num_data = keep_num.data_ptr<int64_t>();
const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);
const int col_blocks =
(boxes_num + THREADS_PER_BLOCK_NMS - 1) / THREADS_PER_BLOCK_NMS;
Tensor mask =
at::empty({boxes_num, col_blocks}, boxes.options().dtype(at::kLong));
unsigned long long *mask_data =

View File

@ -13,7 +13,7 @@ void AssignScoreWithKForwardCUDAKernelLauncher(
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 blocks(DIVUP(B * O * N1 * K, THREADS_PER_BLOCK));
dim3 blocks(GET_BLOCKS(B * O * N1 * K, THREADS_PER_BLOCK));
dim3 threads(THREADS_PER_BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
@ -36,9 +36,9 @@ void AssignScoreWithKBackwardCUDAKernelLauncher(
at::cuda::CUDAGuard device_guard(grad_out.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 blocks1(DIVUP(B * M * O, THREADS_PER_BLOCK));
dim3 blocks1(GET_BLOCKS(B * M * O, THREADS_PER_BLOCK));
dim3 threads1(THREADS_PER_BLOCK);
dim3 blocks2(DIVUP(B * N1 * K * M, THREADS_PER_BLOCK));
dim3 blocks2(GET_BLOCKS(B * N1 * K * M, THREADS_PER_BLOCK));
dim3 threads2(THREADS_PER_BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(

View File

@ -22,7 +22,7 @@ void BallQueryForwardCUDAKernelLauncher(int b, int n, int m, float min_radius,
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// blockIdx.x(col), blockIdx.y(row)
dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b);
dim3 blocks(GET_BLOCKS(m, THREADS_PER_BLOCK), b);
dim3 threads(THREADS_PER_BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(

View File

@ -16,7 +16,7 @@ void GatherPointsForwardCUDAKernelLauncher(int b, int c, int n, int npoints,
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// blockIdx.x(col), blockIdx.y(row)
dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b);
dim3 blocks(GET_BLOCKS(npoints, THREADS_PER_BLOCK), c, b);
dim3 threads(THREADS_PER_BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
@ -43,7 +43,7 @@ void GatherPointsBackwardCUDAKernelLauncher(int b, int c, int n, int npoints,
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// blockIdx.x(col), blockIdx.y(row)
dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b);
dim3 blocks(GET_BLOCKS(npoints, THREADS_PER_BLOCK), c, b);
dim3 threads(THREADS_PER_BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(

View File

@ -19,7 +19,7 @@ void GroupPointsForwardCUDAKernelLauncher(int b, int c, int n, int npoints,
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// blockIdx.x(col), blockIdx.y(row)
dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b);
dim3 blocks(GET_BLOCKS(npoints * nsample, THREADS_PER_BLOCK), c, b);
dim3 threads(THREADS_PER_BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
@ -46,7 +46,7 @@ void GroupPointsBackwardCUDAKernelLauncher(int b, int c, int n, int npoints,
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// blockIdx.x(col), blockIdx.y(row)
dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b);
dim3 blocks(GET_BLOCKS(npoints * nsample, THREADS_PER_BLOCK), c, b);
dim3 threads(THREADS_PER_BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(

View File

@ -21,8 +21,8 @@ void IoU3DBoxesOverlapBevForwardCUDAKernelLauncher(const int num_a,
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// blockIdx.x(col), blockIdx.y(row)
dim3 blocks(DIVUP(num_b, THREADS_PER_BLOCK_IOU3D),
DIVUP(num_a, THREADS_PER_BLOCK_IOU3D));
dim3 blocks(GET_BLOCKS(num_b, THREADS_PER_BLOCK_IOU3D),
GET_BLOCKS(num_a, THREADS_PER_BLOCK_IOU3D));
dim3 threads(THREADS_PER_BLOCK_IOU3D, THREADS_PER_BLOCK_IOU3D);
iou3d_boxes_overlap_bev_forward_cuda_kernel<<<blocks, threads, 0, stream>>>(
@ -41,8 +41,8 @@ void IoU3DBoxesIoUBevForwardCUDAKernelLauncher(const int num_a,
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// blockIdx.x(col), blockIdx.y(row)
dim3 blocks(DIVUP(num_b, THREADS_PER_BLOCK_IOU3D),
DIVUP(num_a, THREADS_PER_BLOCK_IOU3D));
dim3 blocks(GET_BLOCKS(num_b, THREADS_PER_BLOCK_IOU3D),
GET_BLOCKS(num_a, THREADS_PER_BLOCK_IOU3D));
dim3 threads(THREADS_PER_BLOCK_IOU3D, THREADS_PER_BLOCK_IOU3D);
iou3d_boxes_iou_bev_forward_cuda_kernel<<<blocks, threads, 0, stream>>>(
@ -58,8 +58,8 @@ void IoU3DNMSForwardCUDAKernelLauncher(const Tensor boxes,
at::cuda::CUDAGuard device_guard(boxes.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 blocks(DIVUP(boxes_num, THREADS_PER_BLOCK_NMS),
DIVUP(boxes_num, THREADS_PER_BLOCK_NMS));
dim3 blocks(GET_BLOCKS(boxes_num, THREADS_PER_BLOCK_NMS),
GET_BLOCKS(boxes_num, THREADS_PER_BLOCK_NMS));
dim3 threads(THREADS_PER_BLOCK_NMS);
nms_forward_cuda_kernel<<<blocks, threads, 0, stream>>>(
@ -75,8 +75,8 @@ void IoU3DNMSNormalForwardCUDAKernelLauncher(const Tensor boxes,
at::cuda::CUDAGuard device_guard(boxes.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 blocks(DIVUP(boxes_num, THREADS_PER_BLOCK_NMS),
DIVUP(boxes_num, THREADS_PER_BLOCK_NMS));
dim3 blocks(GET_BLOCKS(boxes_num, THREADS_PER_BLOCK_NMS),
GET_BLOCKS(boxes_num, THREADS_PER_BLOCK_NMS));
dim3 threads(THREADS_PER_BLOCK_NMS);
nms_normal_forward_cuda_kernel<<<blocks, threads, 0, stream>>>(

View File

@ -19,7 +19,7 @@ void KNNForwardCUDAKernelLauncher(int b, int n, int m, int nsample,
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// blockIdx.x(col), blockIdx.y(row)
dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b);
dim3 blocks(GET_BLOCKS(m, THREADS_PER_BLOCK), b);
dim3 threads(THREADS_PER_BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(

View File

@ -13,10 +13,11 @@ Tensor NMSCUDAKernelLauncher(Tensor boxes, Tensor scores, float iou_threshold,
auto boxes_sorted = boxes.index_select(0, order_t);
int boxes_num = boxes.size(0);
const int col_blocks = DIVUP(boxes_num, threadsPerBlock);
const int col_blocks = (boxes_num + threadsPerBlock - 1) / threadsPerBlock;
const int col_blocks_alloc = GET_BLOCKS(boxes_num, threadsPerBlock);
Tensor mask =
at::empty({boxes_num, col_blocks}, boxes.options().dtype(at::kLong));
dim3 blocks(col_blocks, col_blocks);
dim3 blocks(col_blocks_alloc, col_blocks_alloc);
dim3 threads(threadsPerBlock);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
nms_cuda<<<blocks, threads, 0, stream>>>(

View File

@ -21,7 +21,7 @@ void PointsInBoxesPartForwardCUDAKernelLauncher(int batch_size, int boxes_num,
at::cuda::CUDAGuard device_guard(boxes.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 blocks(DIVUP(pts_num, THREADS_PER_BLOCK), batch_size);
dim3 blocks(GET_BLOCKS(pts_num, THREADS_PER_BLOCK), batch_size);
dim3 threads(THREADS_PER_BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
@ -47,7 +47,7 @@ void PointsInBoxesAllForwardCUDAKernelLauncher(int batch_size, int boxes_num,
at::cuda::CUDAGuard device_guard(boxes.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 blocks(DIVUP(pts_num, THREADS_PER_BLOCK), batch_size);
dim3 blocks(GET_BLOCKS(pts_num, THREADS_PER_BLOCK), batch_size);
dim3 threads(THREADS_PER_BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(

View File

@ -26,7 +26,7 @@ void RoiawarePool3dForwardCUDAKernelLauncher(
Tensor pts_mask =
-at::ones({boxes_num, pts_num}, pts_feature.options().dtype(at::kInt));
dim3 blocks_mask(DIVUP(pts_num, THREADS_PER_BLOCK), boxes_num);
dim3 blocks_mask(GET_BLOCKS(pts_num, THREADS_PER_BLOCK), boxes_num);
dim3 threads(THREADS_PER_BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
@ -42,7 +42,7 @@ void RoiawarePool3dForwardCUDAKernelLauncher(
// TODO: Merge the collect and pool functions, SS
dim3 blocks_collect(DIVUP(boxes_num, THREADS_PER_BLOCK));
dim3 blocks_collect(GET_BLOCKS(boxes_num, THREADS_PER_BLOCK));
AT_DISPATCH_INTEGRAL_TYPES(
pts_idx_of_voxels.scalar_type(), "collect_inside_pts_for_box3d", [&] {
@ -55,8 +55,8 @@ void RoiawarePool3dForwardCUDAKernelLauncher(
AT_CUDA_CHECK(cudaGetLastError());
dim3 blocks_pool(DIVUP(out_x * out_y * out_z, THREADS_PER_BLOCK), channels,
boxes_num);
dim3 blocks_pool(GET_BLOCKS(out_x * out_y * out_z, THREADS_PER_BLOCK),
channels, boxes_num);
if (pool_method == 0) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
pts_feature.scalar_type(), "roiaware_maxpool3d", [&] {
@ -93,7 +93,7 @@ void RoiawarePool3dBackwardCUDAKernelLauncher(
at::cuda::CUDAGuard device_guard(grad_out.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 blocks(DIVUP(out_x * out_y * out_z, THREADS_PER_BLOCK), channels,
dim3 blocks(GET_BLOCKS(out_x * out_y * out_z, THREADS_PER_BLOCK), channels,
boxes_num);
dim3 threads(THREADS_PER_BLOCK);

View File

@ -24,7 +24,7 @@ void RoIPointPool3dForwardCUDAKernelLauncher(
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// blockIdx.x(col), blockIdx.y(row)
dim3 blocks(DIVUP(pts_num, THREADS_PER_BLOCK), boxes_num, batch_size);
dim3 blocks(GET_BLOCKS(pts_num, THREADS_PER_BLOCK), boxes_num, batch_size);
dim3 threads(THREADS_PER_BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
@ -38,14 +38,14 @@ void RoIPointPool3dForwardCUDAKernelLauncher(
boxes3d.options().dtype(at::kInt));
// blockIdx.x(col), blockIdx.y(row)
dim3 blocks2(DIVUP(boxes_num, THREADS_PER_BLOCK), batch_size);
dim3 blocks2(GET_BLOCKS(boxes_num, THREADS_PER_BLOCK), batch_size);
get_pooled_idx<<<blocks2, threads, 0, stream>>>(
batch_size, pts_num, boxes_num, sampled_pts_num,
pts_assign.data_ptr<int>(), pts_idx.data_ptr<int>(),
pooled_empty_flag.data_ptr<int>());
dim3 blocks_pool(DIVUP(sampled_pts_num, THREADS_PER_BLOCK), boxes_num,
dim3 blocks_pool(GET_BLOCKS(sampled_pts_num, THREADS_PER_BLOCK), boxes_num,
batch_size);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(

View File

@ -23,7 +23,7 @@ void ThreeInterpolateForwardCUDAKernelLauncher(int b, int c, int m, int n,
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// blockIdx.x(col), blockIdx.y(row)
dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b);
dim3 blocks(GET_BLOCKS(n, THREADS_PER_BLOCK), c, b);
dim3 threads(THREADS_PER_BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
@ -51,7 +51,7 @@ void ThreeInterpolateBackwardCUDAKernelLauncher(int b, int c, int n, int m,
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// blockIdx.x(col), blockIdx.y(row)
dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b);
dim3 blocks(GET_BLOCKS(n, THREADS_PER_BLOCK), c, b);
dim3 threads(THREADS_PER_BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(

View File

@ -21,7 +21,7 @@ void ThreeNNForwardCUDAKernelLauncher(int b, int n, int m, const Tensor unknown,
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// blockIdx.x(col), blockIdx.y(row)
dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), b);
dim3 blocks(GET_BLOCKS(n, THREADS_PER_BLOCK), b);
dim3 threads(THREADS_PER_BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(

View File

@ -73,7 +73,8 @@ void iou3d_nms_forward(Tensor boxes, Tensor keep, Tensor keep_num,
int64_t *keep_data = keep.data_ptr<int64_t>();
int64_t *keep_num_data = keep_num.data_ptr<int64_t>();
const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);
const int col_blocks =
(boxes_num + THREADS_PER_BLOCK_NMS - 1) / THREADS_PER_BLOCK_NMS;
Tensor mask =
at::empty({boxes_num, col_blocks}, boxes.options().dtype(at::kLong));
@ -117,7 +118,8 @@ void iou3d_nms_normal_forward(Tensor boxes, Tensor keep, Tensor keep_num,
int64_t *keep_data = keep.data_ptr<int64_t>();
int64_t *keep_num_data = keep_num.data_ptr<int64_t>();
const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);
const int col_blocks =
(boxes_num + THREADS_PER_BLOCK_NMS - 1) / THREADS_PER_BLOCK_NMS;
Tensor mask =
at::empty({boxes_num, col_blocks}, boxes.options().dtype(at::kLong));

View File

@ -85,7 +85,7 @@ void CornerPoolForwardLauncher(const scalar_t *input, scalar_t *output,
case 0:
case 1:
nthreads = batch_size * channels * width;
col_block = DIVUP(nthreads, THREADS_PER_BLOCK);
col_block = GET_BLOCKS(nthreads, THREADS_PER_BLOCK);
top_bottom_pool_kernel<scalar_t>
<<<col_block, THREADS_PER_BLOCK, 0, stream>>>(
input, output, batch_size, channels, height, width, pool_type);
@ -93,7 +93,7 @@ void CornerPoolForwardLauncher(const scalar_t *input, scalar_t *output,
case 2:
case 3:
nthreads = batch_size * channels * height;
col_block = DIVUP(nthreads, THREADS_PER_BLOCK);
col_block = GET_BLOCKS(nthreads, THREADS_PER_BLOCK);
left_right_pool_kernel<scalar_t>
<<<col_block, THREADS_PER_BLOCK, 0, stream>>>(
input, output, batch_size, channels, height, width, pool_type);

View File

@ -67,7 +67,7 @@ void CumMaxMinForwardLauncher(const scalar_t *input, scalar_t *output_value,
const int data_size =
tensor_desc.stride[0] * tensor_desc.shape[0] / tensor_desc.shape[cum_dim];
const int col_block = DIVUP(data_size, THREADS_PER_BLOCK);
const int col_block = GET_BLOCKS(data_size, THREADS_PER_BLOCK);
cummaxmin_kernel<scalar_t><<<col_block, THREADS_PER_BLOCK, 0, stream>>>(
input, output_value, output_index, tensor_desc, cum_dim, cum_type);

View File

@ -114,7 +114,8 @@ size_t get_onnxnms_workspace_size(size_t num_batches, size_t spatial_dimension,
mmcv::getAlignedSize(spatial_dimension * boxes_word_size);
size_t boxes_workspace =
mmcv::getAlignedSize(spatial_dimension * 4 * boxes_word_size);
const int col_blocks = DIVUP(spatial_dimension, threadsPerBlock);
const int col_blocks =
(spatial_dimension + threadsPerBlock - 1) / threadsPerBlock;
size_t mask_workspace = mmcv::getAlignedSize(spatial_dimension * col_blocks *
sizeof(unsigned long long));
size_t index_template_workspace =
@ -163,7 +164,8 @@ void TRTNMSCUDAKernelLauncher_float(const float* boxes, const float* scores,
int spatial_dimension, int num_classes,
size_t output_length, void* workspace,
cudaStream_t stream) {
const int col_blocks = DIVUP(spatial_dimension, threadsPerBlock);
const int col_blocks =
(spatial_dimension + threadsPerBlock - 1) / threadsPerBlock;
float* boxes_sorted = (float*)workspace;
workspace = static_cast<char*>(workspace) +
mmcv::getAlignedSize(spatial_dimension * 4 * sizeof(float));

View File

@ -67,7 +67,7 @@ void TRTONNXScatterNDKernelLauncher(const T* data, const int* indices,
num_update_indice *= indice_desc.shape[i];
}
// scatter
const int col_block = DIVUP(num_update_indice, threadsPerBlock);
const int col_block = GET_BLOCKS(num_update_indice, threadsPerBlock);
onnx_scatternd_kernel<<<col_block, threadsPerBlock, 0, stream>>>(
num_update_indice, indices, update, output, tensor_desc, indice_desc);
}

View File

@ -3,8 +3,6 @@
#define TRT_CUDA_HELPER_HPP
#include <cublas_v2.h>
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
#define cudaCheckError() \
{ \
cudaError_t e = cudaGetLastError(); \