mirror of
https://github.com/RE-OWOD/RE-OWOD.git
synced 2025-06-03 14:59:31 +08:00
40 lines
1.1 KiB
C
40 lines
1.1 KiB
C
|
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||
|
#pragma once
|
||
|
#include <torch/types.h>
|
||
|
|
||
|
namespace detectron2 {
|
||
|
|
||
|
at::Tensor nms_rotated_cpu(
|
||
|
const at::Tensor& dets,
|
||
|
const at::Tensor& scores,
|
||
|
const float iou_threshold);
|
||
|
|
||
|
#if defined(WITH_CUDA) || defined(WITH_HIP)
|
||
|
at::Tensor nms_rotated_cuda(
|
||
|
const at::Tensor& dets,
|
||
|
const at::Tensor& scores,
|
||
|
const float iou_threshold);
|
||
|
#endif
|
||
|
|
||
|
// Interface for Python
|
||
|
// inline is needed to prevent multiple function definitions when this header is
|
||
|
// included by different cpps
|
||
|
inline at::Tensor nms_rotated(
|
||
|
const at::Tensor& dets,
|
||
|
const at::Tensor& scores,
|
||
|
const float iou_threshold) {
|
||
|
assert(dets.device().is_cuda() == scores.device().is_cuda());
|
||
|
if (dets.device().is_cuda()) {
|
||
|
#if defined(WITH_CUDA) || defined(WITH_HIP)
|
||
|
return nms_rotated_cuda(
|
||
|
dets.contiguous(), scores.contiguous(), iou_threshold);
|
||
|
#else
|
||
|
AT_ERROR("Not compiled with GPU support");
|
||
|
#endif
|
||
|
}
|
||
|
|
||
|
return nms_rotated_cpu(dets.contiguous(), scores.contiguous(), iou_threshold);
|
||
|
}
|
||
|
|
||
|
} // namespace detectron2
|