[Feature] Add box_iou_quadri & nms_quadri (#2277)

* update

* update

* fix link

* fix bug

* update nms_quadri

* fix lint

* Update test_nms_quadri.py

* Update box_iou_quadri.py

* fix bug

* Update test_nms_quadri.py

* Update box_iou_rotated_utils.hpp

* Update box_iou_quadri.py

* Update mmcv/ops/nms.py
pull/2371/head
Yue Zhou 2022-10-13 17:26:14 +08:00 committed by Zaida Zhou
parent f48975b2e3
commit a4f304a5f5
18 changed files with 864 additions and 9 deletions

View File

@ -10,6 +10,7 @@ We implement common ops used in detection, segmentation, etc.
| BBoxOverlaps | | √ | √ | √ |
| BorderAlign | | √ | | |
| BoxIouRotated | √ | √ | | |
| BoxIouQuadri | √ | √ | | |
| CARAFE | | √ | √ | |
| ChamferDistance | | √ | | |
| CrissCrossAttention | | √ | | |
@ -35,6 +36,7 @@ We implement common ops used in detection, segmentation, etc.
| MultiScaleDeformableAttn | | √ | | |
| NMS | √ | √ | √ | |
| NMSRotated | √ | √ | | |
| NMSQuadri | √ | √ | | |
| PixelGroup | √ | | | |
| PointsInBoxes | √ | √ | | |
| PointsInPolygons | | √ | | |

View File

@ -10,6 +10,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| BBoxOverlaps | | √ | √ | √ |
| BorderAlign | | √ | | |
| BoxIouRotated | √ | √ | | |
| BoxIouQuadri | √ | √ | | |
| CARAFE | | √ | √ | |
| ChamferDistance | | √ | | |
| CrissCrossAttention | | √ | | |
@ -35,6 +36,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| MultiScaleDeformableAttn | | √ | | |
| NMS | √ | √ | √ | |
| NMSRotated | √ | √ | | |
| NMSQuadri | √ | √ | | |
| PixelGroup | √ | | | |
| PointsInBoxes | √ | √ | | |
| PointsInPolygons | | √ | | |

View File

@ -4,6 +4,7 @@ from .assign_score_withk import assign_score_withk
from .ball_query import ball_query
from .bbox import bbox_overlaps
from .border_align import BorderAlign, border_align
from .box_iou_quadri import box_iou_quadri
from .box_iou_rotated import box_iou_rotated
from .carafe import CARAFE, CARAFENaive, CARAFEPack, carafe, carafe_naive
from .cc_attention import CrissCrossAttention
@ -38,7 +39,7 @@ from .modulated_deform_conv import (ModulatedDeformConv2d,
ModulatedDeformConv2dPack,
modulated_deform_conv2d)
from .multi_scale_deform_attn import MultiScaleDeformableAttention
from .nms import batched_nms, nms, nms_match, nms_rotated, soft_nms
from .nms import batched_nms, nms, nms_match, nms_quadri, nms_rotated, soft_nms
from .pixel_group import pixel_group
from .point_sample import (SimpleRoIAlign, point_sample,
rel_roi_point_to_rel_img_point)
@ -84,13 +85,14 @@ __all__ = [
'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask',
'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign',
'SAConv2d', 'TINShift', 'tin_shift', 'assign_score_withk',
'box_iou_rotated', 'RoIPointPool3d', 'nms_rotated', 'knn', 'ball_query',
'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu',
'rotated_feature_align', 'RiRoIAlignRotated', 'riroi_align_rotated',
'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'QueryAndGroup',
'GroupAll', 'grouping_operation', 'contour_expand', 'three_nn',
'three_interpolate', 'MultiScaleDeformableAttention', 'BorderAlign',
'border_align', 'gather_points', 'furthest_point_sample',
'box_iou_rotated', 'box_iou_quadri', 'RoIPointPool3d', 'nms_rotated',
'knn', 'ball_query', 'upfirdn2d', 'FusedBiasLeakyReLU',
'fused_bias_leakyrelu', 'rotated_feature_align', 'RiRoIAlignRotated',
'riroi_align_rotated', 'RoIAlignRotated', 'roi_align_rotated',
'pixel_group', 'QueryAndGroup', 'GroupAll', 'grouping_operation',
'contour_expand', 'three_nn', 'three_interpolate',
'MultiScaleDeformableAttention', 'BorderAlign', 'border_align',
'gather_points', 'furthest_point_sample', 'nms_quadri',
'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation',
'boxes_iou3d', 'boxes_iou_bev', 'boxes_overlap_bev', 'nms_bev',
'nms_normal_bev', 'nms3d', 'nms3d_normal', 'Voxelization', 'voxelization',

View File

@ -0,0 +1,49 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext', ['box_iou_quadri'])
def box_iou_quadri(bboxes1: torch.Tensor,
bboxes2: torch.Tensor,
mode: str = 'iou',
aligned: bool = False) -> torch.Tensor:
"""Return intersection-over-union (Jaccard index) of boxes.
Both sets of boxes are expected to be in
(x1, y1, ..., x4, y4) format.
If ``aligned`` is ``False``, then calculate the ious between each bbox
of bboxes1 and bboxes2, otherwise the ious between each aligned pair of
bboxes1 and bboxes2.
Args:
bboxes1 (torch.Tensor): quadrilateral bboxes 1. It has shape (N, 8),
indicating (x1, y1, ..., x4, y4) for each row.
bboxes2 (torch.Tensor): quadrilateral bboxes 2. It has shape (M, 8),
indicating (x1, y1, ..., x4, y4) for each row.
mode (str): "iou" (intersection over union) or iof (intersection over
foreground).
Returns:
torch.Tensor: Return the ious betweens boxes. If ``aligned`` is
``False``, the shape of ious is (N, M) else (N,).
"""
assert mode in ['iou', 'iof']
mode_dict = {'iou': 0, 'iof': 1}
mode_flag = mode_dict[mode]
rows = bboxes1.size(0)
cols = bboxes2.size(0)
if aligned:
ious = bboxes1.new_zeros(rows)
else:
ious = bboxes1.new_zeros(rows * cols)
bboxes1 = bboxes1.contiguous()
bboxes2 = bboxes2.contiguous()
ext_module.box_iou_quadri(
bboxes1, bboxes2, ious, mode_flag=mode_flag, aligned=aligned)
if not aligned:
ious = ious.view(rows, cols)
return ious

View File

@ -270,6 +270,17 @@ HOST_DEVICE_INLINE int convex_hull_graham(const Point<T> (&p)[24],
return m;
}
template <typename T>
HOST_DEVICE_INLINE T quadri_box_area(const Point<T> (&q)[4]) {
T area = 0;
#pragma unroll
for (int i = 1; i < 3; i++) {
area += fabs(cross_2d<T>(q[i] - q[0], q[i + 1] - q[0]));
}
return area / 2.0;
}
template <typename T>
HOST_DEVICE_INLINE T polygon_area(const Point<T> (&q)[24], const int& m) {
if (m <= 2) {
@ -308,6 +319,25 @@ HOST_DEVICE_INLINE T rotated_boxes_intersection(const RotatedBox<T>& box1,
return polygon_area<T>(orderedPts, num_convex);
}
template <typename T>
HOST_DEVICE_INLINE T quadri_boxes_intersection(const Point<T> (&pts1)[4],
const Point<T> (&pts2)[4]) {
// There are up to 4 x 4 + 4 + 4 = 24 intersections (including dups) returned
// from rotated_rect_intersection_pts
Point<T> intersectPts[24], orderedPts[24];
int num = get_intersection_points<T>(pts1, pts2, intersectPts);
if (num <= 2) {
return 0.0;
}
// Convex Hull to order the intersection points in clockwise order and find
// the contour area.
int num_convex = convex_hull_graham<T>(intersectPts, num, orderedPts, true);
return polygon_area<T>(orderedPts, num_convex);
}
} // namespace
template <typename T>
@ -345,3 +375,52 @@ HOST_DEVICE_INLINE T single_box_iou_rotated(T const* const box1_raw,
const T iou = intersection / baseS;
return iou;
}
template <typename T>
HOST_DEVICE_INLINE T single_box_iou_quadri(T const* const pts1_raw,
T const* const pts2_raw,
const int mode_flag) {
// shift center to the middle point to achieve higher precision in result
Point<T> pts1[4], pts2[4];
auto center_shift_x =
(pts1_raw[0] + pts2_raw[0] + pts1_raw[2] + pts2_raw[2] + pts1_raw[4] +
pts2_raw[4] + pts1_raw[6] + pts2_raw[6]) /
8.0;
auto center_shift_y =
(pts1_raw[1] + pts2_raw[1] + pts1_raw[3] + pts2_raw[3] + pts1_raw[5] +
pts2_raw[5] + pts1_raw[7] + pts2_raw[7]) /
8.0;
pts1[0].x = pts1_raw[0] - center_shift_x;
pts1[0].y = pts1_raw[1] - center_shift_y;
pts1[1].x = pts1_raw[2] - center_shift_x;
pts1[1].y = pts1_raw[3] - center_shift_y;
pts1[2].x = pts1_raw[4] - center_shift_x;
pts1[2].y = pts1_raw[5] - center_shift_y;
pts1[3].x = pts1_raw[6] - center_shift_x;
pts1[3].y = pts1_raw[7] - center_shift_y;
pts2[0].x = pts2_raw[0] - center_shift_x;
pts2[0].y = pts2_raw[1] - center_shift_y;
pts2[1].x = pts2_raw[2] - center_shift_x;
pts2[1].y = pts2_raw[3] - center_shift_y;
pts2[2].x = pts2_raw[4] - center_shift_x;
pts2[2].y = pts2_raw[5] - center_shift_y;
pts2[3].x = pts2_raw[6] - center_shift_x;
pts2[3].y = pts2_raw[7] - center_shift_y;
const T area1 = quadri_box_area<T>(pts1);
const T area2 = quadri_box_area<T>(pts2);
if (area1 < 1e-14 || area2 < 1e-14) {
return 0.f;
}
const T intersection = quadri_boxes_intersection<T>(pts1, pts2);
T baseS = 1.0;
if (mode_flag == 0) {
baseS = (area1 + area2 - intersection);
} else if (mode_flag == 1) {
baseS = area1;
}
const T iou = intersection / baseS;
return iou;
}

View File

@ -0,0 +1,91 @@
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
#ifndef BOX_IOU_QUADRI_CUDA_CUH
#define BOX_IOU_QUADRI_CUDA_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
#include "box_iou_rotated_utils.hpp"
// 2D block with 32 * 16 = 512 threads per block
const int BLOCK_DIM_X = 32;
const int BLOCK_DIM_Y = 16;
inline int divideUP(const int x, const int y) { return (((x) + (y)-1) / (y)); }
template <typename T>
__global__ void box_iou_quadri_cuda_kernel(
const int n_boxes1, const int n_boxes2, const T* dev_boxes1,
const T* dev_boxes2, T* dev_ious, const int mode_flag, const bool aligned) {
if (aligned) {
CUDA_1D_KERNEL_LOOP(index, n_boxes1) {
int b1 = index;
int b2 = index;
int base1 = b1 * 8;
float block_boxes1[8];
float block_boxes2[8];
block_boxes1[0] = dev_boxes1[base1 + 0];
block_boxes1[1] = dev_boxes1[base1 + 1];
block_boxes1[2] = dev_boxes1[base1 + 2];
block_boxes1[3] = dev_boxes1[base1 + 3];
block_boxes1[4] = dev_boxes1[base1 + 4];
block_boxes1[5] = dev_boxes1[base1 + 5];
block_boxes1[6] = dev_boxes1[base1 + 6];
block_boxes1[7] = dev_boxes1[base1 + 7];
int base2 = b2 * 8;
block_boxes2[0] = dev_boxes2[base2 + 0];
block_boxes2[1] = dev_boxes2[base2 + 1];
block_boxes2[2] = dev_boxes2[base2 + 2];
block_boxes2[3] = dev_boxes2[base2 + 3];
block_boxes2[4] = dev_boxes2[base2 + 4];
block_boxes2[5] = dev_boxes2[base2 + 5];
block_boxes2[6] = dev_boxes2[base2 + 6];
block_boxes2[7] = dev_boxes2[base2 + 7];
dev_ious[index] =
single_box_iou_quadri<T>(block_boxes1, block_boxes2, mode_flag);
}
} else {
CUDA_1D_KERNEL_LOOP(index, n_boxes1 * n_boxes2) {
int b1 = index / n_boxes2;
int b2 = index % n_boxes2;
int base1 = b1 * 8;
float block_boxes1[8];
float block_boxes2[8];
block_boxes1[0] = dev_boxes1[base1 + 0];
block_boxes1[1] = dev_boxes1[base1 + 1];
block_boxes1[2] = dev_boxes1[base1 + 2];
block_boxes1[3] = dev_boxes1[base1 + 3];
block_boxes1[4] = dev_boxes1[base1 + 4];
block_boxes1[5] = dev_boxes1[base1 + 5];
block_boxes1[6] = dev_boxes1[base1 + 6];
block_boxes1[7] = dev_boxes1[base1 + 7];
int base2 = b2 * 8;
block_boxes2[0] = dev_boxes2[base2 + 0];
block_boxes2[1] = dev_boxes2[base2 + 1];
block_boxes2[2] = dev_boxes2[base2 + 2];
block_boxes2[3] = dev_boxes2[base2 + 3];
block_boxes2[4] = dev_boxes2[base2 + 4];
block_boxes2[5] = dev_boxes2[base2 + 5];
block_boxes2[6] = dev_boxes2[base2 + 6];
block_boxes2[7] = dev_boxes2[base2 + 7];
dev_ious[index] =
single_box_iou_quadri<T>(block_boxes1, block_boxes2, mode_flag);
}
}
}
#endif

View File

@ -0,0 +1,141 @@
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
#ifndef NMS_QUADRI_CUDA_CUH
#define NMS_QUADRI_CUDA_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
#include "box_iou_rotated_utils.hpp"
__host__ __device__ inline int divideUP(const int x, const int y) {
return (((x) + (y)-1) / (y));
}
namespace {
int const threadsPerBlock = sizeof(unsigned long long) * 8;
}
template <typename T>
__global__ void nms_quadri_cuda_kernel(const int n_boxes,
const float iou_threshold,
const T* dev_boxes,
unsigned long long* dev_mask,
const int multi_label) {
if (multi_label == 1) {
const int row_start = blockIdx.y;
const int col_start = blockIdx.x;
// if (row_start > col_start) return;
const int row_size =
min(n_boxes - row_start * threadsPerBlock, threadsPerBlock);
const int col_size =
min(n_boxes - col_start * threadsPerBlock, threadsPerBlock);
// Compared to nms_cuda_kernel, where each box is represented with 4 values
// (x1, y1, x2, y2), each rotated box is represented with 8 values
// (x1, y1, ..., x4, y4) here.
__shared__ T block_boxes[threadsPerBlock * 8];
if (threadIdx.x < col_size) {
block_boxes[threadIdx.x * 8 + 0] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 9 + 0];
block_boxes[threadIdx.x * 8 + 1] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 9 + 1];
block_boxes[threadIdx.x * 8 + 2] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 9 + 2];
block_boxes[threadIdx.x * 8 + 3] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 9 + 3];
block_boxes[threadIdx.x * 8 + 4] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 9 + 4];
block_boxes[threadIdx.x * 8 + 5] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 9 + 5];
block_boxes[threadIdx.x * 8 + 6] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 9 + 6];
block_boxes[threadIdx.x * 8 + 7] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 9 + 7];
}
__syncthreads();
if (threadIdx.x < row_size) {
const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;
const T* cur_box = dev_boxes + cur_box_idx * 9;
int i = 0;
unsigned long long t = 0;
int start = 0;
if (row_start == col_start) {
start = threadIdx.x + 1;
}
for (i = start; i < col_size; i++) {
// Instead of devIoU used by original horizontal nms, here
// we use the single_box_iou_quadri function from
// box_iou_rotated_utils.h
if (single_box_iou_quadri<T>(cur_box, block_boxes + i * 8, 0) >
iou_threshold) {
t |= 1ULL << i;
}
}
const int col_blocks = divideUP(n_boxes, threadsPerBlock);
dev_mask[cur_box_idx * col_blocks + col_start] = t;
}
} else {
const int row_start = blockIdx.y;
const int col_start = blockIdx.x;
// if (row_start > col_start) return;
const int row_size =
min(n_boxes - row_start * threadsPerBlock, threadsPerBlock);
const int col_size =
min(n_boxes - col_start * threadsPerBlock, threadsPerBlock);
// Compared to nms_cuda_kernel, where each box is represented with 4 values
// (x1, y1, x2, y2), each rotated box is represented with 8 values
// (x1, y1, , ..., x4, y4) here.
__shared__ T block_boxes[threadsPerBlock * 8];
if (threadIdx.x < col_size) {
block_boxes[threadIdx.x * 8 + 0] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 8 + 0];
block_boxes[threadIdx.x * 8 + 1] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 8 + 1];
block_boxes[threadIdx.x * 8 + 2] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 8 + 2];
block_boxes[threadIdx.x * 8 + 3] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 8 + 3];
block_boxes[threadIdx.x * 8 + 4] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 8 + 4];
block_boxes[threadIdx.x * 8 + 5] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 8 + 5];
block_boxes[threadIdx.x * 8 + 6] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 8 + 6];
block_boxes[threadIdx.x * 8 + 7] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 8 + 7];
}
__syncthreads();
if (threadIdx.x < row_size) {
const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;
const T* cur_box = dev_boxes + cur_box_idx * 8;
int i = 0;
unsigned long long t = 0;
int start = 0;
if (row_start == col_start) {
start = threadIdx.x + 1;
}
for (i = start; i < col_size; i++) {
// Instead of devIoU used by original horizontal nms, here
// we use the single_box_iou_quadri function from
// box_iou_rotated_utils.h
if (single_box_iou_quadri<T>(cur_box, block_boxes + i * 8, 0) >
iou_threshold) {
t |= 1ULL << i;
}
}
const int col_blocks = divideUP(n_boxes, threadsPerBlock);
dev_mask[cur_box_idx * col_blocks + col_start] = t;
}
}
}
#endif

