mirror of https://github.com/open-mmlab/mmcv.git
NMS with CUDA only (#1824)
* add gather_keep_from_mask_parallize * remove unused cache * move syncthread * remove unused comment * add more comments, rename the kernel and variablepull/1881/head
parent
3270caf6af
commit
74031cc508
|
@ -72,4 +72,46 @@ __global__ void nms_cuda(const int n_boxes, const float iou_threshold,
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void gather_keep_from_mask(bool *keep,
|
||||
const unsigned long long *dev_mask,
|
||||
const int n_boxes) {
|
||||
const int col_blocks = (n_boxes + threadsPerBlock - 1) / threadsPerBlock;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
// mark the bboxes which have been removed.
|
||||
extern __shared__ unsigned long long removed[];
|
||||
|
||||
// initialize removed.
|
||||
for (int i = tid; i < col_blocks; i += blockDim.x) {
|
||||
removed[i] = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int nblock = 0; nblock < col_blocks; ++nblock) {
|
||||
auto removed_val = removed[nblock];
|
||||
__syncthreads();
|
||||
const int i_offset = nblock * threadsPerBlock;
|
||||
#pragma unroll
|
||||
for (int inblock = 0; inblock < threadsPerBlock; ++inblock) {
|
||||
const int i = i_offset + inblock;
|
||||
if (i >= n_boxes) break;
|
||||
// select a candidate, check if it should kept.
|
||||
if (!(removed_val & (1ULL << inblock))) {
|
||||
if (tid == 0) {
|
||||
// mark the output.
|
||||
keep[i] = true;
|
||||
}
|
||||
auto p = dev_mask + i * col_blocks;
|
||||
// remove all bboxes which overlap the candidate.
|
||||
for (int j = tid; j < col_blocks; j += blockDim.x) {
|
||||
if (j >= nblock) removed[j] |= p[j];
|
||||
}
|
||||
__syncthreads();
|
||||
removed_val = removed[nblock];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif // NMS_CUDA_KERNEL_CUH
|
||||
|
|
|
@ -24,31 +24,13 @@ Tensor NMSCUDAKernelLauncher(Tensor boxes, Tensor scores, float iou_threshold,
|
|||
boxes_num, iou_threshold, offset, boxes_sorted.data_ptr<float>(),
|
||||
(unsigned long long*)mask.data_ptr<int64_t>());
|
||||
|
||||
at::Tensor mask_cpu = mask.to(at::kCPU);
|
||||
unsigned long long* mask_host =
|
||||
(unsigned long long*)mask_cpu.data_ptr<int64_t>();
|
||||
|
||||
std::vector<unsigned long long> remv(col_blocks);
|
||||
memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);
|
||||
|
||||
at::Tensor keep_t =
|
||||
at::zeros({boxes_num}, boxes.options().dtype(at::kBool).device(at::kCPU));
|
||||
bool* keep = keep_t.data_ptr<bool>();
|
||||
|
||||
for (int i = 0; i < boxes_num; i++) {
|
||||
int nblock = i / threadsPerBlock;
|
||||
int inblock = i % threadsPerBlock;
|
||||
|
||||
if (!(remv[nblock] & (1ULL << inblock))) {
|
||||
keep[i] = true;
|
||||
// set every overlap box with bit 1 in remv
|
||||
unsigned long long* p = mask_host + i * col_blocks;
|
||||
for (int j = nblock; j < col_blocks; j++) {
|
||||
remv[j] |= p[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Filter the boxes which should be kept.
|
||||
at::Tensor keep_t = at::zeros(
|
||||
{boxes_num}, boxes.options().dtype(at::kBool).device(at::kCUDA));
|
||||
gather_keep_from_mask<<<1, min(col_blocks, THREADS_PER_BLOCK),
|
||||
col_blocks * sizeof(unsigned long long), stream>>>(
|
||||
keep_t.data_ptr<bool>(), (unsigned long long*)mask.data_ptr<int64_t>(),
|
||||
boxes_num);
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return order_t.masked_select(keep_t.to(at::kCUDA));
|
||||
return order_t.masked_select(keep_t);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue