Fix furthest_sample_point (#1405)

pull/1408/head
q.yao 2021-10-15 17:30:04 +08:00 committed by GitHub
parent 5c25ae1a19
commit b484abac09
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 8 additions and 6 deletions

View File

@ -67,10 +67,11 @@ __global__ void furthest_point_sampling_forward_cuda_kernel(
dists_i[tid] = besti;
__syncthreads();
#pragma unroll
for (int block_size_thres = 1024; block_size_thres >= 2;
block_size_thres /= 2) {
int tid_thres = block_size_thres / 2;
if (block_size >= block_size_thres) {
block_size_thres >>= 1) {
const int tid_thres = block_size_thres / 2;
if (block_size >= block_size_thres && tid < tid_thres) {
__update(dists, dists_i, tid, tid + tid_thres);
}
__syncthreads();
@ -133,10 +134,11 @@ __global__ void furthest_point_sampling_with_dist_forward_cuda_kernel(
dists_i[tid] = besti;
__syncthreads();
#pragma unroll
for (int block_size_thres = 1024; block_size_thres >= 2;
block_size_thres /= 2) {
int tid_thres = block_size_thres / 2;
if (block_size >= block_size_thres) {
block_size_thres >>= 1) {
const int tid_thres = block_size_thres / 2;
if (block_size >= block_size_thres && tid < tid_thres) {
__update(dists, dists_i, tid, tid + tid_thres);
}
__syncthreads();