mirror of https://github.com/open-mmlab/mmcv.git
Fix furthest_sample_point (#1405)
parent
5c25ae1a19
commit
b484abac09
|
@ -67,10 +67,11 @@ __global__ void furthest_point_sampling_forward_cuda_kernel(
|
|||
dists_i[tid] = besti;
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int block_size_thres = 1024; block_size_thres >= 2;
|
||||
block_size_thres /= 2) {
|
||||
int tid_thres = block_size_thres / 2;
|
||||
if (block_size >= block_size_thres) {
|
||||
block_size_thres >>= 1) {
|
||||
const int tid_thres = block_size_thres / 2;
|
||||
if (block_size >= block_size_thres && tid < tid_thres) {
|
||||
__update(dists, dists_i, tid, tid + tid_thres);
|
||||
}
|
||||
__syncthreads();
|
||||
|
@ -133,10 +134,11 @@ __global__ void furthest_point_sampling_with_dist_forward_cuda_kernel(
|
|||
dists_i[tid] = besti;
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int block_size_thres = 1024; block_size_thres >= 2;
|
||||
block_size_thres /= 2) {
|
||||
int tid_thres = block_size_thres / 2;
|
||||
if (block_size >= block_size_thres) {
|
||||
block_size_thres >>= 1) {
|
||||
const int tid_thres = block_size_thres / 2;
|
||||
if (block_size >= block_size_thres && tid < tid_thres) {
|
||||
__update(dists, dists_i, tid, tid + tid_thres);
|
||||
}
|
||||
__syncthreads();
|
||||
|
|
Loading…
Reference in New Issue