From 51b40c332aff9d2927fcc252b248d295850a4d55 Mon Sep 17 00:00:00 2001 From: Yue Zhou <592267829@qq.com> Date: Mon, 10 Jan 2022 11:00:50 +0800 Subject: [PATCH] [Feature] Add min_area_polygons CUDA op for rotated detection. (#1611) * init * Update pybind.cpp * Update min_area_polygons_cuda.cuh * Update cudabind.cpp * fix bug * Create test_min_area_polygons.py * add test * update * Update min_area_polygons_cuda.cuh * fix bugs. * Update min_area_polygons_cuda.cuh * Update min_area_polygons.py * Update min_area_polygons_cuda.cuh * merge these 4 nested loops * add AT_DISPATCH_FLOATING_TYPES_AND_HALF * fix lint * Resolving conflicts --- docs/en/understand_mmcv/ops.md | 1 + docs/zh_cn/understand_mmcv/ops.md | 1 + mmcv/ops/__init__.py | 3 +- .../common/cuda/min_area_polygons_cuda.cuh | 301 ++++++++++++++++++ mmcv/ops/csrc/pytorch/cuda/cudabind.cpp | 9 + .../csrc/pytorch/cuda/min_area_polygons.cu | 21 ++ mmcv/ops/csrc/pytorch/min_area_polygons.cpp | 11 + mmcv/ops/csrc/pytorch/pybind.cpp | 4 + mmcv/ops/min_area_polygons.py | 18 ++ tests/test_ops/test_min_area_polygons.py | 29 ++ 10 files changed, 397 insertions(+), 1 deletion(-) create mode 100644 mmcv/ops/csrc/common/cuda/min_area_polygons_cuda.cuh create mode 100644 mmcv/ops/csrc/pytorch/cuda/min_area_polygons.cu create mode 100644 mmcv/ops/csrc/pytorch/min_area_polygons.cpp create mode 100644 mmcv/ops/min_area_polygons.py create mode 100644 tests/test_ops/test_min_area_polygons.py diff --git a/docs/en/understand_mmcv/ops.md b/docs/en/understand_mmcv/ops.md index 083be9cbf..60e0e56de 100644 --- a/docs/en/understand_mmcv/ops.md +++ b/docs/en/understand_mmcv/ops.md @@ -19,6 +19,7 @@ We implement common CUDA ops used in detection, segmentation, etc. - GroupPoints - KNN - MaskedConv +- MinAreaPolygon - NMS - PointsInPolygons - PSAMask diff --git a/docs/zh_cn/understand_mmcv/ops.md b/docs/zh_cn/understand_mmcv/ops.md index 961ff60a9..2ccf2098e 100644 --- a/docs/zh_cn/understand_mmcv/ops.md +++ b/docs/zh_cn/understand_mmcv/ops.md @@ -18,6 +18,7 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子 - GeneralizedAttention - KNN - MaskedConv +- MinAreaPolygon - NMS - PointsInPolygons - PSAMask diff --git a/mmcv/ops/__init__.py b/mmcv/ops/__init__.py index 44240f65e..4670d7ba5 100644 --- a/mmcv/ops/__init__.py +++ b/mmcv/ops/__init__.py @@ -28,6 +28,7 @@ from .info import (get_compiler_version, get_compiling_cuda_version, from .iou3d import boxes_iou_bev, nms_bev, nms_normal_bev from .knn import knn from .masked_conv import MaskedConv2d, masked_conv2d +from .min_area_polygons import min_area_polygons from .modulated_deform_conv import (ModulatedDeformConv2d, ModulatedDeformConv2dPack, modulated_deform_conv2d) @@ -82,5 +83,5 @@ __all__ = [ 'boxes_iou_bev', 'nms_bev', 'nms_normal_bev', 'Voxelization', 'voxelization', 'dynamic_scatter', 'DynamicScatter', 'RoIAwarePool3d', 'points_in_boxes_part', 'points_in_boxes_cpu', 'points_in_boxes_all', - 'points_in_polygons' + 'points_in_polygons', 'min_area_polygons' ] diff --git a/mmcv/ops/csrc/common/cuda/min_area_polygons_cuda.cuh b/mmcv/ops/csrc/common/cuda/min_area_polygons_cuda.cuh new file mode 100644 index 000000000..f06288f0f --- /dev/null +++ b/mmcv/ops/csrc/common/cuda/min_area_polygons_cuda.cuh @@ -0,0 +1,301 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef MIN_AREA_POLYGONS_CUDA_KERNEL_CUH +#define MIN_AREA_POLYGONS_CUDA_KERNEL_CUH + +#ifdef MMCV_USE_PARROTS +#include "parrots_cuda_helper.hpp" +#else +#include "pytorch_cuda_helper.hpp" +#endif + +#define MAXN 20 +const float EPS = 1E-8; +const float PI = 3.1415926; + +struct Point { + float x, y; + __device__ Point() {} + __device__ Point(float x, float y) : x(x), y(y) {} +}; + +__device__ inline void swap1(Point *a, Point *b) { + Point temp; + temp.x = a->x; + temp.y = a->y; + + a->x = b->x; + a->y = b->y; + + b->x = temp.x; + b->y = temp.y; +} +__device__ inline float cross(Point o, Point a, Point b) { + return (a.x - o.x) * (b.y - o.y) - (b.x - o.x) * (a.y - o.y); +} + +__device__ inline float dis(Point a, Point b) { + return (a.x - b.x) * (a.x - b.x) + (a.y - b.y) * (a.y - b.y); +} +__device__ inline void minBoundingRect(Point *ps, int n_points, float *minbox) { + float convex_points[2][MAXN]; + for (int j = 0; j < n_points; j++) { + convex_points[0][j] = ps[j].x; + } + for (int j = 0; j < n_points; j++) { + convex_points[1][j] = ps[j].y; + } + + Point edges[MAXN]; + float edges_angles[MAXN]; + float unique_angles[MAXN]; + int n_edges = n_points - 1; + int n_unique = 0; + int unique_flag = 0; + + for (int i = 0; i < n_edges; i++) { + edges[i].x = ps[i + 1].x - ps[i].x; + edges[i].y = ps[i + 1].y - ps[i].y; + } + for (int i = 0; i < n_edges; i++) { + edges_angles[i] = atan2((double)edges[i].y, (double)edges[i].x); + if (edges_angles[i] >= 0) { + edges_angles[i] = fmod((double)edges_angles[i], (double)PI / 2); + } else { + edges_angles[i] = + edges_angles[i] - (int)(edges_angles[i] / (PI / 2) - 1) * (PI / 2); + } + } + unique_angles[0] = edges_angles[0]; + n_unique += 1; + for (int i = 1; i < n_edges; i++) { + for (int j = 0; j < n_unique; j++) { + if (edges_angles[i] == unique_angles[j]) { + unique_flag += 1; + } + } + if (unique_flag == 0) { + unique_angles[n_unique] = edges_angles[i]; + n_unique += 1; + unique_flag = 0; + } else { + unique_flag = 0; + } + } + + float minarea = 1e12; + for (int i = 0; i < n_unique; i++) { + float R[2][2]; + float rot_points[2][MAXN]; + R[0][0] = cos(unique_angles[i]); + R[0][1] = -sin(unique_angles[i]); + R[1][0] = sin(unique_angles[i]); + R[1][1] = cos(unique_angles[i]); + // R x Points + for (int m = 0; m < 2; m++) { + for (int n = 0; n < n_points; n++) { + float sum = 0.0; + for (int k = 0; k < 2; k++) { + sum = sum + R[m][k] * convex_points[k][n]; + } + rot_points[m][n] = sum; + } + } + + // xmin; + float xmin, ymin, xmax, ymax; + xmin = 1e12; + for (int j = 0; j < n_points; j++) { + if (isinf(rot_points[0][j]) || isnan(rot_points[0][j])) { + continue; + } else { + if (rot_points[0][j] < xmin) { + xmin = rot_points[0][j]; + } + } + } + // ymin + ymin = 1e12; + for (int j = 0; j < n_points; j++) { + if (isinf(rot_points[1][j]) || isnan(rot_points[1][j])) { + continue; + } else { + if (rot_points[1][j] < ymin) { + ymin = rot_points[1][j]; + } + } + } + // xmax + xmax = -1e12; + for (int j = 0; j < n_points; j++) { + if (isinf(rot_points[0][j]) || isnan(rot_points[0][j])) { + continue; + } else { + if (rot_points[0][j] > xmax) { + xmax = rot_points[0][j]; + } + } + } + // ymax + ymax = -1e12; + for (int j = 0; j < n_points; j++) { + if (isinf(rot_points[1][j]) || isnan(rot_points[1][j])) { + continue; + } else { + if (rot_points[1][j] > ymax) { + ymax = rot_points[1][j]; + } + } + } + float area = (xmax - xmin) * (ymax - ymin); + if (area < minarea) { + minarea = area; + minbox[0] = unique_angles[i]; + minbox[1] = xmin; + minbox[2] = ymin; + minbox[3] = xmax; + minbox[4] = ymax; + } + } +} + +// convex_find +__device__ inline void Jarvis(Point *in_poly, int &n_poly) { + int n_input = n_poly; + Point input_poly[20]; + for (int i = 0; i < n_input; i++) { + input_poly[i].x = in_poly[i].x; + input_poly[i].y = in_poly[i].y; + } + Point p_max, p_k; + int max_index, k_index; + int Stack[20], top1, top2; + // float sign; + double sign; + Point right_point[10], left_point[10]; + + for (int i = 0; i < n_poly; i++) { + if (in_poly[i].y < in_poly[0].y || + in_poly[i].y == in_poly[0].y && in_poly[i].x < in_poly[0].x) { + Point *j = &(in_poly[0]); + Point *k = &(in_poly[i]); + swap1(j, k); + } + if (i == 0) { + p_max = in_poly[0]; + max_index = 0; + } + if (in_poly[i].y > p_max.y || + in_poly[i].y == p_max.y && in_poly[i].x > p_max.x) { + p_max = in_poly[i]; + max_index = i; + } + } + if (max_index == 0) { + max_index = 1; + p_max = in_poly[max_index]; + } + + k_index = 0, Stack[0] = 0, top1 = 0; + while (k_index != max_index) { + p_k = p_max; + k_index = max_index; + for (int i = 1; i < n_poly; i++) { + sign = cross(in_poly[Stack[top1]], in_poly[i], p_k); + if ((sign > 0) || ((sign == 0) && (dis(in_poly[Stack[top1]], in_poly[i]) > + dis(in_poly[Stack[top1]], p_k)))) { + p_k = in_poly[i]; + k_index = i; + } + } + top1++; + Stack[top1] = k_index; + } + + for (int i = 0; i <= top1; i++) { + right_point[i] = in_poly[Stack[i]]; + } + + k_index = 0, Stack[0] = 0, top2 = 0; + + while (k_index != max_index) { + p_k = p_max; + k_index = max_index; + for (int i = 1; i < n_poly; i++) { + sign = cross(in_poly[Stack[top2]], in_poly[i], p_k); + if ((sign < 0) || (sign == 0) && (dis(in_poly[Stack[top2]], in_poly[i]) > + dis(in_poly[Stack[top2]], p_k))) { + p_k = in_poly[i]; + k_index = i; + } + } + top2++; + Stack[top2] = k_index; + } + + for (int i = top2 - 1; i >= 0; i--) { + left_point[i] = in_poly[Stack[i]]; + } + + for (int i = 0; i < top1 + top2; i++) { + if (i <= top1) { + in_poly[i] = right_point[i]; + } else { + in_poly[i] = left_point[top2 - (i - top1)]; + } + } + n_poly = top1 + top2; +} + +template +__device__ inline void Findminbox(T const *const p, T *minpoints) { + Point ps1[MAXN]; + Point convex[MAXN]; + for (int i = 0; i < 9; i++) { + convex[i].x = p[i * 2]; + convex[i].y = p[i * 2 + 1]; + } + int n_convex = 9; + Jarvis(convex, n_convex); + int n1 = n_convex; + for (int i = 0; i < n1; i++) { + ps1[i].x = convex[i].x; + ps1[i].y = convex[i].y; + } + ps1[n1].x = convex[0].x; + ps1[n1].y = convex[0].y; + + float minbbox[5] = {0}; + minBoundingRect(ps1, n1 + 1, minbbox); + float angle = minbbox[0]; + float xmin = minbbox[1]; + float ymin = minbbox[2]; + float xmax = minbbox[3]; + float ymax = minbbox[4]; + float R[2][2]; + + R[0][0] = cos(angle); + R[0][1] = -sin(angle); + R[1][0] = sin(angle); + R[1][1] = cos(angle); + + minpoints[0] = xmax * R[0][0] + ymin * R[1][0]; + minpoints[1] = xmax * R[0][1] + ymin * R[1][1]; + minpoints[2] = xmin * R[0][0] + ymin * R[1][0]; + minpoints[3] = xmin * R[0][1] + ymin * R[1][1]; + minpoints[4] = xmin * R[0][0] + ymax * R[1][0]; + minpoints[5] = xmin * R[0][1] + ymax * R[1][1]; + minpoints[6] = xmax * R[0][0] + ymax * R[1][0]; + minpoints[7] = xmax * R[0][1] + ymax * R[1][1]; +} + +template +__global__ void min_area_polygons_cuda_kernel(const int ex_n_boxes, + const T *ex_boxes, T *minbox) { + CUDA_1D_KERNEL_LOOP(index, ex_n_boxes) { + const T *cur_box = ex_boxes + index * 18; + T *cur_min_box = minbox + index * 8; + Findminbox(cur_box, cur_min_box); + } +} + +#endif // MIN_AREA_POLYGONS_CUDA_KERNEL_CUH diff --git a/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp b/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp index 36e874d56..c9448896f 100644 --- a/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp +++ b/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp @@ -1500,3 +1500,12 @@ void points_in_polygons_forward_impl(const Tensor points, const Tensor polygons, REGISTER_DEVICE_IMPL(points_in_polygons_forward_impl, CUDA, points_in_polygons_forward_cuda); + +void MinAreaPolygonsCUDAKernelLauncher(const Tensor pointsets, Tensor polygons); + +void min_area_polygons_cuda(const Tensor pointsets, Tensor polygons) { + MinAreaPolygonsCUDAKernelLauncher(pointsets, polygons); +} + +void min_area_polygons_impl(const Tensor pointsets, Tensor polygons); +REGISTER_DEVICE_IMPL(min_area_polygons_impl, CUDA, min_area_polygons_cuda); diff --git a/mmcv/ops/csrc/pytorch/cuda/min_area_polygons.cu b/mmcv/ops/csrc/pytorch/cuda/min_area_polygons.cu new file mode 100644 index 000000000..9314f2dda --- /dev/null +++ b/mmcv/ops/csrc/pytorch/cuda/min_area_polygons.cu @@ -0,0 +1,21 @@ +// Copyright (c) OpenMMLab. All rights reserved +// modified from +// https://github.com/SDL-GuoZonghao/BeyondBoundingBox/blob/main/mmdet/ops/minareabbox/src/minareabbox_kernel.cu +#include "min_area_polygons_cuda.cuh" +#include "pytorch_cuda_helper.hpp" + +void MinAreaPolygonsCUDAKernelLauncher(const Tensor pointsets, + Tensor polygons) { + int num_pointsets = pointsets.size(0); + const int output_size = polygons.numel(); + at::cuda::CUDAGuard device_guard(pointsets.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + pointsets.scalar_type(), "min_area_polygons_cuda_kernel", ([&] { + min_area_polygons_cuda_kernel + <<>>( + num_pointsets, pointsets.data_ptr(), + polygons.data_ptr()); + })); + AT_CUDA_CHECK(cudaGetLastError()); +} diff --git a/mmcv/ops/csrc/pytorch/min_area_polygons.cpp b/mmcv/ops/csrc/pytorch/min_area_polygons.cpp new file mode 100644 index 000000000..8ff996dc8 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/min_area_polygons.cpp @@ -0,0 +1,11 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include "pytorch_cpp_helper.hpp" +#include "pytorch_device_registry.hpp" + +void min_area_polygons_impl(const Tensor pointsets, Tensor polygons) { + DISPATCH_DEVICE_IMPL(min_area_polygons_impl, pointsets, polygons); +} + +void min_area_polygons(const Tensor pointsets, Tensor polygons) { + min_area_polygons_impl(pointsets, polygons); +} diff --git a/mmcv/ops/csrc/pytorch/pybind.cpp b/mmcv/ops/csrc/pytorch/pybind.cpp index 02b665f09..bd87bedef 100644 --- a/mmcv/ops/csrc/pytorch/pybind.cpp +++ b/mmcv/ops/csrc/pytorch/pybind.cpp @@ -363,6 +363,8 @@ void riroi_align_rotated_backward(Tensor top_grad, Tensor rois, void points_in_polygons_forward(Tensor points, Tensor polygons, Tensor output); +void min_area_polygons(const Tensor pointsets, Tensor polygons); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"), py::arg("kernel"), py::arg("up_x"), py::arg("up_y"), py::arg("down_x"), @@ -731,4 +733,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("points_in_polygons_forward", &points_in_polygons_forward, "points_in_polygons_forward", py::arg("points"), py::arg("polygons"), py::arg("output")); + m.def("min_area_polygons", &min_area_polygons, "min_area_polygons", + py::arg("pointsets"), py::arg("polygons")); } diff --git a/mmcv/ops/min_area_polygons.py b/mmcv/ops/min_area_polygons.py new file mode 100644 index 000000000..9f42a8be1 --- /dev/null +++ b/mmcv/ops/min_area_polygons.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', ['min_area_polygons']) + + +def min_area_polygons(pointsets): + """Find the smallest polygons that surrounds all points in the point sets. + + Args: + pointsets (Tensor): point sets with shape (N, 18). + + Returns: + torch.Tensor: Return the smallest polygons with shape (N, 8). + """ + polygons = pointsets.new_zeros((pointsets.size(0), 8)) + ext_module.min_area_polygons(pointsets, polygons) + return polygons diff --git a/tests/test_ops/test_min_area_polygons.py b/tests/test_ops/test_min_area_polygons.py new file mode 100644 index 000000000..a335321ff --- /dev/null +++ b/tests/test_ops/test_min_area_polygons.py @@ -0,0 +1,29 @@ +import numpy as np +import pytest +import torch + +from mmcv.ops import min_area_polygons + +np_pointsets = np.asarray([[ + 1.0, 1.0, 2.0, 2.0, 1.0, 2.0, 2.0, 1.0, 1.0, 3.0, 3.0, 1.0, 2.0, 3.0, 3.0, + 2.0, 1.5, 1.5 +], + [ + 1.5, 1.5, 2.5, 2.5, 1.5, 2.5, 2.5, 1.5, 1.5, + 3.5, 3.5, 1.5, 2.5, 3.5, 3.5, 2.5, 2.0, 2.0 + ]]) + +expected_polygons = np.asarray( + [[3.0000, 1.0000, 1.0000, 1.0000, 1.0000, 3.0000, 3.0000, 3.0000], + [3.5000, 1.5000, 1.5000, 1.5000, 1.5000, 3.5000, 3.5000, 3.5000]]) + + +@pytest.mark.skipif( + not torch.cuda.is_available(), reason='requires CUDA support') +def test_min_area_polygons(): + pointsets = torch.from_numpy(np_pointsets).cuda().float() + + assert np.allclose( + min_area_polygons(pointsets).cpu().numpy(), + expected_polygons, + atol=1e-4)