From 4c4ba6cb174526cd62fb18aab7df7868b230a549 Mon Sep 17 00:00:00 2001 From: liuhw <71241939+dflhw@users.noreply.github.com> Date: Fri, 24 Mar 2023 19:08:57 +0800 Subject: [PATCH] [Fix] Force bbox_overlaps calculation with FP32 for ascend device (#2697) * modify bbox_overlaps op adapter * update --- .../csrc/pytorch/npu/bbox_overlaps_npu.cpp | 32 ++++++++++++------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/npu/bbox_overlaps_npu.cpp b/mmcv/ops/csrc/pytorch/npu/bbox_overlaps_npu.cpp index ebe190f7e..59d0f1ef7 100644 --- a/mmcv/ops/csrc/pytorch/npu/bbox_overlaps_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/bbox_overlaps_npu.cpp @@ -12,23 +12,33 @@ void bbox_overlaps_npu(const Tensor bboxes1, const Tensor bboxes2, Tensor ious, if (mode == 1) { modeStr = "iof"; } - float offset_ = 1; - if (offset == 0) { - offset_ = 0.01; + at::Tensor bboxesFP32 = bboxes2; + at::Tensor gtboxesFP32 = bboxes1; + if (bboxes2.scalar_type() != at::ScalarType::Float) { + bboxesFP32 = NPUNativeFunctions::npu_dtype_cast(bboxes2, at::kFloat); + gtboxesFP32 = NPUNativeFunctions::npu_dtype_cast(bboxes1, at::kFloat); } - at::Tensor bboxes = at::ones_like(bboxes2); - at::Tensor gtboxes = at::ones_like(bboxes1); - bboxes = aligned ? bboxes2.transpose(0, 1) : bboxes2; - gtboxes = aligned ? bboxes1.transpose(0, 1) : bboxes1; + c10::SmallVector iousSize = {gtboxesFP32.size(0), + bboxesFP32.size(0)}; + if (aligned) { + iousSize = {gtboxesFP32.size(0), 1}; + } + at::Tensor iousFP32 = OpPreparation::ApplyTensor(bboxesFP32, iousSize); + bboxesFP32 = aligned ? bboxesFP32.transpose(0, 1) : bboxesFP32; + gtboxesFP32 = aligned ? gtboxesFP32.transpose(0, 1) : gtboxesFP32; OpCommand cmd; cmd.Name("Iou") - .Input(bboxes) - .Input(gtboxes) - .Output(ious) + .Input(bboxesFP32) + .Input(gtboxesFP32) + .Output(iousFP32) .Attr("mode", modeStr) - .Attr("eps", offset_) + .Attr("eps", (float)offset) .Attr("aligned", aligned) .Run(); + if (bboxes2.scalar_type() != at::ScalarType::Float) { + iousFP32 = NPUNativeFunctions::npu_dtype_cast(iousFP32, at::kHalf); + } + ious.copy_(iousFP32); } REGISTER_NPU_IMPL(bbox_overlaps_impl, bbox_overlaps_npu);