mirror of https://github.com/open-mmlab/mmcv.git
[Refactor] Refactor the nms3d op to support MLU (#2296)
* update iou3d * fix parrot * update to device * replace count_nonzero with fill * update build.ymlpull/2205/head^2
parent
b1711db048
commit
c001e2fcba
|
@ -157,10 +157,8 @@ jobs:
|
|||
strategy:
|
||||
matrix:
|
||||
python-version: [3.7]
|
||||
torch: [1.3.1, 1.5.1+cu101, 1.6.0+cu101, 1.7.0+cu101, 1.8.0+cu101]
|
||||
torch: [1.5.1+cu101, 1.6.0+cu101, 1.7.0+cu101, 1.8.0+cu101]
|
||||
include:
|
||||
- torch: 1.3.1
|
||||
torchvision: 0.4.2
|
||||
- torch: 1.5.1+cu101
|
||||
torchvision: 0.6.1+cu101
|
||||
- torch: 1.6.0+cu101
|
||||
|
@ -362,10 +360,8 @@ jobs:
|
|||
runs-on: macos-latest
|
||||
strategy:
|
||||
matrix:
|
||||
torch: [1.3.1, 1.5.1, 1.6.0, 1.7.0, 1.8.0, 1.9.0]
|
||||
torch: [1.5.1, 1.6.0, 1.7.0, 1.8.0, 1.9.0]
|
||||
include:
|
||||
- torch: 1.3.1
|
||||
torchvision: 0.4.2
|
||||
- torch: 1.5.1
|
||||
torchvision: 0.6.1
|
||||
- torch: 1.6.0
|
||||
|
|
|
@ -27,9 +27,9 @@ __device__ inline bool devIoU(float const *const a, float const *const b,
|
|||
return interS > threshold * (Sa + Sb - interS);
|
||||
}
|
||||
|
||||
__global__ void nms_cuda(const int n_boxes, const float iou_threshold,
|
||||
const int offset, const float *dev_boxes,
|
||||
unsigned long long *dev_mask) {
|
||||
__global__ static void nms_cuda(const int n_boxes, const float iou_threshold,
|
||||
const int offset, const float *dev_boxes,
|
||||
unsigned long long *dev_mask) {
|
||||
int blocks = (n_boxes + threadsPerBlock - 1) / threadsPerBlock;
|
||||
CUDA_2D_KERNEL_BLOCK_LOOP(col_start, blocks, row_start, blocks) {
|
||||
const int tid = threadIdx.x;
|
||||
|
@ -73,9 +73,9 @@ __global__ void nms_cuda(const int n_boxes, const float iou_threshold,
|
|||
}
|
||||
}
|
||||
|
||||
__global__ void gather_keep_from_mask(bool *keep,
|
||||
const unsigned long long *dev_mask,
|
||||
const int n_boxes) {
|
||||
__global__ static void gather_keep_from_mask(bool *keep,
|
||||
const unsigned long long *dev_mask,
|
||||
const int n_boxes) {
|
||||
const int col_blocks = (n_boxes + threadsPerBlock - 1) / threadsPerBlock;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
|
|
|
@ -570,14 +570,12 @@ void IoU3DBoxesOverlapBevForwardCUDAKernelLauncher(const int num_a,
|
|||
const Tensor boxes_b,
|
||||
Tensor ans_overlap);
|
||||
|
||||
void IoU3DNMS3DForwardCUDAKernelLauncher(const Tensor boxes,
|
||||
unsigned long long* mask,
|
||||
int boxes_num,
|
||||
void IoU3DNMS3DForwardCUDAKernelLauncher(const Tensor boxes, Tensor& keep,
|
||||
Tensor& keep_num,
|
||||
float nms_overlap_thresh);
|
||||
|
||||
void IoU3DNMS3DNormalForwardCUDAKernelLauncher(const Tensor boxes,
|
||||
unsigned long long* mask,
|
||||
int boxes_num,
|
||||
void IoU3DNMS3DNormalForwardCUDAKernelLauncher(const Tensor boxes, Tensor& keep,
|
||||
Tensor& keep_num,
|
||||
float nms_overlap_thresh);
|
||||
|
||||
void iou3d_boxes_overlap_bev_forward_cuda(const int num_a, const Tensor boxes_a,
|
||||
|
@ -587,16 +585,16 @@ void iou3d_boxes_overlap_bev_forward_cuda(const int num_a, const Tensor boxes_a,
|
|||
ans_overlap);
|
||||
};
|
||||
|
||||
void iou3d_nms3d_forward_cuda(const Tensor boxes, unsigned long long* mask,
|
||||
int boxes_num, float nms_overlap_thresh) {
|
||||
IoU3DNMS3DForwardCUDAKernelLauncher(boxes, mask, boxes_num,
|
||||
void iou3d_nms3d_forward_cuda(const Tensor boxes, Tensor& keep,
|
||||
Tensor& keep_num, float nms_overlap_thresh) {
|
||||
IoU3DNMS3DForwardCUDAKernelLauncher(boxes, keep, keep_num,
|
||||
nms_overlap_thresh);
|
||||
};
|
||||
|
||||
void iou3d_nms3d_normal_forward_cuda(const Tensor boxes,
|
||||
unsigned long long* mask, int boxes_num,
|
||||
void iou3d_nms3d_normal_forward_cuda(const Tensor boxes, Tensor& keep,
|
||||
Tensor& keep_num,
|
||||
float nms_overlap_thresh) {
|
||||
IoU3DNMS3DNormalForwardCUDAKernelLauncher(boxes, mask, boxes_num,
|
||||
IoU3DNMS3DNormalForwardCUDAKernelLauncher(boxes, keep, keep_num,
|
||||
nms_overlap_thresh);
|
||||
};
|
||||
|
||||
|
@ -604,11 +602,11 @@ void iou3d_boxes_overlap_bev_forward_impl(const int num_a, const Tensor boxes_a,
|
|||
const int num_b, const Tensor boxes_b,
|
||||
Tensor ans_overlap);
|
||||
|
||||
void iou3d_nms3d_forward_impl(const Tensor boxes, unsigned long long* mask,
|
||||
int boxes_num, float nms_overlap_thresh);
|
||||
void iou3d_nms3d_forward_impl(const Tensor boxes, Tensor& keep,
|
||||
Tensor& keep_num, float nms_overlap_thresh);
|
||||
|
||||
void iou3d_nms3d_normal_forward_impl(const Tensor boxes,
|
||||
unsigned long long* mask, int boxes_num,
|
||||
void iou3d_nms3d_normal_forward_impl(const Tensor boxes, Tensor& keep,
|
||||
Tensor& keep_num,
|
||||
float nms_overlap_thresh);
|
||||
|
||||
REGISTER_DEVICE_IMPL(iou3d_boxes_overlap_bev_forward_impl, CUDA,
|
||||
|
|
|
@ -19,16 +19,16 @@ void iou3d_boxes_overlap_bev_forward_impl(const int num_a, const Tensor boxes_a,
|
|||
num_b, boxes_b, ans_overlap);
|
||||
}
|
||||
|
||||
void iou3d_nms3d_forward_impl(const Tensor boxes, unsigned long long *mask,
|
||||
int boxes_num, float nms_overlap_thresh) {
|
||||
DISPATCH_DEVICE_IMPL(iou3d_nms3d_forward_impl, boxes, mask, boxes_num,
|
||||
void iou3d_nms3d_forward_impl(const Tensor boxes, Tensor &keep,
|
||||
Tensor &keep_num, float nms_overlap_thresh) {
|
||||
DISPATCH_DEVICE_IMPL(iou3d_nms3d_forward_impl, boxes, keep, keep_num,
|
||||
nms_overlap_thresh);
|
||||
}
|
||||
|
||||
void iou3d_nms3d_normal_forward_impl(const Tensor boxes,
|
||||
unsigned long long *mask, int boxes_num,
|
||||
void iou3d_nms3d_normal_forward_impl(const Tensor boxes, Tensor &keep,
|
||||
Tensor &keep_num,
|
||||
float nms_overlap_thresh) {
|
||||
DISPATCH_DEVICE_IMPL(iou3d_nms3d_normal_forward_impl, boxes, mask, boxes_num,
|
||||
DISPATCH_DEVICE_IMPL(iou3d_nms3d_normal_forward_impl, boxes, keep, keep_num,
|
||||
nms_overlap_thresh);
|
||||
}
|
||||
|
||||
|
@ -51,41 +51,7 @@ void iou3d_nms3d_forward(Tensor boxes, Tensor keep, Tensor keep_num,
|
|||
CHECK_CONTIGUOUS(boxes);
|
||||
CHECK_CONTIGUOUS(keep);
|
||||
|
||||
int boxes_num = boxes.size(0);
|
||||
int64_t *keep_data = keep.data_ptr<int64_t>();
|
||||
int64_t *keep_num_data = keep_num.data_ptr<int64_t>();
|
||||
|
||||
const int col_blocks =
|
||||
(boxes_num + THREADS_PER_BLOCK_NMS - 1) / THREADS_PER_BLOCK_NMS;
|
||||
|
||||
Tensor mask =
|
||||
at::empty({boxes_num, col_blocks}, boxes.options().dtype(at::kLong));
|
||||
unsigned long long *mask_data =
|
||||
(unsigned long long *)mask.data_ptr<int64_t>();
|
||||
iou3d_nms3d_forward_impl(boxes, mask_data, boxes_num, nms_overlap_thresh);
|
||||
|
||||
at::Tensor mask_cpu = mask.to(at::kCPU);
|
||||
unsigned long long *mask_host =
|
||||
(unsigned long long *)mask_cpu.data_ptr<int64_t>();
|
||||
|
||||
std::vector<unsigned long long> remv_cpu(col_blocks);
|
||||
memset(&remv_cpu[0], 0, sizeof(unsigned long long) * col_blocks);
|
||||
|
||||
int num_to_keep = 0;
|
||||
|
||||
for (int i = 0; i < boxes_num; i++) {
|
||||
int nblock = i / THREADS_PER_BLOCK_NMS;
|
||||
int inblock = i % THREADS_PER_BLOCK_NMS;
|
||||
|
||||
if (!(remv_cpu[nblock] & (1ULL << inblock))) {
|
||||
keep_data[num_to_keep++] = i;
|
||||
unsigned long long *p = &mask_host[0] + i * col_blocks;
|
||||
for (int j = nblock; j < col_blocks; j++) {
|
||||
remv_cpu[j] |= p[j];
|
||||
}
|
||||
}
|
||||
*keep_num_data = num_to_keep;
|
||||
}
|
||||
iou3d_nms3d_forward_impl(boxes, keep, keep_num, nms_overlap_thresh);
|
||||
}
|
||||
|
||||
void iou3d_nms3d_normal_forward(Tensor boxes, Tensor keep, Tensor keep_num,
|
||||
|
@ -96,40 +62,5 @@ void iou3d_nms3d_normal_forward(Tensor boxes, Tensor keep, Tensor keep_num,
|
|||
CHECK_CONTIGUOUS(boxes);
|
||||
CHECK_CONTIGUOUS(keep);
|
||||
|
||||
int boxes_num = boxes.size(0);
|
||||
int64_t *keep_data = keep.data_ptr<int64_t>();
|
||||
int64_t *keep_num_data = keep_num.data_ptr<int64_t>();
|
||||
|
||||
const int col_blocks =
|
||||
(boxes_num + THREADS_PER_BLOCK_NMS - 1) / THREADS_PER_BLOCK_NMS;
|
||||
|
||||
Tensor mask =
|
||||
at::empty({boxes_num, col_blocks}, boxes.options().dtype(at::kLong));
|
||||
unsigned long long *mask_data =
|
||||
(unsigned long long *)mask.data_ptr<int64_t>();
|
||||
iou3d_nms3d_normal_forward_impl(boxes, mask_data, boxes_num,
|
||||
nms_overlap_thresh);
|
||||
|
||||
at::Tensor mask_cpu = mask.to(at::kCPU);
|
||||
unsigned long long *mask_host =
|
||||
(unsigned long long *)mask_cpu.data_ptr<int64_t>();
|
||||
|
||||
std::vector<unsigned long long> remv_cpu(col_blocks);
|
||||
memset(&remv_cpu[0], 0, sizeof(unsigned long long) * col_blocks);
|
||||
int num_to_keep = 0;
|
||||
|
||||
for (int i = 0; i < boxes_num; i++) {
|
||||
int nblock = i / THREADS_PER_BLOCK_NMS;
|
||||
int inblock = i % THREADS_PER_BLOCK_NMS;
|
||||
|
||||
if (!(remv_cpu[nblock] & (1ULL << inblock))) {
|
||||
keep_data[num_to_keep++] = i;
|
||||
unsigned long long *p = &mask_host[0] + i * col_blocks;
|
||||
for (int j = nblock; j < col_blocks; j++) {
|
||||
remv_cpu[j] |= p[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
*keep_num_data = num_to_keep;
|
||||
iou3d_nms3d_normal_forward_impl(boxes, keep, keep_num, nms_overlap_thresh);
|
||||
}
|
||||
|
|
|
@ -570,14 +570,12 @@ void IoU3DBoxesOverlapBevForwardCUDAKernelLauncher(const int num_a,
|
|||
const Tensor boxes_b,
|
||||
Tensor ans_overlap);
|
||||
|
||||
void IoU3DNMS3DForwardCUDAKernelLauncher(const Tensor boxes,
|
||||
unsigned long long* mask,
|
||||
int boxes_num,
|
||||
void IoU3DNMS3DForwardCUDAKernelLauncher(const Tensor boxes, Tensor& keep,
|
||||
Tensor& keep_num,
|
||||
float nms_overlap_thresh);
|
||||
|
||||
void IoU3DNMS3DNormalForwardCUDAKernelLauncher(const Tensor boxes,
|
||||
unsigned long long* mask,
|
||||
int boxes_num,
|
||||
void IoU3DNMS3DNormalForwardCUDAKernelLauncher(const Tensor boxes, Tensor& keep,
|
||||
Tensor& keep_num,
|
||||
float nms_overlap_thresh);
|
||||
|
||||
void iou3d_boxes_overlap_bev_forward_cuda(const int num_a, const Tensor boxes_a,
|
||||
|
@ -587,16 +585,16 @@ void iou3d_boxes_overlap_bev_forward_cuda(const int num_a, const Tensor boxes_a,
|
|||
ans_overlap);
|
||||
};
|
||||
|
||||
void iou3d_nms3d_forward_cuda(const Tensor boxes, unsigned long long* mask,
|
||||
int boxes_num, float nms_overlap_thresh) {
|
||||
IoU3DNMS3DForwardCUDAKernelLauncher(boxes, mask, boxes_num,
|
||||
void iou3d_nms3d_forward_cuda(const Tensor boxes, Tensor& keep,
|
||||
Tensor& keep_num, float nms_overlap_thresh) {
|
||||
IoU3DNMS3DForwardCUDAKernelLauncher(boxes, keep, keep_num,
|
||||
nms_overlap_thresh);
|
||||
};
|
||||
|
||||
void iou3d_nms3d_normal_forward_cuda(const Tensor boxes,
|
||||
unsigned long long* mask, int boxes_num,
|
||||
void iou3d_nms3d_normal_forward_cuda(const Tensor boxes, Tensor& keep,
|
||||
Tensor& keep_num,
|
||||
float nms_overlap_thresh) {
|
||||
IoU3DNMS3DNormalForwardCUDAKernelLauncher(boxes, mask, boxes_num,
|
||||
IoU3DNMS3DNormalForwardCUDAKernelLauncher(boxes, keep, keep_num,
|
||||
nms_overlap_thresh);
|
||||
};
|
||||
|
||||
|
@ -604,11 +602,11 @@ void iou3d_boxes_overlap_bev_forward_impl(const int num_a, const Tensor boxes_a,
|
|||
const int num_b, const Tensor boxes_b,
|
||||
Tensor ans_overlap);
|
||||
|
||||
void iou3d_nms3d_forward_impl(const Tensor boxes, unsigned long long* mask,
|
||||
int boxes_num, float nms_overlap_thresh);
|
||||
void iou3d_nms3d_forward_impl(const Tensor boxes, Tensor& keep,
|
||||
Tensor& keep_num, float nms_overlap_thresh);
|
||||
|
||||
void iou3d_nms3d_normal_forward_impl(const Tensor boxes,
|
||||
unsigned long long* mask, int boxes_num,
|
||||
void iou3d_nms3d_normal_forward_impl(const Tensor boxes, Tensor& keep,
|
||||
Tensor& keep_num,
|
||||
float nms_overlap_thresh);
|
||||
|
||||
REGISTER_DEVICE_IMPL(iou3d_boxes_overlap_bev_forward_impl, CUDA,
|
||||
|
|
|
@ -10,6 +10,7 @@ All Rights Reserved 2019-2020.
|
|||
#include <stdio.h>
|
||||
|
||||
#include "iou3d_cuda_kernel.cuh"
|
||||
#include "nms_cuda_kernel.cuh"
|
||||
#include "pytorch_cuda_helper.hpp"
|
||||
|
||||
void IoU3DBoxesOverlapBevForwardCUDAKernelLauncher(const int num_a,
|
||||
|
@ -32,36 +33,72 @@ void IoU3DBoxesOverlapBevForwardCUDAKernelLauncher(const int num_a,
|
|||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
void IoU3DNMS3DForwardCUDAKernelLauncher(const Tensor boxes,
|
||||
unsigned long long *mask,
|
||||
int boxes_num,
|
||||
void IoU3DNMS3DForwardCUDAKernelLauncher(const Tensor boxes, Tensor& keep,
|
||||
Tensor& keep_num,
|
||||
float nms_overlap_thresh) {
|
||||
using namespace at::indexing;
|
||||
at::cuda::CUDAGuard device_guard(boxes.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
int boxes_num = boxes.size(0);
|
||||
|
||||
const int col_blocks =
|
||||
(boxes_num + THREADS_PER_BLOCK_NMS - 1) / THREADS_PER_BLOCK_NMS;
|
||||
Tensor mask =
|
||||
at::empty({boxes_num, col_blocks}, boxes.options().dtype(at::kLong));
|
||||
|
||||
dim3 blocks(GET_BLOCKS(boxes_num, THREADS_PER_BLOCK_NMS),
|
||||
GET_BLOCKS(boxes_num, THREADS_PER_BLOCK_NMS));
|
||||
dim3 threads(THREADS_PER_BLOCK_NMS);
|
||||
|
||||
iou3d_nms3d_forward_cuda_kernel<<<blocks, threads, 0, stream>>>(
|
||||
boxes_num, nms_overlap_thresh, boxes.data_ptr<float>(), mask);
|
||||
boxes_num, nms_overlap_thresh, boxes.data_ptr<float>(),
|
||||
(unsigned long long*)mask.data_ptr<int64_t>());
|
||||
|
||||
at::Tensor keep_t = at::zeros(
|
||||
{boxes_num}, boxes.options().dtype(at::kBool).device(at::kCUDA));
|
||||
gather_keep_from_mask<<<1, min(col_blocks, THREADS_PER_BLOCK),
|
||||
col_blocks * sizeof(unsigned long long), stream>>>(
|
||||
keep_t.data_ptr<bool>(), (unsigned long long*)mask.data_ptr<int64_t>(),
|
||||
boxes_num);
|
||||
|
||||
auto keep_data = keep_t.nonzero().index({Slice(), 0});
|
||||
keep_num.fill_(at::Scalar(keep_data.size(0)));
|
||||
keep.index_put_({Slice(0, keep_data.size(0))}, keep_data);
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
void IoU3DNMS3DNormalForwardCUDAKernelLauncher(const Tensor boxes,
|
||||
unsigned long long *mask,
|
||||
int boxes_num,
|
||||
void IoU3DNMS3DNormalForwardCUDAKernelLauncher(const Tensor boxes, Tensor& keep,
|
||||
Tensor& keep_num,
|
||||
float nms_overlap_thresh) {
|
||||
using namespace at::indexing;
|
||||
at::cuda::CUDAGuard device_guard(boxes.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
int boxes_num = boxes.size(0);
|
||||
|
||||
const int col_blocks =
|
||||
(boxes_num + THREADS_PER_BLOCK_NMS - 1) / THREADS_PER_BLOCK_NMS;
|
||||
Tensor mask =
|
||||
at::empty({boxes_num, col_blocks}, boxes.options().dtype(at::kLong));
|
||||
|
||||
dim3 blocks(GET_BLOCKS(boxes_num, THREADS_PER_BLOCK_NMS),
|
||||
GET_BLOCKS(boxes_num, THREADS_PER_BLOCK_NMS));
|
||||
dim3 threads(THREADS_PER_BLOCK_NMS);
|
||||
|
||||
iou3d_nms3d_normal_forward_cuda_kernel<<<blocks, threads, 0, stream>>>(
|
||||
boxes_num, nms_overlap_thresh, boxes.data_ptr<float>(), mask);
|
||||
boxes_num, nms_overlap_thresh, boxes.data_ptr<float>(),
|
||||
(unsigned long long*)mask.data_ptr<int64_t>());
|
||||
|
||||
at::Tensor keep_t = at::zeros(
|
||||
{boxes_num}, boxes.options().dtype(at::kBool).device(at::kCUDA));
|
||||
gather_keep_from_mask<<<1, min(col_blocks, THREADS_PER_BLOCK),
|
||||
col_blocks * sizeof(unsigned long long), stream>>>(
|
||||
keep_t.data_ptr<bool>(), (unsigned long long*)mask.data_ptr<int64_t>(),
|
||||
boxes_num);
|
||||
|
||||
auto keep_data = keep_t.nonzero().index({Slice(), 0});
|
||||
keep_num.fill_(at::Scalar(keep_data.size(0)));
|
||||
keep.index_put_({Slice(0, keep_data.size(0))}, keep_data);
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
|
|
@ -19,16 +19,16 @@ void iou3d_boxes_overlap_bev_forward_impl(const int num_a, const Tensor boxes_a,
|
|||
num_b, boxes_b, ans_overlap);
|
||||
}
|
||||
|
||||
void iou3d_nms3d_forward_impl(const Tensor boxes, unsigned long long *mask,
|
||||
int boxes_num, float nms_overlap_thresh) {
|
||||
DISPATCH_DEVICE_IMPL(iou3d_nms3d_forward_impl, boxes, mask, boxes_num,
|
||||
void iou3d_nms3d_forward_impl(const Tensor boxes, Tensor &keep,
|
||||
Tensor &keep_num, float nms_overlap_thresh) {
|
||||
DISPATCH_DEVICE_IMPL(iou3d_nms3d_forward_impl, boxes, keep, keep_num,
|
||||
nms_overlap_thresh);
|
||||
}
|
||||
|
||||
void iou3d_nms3d_normal_forward_impl(const Tensor boxes,
|
||||
unsigned long long *mask, int boxes_num,
|
||||
void iou3d_nms3d_normal_forward_impl(const Tensor boxes, Tensor &keep,
|
||||
Tensor &keep_num,
|
||||
float nms_overlap_thresh) {
|
||||
DISPATCH_DEVICE_IMPL(iou3d_nms3d_normal_forward_impl, boxes, mask, boxes_num,
|
||||
DISPATCH_DEVICE_IMPL(iou3d_nms3d_normal_forward_impl, boxes, keep, keep_num,
|
||||
nms_overlap_thresh);
|
||||
}
|
||||
|
||||
|
@ -51,41 +51,7 @@ void iou3d_nms3d_forward(Tensor boxes, Tensor keep, Tensor keep_num,
|
|||
CHECK_CONTIGUOUS(boxes);
|
||||
CHECK_CONTIGUOUS(keep);
|
||||
|
||||
int boxes_num = boxes.size(0);
|
||||
int64_t *keep_data = keep.data_ptr<int64_t>();
|
||||
int64_t *keep_num_data = keep_num.data_ptr<int64_t>();
|
||||
|
||||
const int col_blocks =
|
||||
(boxes_num + THREADS_PER_BLOCK_NMS - 1) / THREADS_PER_BLOCK_NMS;
|
||||
|
||||
Tensor mask =
|
||||
at::empty({boxes_num, col_blocks}, boxes.options().dtype(at::kLong));
|
||||
unsigned long long *mask_data =
|
||||
(unsigned long long *)mask.data_ptr<int64_t>();
|
||||
iou3d_nms3d_forward_impl(boxes, mask_data, boxes_num, nms_overlap_thresh);
|
||||
|
||||
at::Tensor mask_cpu = mask.to(at::kCPU);
|
||||
unsigned long long *mask_host =
|
||||
(unsigned long long *)mask_cpu.data_ptr<int64_t>();
|
||||
|
||||
std::vector<unsigned long long> remv_cpu(col_blocks);
|
||||
memset(&remv_cpu[0], 0, sizeof(unsigned long long) * col_blocks);
|
||||
|
||||
int num_to_keep = 0;
|
||||
|
||||
for (int i = 0; i < boxes_num; i++) {
|
||||
int nblock = i / THREADS_PER_BLOCK_NMS;
|
||||
int inblock = i % THREADS_PER_BLOCK_NMS;
|
||||
|
||||
if (!(remv_cpu[nblock] & (1ULL << inblock))) {
|
||||
keep_data[num_to_keep++] = i;
|
||||
unsigned long long *p = &mask_host[0] + i * col_blocks;
|
||||
for (int j = nblock; j < col_blocks; j++) {
|
||||
remv_cpu[j] |= p[j];
|
||||
}
|
||||
}
|
||||
*keep_num_data = num_to_keep;
|
||||
}
|
||||
iou3d_nms3d_forward_impl(boxes, keep, keep_num, nms_overlap_thresh);
|
||||
}
|
||||
|
||||
void iou3d_nms3d_normal_forward(Tensor boxes, Tensor keep, Tensor keep_num,
|
||||
|
@ -96,40 +62,5 @@ void iou3d_nms3d_normal_forward(Tensor boxes, Tensor keep, Tensor keep_num,
|
|||
CHECK_CONTIGUOUS(boxes);
|
||||
CHECK_CONTIGUOUS(keep);
|
||||
|
||||
int boxes_num = boxes.size(0);
|
||||
int64_t *keep_data = keep.data_ptr<int64_t>();
|
||||
int64_t *keep_num_data = keep_num.data_ptr<int64_t>();
|
||||
|
||||
const int col_blocks =
|
||||
(boxes_num + THREADS_PER_BLOCK_NMS - 1) / THREADS_PER_BLOCK_NMS;
|
||||
|
||||
Tensor mask =
|
||||
at::empty({boxes_num, col_blocks}, boxes.options().dtype(at::kLong));
|
||||
unsigned long long *mask_data =
|
||||
(unsigned long long *)mask.data_ptr<int64_t>();
|
||||
iou3d_nms3d_normal_forward_impl(boxes, mask_data, boxes_num,
|
||||
nms_overlap_thresh);
|
||||
|
||||
at::Tensor mask_cpu = mask.to(at::kCPU);
|
||||
unsigned long long *mask_host =
|
||||
(unsigned long long *)mask_cpu.data_ptr<int64_t>();
|
||||
|
||||
std::vector<unsigned long long> remv_cpu(col_blocks);
|
||||
memset(&remv_cpu[0], 0, sizeof(unsigned long long) * col_blocks);
|
||||
int num_to_keep = 0;
|
||||
|
||||
for (int i = 0; i < boxes_num; i++) {
|
||||
int nblock = i / THREADS_PER_BLOCK_NMS;
|
||||
int inblock = i % THREADS_PER_BLOCK_NMS;
|
||||
|
||||
if (!(remv_cpu[nblock] & (1ULL << inblock))) {
|
||||
keep_data[num_to_keep++] = i;
|
||||
unsigned long long *p = &mask_host[0] + i * col_blocks;
|
||||
for (int j = nblock; j < col_blocks; j++) {
|
||||
remv_cpu[j] |= p[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
*keep_num_data = num_to_keep;
|
||||
iou3d_nms3d_normal_forward_impl(boxes, keep, keep_num, nms_overlap_thresh);
|
||||
}
|
||||
|
|
|
@ -82,11 +82,11 @@ def nms3d(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor:
|
|||
order = scores.sort(0, descending=True)[1]
|
||||
boxes = boxes[order].contiguous()
|
||||
|
||||
keep = torch.zeros(boxes.size(0), dtype=torch.long)
|
||||
num_out = torch.zeros(size=(), dtype=torch.long)
|
||||
keep = boxes.new_zeros(boxes.size(0), dtype=torch.long)
|
||||
num_out = boxes.new_zeros(size=(), dtype=torch.long)
|
||||
ext_module.iou3d_nms3d_forward(
|
||||
boxes, keep, num_out, nms_overlap_thresh=iou_threshold)
|
||||
keep = order[keep[:num_out].cuda(boxes.device)].contiguous()
|
||||
keep = order[keep[:num_out].to(boxes.device)].contiguous()
|
||||
return keep
|
||||
|
||||
|
||||
|
@ -109,11 +109,11 @@ def nms3d_normal(boxes: Tensor, scores: Tensor,
|
|||
order = scores.sort(0, descending=True)[1]
|
||||
boxes = boxes[order].contiguous()
|
||||
|
||||
keep = torch.zeros(boxes.size(0), dtype=torch.long)
|
||||
num_out = torch.zeros(size=(), dtype=torch.long)
|
||||
keep = boxes.new_zeros(boxes.size(0), dtype=torch.long)
|
||||
num_out = boxes.new_zeros(size=(), dtype=torch.long)
|
||||
ext_module.iou3d_nms3d_normal_forward(
|
||||
boxes, keep, num_out, nms_overlap_thresh=iou_threshold)
|
||||
return order[keep[:num_out].cuda(boxes.device)].contiguous()
|
||||
return order[keep[:num_out].to(boxes.device)].contiguous()
|
||||
|
||||
|
||||
def _xyxyr2xywhr(boxes: Tensor) -> Tensor:
|
||||
|
|
|
@ -4,11 +4,16 @@ import pytest
|
|||
import torch
|
||||
|
||||
from mmcv.ops import boxes_iou3d, boxes_overlap_bev, nms3d, nms3d_normal
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='requires CUDA support')
|
||||
def test_boxes_overlap_bev():
|
||||
@pytest.mark.parametrize('device', [
|
||||
pytest.param(
|
||||
'cuda',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_CUDA_AVAILABLE, reason='requires CUDA support'))
|
||||
])
|
||||
def test_boxes_overlap_bev(device):
|
||||
np_boxes1 = np.asarray([[1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 0.0],
|
||||
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 0.0],
|
||||
[3.0, 3.0, 3.0, 3.0, 2.0, 2.0, 0.0]],
|
||||
|
@ -22,8 +27,8 @@ def test_boxes_overlap_bev():
|
|||
(3 + 2 * 2**0.5)], [1.0, 1.0, 1.0], [0.0, 0.0, 0.0]],
|
||||
dtype=np.float32)
|
||||
|
||||
boxes1 = torch.from_numpy(np_boxes1).cuda()
|
||||
boxes2 = torch.from_numpy(np_boxes2).cuda()
|
||||
boxes1 = torch.from_numpy(np_boxes1).to(device)
|
||||
boxes2 = torch.from_numpy(np_boxes2).to(device)
|
||||
|
||||
# test for 3 boxes
|
||||
overlaps = boxes_overlap_bev(boxes1, boxes2)
|
||||
|
@ -37,9 +42,13 @@ def test_boxes_overlap_bev():
|
|||
overlaps.cpu().numpy(), np_expect_overlaps.repeat(555, 1), atol=1e-4)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='requires CUDA support')
|
||||
def test_boxes_iou3d():
|
||||
@pytest.mark.parametrize('device', [
|
||||
pytest.param(
|
||||
'cuda',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_CUDA_AVAILABLE, reason='requires CUDA support'))
|
||||
])
|
||||
def test_boxes_iou3d(device):
|
||||
np_boxes1 = np.asarray([[1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 0.0],
|
||||
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 0.0],
|
||||
[3.0, 3.0, 3.0, 3.0, 2.0, 2.0, 0.0]],
|
||||
|
@ -53,16 +62,20 @@ def test_boxes_iou3d():
|
|||
[0.0, 0.0, 0.0]],
|
||||
dtype=np.float32)
|
||||
|
||||
boxes1 = torch.from_numpy(np_boxes1).cuda()
|
||||
boxes2 = torch.from_numpy(np_boxes2).cuda()
|
||||
boxes1 = torch.from_numpy(np_boxes1).to(device)
|
||||
boxes2 = torch.from_numpy(np_boxes2).to(device)
|
||||
|
||||
ious = boxes_iou3d(boxes1, boxes2)
|
||||
assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='requires CUDA support')
|
||||
def test_nms3d():
|
||||
@pytest.mark.parametrize('device', [
|
||||
pytest.param(
|
||||
'cuda',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_CUDA_AVAILABLE, reason='requires CUDA support'))
|
||||
])
|
||||
def test_nms3d(device):
|
||||
# test for 5 boxes
|
||||
np_boxes = np.asarray([[1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 0.0],
|
||||
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 0.0],
|
||||
|
@ -74,7 +87,7 @@ def test_nms3d():
|
|||
np_inds = np.array([1, 0, 3])
|
||||
boxes = torch.from_numpy(np_boxes)
|
||||
scores = torch.from_numpy(np_scores)
|
||||
inds = nms3d(boxes.cuda(), scores.cuda(), iou_threshold=0.3)
|
||||
inds = nms3d(boxes.to(device), scores.to(device), iou_threshold=0.3)
|
||||
|
||||
assert np.allclose(inds.cpu().numpy(), np_inds)
|
||||
|
||||
|
@ -84,14 +97,18 @@ def test_nms3d():
|
|||
np_scores = np.random.rand(555).astype(np.float32)
|
||||
boxes = torch.from_numpy(np_boxes)
|
||||
scores = torch.from_numpy(np_scores)
|
||||
inds = nms3d(boxes.cuda(), scores.cuda(), iou_threshold=0.3)
|
||||
inds = nms3d(boxes.to(device), scores.to(device), iou_threshold=0.3)
|
||||
|
||||
assert len(inds.cpu().numpy()) == 176
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='requires CUDA support')
|
||||
def test_nms3d_normal():
|
||||
@pytest.mark.parametrize('device', [
|
||||
pytest.param(
|
||||
'cuda',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_CUDA_AVAILABLE, reason='requires CUDA support'))
|
||||
])
|
||||
def test_nms3d_normal(device):
|
||||
# test for 5 boxes
|
||||
np_boxes = np.asarray([[1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 0.0],
|
||||
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 0.0],
|
||||
|
@ -103,7 +120,7 @@ def test_nms3d_normal():
|
|||
np_inds = np.array([1, 0, 3])
|
||||
boxes = torch.from_numpy(np_boxes)
|
||||
scores = torch.from_numpy(np_scores)
|
||||
inds = nms3d_normal(boxes.cuda(), scores.cuda(), iou_threshold=0.3)
|
||||
inds = nms3d_normal(boxes.to(device), scores.to(device), iou_threshold=0.3)
|
||||
|
||||
assert np.allclose(inds.cpu().numpy(), np_inds)
|
||||
|
||||
|
@ -113,6 +130,6 @@ def test_nms3d_normal():
|
|||
np_scores = np.random.rand(555).astype(np.float32)
|
||||
boxes = torch.from_numpy(np_boxes)
|
||||
scores = torch.from_numpy(np_scores)
|
||||
inds = nms3d_normal(boxes.cuda(), scores.cuda(), iou_threshold=0.3)
|
||||
inds = nms3d_normal(boxes.to(device), scores.to(device), iou_threshold=0.3)
|
||||
|
||||
assert len(inds.cpu().numpy()) == 148
|
||||
|
|
Loading…
Reference in New Issue