pull/2899/head
nijkah 2023-11-01 15:31:38 +09:00
parent f09f814d69
commit 15feb5978b
1 changed files with 4 additions and 4 deletions

View File

@ -17,12 +17,12 @@ void box_iou_rotated_cuda(const Tensor boxes1, const Tensor boxes2, Tensor ious,
at::cuda::CUDAGuard device_guard(boxes1.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
boxes1.scalar_type(), "box_iou_rotated_cuda_kernel", [&] {
boxes1.scalar_type(), "box_iou_rotated_cuda_kernel", [&] {
box_iou_rotated_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
num_boxes1, num_boxes2, boxes1.data_ptr<scalar_t>(),
boxes2.data_ptr<scalar_t>(), (scalar_t*)ious.data_ptr<scalar_t>(),
mode_flag, aligned);
});
boxes2.data_ptr<scalar_t>(),
(scalar_t*)ious.data_ptr<scalar_t>(), mode_flag, aligned);
});
AT_CUDA_CHECK(cudaGetLastError());
}