mirror of https://github.com/open-mmlab/mmcv.git
[Feature] Add min_area_polygons CUDA op for rotated detection. (#1611)
* init * Update pybind.cpp * Update min_area_polygons_cuda.cuh * Update cudabind.cpp * fix bug * Create test_min_area_polygons.py * add test * update * Update min_area_polygons_cuda.cuh * fix bugs. * Update min_area_polygons_cuda.cuh * Update min_area_polygons.py * Update min_area_polygons_cuda.cuh * merge these 4 nested loops * add AT_DISPATCH_FLOATING_TYPES_AND_HALF * fix lint * Resolving conflictspull/1598/head
parent
b6167d5987
commit
51b40c332a
|
@ -19,6 +19,7 @@ We implement common CUDA ops used in detection, segmentation, etc.
|
|||
- GroupPoints
|
||||
- KNN
|
||||
- MaskedConv
|
||||
- MinAreaPolygon
|
||||
- NMS
|
||||
- PointsInPolygons
|
||||
- PSAMask
|
||||
|
|
|
@ -18,6 +18,7 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子
|
|||
- GeneralizedAttention
|
||||
- KNN
|
||||
- MaskedConv
|
||||
- MinAreaPolygon
|
||||
- NMS
|
||||
- PointsInPolygons
|
||||
- PSAMask
|
||||
|
|
|
@ -28,6 +28,7 @@ from .info import (get_compiler_version, get_compiling_cuda_version,
|
|||
from .iou3d import boxes_iou_bev, nms_bev, nms_normal_bev
|
||||
from .knn import knn
|
||||
from .masked_conv import MaskedConv2d, masked_conv2d
|
||||
from .min_area_polygons import min_area_polygons
|
||||
from .modulated_deform_conv import (ModulatedDeformConv2d,
|
||||
ModulatedDeformConv2dPack,
|
||||
modulated_deform_conv2d)
|
||||
|
@ -82,5 +83,5 @@ __all__ = [
|
|||
'boxes_iou_bev', 'nms_bev', 'nms_normal_bev', 'Voxelization',
|
||||
'voxelization', 'dynamic_scatter', 'DynamicScatter', 'RoIAwarePool3d',
|
||||
'points_in_boxes_part', 'points_in_boxes_cpu', 'points_in_boxes_all',
|
||||
'points_in_polygons'
|
||||
'points_in_polygons', 'min_area_polygons'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,301 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#ifndef MIN_AREA_POLYGONS_CUDA_KERNEL_CUH
|
||||
#define MIN_AREA_POLYGONS_CUDA_KERNEL_CUH
|
||||
|
||||
#ifdef MMCV_USE_PARROTS
|
||||
#include "parrots_cuda_helper.hpp"
|
||||
#else
|
||||
#include "pytorch_cuda_helper.hpp"
|
||||
#endif
|
||||
|
||||
#define MAXN 20
|
||||
const float EPS = 1E-8;
|
||||
const float PI = 3.1415926;
|
||||
|
||||
struct Point {
|
||||
float x, y;
|
||||
__device__ Point() {}
|
||||
__device__ Point(float x, float y) : x(x), y(y) {}
|
||||
};
|
||||
|
||||
__device__ inline void swap1(Point *a, Point *b) {
|
||||
Point temp;
|
||||
temp.x = a->x;
|
||||
temp.y = a->y;
|
||||
|
||||
a->x = b->x;
|
||||
a->y = b->y;
|
||||
|
||||
b->x = temp.x;
|
||||
b->y = temp.y;
|
||||
}
|
||||
__device__ inline float cross(Point o, Point a, Point b) {
|
||||
return (a.x - o.x) * (b.y - o.y) - (b.x - o.x) * (a.y - o.y);
|
||||
}
|
||||
|
||||
__device__ inline float dis(Point a, Point b) {
|
||||
return (a.x - b.x) * (a.x - b.x) + (a.y - b.y) * (a.y - b.y);
|
||||
}
|
||||
__device__ inline void minBoundingRect(Point *ps, int n_points, float *minbox) {
|
||||
float convex_points[2][MAXN];
|
||||
for (int j = 0; j < n_points; j++) {
|
||||
convex_points[0][j] = ps[j].x;
|
||||
}
|
||||
for (int j = 0; j < n_points; j++) {
|
||||
convex_points[1][j] = ps[j].y;
|
||||
}
|
||||
|
||||
Point edges[MAXN];
|
||||
float edges_angles[MAXN];
|
||||
float unique_angles[MAXN];
|
||||
int n_edges = n_points - 1;
|
||||
int n_unique = 0;
|
||||
int unique_flag = 0;
|
||||
|
||||
for (int i = 0; i < n_edges; i++) {
|
||||
edges[i].x = ps[i + 1].x - ps[i].x;
|
||||
edges[i].y = ps[i + 1].y - ps[i].y;
|
||||
}
|
||||
for (int i = 0; i < n_edges; i++) {
|
||||
edges_angles[i] = atan2((double)edges[i].y, (double)edges[i].x);
|
||||
if (edges_angles[i] >= 0) {
|
||||
edges_angles[i] = fmod((double)edges_angles[i], (double)PI / 2);
|
||||
} else {
|
||||
edges_angles[i] =
|
||||
edges_angles[i] - (int)(edges_angles[i] / (PI / 2) - 1) * (PI / 2);
|
||||
}
|
||||
}
|
||||
unique_angles[0] = edges_angles[0];
|
||||
n_unique += 1;
|
||||
for (int i = 1; i < n_edges; i++) {
|
||||
for (int j = 0; j < n_unique; j++) {
|
||||
if (edges_angles[i] == unique_angles[j]) {
|
||||
unique_flag += 1;
|
||||
}
|
||||
}
|
||||
if (unique_flag == 0) {
|
||||
unique_angles[n_unique] = edges_angles[i];
|
||||
n_unique += 1;
|
||||
unique_flag = 0;
|
||||
} else {
|
||||
unique_flag = 0;
|
||||
}
|
||||
}
|
||||
|
||||
float minarea = 1e12;
|
||||
for (int i = 0; i < n_unique; i++) {
|
||||
float R[2][2];
|
||||
float rot_points[2][MAXN];
|
||||
R[0][0] = cos(unique_angles[i]);
|
||||
R[0][1] = -sin(unique_angles[i]);
|
||||
R[1][0] = sin(unique_angles[i]);
|
||||
R[1][1] = cos(unique_angles[i]);
|
||||
// R x Points
|
||||
for (int m = 0; m < 2; m++) {
|
||||
for (int n = 0; n < n_points; n++) {
|
||||
float sum = 0.0;
|
||||
for (int k = 0; k < 2; k++) {
|
||||
sum = sum + R[m][k] * convex_points[k][n];
|
||||
}
|
||||
rot_points[m][n] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
// xmin;
|
||||
float xmin, ymin, xmax, ymax;
|
||||
xmin = 1e12;
|
||||
for (int j = 0; j < n_points; j++) {
|
||||
if (isinf(rot_points[0][j]) || isnan(rot_points[0][j])) {
|
||||
continue;
|
||||
} else {
|
||||
if (rot_points[0][j] < xmin) {
|
||||
xmin = rot_points[0][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
// ymin
|
||||
ymin = 1e12;
|
||||
for (int j = 0; j < n_points; j++) {
|
||||
if (isinf(rot_points[1][j]) || isnan(rot_points[1][j])) {
|
||||
continue;
|
||||
} else {
|
||||
if (rot_points[1][j] < ymin) {
|
||||
ymin = rot_points[1][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
// xmax
|
||||
xmax = -1e12;
|
||||
for (int j = 0; j < n_points; j++) {
|
||||
if (isinf(rot_points[0][j]) || isnan(rot_points[0][j])) {
|
||||
continue;
|
||||
} else {
|
||||
if (rot_points[0][j] > xmax) {
|
||||
xmax = rot_points[0][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
// ymax
|
||||
ymax = -1e12;
|
||||
for (int j = 0; j < n_points; j++) {
|
||||
if (isinf(rot_points[1][j]) || isnan(rot_points[1][j])) {
|
||||
continue;
|
||||
} else {
|
||||
if (rot_points[1][j] > ymax) {
|
||||
ymax = rot_points[1][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
float area = (xmax - xmin) * (ymax - ymin);
|
||||
if (area < minarea) {
|
||||
minarea = area;
|
||||
minbox[0] = unique_angles[i];
|
||||
minbox[1] = xmin;
|
||||
minbox[2] = ymin;
|
||||
minbox[3] = xmax;
|
||||
minbox[4] = ymax;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// convex_find
|
||||
__device__ inline void Jarvis(Point *in_poly, int &n_poly) {
|
||||
int n_input = n_poly;
|
||||
Point input_poly[20];
|
||||
for (int i = 0; i < n_input; i++) {
|
||||
input_poly[i].x = in_poly[i].x;
|
||||
input_poly[i].y = in_poly[i].y;
|
||||
}
|
||||
Point p_max, p_k;
|
||||
int max_index, k_index;
|
||||
int Stack[20], top1, top2;
|
||||
// float sign;
|
||||
double sign;
|
||||
Point right_point[10], left_point[10];
|
||||
|
||||
for (int i = 0; i < n_poly; i++) {
|
||||
if (in_poly[i].y < in_poly[0].y ||
|
||||
in_poly[i].y == in_poly[0].y && in_poly[i].x < in_poly[0].x) {
|
||||
Point *j = &(in_poly[0]);
|
||||
Point *k = &(in_poly[i]);
|
||||
swap1(j, k);
|
||||
}
|
||||
if (i == 0) {
|
||||
p_max = in_poly[0];
|
||||
max_index = 0;
|
||||
}
|
||||
if (in_poly[i].y > p_max.y ||
|
||||
in_poly[i].y == p_max.y && in_poly[i].x > p_max.x) {
|
||||
p_max = in_poly[i];
|
||||
max_index = i;
|
||||
}
|
||||
}
|
||||
if (max_index == 0) {
|
||||
max_index = 1;
|
||||
p_max = in_poly[max_index];
|
||||
}
|
||||
|
||||
k_index = 0, Stack[0] = 0, top1 = 0;
|
||||
while (k_index != max_index) {
|
||||
p_k = p_max;
|
||||
k_index = max_index;
|
||||
for (int i = 1; i < n_poly; i++) {
|
||||
sign = cross(in_poly[Stack[top1]], in_poly[i], p_k);
|
||||
if ((sign > 0) || ((sign == 0) && (dis(in_poly[Stack[top1]], in_poly[i]) >
|
||||
dis(in_poly[Stack[top1]], p_k)))) {
|
||||
p_k = in_poly[i];
|
||||
k_index = i;
|
||||
}
|
||||
}
|
||||
top1++;
|
||||
Stack[top1] = k_index;
|
||||
}
|
||||
|
||||
for (int i = 0; i <= top1; i++) {
|
||||
right_point[i] = in_poly[Stack[i]];
|
||||
}
|
||||
|
||||
k_index = 0, Stack[0] = 0, top2 = 0;
|
||||
|
||||
while (k_index != max_index) {
|
||||
p_k = p_max;
|
||||
k_index = max_index;
|
||||
for (int i = 1; i < n_poly; i++) {
|
||||
sign = cross(in_poly[Stack[top2]], in_poly[i], p_k);
|
||||
if ((sign < 0) || (sign == 0) && (dis(in_poly[Stack[top2]], in_poly[i]) >
|
||||
dis(in_poly[Stack[top2]], p_k))) {
|
||||
p_k = in_poly[i];
|
||||
k_index = i;
|
||||
}
|
||||
}
|
||||
top2++;
|
||||
Stack[top2] = k_index;
|
||||
}
|
||||
|
||||
for (int i = top2 - 1; i >= 0; i--) {
|
||||
left_point[i] = in_poly[Stack[i]];
|
||||
}
|
||||
|
||||
for (int i = 0; i < top1 + top2; i++) {
|
||||
if (i <= top1) {
|
||||
in_poly[i] = right_point[i];
|
||||
} else {
|
||||
in_poly[i] = left_point[top2 - (i - top1)];
|
||||
}
|
||||
}
|
||||
n_poly = top1 + top2;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ inline void Findminbox(T const *const p, T *minpoints) {
|
||||
Point ps1[MAXN];
|
||||
Point convex[MAXN];
|
||||
for (int i = 0; i < 9; i++) {
|
||||
convex[i].x = p[i * 2];
|
||||
convex[i].y = p[i * 2 + 1];
|
||||
}
|
||||
int n_convex = 9;
|
||||
Jarvis(convex, n_convex);
|
||||
int n1 = n_convex;
|
||||
for (int i = 0; i < n1; i++) {
|
||||
ps1[i].x = convex[i].x;
|
||||
ps1[i].y = convex[i].y;
|
||||
}
|
||||
ps1[n1].x = convex[0].x;
|
||||
ps1[n1].y = convex[0].y;
|
||||
|
||||
float minbbox[5] = {0};
|
||||
minBoundingRect(ps1, n1 + 1, minbbox);
|
||||
float angle = minbbox[0];
|
||||
float xmin = minbbox[1];
|
||||
float ymin = minbbox[2];
|
||||
float xmax = minbbox[3];
|
||||
float ymax = minbbox[4];
|
||||
float R[2][2];
|
||||
|
||||
R[0][0] = cos(angle);
|
||||
R[0][1] = -sin(angle);
|
||||
R[1][0] = sin(angle);
|
||||
R[1][1] = cos(angle);
|
||||
|
||||
minpoints[0] = xmax * R[0][0] + ymin * R[1][0];
|
||||
minpoints[1] = xmax * R[0][1] + ymin * R[1][1];
|
||||
minpoints[2] = xmin * R[0][0] + ymin * R[1][0];
|
||||
minpoints[3] = xmin * R[0][1] + ymin * R[1][1];
|
||||
minpoints[4] = xmin * R[0][0] + ymax * R[1][0];
|
||||
minpoints[5] = xmin * R[0][1] + ymax * R[1][1];
|
||||
minpoints[6] = xmax * R[0][0] + ymax * R[1][0];
|
||||
minpoints[7] = xmax * R[0][1] + ymax * R[1][1];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void min_area_polygons_cuda_kernel(const int ex_n_boxes,
|
||||
const T *ex_boxes, T *minbox) {
|
||||
CUDA_1D_KERNEL_LOOP(index, ex_n_boxes) {
|
||||
const T *cur_box = ex_boxes + index * 18;
|
||||
T *cur_min_box = minbox + index * 8;
|
||||
Findminbox(cur_box, cur_min_box);
|
||||
}
|
||||
}
|
||||
|
||||
#endif // MIN_AREA_POLYGONS_CUDA_KERNEL_CUH
|
|
@ -1500,3 +1500,12 @@ void points_in_polygons_forward_impl(const Tensor points, const Tensor polygons,
|
|||
|
||||
REGISTER_DEVICE_IMPL(points_in_polygons_forward_impl, CUDA,
|
||||
points_in_polygons_forward_cuda);
|
||||
|
||||
void MinAreaPolygonsCUDAKernelLauncher(const Tensor pointsets, Tensor polygons);
|
||||
|
||||
void min_area_polygons_cuda(const Tensor pointsets, Tensor polygons) {
|
||||
MinAreaPolygonsCUDAKernelLauncher(pointsets, polygons);
|
||||
}
|
||||
|
||||
void min_area_polygons_impl(const Tensor pointsets, Tensor polygons);
|
||||
REGISTER_DEVICE_IMPL(min_area_polygons_impl, CUDA, min_area_polygons_cuda);
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
// modified from
|
||||
// https://github.com/SDL-GuoZonghao/BeyondBoundingBox/blob/main/mmdet/ops/minareabbox/src/minareabbox_kernel.cu
|
||||
#include "min_area_polygons_cuda.cuh"
|
||||
#include "pytorch_cuda_helper.hpp"
|
||||
|
||||
void MinAreaPolygonsCUDAKernelLauncher(const Tensor pointsets,
|
||||
Tensor polygons) {
|
||||
int num_pointsets = pointsets.size(0);
|
||||
const int output_size = polygons.numel();
|
||||
at::cuda::CUDAGuard device_guard(pointsets.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
pointsets.scalar_type(), "min_area_polygons_cuda_kernel", ([&] {
|
||||
min_area_polygons_cuda_kernel<scalar_t>
|
||||
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
|
||||
num_pointsets, pointsets.data_ptr<scalar_t>(),
|
||||
polygons.data_ptr<scalar_t>());
|
||||
}));
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
|
@ -0,0 +1,11 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#include "pytorch_cpp_helper.hpp"
|
||||
#include "pytorch_device_registry.hpp"
|
||||
|
||||
void min_area_polygons_impl(const Tensor pointsets, Tensor polygons) {
|
||||
DISPATCH_DEVICE_IMPL(min_area_polygons_impl, pointsets, polygons);
|
||||
}
|
||||
|
||||
void min_area_polygons(const Tensor pointsets, Tensor polygons) {
|
||||
min_area_polygons_impl(pointsets, polygons);
|
||||
}
|
|
@ -363,6 +363,8 @@ void riroi_align_rotated_backward(Tensor top_grad, Tensor rois,
|
|||
|
||||
void points_in_polygons_forward(Tensor points, Tensor polygons, Tensor output);
|
||||
|
||||
void min_area_polygons(const Tensor pointsets, Tensor polygons);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"),
|
||||
py::arg("kernel"), py::arg("up_x"), py::arg("up_y"), py::arg("down_x"),
|
||||
|
@ -731,4 +733,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||
m.def("points_in_polygons_forward", &points_in_polygons_forward,
|
||||
"points_in_polygons_forward", py::arg("points"), py::arg("polygons"),
|
||||
py::arg("output"));
|
||||
m.def("min_area_polygons", &min_area_polygons, "min_area_polygons",
|
||||
py::arg("pointsets"), py::arg("polygons"));
|
||||
}
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from ..utils import ext_loader
|
||||
|
||||
ext_module = ext_loader.load_ext('_ext', ['min_area_polygons'])
|
||||
|
||||
|
||||
def min_area_polygons(pointsets):
|
||||
"""Find the smallest polygons that surrounds all points in the point sets.
|
||||
|
||||
Args:
|
||||
pointsets (Tensor): point sets with shape (N, 18).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Return the smallest polygons with shape (N, 8).
|
||||
"""
|
||||
polygons = pointsets.new_zeros((pointsets.size(0), 8))
|
||||
ext_module.min_area_polygons(pointsets, polygons)
|
||||
return polygons
|
|
@ -0,0 +1,29 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcv.ops import min_area_polygons
|
||||
|
||||
np_pointsets = np.asarray([[
|
||||
1.0, 1.0, 2.0, 2.0, 1.0, 2.0, 2.0, 1.0, 1.0, 3.0, 3.0, 1.0, 2.0, 3.0, 3.0,
|
||||
2.0, 1.5, 1.5
|
||||
],
|
||||
[
|
||||
1.5, 1.5, 2.5, 2.5, 1.5, 2.5, 2.5, 1.5, 1.5,
|
||||
3.5, 3.5, 1.5, 2.5, 3.5, 3.5, 2.5, 2.0, 2.0
|
||||
]])
|
||||
|
||||
expected_polygons = np.asarray(
|
||||
[[3.0000, 1.0000, 1.0000, 1.0000, 1.0000, 3.0000, 3.0000, 3.0000],
|
||||
[3.5000, 1.5000, 1.5000, 1.5000, 1.5000, 3.5000, 3.5000, 3.5000]])
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='requires CUDA support')
|
||||
def test_min_area_polygons():
|
||||
pointsets = torch.from_numpy(np_pointsets).cuda().float()
|
||||
|
||||
assert np.allclose(
|
||||
min_area_polygons(pointsets).cpu().numpy(),
|
||||
expected_polygons,
|
||||
atol=1e-4)
|
Loading…
Reference in New Issue