diff --git a/docs/en/understand_mmcv/ops.md b/docs/en/understand_mmcv/ops.md index dedcdb4ed..2ee4380ae 100644 --- a/docs/en/understand_mmcv/ops.md +++ b/docs/en/understand_mmcv/ops.md @@ -7,7 +7,7 @@ We implement common ops used in detection, segmentation, etc. | ActiveRotatedFilter | √ | √ | | | | | AssignScoreWithK | | √ | | | | | BallQuery | | √ | √ | | | -| BBoxOverlaps | | √ | √ | √ | | +| BBoxOverlaps | | √ | √ | √ | √ | | BorderAlign | | √ | | | | | BoxIouRotated | √ | √ | | | | | BoxIouQuadri | √ | √ | | | | diff --git a/docs/zh_cn/understand_mmcv/ops.md b/docs/zh_cn/understand_mmcv/ops.md index ddf68d79f..862946e7c 100644 --- a/docs/zh_cn/understand_mmcv/ops.md +++ b/docs/zh_cn/understand_mmcv/ops.md @@ -7,7 +7,7 @@ MMCV 提供了检测、分割等任务中常用的算子 | ActiveRotatedFilter | √ | √ | | | | | AssignScoreWithK | | √ | | | | | BallQuery | | √ | √ | | | -| BBoxOverlaps | | √ | √ | √ | | +| BBoxOverlaps | | √ | √ | √ | √ | | BorderAlign | | √ | | | | | BoxIouRotated | √ | √ | | | | | BoxIouQuadri | √ | √ | | | | diff --git a/mmcv/ops/csrc/pytorch/npu/bbox_overlaps_npu.cpp b/mmcv/ops/csrc/pytorch/npu/bbox_overlaps_npu.cpp new file mode 100644 index 000000000..667d47125 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/npu/bbox_overlaps_npu.cpp @@ -0,0 +1,30 @@ +#include "pytorch_npu_helper.hpp" + +using namespace NPU_NAME_SPACE; +using namespace std; + +void bbox_overlaps_impl(const Tensor bboxes1, const Tensor bboxes2, Tensor ious, + const int mode, const bool aligned, const int offset); + +void bbox_overlaps_npu(const Tensor bboxes1, const Tensor bboxes2, Tensor ious, + const int mode, const bool aligned, const int offset) { + string modeStr = "iou"; + if (mode == 1) { + modeStr = "iof"; + } + at::Tensor bboxes = at::ones_like(bboxes2); + at::Tensor gtboxes = at::ones_like(bboxes1); + bboxes = aligned ? bboxes2.transpose(0, 1) : bboxes2; + gtboxes = aligned ? bboxes1.transpose(0, 1) : bboxes1; + OpCommand cmd; + cmd.Name("Iou") + .Input(bboxes) + .Input(gtboxes) + .Output(ious) + .Attr("mode", modeStr) + .Attr("eps", (float)offset) + .Attr("aligned", aligned) + .Run(); +} + +REGISTER_NPU_IMPL(bbox_overlaps_impl, bbox_overlaps_npu); diff --git a/tests/test_ops/test_bbox.py b/tests/test_ops/test_bbox.py index 1b522ee3c..3d1486eb0 100644 --- a/tests/test_ops/test_bbox.py +++ b/tests/test_ops/test_bbox.py @@ -3,7 +3,8 @@ import numpy as np import pytest import torch -from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MPS_AVAILABLE +from mmcv.utils import (IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MPS_AVAILABLE, + IS_NPU_AVAILABLE) class TestBBox: @@ -55,7 +56,11 @@ class TestBBox: pytest.param( 'mps', marks=pytest.mark.skipif( - not IS_MPS_AVAILABLE, reason='requires MPS support')) + not IS_MPS_AVAILABLE, reason='requires MPS support')), + pytest.param( + 'npu', + marks=pytest.mark.skipif( + not IS_NPU_AVAILABLE, reason='requires NPU support')) ]) def test_bbox_overlaps_float(self, device): self._test_bbox_overlaps(device, dtype=torch.float) @@ -68,7 +73,11 @@ class TestBBox: pytest.param( 'mlu', marks=pytest.mark.skipif( - not IS_MLU_AVAILABLE, reason='requires MLU support')) + not IS_MLU_AVAILABLE, reason='requires MLU support')), + pytest.param( + 'npu', + marks=pytest.mark.skipif( + not IS_NPU_AVAILABLE, reason='requires NPU support')) ]) def test_bbox_overlaps_half(self, device): self._test_bbox_overlaps(device, dtype=torch.half)