View File

@ -0,0 +1,17 @@
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"
void box_iou_quadri_impl(const Tensor boxes1, const Tensor boxes2, Tensor ious,
const int mode_flag, const bool aligned) {
DISPATCH_DEVICE_IMPL(box_iou_quadri_impl, boxes1, boxes2, ious, mode_flag,
aligned);
}
// Interface for Python
// inline is needed to prevent multiple function definitions when this header is
// included by different cpps
void box_iou_quadri(const Tensor boxes1, const Tensor boxes2, Tensor ious,
const int mode_flag, const bool aligned) {
box_iou_quadri_impl(boxes1, boxes2, ious, mode_flag, aligned);
}

View File

@ -0,0 +1,36 @@
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
#include "box_iou_rotated_utils.hpp"
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"
template <typename T>
void box_iou_quadri_cpu_kernel(const Tensor boxes1, const Tensor boxes2,
Tensor ious, const int mode_flag,
const bool aligned) {
int output_size = ious.numel();
auto num_boxes1 = boxes1.size(0);
auto num_boxes2 = boxes2.size(0);
if (aligned) {
for (int i = 0; i < output_size; i++) {
ious[i] = single_box_iou_quadri<T>(boxes1[i].data_ptr<T>(),
boxes2[i].data_ptr<T>(), mode_flag);
}
} else {
for (int i = 0; i < num_boxes1; i++) {
for (int j = 0; j < num_boxes2; j++) {
ious[i * num_boxes2 + j] = single_box_iou_quadri<T>(
boxes1[i].data_ptr<T>(), boxes2[j].data_ptr<T>(), mode_flag);
}
}
}
}
void box_iou_quadri_cpu(const Tensor boxes1, const Tensor boxes2, Tensor ious,
const int mode_flag, const bool aligned) {
box_iou_quadri_cpu_kernel<float>(boxes1, boxes2, ious, mode_flag, aligned);
}
void box_iou_quadri_impl(const Tensor boxes1, const Tensor boxes2, Tensor ious,
const int mode_flag, const bool aligned);
REGISTER_DEVICE_IMPL(box_iou_quadri_impl, CPU, box_iou_quadri_cpu);

