[Feature] Add min_area_polygons CUDA op for rotated detection. (#1611)

* init

* Update pybind.cpp

* Update min_area_polygons_cuda.cuh

* Update cudabind.cpp

* fix bug

* Create test_min_area_polygons.py

* add test

* update

* Update min_area_polygons_cuda.cuh

* fix bugs.

* Update min_area_polygons_cuda.cuh

* Update min_area_polygons.py

* Update min_area_polygons_cuda.cuh

* merge these 4 nested loops

* add AT_DISPATCH_FLOATING_TYPES_AND_HALF

* fix lint

* Resolving conflicts
pull/1598/head
Yue Zhou 2022-01-10 11:00:50 +08:00 committed by GitHub
parent b6167d5987
commit 51b40c332a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 397 additions and 1 deletions

View File

@ -19,6 +19,7 @@ We implement common CUDA ops used in detection, segmentation, etc.
- GroupPoints
- KNN
- MaskedConv
- MinAreaPolygon
- NMS
- PointsInPolygons
- PSAMask

View File

@ -18,6 +18,7 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子
- GeneralizedAttention
- KNN
- MaskedConv
- MinAreaPolygon
- NMS
- PointsInPolygons
- PSAMask

View File

@ -28,6 +28,7 @@ from .info import (get_compiler_version, get_compiling_cuda_version,
from .iou3d import boxes_iou_bev, nms_bev, nms_normal_bev
from .knn import knn
from .masked_conv import MaskedConv2d, masked_conv2d
from .min_area_polygons import min_area_polygons
from .modulated_deform_conv import (ModulatedDeformConv2d,
ModulatedDeformConv2dPack,
modulated_deform_conv2d)
@ -82,5 +83,5 @@ __all__ = [
'boxes_iou_bev', 'nms_bev', 'nms_normal_bev', 'Voxelization',
'voxelization', 'dynamic_scatter', 'DynamicScatter', 'RoIAwarePool3d',
'points_in_boxes_part', 'points_in_boxes_cpu', 'points_in_boxes_all',
'points_in_polygons'
'points_in_polygons', 'min_area_polygons'
]

View File

@ -0,0 +1,301 @@
// Copyright (c) OpenMMLab. All rights reserved
#ifndef MIN_AREA_POLYGONS_CUDA_KERNEL_CUH
#define MIN_AREA_POLYGONS_CUDA_KERNEL_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
#define MAXN 20
const float EPS = 1E-8;
const float PI = 3.1415926;
struct Point {
float x, y;
__device__ Point() {}
__device__ Point(float x, float y) : x(x), y(y) {}
};
__device__ inline void swap1(Point *a, Point *b) {
Point temp;
temp.x = a->x;
temp.y = a->y;
a->x = b->x;
a->y = b->y;
b->x = temp.x;
b->y = temp.y;
}
__device__ inline float cross(Point o, Point a, Point b) {
return (a.x - o.x) * (b.y - o.y) - (b.x - o.x) * (a.y - o.y);
}
__device__ inline float dis(Point a, Point b) {
return (a.x - b.x) * (a.x - b.x) + (a.y - b.y) * (a.y - b.y);
}
__device__ inline void minBoundingRect(Point *ps, int n_points, float *minbox) {
float convex_points[2][MAXN];
for (int j = 0; j < n_points; j++) {
convex_points[0][j] = ps[j].x;
}
for (int j = 0; j < n_points; j++) {
convex_points[1][j] = ps[j].y;
}
Point edges[MAXN];
float edges_angles[MAXN];
float unique_angles[MAXN];
int n_edges = n_points - 1;
int n_unique = 0;
int unique_flag = 0;
for (int i = 0; i < n_edges; i++) {
edges[i].x = ps[i + 1].x - ps[i].x;
edges[i].y = ps[i + 1].y - ps[i].y;
}
for (int i = 0; i < n_edges; i++) {
edges_angles[i] = atan2((double)edges[i].y, (double)edges[i].x);
if (edges_angles[i] >= 0) {
edges_angles[i] = fmod((double)edges_angles[i], (double)PI / 2);
} else {
edges_angles[i] =
edges_angles[i] - (int)(edges_angles[i] / (PI / 2) - 1) * (PI / 2);
}
}
unique_angles[0] = edges_angles[0];
n_unique += 1;
for (int i = 1; i < n_edges; i++) {
for (int j = 0; j < n_unique; j++) {
if (edges_angles[i] == unique_angles[j]) {
unique_flag += 1;
}
}
if (unique_flag == 0) {
unique_angles[n_unique] = edges_angles[i];
n_unique += 1;
unique_flag = 0;
} else {
unique_flag = 0;
}
}
float minarea = 1e12;
for (int i = 0; i < n_unique; i++) {
float R[2][2];
float rot_points[2][MAXN];
R[0][0] = cos(unique_angles[i]);
R[0][1] = -sin(unique_angles[i]);
R[1][0] = sin(unique_angles[i]);
R[1][1] = cos(unique_angles[i]);
// R x Points
for (int m = 0; m < 2; m++) {
for (int n = 0; n < n_points; n++) {
float sum = 0.0;
for (int k = 0; k < 2; k++) {
sum = sum + R[m][k] * convex_points[k][n];
}
rot_points[m][n] = sum;
}
}
// xmin;
float xmin, ymin, xmax, ymax;
xmin = 1e12;
for (int j = 0; j < n_points; j++) {
if (isinf(rot_points[0][j]) || isnan(rot_points[0][j])) {
continue;
} else {
if (rot_points[0][j] < xmin) {
xmin = rot_points[0][j];
}
}
}
// ymin
ymin = 1e12;
for (int j = 0; j < n_points; j++) {
if (isinf(rot_points[1][j]) || isnan(rot_points[1][j])) {
continue;
} else {
if (rot_points[1][j] < ymin) {
ymin = rot_points[1][j];
}
}
}
// xmax
xmax = -1e12;
for (int j = 0; j < n_points; j++) {
if (isinf(rot_points[0][j]) || isnan(rot_points[0][j])) {
continue;
} else {
if (rot_points[0][j] > xmax) {
xmax = rot_points[0][j];
}
}
}
// ymax
ymax = -1e12;
for (int j = 0; j < n_points; j++) {
if (isinf(rot_points[1][j]) || isnan(rot_points[1][j])) {
continue;
} else {
if (rot_points[1][j] > ymax) {
ymax = rot_points[1][j];
}
}
}
float area = (xmax - xmin) * (ymax - ymin);
if (area < minarea) {
minarea = area;
minbox[0] = unique_angles[i];
minbox[1] = xmin;
minbox[2] = ymin;
minbox[3] = xmax;
minbox[4] = ymax;
}
}
}
// convex_find
__device__ inline void Jarvis(Point *in_poly, int &n_poly) {
int n_input = n_poly;
Point input_poly[20];
for (int i = 0; i < n_input; i++) {
input_poly[i].x = in_poly[i].x;
input_poly[i].y = in_poly[i].y;
}
Point p_max, p_k;
int max_index, k_index;
int Stack[20], top1, top2;
// float sign;
double sign;
Point right_point[10], left_point[10];
for (int i = 0; i < n_poly; i++) {
if (in_poly[i].y < in_poly[0].y ||
in_poly[i].y == in_poly[0].y && in_poly[i].x < in_poly[0].x) {
Point *j = &(in_poly[0]);
Point *k = &(in_poly[i]);
swap1(j, k);
}
if (i == 0) {
p_max = in_poly[0];
max_index = 0;
}
if (in_poly[i].y > p_max.y ||
in_poly[i].y == p_max.y && in_poly[i].x > p_max.x) {
p_max = in_poly[i];
max_index = i;
}
}
if (max_index == 0) {
max_index = 1;
p_max = in_poly[max_index];
}
k_index = 0, Stack[0] = 0, top1 = 0;
while (k_index != max_index) {
p_k = p_max;
k_index = max_index;
for (int i = 1; i < n_poly; i++) {
sign = cross(in_poly[Stack[top1]], in_poly[i], p_k);
if ((sign > 0) || ((sign == 0) && (dis(in_poly[Stack[top1]], in_poly[i]) >
dis(in_poly[Stack[top1]], p_k)))) {
p_k = in_poly[i];
k_index = i;
}
}
top1++;
Stack[top1] = k_index;
}
for (int i = 0; i <= top1; i++) {
right_point[i] = in_poly[Stack[i]];
}
k_index = 0, Stack[0] = 0, top2 = 0;
while (k_index != max_index) {
p_k = p_max;
k_index = max_index;
for (int i = 1; i < n_poly; i++) {
sign = cross(in_poly[Stack[top2]], in_poly[i], p_k);
if ((sign < 0) || (sign == 0) && (dis(in_poly[Stack[top2]], in_poly[i]) >
dis(in_poly[Stack[top2]], p_k))) {
p_k = in_poly[i];
k_index = i;
}
}
top2++;
Stack[top2] = k_index;
}
for (int i = top2 - 1; i >= 0; i--) {
left_point[i] = in_poly[Stack[i]];
}
for (int i = 0; i < top1 + top2; i++) {
if (i <= top1) {
in_poly[i] = right_point[i];
} else {
in_poly[i] = left_point[top2 - (i - top1)];
}
}
n_poly = top1 + top2;
}
template <typename T>
__device__ inline void Findminbox(T const *const p, T *minpoints) {
Point ps1[MAXN];
Point convex[MAXN];
for (int i = 0; i < 9; i++) {
convex[i].x = p[i * 2];
convex[i].y = p[i * 2 + 1];
}
int n_convex = 9;
Jarvis(convex, n_convex);
int n1 = n_convex;
for (int i = 0; i < n1; i++) {
ps1[i].x = convex[i].x;
ps1[i].y = convex[i].y;
}
ps1[n1].x = convex[0].x;
ps1[n1].y = convex[0].y;
float minbbox[5] = {0};
minBoundingRect(ps1, n1 + 1, minbbox);
float angle = minbbox[0];
float xmin = minbbox[1];
float ymin = minbbox[2];
float xmax = minbbox[3];
float ymax = minbbox[4];
float R[2][2];
R[0][0] = cos(angle);
R[0][1] = -sin(angle);
R[1][0] = sin(angle);
R[1][1] = cos(angle);
minpoints[0] = xmax * R[0][0] + ymin * R[1][0];
minpoints[1] = xmax * R[0][1] + ymin * R[1][1];
minpoints[2] = xmin * R[0][0] + ymin * R[1][0];
minpoints[3] = xmin * R[0][1] + ymin * R[1][1];
minpoints[4] = xmin * R[0][0] + ymax * R[1][0];
minpoints[5] = xmin * R[0][1] + ymax * R[1][1];
minpoints[6] = xmax * R[0][0] + ymax * R[1][0];
minpoints[7] = xmax * R[0][1] + ymax * R[1][1];
}
template <typename T>
__global__ void min_area_polygons_cuda_kernel(const int ex_n_boxes,
const T *ex_boxes, T *minbox) {
CUDA_1D_KERNEL_LOOP(index, ex_n_boxes) {
const T *cur_box = ex_boxes + index * 18;
T *cur_min_box = minbox + index * 8;
Findminbox(cur_box, cur_min_box);
}
}
#endif // MIN_AREA_POLYGONS_CUDA_KERNEL_CUH

View File

@ -1500,3 +1500,12 @@ void points_in_polygons_forward_impl(const Tensor points, const Tensor polygons,
REGISTER_DEVICE_IMPL(points_in_polygons_forward_impl, CUDA,
points_in_polygons_forward_cuda);
void MinAreaPolygonsCUDAKernelLauncher(const Tensor pointsets, Tensor polygons);
void min_area_polygons_cuda(const Tensor pointsets, Tensor polygons) {
MinAreaPolygonsCUDAKernelLauncher(pointsets, polygons);
}
void min_area_polygons_impl(const Tensor pointsets, Tensor polygons);
REGISTER_DEVICE_IMPL(min_area_polygons_impl, CUDA, min_area_polygons_cuda);

View File

@ -0,0 +1,21 @@
// Copyright (c) OpenMMLab. All rights reserved
// modified from
// https://github.com/SDL-GuoZonghao/BeyondBoundingBox/blob/main/mmdet/ops/minareabbox/src/minareabbox_kernel.cu
#include "min_area_polygons_cuda.cuh"
#include "pytorch_cuda_helper.hpp"
void MinAreaPolygonsCUDAKernelLauncher(const Tensor pointsets,
Tensor polygons) {
int num_pointsets = pointsets.size(0);
const int output_size = polygons.numel();
at::cuda::CUDAGuard device_guard(pointsets.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
pointsets.scalar_type(), "min_area_polygons_cuda_kernel", ([&] {
min_area_polygons_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
num_pointsets, pointsets.data_ptr<scalar_t>(),
polygons.data_ptr<scalar_t>());
}));
AT_CUDA_CHECK(cudaGetLastError());
}

View File

@ -0,0 +1,11 @@
// Copyright (c) OpenMMLab. All rights reserved
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"
void min_area_polygons_impl(const Tensor pointsets, Tensor polygons) {
DISPATCH_DEVICE_IMPL(min_area_polygons_impl, pointsets, polygons);
}
void min_area_polygons(const Tensor pointsets, Tensor polygons) {
min_area_polygons_impl(pointsets, polygons);
}

View File

@ -363,6 +363,8 @@ void riroi_align_rotated_backward(Tensor top_grad, Tensor rois,
void points_in_polygons_forward(Tensor points, Tensor polygons, Tensor output);
void min_area_polygons(const Tensor pointsets, Tensor polygons);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"),
py::arg("kernel"), py::arg("up_x"), py::arg("up_y"), py::arg("down_x"),
@ -731,4 +733,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("points_in_polygons_forward", &points_in_polygons_forward,
"points_in_polygons_forward", py::arg("points"), py::arg("polygons"),
py::arg("output"));
m.def("min_area_polygons", &min_area_polygons, "min_area_polygons",
py::arg("pointsets"), py::arg("polygons"));
}

View File

@ -0,0 +1,18 @@
# Copyright (c) OpenMMLab. All rights reserved.
from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext', ['min_area_polygons'])
def min_area_polygons(pointsets):
"""Find the smallest polygons that surrounds all points in the point sets.
Args:
pointsets (Tensor): point sets with shape (N, 18).
Returns:
torch.Tensor: Return the smallest polygons with shape (N, 8).
"""
polygons = pointsets.new_zeros((pointsets.size(0), 8))
ext_module.min_area_polygons(pointsets, polygons)
return polygons

View File

@ -0,0 +1,29 @@
import numpy as np
import pytest
import torch
from mmcv.ops import min_area_polygons
np_pointsets = np.asarray([[
1.0, 1.0, 2.0, 2.0, 1.0, 2.0, 2.0, 1.0, 1.0, 3.0, 3.0, 1.0, 2.0, 3.0, 3.0,
2.0, 1.5, 1.5
],
[
1.5, 1.5, 2.5, 2.5, 1.5, 2.5, 2.5, 1.5, 1.5,
3.5, 3.5, 1.5, 2.5, 3.5, 3.5, 2.5, 2.0, 2.0
]])
expected_polygons = np.asarray(
[[3.0000, 1.0000, 1.0000, 1.0000, 1.0000, 3.0000, 3.0000, 3.0000],
[3.5000, 1.5000, 1.5000, 1.5000, 1.5000, 3.5000, 3.5000, 3.5000]])
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_min_area_polygons():
pointsets = torch.from_numpy(np_pointsets).cuda().float()
assert np.allclose(
min_area_polygons(pointsets).cpu().numpy(),
expected_polygons,
atol=1e-4)