[Fix] Fix bbox overlap fp16 (#1958)

* add CUDA_ARCH check

* add check in cuh
pull/1972/head v1.5.1
q.yao 2022-05-14 19:46:27 +08:00 committed by GitHub
parent a3b4640be8
commit 8708851eca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 0 deletions

View File

@ -88,6 +88,7 @@ __global__ void bbox_overlaps_cuda_kernel(const T* bbox1, const T* bbox2,
}
}
#if __CUDA_ARCH__ >= 530
__device__ __forceinline__ __half __half_area(const __half x1, const __half y1,
const __half x2, const __half y2,
const __half offset) {
@ -141,5 +142,6 @@ __device__ void bbox_overlaps_cuda_kernel_half(
ious[index] = __hdiv(interS, baseS);
}
}
#endif // __CUDA_ARCH__ >= 530
#endif // BBOX_OVERLAPS_CUDA_KERNEL_CUH

View File

@ -4,6 +4,7 @@
// Disable fp16 on ROCm device
#ifndef HIP_DIFF
#if __CUDA_ARCH__ >= 530
template <>
__global__ void bbox_overlaps_cuda_kernel<at::Half>(
const at::Half* bbox1, const at::Half* bbox2, at::Half* ious,
@ -14,6 +15,7 @@ __global__ void bbox_overlaps_cuda_kernel<at::Half>(
reinterpret_cast<__half*>(ious), num_bbox1,
num_bbox2, mode, aligned, offset);
}
#endif // __CUDA_ARCH__ >= 530
#endif // HIP_DIFF
void BBoxOverlapsCUDAKernelLauncher(const Tensor bboxes1, const Tensor bboxes2,