Fix iou3d bug in parrots (#1656)

This commit is contained in:
pc 2022-01-11 11:11:08 +08:00 committed by GitHub
parent 48419395e3
commit b5ec327d34
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -75,6 +75,7 @@ void iou3d_nms_forward(Tensor boxes, Tensor keep, Tensor keep_num,
const int col_blocks =
(boxes_num + THREADS_PER_BLOCK_NMS - 1) / THREADS_PER_BLOCK_NMS;
Tensor mask =
at::empty({boxes_num, col_blocks}, boxes.options().dtype(at::kLong));
unsigned long long *mask_data =
@ -117,7 +118,8 @@ void iou3d_nms_normal_forward(Tensor boxes, Tensor keep, Tensor keep_num,
int64_t *keep_data = keep.data_ptr<int64_t>();
int64_t *keep_num_data = keep_num.data_ptr<int64_t>();
const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);
const int col_blocks =
(boxes_num + THREADS_PER_BLOCK_NMS - 1) / THREADS_PER_BLOCK_NMS;
Tensor mask =
at::empty({boxes_num, col_blocks}, boxes.options().dtype(at::kLong));