mirror of https://github.com/open-mmlab/mmcv.git
[Feature] Add points_in_polygons CUDA op for rotated detection. (#1600)
parent
a4dc2a72ab
commit
304efbb650
|
@ -20,6 +20,7 @@ We implement common CUDA ops used in detection, segmentation, etc.
|
||||||
- KNN
|
- KNN
|
||||||
- MaskedConv
|
- MaskedConv
|
||||||
- NMS
|
- NMS
|
||||||
|
- PointsInPolygons
|
||||||
- PSAMask
|
- PSAMask
|
||||||
- RiRoIAlignRotated
|
- RiRoIAlignRotated
|
||||||
- RotatedFeatureAlign
|
- RotatedFeatureAlign
|
||||||
|
|
|
@ -19,6 +19,7 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子
|
||||||
- KNN
|
- KNN
|
||||||
- MaskedConv
|
- MaskedConv
|
||||||
- NMS
|
- NMS
|
||||||
|
- PointsInPolygons
|
||||||
- PSAMask
|
- PSAMask
|
||||||
- RotatedFeatureAlign
|
- RotatedFeatureAlign
|
||||||
- RoIPointPool3d
|
- RoIPointPool3d
|
||||||
|
|
|
@ -38,6 +38,7 @@ from .point_sample import (SimpleRoIAlign, point_sample,
|
||||||
rel_roi_point_to_rel_img_point)
|
rel_roi_point_to_rel_img_point)
|
||||||
from .points_in_boxes import (points_in_boxes_all, points_in_boxes_cpu,
|
from .points_in_boxes import (points_in_boxes_all, points_in_boxes_cpu,
|
||||||
points_in_boxes_part)
|
points_in_boxes_part)
|
||||||
|
from .points_in_polygons import points_in_polygons
|
||||||
from .points_sampler import PointsSampler
|
from .points_sampler import PointsSampler
|
||||||
from .psa_mask import PSAMask
|
from .psa_mask import PSAMask
|
||||||
from .riroi_align_rotated import RiRoIAlignRotated, riroi_align_rotated
|
from .riroi_align_rotated import RiRoIAlignRotated, riroi_align_rotated
|
||||||
|
@ -80,5 +81,6 @@ __all__ = [
|
||||||
'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation',
|
'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation',
|
||||||
'boxes_iou_bev', 'nms_bev', 'nms_normal_bev', 'Voxelization',
|
'boxes_iou_bev', 'nms_bev', 'nms_normal_bev', 'Voxelization',
|
||||||
'voxelization', 'dynamic_scatter', 'DynamicScatter', 'RoIAwarePool3d',
|
'voxelization', 'dynamic_scatter', 'DynamicScatter', 'RoIAwarePool3d',
|
||||||
'points_in_boxes_part', 'points_in_boxes_cpu', 'points_in_boxes_all'
|
'points_in_boxes_part', 'points_in_boxes_cpu', 'points_in_boxes_all',
|
||||||
|
'points_in_polygons'
|
||||||
]
|
]
|
||||||
|
|
|
@ -0,0 +1,79 @@
|
||||||
|
// Copyright (c) OpenMMLab. All rights reserved
|
||||||
|
#ifndef POINTS_IN_POLYGONS_CUDA_KERNEL_CUH
|
||||||
|
#define POINTS_IN_POLYGONS_CUDA_KERNEL_CUH
|
||||||
|
|
||||||
|
#ifdef MMCV_USE_PARROTS
|
||||||
|
#include "parrots_cuda_helper.hpp"
|
||||||
|
#else
|
||||||
|
#include "pytorch_cuda_helper.hpp"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
struct point {
|
||||||
|
float x, y;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
__global__ void points_in_polygons_forward_cuda_kernel(
|
||||||
|
const int nthreads, const scalar_t *vertex1, const scalar_t *vertex2,
|
||||||
|
const int rows, const int cols, scalar_t *inside_flag) {
|
||||||
|
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
||||||
|
int row = index / cols;
|
||||||
|
int col = index % cols;
|
||||||
|
|
||||||
|
const scalar_t *offset_vertex1 = vertex1 + row * 2;
|
||||||
|
const scalar_t *offset_vertex2 = vertex2 + col * 8;
|
||||||
|
|
||||||
|
point point_[1];
|
||||||
|
point polygon[4];
|
||||||
|
|
||||||
|
point_[0].x = offset_vertex1[0];
|
||||||
|
point_[0].y = offset_vertex1[1];
|
||||||
|
|
||||||
|
polygon[0].x = offset_vertex2[0];
|
||||||
|
polygon[0].y = offset_vertex2[1];
|
||||||
|
polygon[1].x = offset_vertex2[2];
|
||||||
|
polygon[1].y = offset_vertex2[3];
|
||||||
|
polygon[2].x = offset_vertex2[4];
|
||||||
|
polygon[2].y = offset_vertex2[5];
|
||||||
|
polygon[3].x = offset_vertex2[6];
|
||||||
|
polygon[3].y = offset_vertex2[7];
|
||||||
|
|
||||||
|
int nCross = 0;
|
||||||
|
int i, j;
|
||||||
|
float sx, sy, tx, ty, px, py, x;
|
||||||
|
for (i = 0, j = 3; i < 4; j = i, i++) {
|
||||||
|
sx = polygon[i].x;
|
||||||
|
sy = polygon[i].y;
|
||||||
|
tx = polygon[j].x;
|
||||||
|
ty = polygon[j].y;
|
||||||
|
|
||||||
|
px = point_[0].x;
|
||||||
|
py = point_[0].y;
|
||||||
|
|
||||||
|
if (py < min(sy, ty)) continue;
|
||||||
|
if (py > max(sy, ty)) continue;
|
||||||
|
|
||||||
|
if ((sx == px && sy == py) || (tx == px && ty == py)) {
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
if ((sy < py && ty >= py) || (sy >= py && ty < py)) {
|
||||||
|
x = sx + (py - sy) * (tx - sx) / (ty - sy);
|
||||||
|
if (x == px) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (x > px) {
|
||||||
|
nCross++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (nCross % 2 == 1) {
|
||||||
|
inside_flag[index] = 1.0;
|
||||||
|
} else {
|
||||||
|
inside_flag[index] = 0.0;
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // POINTS_IN_POLYGONS_CUDA_KERNEL_CUH
|
|
@ -1481,3 +1481,22 @@ REGISTER_DEVICE_IMPL(rotated_feature_align_forward_impl, CUDA,
|
||||||
rotated_feature_align_forward_cuda);
|
rotated_feature_align_forward_cuda);
|
||||||
REGISTER_DEVICE_IMPL(rotated_feature_align_backward_impl, CUDA,
|
REGISTER_DEVICE_IMPL(rotated_feature_align_backward_impl, CUDA,
|
||||||
rotated_feature_align_backward_cuda);
|
rotated_feature_align_backward_cuda);
|
||||||
|
|
||||||
|
void PointsInPolygonsForwardCUDAKernelLauncher(const at::Tensor points,
|
||||||
|
const at::Tensor polygons,
|
||||||
|
const int rows, const int cols,
|
||||||
|
at::Tensor output);
|
||||||
|
|
||||||
|
void points_in_polygons_forward_cuda(const Tensor points, const Tensor polygons,
|
||||||
|
Tensor output, const int rows,
|
||||||
|
const int cols) {
|
||||||
|
PointsInPolygonsForwardCUDAKernelLauncher(points, polygons, rows, cols,
|
||||||
|
output);
|
||||||
|
};
|
||||||
|
|
||||||
|
void points_in_polygons_forward_impl(const Tensor points, const Tensor polygons,
|
||||||
|
Tensor output, const int rows,
|
||||||
|
const int cols);
|
||||||
|
|
||||||
|
REGISTER_DEVICE_IMPL(points_in_polygons_forward_impl, CUDA,
|
||||||
|
points_in_polygons_forward_cuda);
|
||||||
|
|
|
@ -0,0 +1,28 @@
|
||||||
|
// Copyright (c) OpenMMLab. All rights reserved
|
||||||
|
// Modified from
|
||||||
|
// https://github.com/ming71/CUDA/blob/master/point_justify/points_justify_kernel.cu
|
||||||
|
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
|
#include "points_in_polygons_cuda_kernel.cuh"
|
||||||
|
#include "pytorch_cuda_helper.hpp"
|
||||||
|
|
||||||
|
void PointsInPolygonsForwardCUDAKernelLauncher(const at::Tensor points,
|
||||||
|
const at::Tensor polygons,
|
||||||
|
const int rows, const int cols,
|
||||||
|
at::Tensor output) {
|
||||||
|
const int output_size = rows * cols;
|
||||||
|
at::cuda::CUDAGuard device_guard(points.device());
|
||||||
|
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||||
|
points.scalar_type(), "points_in_polygons_forward_cuda_kernel", ([&] {
|
||||||
|
const scalar_t *vertex1 = points.data_ptr<scalar_t>();
|
||||||
|
const scalar_t *vertex2 = polygons.data_ptr<scalar_t>();
|
||||||
|
scalar_t *inside_flag = output.data_ptr<scalar_t>();
|
||||||
|
|
||||||
|
points_in_polygons_forward_cuda_kernel<scalar_t>
|
||||||
|
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
|
||||||
|
output_size, vertex1, vertex2, rows, cols, inside_flag);
|
||||||
|
}));
|
||||||
|
AT_CUDA_CHECK(cudaGetLastError());
|
||||||
|
}
|
|
@ -0,0 +1,15 @@
|
||||||
|
#include "pytorch_cpp_helper.hpp"
|
||||||
|
#include "pytorch_device_registry.hpp"
|
||||||
|
|
||||||
|
void points_in_polygons_forward_impl(const Tensor points, const Tensor polygons,
|
||||||
|
Tensor output, const int rows,
|
||||||
|
const int cols) {
|
||||||
|
DISPATCH_DEVICE_IMPL(points_in_polygons_forward_impl, points, polygons,
|
||||||
|
output, rows, cols);
|
||||||
|
}
|
||||||
|
|
||||||
|
void points_in_polygons_forward(Tensor points, Tensor polygons, Tensor output) {
|
||||||
|
int rows = points.size(0);
|
||||||
|
int cols = polygons.size(0);
|
||||||
|
points_in_polygons_forward_impl(points, polygons, output, rows, cols);
|
||||||
|
}
|
|
@ -361,6 +361,8 @@ void riroi_align_rotated_backward(Tensor top_grad, Tensor rois,
|
||||||
int num_samples, int num_orientations,
|
int num_samples, int num_orientations,
|
||||||
bool clockwise);
|
bool clockwise);
|
||||||
|
|
||||||
|
void points_in_polygons_forward(Tensor points, Tensor polygons, Tensor output);
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"),
|
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"),
|
||||||
py::arg("kernel"), py::arg("up_x"), py::arg("up_y"), py::arg("down_x"),
|
py::arg("kernel"), py::arg("up_x"), py::arg("up_y"), py::arg("down_x"),
|
||||||
|
@ -726,4 +728,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
py::arg("pooled_width"), py::arg("spatial_scale"),
|
py::arg("pooled_width"), py::arg("spatial_scale"),
|
||||||
py::arg("num_samples"), py::arg("num_orientations"),
|
py::arg("num_samples"), py::arg("num_orientations"),
|
||||||
py::arg("clockwise"));
|
py::arg("clockwise"));
|
||||||
|
m.def("points_in_polygons_forward", &points_in_polygons_forward,
|
||||||
|
"points_in_polygons_forward", py::arg("points"), py::arg("polygons"),
|
||||||
|
py::arg("output"));
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,37 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ..utils import ext_loader
|
||||||
|
|
||||||
|
ext_module = ext_loader.load_ext('_ext', ['points_in_polygons_forward'])
|
||||||
|
|
||||||
|
|
||||||
|
def points_in_polygons(points, polygons):
|
||||||
|
"""Judging whether points are inside polygons, which is used in the ATSS
|
||||||
|
assignment for the rotated boxes.
|
||||||
|
|
||||||
|
It should be noted that when the point is just at the polygon boundary, the
|
||||||
|
judgment will be inaccurate, but the effect on assignment is limited.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
points (torch.Tensor): It has shape (B, 2), indicating (x, y).
|
||||||
|
M means the number of predicted points.
|
||||||
|
polygons (torch.Tensor): It has shape (M, 8), indicating
|
||||||
|
(x1, y1, x2, y2, x3, y3, x4, y4). M means the number of
|
||||||
|
ground truth polygons.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Return the result with the shape of (B, M),
|
||||||
|
1 indicates that the point is inside the polygon,
|
||||||
|
0 indicates that the point is outside the polygon.
|
||||||
|
"""
|
||||||
|
assert points.shape[1] == 2, \
|
||||||
|
'points dimension should be 2, ' \
|
||||||
|
f'but got unexpected shape {points.shape[1]}'
|
||||||
|
assert polygons.shape[1] == 8, \
|
||||||
|
'polygons dimension should be 8, ' \
|
||||||
|
f'but got unexpected shape {polygons.shape[1]}'
|
||||||
|
output = torch.full([points.shape[0], polygons.shape[0]],
|
||||||
|
0.).cuda().float()
|
||||||
|
ext_module.points_in_polygons_forward(points.contiguous(),
|
||||||
|
polygons.contiguous(), output)
|
||||||
|
return output
|
|
@ -0,0 +1,22 @@
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from mmcv.ops import points_in_polygons
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not torch.cuda.is_available(), reason='requires CUDA support')
|
||||||
|
def test_points_in_polygons():
|
||||||
|
points = np.array([[300., 300.], [400., 400.], [100., 100], [300, 250],
|
||||||
|
[100, 0]])
|
||||||
|
polygons = np.array([[200., 200., 400., 400., 500., 200., 400., 100.],
|
||||||
|
[400., 400., 500., 500., 600., 300., 500., 200.],
|
||||||
|
[300., 300., 600., 700., 700., 700., 700., 100.]])
|
||||||
|
expected_output = np.array([[0., 0., 0.], [0., 0., 1.], [0., 0., 0.],
|
||||||
|
[1., 0., 0.], [0., 0., 0.]])
|
||||||
|
points = torch.from_numpy(points).cuda().float()
|
||||||
|
polygons = torch.from_numpy(polygons).cuda().float()
|
||||||
|
expected_output = torch.from_numpy(expected_output).cuda().float()
|
||||||
|
assert torch.allclose(
|
||||||
|
points_in_polygons(points, polygons), expected_output, 1e-3)
|
Loading…
Reference in New Issue