[Enhance] Differentiable rotated IoU (#1854)

* diff_iou_rotated is working

* add test; fix lint

* fix lint for test

* disable cpu build

* refactor files structure

* fix comments

* remove extra .repeat()

* add comment

* fix j-1 bug; update doc

* fix clang lint

* update docstrings

* fix comments

* fix comments
pull/1881/head
Danila Rukhovich 2022-04-15 11:40:07 +04:00 committed by GitHub
parent 7982dd1a06
commit aee596d523
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 555 additions and 1 deletions

View File

@ -13,6 +13,7 @@ We implement common CUDA ops used in detection, segmentation, etc.
- CornerPool
- Deformable Convolution v1/v2
- Deformable RoIPool
- DiffIoURotated
- DynamicScatter
- GatherPoints
- FurthestPointSample

View File

@ -13,6 +13,7 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子
- CornerPool
- Deformable Convolution v1/v2
- Deformable RoIPool
- DiffIoURotated
- DynamicScatter
- GatherPoints
- FurthestPointSample

View File

@ -18,6 +18,7 @@ from .deprecated_wrappers import Conv2d_deprecated as Conv2d
from .deprecated_wrappers import ConvTranspose2d_deprecated as ConvTranspose2d
from .deprecated_wrappers import Linear_deprecated as Linear
from .deprecated_wrappers import MaxPool2d_deprecated as MaxPool2d
from .diff_iou_rotated import diff_iou_rotated_2d, diff_iou_rotated_3d
from .focal_loss import (SigmoidFocalLoss, SoftmaxFocalLoss,
sigmoid_focal_loss, softmax_focal_loss)
from .furthest_point_sample import (furthest_point_sample,
@ -96,5 +97,5 @@ __all__ = [
'SparseMaxPool2d', 'SparseMaxPool3d', 'SparseConvTensor', 'scatter_nd',
'points_in_boxes_part', 'points_in_boxes_cpu', 'points_in_boxes_all',
'points_in_polygons', 'min_area_polygons', 'active_rotated_filter',
'convex_iou', 'convex_giou'
'convex_iou', 'convex_giou', 'diff_iou_rotated_2d', 'diff_iou_rotated_3d'
]

View File

@ -0,0 +1,136 @@
// Copyright (c) OpenMMLab. All rights reserved
// Adapted from
// https://github.com/lilanxiao/Rotated_IoU/cuda_op/sort_vert_kernel.cu # noqa
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
#define MAX_NUM_VERT_IDX 9
#define INTERSECTION_OFFSET 8
#define EPSILON 1e-8
inline int opt_n_thread(int work_size) {
const int pow_2 = std::log(static_cast<double>(work_size)) / std::log(2.0);
return max(min(1 << pow_2, THREADS_PER_BLOCK), 1);
}
/*
compare normalized vertices (vertices around (0,0))
if vertex1 < vertex2 return true.
order: minimum at x-aixs, become larger in anti-clockwise direction
*/
__device__ bool compare_vertices(float x1, float y1, float x2, float y2) {
if (fabs(x1 - x2) < EPSILON && fabs(y2 - y1) < EPSILON)
return false; // if equal, return false
if (y1 > 0 && y2 < 0) return true;
if (y1 < 0 && y2 > 0) return false;
float n1 = x1 * x1 + y1 * y1 + EPSILON;
float n2 = x2 * x2 + y2 * y2 + EPSILON;
float diff = fabs(x1) * x1 / n1 - fabs(x2) * x2 / n2;
if (y1 > 0 && y2 > 0) {
if (diff > EPSILON)
return true;
else
return false;
}
if (y1 < 0 && y2 < 0) {
if (diff < EPSILON)
return true;
else
return false;
}
}
__global__ void diff_iou_rotated_sort_vertices_forward_cuda_kernel(
int b, int n, int m, const float *__restrict__ vertices,
const bool *__restrict__ mask, const int *__restrict__ num_valid,
int *__restrict__ idx) {
int batch_idx = blockIdx.x;
vertices += batch_idx * n * m * 2;
mask += batch_idx * n * m;
num_valid += batch_idx * n;
idx += batch_idx * n * MAX_NUM_VERT_IDX;
int index = threadIdx.x; // index of polygon
int stride = blockDim.x;
for (int i = index; i < n; i += stride) {
int pad; // index of arbitrary invalid intersection point (not box corner!)
for (int j = INTERSECTION_OFFSET; j < m; ++j) {
if (!mask[i * m + j]) {
pad = j;
break;
}
}
if (num_valid[i] < 3) {
// not enough vertices, take an invalid intersection point
// (zero padding)
for (int j = 0; j < MAX_NUM_VERT_IDX; ++j) {
idx[i * MAX_NUM_VERT_IDX + j] = pad;
}
} else {
// sort the valid vertices
// note the number of valid vertices is known
// note: check that num_valid[i] < MAX_NUM_VERT_IDX
for (int j = 0; j < num_valid[i]; ++j) {
// initialize with a "big" value
float x_min = 1;
float y_min = -EPSILON;
int i_take = 0;
int i2;
float x2, y2;
if (j != 0) {
i2 = idx[i * MAX_NUM_VERT_IDX + j - 1];
x2 = vertices[i * m * 2 + i2 * 2 + 0];
y2 = vertices[i * m * 2 + i2 * 2 + 1];
}
for (int k = 0; k < m; ++k) {
float x = vertices[i * m * 2 + k * 2 + 0];
float y = vertices[i * m * 2 + k * 2 + 1];
if (mask[i * m + k] && compare_vertices(x, y, x_min, y_min)) {
if ((j == 0) || (j != 0 && compare_vertices(x2, y2, x, y))) {
x_min = x;
y_min = y;
i_take = k;
}
}
}
idx[i * MAX_NUM_VERT_IDX + j] = i_take;
}
// duplicate the first idx
idx[i * MAX_NUM_VERT_IDX + num_valid[i]] = idx[i * MAX_NUM_VERT_IDX + 0];
// pad zeros
for (int j = num_valid[i] + 1; j < MAX_NUM_VERT_IDX; ++j) {
idx[i * MAX_NUM_VERT_IDX + j] = pad;
}
// for corner case: the two boxes are exactly the same.
// in this case, idx would have duplicate elements, which makes the
// shoelace formula broken because of the definition, the duplicate
// elements only appear in the first 8 positions (they are "corners in
// box", not "intersection of edges")
if (num_valid[i] == 8) {
int counter = 0;
for (int j = 0; j < 4; ++j) {
int check = idx[i * MAX_NUM_VERT_IDX + j];
for (int k = 4; k < INTERSECTION_OFFSET; ++k) {
if (idx[i * MAX_NUM_VERT_IDX + k] == check) counter++;
}
}
if (counter == 4) {
idx[i * MAX_NUM_VERT_IDX + 4] = idx[i * MAX_NUM_VERT_IDX + 0];
for (int j = 5; j < MAX_NUM_VERT_IDX; ++j) {
idx[i * MAX_NUM_VERT_IDX + j] = pad;
}
}
}
// TODO: still might need to cover some other corner cases :(
}
}
}

View File

@ -1699,3 +1699,19 @@ void convex_giou_impl(const Tensor pointsets, const Tensor polygons,
REGISTER_DEVICE_IMPL(convex_iou_impl, CUDA, convex_iou_cuda);
REGISTER_DEVICE_IMPL(convex_giou_impl, CUDA, convex_giou_cuda);
Tensor DiffIoURotatedSortVerticesCUDAKernelLauncher(Tensor vertices,
Tensor mask,
Tensor num_valid);
Tensor diff_iou_rotated_sort_vertices_forward_cuda(Tensor vertices, Tensor mask,
Tensor num_valid) {
return DiffIoURotatedSortVerticesCUDAKernelLauncher(vertices, mask,
num_valid);
}
Tensor diff_iou_rotated_sort_vertices_forward_impl(Tensor vertices, Tensor mask,
Tensor num_valid);
REGISTER_DEVICE_IMPL(diff_iou_rotated_sort_vertices_forward_impl, CUDA,
diff_iou_rotated_sort_vertices_forward_cuda);

View File

@ -0,0 +1,35 @@
// Copyright (c) OpenMMLab. All rights reserved
// Adapted from
// https://github.com/lilanxiao/Rotated_IoU/cuda_op/sort_vert_kernel.cu # noqa
#include "diff_iou_rotated_cuda_kernel.cuh"
#include "pytorch_cpp_helper.hpp"
#include "pytorch_cuda_helper.hpp"
at::Tensor DiffIoURotatedSortVerticesCUDAKernelLauncher(at::Tensor vertices,
at::Tensor mask,
at::Tensor num_valid) {
at::cuda::CUDAGuard device_guard(vertices.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
CHECK_CONTIGUOUS(vertices);
CHECK_CONTIGUOUS(mask);
CHECK_CONTIGUOUS(num_valid);
CHECK_CUDA(vertices);
CHECK_CUDA(mask);
CHECK_CUDA(num_valid);
int b = vertices.size(0);
int n = vertices.size(1);
int m = vertices.size(2);
at::Tensor idx =
torch::zeros({b, n, MAX_NUM_VERT_IDX},
at::device(vertices.device()).dtype(at::ScalarType::Int));
diff_iou_rotated_sort_vertices_forward_cuda_kernel<<<b, opt_n_thread(n), 0,
stream>>>(
b, n, m, vertices.data_ptr<float>(), mask.data_ptr<bool>(),
num_valid.data_ptr<int>(), idx.data_ptr<int>());
AT_CUDA_CHECK(cudaGetLastError());
return idx;
}

View File

@ -0,0 +1,14 @@
// Copyright (c) OpenMMLab. All rights reserved
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"
Tensor diff_iou_rotated_sort_vertices_forward_impl(Tensor vertices, Tensor mask,
Tensor num_valid) {
return DISPATCH_DEVICE_IMPL(diff_iou_rotated_sort_vertices_forward_impl,
vertices, mask, num_valid);
}
Tensor diff_iou_rotated_sort_vertices_forward(Tensor vertices, Tensor mask,
Tensor num_valid) {
return diff_iou_rotated_sort_vertices_forward_impl(vertices, mask, num_valid);
}

View File

@ -400,6 +400,10 @@ void convex_iou(const Tensor pointsets, const Tensor polygons, Tensor ious);
void convex_giou(const Tensor pointsets, const Tensor polygons, Tensor output);
at::Tensor diff_iou_rotated_sort_vertices_forward(at::Tensor vertices,
at::Tensor mask,
at::Tensor num_valid);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"),
py::arg("kernel"), py::arg("up_x"), py::arg("up_y"), py::arg("down_x"),
@ -809,4 +813,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("polygons"), py::arg("ious"));
m.def("convex_giou", &convex_giou, "convex_giou", py::arg("pointsets"),
py::arg("polygons"), py::arg("output"));
m.def("diff_iou_rotated_sort_vertices_forward",
&diff_iou_rotated_sort_vertices_forward,
"diff_iou_rotated_sort_vertices_forward", py::arg("vertices"),
py::arg("mask"), py::arg("num_valid"));
}

View File

@ -0,0 +1,293 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Adapted from https://github.com/lilanxiao/Rotated_IoU/blob/master/box_intersection_2d.py # noqa
# Adapted from https://github.com/lilanxiao/Rotated_IoU/blob/master/oriented_iou_loss.py # noqa
import torch
from torch.autograd import Function
from ..utils import ext_loader
EPSILON = 1e-8
ext_module = ext_loader.load_ext('_ext',
['diff_iou_rotated_sort_vertices_forward'])
class SortVertices(Function):
@staticmethod
def forward(ctx, vertices, mask, num_valid):
idx = ext_module.diff_iou_rotated_sort_vertices_forward(
vertices, mask, num_valid)
ctx.mark_non_differentiable(idx)
return idx
@staticmethod
def backward(ctx, gradout):
return ()
def box_intersection(corners1, corners2):
"""Find intersection points of rectangles.
Convention: if two edges are collinear, there is no intersection point.
Args:
corners1 (Tensor): (B, N, 4, 2) First batch of boxes.
corners2 (Tensor): (B, N, 4, 2) Second batch of boxes.
Returns:
Tuple:
- Tensor: (B, N, 4, 4, 2) Intersections.
- Tensor: (B, N, 4, 4) Valid intersections mask.
"""
# build edges from corners
# B, N, 4, 4: Batch, Box, edge, point
line1 = torch.cat([corners1, corners1[:, :, [1, 2, 3, 0], :]], dim=3)
line2 = torch.cat([corners2, corners2[:, :, [1, 2, 3, 0], :]], dim=3)
# duplicate data to pair each edges from the boxes
# (B, N, 4, 4) -> (B, N, 4, 4, 4) : Batch, Box, edge1, edge2, point
line1_ext = line1.unsqueeze(3)
line2_ext = line2.unsqueeze(2)
x1, y1, x2, y2 = line1_ext.split([1, 1, 1, 1], dim=-1)
x3, y3, x4, y4 = line2_ext.split([1, 1, 1, 1], dim=-1)
# math: https://en.wikipedia.org/wiki/Line%E2%80%93line_intersection
numerator = (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4)
denumerator_t = (x1 - x3) * (y3 - y4) - (y1 - y3) * (x3 - x4)
t = denumerator_t / numerator
t[numerator == .0] = -1.
mask_t = (t > 0) & (t < 1) # intersection on line segment 1
denumerator_u = (x1 - x2) * (y1 - y3) - (y1 - y2) * (x1 - x3)
u = -denumerator_u / numerator
u[numerator == .0] = -1.
mask_u = (u > 0) & (u < 1) # intersection on line segment 2
mask = mask_t * mask_u
# overwrite with EPSILON. otherwise numerically unstable
t = denumerator_t / (numerator + EPSILON)
intersections = torch.stack([x1 + t * (x2 - x1), y1 + t * (y2 - y1)],
dim=-1)
intersections = intersections * mask.float().unsqueeze(-1)
return intersections, mask
def box1_in_box2(corners1, corners2):
"""Check if corners of box1 lie in box2.
Convention: if a corner is exactly on the edge of the other box,
it's also a valid point.
Args:
corners1 (Tensor): (B, N, 4, 2) First batch of boxes.
corners2 (Tensor): (B, N, 4, 2) Second batch of boxes.
Returns:
Tensor: (B, N, 4) Intersection.
"""
# a, b, c, d - 4 vertices of box2
a = corners2[:, :, 0:1, :] # (B, N, 1, 2)
b = corners2[:, :, 1:2, :] # (B, N, 1, 2)
d = corners2[:, :, 3:4, :] # (B, N, 1, 2)
# ab, am, ad - vectors between corresponding vertices
ab = b - a # (B, N, 1, 2)
am = corners1 - a # (B, N, 4, 2)
ad = d - a # (B, N, 1, 2)
prod_ab = torch.sum(ab * am, dim=-1) # (B, N, 4)
norm_ab = torch.sum(ab * ab, dim=-1) # (B, N, 1)
prod_ad = torch.sum(ad * am, dim=-1) # (B, N, 4)
norm_ad = torch.sum(ad * ad, dim=-1) # (B, N, 1)
# NOTE: the expression looks ugly but is stable if the two boxes
# are exactly the same also stable with different scale of bboxes
cond1 = (prod_ab / norm_ab > -1e-6) * (prod_ab / norm_ab < 1 + 1e-6
) # (B, N, 4)
cond2 = (prod_ad / norm_ad > -1e-6) * (prod_ad / norm_ad < 1 + 1e-6
) # (B, N, 4)
return cond1 * cond2
def box_in_box(corners1, corners2):
"""Check if corners of two boxes lie in each other.
Args:
corners1 (Tensor): (B, N, 4, 2) First batch of boxes.
corners2 (Tensor): (B, N, 4, 2) Second batch of boxes.
Returns:
Tuple:
- Tensor: (B, N, 4) True if i-th corner of box1 is in box2.
- Tensor: (B, N, 4) True if i-th corner of box2 is in box1.
"""
c1_in_2 = box1_in_box2(corners1, corners2)
c2_in_1 = box1_in_box2(corners2, corners1)
return c1_in_2, c2_in_1
def build_vertices(corners1, corners2, c1_in_2, c2_in_1, intersections,
valid_mask):
"""Find vertices of intersection area.
Args:
corners1 (Tensor): (B, N, 4, 2) First batch of boxes.
corners2 (Tensor): (B, N, 4, 2) Second batch of boxes.
c1_in_2 (Tensor): (B, N, 4) True if i-th corner of box1 is in box2.
c2_in_1 (Tensor): (B, N, 4) True if i-th corner of box2 is in box1.
intersections (Tensor): (B, N, 4, 4, 2) Intersections.
valid_mask (Tensor): (B, N, 4, 4) Valid intersections mask.
Returns:
Tuple:
- Tensor: (B, N, 24, 2) Vertices of intersection area;
only some elements are valid.
- Tensor: (B, N, 24) Mask of valid elements in vertices.
"""
# NOTE: inter has elements equals zero and has zeros gradient
# (masked by multiplying with 0); can be used as trick
B = corners1.size()[0]
N = corners1.size()[1]
# (B, N, 4 + 4 + 16, 2)
vertices = torch.cat(
[corners1, corners2,
intersections.view([B, N, -1, 2])], dim=2)
# Bool (B, N, 4 + 4 + 16)
mask = torch.cat([c1_in_2, c2_in_1, valid_mask.view([B, N, -1])], dim=2)
return vertices, mask
def sort_indices(vertices, mask):
"""Sort indices.
Note:
why 9? the polygon has maximal 8 vertices.
+1 to duplicate the first element.
the index should have following structure:
(A, B, C, ... , A, X, X, X)
and X indicates the index of arbitrary elements in the last
16 (intersections not corners) with value 0 and mask False.
(cause they have zero value and zero gradient)
Args:
vertices (Tensor): (B, N, 24, 2) Box vertices.
mask (Tensor): (B, N, 24) Mask.
Returns:
Tensor: (B, N, 9) Sorted indices.
"""
num_valid = torch.sum(mask.int(), dim=2).int() # (B, N)
mean = torch.sum(
vertices * mask.float().unsqueeze(-1), dim=2,
keepdim=True) / num_valid.unsqueeze(-1).unsqueeze(-1)
vertices_normalized = vertices - mean # normalization makes sorting easier
return SortVertices.apply(vertices_normalized, mask, num_valid).long()
def calculate_area(idx_sorted, vertices):
"""Calculate area of intersection.
Args:
idx_sorted (Tensor): (B, N, 9) Sorted vertex ids.
vertices (Tensor): (B, N, 24, 2) Vertices.
Returns:
Tuple:
- Tensor (B, N): Area of intersection.
- Tensor: (B, N, 9, 2) Vertices of polygon with zero padding.
"""
idx_ext = idx_sorted.unsqueeze(-1).repeat([1, 1, 1, 2])
selected = torch.gather(vertices, 2, idx_ext)
total = selected[:, :, 0:-1, 0] * selected[:, :, 1:, 1] \
- selected[:, :, 0:-1, 1] * selected[:, :, 1:, 0]
total = torch.sum(total, dim=2)
area = torch.abs(total) / 2
return area, selected
def oriented_box_intersection_2d(corners1, corners2):
"""Calculate intersection area of 2d rotated boxes.
Args:
corners1 (Tensor): (B, N, 4, 2) First batch of boxes.
corners2 (Tensor): (B, N, 4, 2) Second batch of boxes.
Returns:
Tuple:
- Tensor (B, N): Area of intersection.
- Tensor (B, N, 9, 2): Vertices of polygon with zero padding.
"""
intersections, valid_mask = box_intersection(corners1, corners2)
c12, c21 = box_in_box(corners1, corners2)
vertices, mask = build_vertices(corners1, corners2, c12, c21,
intersections, valid_mask)
sorted_indices = sort_indices(vertices, mask)
return calculate_area(sorted_indices, vertices)
def box2corners(box):
"""Convert rotated 2d box coordinate to corners.
Args:
box (Tensor): (B, N, 5) with x, y, w, h, alpha.
Returns:
Tensor: (B, N, 4, 2) Corners.
"""
B = box.size()[0]
x, y, w, h, alpha = box.split([1, 1, 1, 1, 1], dim=-1)
x4 = torch.FloatTensor([0.5, -0.5, -0.5, 0.5]).to(box.device)
x4 = x4 * w # (B, N, 4)
y4 = torch.FloatTensor([0.5, 0.5, -0.5, -0.5]).to(box.device)
y4 = y4 * h # (B, N, 4)
corners = torch.stack([x4, y4], dim=-1) # (B, N, 4, 2)
sin = torch.sin(alpha)
cos = torch.cos(alpha)
row1 = torch.cat([cos, sin], dim=-1)
row2 = torch.cat([-sin, cos], dim=-1) # (B, N, 2)
rot_T = torch.stack([row1, row2], dim=-2) # (B, N, 2, 2)
rotated = torch.bmm(corners.view([-1, 4, 2]), rot_T.view([-1, 2, 2]))
rotated = rotated.view([B, -1, 4, 2]) # (B * N, 4, 2) -> (B, N, 4, 2)
rotated[..., 0] += x
rotated[..., 1] += y
return rotated
def diff_iou_rotated_2d(box1, box2):
"""Calculate differentiable iou of rotated 2d boxes.
Args:
box1 (Tensor): (B, N, 5) First box.
box2 (Tensor): (B, N, 5) Second box.
Returns:
Tensor: (B, N) IoU.
"""
corners1 = box2corners(box1)
corners2 = box2corners(box2)
intersection, _ = oriented_box_intersection_2d(corners1,
corners2) # (B, N)
area1 = box1[:, :, 2] * box1[:, :, 3]
area2 = box2[:, :, 2] * box2[:, :, 3]
union = area1 + area2 - intersection
iou = intersection / union
return iou
def diff_iou_rotated_3d(box3d1, box3d2):
"""Calculate differentiable iou of rotated 3d boxes.
Args:
box3d1 (Tensor): (B, N, 3+3+1) First box (x,y,z,w,h,l,alpha).
box3d2 (Tensor): (B, N, 3+3+1) Second box (x,y,z,w,h,l,alpha).
Returns:
Tensor: (B, N) IoU.
"""
box1 = box3d1[..., [0, 1, 3, 4, 6]] # 2d box
box2 = box3d2[..., [0, 1, 3, 4, 6]]
corners1 = box2corners(box1)
corners2 = box2corners(box2)
intersection, _ = oriented_box_intersection_2d(corners1, corners2)
zmax1 = box3d1[..., 2] + box3d1[..., 5] * 0.5
zmin1 = box3d1[..., 2] - box3d1[..., 5] * 0.5
zmax2 = box3d2[..., 2] + box3d2[..., 5] * 0.5
zmin2 = box3d2[..., 2] - box3d2[..., 5] * 0.5
z_overlap = (torch.min(zmax1, zmax2) -
torch.max(zmin1, zmin2)).clamp_(min=0.)
intersection_3d = intersection * z_overlap
volume1 = box3d1[..., 3] * box3d1[..., 4] * box3d1[..., 5]
volume2 = box3d2[..., 3] * box3d2[..., 4] * box3d2[..., 5]
union_3d = volume1 + volume2 - intersection_3d
return intersection_3d / union_3d

View File

@ -0,0 +1,49 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
import torch
from mmcv.ops import diff_iou_rotated_2d, diff_iou_rotated_3d
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_diff_iou_rotated_2d():
np_boxes1 = np.asarray([[[0.5, 0.5, 1., 1., .0], [0.5, 0.5, 1., 1., .0],
[0.5, 0.5, 1., 1., .0], [0.5, 0.5, 1., 1., .0],
[0.5, 0.5, 1., 1., .0]]],
dtype=np.float32)
np_boxes2 = np.asarray(
[[[0.5, 0.5, 1., 1., .0], [0.5, 0.5, 1., 1., np.pi / 2],
[0.5, 0.5, 1., 1., np.pi / 4], [1., 1., 1., 1., .0],
[1.5, 1.5, 1., 1., .0]]],
dtype=np.float32)
boxes1 = torch.from_numpy(np_boxes1).cuda()
boxes2 = torch.from_numpy(np_boxes2).cuda()
np_expect_ious = np.asarray([[1., 1., .7071, 1 / 7, .0]])
ious = diff_iou_rotated_2d(boxes1, boxes2)
assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_diff_iou_rotated_3d():
np_boxes1 = np.asarray(
[[[.5, .5, .5, 1., 1., 1., .0], [.5, .5, .5, 1., 1., 1., .0],
[.5, .5, .5, 1., 1., 1., .0], [.5, .5, .5, 1., 1., 1., .0],
[.5, .5, .5, 1., 1., 1., .0]]],
dtype=np.float32)
np_boxes2 = np.asarray(
[[[.5, .5, .5, 1., 1., 1., .0], [.5, .5, .5, 1., 1., 2., np.pi / 2],
[.5, .5, .5, 1., 1., 1., np.pi / 4], [1., 1., 1., 1., 1., 1., .0],
[-1.5, -1.5, -1.5, 2.5, 2.5, 2.5, .0]]],
dtype=np.float32)
boxes1 = torch.from_numpy(np_boxes1).cuda()
boxes2 = torch.from_numpy(np_boxes2).cuda()
np_expect_ious = np.asarray([[1., .5, .7071, 1 / 15, .0]])
ious = diff_iou_rotated_3d(boxes1, boxes2)
assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4)