mirror of https://github.com/open-mmlab/mmcv.git
[Feature] Add feature refine CUDA op for rotated detection. (#1603)
* re PR * replace all the feature_refine with rotated_feature_align * replace all the FR with RotatedFeatureAlign * Update mmcv/ops/rotated_feature_align.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmcv/ops/rotated_feature_align.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * fix lint * Update mmcv/ops/csrc/pytorch/cuda/rotated_feature_align_cuda.cu Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmcv/ops/csrc/pytorch/cuda/cudabind.cpp Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmcv/ops/csrc/pytorch/cuda/rotated_feature_align_cuda.cu Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmcv/ops/csrc/pytorch/cuda/rotated_feature_align_cuda.cu Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmcv/ops/csrc/pytorch/cuda/rotated_feature_align_cuda.cu Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * fix lint * Update cudabind.cpp * Update cudabind.cpp * fix bug in test. & add backward test * fix lint Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>pull/1599/head
parent
9b49fcc6c1
commit
2475dc3452
|
@ -21,6 +21,7 @@ We implement common CUDA ops used in detection, segmentation, etc.
|
|||
- MaskedConv
|
||||
- NMS
|
||||
- PSAMask
|
||||
- RotatedFeatureAlign
|
||||
- RoIPointPool3d
|
||||
- RoIPool
|
||||
- RoIAlign
|
||||
|
|
|
@ -20,6 +20,7 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子
|
|||
- MaskedConv
|
||||
- NMS
|
||||
- PSAMask
|
||||
- RotatedFeatureAlign
|
||||
- RoIPointPool3d
|
||||
- RoIPool
|
||||
- RoIAlign
|
||||
|
|
|
@ -45,6 +45,7 @@ from .roi_align_rotated import RoIAlignRotated, roi_align_rotated
|
|||
from .roi_pool import RoIPool, roi_pool
|
||||
from .roiaware_pool3d import RoIAwarePool3d
|
||||
from .roipoint_pool3d import RoIPointPool3d
|
||||
from .rotated_feature_align import rotated_feature_align
|
||||
from .saconv import SAConv2d
|
||||
from .scatter_points import DynamicScatter, dynamic_scatter
|
||||
from .sync_bn import SyncBatchNorm
|
||||
|
@ -70,10 +71,11 @@ __all__ = [
|
|||
'SAConv2d', 'TINShift', 'tin_shift', 'assign_score_withk',
|
||||
'box_iou_rotated', 'RoIPointPool3d', 'nms_rotated', 'knn', 'ball_query',
|
||||
'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu',
|
||||
'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'QueryAndGroup',
|
||||
'GroupAll', 'grouping_operation', 'contour_expand', 'three_nn',
|
||||
'three_interpolate', 'MultiScaleDeformableAttention', 'BorderAlign',
|
||||
'border_align', 'gather_points', 'furthest_point_sample',
|
||||
'rotated_feature_align', 'RoIAlignRotated', 'roi_align_rotated',
|
||||
'pixel_group', 'QueryAndGroup', 'GroupAll', 'grouping_operation',
|
||||
'contour_expand', 'three_nn', 'three_interpolate',
|
||||
'MultiScaleDeformableAttention', 'BorderAlign', 'border_align',
|
||||
'gather_points', 'furthest_point_sample',
|
||||
'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation',
|
||||
'boxes_iou_bev', 'nms_bev', 'nms_normal_bev', 'Voxelization',
|
||||
'voxelization', 'dynamic_scatter', 'DynamicScatter', 'RoIAwarePool3d',
|
||||
|
|
|
@ -0,0 +1,129 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
// Modified from
|
||||
// https://github.com/SJTU-Thinklab-Det/r3det-on-mmdetection/blob/master/mmdet/ops/fr/src/feature_refine_kernel.cu
|
||||
#ifndef ROTATED_FEATURE_ALIGN_CUDA_KERNEL_CUH
|
||||
#define ROTATED_FEATURE_ALIGN_CUDA_KERNEL_CUH
|
||||
|
||||
#ifdef MMCV_USE_PARROTS
|
||||
#include "parrots_cuda_helper.hpp"
|
||||
#else
|
||||
#include "pytorch_cuda_helper.hpp"
|
||||
#endif
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void rotated_feature_align_forward_kernel(
|
||||
const int nthreads, const int points, const scalar_t* bottom_data,
|
||||
const scalar_t* best_bboxes, const scalar_t spatial_scale,
|
||||
const int channels, const int height, const int width, scalar_t* top_data) {
|
||||
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
int w = index % width;
|
||||
int h = (index / width) % height;
|
||||
int c = (index / width / height) % channels;
|
||||
int n = index / width / height / channels;
|
||||
|
||||
const scalar_t* bbox_offset =
|
||||
best_bboxes + ((n * height + h) * width + w) * 5;
|
||||
scalar_t roi_y = bbox_offset[0] * spatial_scale;
|
||||
scalar_t roi_x = bbox_offset[1] * spatial_scale;
|
||||
|
||||
scalar_t px[5] = {roi_x, 0, 0, 0, 0};
|
||||
scalar_t py[5] = {roi_y, 0, 0, 0, 0};
|
||||
|
||||
if (points > 1) {
|
||||
scalar_t roi_w = bbox_offset[2] * spatial_scale;
|
||||
scalar_t roi_h = bbox_offset[3] * spatial_scale;
|
||||
scalar_t roi_a = bbox_offset[4];
|
||||
|
||||
scalar_t w_2 = roi_w / 2, h_2 = roi_h / 2;
|
||||
scalar_t cosa = cosf(roi_a), sina = sinf(roi_a);
|
||||
scalar_t wx = cosa * w_2, wy = sina * w_2;
|
||||
scalar_t hx = -sina * h_2, hy = cosa * h_2;
|
||||
|
||||
px[1] = roi_x + wx + hx;
|
||||
py[1] = roi_y + wy + hy;
|
||||
px[2] = roi_x - wx + hx;
|
||||
py[2] = roi_y - wy + hy;
|
||||
px[3] = roi_x - wx - hx;
|
||||
py[3] = roi_y - wy - hy;
|
||||
px[4] = roi_x + wx - hx;
|
||||
py[4] = roi_y + wy - hy;
|
||||
}
|
||||
|
||||
const scalar_t* offset_bottom_data =
|
||||
bottom_data + (n * channels + c) * height * width;
|
||||
|
||||
scalar_t output_val = bottom_data[index];
|
||||
for (int i = 0; i < points; i++) {
|
||||
output_val += bilinear_interpolate<scalar_t>(offset_bottom_data, height,
|
||||
width, py[i], px[i], i);
|
||||
}
|
||||
top_data[index] = output_val;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void rotated_feature_align_backward_kernel(
|
||||
const int nthreads, const int points, const scalar_t* top_diff,
|
||||
const scalar_t* best_bboxes, const scalar_t spatial_scale,
|
||||
const int channels, const int height, const int width,
|
||||
scalar_t* bottom_diff) {
|
||||
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
int w = index % width;
|
||||
int h = (index / width) % height;
|
||||
int c = (index / width / height) % channels;
|
||||
int n = index / width / height / channels;
|
||||
|
||||
const scalar_t* bbox_offset =
|
||||
best_bboxes + ((n * height + h) * width + w) * 5;
|
||||
scalar_t roi_y = bbox_offset[0] * spatial_scale;
|
||||
scalar_t roi_x = bbox_offset[1] * spatial_scale;
|
||||
|
||||
scalar_t px[5] = {roi_x, 0, 0, 0, 0};
|
||||
scalar_t py[5] = {roi_y, 0, 0, 0, 0};
|
||||
|
||||
if (points > 1) {
|
||||
scalar_t roi_w = bbox_offset[2] * spatial_scale;
|
||||
scalar_t roi_h = bbox_offset[3] * spatial_scale;
|
||||
scalar_t roi_a = bbox_offset[4];
|
||||
|
||||
scalar_t w_2 = roi_w / 2, h_2 = roi_h / 2;
|
||||
scalar_t cosa = cosf(roi_a), sina = sinf(roi_a);
|
||||
scalar_t wx = cosa * w_2, wy = sina * w_2;
|
||||
scalar_t hx = -sina * h_2, hy = cosa * h_2;
|
||||
|
||||
px[1] = roi_x + wx + hx;
|
||||
py[1] = roi_y + wy + hy;
|
||||
px[2] = roi_x - wx + hx;
|
||||
py[2] = roi_y - wy + hy;
|
||||
px[3] = roi_x - wx - hx;
|
||||
py[3] = roi_y - wy - hy;
|
||||
px[4] = roi_x + wx - hx;
|
||||
py[4] = roi_y + wy - hy;
|
||||
}
|
||||
|
||||
scalar_t* offset_bottom_diff =
|
||||
bottom_diff + (n * channels + c) * height * width;
|
||||
scalar_t value_top_diff = top_diff[index];
|
||||
|
||||
atomicAdd(bottom_diff + index, value_top_diff);
|
||||
for (int i = 0; i < points; i++) {
|
||||
scalar_t w1, w2, w3, w4;
|
||||
int x_low, x_high, y_low, y_high;
|
||||
|
||||
bilinear_interpolate_gradient<scalar_t>(height, width, py[i], px[i], w1,
|
||||
w2, w3, w4, x_low, x_high, y_low,
|
||||
y_high, i);
|
||||
scalar_t g1 = value_top_diff * w1;
|
||||
scalar_t g2 = value_top_diff * w2;
|
||||
scalar_t g3 = value_top_diff * w3;
|
||||
scalar_t g4 = value_top_diff * w4;
|
||||
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
|
||||
atomicAdd(offset_bottom_diff + y_low * width + x_low, g1);
|
||||
atomicAdd(offset_bottom_diff + y_low * width + x_high, g2);
|
||||
atomicAdd(offset_bottom_diff + y_high * width + x_low, g3);
|
||||
atomicAdd(offset_bottom_diff + y_high * width + x_high, g4);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // ROTATED_FEATURE_ALIGN_CUDA_KERNEL_CUH
|
|
@ -1358,7 +1358,51 @@ void dynamic_voxelize_forward_impl(const at::Tensor& points, at::Tensor& coors,
|
|||
const std::vector<float> voxel_size,
|
||||
const std::vector<float> coors_range,
|
||||
const int NDim);
|
||||
|
||||
REGISTER_DEVICE_IMPL(hard_voxelize_forward_impl, CUDA,
|
||||
hard_voxelize_forward_cuda);
|
||||
REGISTER_DEVICE_IMPL(dynamic_voxelize_forward_impl, CUDA,
|
||||
dynamic_voxelize_forward_cuda);
|
||||
|
||||
void RotatedFeatureAlignForwardCUDAKernelLauncher(const Tensor features,
|
||||
const Tensor best_bboxes,
|
||||
const float spatial_scale,
|
||||
const int points,
|
||||
Tensor output);
|
||||
|
||||
void RotatedFeatureAlignBackwardCUDAKernelLauncher(const Tensor top_grad,
|
||||
const Tensor best_bboxes,
|
||||
const float spatial_scale,
|
||||
const int points,
|
||||
Tensor bottom_grad);
|
||||
|
||||
void rotated_feature_align_forward_cuda(const Tensor features,
|
||||
const Tensor best_bboxes,
|
||||
const float spatial_scale,
|
||||
const int points, Tensor output) {
|
||||
RotatedFeatureAlignForwardCUDAKernelLauncher(features, best_bboxes,
|
||||
spatial_scale, points, output);
|
||||
};
|
||||
|
||||
void rotated_feature_align_backward_cuda(const Tensor top_grad,
|
||||
const Tensor best_bboxes,
|
||||
const float spatial_scale,
|
||||
const int points, Tensor bottom_grad) {
|
||||
RotatedFeatureAlignBackwardCUDAKernelLauncher(
|
||||
top_grad, best_bboxes, spatial_scale, points, bottom_grad);
|
||||
};
|
||||
|
||||
void rotated_feature_align_forward_impl(const Tensor features,
|
||||
const Tensor best_bboxes,
|
||||
const float spatial_scale,
|
||||
const int points, Tensor output);
|
||||
|
||||
void rotated_feature_align_backward_impl(const Tensor top_grad,
|
||||
const Tensor best_bboxes,
|
||||
const float spatial_scale,
|
||||
const int points, Tensor bottom_grad);
|
||||
|
||||
REGISTER_DEVICE_IMPL(rotated_feature_align_forward_impl, CUDA,
|
||||
rotated_feature_align_forward_cuda);
|
||||
REGISTER_DEVICE_IMPL(rotated_feature_align_backward_impl, CUDA,
|
||||
rotated_feature_align_backward_cuda);
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
// Modified from
|
||||
// https://github.com/SJTU-Thinklab-Det/r3det-on-mmdetection/blob/master/mmdet/ops/fr/src/feature_refine_kernel.cu
|
||||
#include "pytorch_cuda_helper.hpp"
|
||||
#include "rotated_feature_align_cuda_kernel.cuh"
|
||||
|
||||
void RotatedFeatureAlignForwardCUDAKernelLauncher(const Tensor features,
|
||||
const Tensor best_bboxes,
|
||||
const float spatial_scale,
|
||||
const int points,
|
||||
Tensor output) {
|
||||
at::cuda::CUDAGuard device_guard(features.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
const int output_size = features.numel();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
features.scalar_type(), "rotated_feature_align_forward_cuda_kernel",
|
||||
([&] {
|
||||
const scalar_t* bottom_data = features.data_ptr<scalar_t>();
|
||||
const scalar_t* bboxes_data = best_bboxes.data_ptr<scalar_t>();
|
||||
scalar_t* top_data = output.data_ptr<scalar_t>();
|
||||
|
||||
rotated_feature_align_forward_kernel<scalar_t>
|
||||
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
|
||||
output_size, points, bottom_data, bboxes_data,
|
||||
scalar_t(spatial_scale), features.size(1), features.size(2),
|
||||
features.size(3), top_data);
|
||||
}));
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
void RotatedFeatureAlignBackwardCUDAKernelLauncher(const Tensor top_grad,
|
||||
const Tensor best_bboxes,
|
||||
const float spatial_scale,
|
||||
const int points,
|
||||
Tensor bottom_grad) {
|
||||
at::cuda::CUDAGuard device_guard(top_grad.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
const int output_size = top_grad.numel();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
top_grad.scalar_type(), "rotated_feature_align_backward_cuda_kernel",
|
||||
([&] {
|
||||
const scalar_t* top_diff = top_grad.data_ptr<scalar_t>();
|
||||
const scalar_t* bboxes_data = best_bboxes.data_ptr<scalar_t>();
|
||||
scalar_t* bottom_diff = bottom_grad.data_ptr<scalar_t>();
|
||||
|
||||
rotated_feature_align_backward_kernel<scalar_t>
|
||||
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
|
||||
output_size, points, top_diff, bboxes_data,
|
||||
scalar_t(spatial_scale), top_grad.size(1), top_grad.size(2),
|
||||
top_grad.size(3), bottom_diff);
|
||||
}));
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
|
@ -340,6 +340,16 @@ void correlation_backward(Tensor grad_output, Tensor input1, Tensor input2,
|
|||
int dilationH, int dilationW, int dilation_patchH,
|
||||
int dilation_patchW, int dH, int dW);
|
||||
|
||||
void rotated_feature_align_forward(const Tensor features,
|
||||
const Tensor best_bboxes, Tensor output,
|
||||
const float spatial_scale, const int points);
|
||||
|
||||
void rotated_feature_align_backward(const Tensor top_grad,
|
||||
const Tensor best_bboxes,
|
||||
Tensor bottom_grad,
|
||||
const float spatial_scale,
|
||||
const int points);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"),
|
||||
py::arg("kernel"), py::arg("up_x"), py::arg("up_y"), py::arg("down_x"),
|
||||
|
@ -686,4 +696,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||
"roiaware_pool3d_backward", py::arg("pts_idx_of_voxels"),
|
||||
py::arg("argmax"), py::arg("grad_out"), py::arg("grad_in"),
|
||||
py::arg("pool_method"));
|
||||
m.def("rotated_feature_align_forward", &rotated_feature_align_forward,
|
||||
"Feature Refine forward (CUDA)", py::arg("features"),
|
||||
py::arg("best_bboxes"), py::arg("output"), py::arg("spatial_scale"),
|
||||
py::arg("points"));
|
||||
m.def("rotated_feature_align_backward", &rotated_feature_align_backward,
|
||||
"Feature Refine backward (CUDA)", py::arg("top_grad"),
|
||||
py::arg("best_bboxes"), py::arg("bottom_grad"),
|
||||
py::arg("spatial_scale"), py::arg("points"));
|
||||
}
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
// Modified from
|
||||
// https://github.com/SJTU-Thinklab-Det/r3det-on-mmdetection/blob/master/mmdet/ops/fr/src/feature_refine_cuda.cpp
|
||||
|
||||
#include "pytorch_cpp_helper.hpp"
|
||||
#include "pytorch_device_registry.hpp"
|
||||
|
||||
void rotated_feature_align_forward_impl(const Tensor features,
|
||||
const Tensor best_bboxes,
|
||||
const float spatial_scale,
|
||||
const int points, Tensor output) {
|
||||
DISPATCH_DEVICE_IMPL(rotated_feature_align_forward_impl, features,
|
||||
best_bboxes, spatial_scale, points, output);
|
||||
}
|
||||
|
||||
void rotated_feature_align_backward_impl(const Tensor top_grad,
|
||||
const Tensor best_bboxes,
|
||||
const float spatial_scale,
|
||||
const int points, Tensor bottom_grad) {
|
||||
DISPATCH_DEVICE_IMPL(rotated_feature_align_backward_impl, top_grad,
|
||||
best_bboxes, spatial_scale, points, bottom_grad);
|
||||
}
|
||||
|
||||
void rotated_feature_align_forward(const Tensor features,
|
||||
const Tensor best_bboxes, Tensor output,
|
||||
const float spatial_scale,
|
||||
const int points) {
|
||||
rotated_feature_align_forward_impl(features, best_bboxes, spatial_scale,
|
||||
points, output);
|
||||
}
|
||||
|
||||
void rotated_feature_align_backward(const Tensor top_grad,
|
||||
const Tensor best_bboxes,
|
||||
Tensor bottom_grad,
|
||||
const float spatial_scale,
|
||||
const int points) {
|
||||
rotated_feature_align_backward_impl(top_grad, best_bboxes, spatial_scale,
|
||||
points, bottom_grad);
|
||||
}
|
|
@ -0,0 +1,75 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from torch.autograd import Function
|
||||
from torch.autograd.function import once_differentiable
|
||||
|
||||
from ..utils import ext_loader
|
||||
|
||||
ext_module = ext_loader.load_ext(
|
||||
'_ext',
|
||||
['rotated_feature_align_forward', 'rotated_feature_align_backward'])
|
||||
|
||||
|
||||
class RotatedFeatureAlignFunction(Function):
|
||||
"""Using the feature interpolation to obtain the position information
|
||||
correspond to the refined rotate anchors and reconstruct the feature maps
|
||||
in pixel-wise manner to achieve feature alignment.
|
||||
|
||||
The details are described in the paper
|
||||
`R3Det: Refined Single-Stage Detector with Feature Refinement for Rotating
|
||||
Object <https://arxiv.org/abs/1908.05612>`_.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, features, best_rbboxes, spatial_scale, points):
|
||||
"""
|
||||
Args:
|
||||
features (torch.Tensor): Input features with shape [N,C,H,W].
|
||||
best_rbboxes (torch.Tensor): Refined rotate anchors with
|
||||
shape [N,H,W,5]. Coordinate format (cx,cx,h,w,a).
|
||||
spatial_scale (float): The scale of feature map size and
|
||||
input image size.
|
||||
points (int, optional): The number of sample points.
|
||||
Only 1 and 5 are supported. Defaults to 1.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Refined features with shape [N,C,H,W].
|
||||
"""
|
||||
ctx.spatial_scale = spatial_scale
|
||||
ctx.points = points
|
||||
ctx.save_for_backward(best_rbboxes)
|
||||
assert points in [1, 5]
|
||||
output = torch.zeros_like(features)
|
||||
ext_module.rotated_feature_align_forward(features, best_rbboxes,
|
||||
output, spatial_scale, points)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@once_differentiable
|
||||
def backward(ctx, grad_output):
|
||||
"""
|
||||
Args:
|
||||
grad_output (torch.Tensor): The gradiant of output features
|
||||
with shape [N,C,H,W].
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The gradiant of input features with shape [N,C,H,W].
|
||||
"""
|
||||
best_rbboxes = ctx.saved_tensors[0]
|
||||
points = ctx.points
|
||||
spatial_scale = ctx.spatial_scale
|
||||
grad_input = None
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad_input = torch.zeros_like(grad_output)
|
||||
ext_module.rotated_feature_align_backward(grad_output.contiguous(),
|
||||
best_rbboxes, grad_input,
|
||||
spatial_scale, points)
|
||||
return grad_input, None, None, None
|
||||
|
||||
|
||||
def rotated_feature_align(features,
|
||||
best_rbboxes,
|
||||
spatial_scale=1 / 8,
|
||||
points=1):
|
||||
return RotatedFeatureAlignFunction.apply(features, best_rbboxes,
|
||||
spatial_scale, points)
|
|
@ -0,0 +1,129 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcv.ops import rotated_feature_align
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='requires CUDA support')
|
||||
def test_rotated_feature_align():
|
||||
feature = torch.tensor([[[[1.2924, -0.2172, -0.5222, 0.1172],
|
||||
[0.9144, 1.2248, 1.3115, -0.9690],
|
||||
[-0.8949, -1.1797, -0.9093, -0.3961],
|
||||
[-0.4586, 0.5062, -0.7947, -0.7397]],
|
||||
[[-1.0943, -0.7495, 1.3461, -1.1652],
|
||||
[0.2034, 0.6763, -1.2357, 0.5231],
|
||||
[-1.0062, 1.2592, 1.4225, -0.3951],
|
||||
[-0.1242, -1.6240, 0.1932, 2.7181]],
|
||||
[[-1.6271, -1.0276, 0.0578, -0.2997],
|
||||
[-0.9684, -1.6946, -1.3188, -1.1938],
|
||||
[-1.6744, -0.8917, -0.6556,
|
||||
1.0073], [-0.1205, 0.3671, -0.3731, -0.5347]]],
|
||||
[[[0.7035, 0.2089, -0.1774, 3.4670],
|
||||
[-0.8505, -0.9278, 1.4714, 0.1644],
|
||||
[0.0898, 0.3531, -0.4007, 0.1927],
|
||||
[1.2569, -0.2636, -0.5223, 0.0616]],
|
||||
[[0.1760, -0.7639, -0.4600, -1.3260],
|
||||
[-0.9921, -0.2970, -0.8955, 1.0508],
|
||||
[1.3515, -0.1641, 1.9679, 1.1986],
|
||||
[-0.3616, 0.6287, 0.4933, 0.3360]],
|
||||
[[-0.5860, 0.2124, -0.8700, 2.4200],
|
||||
[-0.0551, -1.5103, -1.6779, 0.8399],
|
||||
[0.8431, 1.2414, -1.1243, -0.3887],
|
||||
[-2.1254, 0.6047, -0.3515, 0.7254]]]],
|
||||
device='cuda',
|
||||
requires_grad=True)
|
||||
|
||||
bbox = torch.tensor(
|
||||
[[[[1.3080e+01, 1.2688e+01, 1.1214e+01, 9.3944e+01, -9.1905e-01],
|
||||
[3.8104e+01, 1.0134e+01, 1.4659e+02, 9.0306e+01, -9.8211e-01],
|
||||
[-5.3213e+01, 4.9508e+01, 5.1513e+01, 3.2055e+01, -3.1954e-01],
|
||||
[2.6974e+01, 2.5248e+01, 5.4495e+01, 3.1083e+00, -6.2127e-01]],
|
||||
[[-1.5604e+01, -5.1908e+01, 2.3998e+02, 1.5008e+01, -1.2546e+00],
|
||||
[3.1354e+01, -7.3635e+00, 6.7879e+01, 3.5081e+01, -3.3851e-01],
|
||||
[-5.3292e+00, 9.1946e+00, 1.2834e+01, 1.0485e+01, -1.3039e+00],
|
||||
[-2.3925e+01, 3.6623e+01, 3.9875e+01, 7.2009e+01, -6.5934e-01]],
|
||||
[[7.2114e+01, -2.3781e+01, 2.9106e+01, 8.4501e+01, -1.1340e+00],
|
||||
[2.6258e+01, -7.7034e+00, 1.7629e+02, 1.0615e+02, -1.2156e+00],
|
||||
[3.8057e+01, 4.6016e+01, 1.2965e+01, 6.9384e+00, -1.0855e+00],
|
||||
[2.4428e+01, -1.6189e+01, 2.0572e+02, 3.1622e+01, -1.5719e-01]],
|
||||
[[3.8226e+00, 2.9608e+01, 1.4457e+01, 6.8179e+01, -9.1997e-01],
|
||||
[2.5003e+01, -4.2490e+01, 9.6007e+01, 4.9086e+01, -1.4786e+00],
|
||||
[8.5983e+01, 5.4980e+01, 7.8080e+01, 1.0003e+02, -1.0926e+00],
|
||||
[9.9065e+00, 4.1457e+01, 5.9799e+00, 1.7973e+01, -5.6313e-01]]],
|
||||
[[[-1.8244e+01, 4.6309e+00, 5.3010e+01, 2.4310e+01, -7.0345e-01],
|
||||
[1.9419e+01, 3.6704e+01, 5.2390e+01, 5.4133e+01, -3.7730e-01],
|
||||
[5.6387e+01, 2.3752e+01, 9.0441e+00, 1.7792e+01, -1.5583e+00],
|
||||
[3.6303e+01, 1.6396e+01, 2.0283e+01, 1.9148e+01, -8.3419e-01]],
|
||||
[[3.2169e+01, 3.0521e+01, 2.6283e+01, 1.9680e+02, -3.0454e-01],
|
||||
[2.5788e+01, -3.2189e+01, 8.8882e+01, 1.0207e+02, -1.5328e+00],
|
||||
[8.4676e+00, -1.6668e+01, 2.4657e+01, 1.1275e+02, -4.0388e-01],
|
||||
[-1.0799e+01, 6.0422e+00, 9.5807e+00, 3.3677e+01, -3.5438e-01]],
|
||||
[[6.9363e+01, 1.0850e+01, 2.5968e+01, 2.2311e+01, -1.6408e-01],
|
||||
[2.8140e+00, 4.6843e+00, 3.1289e+00, 2.1480e+01, -6.7583e-01],
|
||||
[2.6661e+01, 4.5290e+01, 6.1679e+00, 3.0005e+01, -8.9806e-01],
|
||||
[5.0871e+00, 1.3234e+01, 9.2087e+01, 4.9622e+01, -2.8020e-01]],
|
||||
[[-1.2643e+01, 2.5176e+01, 5.0488e+01, 5.4246e+01, -4.4840e-01],
|
||||
[-3.4521e+01, 9.8435e-01, 5.2413e+01, 9.7996e+00, -8.4218e-01],
|
||||
[4.9829e+01, -1.0808e+01, 2.9848e+01, 7.3579e+01, -6.2672e-01],
|
||||
[8.0446e+01, 2.8064e+01, 4.5273e+01, 5.3809e+01, -1.2359e+00]]]],
|
||||
device='cuda',
|
||||
requires_grad=True)
|
||||
|
||||
expected_output = torch.tensor([[[[1.1095, -0.2172, -0.5222, -0.6225],
|
||||
[0.9144, 0.7662, 1.0487, -0.9690],
|
||||
[-0.8949, -1.6384, -0.9093, -0.3961],
|
||||
[-0.8604, 0.5062, -0.7947, -0.7397]],
|
||||
[[-0.3961, -0.7495, 1.3461, 1.5528],
|
||||
[0.2034, 0.5522, -1.6722, 0.5231],
|
||||
[-1.0062, 1.1350, 1.4225, -0.3951],
|
||||
[-0.4826, -1.6240, 0.1932, 2.7181]],
|
||||
[[-2.6436, -1.0276, 0.0578, -0.8344],
|
||||
[-0.9684, -1.8151, -2.1843, -1.1938],
|
||||
[-1.6744, -1.0121, -0.6556, 1.0073],
|
||||
[-0.8474, 0.3671, -0.3731, -0.5347]]],
|
||||
[[[0.7035, 0.2089, -0.1774, 3.4670],
|
||||
[-0.8505, -0.9278, 1.4714, 0.1644],
|
||||
[0.0898, 0.3064, -0.4007, 0.5849],
|
||||
[1.2569, -0.2636, -0.5223, 0.0616]],
|
||||
[[0.1760, -0.7639, -0.4600, -1.3260],
|
||||
[-0.9921, -0.2970, -0.8955, 1.0508],
|
||||
[1.3515, -0.6125, 1.9679, 0.5550],
|
||||
[-0.3616, 0.6287, 0.4933, 0.3360]],
|
||||
[[-0.5860, 0.2124, -0.8700, 2.4200],
|
||||
[-0.0551, -1.5103, -1.6779, 0.8399],
|
||||
[0.8431, 0.8455, -1.1243, -1.5994],
|
||||
[-2.1254, 0.6047, -0.3515,
|
||||
0.7254]]]]).cuda()
|
||||
|
||||
expected_grad = torch.tensor([[[[1.0000, 1.8507, 1.1493, 1.5222],
|
||||
[1.0000, 1.1511, 1.2139, 1.4778],
|
||||
[1.0000, 1.2629, 1.3721, 1.0000],
|
||||
[3.0000, 1.0000, 1.0000, 2.0000]],
|
||||
[[1.0000, 1.8507, 1.1493, 1.5222],
|
||||
[1.0000, 1.1511, 1.2139, 1.4778],
|
||||
[1.0000, 1.2629, 1.3721, 1.0000],
|
||||
[3.0000, 1.0000, 1.0000, 2.0000]],
|
||||
[[1.0000, 1.8507, 1.1493, 1.5222],
|
||||
[1.0000, 1.1511, 1.2139, 1.4778],
|
||||
[1.0000, 1.2629, 1.3721, 1.0000],
|
||||
[3.0000, 1.0000, 1.0000, 2.0000]]],
|
||||
[[[1.2687, 1.5055, 1.2382, 1.0000],
|
||||
[1.1458, 1.4258, 1.4160, 1.0000],
|
||||
[1.0000, 1.0000, 1.0000, 1.0000],
|
||||
[1.0000, 1.0000, 1.0000, 1.0000]],
|
||||
[[1.2687, 1.5055, 1.2382, 1.0000],
|
||||
[1.1458, 1.4258, 1.4160, 1.0000],
|
||||
[1.0000, 1.0000, 1.0000, 1.0000],
|
||||
[1.0000, 1.0000, 1.0000, 1.0000]],
|
||||
[[1.2687, 1.5055, 1.2382, 1.0000],
|
||||
[1.1458, 1.4258, 1.4160, 1.0000],
|
||||
[1.0000, 1.0000, 1.0000, 1.0000],
|
||||
[1.0000, 1.0000, 1.0000,
|
||||
1.0000]]]]).cuda()
|
||||
|
||||
output = rotated_feature_align(
|
||||
feature, bbox, spatial_scale=1 / 8, points=1)
|
||||
output.backward(torch.ones_like(output))
|
||||
assert torch.allclose(output, expected_output, 1e-2)
|
||||
assert torch.allclose(feature.grad, expected_grad, 1e-2)
|
Loading…
Reference in New Issue