View File

@ -0,0 +1,64 @@
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
#include "box_iou_rotated_utils.hpp"
#include "pytorch_cpp_helper.hpp"
template <typename scalar_t>
Tensor nms_quadri_cpu_kernel(const Tensor dets, const Tensor scores,
const float iou_threshold) {
// nms_quadri_cpu_kernel is modified from torchvision's nms_cpu_kernel,
// however, the code in this function is much shorter because
// we delegate the IoU computation for quadri boxes to
// the single_box_iou_quadri function in box_iou_rotated_utils.h
AT_ASSERTM(!dets.is_cuda(), "dets must be a CPU tensor");
AT_ASSERTM(!scores.is_cuda(), "scores must be a CPU tensor");
AT_ASSERTM(dets.scalar_type() == scores.scalar_type(),
"dets should have the same type as scores");
if (dets.numel() == 0) {
return at::empty({0}, dets.options().dtype(at::kLong));
}
auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));
auto ndets = dets.size(0);
Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte));
Tensor keep_t = at::zeros({ndets}, dets.options().dtype(at::kLong));
auto suppressed = suppressed_t.data_ptr<uint8_t>();
auto keep = keep_t.data_ptr<int64_t>();
auto order = order_t.data_ptr<int64_t>();
int64_t num_to_keep = 0;
for (int64_t _i = 0; _i < ndets; _i++) {
auto i = order[_i];
if (suppressed[i] == 1) {
continue;
}
keep[num_to_keep++] = i;
for (int64_t _j = _i + 1; _j < ndets; _j++) {
auto j = order[_j];
if (suppressed[j] == 1) {
continue;
}
auto ovr = single_box_iou_quadri<scalar_t>(
dets[i].data_ptr<scalar_t>(), dets[j].data_ptr<scalar_t>(), 0);
if (ovr >= iou_threshold) {
suppressed[j] = 1;
}
}
}
return keep_t.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep);
}
Tensor nms_quadri_cpu(const Tensor dets, const Tensor scores,
const float iou_threshold) {
auto result = at::empty({0}, dets.options());
AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms_quadri", [&] {
result = nms_quadri_cpu_kernel<scalar_t>(dets, scores, iou_threshold);
});
return result;
}

