From 508a322fba81aaf9dd0c215ca8fd44a511a53da4 Mon Sep 17 00:00:00 2001 From: BigBigDream <919056489@qq.com> Date: Wed, 16 Dec 2020 11:49:44 +0800 Subject: [PATCH] Support aligned mode for box_iou_rotated (#677) * support aligned and parrots cpu for box_iou_roatetd * add aligned doc * fix lint * fix lint * fix lint * fix lint * fix bug * fix bug * fix bug * fix lint * fix lint * fix bug * fix bug --- mmcv/ops/box_iou_rotated.py | 30 ++++--- mmcv/ops/csrc/box_iou_rotated_cuda.cuh | 86 +++++++++++-------- mmcv/ops/csrc/parrots/box_iou_rotated.cpp | 40 +++++++-- mmcv/ops/csrc/parrots/box_iou_rotated_cpu.cpp | 36 ++++++++ mmcv/ops/csrc/parrots/box_iou_rotated_cuda.cu | 29 ++----- mmcv/ops/csrc/pytorch/box_iou_rotated.cpp | 15 ++-- mmcv/ops/csrc/pytorch/box_iou_rotated_cpu.cpp | 40 ++++----- mmcv/ops/csrc/pytorch/box_iou_rotated_cuda.cu | 35 +++----- mmcv/ops/csrc/pytorch/pybind.cpp | 6 +- tests/test_ops/test_box_iou_rotated.py | 67 ++++++++++++--- 10 files changed, 238 insertions(+), 146 deletions(-) create mode 100644 mmcv/ops/csrc/parrots/box_iou_rotated_cpu.cpp diff --git a/mmcv/ops/box_iou_rotated.py b/mmcv/ops/box_iou_rotated.py index d1f2346bf..19fee1e4c 100644 --- a/mmcv/ops/box_iou_rotated.py +++ b/mmcv/ops/box_iou_rotated.py @@ -1,30 +1,36 @@ -import torch - from ..utils import ext_loader ext_module = ext_loader.load_ext('_ext', ['box_iou_rotated']) -def box_iou_rotated(bboxes1, bboxes2): +def box_iou_rotated(bboxes1, bboxes2, aligned=False): """Return intersection-over-union (Jaccard index) of boxes. Both sets of boxes are expected to be in (x_center, y_center, width, height, angle) format. + If ``aligned`` is ``False``, then calculate the ious between each bbox + of bboxes1 and bboxes2, otherwise the ious between each aligned pair of + bboxes1 and bboxes2. + Arguments: boxes1 (Tensor): rotated bboxes 1. \ It has shape (N, 5), indicating (x, y, w, h, theta) for each row. boxes2 (Tensor): rotated bboxes 2. \ - It has shape (N, 5), indicating (x, y, w, h, theta) for each row. + It has shape (M, 5), indicating (x, y, w, h, theta) for each row. Returns: - iou (Tensor[N, M]): the NxM matrix containing the pairwise - IoU values for every element in boxes1 and boxes2 + ious(Tensor): shape (N, M) if aligned == False else shape (N,) """ - if torch.__version__ == 'parrots': - out = torch.zeros((bboxes1.shape[0], bboxes2.shape[0]), - dtype=torch.float32).to(bboxes1.device) - ext_module.box_iou_rotated(bboxes1, bboxes2, out) + rows = bboxes1.size(0) + cols = bboxes2.size(0) + if aligned: + ious = bboxes1.new_zeros(rows) else: - out = ext_module.box_iou_rotated(bboxes1, bboxes2) - return out + ious = bboxes1.new_zeros((rows * cols)) + bboxes1 = bboxes1.contiguous() + bboxes2 = bboxes2.contiguous() + ext_module.box_iou_rotated(bboxes1, bboxes2, ious, aligned=aligned) + if not aligned: + ious = ious.view(rows, cols) + return ious diff --git a/mmcv/ops/csrc/box_iou_rotated_cuda.cuh b/mmcv/ops/csrc/box_iou_rotated_cuda.cuh index 09dfb4362..9c60377e0 100644 --- a/mmcv/ops/csrc/box_iou_rotated_cuda.cuh +++ b/mmcv/ops/csrc/box_iou_rotated_cuda.cuh @@ -21,48 +21,60 @@ template __global__ void box_iou_rotated_cuda_kernel(const int n_boxes1, const int n_boxes2, const T* dev_boxes1, - const T* dev_boxes2, T* dev_ious) { - const int row_start = blockIdx.x * blockDim.x; - const int col_start = blockIdx.y * blockDim.y; + const T* dev_boxes2, T* dev_ious, + const bool aligned) { + if (aligned) { + CUDA_1D_KERNEL_LOOP(index, n_boxes1) { + int b1 = index; + int b2 = index; - const int row_size = min(n_boxes1 - row_start, blockDim.x); - const int col_size = min(n_boxes2 - col_start, blockDim.y); + int base1 = b1 * 5; - __shared__ float block_boxes1[BLOCK_DIM_X * 5]; - __shared__ float block_boxes2[BLOCK_DIM_Y * 5]; + float block_boxes1[5]; + float block_boxes2[5]; - // It's safe to copy using threadIdx.x since BLOCK_DIM_X >= BLOCK_DIM_Y - if (threadIdx.x < row_size && threadIdx.y == 0) { - block_boxes1[threadIdx.x * 5 + 0] = - dev_boxes1[(row_start + threadIdx.x) * 5 + 0]; - block_boxes1[threadIdx.x * 5 + 1] = - dev_boxes1[(row_start + threadIdx.x) * 5 + 1]; - block_boxes1[threadIdx.x * 5 + 2] = - dev_boxes1[(row_start + threadIdx.x) * 5 + 2]; - block_boxes1[threadIdx.x * 5 + 3] = - dev_boxes1[(row_start + threadIdx.x) * 5 + 3]; - block_boxes1[threadIdx.x * 5 + 4] = - dev_boxes1[(row_start + threadIdx.x) * 5 + 4]; - } + block_boxes1[0] = dev_boxes1[base1 + 0]; + block_boxes1[1] = dev_boxes1[base1 + 1]; + block_boxes1[2] = dev_boxes1[base1 + 2]; + block_boxes1[3] = dev_boxes1[base1 + 3]; + block_boxes1[4] = dev_boxes1[base1 + 4]; - if (threadIdx.x < col_size && threadIdx.y == 0) { - block_boxes2[threadIdx.x * 5 + 0] = - dev_boxes2[(col_start + threadIdx.x) * 5 + 0]; - block_boxes2[threadIdx.x * 5 + 1] = - dev_boxes2[(col_start + threadIdx.x) * 5 + 1]; - block_boxes2[threadIdx.x * 5 + 2] = - dev_boxes2[(col_start + threadIdx.x) * 5 + 2]; - block_boxes2[threadIdx.x * 5 + 3] = - dev_boxes2[(col_start + threadIdx.x) * 5 + 3]; - block_boxes2[threadIdx.x * 5 + 4] = - dev_boxes2[(col_start + threadIdx.x) * 5 + 4]; - } - __syncthreads(); + int base2 = b2 * 5; - if (threadIdx.x < row_size && threadIdx.y < col_size) { - int offset = (row_start + threadIdx.x) * n_boxes2 + col_start + threadIdx.y; - dev_ious[offset] = single_box_iou_rotated( - block_boxes1 + threadIdx.x * 5, block_boxes2 + threadIdx.y * 5); + block_boxes2[0] = dev_boxes2[base2 + 0]; + block_boxes2[1] = dev_boxes2[base2 + 1]; + block_boxes2[2] = dev_boxes2[base2 + 2]; + block_boxes2[3] = dev_boxes2[base2 + 3]; + block_boxes2[4] = dev_boxes2[base2 + 4]; + + dev_ious[index] = single_box_iou_rotated(block_boxes1, block_boxes2); + } + } else { + CUDA_1D_KERNEL_LOOP(index, n_boxes1 * n_boxes2) { + int b1 = index / n_boxes2; + int b2 = index % n_boxes2; + + int base1 = b1 * 5; + + float block_boxes1[5]; + float block_boxes2[5]; + + block_boxes1[0] = dev_boxes1[base1 + 0]; + block_boxes1[1] = dev_boxes1[base1 + 1]; + block_boxes1[2] = dev_boxes1[base1 + 2]; + block_boxes1[3] = dev_boxes1[base1 + 3]; + block_boxes1[4] = dev_boxes1[base1 + 4]; + + int base2 = b2 * 5; + + block_boxes2[0] = dev_boxes2[base2 + 0]; + block_boxes2[1] = dev_boxes2[base2 + 1]; + block_boxes2[2] = dev_boxes2[base2 + 2]; + block_boxes2[3] = dev_boxes2[base2 + 3]; + block_boxes2[4] = dev_boxes2[base2 + 4]; + + dev_ious[index] = single_box_iou_rotated(block_boxes1, block_boxes2); + } } } diff --git a/mmcv/ops/csrc/parrots/box_iou_rotated.cpp b/mmcv/ops/csrc/parrots/box_iou_rotated.cpp index 42272e525..bf4423ff5 100644 --- a/mmcv/ops/csrc/parrots/box_iou_rotated.cpp +++ b/mmcv/ops/csrc/parrots/box_iou_rotated.cpp @@ -3,22 +3,46 @@ // https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated.h #include "parrots_cpp_helper.hpp" -DArrayLite box_iou_rotated_cuda(const DArrayLite boxes1, - const DArrayLite boxes2, cudaStream_t stream, - CudaContext& ctx); +void box_iou_rotated_cpu_launcher(const DArrayLite boxes1, + const DArrayLite boxes2, DArrayLite ious, + const bool aligned); -void box_iou_rotated(CudaContext& ctx, const SSElement& attr, - const OperatorBase::in_list_t& ins, - OperatorBase::out_list_t& outs) { +void box_iou_rotated_cuda_launcher(const DArrayLite boxes1, + const DArrayLite boxes2, DArrayLite ious, + const bool aligned, cudaStream_t stream); + +void box_iou_rotated_cpu(HostContext& ctx, const SSElement& attr, + const OperatorBase::in_list_t& ins, + OperatorBase::out_list_t& outs) { const auto& boxes1 = ins[0]; const auto& boxes2 = ins[1]; + bool aligned; + SSAttrs(attr).get("aligned", aligned).done(); + auto& ious = outs[0]; + box_iou_rotated_cpu_launcher(boxes1, boxes2, ious, aligned); +} + +void box_iou_rotated_cuda(CudaContext& ctx, const SSElement& attr, + const OperatorBase::in_list_t& ins, + OperatorBase::out_list_t& outs) { + const auto& boxes1 = ins[0]; + const auto& boxes2 = ins[1]; + + bool aligned; + SSAttrs(attr).get("aligned", aligned).done(); + cudaStream_t stream = getStreamNative(ctx.getStream()); - outs[0] = box_iou_rotated_cuda(boxes1, boxes2, stream, ctx); + auto& ious = outs[0]; + box_iou_rotated_cuda_launcher(boxes1, boxes2, ious, aligned, stream); } PARROTS_EXTENSION_REGISTER(box_iou_rotated) + .attr("aligned") .input(2) .output(1) - .apply(box_iou_rotated) + .apply(box_iou_rotated_cpu) +#ifdef PARROTS_USE_CUDA + .apply(box_iou_rotated_cuda) +#endif .done(); diff --git a/mmcv/ops/csrc/parrots/box_iou_rotated_cpu.cpp b/mmcv/ops/csrc/parrots/box_iou_rotated_cpu.cpp new file mode 100644 index 000000000..82452e606 --- /dev/null +++ b/mmcv/ops/csrc/parrots/box_iou_rotated_cpu.cpp @@ -0,0 +1,36 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +// modified from +// https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cpu.cpp +#include "box_iou_rotated_utils.hpp" +#include "parrots_cpp_helper.hpp" + +template +void box_iou_rotated_cpu_kernel(const DArrayLite boxes1, + const DArrayLite boxes2, DArrayLite ious, + const bool aligned) { + int output_size = ious.size(); + int num_boxes1 = boxes1.dim(0); + int num_boxes2 = boxes2.dim(0); + + auto ious_ptr = ious.ptr(); + + if (aligned) { + for (int i = 0; i < output_size; i++) { + ious_ptr[i] = + single_box_iou_rotated(boxes1[i].ptr(), boxes2[i].ptr()); + } + } else { + for (int i = 0; i < num_boxes1; i++) { + for (int j = 0; j < num_boxes2; j++) { + ious_ptr[i * num_boxes2 + j] = + single_box_iou_rotated(boxes1[i].ptr(), boxes2[j].ptr()); + } + } + } +} + +void box_iou_rotated_cpu_launcher(const DArrayLite boxes1, + const DArrayLite boxes2, DArrayLite ious, + const bool aligned) { + box_iou_rotated_cpu_kernel(boxes1, boxes2, ious, aligned); +} diff --git a/mmcv/ops/csrc/parrots/box_iou_rotated_cuda.cu b/mmcv/ops/csrc/parrots/box_iou_rotated_cuda.cu index 969d8dabf..97d5c2097 100644 --- a/mmcv/ops/csrc/parrots/box_iou_rotated_cuda.cu +++ b/mmcv/ops/csrc/parrots/box_iou_rotated_cuda.cu @@ -4,30 +4,19 @@ #include "box_iou_rotated_cuda.cuh" #include "parrots_cuda_helper.hpp" -DArrayLite box_iou_rotated_cuda(const DArrayLite boxes1, - const DArrayLite boxes2, cudaStream_t stream, - CudaContext& ctx) { +void box_iou_rotated_cuda_launcher(const DArrayLite boxes1, + const DArrayLite boxes2, DArrayLite ious, + const bool aligned, cudaStream_t stream) { using scalar_t = float; + int output_size = ious.size(); int num_boxes1 = boxes1.dim(0); int num_boxes2 = boxes2.dim(0); - auto ious = ctx.createDArrayLite( - DArraySpec::array(Prim::Float32, DArrayShape(num_boxes1 * num_boxes2))); + box_iou_rotated_cuda_kernel + <<>>( + num_boxes1, num_boxes2, boxes1.ptr(), + boxes2.ptr(), (scalar_t*)ious.ptr(), aligned); - if (num_boxes1 > 0 && num_boxes2 > 0) { - const int blocks_x = divideUP(num_boxes1, BLOCK_DIM_X); - const int blocks_y = divideUP(num_boxes2, BLOCK_DIM_Y); - - dim3 blocks(blocks_x, blocks_y); - dim3 threads(BLOCK_DIM_X, BLOCK_DIM_Y); - - box_iou_rotated_cuda_kernel<<>>( - num_boxes1, num_boxes2, boxes1.ptr(), boxes2.ptr(), - (scalar_t*)ious.ptr()); - - PARROTS_CUDA_CHECK(cudaGetLastError()); - } - - return ious.view({num_boxes1, num_boxes2}); + PARROTS_CUDA_CHECK(cudaGetLastError()); } diff --git a/mmcv/ops/csrc/pytorch/box_iou_rotated.cpp b/mmcv/ops/csrc/pytorch/box_iou_rotated.cpp index e27f9b181..68f3a9b94 100644 --- a/mmcv/ops/csrc/pytorch/box_iou_rotated.cpp +++ b/mmcv/ops/csrc/pytorch/box_iou_rotated.cpp @@ -3,24 +3,27 @@ // https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated.h #include "pytorch_cpp_helper.hpp" -Tensor box_iou_rotated_cpu(const Tensor boxes1, const Tensor boxes2); +void box_iou_rotated_cpu(const Tensor boxes1, const Tensor boxes2, Tensor ious, + const bool aligned); #ifdef MMCV_WITH_CUDA -Tensor box_iou_rotated_cuda(const Tensor boxes1, const Tensor boxes2); +void box_iou_rotated_cuda(const Tensor boxes1, const Tensor boxes2, Tensor ious, + const bool aligned); #endif // Interface for Python // inline is needed to prevent multiple function definitions when this header is // included by different cpps -Tensor box_iou_rotated(const Tensor boxes1, const Tensor boxes2) { +void box_iou_rotated(const Tensor boxes1, const Tensor boxes2, Tensor ious, + const bool aligned) { assert(boxes1.device().is_cuda() == boxes2.device().is_cuda()); if (boxes1.device().is_cuda()) { #ifdef MMCV_WITH_CUDA - return box_iou_rotated_cuda(boxes1, boxes2); + box_iou_rotated_cuda(boxes1, boxes2, ious, aligned); #else AT_ERROR("Not compiled with GPU support"); #endif + } else { + box_iou_rotated_cpu(boxes1, boxes2, ious, aligned); } - - return box_iou_rotated_cpu(boxes1, boxes2); } diff --git a/mmcv/ops/csrc/pytorch/box_iou_rotated_cpu.cpp b/mmcv/ops/csrc/pytorch/box_iou_rotated_cpu.cpp index a79e386b9..c14f2983f 100644 --- a/mmcv/ops/csrc/pytorch/box_iou_rotated_cpu.cpp +++ b/mmcv/ops/csrc/pytorch/box_iou_rotated_cpu.cpp @@ -6,35 +6,27 @@ template void box_iou_rotated_cpu_kernel(const Tensor boxes1, const Tensor boxes2, - Tensor ious) { - auto widths1 = boxes1.select(1, 2).contiguous(); - auto heights1 = boxes1.select(1, 3).contiguous(); - auto widths2 = boxes2.select(1, 2).contiguous(); - auto heights2 = boxes2.select(1, 3).contiguous(); - - Tensor areas1 = widths1 * heights1; - Tensor areas2 = widths2 * heights2; - + Tensor ious, const bool aligned) { + int output_size = ious.numel(); auto num_boxes1 = boxes1.size(0); auto num_boxes2 = boxes2.size(0); - for (int i = 0; i < num_boxes1; i++) { - for (int j = 0; j < num_boxes2; j++) { - ious[i * num_boxes2 + j] = single_box_iou_rotated( - boxes1[i].data_ptr(), boxes2[j].data_ptr()); + if (aligned) { + for (int i = 0; i < output_size; i++) { + ious[i] = single_box_iou_rotated(boxes1[i].data_ptr(), + boxes2[i].data_ptr()); + } + } else { + for (int i = 0; i < num_boxes1; i++) { + for (int j = 0; j < num_boxes2; j++) { + ious[i * num_boxes2 + j] = single_box_iou_rotated( + boxes1[i].data_ptr(), boxes2[j].data_ptr()); + } } } } -Tensor box_iou_rotated_cpu(const Tensor boxes1, const Tensor boxes2) { - auto num_boxes1 = boxes1.size(0); - auto num_boxes2 = boxes2.size(0); - Tensor ious = - at::empty({num_boxes1 * num_boxes2}, boxes1.options().dtype(at::kFloat)); - - box_iou_rotated_cpu_kernel(boxes1, boxes2, ious); - - // reshape from 1d array to 2d array - auto shape = std::vector{num_boxes1, num_boxes2}; - return ious.reshape(shape); +void box_iou_rotated_cpu(const Tensor boxes1, const Tensor boxes2, Tensor ious, + const bool aligned) { + box_iou_rotated_cpu_kernel(boxes1, boxes2, ious, aligned); } diff --git a/mmcv/ops/csrc/pytorch/box_iou_rotated_cuda.cu b/mmcv/ops/csrc/pytorch/box_iou_rotated_cuda.cu index d550ab7e7..15ed934dd 100644 --- a/mmcv/ops/csrc/pytorch/box_iou_rotated_cuda.cu +++ b/mmcv/ops/csrc/pytorch/box_iou_rotated_cuda.cu @@ -4,35 +4,22 @@ #include "box_iou_rotated_cuda.cuh" #include "pytorch_cuda_helper.hpp" -Tensor box_iou_rotated_cuda(const Tensor boxes1, const Tensor boxes2) { +void box_iou_rotated_cuda(const Tensor boxes1, const Tensor boxes2, Tensor ious, + const bool aligned) { using scalar_t = float; AT_ASSERTM(boxes1.type().is_cuda(), "boxes1 must be a CUDA tensor"); AT_ASSERTM(boxes2.type().is_cuda(), "boxes2 must be a CUDA tensor"); - at::cuda::CUDAGuard device_guard(boxes1.device()); + int output_size = ious.numel(); int num_boxes1 = boxes1.size(0); int num_boxes2 = boxes2.size(0); - Tensor ious = - at::empty({num_boxes1 * num_boxes2}, boxes1.options().dtype(at::kFloat)); - - if (num_boxes1 > 0 && num_boxes2 > 0) { - const int blocks_x = at::cuda::ATenCeilDiv(num_boxes1, BLOCK_DIM_X); - const int blocks_y = at::cuda::ATenCeilDiv(num_boxes2, BLOCK_DIM_Y); - - dim3 blocks(blocks_x, blocks_y); - dim3 threads(BLOCK_DIM_X, BLOCK_DIM_Y); - - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - box_iou_rotated_cuda_kernel<<>>( - num_boxes1, num_boxes2, boxes1.data_ptr(), - boxes2.data_ptr(), (scalar_t*)ious.data_ptr()); - - AT_CUDA_CHECK(cudaGetLastError()); - } - - // reshape from 1d array to 2d array - auto shape = std::vector{num_boxes1, num_boxes2}; - return ious.reshape(shape); + at::cuda::CUDAGuard device_guard(boxes1.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + box_iou_rotated_cuda_kernel + <<>>( + num_boxes1, num_boxes2, boxes1.data_ptr(), + boxes2.data_ptr(), (scalar_t*)ious.data_ptr(), + aligned); + AT_CUDA_CHECK(cudaGetLastError()); } diff --git a/mmcv/ops/csrc/pytorch/pybind.cpp b/mmcv/ops/csrc/pytorch/pybind.cpp index 744c509f7..a7e4d2b0f 100644 --- a/mmcv/ops/csrc/pytorch/pybind.cpp +++ b/mmcv/ops/csrc/pytorch/pybind.cpp @@ -175,7 +175,8 @@ Tensor top_pool_forward(Tensor input); Tensor top_pool_backward(Tensor input, Tensor grad_output); -Tensor box_iou_rotated(const Tensor boxes1, const Tensor boxes2); +void box_iou_rotated(const Tensor boxes1, const Tensor boxes2, Tensor ious, + const bool aligned); Tensor nms_rotated(const Tensor dets, Tensor scores, Tensor order, Tensor dets_sorted, const float iou_threshold, @@ -364,7 +365,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("input"), py::arg("grad_output"), py::call_guard()); m.def("box_iou_rotated", &box_iou_rotated, "IoU for rotated boxes", - py::arg("boxes1"), py::arg("boxes2")); + py::arg("boxes1"), py::arg("boxes2"), py::arg("ious"), + py::arg("aligned")); m.def("nms_rotated", &nms_rotated, "NMS for rotated boxes", py::arg("dets"), py::arg("scores"), py::arg("order"), py::arg("dets_sorted"), py::arg("iou_threshold"), py::arg("multi_label")); diff --git a/tests/test_ops/test_box_iou_rotated.py b/tests/test_ops/test_box_iou_rotated.py index 191b6a118..784b4ba88 100644 --- a/tests/test_ops/test_box_iou_rotated.py +++ b/tests/test_ops/test_box_iou_rotated.py @@ -1,21 +1,62 @@ import numpy as np +import pytest import torch class TestBoxIoURotated(object): - def test_box_iou_rotated(self): - if not torch.cuda.is_available(): - return + def test_box_iou_rotated_cpu(self): from mmcv.ops import box_iou_rotated - b1 = torch.tensor( - [[1.0, 1.0, 3.0, 4.0], [2.0, 2.0, 3.0, 4.0], [7.0, 7.0, 8.0, 8.0]], - dtype=torch.float32).cuda() - b2 = torch.tensor([[0.0, 2.0, 2.0, 5.0], [2.0, 1.0, 3.0, 3.0]], - dtype=torch.float32).cuda() - expect_output = torch.tensor( - [[0.2715, 0.0000], [0.1396, 0.0000], [0.0000, 0.0000]], - dtype=torch.float32).cuda() - output = box_iou_rotated(b1, b2) + np_boxes1 = np.asarray( + [[1.0, 1.0, 3.0, 4.0, 0.5], [2.0, 2.0, 3.0, 4.0, 0.6], + [7.0, 7.0, 8.0, 8.0, 0.4]], + dtype=np.float32) + np_boxes2 = np.asarray( + [[0.0, 2.0, 2.0, 5.0, 0.3], [2.0, 1.0, 3.0, 3.0, 0.5], + [5.0, 5.0, 6.0, 7.0, 0.4]], + dtype=np.float32) + np_expect_ious = np.asarray( + [[0.3708, 0.4351, 0.0000], [0.1104, 0.4487, 0.0424], + [0.0000, 0.0000, 0.3622]], + dtype=np.float32) + np_expect_ious_aligned = np.asarray([0.3708, 0.4487, 0.3622], + dtype=np.float32) + + boxes1 = torch.from_numpy(np_boxes1) + boxes2 = torch.from_numpy(np_boxes2) + + ious = box_iou_rotated(boxes1, boxes2) + assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4) + + ious = box_iou_rotated(boxes1, boxes2, aligned=True) assert np.allclose( - output.cpu().numpy(), expect_output.cpu().numpy(), atol=1e-4) + ious.cpu().numpy(), np_expect_ious_aligned, atol=1e-4) + + @pytest.mark.skipif( + not torch.cuda.is_available(), reason='requires CUDA support') + def test_box_iou_rotated_cuda(self): + from mmcv.ops import box_iou_rotated + np_boxes1 = np.asarray( + [[1.0, 1.0, 3.0, 4.0, 0.5], [2.0, 2.0, 3.0, 4.0, 0.6], + [7.0, 7.0, 8.0, 8.0, 0.4]], + dtype=np.float32) + np_boxes2 = np.asarray( + [[0.0, 2.0, 2.0, 5.0, 0.3], [2.0, 1.0, 3.0, 3.0, 0.5], + [5.0, 5.0, 6.0, 7.0, 0.4]], + dtype=np.float32) + np_expect_ious = np.asarray( + [[0.3708, 0.4351, 0.0000], [0.1104, 0.4487, 0.0424], + [0.0000, 0.0000, 0.3622]], + dtype=np.float32) + np_expect_ious_aligned = np.asarray([0.3708, 0.4487, 0.3622], + dtype=np.float32) + + boxes1 = torch.from_numpy(np_boxes1).cuda() + boxes2 = torch.from_numpy(np_boxes2).cuda() + + ious = box_iou_rotated(boxes1, boxes2) + assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4) + + ious = box_iou_rotated(boxes1, boxes2, aligned=True) + assert np.allclose( + ious.cpu().numpy(), np_expect_ious_aligned, atol=1e-4)