mirror of https://github.com/open-mmlab/mmcv.git
[Feature] Add C++ implementation for bbox_overlaps (#2477)
* add ops bbox_overlaps * format code * Return the pytorch version * Intermediate modification * Solve problems in parameter passing * revise bug * "add test case"pull/2574/head
parent
db391c50a3
commit
422816e45c
|
@ -106,25 +106,17 @@ def bbox_overlaps(bboxes1: torch.Tensor,
|
|||
|
||||
rows = bboxes1.size(0)
|
||||
cols = bboxes2.size(0)
|
||||
|
||||
if aligned:
|
||||
assert rows == cols
|
||||
ious = bboxes1.new_zeros(rows)
|
||||
else:
|
||||
ious = bboxes1.new_zeros((rows, cols))
|
||||
|
||||
if rows * cols == 0:
|
||||
return bboxes1.new(rows, 1) if aligned else bboxes1.new(rows, cols)
|
||||
|
||||
if bboxes1.device.type == 'cpu':
|
||||
return _bbox_overlaps_cpu(
|
||||
bboxes1, bboxes2, mode=mode, aligned=aligned, offset=offset)
|
||||
else:
|
||||
if aligned:
|
||||
ious = bboxes1.new_zeros(rows)
|
||||
else:
|
||||
ious = bboxes1.new_zeros((rows, cols))
|
||||
ext_module.bbox_overlaps(
|
||||
bboxes1,
|
||||
bboxes2,
|
||||
ious,
|
||||
mode=mode_flag,
|
||||
aligned=aligned,
|
||||
offset=offset)
|
||||
return ious
|
||||
|
||||
ext_module.bbox_overlaps(
|
||||
bboxes1, bboxes2, ious, mode=mode_flag, aligned=aligned, offset=offset)
|
||||
|
||||
return ious
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
// Copyright(c) OpenMMLab.All rights reserved.
|
||||
#include "pytorch_cpp_helper.hpp"
|
||||
#include "pytorch_device_registry.hpp"
|
||||
|
||||
using torch::indexing::None;
|
||||
using torch::indexing::Slice;
|
||||
|
||||
void bbox_overlaps_cpu_kernel(const Tensor boxes1, const Tensor boxes2,
|
||||
Tensor ious, const int mode_flag,
|
||||
const bool aligned, const int offset) {
|
||||
Tensor temp_ious;
|
||||
if (aligned) {
|
||||
Tensor lt = torch::max(boxes1.index({Slice(None), Slice({None, 2})}),
|
||||
boxes2.index({Slice(None), Slice({None, 2})}));
|
||||
Tensor rb = torch::min(boxes1.index({Slice(None), Slice(2)}),
|
||||
boxes2.index({Slice(None), Slice(2)}));
|
||||
Tensor wh = (rb - lt + offset).clamp(0.f, INT_MAX * 1.f);
|
||||
Tensor overlap = wh.index({Slice(None), 0}) * wh.index({Slice(None), 1});
|
||||
Tensor area1 = (boxes1.index({Slice(None), 2}) -
|
||||
boxes1.index({Slice(None), 0}) + offset) *
|
||||
(boxes1.index({Slice(None), 3}) -
|
||||
boxes1.index({Slice(None), 1}) + offset);
|
||||
if (mode_flag == 0) {
|
||||
Tensor area2 = (boxes2.index({Slice(None), 2}) -
|
||||
boxes2.index({Slice(None), 0}) + offset) *
|
||||
(boxes2.index({Slice(None), 3}) -
|
||||
boxes2.index({Slice(None), 1}) + offset);
|
||||
temp_ious = overlap / (area1 + area2 - overlap);
|
||||
} else {
|
||||
temp_ious = overlap / area1;
|
||||
}
|
||||
} else {
|
||||
Tensor lt = torch::max(boxes1.index({Slice(None), None, Slice({None, 2})}),
|
||||
boxes2.index({Slice(None), Slice({None, 2})}));
|
||||
Tensor rb = torch::min(boxes1.index({Slice(None), None, Slice(2)}),
|
||||
boxes2.index({Slice(None), Slice(2)}));
|
||||
Tensor wh = (rb - lt + offset).clamp(0.f, INT_MAX * 1.f);
|
||||
Tensor overlap = wh.index({"...", 0}) * wh.index({"...", 1});
|
||||
Tensor area1 = (boxes1.index({Slice(None), 2}) -
|
||||
boxes1.index({Slice(None), 0}) + offset) *
|
||||
(boxes1.index({Slice(None), 3}) -
|
||||
boxes1.index({Slice(None), 1}) + offset);
|
||||
if (mode_flag == 0) {
|
||||
Tensor area2 = (boxes2.index({Slice(None), 2}) -
|
||||
boxes2.index({Slice(None), 0}) + offset) *
|
||||
(boxes2.index({Slice(None), 3}) -
|
||||
boxes2.index({Slice(None), 1}) + offset);
|
||||
temp_ious =
|
||||
overlap / (area1.index({Slice(None), None}) + area2 - overlap);
|
||||
} else {
|
||||
temp_ious = overlap / area1.index({Slice(None), None});
|
||||
}
|
||||
}
|
||||
ious.copy_(temp_ious);
|
||||
}
|
||||
|
||||
void bbox_overlaps_cpu(const Tensor boxes1, const Tensor boxes2, Tensor ious,
|
||||
const int mode, const bool aligned, const int offset) {
|
||||
bbox_overlaps_cpu_kernel(boxes1, boxes2, ious, mode, aligned, offset);
|
||||
}
|
||||
|
||||
void bbox_overlaps_impl(const Tensor boxes1, const Tensor boxes2, Tensor ious,
|
||||
const int mode, const bool aligned, const int offset);
|
||||
|
||||
REGISTER_DEVICE_IMPL(bbox_overlaps_impl, CPU, bbox_overlaps_cpu);
|
|
@ -34,6 +34,14 @@ class TestBBox:
|
|||
out = bbox_overlaps(b1, b2, offset=1)
|
||||
assert np.allclose(out.cpu().numpy(), should_output, 1e-2)
|
||||
|
||||
b1 = torch.tensor([[10.0 + i, 10.0 + i, 30.0 + i, 30.0 + i]
|
||||
for i in range(1000)]).to(device).type(dtype)
|
||||
b2 = torch.tensor([[20.0 + i, 20.0 + i, 40.0 + i, 40.0 + i]
|
||||
for i in range(1000)]).to(device).type(dtype)
|
||||
should_output = np.array([1 / 7] * 1000)
|
||||
out = bbox_overlaps(b1, b2, aligned=True)
|
||||
assert np.allclose(out.cpu().numpy(), should_output, 1e-2)
|
||||
|
||||
@pytest.mark.parametrize('device', [
|
||||
'cpu',
|
||||
pytest.param(
|
||||
|
|
Loading…
Reference in New Issue