mirror of https://github.com/open-mmlab/mmcv.git
[Feature] Add points_in_polygons CUDA op for rotated detection. (#1600)
parent
a4dc2a72ab
commit
304efbb650
|
@ -20,6 +20,7 @@ We implement common CUDA ops used in detection, segmentation, etc.
|
|||
- KNN
|
||||
- MaskedConv
|
||||
- NMS
|
||||
- PointsInPolygons
|
||||
- PSAMask
|
||||
- RiRoIAlignRotated
|
||||
- RotatedFeatureAlign
|
||||
|
|
|
@ -19,6 +19,7 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子
|
|||
- KNN
|
||||
- MaskedConv
|
||||
- NMS
|
||||
- PointsInPolygons
|
||||
- PSAMask
|
||||
- RotatedFeatureAlign
|
||||
- RoIPointPool3d
|
||||
|
|
|
@ -38,6 +38,7 @@ from .point_sample import (SimpleRoIAlign, point_sample,
|
|||
rel_roi_point_to_rel_img_point)
|
||||
from .points_in_boxes import (points_in_boxes_all, points_in_boxes_cpu,
|
||||
points_in_boxes_part)
|
||||
from .points_in_polygons import points_in_polygons
|
||||
from .points_sampler import PointsSampler
|
||||
from .psa_mask import PSAMask
|
||||
from .riroi_align_rotated import RiRoIAlignRotated, riroi_align_rotated
|
||||
|
@ -80,5 +81,6 @@ __all__ = [
|
|||
'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation',
|
||||
'boxes_iou_bev', 'nms_bev', 'nms_normal_bev', 'Voxelization',
|
||||
'voxelization', 'dynamic_scatter', 'DynamicScatter', 'RoIAwarePool3d',
|
||||
'points_in_boxes_part', 'points_in_boxes_cpu', 'points_in_boxes_all'
|
||||
'points_in_boxes_part', 'points_in_boxes_cpu', 'points_in_boxes_all',
|
||||
'points_in_polygons'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#ifndef POINTS_IN_POLYGONS_CUDA_KERNEL_CUH
|
||||
#define POINTS_IN_POLYGONS_CUDA_KERNEL_CUH
|
||||
|
||||
#ifdef MMCV_USE_PARROTS
|
||||
#include "parrots_cuda_helper.hpp"
|
||||
#else
|
||||
#include "pytorch_cuda_helper.hpp"
|
||||
#endif
|
||||
|
||||
struct point {
|
||||
float x, y;
|
||||
};
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void points_in_polygons_forward_cuda_kernel(
|
||||
const int nthreads, const scalar_t *vertex1, const scalar_t *vertex2,
|
||||
const int rows, const int cols, scalar_t *inside_flag) {
|
||||
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
int row = index / cols;
|
||||
int col = index % cols;
|
||||
|
||||
const scalar_t *offset_vertex1 = vertex1 + row * 2;
|
||||
const scalar_t *offset_vertex2 = vertex2 + col * 8;
|
||||
|
||||
point point_[1];
|
||||
point polygon[4];
|
||||
|
||||
point_[0].x = offset_vertex1[0];
|
||||
point_[0].y = offset_vertex1[1];
|
||||
|
||||
polygon[0].x = offset_vertex2[0];
|
||||
polygon[0].y = offset_vertex2[1];
|
||||
polygon[1].x = offset_vertex2[2];
|
||||
polygon[1].y = offset_vertex2[3];
|
||||
polygon[2].x = offset_vertex2[4];
|
||||
polygon[2].y = offset_vertex2[5];
|
||||
polygon[3].x = offset_vertex2[6];
|
||||
polygon[3].y = offset_vertex2[7];
|
||||
|
||||
int nCross = 0;
|
||||
int i, j;
|
||||
float sx, sy, tx, ty, px, py, x;
|
||||
for (i = 0, j = 3; i < 4; j = i, i++) {
|
||||
sx = polygon[i].x;
|
||||
sy = polygon[i].y;
|
||||
tx = polygon[j].x;
|
||||
ty = polygon[j].y;
|
||||
|
||||
px = point_[0].x;
|
||||
py = point_[0].y;
|
||||
|
||||
if (py < min(sy, ty)) continue;
|
||||
if (py > max(sy, ty)) continue;
|
||||
|
||||
if ((sx == px && sy == py) || (tx == px && ty == py)) {
|
||||
break;
|
||||
} else {
|
||||
if ((sy < py && ty >= py) || (sy >= py && ty < py)) {
|
||||
x = sx + (py - sy) * (tx - sx) / (ty - sy);
|
||||
if (x == px) {
|
||||
break;
|
||||
}
|
||||
if (x > px) {
|
||||
nCross++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (nCross % 2 == 1) {
|
||||
inside_flag[index] = 1.0;
|
||||
} else {
|
||||
inside_flag[index] = 0.0;
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
#endif // POINTS_IN_POLYGONS_CUDA_KERNEL_CUH
|
|
@ -1481,3 +1481,22 @@ REGISTER_DEVICE_IMPL(rotated_feature_align_forward_impl, CUDA,
|
|||
rotated_feature_align_forward_cuda);
|
||||
REGISTER_DEVICE_IMPL(rotated_feature_align_backward_impl, CUDA,
|
||||
rotated_feature_align_backward_cuda);
|
||||
|
||||
void PointsInPolygonsForwardCUDAKernelLauncher(const at::Tensor points,
|
||||
const at::Tensor polygons,
|
||||
const int rows, const int cols,
|
||||
at::Tensor output);
|
||||
|
||||
void points_in_polygons_forward_cuda(const Tensor points, const Tensor polygons,
|
||||
Tensor output, const int rows,
|
||||
const int cols) {
|
||||
PointsInPolygonsForwardCUDAKernelLauncher(points, polygons, rows, cols,
|
||||
output);
|
||||
};
|
||||
|
||||
void points_in_polygons_forward_impl(const Tensor points, const Tensor polygons,
|
||||
Tensor output, const int rows,
|
||||
const int cols);
|
||||
|
||||
REGISTER_DEVICE_IMPL(points_in_polygons_forward_impl, CUDA,
|
||||
points_in_polygons_forward_cuda);
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
// Modified from
|
||||
// https://github.com/ming71/CUDA/blob/master/point_justify/points_justify_kernel.cu
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#include "points_in_polygons_cuda_kernel.cuh"
|
||||
#include "pytorch_cuda_helper.hpp"
|
||||
|
||||
void PointsInPolygonsForwardCUDAKernelLauncher(const at::Tensor points,
|
||||
const at::Tensor polygons,
|
||||
const int rows, const int cols,
|
||||
at::Tensor output) {
|
||||
const int output_size = rows * cols;
|
||||
at::cuda::CUDAGuard device_guard(points.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
points.scalar_type(), "points_in_polygons_forward_cuda_kernel", ([&] {
|
||||
const scalar_t *vertex1 = points.data_ptr<scalar_t>();
|
||||
const scalar_t *vertex2 = polygons.data_ptr<scalar_t>();
|
||||
scalar_t *inside_flag = output.data_ptr<scalar_t>();
|
||||
|
||||
points_in_polygons_forward_cuda_kernel<scalar_t>
|
||||
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
|
||||
output_size, vertex1, vertex2, rows, cols, inside_flag);
|
||||
}));
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
#include "pytorch_cpp_helper.hpp"
|
||||
#include "pytorch_device_registry.hpp"
|
||||
|
||||
void points_in_polygons_forward_impl(const Tensor points, const Tensor polygons,
|
||||
Tensor output, const int rows,
|
||||
const int cols) {
|
||||
DISPATCH_DEVICE_IMPL(points_in_polygons_forward_impl, points, polygons,
|
||||
output, rows, cols);
|
||||
}
|
||||
|
||||
void points_in_polygons_forward(Tensor points, Tensor polygons, Tensor output) {
|
||||
int rows = points.size(0);
|
||||
int cols = polygons.size(0);
|
||||
points_in_polygons_forward_impl(points, polygons, output, rows, cols);
|
||||
}
|
|
@ -361,6 +361,8 @@ void riroi_align_rotated_backward(Tensor top_grad, Tensor rois,
|
|||
int num_samples, int num_orientations,
|
||||
bool clockwise);
|
||||
|
||||
void points_in_polygons_forward(Tensor points, Tensor polygons, Tensor output);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"),
|
||||
py::arg("kernel"), py::arg("up_x"), py::arg("up_y"), py::arg("down_x"),
|
||||
|
@ -726,4 +728,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||
py::arg("pooled_width"), py::arg("spatial_scale"),
|
||||
py::arg("num_samples"), py::arg("num_orientations"),
|
||||
py::arg("clockwise"));
|
||||
m.def("points_in_polygons_forward", &points_in_polygons_forward,
|
||||
"points_in_polygons_forward", py::arg("points"), py::arg("polygons"),
|
||||
py::arg("output"));
|
||||
}
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
import torch
|
||||
|
||||
from ..utils import ext_loader
|
||||
|
||||
ext_module = ext_loader.load_ext('_ext', ['points_in_polygons_forward'])
|
||||
|
||||
|
||||
def points_in_polygons(points, polygons):
|
||||
"""Judging whether points are inside polygons, which is used in the ATSS
|
||||
assignment for the rotated boxes.
|
||||
|
||||
It should be noted that when the point is just at the polygon boundary, the
|
||||
judgment will be inaccurate, but the effect on assignment is limited.
|
||||
|
||||
Args:
|
||||
points (torch.Tensor): It has shape (B, 2), indicating (x, y).
|
||||
M means the number of predicted points.
|
||||
polygons (torch.Tensor): It has shape (M, 8), indicating
|
||||
(x1, y1, x2, y2, x3, y3, x4, y4). M means the number of
|
||||
ground truth polygons.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Return the result with the shape of (B, M),
|
||||
1 indicates that the point is inside the polygon,
|
||||
0 indicates that the point is outside the polygon.
|
||||
"""
|
||||
assert points.shape[1] == 2, \
|
||||
'points dimension should be 2, ' \
|
||||
f'but got unexpected shape {points.shape[1]}'
|
||||
assert polygons.shape[1] == 8, \
|
||||
'polygons dimension should be 8, ' \
|
||||
f'but got unexpected shape {polygons.shape[1]}'
|
||||
output = torch.full([points.shape[0], polygons.shape[0]],
|
||||
0.).cuda().float()
|
||||
ext_module.points_in_polygons_forward(points.contiguous(),
|
||||
polygons.contiguous(), output)
|
||||
return output
|
|
@ -0,0 +1,22 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcv.ops import points_in_polygons
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='requires CUDA support')
|
||||
def test_points_in_polygons():
|
||||
points = np.array([[300., 300.], [400., 400.], [100., 100], [300, 250],
|
||||
[100, 0]])
|
||||
polygons = np.array([[200., 200., 400., 400., 500., 200., 400., 100.],
|
||||
[400., 400., 500., 500., 600., 300., 500., 200.],
|
||||
[300., 300., 600., 700., 700., 700., 700., 100.]])
|
||||
expected_output = np.array([[0., 0., 0.], [0., 0., 1.], [0., 0., 0.],
|
||||
[1., 0., 0.], [0., 0., 0.]])
|
||||
points = torch.from_numpy(points).cuda().float()
|
||||
polygons = torch.from_numpy(polygons).cuda().float()
|
||||
expected_output = torch.from_numpy(expected_output).cuda().float()
|
||||
assert torch.allclose(
|
||||
points_in_polygons(points, polygons), expected_output, 1e-3)
|
Loading…
Reference in New Issue