[Feature] Add points_in_polygons CUDA op for rotated detection. (#1600)

pull/1612/head
zhouyue 2021-12-24 10:56:48 +08:00 committed by GitHub
parent a4dc2a72ab
commit 304efbb650
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 210 additions and 1 deletions

View File

@ -20,6 +20,7 @@ We implement common CUDA ops used in detection, segmentation, etc.
- KNN
- MaskedConv
- NMS
- PointsInPolygons
- PSAMask
- RiRoIAlignRotated
- RotatedFeatureAlign

View File

@ -19,6 +19,7 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子
- KNN
- MaskedConv
- NMS
- PointsInPolygons
- PSAMask
- RotatedFeatureAlign
- RoIPointPool3d

View File

@ -38,6 +38,7 @@ from .point_sample import (SimpleRoIAlign, point_sample,
rel_roi_point_to_rel_img_point)
from .points_in_boxes import (points_in_boxes_all, points_in_boxes_cpu,
points_in_boxes_part)
from .points_in_polygons import points_in_polygons
from .points_sampler import PointsSampler
from .psa_mask import PSAMask
from .riroi_align_rotated import RiRoIAlignRotated, riroi_align_rotated
@ -80,5 +81,6 @@ __all__ = [
'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation',
'boxes_iou_bev', 'nms_bev', 'nms_normal_bev', 'Voxelization',
'voxelization', 'dynamic_scatter', 'DynamicScatter', 'RoIAwarePool3d',
'points_in_boxes_part', 'points_in_boxes_cpu', 'points_in_boxes_all'
'points_in_boxes_part', 'points_in_boxes_cpu', 'points_in_boxes_all',
'points_in_polygons'
]

View File

@ -0,0 +1,79 @@
// Copyright (c) OpenMMLab. All rights reserved
#ifndef POINTS_IN_POLYGONS_CUDA_KERNEL_CUH
#define POINTS_IN_POLYGONS_CUDA_KERNEL_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
struct point {
float x, y;
};
template <typename scalar_t>
__global__ void points_in_polygons_forward_cuda_kernel(
const int nthreads, const scalar_t *vertex1, const scalar_t *vertex2,
const int rows, const int cols, scalar_t *inside_flag) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
int row = index / cols;
int col = index % cols;
const scalar_t *offset_vertex1 = vertex1 + row * 2;
const scalar_t *offset_vertex2 = vertex2 + col * 8;
point point_[1];
point polygon[4];
point_[0].x = offset_vertex1[0];
point_[0].y = offset_vertex1[1];
polygon[0].x = offset_vertex2[0];
polygon[0].y = offset_vertex2[1];
polygon[1].x = offset_vertex2[2];
polygon[1].y = offset_vertex2[3];
polygon[2].x = offset_vertex2[4];
polygon[2].y = offset_vertex2[5];
polygon[3].x = offset_vertex2[6];
polygon[3].y = offset_vertex2[7];
int nCross = 0;
int i, j;
float sx, sy, tx, ty, px, py, x;
for (i = 0, j = 3; i < 4; j = i, i++) {
sx = polygon[i].x;
sy = polygon[i].y;
tx = polygon[j].x;
ty = polygon[j].y;
px = point_[0].x;
py = point_[0].y;
if (py < min(sy, ty)) continue;
if (py > max(sy, ty)) continue;
if ((sx == px && sy == py) || (tx == px && ty == py)) {
break;
} else {
if ((sy < py && ty >= py) || (sy >= py && ty < py)) {
x = sx + (py - sy) * (tx - sx) / (ty - sy);
if (x == px) {
break;
}
if (x > px) {
nCross++;
}
}
}
}
if (nCross % 2 == 1) {
inside_flag[index] = 1.0;
} else {
inside_flag[index] = 0.0;
}
return;
}
}
#endif // POINTS_IN_POLYGONS_CUDA_KERNEL_CUH

View File

@ -1481,3 +1481,22 @@ REGISTER_DEVICE_IMPL(rotated_feature_align_forward_impl, CUDA,
rotated_feature_align_forward_cuda);
REGISTER_DEVICE_IMPL(rotated_feature_align_backward_impl, CUDA,
rotated_feature_align_backward_cuda);
void PointsInPolygonsForwardCUDAKernelLauncher(const at::Tensor points,
const at::Tensor polygons,
const int rows, const int cols,
at::Tensor output);
void points_in_polygons_forward_cuda(const Tensor points, const Tensor polygons,
Tensor output, const int rows,
const int cols) {
PointsInPolygonsForwardCUDAKernelLauncher(points, polygons, rows, cols,
output);
};
void points_in_polygons_forward_impl(const Tensor points, const Tensor polygons,
Tensor output, const int rows,
const int cols);
REGISTER_DEVICE_IMPL(points_in_polygons_forward_impl, CUDA,
points_in_polygons_forward_cuda);

