mirror of https://github.com/open-mmlab/mmcv.git
[Enhancement] Add the dtype limit of nms_npu to maintain consistency with the GPU (#2724)
* Increase the dtype limit to maintain consistency with the gpu. * update error message Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> --------- Co-authored-by: momo609 <963372609.qq.com> Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>pull/2726/head
parent
8ceac9347e
commit
f946a933bb
|
@ -4,6 +4,8 @@ using namespace NPU_NAME_SPACE;
|
|||
using namespace std;
|
||||
|
||||
Tensor nms_npu(Tensor boxes, Tensor scores, float iou_threshold, int offset) {
|
||||
TORCH_CHECK((boxes.scalar_type == at::ScalarType::Float),
|
||||
"The type of boxes tensor passed in nms_npu should be float");
|
||||
int64_t offset_64 = offset;
|
||||
at::Tensor iou_threshold_y = at_npu::native::OpPreparation::ApplyTensor(
|
||||
{}, boxes.options().dtype(at::kFloat), boxes)
|
||||
|
|
Loading…
Reference in New Issue