View File

@ -0,0 +1,23 @@
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
#include "box_iou_quadri_cuda.cuh"
#include "pytorch_cuda_helper.hpp"
void box_iou_quadri_cuda(const Tensor boxes1, const Tensor boxes2, Tensor ious,
const int mode_flag, const bool aligned) {
using scalar_t = float;
AT_ASSERTM(boxes1.is_cuda(), "boxes1 must be a CUDA tensor");
AT_ASSERTM(boxes2.is_cuda(), "boxes2 must be a CUDA tensor");
int output_size = ious.numel();
int num_boxes1 = boxes1.size(0);
int num_boxes2 = boxes2.size(0);
at::cuda::CUDAGuard device_guard(boxes1.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
box_iou_quadri_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
num_boxes1, num_boxes2, boxes1.data_ptr<scalar_t>(),
boxes2.data_ptr<scalar_t>(), (scalar_t*)ious.data_ptr<scalar_t>(),
mode_flag, aligned);
AT_CUDA_CHECK(cudaGetLastError());
}

View File

@ -125,6 +125,13 @@ void box_iou_rotated_impl(const Tensor boxes1, const Tensor boxes2, Tensor ious,
const int mode_flag, const bool aligned);
REGISTER_DEVICE_IMPL(box_iou_rotated_impl, CUDA, box_iou_rotated_cuda);
void box_iou_quadri_cuda(const Tensor boxes1, const Tensor boxes2, Tensor ious,
const int mode_flag, const bool aligned);
void box_iou_quadri_impl(const Tensor boxes1, const Tensor boxes2, Tensor ious,
const int mode_flag, const bool aligned);
REGISTER_DEVICE_IMPL(box_iou_quadri_impl, CUDA, box_iou_quadri_cuda);
void CARAFEForwardCUDAKernelLauncher(const Tensor features, const Tensor masks,
Tensor rfeatures, Tensor routput,
Tensor rmasks, Tensor output,

View File

@ -0,0 +1,60 @@
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
#include "nms_quadri_cuda.cuh"
#include "pytorch_cuda_helper.hpp"
Tensor nms_quadri_cuda(const Tensor dets, const Tensor scores,
const Tensor order_t, const Tensor dets_sorted,
float iou_threshold, const int multi_label) {
// using scalar_t = float;
AT_ASSERTM(dets.is_cuda(), "dets must be a CUDA tensor");
AT_ASSERTM(scores.is_cuda(), "scores must be a CUDA tensor");
at::cuda::CUDAGuard device_guard(dets.device());
int dets_num = dets.size(0);
const int col_blocks = at::cuda::ATenCeilDiv(dets_num, threadsPerBlock);
Tensor mask =
at::empty({dets_num * col_blocks}, dets.options().dtype(at::kLong));
dim3 blocks(col_blocks, col_blocks);
dim3 threads(threadsPerBlock);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
dets_sorted.scalar_type(), "nms_quadri_kernel_cuda", [&] {
nms_quadri_cuda_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
dets_num, iou_threshold, dets_sorted.data_ptr<scalar_t>(),
(unsigned long long*)mask.data_ptr<int64_t>(), multi_label);
});
Tensor mask_cpu = mask.to(at::kCPU);
unsigned long long* mask_host =
(unsigned long long*)mask_cpu.data_ptr<int64_t>();
std::vector<unsigned long long> remv(col_blocks);
memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);
Tensor keep =
at::empty({dets_num}, dets.options().dtype(at::kLong).device(at::kCPU));
int64_t* keep_out = keep.data_ptr<int64_t>();
int num_to_keep = 0;
for (int i = 0; i < dets_num; i++) {
int nblock = i / threadsPerBlock;
int inblock = i % threadsPerBlock;
if (!(remv[nblock] & (1ULL << inblock))) {
keep_out[num_to_keep++] = i;
unsigned long long* p = mask_host + i * col_blocks;
for (int j = nblock; j < col_blocks; j++) {
remv[j] |= p[j];
}
}
}
AT_CUDA_CHECK(cudaGetLastError());
return order_t.index(
{keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep)
.to(order_t.device(), keep.scalar_type())});
}

