mirror of https://github.com/open-mmlab/mmcv.git
[Feature] Add Ascend support for bbox_overlaps (#2580)
* Add support for Ascend devices with bbox_overlaps * ... * ...pull/2541/head^2
parent
615d28176a
commit
5a61e534ea
|
@ -7,7 +7,7 @@ We implement common ops used in detection, segmentation, etc.
|
|||
| ActiveRotatedFilter | √ | √ | | | |
|
||||
| AssignScoreWithK | | √ | | | |
|
||||
| BallQuery | | √ | √ | | |
|
||||
| BBoxOverlaps | | √ | √ | √ | |
|
||||
| BBoxOverlaps | | √ | √ | √ | √ |
|
||||
| BorderAlign | | √ | | | |
|
||||
| BoxIouRotated | √ | √ | | | |
|
||||
| BoxIouQuadri | √ | √ | | | |
|
||||
|
|
|
@ -7,7 +7,7 @@ MMCV 提供了检测、分割等任务中常用的算子
|
|||
| ActiveRotatedFilter | √ | √ | | | |
|
||||
| AssignScoreWithK | | √ | | | |
|
||||
| BallQuery | | √ | √ | | |
|
||||
| BBoxOverlaps | | √ | √ | √ | |
|
||||
| BBoxOverlaps | | √ | √ | √ | √ |
|
||||
| BorderAlign | | √ | | | |
|
||||
| BoxIouRotated | √ | √ | | | |
|
||||
| BoxIouQuadri | √ | √ | | | |
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
#include "pytorch_npu_helper.hpp"
|
||||
|
||||
using namespace NPU_NAME_SPACE;
|
||||
using namespace std;
|
||||
|
||||
void bbox_overlaps_impl(const Tensor bboxes1, const Tensor bboxes2, Tensor ious,
|
||||
const int mode, const bool aligned, const int offset);
|
||||
|
||||
void bbox_overlaps_npu(const Tensor bboxes1, const Tensor bboxes2, Tensor ious,
|
||||
const int mode, const bool aligned, const int offset) {
|
||||
string modeStr = "iou";
|
||||
if (mode == 1) {
|
||||
modeStr = "iof";
|
||||
}
|
||||
at::Tensor bboxes = at::ones_like(bboxes2);
|
||||
at::Tensor gtboxes = at::ones_like(bboxes1);
|
||||
bboxes = aligned ? bboxes2.transpose(0, 1) : bboxes2;
|
||||
gtboxes = aligned ? bboxes1.transpose(0, 1) : bboxes1;
|
||||
OpCommand cmd;
|
||||
cmd.Name("Iou")
|
||||
.Input(bboxes)
|
||||
.Input(gtboxes)
|
||||
.Output(ious)
|
||||
.Attr("mode", modeStr)
|
||||
.Attr("eps", (float)offset)
|
||||
.Attr("aligned", aligned)
|
||||
.Run();
|
||||
}
|
||||
|
||||
REGISTER_NPU_IMPL(bbox_overlaps_impl, bbox_overlaps_npu);
|
|
@ -3,7 +3,8 @@ import numpy as np
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MPS_AVAILABLE
|
||||
from mmcv.utils import (IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MPS_AVAILABLE,
|
||||
IS_NPU_AVAILABLE)
|
||||
|
||||
|
||||
class TestBBox:
|
||||
|
@ -55,7 +56,11 @@ class TestBBox:
|
|||
pytest.param(
|
||||
'mps',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_MPS_AVAILABLE, reason='requires MPS support'))
|
||||
not IS_MPS_AVAILABLE, reason='requires MPS support')),
|
||||
pytest.param(
|
||||
'npu',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_NPU_AVAILABLE, reason='requires NPU support'))
|
||||
])
|
||||
def test_bbox_overlaps_float(self, device):
|
||||
self._test_bbox_overlaps(device, dtype=torch.float)
|
||||
|
@ -68,7 +73,11 @@ class TestBBox:
|
|||
pytest.param(
|
||||
'mlu',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_MLU_AVAILABLE, reason='requires MLU support'))
|
||||
not IS_MLU_AVAILABLE, reason='requires MLU support')),
|
||||
pytest.param(
|
||||
'npu',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_NPU_AVAILABLE, reason='requires NPU support'))
|
||||
])
|
||||
def test_bbox_overlaps_half(self, device):
|
||||
self._test_bbox_overlaps(device, dtype=torch.half)
|
||||
|
|
Loading…
Reference in New Issue