mirror of https://github.com/open-mmlab/mmcv.git
[Feature] Add RiRoIAlignRotated CUDA op for rotated detection. (#1599)
parent
2475dc3452
commit
0bcbeadb53
|
@ -21,6 +21,7 @@ We implement common CUDA ops used in detection, segmentation, etc.
|
|||
- MaskedConv
|
||||
- NMS
|
||||
- PSAMask
|
||||
- RiRoIAlignRotated
|
||||
- RotatedFeatureAlign
|
||||
- RoIPointPool3d
|
||||
- RoIPool
|
||||
|
|
|
@ -23,6 +23,7 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子
|
|||
- RotatedFeatureAlign
|
||||
- RoIPointPool3d
|
||||
- RoIPool
|
||||
- RiRoIAlignRotated
|
||||
- RoIAlign
|
||||
- RoIAwarePool3d
|
||||
- SimpleRoIAlign
|
||||
|
|
|
@ -40,6 +40,7 @@ from .points_in_boxes import (points_in_boxes_all, points_in_boxes_cpu,
|
|||
points_in_boxes_part)
|
||||
from .points_sampler import PointsSampler
|
||||
from .psa_mask import PSAMask
|
||||
from .riroi_align_rotated import RiRoIAlignRotated, riroi_align_rotated
|
||||
from .roi_align import RoIAlign, roi_align
|
||||
from .roi_align_rotated import RoIAlignRotated, roi_align_rotated
|
||||
from .roi_pool import RoIPool, roi_pool
|
||||
|
@ -71,11 +72,11 @@ __all__ = [
|
|||
'SAConv2d', 'TINShift', 'tin_shift', 'assign_score_withk',
|
||||
'box_iou_rotated', 'RoIPointPool3d', 'nms_rotated', 'knn', 'ball_query',
|
||||
'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu',
|
||||
'rotated_feature_align', 'RoIAlignRotated', 'roi_align_rotated',
|
||||
'pixel_group', 'QueryAndGroup', 'GroupAll', 'grouping_operation',
|
||||
'contour_expand', 'three_nn', 'three_interpolate',
|
||||
'MultiScaleDeformableAttention', 'BorderAlign', 'border_align',
|
||||
'gather_points', 'furthest_point_sample',
|
||||
'rotated_feature_align', 'RiRoIAlignRotated', 'riroi_align_rotated',
|
||||
'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'QueryAndGroup',
|
||||
'GroupAll', 'grouping_operation', 'contour_expand', 'three_nn',
|
||||
'three_interpolate', 'MultiScaleDeformableAttention', 'BorderAlign',
|
||||
'border_align', 'gather_points', 'furthest_point_sample',
|
||||
'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation',
|
||||
'boxes_iou_bev', 'nms_bev', 'nms_normal_bev', 'Voxelization',
|
||||
'voxelization', 'dynamic_scatter', 'DynamicScatter', 'RoIAwarePool3d',
|
||||
|
|
|
@ -0,0 +1,242 @@
|
|||
// Modified from
|
||||
// https://github.com/csuhan/ReDet/blob/master/mmdet/ops/riroi_align/src/riroi_align_kernel.cu
|
||||
#ifndef RIROI_ALIGN_ROTATED_CUDA_KERNEL_CUH
|
||||
#define RIROI_ALIGN_ROTATED_CUDA_KERNEL_CUH
|
||||
|
||||
#include <float.h>
|
||||
#ifdef MMCV_USE_PARROTS
|
||||
#include "parrots_cuda_helper.hpp"
|
||||
#else // MMCV_USE_PARROTS
|
||||
#include "pytorch_cuda_helper.hpp"
|
||||
#endif // MMCV_USE_PARROTS
|
||||
|
||||
/*** Forward ***/
|
||||
template <typename scalar_t>
|
||||
__global__ void riroi_align_rotated_forward_cuda_kernel(
|
||||
const int nthreads, const scalar_t *bottom_data,
|
||||
const scalar_t *bottom_rois, const scalar_t spatial_scale,
|
||||
const int num_samples, const bool clockwise, const int channels,
|
||||
const int height, const int width, const int pooled_height,
|
||||
const int pooled_width, const int num_orientations, scalar_t *top_data) {
|
||||
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
// (n, c, ph, pw) is an element in the pooled output
|
||||
int pw = index % pooled_width;
|
||||
int ph = (index / pooled_width) % pooled_height;
|
||||
int o = (index / pooled_width / pooled_height) % num_orientations;
|
||||
int c =
|
||||
(index / pooled_width / pooled_height / num_orientations) % channels;
|
||||
int n = index / pooled_width / pooled_height / num_orientations / channels;
|
||||
|
||||
const scalar_t *offset_bottom_rois = bottom_rois + n * 6;
|
||||
int roi_batch_ind = offset_bottom_rois[0];
|
||||
|
||||
// Do not using rounding; this implementation detail is critical
|
||||
scalar_t roi_center_w = offset_bottom_rois[1] * spatial_scale;
|
||||
scalar_t roi_center_h = offset_bottom_rois[2] * spatial_scale;
|
||||
scalar_t roi_width = offset_bottom_rois[3] * spatial_scale;
|
||||
scalar_t roi_height = offset_bottom_rois[4] * spatial_scale;
|
||||
// scalar_t theta = offset_bottom_rois[5] * M_PI / 180.0;
|
||||
scalar_t theta = offset_bottom_rois[5];
|
||||
// Force malformed ROIs to be 1x1
|
||||
roi_width = max(roi_width, (scalar_t)1.);
|
||||
roi_height = max(roi_height, (scalar_t)1.);
|
||||
scalar_t bin_size_h = static_cast<scalar_t>(roi_height) /
|
||||
static_cast<scalar_t>(pooled_height);
|
||||
scalar_t bin_size_w =
|
||||
static_cast<scalar_t>(roi_width) / static_cast<scalar_t>(pooled_width);
|
||||
|
||||
// find aligned index
|
||||
scalar_t ind_float = theta * num_orientations / (2 * M_PI);
|
||||
int ind = floor(ind_float);
|
||||
scalar_t l_var = ind_float - (scalar_t)ind;
|
||||
scalar_t r_var = 1.0 - l_var;
|
||||
// correct start channel
|
||||
ind = (ind + num_orientations) % num_orientations;
|
||||
// rotated channel
|
||||
int ind_rot = (o - ind + num_orientations) % num_orientations;
|
||||
int ind_rot_plus = (ind_rot + 1 + num_orientations) % num_orientations;
|
||||
const scalar_t *offset_bottom_data =
|
||||
bottom_data + (roi_batch_ind * channels * num_orientations +
|
||||
c * num_orientations + ind_rot) *
|
||||
height * width;
|
||||
|
||||
const scalar_t *offset_bottom_data_plus =
|
||||
bottom_data + (roi_batch_ind * channels * num_orientations +
|
||||
c * num_orientations + ind_rot_plus) *
|
||||
height * width;
|
||||
// We use roi_bin_grid to sample the grid and mimic integral
|
||||
int roi_bin_grid_h = (num_samples > 0)
|
||||
? num_samples
|
||||
: ceilf(roi_height / pooled_height); // e.g., = 2
|
||||
int roi_bin_grid_w =
|
||||
(num_samples > 0) ? num_samples : ceilf(roi_width / pooled_width);
|
||||
|
||||
// roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
|
||||
// Appropriate translation needs to be applied after.
|
||||
if (clockwise) {
|
||||
theta = -theta; // If clockwise, the angle needs to be reversed.
|
||||
}
|
||||
scalar_t roi_start_h = -roi_height / 2.0;
|
||||
scalar_t roi_start_w = -roi_width / 2.0;
|
||||
scalar_t cosscalar_theta = cos(theta);
|
||||
scalar_t sinscalar_theta = sin(theta);
|
||||
|
||||
// We do average (integral) pooling inside a bin
|
||||
const scalar_t count = max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4
|
||||
|
||||
scalar_t output_val = 0.;
|
||||
for (int iy = 0; iy < roi_bin_grid_h; iy++) { // e.g., iy = 0, 1
|
||||
const scalar_t yy =
|
||||
roi_start_h + ph * bin_size_h +
|
||||
static_cast<scalar_t>(iy + .5f) * bin_size_h /
|
||||
static_cast<scalar_t>(roi_bin_grid_h); // e.g., 0.5, 1.5
|
||||
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
|
||||
const scalar_t xx = roi_start_w + pw * bin_size_w +
|
||||
static_cast<scalar_t>(ix + .5f) * bin_size_w /
|
||||
static_cast<scalar_t>(roi_bin_grid_w);
|
||||
|
||||
// Rotate by theta (counterclockwise) around the center and translate
|
||||
scalar_t y = yy * cosscalar_theta - xx * sinscalar_theta + roi_center_h;
|
||||
scalar_t x = yy * sinscalar_theta + xx * cosscalar_theta + roi_center_w;
|
||||
|
||||
scalar_t val = bilinear_interpolate<scalar_t>(
|
||||
offset_bottom_data, height, width, y, x, index);
|
||||
scalar_t val_plus = bilinear_interpolate<scalar_t>(
|
||||
offset_bottom_data_plus, height, width, y, x, index);
|
||||
output_val += r_var * val + l_var * val_plus;
|
||||
}
|
||||
}
|
||||
output_val /= count;
|
||||
|
||||
top_data[index] = output_val;
|
||||
}
|
||||
}
|
||||
|
||||
/*** Backward ***/
|
||||
template <typename scalar_t>
|
||||
__global__ void riroi_align_rotated_backward_cuda_kernel(
|
||||
const int nthreads, const scalar_t *top_diff, const scalar_t *bottom_rois,
|
||||
const scalar_t spatial_scale, const int num_samples, const bool clockwise,
|
||||
const int channels, const int height, const int width,
|
||||
const int pooled_height, const int pooled_width, const int num_orientations,
|
||||
scalar_t *bottom_diff) {
|
||||
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
// (n, c, ph, pw) is an element in the pooled output
|
||||
int pw = index % pooled_width;
|
||||
int ph = (index / pooled_width) % pooled_height;
|
||||
int o = (index / pooled_width / pooled_height) % num_orientations;
|
||||
int c =
|
||||
(index / pooled_width / pooled_height / num_orientations) % channels;
|
||||
int n = index / pooled_width / pooled_height / num_orientations / channels;
|
||||
|
||||
const scalar_t *offset_bottom_rois = bottom_rois + n * 6;
|
||||
int roi_batch_ind = offset_bottom_rois[0];
|
||||
|
||||
// Do not round
|
||||
scalar_t roi_center_w = offset_bottom_rois[1] * spatial_scale;
|
||||
scalar_t roi_center_h = offset_bottom_rois[2] * spatial_scale;
|
||||
scalar_t roi_width = offset_bottom_rois[3] * spatial_scale;
|
||||
scalar_t roi_height = offset_bottom_rois[4] * spatial_scale;
|
||||
// scalar_t theta = offset_bottom_rois[5] * M_PI / 180.0;
|
||||
scalar_t theta = offset_bottom_rois[5];
|
||||
// Force malformed ROIs to be 1x1
|
||||
roi_width = max(roi_width, (scalar_t)1.);
|
||||
roi_height = max(roi_height, (scalar_t)1.);
|
||||
|
||||
scalar_t bin_size_h = static_cast<scalar_t>(roi_height) /
|
||||
static_cast<scalar_t>(pooled_height);
|
||||
scalar_t bin_size_w =
|
||||
static_cast<scalar_t>(roi_width) / static_cast<scalar_t>(pooled_width);
|
||||
|
||||
// find aligned index
|
||||
scalar_t ind_float = theta * num_orientations / (2 * M_PI);
|
||||
int ind = floor(ind_float);
|
||||
scalar_t l_var = ind_float - (scalar_t)ind;
|
||||
scalar_t r_var = 1.0 - l_var;
|
||||
// correct start channel
|
||||
ind = (ind + num_orientations) % num_orientations;
|
||||
// rotated channel
|
||||
int ind_rot = (o - ind + num_orientations) % num_orientations;
|
||||
int ind_rot_plus = (ind_rot + 1 + num_orientations) % num_orientations;
|
||||
scalar_t *offset_bottom_diff =
|
||||
bottom_diff + (roi_batch_ind * channels * num_orientations +
|
||||
c * num_orientations + ind_rot) *
|
||||
height * width;
|
||||
scalar_t *offset_bottom_diff_plus =
|
||||
bottom_diff + (roi_batch_ind * channels * num_orientations +
|
||||
c * num_orientations + ind_rot_plus) *
|
||||
height * width;
|
||||
int top_offset =
|
||||
(n * channels * num_orientations + c * num_orientations + o) *
|
||||
pooled_height * pooled_width;
|
||||
const scalar_t *offset_top_diff = top_diff + top_offset;
|
||||
const scalar_t top_diff_this_bin = offset_top_diff[ph * pooled_width + pw];
|
||||
|
||||
// We use roi_bin_grid to sample the grid and mimic integral
|
||||
int roi_bin_grid_h = (num_samples > 0)
|
||||
? num_samples
|
||||
: ceilf(roi_height / pooled_height); // e.g., = 2
|
||||
int roi_bin_grid_w =
|
||||
(num_samples > 0) ? num_samples : ceilf(roi_width / pooled_width);
|
||||
|
||||
// roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
|
||||
// Appropriate translation needs to be applied after.
|
||||
if (clockwise) {
|
||||
theta = -theta; // If clockwise, the angle needs to be reversed.
|
||||
}
|
||||
scalar_t roi_start_h = -roi_height / 2.0;
|
||||
scalar_t roi_start_w = -roi_width / 2.0;
|
||||
scalar_t cosTheta = cos(theta);
|
||||
scalar_t sinTheta = sin(theta);
|
||||
|
||||
// We do average (integral) pooling inside a bin
|
||||
const scalar_t count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
|
||||
|
||||
for (int iy = 0; iy < roi_bin_grid_h; iy++) { // e.g., iy = 0, 1
|
||||
const scalar_t yy =
|
||||
roi_start_h + ph * bin_size_h +
|
||||
static_cast<scalar_t>(iy + .5f) * bin_size_h /
|
||||
static_cast<scalar_t>(roi_bin_grid_h); // e.g., 0.5, 1.5
|
||||
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
|
||||
const scalar_t xx = roi_start_w + pw * bin_size_w +
|
||||
static_cast<scalar_t>(ix + .5f) * bin_size_w /
|
||||
static_cast<scalar_t>(roi_bin_grid_w);
|
||||
|
||||
// Rotate by theta around the center and translate
|
||||
scalar_t y = yy * cosTheta - xx * sinTheta + roi_center_h;
|
||||
scalar_t x = yy * sinTheta + xx * cosTheta + roi_center_w;
|
||||
|
||||
scalar_t w1, w2, w3, w4;
|
||||
int x_low, x_high, y_low, y_high;
|
||||
|
||||
bilinear_interpolate_gradient<scalar_t>(height, width, y, x, w1, w2, w3,
|
||||
w4, x_low, x_high, y_low,
|
||||
y_high, index);
|
||||
|
||||
scalar_t g1 = top_diff_this_bin * w1 / count;
|
||||
scalar_t g2 = top_diff_this_bin * w2 / count;
|
||||
scalar_t g3 = top_diff_this_bin * w3 / count;
|
||||
scalar_t g4 = top_diff_this_bin * w4 / count;
|
||||
|
||||
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
|
||||
atomicAdd(offset_bottom_diff + y_low * width + x_low, g1 * r_var);
|
||||
atomicAdd(offset_bottom_diff + y_low * width + x_high, g2 * r_var);
|
||||
atomicAdd(offset_bottom_diff + y_high * width + x_low, g3 * r_var);
|
||||
atomicAdd(offset_bottom_diff + y_high * width + x_high, g4 * r_var);
|
||||
|
||||
atomicAdd(offset_bottom_diff_plus + y_low * width + x_low,
|
||||
g1 * l_var);
|
||||
atomicAdd(offset_bottom_diff_plus + y_low * width + x_high,
|
||||
g2 * l_var);
|
||||
atomicAdd(offset_bottom_diff_plus + y_high * width + x_low,
|
||||
g3 * l_var);
|
||||
atomicAdd(offset_bottom_diff_plus + y_high * width + x_high,
|
||||
g4 * l_var);
|
||||
|
||||
} // if
|
||||
} // ix
|
||||
} // iy
|
||||
} // CUDA_1D_KERNEL_LOOP
|
||||
} // RiRoIAlignBackward
|
||||
|
||||
#endif // RIROI_ALIGN_ROTATED_CUDA_KERNEL_CUH
|
|
@ -992,6 +992,81 @@ REGISTER_DEVICE_IMPL(roi_align_rotated_forward_impl, CUDA,
|
|||
REGISTER_DEVICE_IMPL(roi_align_rotated_backward_impl, CUDA,
|
||||
roi_align_rotated_backward_cuda);
|
||||
|
||||
void RiROIAlignRotatedForwardCUDAKernelLauncher(
|
||||
const at::Tensor features, const at::Tensor rois, const float spatial_scale,
|
||||
const int num_samples, const bool clockwise, const int channels,
|
||||
const int height, const int width, const int num_rois,
|
||||
const int pooled_height, const int pooled_width, const int num_orientations,
|
||||
at::Tensor output);
|
||||
|
||||
void RiROIAlignRotatedBackwardCUDAKernelLauncher(
|
||||
const at::Tensor top_grad, const at::Tensor rois, const float spatial_scale,
|
||||
const int num_samples, const bool clockwise, const int channels,
|
||||
const int height, const int width, const int num_rois,
|
||||
const int pooled_height, const int pooled_width, const int num_orientations,
|
||||
at::Tensor bottom_grad);
|
||||
|
||||
void riroi_align_rotated_forward_cuda(Tensor features, Tensor rois,
|
||||
Tensor output, int pooled_height,
|
||||
int pooled_width, float spatial_scale,
|
||||
int num_samples, int num_orientations,
|
||||
bool clockwise) {
|
||||
// Number of ROIs
|
||||
int num_rois = rois.size(0);
|
||||
int size_rois = rois.size(1);
|
||||
if (size_rois != 6) {
|
||||
AT_ERROR("wrong roi size");
|
||||
}
|
||||
CHECK_CONTIGUOUS(features);
|
||||
CHECK_CONTIGUOUS(rois);
|
||||
int num_channels = features.size(1) / num_orientations;
|
||||
int data_height = features.size(2);
|
||||
int data_width = features.size(3);
|
||||
RiROIAlignRotatedForwardCUDAKernelLauncher(
|
||||
features, rois, spatial_scale, num_samples, clockwise, num_channels,
|
||||
data_height, data_width, num_rois, pooled_height, pooled_width,
|
||||
num_orientations, output);
|
||||
}
|
||||
|
||||
void riroi_align_rotated_backward_cuda(Tensor top_grad, Tensor rois,
|
||||
Tensor bottom_grad, int pooled_height,
|
||||
int pooled_width, float spatial_scale,
|
||||
int num_samples, int num_orientations,
|
||||
bool clockwise) {
|
||||
// Number of ROIs
|
||||
int num_rois = rois.size(0);
|
||||
int size_rois = rois.size(1);
|
||||
if (size_rois != 6) {
|
||||
AT_ERROR("wrong roi size");
|
||||
}
|
||||
CHECK_CONTIGUOUS(top_grad);
|
||||
CHECK_CONTIGUOUS(rois);
|
||||
int num_channels = bottom_grad.size(1) / num_orientations;
|
||||
int data_height = bottom_grad.size(2);
|
||||
int data_width = bottom_grad.size(3);
|
||||
RiROIAlignRotatedBackwardCUDAKernelLauncher(
|
||||
top_grad, rois, spatial_scale, num_samples, clockwise, num_channels,
|
||||
data_height, data_width, num_rois, pooled_height, pooled_width,
|
||||
num_orientations, bottom_grad);
|
||||
}
|
||||
|
||||
void riroi_align_rotated_forward_impl(Tensor features, Tensor rois,
|
||||
Tensor output, int pooled_height,
|
||||
int pooled_width, float spatial_scale,
|
||||
int num_samples, int num_orientations,
|
||||
bool clockwise);
|
||||
|
||||
void riroi_align_rotated_backward_impl(Tensor top_grad, Tensor rois,
|
||||
Tensor bottom_grad, int pooled_height,
|
||||
int pooled_width, float spatial_scale,
|
||||
int num_samples, int num_orientations,
|
||||
bool clockwise);
|
||||
|
||||
REGISTER_DEVICE_IMPL(riroi_align_rotated_forward_impl, CUDA,
|
||||
riroi_align_rotated_forward_cuda);
|
||||
REGISTER_DEVICE_IMPL(riroi_align_rotated_backward_impl, CUDA,
|
||||
riroi_align_rotated_backward_cuda);
|
||||
|
||||
void RoiawarePool3dForwardCUDAKernelLauncher(
|
||||
int boxes_num, int pts_num, int channels, int max_pts_each_voxel, int out_x,
|
||||
int out_y, int out_z, const Tensor rois, const Tensor pts,
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#include "pytorch_cuda_helper.hpp"
|
||||
#include "riroi_align_rotated_cuda_kernel.cuh"
|
||||
|
||||
void RiROIAlignRotatedForwardCUDAKernelLauncher(
|
||||
const at::Tensor features, const at::Tensor rois, const float spatial_scale,
|
||||
const int num_samples, const bool clockwise, const int channels,
|
||||
const int height, const int width, const int num_rois,
|
||||
const int pooled_height, const int pooled_width, const int num_orientations,
|
||||
at::Tensor output) {
|
||||
const int output_size =
|
||||
num_rois * pooled_height * pooled_width * channels * num_orientations;
|
||||
at::cuda::CUDAGuard device_guard(features.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
features.scalar_type(), "riroi_align_rotated_forward_cuda_kernel", ([&] {
|
||||
const scalar_t *bottom_data = features.data_ptr<scalar_t>();
|
||||
const scalar_t *rois_data = rois.data_ptr<scalar_t>();
|
||||
scalar_t *top_data = output.data_ptr<scalar_t>();
|
||||
|
||||
riroi_align_rotated_forward_cuda_kernel<scalar_t>
|
||||
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
|
||||
output_size, bottom_data, rois_data, scalar_t(spatial_scale),
|
||||
num_samples, clockwise, channels, height, width, pooled_height,
|
||||
pooled_width, num_orientations, top_data);
|
||||
}));
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
void RiROIAlignRotatedBackwardCUDAKernelLauncher(
|
||||
const at::Tensor top_grad, const at::Tensor rois, const float spatial_scale,
|
||||
const int num_samples, const bool clockwise, const int channels,
|
||||
const int height, const int width, const int num_rois,
|
||||
const int pooled_height, const int pooled_width, const int num_orientations,
|
||||
at::Tensor bottom_grad) {
|
||||
const int output_size =
|
||||
num_rois * pooled_height * pooled_width * channels * num_orientations;
|
||||
at::cuda::CUDAGuard device_guard(top_grad.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
top_grad.scalar_type(), "riroi_align_rotated_backward_cuda_kernel", ([&] {
|
||||
const scalar_t *top_diff = top_grad.data_ptr<scalar_t>();
|
||||
const scalar_t *rois_data = rois.data_ptr<scalar_t>();
|
||||
scalar_t *bottom_diff = bottom_grad.data_ptr<scalar_t>();
|
||||
riroi_align_rotated_backward_cuda_kernel<scalar_t>
|
||||
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
|
||||
output_size, top_diff, rois_data, spatial_scale, num_samples,
|
||||
clockwise, channels, height, width, pooled_height, pooled_width,
|
||||
num_orientations, bottom_diff);
|
||||
}));
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
|
@ -350,6 +350,17 @@ void rotated_feature_align_backward(const Tensor top_grad,
|
|||
const float spatial_scale,
|
||||
const int points);
|
||||
|
||||
void riroi_align_rotated_forward(Tensor features, Tensor rois, Tensor output,
|
||||
int pooled_height, int pooled_width,
|
||||
float spatial_scale, int num_samples,
|
||||
int num_orientations, bool clockwise);
|
||||
|
||||
void riroi_align_rotated_backward(Tensor top_grad, Tensor rois,
|
||||
Tensor bottom_grad, int pooled_height,
|
||||
int pooled_width, float spatial_scale,
|
||||
int num_samples, int num_orientations,
|
||||
bool clockwise);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"),
|
||||
py::arg("kernel"), py::arg("up_x"), py::arg("up_y"), py::arg("down_x"),
|
||||
|
@ -704,4 +715,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||
"Feature Refine backward (CUDA)", py::arg("top_grad"),
|
||||
py::arg("best_bboxes"), py::arg("bottom_grad"),
|
||||
py::arg("spatial_scale"), py::arg("points"));
|
||||
m.def("riroi_align_rotated_forward", &riroi_align_rotated_forward,
|
||||
"riroi_align_rotated forward", py::arg("features"), py::arg("rois"),
|
||||
py::arg("output"), py::arg("pooled_height"), py::arg("pooled_width"),
|
||||
py::arg("spatial_scale"), py::arg("num_samples"),
|
||||
py::arg("num_orientations"), py::arg("clockwise"));
|
||||
m.def("riroi_align_rotated_backward", &riroi_align_rotated_backward,
|
||||
"riroi_align_rotated backward", py::arg("top_grad"), py::arg("rois"),
|
||||
py::arg("bottom_grad"), py::arg("pooled_height"),
|
||||
py::arg("pooled_width"), py::arg("spatial_scale"),
|
||||
py::arg("num_samples"), py::arg("num_orientations"),
|
||||
py::arg("clockwise"));
|
||||
}
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#include "pytorch_cpp_helper.hpp"
|
||||
#include "pytorch_device_registry.hpp"
|
||||
|
||||
void riroi_align_rotated_forward_impl(Tensor features, Tensor rois,
|
||||
Tensor output, int pooled_height,
|
||||
int pooled_width, float spatial_scale,
|
||||
int num_samples, int num_orientations,
|
||||
bool clockwise) {
|
||||
DISPATCH_DEVICE_IMPL(riroi_align_rotated_forward_impl, features, rois, output,
|
||||
pooled_height, pooled_width, spatial_scale, num_samples,
|
||||
num_orientations, clockwise);
|
||||
}
|
||||
|
||||
void riroi_align_rotated_backward_impl(Tensor top_grad, Tensor rois,
|
||||
Tensor bottom_grad, int pooled_height,
|
||||
int pooled_width, float spatial_scale,
|
||||
int num_samples, int num_orientations,
|
||||
bool clockwise) {
|
||||
DISPATCH_DEVICE_IMPL(riroi_align_rotated_backward_impl, top_grad, rois,
|
||||
bottom_grad, pooled_height, pooled_width, spatial_scale,
|
||||
num_samples, num_orientations, clockwise);
|
||||
}
|
||||
|
||||
void riroi_align_rotated_forward(Tensor features, Tensor rois, Tensor output,
|
||||
int pooled_height, int pooled_width,
|
||||
float spatial_scale, int num_samples,
|
||||
int num_orientations, bool clockwise) {
|
||||
riroi_align_rotated_forward_impl(features, rois, output, pooled_height,
|
||||
pooled_width, spatial_scale, num_samples,
|
||||
num_orientations, clockwise);
|
||||
}
|
||||
|
||||
void riroi_align_rotated_backward(Tensor top_grad, Tensor rois,
|
||||
Tensor bottom_grad, int pooled_height,
|
||||
int pooled_width, float spatial_scale,
|
||||
int num_samples, int num_orientations,
|
||||
bool clockwise) {
|
||||
riroi_align_rotated_backward_impl(top_grad, rois, bottom_grad, pooled_height,
|
||||
pooled_width, spatial_scale, num_samples,
|
||||
num_orientations, clockwise);
|
||||
}
|
|
@ -0,0 +1,119 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Function
|
||||
|
||||
from ..utils import ext_loader, is_tuple_of
|
||||
|
||||
ext_module = ext_loader.load_ext(
|
||||
'_ext', ['riroi_align_rotated_forward', 'riroi_align_rotated_backward'])
|
||||
|
||||
|
||||
class RiRoIAlignRotatedFunction(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx,
|
||||
features,
|
||||
rois,
|
||||
out_size,
|
||||
spatial_scale,
|
||||
num_samples=0,
|
||||
num_orientations=8,
|
||||
clockwise=False):
|
||||
if isinstance(out_size, int):
|
||||
out_h = out_size
|
||||
out_w = out_size
|
||||
elif is_tuple_of(out_size, int):
|
||||
assert len(out_size) == 2
|
||||
out_h, out_w = out_size
|
||||
else:
|
||||
raise TypeError(
|
||||
f'"out_size" should be an integer or tuple of integers,'
|
||||
f' but got {out_size}')
|
||||
ctx.spatial_scale = spatial_scale
|
||||
ctx.num_samples = num_samples
|
||||
ctx.num_orientations = num_orientations
|
||||
ctx.clockwise = clockwise
|
||||
ctx.save_for_backward(rois)
|
||||
ctx.feature_size = features.size()
|
||||
|
||||
batch_size, num_channels, _, _ = features.size()
|
||||
num_rois = rois.size(0)
|
||||
|
||||
output = features.new_zeros(num_rois, num_channels, out_h, out_w)
|
||||
|
||||
ext_module.riroi_align_rotated_forward(features, rois, output, out_h,
|
||||
out_w, spatial_scale,
|
||||
num_samples, num_orientations,
|
||||
clockwise)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
feature_size = ctx.feature_size
|
||||
spatial_scale = ctx.spatial_scale
|
||||
num_orientations = ctx.num_orientations
|
||||
clockwise = ctx.clockwise
|
||||
num_samples = ctx.num_samples
|
||||
rois = ctx.saved_tensors[0]
|
||||
assert feature_size is not None
|
||||
batch_size, num_channels, feature_h, feature_w = feature_size
|
||||
|
||||
out_w = grad_output.size(3)
|
||||
out_h = grad_output.size(2)
|
||||
|
||||
grad_input = grad_rois = None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad_input = rois.new_zeros(batch_size, num_channels, feature_h,
|
||||
feature_w)
|
||||
ext_module.riroi_align_rotated_backward(
|
||||
grad_output.contiguous(), rois, grad_input, out_h, out_w,
|
||||
spatial_scale, num_samples, num_orientations, clockwise)
|
||||
|
||||
return grad_input, grad_rois, None, None, None, None, None
|
||||
|
||||
|
||||
riroi_align_rotated = RiRoIAlignRotatedFunction.apply
|
||||
|
||||
|
||||
class RiRoIAlignRotated(nn.Module):
|
||||
"""Rotation-invariant RoI align pooling layer for rotated proposals.
|
||||
|
||||
It accepts a feature map of shape (N, C, H, W) and rois with shape
|
||||
(n, 6) with each roi decoded as (batch_index, center_x, center_y,
|
||||
w, h, angle). The angle is in radian.
|
||||
|
||||
The details are described in the paper `ReDet: A Rotation-equivariant
|
||||
Detector for Aerial Object Detection <https://arxiv.org/abs/2103.07733>`_.
|
||||
|
||||
Args:
|
||||
out_size (tuple): fixed dimensional RoI output with shape (h, w).
|
||||
spatial_scale (float): scale the input boxes by this number
|
||||
num_samples (int): number of inputs samples to take for each
|
||||
output sample. 0 to take samples densely for current models.
|
||||
num_orientations (int): number of oriented channels.
|
||||
clockwise (bool): If True, the angle in each proposal follows a
|
||||
clockwise fashion in image space, otherwise, the angle is
|
||||
counterclockwise. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
out_size,
|
||||
spatial_scale,
|
||||
num_samples=0,
|
||||
num_orientations=8,
|
||||
clockwise=False):
|
||||
super(RiRoIAlignRotated, self).__init__()
|
||||
|
||||
self.out_size = out_size
|
||||
self.spatial_scale = float(spatial_scale)
|
||||
self.num_samples = int(num_samples)
|
||||
self.num_orientations = int(num_orientations)
|
||||
self.clockwise = clockwise
|
||||
|
||||
def forward(self, features, rois):
|
||||
return RiRoIAlignRotatedFunction.apply(features, rois, self.out_size,
|
||||
self.spatial_scale,
|
||||
self.num_samples,
|
||||
self.num_orientations,
|
||||
self.clockwise)
|
|
@ -0,0 +1,73 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from torch.autograd import gradcheck
|
||||
|
||||
from mmcv.ops import RiRoIAlignRotated
|
||||
|
||||
np_feature = np.array([[[[1, 2], [3, 4]], [[1, 2], [4, 3]], [[4, 3], [2, 1]],
|
||||
[[1, 2], [5, 6]], [[3, 4], [7, 8]], [[9, 10], [13,
|
||||
14]],
|
||||
[[11, 12], [15, 16]], [[1, 1], [2, 2]]]])
|
||||
np_rois = np.array([[0., 0.5, 0.5, 1., 1., np.pi / 3],
|
||||
[0., 1., 1., 3., 3., np.pi / 2]])
|
||||
expect_output = np.array([[[[1.8425, 1.3516], [2.3151, 1.8241]],
|
||||
[[2.4779, 1.7416], [3.2173, 2.5632]],
|
||||
[[2.7149, 2.2638], [2.6540, 2.3673]],
|
||||
[[2.9461, 2.8638], [2.8028, 2.7205]],
|
||||
[[4.1943, 2.7214], [5.6119, 4.1391]],
|
||||
[[7.5276, 6.0547], [8.9453, 7.4724]],
|
||||
[[12.1943, 10.7214], [13.6119, 12.1391]],
|
||||
[[9.5489, 8.4237], [10.5763, 9.4511]]],
|
||||
[[[7.6562, 12.5625], [4.0000, 6.6250]],
|
||||
[[1.0000, 1.3125], [0.5000, 0.6562]],
|
||||
[[1.6562, 1.9375], [1.0000, 1.3125]],
|
||||
[[1.8438, 2.0547], [0.7500, 1.1562]],
|
||||
[[0.8438, 3.0625], [0.2500, 1.1875]],
|
||||
[[2.6562, 2.5625], [1.5000, 1.6250]],
|
||||
[[3.6562, 4.5625], [2.0000, 2.6250]],
|
||||
[[6.6562, 10.5625], [3.5000, 5.6250]]]])
|
||||
|
||||
expect_grad = np.array([[[[1.4727, 1.5586], [1.5586, 1.6602]],
|
||||
[[1.4727, 1.5586], [1.5586, 1.6602]],
|
||||
[[1.4727, 1.5586], [1.5586, 1.6602]],
|
||||
[[1.4727, 1.5586], [1.5586, 1.6602]],
|
||||
[[1.4727, 1.5586], [1.5586, 1.6602]],
|
||||
[[1.4727, 1.5586], [1.5586, 1.6602]],
|
||||
[[1.4727, 1.5586], [1.5586, 1.6602]],
|
||||
[[1.4727, 1.5586], [1.5586, 1.6602]]]])
|
||||
|
||||
pool_h = 2
|
||||
pool_w = 2
|
||||
spatial_scale = 1.0
|
||||
num_samples = 2
|
||||
sampling_ratio = 2
|
||||
num_orientations = 8
|
||||
clockwise = False
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='requires CUDA support')
|
||||
def test_roialign_rotated_gradcheck():
|
||||
x = torch.tensor(
|
||||
np_feature, dtype=torch.float, device='cuda', requires_grad=True)
|
||||
rois = torch.tensor(np_rois, dtype=torch.float, device='cuda')
|
||||
froipool = RiRoIAlignRotated((pool_h, pool_w), spatial_scale, num_samples,
|
||||
num_orientations, clockwise)
|
||||
gradcheck(froipool, (x, rois), eps=1e-3, atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='requires CUDA support')
|
||||
def test_roialign_rotated_allclose():
|
||||
x = torch.tensor(
|
||||
np_feature, dtype=torch.float, device='cuda', requires_grad=True)
|
||||
rois = torch.tensor(np_rois, dtype=torch.float, device='cuda')
|
||||
froipool = RiRoIAlignRotated((pool_h, pool_w), spatial_scale, num_samples,
|
||||
num_orientations, clockwise)
|
||||
output = froipool(x, rois)
|
||||
output.backward(torch.ones_like(output))
|
||||
assert np.allclose(
|
||||
output.data.type(torch.float).cpu().numpy(), expect_output, atol=1e-3)
|
||||
assert np.allclose(
|
||||
x.grad.data.type(torch.float).cpu().numpy(), expect_grad, atol=1e-3)
|
Loading…
Reference in New Issue