View File

@ -0,0 +1,30 @@
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
#include "pytorch_cpp_helper.hpp"
Tensor nms_quadri_cpu(const Tensor dets, const Tensor scores,
const float iou_threshold);
#ifdef MMCV_WITH_CUDA
Tensor nms_quadri_cuda(const Tensor dets, const Tensor scores,
const Tensor order, const Tensor dets_sorted,
const float iou_threshold, const int multi_label);
#endif
// Interface for Python
// inline is needed to prevent multiple function definitions when this header is
// included by different cpps
Tensor nms_quadri(const Tensor dets, const Tensor scores, const Tensor order,
const Tensor dets_sorted, const float iou_threshold,
const int multi_label) {
assert(dets.device().is_cuda() == scores.device().is_cuda());
if (dets.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
return nms_quadri_cuda(dets, scores, order, dets_sorted, iou_threshold,
multi_label);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
return nms_quadri_cpu(dets, scores, iou_threshold);
}

View File

@ -423,6 +423,13 @@ void chamfer_distance_backward(const Tensor xyz1, const Tensor xyz2,
Tensor graddist2, Tensor gradxyz1,
Tensor gradxyz2);
void box_iou_quadri(const Tensor boxes1, const Tensor boxes2, Tensor ious,
const int mode_flag, const bool aligned);
Tensor nms_quadri(const Tensor dets, const Tensor scores, const Tensor order,
const Tensor dets_sorted, const float iou_threshold,
const int multi_label);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"),
py::arg("kernel"), py::arg("up_x"), py::arg("up_y"), py::arg("down_x"),
@ -853,4 +860,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("input"), py::arg("rois"), py::arg("grad_rois"),
py::arg("pooled_height"), py::arg("pooled_width"),
py::arg("spatial_scale"));
m.def("box_iou_quadri", &box_iou_quadri, "IoU for quadrilateral boxes",
py::arg("boxes1"), py::arg("boxes2"), py::arg("ious"),
py::arg("mode_flag"), py::arg("aligned"));
m.def("nms_quadri", &nms_quadri, "NMS for quadrilateral boxes",
py::arg("dets"), py::arg("scores"), py::arg("order"),
py::arg("dets_sorted"), py::arg("iou_threshold"),
py::arg("multi_label"));
}

