mirror of https://github.com/open-mmlab/mmcv.git
parent
a3b4640be8
commit
8708851eca
|
@ -88,6 +88,7 @@ __global__ void bbox_overlaps_cuda_kernel(const T* bbox1, const T* bbox2,
|
|||
}
|
||||
}
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
__device__ __forceinline__ __half __half_area(const __half x1, const __half y1,
|
||||
const __half x2, const __half y2,
|
||||
const __half offset) {
|
||||
|
@ -141,5 +142,6 @@ __device__ void bbox_overlaps_cuda_kernel_half(
|
|||
ious[index] = __hdiv(interS, baseS);
|
||||
}
|
||||
}
|
||||
#endif // __CUDA_ARCH__ >= 530
|
||||
|
||||
#endif // BBOX_OVERLAPS_CUDA_KERNEL_CUH
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
|
||||
// Disable fp16 on ROCm device
|
||||
#ifndef HIP_DIFF
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
template <>
|
||||
__global__ void bbox_overlaps_cuda_kernel<at::Half>(
|
||||
const at::Half* bbox1, const at::Half* bbox2, at::Half* ious,
|
||||
|
@ -14,6 +15,7 @@ __global__ void bbox_overlaps_cuda_kernel<at::Half>(
|
|||
reinterpret_cast<__half*>(ious), num_bbox1,
|
||||
num_bbox2, mode, aligned, offset);
|
||||
}
|
||||
#endif // __CUDA_ARCH__ >= 530
|
||||
#endif // HIP_DIFF
|
||||
|
||||
void BBoxOverlapsCUDAKernelLauncher(const Tensor bboxes1, const Tensor bboxes2,
|
||||
|
|
Loading…
Reference in New Issue