From b5ec327d34dfeb93fc02c04beb7d450b2d68323b Mon Sep 17 00:00:00 2001 From: pc Date: Tue, 11 Jan 2022 11:11:08 +0800 Subject: [PATCH] Fix iou3d bug in parrots (#1656) --- mmcv/ops/csrc/parrots/iou3d.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mmcv/ops/csrc/parrots/iou3d.cpp b/mmcv/ops/csrc/parrots/iou3d.cpp index 4d785703e..8af4f05da 100644 --- a/mmcv/ops/csrc/parrots/iou3d.cpp +++ b/mmcv/ops/csrc/parrots/iou3d.cpp @@ -75,6 +75,7 @@ void iou3d_nms_forward(Tensor boxes, Tensor keep, Tensor keep_num, const int col_blocks = (boxes_num + THREADS_PER_BLOCK_NMS - 1) / THREADS_PER_BLOCK_NMS; + Tensor mask = at::empty({boxes_num, col_blocks}, boxes.options().dtype(at::kLong)); unsigned long long *mask_data = @@ -117,7 +118,8 @@ void iou3d_nms_normal_forward(Tensor boxes, Tensor keep, Tensor keep_num, int64_t *keep_data = keep.data_ptr(); int64_t *keep_num_data = keep_num.data_ptr(); - const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS); + const int col_blocks = + (boxes_num + THREADS_PER_BLOCK_NMS - 1) / THREADS_PER_BLOCK_NMS; Tensor mask = at::empty({boxes_num, col_blocks}, boxes.options().dtype(at::kLong));