View File

@ -9,7 +9,7 @@ from mmcv.utils import deprecated_api_warning
from ..utils import ext_loader
ext_module = ext_loader.load_ext(
'_ext', ['nms', 'softnms', 'nms_match', 'nms_rotated'])
'_ext', ['nms', 'softnms', 'nms_match', 'nms_rotated', 'nms_quadri'])
# This function is modified from: https://github.com/pytorch/vision/
@ -475,3 +475,45 @@ def nms_rotated(dets: Tensor,
dets = torch.cat((dets[keep_inds], scores[keep_inds].reshape(-1, 1)),
dim=1)
return dets, keep_inds
def nms_quadri(dets: Tensor,
scores: Tensor,
iou_threshold: float,
labels: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
"""Performs non-maximum suppression (NMS) on the quadrilateral boxes
according to their intersection-over-union (IoU).
Quadri NMS iteratively removes lower scoring quadrilateral boxes
which have an IoU greater than iou_threshold with another (higher
scoring) quadrilateral box.
Args:
dets (torch.Tensor): Quadri boxes in shape (N, 8).
They are expected to be in
(x1, y1, ..., x4, y4) format.
scores (torch.Tensor): scores in shape (N, ).
iou_threshold (float): IoU thresh for NMS.
labels (torch.Tensor, optional): boxes' label in shape (N,).
Returns:
tuple: kept dets(boxes and scores) and indice, which is always the
same data type as the input.
"""
if dets.shape[0] == 0:
return dets, None
multi_label = labels is not None
if multi_label:
dets_with_lables = \
torch.cat((dets, labels.unsqueeze(1)), 1) # type: ignore
else:
dets_with_lables = dets
_, order = scores.sort(0, descending=True)
dets_sorted = dets_with_lables.index_select(0, order)
keep_inds = ext_module.nms_quadri(dets_with_lables, scores, order,
dets_sorted, iou_threshold, multi_label)
dets = torch.cat((dets[keep_inds], scores[keep_inds].reshape(-1, 1)),
dim=1)
return dets, keep_inds

View File

@ -0,0 +1,77 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
import torch
from mmcv.utils import IS_CUDA_AVAILABLE
class TestBoxIoUQuadri:
@pytest.mark.parametrize('device', [
'cpu',
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
])
def test_box_iou_quadri_cuda(self, device):
from mmcv.ops import box_iou_quadri
np_boxes1 = np.asarray([[1.0, 1.0, 3.0, 4.0, 4.0, 4.0, 4.0, 1.0],
[2.0, 2.0, 3.0, 4.0, 4.0, 2.0, 3.0, 1.0],
[7.0, 7.0, 8.0, 8.0, 9.0, 7.0, 8.0, 6.0]],
dtype=np.float32)
np_boxes2 = np.asarray([[0.0, 0.0, 0.0, 2.0, 2.0, 2.0, 2.0, 0.0],
[2.0, 1.0, 2.0, 4.0, 4.0, 4.0, 4.0, 1.0],
[7.0, 6.0, 7.0, 8.0, 9.0, 8.0, 9.0, 6.0]],
dtype=np.float32)
np_expect_ious = np.asarray(
[[0.0714, 1.0000, 0.0000], [0.0000, 0.5000, 0.0000],
[0.0000, 0.0000, 0.5000]],
dtype=np.float32)
np_expect_ious_aligned = np.asarray([0.0714, 0.5000, 0.5000],
dtype=np.float32)
boxes1 = torch.from_numpy(np_boxes1).to(device)
boxes2 = torch.from_numpy(np_boxes2).to(device)
ious = box_iou_quadri(boxes1, boxes2)
assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4)
ious = box_iou_quadri(boxes1, boxes2, aligned=True)
assert np.allclose(
ious.cpu().numpy(), np_expect_ious_aligned, atol=1e-4)
@pytest.mark.parametrize('device', [
'cpu',
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
])
def test_box_iou_quadri_iof_cuda(self, device):
from mmcv.ops import box_iou_quadri
np_boxes1 = np.asarray([[1.0, 1.0, 3.0, 4.0, 4.0, 4.0, 4.0, 1.0],
[2.0, 2.0, 3.0, 4.0, 4.0, 2.0, 3.0, 1.0],
[7.0, 7.0, 8.0, 8.0, 9.0, 7.0, 8.0, 6.0]],
dtype=np.float32)
np_boxes2 = np.asarray([[0.0, 0.0, 0.0, 2.0, 2.0, 2.0, 2.0, 0.0],
[2.0, 1.0, 2.0, 4.0, 4.0, 4.0, 4.0, 1.0],
[7.0, 6.0, 7.0, 8.0, 9.0, 8.0, 9.0, 6.0]],
dtype=np.float32)
np_expect_ious = np.asarray(
[[0.1111, 1.0000, 0.0000], [0.0000, 1.0000, 0.0000],
[0.0000, 0.0000, 1.0000]],
dtype=np.float32)
np_expect_ious_aligned = np.asarray([0.1111, 1.0000, 1.0000],
dtype=np.float32)
boxes1 = torch.from_numpy(np_boxes1).to(device)
boxes2 = torch.from_numpy(np_boxes2).to(device)
ious = box_iou_quadri(boxes1, boxes2, mode='iof')
assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4)
ious = box_iou_quadri(boxes1, boxes2, mode='iof', aligned=True)
assert np.allclose(
ious.cpu().numpy(), np_expect_ious_aligned, atol=1e-4)