View File

@ -0,0 +1,28 @@
// Copyright (c) OpenMMLab. All rights reserved
// Modified from
// https://github.com/ming71/CUDA/blob/master/point_justify/points_justify_kernel.cu
#include <stdio.h>
#include "points_in_polygons_cuda_kernel.cuh"
#include "pytorch_cuda_helper.hpp"
void PointsInPolygonsForwardCUDAKernelLauncher(const at::Tensor points,
const at::Tensor polygons,
const int rows, const int cols,
at::Tensor output) {
const int output_size = rows * cols;
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
points.scalar_type(), "points_in_polygons_forward_cuda_kernel", ([&] {
const scalar_t *vertex1 = points.data_ptr<scalar_t>();
const scalar_t *vertex2 = polygons.data_ptr<scalar_t>();
scalar_t *inside_flag = output.data_ptr<scalar_t>();
points_in_polygons_forward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, vertex1, vertex2, rows, cols, inside_flag);
}));
AT_CUDA_CHECK(cudaGetLastError());
}

View File

@ -0,0 +1,15 @@
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"
void points_in_polygons_forward_impl(const Tensor points, const Tensor polygons,
Tensor output, const int rows,
const int cols) {
DISPATCH_DEVICE_IMPL(points_in_polygons_forward_impl, points, polygons,
output, rows, cols);
}
void points_in_polygons_forward(Tensor points, Tensor polygons, Tensor output) {
int rows = points.size(0);
int cols = polygons.size(0);
points_in_polygons_forward_impl(points, polygons, output, rows, cols);
}

View File

@ -361,6 +361,8 @@ void riroi_align_rotated_backward(Tensor top_grad, Tensor rois,
int num_samples, int num_orientations,
bool clockwise);
void points_in_polygons_forward(Tensor points, Tensor polygons, Tensor output);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"),
py::arg("kernel"), py::arg("up_x"), py::arg("up_y"), py::arg("down_x"),
@ -726,4 +728,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("pooled_width"), py::arg("spatial_scale"),
py::arg("num_samples"), py::arg("num_orientations"),
py::arg("clockwise"));
m.def("points_in_polygons_forward", &points_in_polygons_forward,
"points_in_polygons_forward", py::arg("points"), py::arg("polygons"),
py::arg("output"));
}

View File

@ -0,0 +1,37 @@
import torch
from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext', ['points_in_polygons_forward'])
def points_in_polygons(points, polygons):
"""Judging whether points are inside polygons, which is used in the ATSS
assignment for the rotated boxes.
It should be noted that when the point is just at the polygon boundary, the
judgment will be inaccurate, but the effect on assignment is limited.
Args:
points (torch.Tensor): It has shape (B, 2), indicating (x, y).
M means the number of predicted points.
polygons (torch.Tensor): It has shape (M, 8), indicating
(x1, y1, x2, y2, x3, y3, x4, y4). M means the number of
ground truth polygons.
Returns:
torch.Tensor: Return the result with the shape of (B, M),
1 indicates that the point is inside the polygon,
0 indicates that the point is outside the polygon.
"""
assert points.shape[1] == 2, \
'points dimension should be 2, ' \
f'but got unexpected shape {points.shape[1]}'
assert polygons.shape[1] == 8, \
'polygons dimension should be 8, ' \
f'but got unexpected shape {polygons.shape[1]}'
output = torch.full([points.shape[0], polygons.shape[0]],
0.).cuda().float()
ext_module.points_in_polygons_forward(points.contiguous(),
polygons.contiguous(), output)
return output

View File

@ -0,0 +1,22 @@
import numpy as np
import pytest
import torch
from mmcv.ops import points_in_polygons
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_points_in_polygons():
points = np.array([[300., 300.], [400., 400.], [100., 100], [300, 250],
[100, 0]])
polygons = np.array([[200., 200., 400., 400., 500., 200., 400., 100.],
[400., 400., 500., 500., 600., 300., 500., 200.],
[300., 300., 600., 700., 700., 700., 700., 100.]])
expected_output = np.array([[0., 0., 0.], [0., 0., 1.], [0., 0., 0.],
[1., 0., 0.], [0., 0., 0.]])
points = torch.from_numpy(points).cuda().float()
polygons = torch.from_numpy(polygons).cuda().float()
expected_output = torch.from_numpy(expected_output).cuda().float()
assert torch.allclose(
points_in_polygons(points, polygons), expected_output, 1e-3)