mirror of https://github.com/open-mmlab/mmcv.git
Support aligned mode for box_iou_rotated (#677)
* support aligned and parrots cpu for box_iou_roatetd * add aligned doc * fix lint * fix lint * fix lint * fix lint * fix bug * fix bug * fix bug * fix lint * fix lint * fix bug * fix bugpull/729/head
parent
8008f475ab
commit
508a322fba
|
@ -1,30 +1,36 @@
|
|||
import torch
|
||||
|
||||
from ..utils import ext_loader
|
||||
|
||||
ext_module = ext_loader.load_ext('_ext', ['box_iou_rotated'])
|
||||
|
||||
|
||||
def box_iou_rotated(bboxes1, bboxes2):
|
||||
def box_iou_rotated(bboxes1, bboxes2, aligned=False):
|
||||
"""Return intersection-over-union (Jaccard index) of boxes.
|
||||
|
||||
Both sets of boxes are expected to be in
|
||||
(x_center, y_center, width, height, angle) format.
|
||||
|
||||
If ``aligned`` is ``False``, then calculate the ious between each bbox
|
||||
of bboxes1 and bboxes2, otherwise the ious between each aligned pair of
|
||||
bboxes1 and bboxes2.
|
||||
|
||||
Arguments:
|
||||
boxes1 (Tensor): rotated bboxes 1. \
|
||||
It has shape (N, 5), indicating (x, y, w, h, theta) for each row.
|
||||
boxes2 (Tensor): rotated bboxes 2. \
|
||||
It has shape (N, 5), indicating (x, y, w, h, theta) for each row.
|
||||
It has shape (M, 5), indicating (x, y, w, h, theta) for each row.
|
||||
|
||||
Returns:
|
||||
iou (Tensor[N, M]): the NxM matrix containing the pairwise
|
||||
IoU values for every element in boxes1 and boxes2
|
||||
ious(Tensor): shape (N, M) if aligned == False else shape (N,)
|
||||
"""
|
||||
if torch.__version__ == 'parrots':
|
||||
out = torch.zeros((bboxes1.shape[0], bboxes2.shape[0]),
|
||||
dtype=torch.float32).to(bboxes1.device)
|
||||
ext_module.box_iou_rotated(bboxes1, bboxes2, out)
|
||||
rows = bboxes1.size(0)
|
||||
cols = bboxes2.size(0)
|
||||
if aligned:
|
||||
ious = bboxes1.new_zeros(rows)
|
||||
else:
|
||||
out = ext_module.box_iou_rotated(bboxes1, bboxes2)
|
||||
return out
|
||||
ious = bboxes1.new_zeros((rows * cols))
|
||||
bboxes1 = bboxes1.contiguous()
|
||||
bboxes2 = bboxes2.contiguous()
|
||||
ext_module.box_iou_rotated(bboxes1, bboxes2, ious, aligned=aligned)
|
||||
if not aligned:
|
||||
ious = ious.view(rows, cols)
|
||||
return ious
|
||||
|
|
|
@ -21,48 +21,60 @@ template <typename T>
|
|||
__global__ void box_iou_rotated_cuda_kernel(const int n_boxes1,
|
||||
const int n_boxes2,
|
||||
const T* dev_boxes1,
|
||||
const T* dev_boxes2, T* dev_ious) {
|
||||
const int row_start = blockIdx.x * blockDim.x;
|
||||
const int col_start = blockIdx.y * blockDim.y;
|
||||
const T* dev_boxes2, T* dev_ious,
|
||||
const bool aligned) {
|
||||
if (aligned) {
|
||||
CUDA_1D_KERNEL_LOOP(index, n_boxes1) {
|
||||
int b1 = index;
|
||||
int b2 = index;
|
||||
|
||||
const int row_size = min(n_boxes1 - row_start, blockDim.x);
|
||||
const int col_size = min(n_boxes2 - col_start, blockDim.y);
|
||||
int base1 = b1 * 5;
|
||||
|
||||
__shared__ float block_boxes1[BLOCK_DIM_X * 5];
|
||||
__shared__ float block_boxes2[BLOCK_DIM_Y * 5];
|
||||
float block_boxes1[5];
|
||||
float block_boxes2[5];
|
||||
|
||||
// It's safe to copy using threadIdx.x since BLOCK_DIM_X >= BLOCK_DIM_Y
|
||||
if (threadIdx.x < row_size && threadIdx.y == 0) {
|
||||
block_boxes1[threadIdx.x * 5 + 0] =
|
||||
dev_boxes1[(row_start + threadIdx.x) * 5 + 0];
|
||||
block_boxes1[threadIdx.x * 5 + 1] =
|
||||
dev_boxes1[(row_start + threadIdx.x) * 5 + 1];
|
||||
block_boxes1[threadIdx.x * 5 + 2] =
|
||||
dev_boxes1[(row_start + threadIdx.x) * 5 + 2];
|
||||
block_boxes1[threadIdx.x * 5 + 3] =
|
||||
dev_boxes1[(row_start + threadIdx.x) * 5 + 3];
|
||||
block_boxes1[threadIdx.x * 5 + 4] =
|
||||
dev_boxes1[(row_start + threadIdx.x) * 5 + 4];
|
||||
}
|
||||
block_boxes1[0] = dev_boxes1[base1 + 0];
|
||||
block_boxes1[1] = dev_boxes1[base1 + 1];
|
||||
block_boxes1[2] = dev_boxes1[base1 + 2];
|
||||
block_boxes1[3] = dev_boxes1[base1 + 3];
|
||||
block_boxes1[4] = dev_boxes1[base1 + 4];
|
||||
|
||||
if (threadIdx.x < col_size && threadIdx.y == 0) {
|
||||
block_boxes2[threadIdx.x * 5 + 0] =
|
||||
dev_boxes2[(col_start + threadIdx.x) * 5 + 0];
|
||||
block_boxes2[threadIdx.x * 5 + 1] =
|
||||
dev_boxes2[(col_start + threadIdx.x) * 5 + 1];
|
||||
block_boxes2[threadIdx.x * 5 + 2] =
|
||||
dev_boxes2[(col_start + threadIdx.x) * 5 + 2];
|
||||
block_boxes2[threadIdx.x * 5 + 3] =
|
||||
dev_boxes2[(col_start + threadIdx.x) * 5 + 3];
|
||||
block_boxes2[threadIdx.x * 5 + 4] =
|
||||
dev_boxes2[(col_start + threadIdx.x) * 5 + 4];
|
||||
}
|
||||
__syncthreads();
|
||||
int base2 = b2 * 5;
|
||||
|
||||
if (threadIdx.x < row_size && threadIdx.y < col_size) {
|
||||
int offset = (row_start + threadIdx.x) * n_boxes2 + col_start + threadIdx.y;
|
||||
dev_ious[offset] = single_box_iou_rotated<T>(
|
||||
block_boxes1 + threadIdx.x * 5, block_boxes2 + threadIdx.y * 5);
|
||||
block_boxes2[0] = dev_boxes2[base2 + 0];
|
||||
block_boxes2[1] = dev_boxes2[base2 + 1];
|
||||
block_boxes2[2] = dev_boxes2[base2 + 2];
|
||||
block_boxes2[3] = dev_boxes2[base2 + 3];
|
||||
block_boxes2[4] = dev_boxes2[base2 + 4];
|
||||
|
||||
dev_ious[index] = single_box_iou_rotated<T>(block_boxes1, block_boxes2);
|
||||
}
|
||||
} else {
|
||||
CUDA_1D_KERNEL_LOOP(index, n_boxes1 * n_boxes2) {
|
||||
int b1 = index / n_boxes2;
|
||||
int b2 = index % n_boxes2;
|
||||
|
||||
int base1 = b1 * 5;
|
||||
|
||||
float block_boxes1[5];
|
||||
float block_boxes2[5];
|
||||
|
||||
block_boxes1[0] = dev_boxes1[base1 + 0];
|
||||
block_boxes1[1] = dev_boxes1[base1 + 1];
|
||||
block_boxes1[2] = dev_boxes1[base1 + 2];
|
||||
block_boxes1[3] = dev_boxes1[base1 + 3];
|
||||
block_boxes1[4] = dev_boxes1[base1 + 4];
|
||||
|
||||
int base2 = b2 * 5;
|
||||
|
||||
block_boxes2[0] = dev_boxes2[base2 + 0];
|
||||
block_boxes2[1] = dev_boxes2[base2 + 1];
|
||||
block_boxes2[2] = dev_boxes2[base2 + 2];
|
||||
block_boxes2[3] = dev_boxes2[base2 + 3];
|
||||
block_boxes2[4] = dev_boxes2[base2 + 4];
|
||||
|
||||
dev_ious[index] = single_box_iou_rotated<T>(block_boxes1, block_boxes2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -3,22 +3,46 @@
|
|||
// https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated.h
|
||||
#include "parrots_cpp_helper.hpp"
|
||||
|
||||
DArrayLite box_iou_rotated_cuda(const DArrayLite boxes1,
|
||||
const DArrayLite boxes2, cudaStream_t stream,
|
||||
CudaContext& ctx);
|
||||
void box_iou_rotated_cpu_launcher(const DArrayLite boxes1,
|
||||
const DArrayLite boxes2, DArrayLite ious,
|
||||
const bool aligned);
|
||||
|
||||
void box_iou_rotated(CudaContext& ctx, const SSElement& attr,
|
||||
const OperatorBase::in_list_t& ins,
|
||||
OperatorBase::out_list_t& outs) {
|
||||
void box_iou_rotated_cuda_launcher(const DArrayLite boxes1,
|
||||
const DArrayLite boxes2, DArrayLite ious,
|
||||
const bool aligned, cudaStream_t stream);
|
||||
|
||||
void box_iou_rotated_cpu(HostContext& ctx, const SSElement& attr,
|
||||
const OperatorBase::in_list_t& ins,
|
||||
OperatorBase::out_list_t& outs) {
|
||||
const auto& boxes1 = ins[0];
|
||||
const auto& boxes2 = ins[1];
|
||||
|
||||
bool aligned;
|
||||
SSAttrs(attr).get<bool>("aligned", aligned).done();
|
||||
auto& ious = outs[0];
|
||||
box_iou_rotated_cpu_launcher(boxes1, boxes2, ious, aligned);
|
||||
}
|
||||
|
||||
void box_iou_rotated_cuda(CudaContext& ctx, const SSElement& attr,
|
||||
const OperatorBase::in_list_t& ins,
|
||||
OperatorBase::out_list_t& outs) {
|
||||
const auto& boxes1 = ins[0];
|
||||
const auto& boxes2 = ins[1];
|
||||
|
||||
bool aligned;
|
||||
SSAttrs(attr).get<bool>("aligned", aligned).done();
|
||||
|
||||
cudaStream_t stream = getStreamNative<CudaDevice>(ctx.getStream());
|
||||
outs[0] = box_iou_rotated_cuda(boxes1, boxes2, stream, ctx);
|
||||
auto& ious = outs[0];
|
||||
box_iou_rotated_cuda_launcher(boxes1, boxes2, ious, aligned, stream);
|
||||
}
|
||||
|
||||
PARROTS_EXTENSION_REGISTER(box_iou_rotated)
|
||||
.attr("aligned")
|
||||
.input(2)
|
||||
.output(1)
|
||||
.apply(box_iou_rotated)
|
||||
.apply(box_iou_rotated_cpu)
|
||||
#ifdef PARROTS_USE_CUDA
|
||||
.apply(box_iou_rotated_cuda)
|
||||
#endif
|
||||
.done();
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
// modified from
|
||||
// https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cpu.cpp
|
||||
#include "box_iou_rotated_utils.hpp"
|
||||
#include "parrots_cpp_helper.hpp"
|
||||
|
||||
template <typename T>
|
||||
void box_iou_rotated_cpu_kernel(const DArrayLite boxes1,
|
||||
const DArrayLite boxes2, DArrayLite ious,
|
||||
const bool aligned) {
|
||||
int output_size = ious.size();
|
||||
int num_boxes1 = boxes1.dim(0);
|
||||
int num_boxes2 = boxes2.dim(0);
|
||||
|
||||
auto ious_ptr = ious.ptr<float>();
|
||||
|
||||
if (aligned) {
|
||||
for (int i = 0; i < output_size; i++) {
|
||||
ious_ptr[i] =
|
||||
single_box_iou_rotated<T>(boxes1[i].ptr<T>(), boxes2[i].ptr<T>());
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < num_boxes1; i++) {
|
||||
for (int j = 0; j < num_boxes2; j++) {
|
||||
ious_ptr[i * num_boxes2 + j] =
|
||||
single_box_iou_rotated<T>(boxes1[i].ptr<T>(), boxes2[j].ptr<T>());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void box_iou_rotated_cpu_launcher(const DArrayLite boxes1,
|
||||
const DArrayLite boxes2, DArrayLite ious,
|
||||
const bool aligned) {
|
||||
box_iou_rotated_cpu_kernel<float>(boxes1, boxes2, ious, aligned);
|
||||
}
|
|
@ -4,30 +4,19 @@
|
|||
#include "box_iou_rotated_cuda.cuh"
|
||||
#include "parrots_cuda_helper.hpp"
|
||||
|
||||
DArrayLite box_iou_rotated_cuda(const DArrayLite boxes1,
|
||||
const DArrayLite boxes2, cudaStream_t stream,
|
||||
CudaContext& ctx) {
|
||||
void box_iou_rotated_cuda_launcher(const DArrayLite boxes1,
|
||||
const DArrayLite boxes2, DArrayLite ious,
|
||||
const bool aligned, cudaStream_t stream) {
|
||||
using scalar_t = float;
|
||||
|
||||
int output_size = ious.size();
|
||||
int num_boxes1 = boxes1.dim(0);
|
||||
int num_boxes2 = boxes2.dim(0);
|
||||
|
||||
auto ious = ctx.createDArrayLite(
|
||||
DArraySpec::array(Prim::Float32, DArrayShape(num_boxes1 * num_boxes2)));
|
||||
box_iou_rotated_cuda_kernel<scalar_t>
|
||||
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
|
||||
num_boxes1, num_boxes2, boxes1.ptr<scalar_t>(),
|
||||
boxes2.ptr<scalar_t>(), (scalar_t*)ious.ptr<scalar_t>(), aligned);
|
||||
|
||||
if (num_boxes1 > 0 && num_boxes2 > 0) {
|
||||
const int blocks_x = divideUP(num_boxes1, BLOCK_DIM_X);
|
||||
const int blocks_y = divideUP(num_boxes2, BLOCK_DIM_Y);
|
||||
|
||||
dim3 blocks(blocks_x, blocks_y);
|
||||
dim3 threads(BLOCK_DIM_X, BLOCK_DIM_Y);
|
||||
|
||||
box_iou_rotated_cuda_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
|
||||
num_boxes1, num_boxes2, boxes1.ptr<scalar_t>(), boxes2.ptr<scalar_t>(),
|
||||
(scalar_t*)ious.ptr<scalar_t>());
|
||||
|
||||
PARROTS_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
return ious.view({num_boxes1, num_boxes2});
|
||||
PARROTS_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
|
|
@ -3,24 +3,27 @@
|
|||
// https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated.h
|
||||
#include "pytorch_cpp_helper.hpp"
|
||||
|
||||
Tensor box_iou_rotated_cpu(const Tensor boxes1, const Tensor boxes2);
|
||||
void box_iou_rotated_cpu(const Tensor boxes1, const Tensor boxes2, Tensor ious,
|
||||
const bool aligned);
|
||||
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
Tensor box_iou_rotated_cuda(const Tensor boxes1, const Tensor boxes2);
|
||||
void box_iou_rotated_cuda(const Tensor boxes1, const Tensor boxes2, Tensor ious,
|
||||
const bool aligned);
|
||||
#endif
|
||||
|
||||
// Interface for Python
|
||||
// inline is needed to prevent multiple function definitions when this header is
|
||||
// included by different cpps
|
||||
Tensor box_iou_rotated(const Tensor boxes1, const Tensor boxes2) {
|
||||
void box_iou_rotated(const Tensor boxes1, const Tensor boxes2, Tensor ious,
|
||||
const bool aligned) {
|
||||
assert(boxes1.device().is_cuda() == boxes2.device().is_cuda());
|
||||
if (boxes1.device().is_cuda()) {
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
return box_iou_rotated_cuda(boxes1, boxes2);
|
||||
box_iou_rotated_cuda(boxes1, boxes2, ious, aligned);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
box_iou_rotated_cpu(boxes1, boxes2, ious, aligned);
|
||||
}
|
||||
|
||||
return box_iou_rotated_cpu(boxes1, boxes2);
|
||||
}
|
||||
|
|
|
@ -6,35 +6,27 @@
|
|||
|
||||
template <typename T>
|
||||
void box_iou_rotated_cpu_kernel(const Tensor boxes1, const Tensor boxes2,
|
||||
Tensor ious) {
|
||||
auto widths1 = boxes1.select(1, 2).contiguous();
|
||||
auto heights1 = boxes1.select(1, 3).contiguous();
|
||||
auto widths2 = boxes2.select(1, 2).contiguous();
|
||||
auto heights2 = boxes2.select(1, 3).contiguous();
|
||||
|
||||
Tensor areas1 = widths1 * heights1;
|
||||
Tensor areas2 = widths2 * heights2;
|
||||
|
||||
Tensor ious, const bool aligned) {
|
||||
int output_size = ious.numel();
|
||||
auto num_boxes1 = boxes1.size(0);
|
||||
auto num_boxes2 = boxes2.size(0);
|
||||
|
||||
for (int i = 0; i < num_boxes1; i++) {
|
||||
for (int j = 0; j < num_boxes2; j++) {
|
||||
ious[i * num_boxes2 + j] = single_box_iou_rotated<T>(
|
||||
boxes1[i].data_ptr<T>(), boxes2[j].data_ptr<T>());
|
||||
if (aligned) {
|
||||
for (int i = 0; i < output_size; i++) {
|
||||
ious[i] = single_box_iou_rotated<T>(boxes1[i].data_ptr<T>(),
|
||||
boxes2[i].data_ptr<T>());
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < num_boxes1; i++) {
|
||||
for (int j = 0; j < num_boxes2; j++) {
|
||||
ious[i * num_boxes2 + j] = single_box_iou_rotated<T>(
|
||||
boxes1[i].data_ptr<T>(), boxes2[j].data_ptr<T>());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Tensor box_iou_rotated_cpu(const Tensor boxes1, const Tensor boxes2) {
|
||||
auto num_boxes1 = boxes1.size(0);
|
||||
auto num_boxes2 = boxes2.size(0);
|
||||
Tensor ious =
|
||||
at::empty({num_boxes1 * num_boxes2}, boxes1.options().dtype(at::kFloat));
|
||||
|
||||
box_iou_rotated_cpu_kernel<float>(boxes1, boxes2, ious);
|
||||
|
||||
// reshape from 1d array to 2d array
|
||||
auto shape = std::vector<int64_t>{num_boxes1, num_boxes2};
|
||||
return ious.reshape(shape);
|
||||
void box_iou_rotated_cpu(const Tensor boxes1, const Tensor boxes2, Tensor ious,
|
||||
const bool aligned) {
|
||||
box_iou_rotated_cpu_kernel<float>(boxes1, boxes2, ious, aligned);
|
||||
}
|
||||
|
|
|
@ -4,35 +4,22 @@
|
|||
#include "box_iou_rotated_cuda.cuh"
|
||||
#include "pytorch_cuda_helper.hpp"
|
||||
|
||||
Tensor box_iou_rotated_cuda(const Tensor boxes1, const Tensor boxes2) {
|
||||
void box_iou_rotated_cuda(const Tensor boxes1, const Tensor boxes2, Tensor ious,
|
||||
const bool aligned) {
|
||||
using scalar_t = float;
|
||||
AT_ASSERTM(boxes1.type().is_cuda(), "boxes1 must be a CUDA tensor");
|
||||
AT_ASSERTM(boxes2.type().is_cuda(), "boxes2 must be a CUDA tensor");
|
||||
at::cuda::CUDAGuard device_guard(boxes1.device());
|
||||
|
||||
int output_size = ious.numel();
|
||||
int num_boxes1 = boxes1.size(0);
|
||||
int num_boxes2 = boxes2.size(0);
|
||||
|
||||
Tensor ious =
|
||||
at::empty({num_boxes1 * num_boxes2}, boxes1.options().dtype(at::kFloat));
|
||||
|
||||
if (num_boxes1 > 0 && num_boxes2 > 0) {
|
||||
const int blocks_x = at::cuda::ATenCeilDiv(num_boxes1, BLOCK_DIM_X);
|
||||
const int blocks_y = at::cuda::ATenCeilDiv(num_boxes2, BLOCK_DIM_Y);
|
||||
|
||||
dim3 blocks(blocks_x, blocks_y);
|
||||
dim3 threads(BLOCK_DIM_X, BLOCK_DIM_Y);
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
box_iou_rotated_cuda_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
|
||||
num_boxes1, num_boxes2, boxes1.data_ptr<scalar_t>(),
|
||||
boxes2.data_ptr<scalar_t>(), (scalar_t*)ious.data_ptr<scalar_t>());
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
// reshape from 1d array to 2d array
|
||||
auto shape = std::vector<int64_t>{num_boxes1, num_boxes2};
|
||||
return ious.reshape(shape);
|
||||
at::cuda::CUDAGuard device_guard(boxes1.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
box_iou_rotated_cuda_kernel<scalar_t>
|
||||
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
|
||||
num_boxes1, num_boxes2, boxes1.data_ptr<scalar_t>(),
|
||||
boxes2.data_ptr<scalar_t>(), (scalar_t*)ious.data_ptr<scalar_t>(),
|
||||
aligned);
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
|
|
@ -175,7 +175,8 @@ Tensor top_pool_forward(Tensor input);
|
|||
|
||||
Tensor top_pool_backward(Tensor input, Tensor grad_output);
|
||||
|
||||
Tensor box_iou_rotated(const Tensor boxes1, const Tensor boxes2);
|
||||
void box_iou_rotated(const Tensor boxes1, const Tensor boxes2, Tensor ious,
|
||||
const bool aligned);
|
||||
|
||||
Tensor nms_rotated(const Tensor dets, Tensor scores, Tensor order,
|
||||
Tensor dets_sorted, const float iou_threshold,
|
||||
|
@ -364,7 +365,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||
py::arg("input"), py::arg("grad_output"),
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
m.def("box_iou_rotated", &box_iou_rotated, "IoU for rotated boxes",
|
||||
py::arg("boxes1"), py::arg("boxes2"));
|
||||
py::arg("boxes1"), py::arg("boxes2"), py::arg("ious"),
|
||||
py::arg("aligned"));
|
||||
m.def("nms_rotated", &nms_rotated, "NMS for rotated boxes", py::arg("dets"),
|
||||
py::arg("scores"), py::arg("order"), py::arg("dets_sorted"),
|
||||
py::arg("iou_threshold"), py::arg("multi_label"));
|
||||
|
|
|
@ -1,21 +1,62 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
class TestBoxIoURotated(object):
|
||||
|
||||
def test_box_iou_rotated(self):
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
def test_box_iou_rotated_cpu(self):
|
||||
from mmcv.ops import box_iou_rotated
|
||||
b1 = torch.tensor(
|
||||
[[1.0, 1.0, 3.0, 4.0], [2.0, 2.0, 3.0, 4.0], [7.0, 7.0, 8.0, 8.0]],
|
||||
dtype=torch.float32).cuda()
|
||||
b2 = torch.tensor([[0.0, 2.0, 2.0, 5.0], [2.0, 1.0, 3.0, 3.0]],
|
||||
dtype=torch.float32).cuda()
|
||||
expect_output = torch.tensor(
|
||||
[[0.2715, 0.0000], [0.1396, 0.0000], [0.0000, 0.0000]],
|
||||
dtype=torch.float32).cuda()
|
||||
output = box_iou_rotated(b1, b2)
|
||||
np_boxes1 = np.asarray(
|
||||
[[1.0, 1.0, 3.0, 4.0, 0.5], [2.0, 2.0, 3.0, 4.0, 0.6],
|
||||
[7.0, 7.0, 8.0, 8.0, 0.4]],
|
||||
dtype=np.float32)
|
||||
np_boxes2 = np.asarray(
|
||||
[[0.0, 2.0, 2.0, 5.0, 0.3], [2.0, 1.0, 3.0, 3.0, 0.5],
|
||||
[5.0, 5.0, 6.0, 7.0, 0.4]],
|
||||
dtype=np.float32)
|
||||
np_expect_ious = np.asarray(
|
||||
[[0.3708, 0.4351, 0.0000], [0.1104, 0.4487, 0.0424],
|
||||
[0.0000, 0.0000, 0.3622]],
|
||||
dtype=np.float32)
|
||||
np_expect_ious_aligned = np.asarray([0.3708, 0.4487, 0.3622],
|
||||
dtype=np.float32)
|
||||
|
||||
boxes1 = torch.from_numpy(np_boxes1)
|
||||
boxes2 = torch.from_numpy(np_boxes2)
|
||||
|
||||
ious = box_iou_rotated(boxes1, boxes2)
|
||||
assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4)
|
||||
|
||||
ious = box_iou_rotated(boxes1, boxes2, aligned=True)
|
||||
assert np.allclose(
|
||||
output.cpu().numpy(), expect_output.cpu().numpy(), atol=1e-4)
|
||||
ious.cpu().numpy(), np_expect_ious_aligned, atol=1e-4)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='requires CUDA support')
|
||||
def test_box_iou_rotated_cuda(self):
|
||||
from mmcv.ops import box_iou_rotated
|
||||
np_boxes1 = np.asarray(
|
||||
[[1.0, 1.0, 3.0, 4.0, 0.5], [2.0, 2.0, 3.0, 4.0, 0.6],
|
||||
[7.0, 7.0, 8.0, 8.0, 0.4]],
|
||||
dtype=np.float32)
|
||||
np_boxes2 = np.asarray(
|
||||
[[0.0, 2.0, 2.0, 5.0, 0.3], [2.0, 1.0, 3.0, 3.0, 0.5],
|
||||
[5.0, 5.0, 6.0, 7.0, 0.4]],
|
||||
dtype=np.float32)
|
||||
np_expect_ious = np.asarray(
|
||||
[[0.3708, 0.4351, 0.0000], [0.1104, 0.4487, 0.0424],
|
||||
[0.0000, 0.0000, 0.3622]],
|
||||
dtype=np.float32)
|
||||
np_expect_ious_aligned = np.asarray([0.3708, 0.4487, 0.3622],
|
||||
dtype=np.float32)
|
||||
|
||||
boxes1 = torch.from_numpy(np_boxes1).cuda()
|
||||
boxes2 = torch.from_numpy(np_boxes2).cuda()
|
||||
|
||||
ious = box_iou_rotated(boxes1, boxes2)
|
||||
assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4)
|
||||
|
||||
ious = box_iou_rotated(boxes1, boxes2, aligned=True)
|
||||
assert np.allclose(
|
||||
ious.cpu().numpy(), np_expect_ious_aligned, atol=1e-4)
|
||||
|
|
Loading…
Reference in New Issue