View File

@ -0,0 +1,119 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
import torch
from mmcv.utils import IS_CUDA_AVAILABLE
class TestNMSQuadri:
@pytest.mark.parametrize('device', [
'cpu',
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
])
def test_ml_nms_quadri(self, device):
from mmcv.ops import nms_quadri
np_boxes = np.array([[1.0, 1.0, 3.0, 4.0, 4.0, 4.0, 4.0, 1.0, 0.7],
[2.0, 2.0, 3.0, 4.0, 4.0, 2.0, 3.0, 1.0, 0.8],
[7.0, 7.0, 8.0, 8.0, 9.0, 7.0, 8.0, 6.0, 0.5],
[0.0, 0.0, 0.0, 2.0, 2.0, 2.0, 2.0, 0.0, 0.9]],
dtype=np.float32)
np_labels = np.array([1, 0, 1, 0], dtype=np.float32)
np_expect_dets = np.array([[0., 0., 0., 2., 2., 2., 2., 0.],
[2., 2., 3., 4., 4., 2., 3., 1.],
[7., 7., 8., 8., 9., 7., 8., 6.]],
dtype=np.float32)
np_expect_keep_inds = np.array([3, 1, 2], dtype=np.int64)
boxes = torch.from_numpy(np_boxes).to(device)
labels = torch.from_numpy(np_labels).to(device)
dets, keep_inds = nms_quadri(boxes[:, :8], boxes[:, -1], 0.3, labels)
assert np.allclose(dets.cpu().numpy()[:, :8], np_expect_dets)
assert np.allclose(keep_inds.cpu().numpy(), np_expect_keep_inds)
@pytest.mark.parametrize('device', [
'cpu',
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
])
def test_nms_quadri(self, device):
from mmcv.ops import nms_quadri
np_boxes = np.array([[1.0, 1.0, 3.0, 4.0, 4.0, 4.0, 4.0, 1.0, 0.7],
[2.0, 2.0, 3.0, 4.0, 4.0, 2.0, 3.0, 1.0, 0.8],
[7.0, 7.0, 8.0, 8.0, 9.0, 7.0, 8.0, 6.0, 0.5],
[0.0, 0.0, 0.0, 2.0, 2.0, 2.0, 2.0, 0.0, 0.9]],
dtype=np.float32)
np_expect_dets = np.array([[0., 0., 0., 2., 2., 2., 2., 0.],
[2., 2., 3., 4., 4., 2., 3., 1.],
[7., 7., 8., 8., 9., 7., 8., 6.]],
dtype=np.float32)
np_expect_keep_inds = np.array([3, 1, 2], dtype=np.int64)
boxes = torch.from_numpy(np_boxes).to(device)
dets, keep_inds = nms_quadri(boxes[:, :8], boxes[:, -1], 0.3)
assert np.allclose(dets.cpu().numpy()[:, :8], np_expect_dets)
assert np.allclose(keep_inds.cpu().numpy(), np_expect_keep_inds)
@pytest.mark.parametrize('device', [
'cpu',
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
])
def test_batched_nms(self, device):
# test batched_nms with nms_quadri
from mmcv.ops import batched_nms
np_boxes = np.array([[1.0, 1.0, 3.0, 4.0, 4.0, 4.0, 4.0, 1.0, 0.7],
[2.0, 2.0, 3.0, 4.0, 4.0, 2.0, 3.0, 1.0, 0.8],
[7.0, 7.0, 8.0, 8.0, 9.0, 7.0, 8.0, 6.0, 0.5],
[0.0, 0.0, 0.0, 2.0, 2.0, 2.0, 2.0, 0.0, 0.9]],
dtype=np.float32)
np_labels = np.array([1, 0, 1, 0], dtype=np.float32)
np_expect_agnostic_dets = np.array([[0., 0., 0., 2., 2., 2., 2., 0.],
[2., 2., 3., 4., 4., 2., 3., 1.],
[7., 7., 8., 8., 9., 7., 8., 6.]],
dtype=np.float32)
np_expect_agnostic_keep_inds = np.array([3, 1, 2], dtype=np.int64)
np_expect_dets = np.array([[0., 0., 0., 2., 2., 2., 2., 0.],
[2., 2., 3., 4., 4., 2., 3., 1.],
[1., 1., 3., 4., 4., 4., 4., 1.],
[7., 7., 8., 8., 9., 7., 8., 6.]],
dtype=np.float32)
np_expect_keep_inds = np.array([3, 1, 0, 2], dtype=np.int64)
nms_cfg = dict(type='nms_quadri', iou_threshold=0.3)
# test class_agnostic is True
boxes, keep = batched_nms(
torch.from_numpy(np_boxes[:, :8]).to(device),
torch.from_numpy(np_boxes[:, -1]).to(device),
torch.from_numpy(np_labels).to(device),
nms_cfg,
class_agnostic=True)
assert np.allclose(boxes.cpu().numpy()[:, :8], np_expect_agnostic_dets)
assert np.allclose(keep.cpu().numpy(), np_expect_agnostic_keep_inds)
# test class_agnostic is False
boxes, keep = batched_nms(
torch.from_numpy(np_boxes[:, :8]).to(device),
torch.from_numpy(np_boxes[:, -1]).to(device),
torch.from_numpy(np_labels).to(device),
nms_cfg,
class_agnostic=False)
assert np.allclose(boxes.cpu().numpy()[:, :8], np_expect_dets)
assert np.allclose(keep.cpu().numpy(), np_expect_keep_inds)