diff --git a/mmcv/ops/csrc/parrots/iou3d.cpp b/mmcv/ops/csrc/parrots/iou3d.cpp index 4d785703e..8af4f05da 100644 --- a/mmcv/ops/csrc/parrots/iou3d.cpp +++ b/mmcv/ops/csrc/parrots/iou3d.cpp @@ -75,6 +75,7 @@ void iou3d_nms_forward(Tensor boxes, Tensor keep, Tensor keep_num, const int col_blocks = (boxes_num + THREADS_PER_BLOCK_NMS - 1) / THREADS_PER_BLOCK_NMS; + Tensor mask = at::empty({boxes_num, col_blocks}, boxes.options().dtype(at::kLong)); unsigned long long *mask_data = @@ -117,7 +118,8 @@ void iou3d_nms_normal_forward(Tensor boxes, Tensor keep, Tensor keep_num, int64_t *keep_data = keep.data_ptr(); int64_t *keep_num_data = keep_num.data_ptr(); - const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS); + const int col_blocks = + (boxes_num + THREADS_PER_BLOCK_NMS - 1) / THREADS_PER_BLOCK_NMS; Tensor mask = at::empty({boxes_num, col_blocks}, boxes.options().dtype(at::kLong));