[Feature] Add active rotated filter op for rotated detection. (#1598)

* add  active_rotated_filter

* fix lint

* fix lint

* renaming nRotation and nOrientation

* Update active_rotated_filter_cuda_kernel.cuh

* Update active_rotated_filter_cuda.cu

* fix bug

* fix lint

* Update test_active_rotated_filter.py

* fix lint

* Update active_rotated_filter_cuda_kernel.cuh

* renaming

* Update mmcv/ops/active_rotated_filter.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/ops/active_rotated_filter.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/ops/active_rotated_filter.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/ops/active_rotated_filter.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/ops/active_rotated_filter.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/ops/active_rotated_filter.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* fix lint

* Resolving conflicts

* fix lint.

* Update __init__.py

* Update mmcv/ops/csrc/pytorch/cuda/active_rotated_filter_cuda.cu

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update active_rotated_filter.cpp

* fix lint

* Update mmcv/ops/csrc/pytorch/cpu/active_rotated_filter.cpp

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/ops/csrc/pytorch/cpu/active_rotated_filter.cpp

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update active_rotated_filter.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
pull/1610/head^2
Yue Zhou 2022-01-10 17:35:20 +08:00 committed by GitHub
parent 40518322b6
commit 9acc892a44
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 629 additions and 1 deletions

View File

@ -2,6 +2,7 @@
We implement common CUDA ops used in detection, segmentation, etc.
- ActiveRotatedFilter
- AssignScoreWithK
- BallQuery
- BBoxOverlaps

View File

@ -2,6 +2,7 @@
MMCV 提供了检测、分割等任务中常用的 CUDA 算子
- ActiveRotatedFilter
- AssignScoreWithK
- BallQuery
- BBoxOverlaps

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .active_rotated_filter import active_rotated_filter
from .assign_score_withk import assign_score_withk
from .ball_query import ball_query
from .bbox import bbox_overlaps
@ -83,5 +84,5 @@ __all__ = [
'boxes_iou_bev', 'nms_bev', 'nms_normal_bev', 'Voxelization',
'voxelization', 'dynamic_scatter', 'DynamicScatter', 'RoIAwarePool3d',
'points_in_boxes_part', 'points_in_boxes_cpu', 'points_in_boxes_all',
'points_in_polygons', 'min_area_polygons'
'points_in_polygons', 'min_area_polygons', 'active_rotated_filter'
]

View File

@ -0,0 +1,61 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from ..utils import ext_loader
ext_module = ext_loader.load_ext(
'_ext',
['active_rotated_filter_forward', 'active_rotated_filter_backward'])
class ActiveRotatedFilterFunction(Function):
"""Encoding the orientation information and generating orientation-
sensitive features.
The details are described in the paper `Align Deep Features for Oriented
Object Detection <https://arxiv.org/abs/2008.09397>_`.
"""
@staticmethod
def forward(ctx, input, indices):
"""
Args:
input (torch.Tensor): Input features with shape
[num_output_planes, num_input_planes, num_orientations, H, W].
indices (torch.Tensor): Indices with shape
[num_orientations, H, W, num_rotations].
Returns:
torch.Tensor: Refined features with shape [num_output_planes *
num_rotations, num_input_planes * num_orientations, H, W].
"""
ctx.save_for_backward(input, indices)
op, ip, o, h, w = input.size()
o, h, w, r = indices.size()
output = input.new_zeros((op * r, ip * o, h, w))
ext_module.active_rotated_filter_forward(input, indices, output)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_out):
"""
Args:
grad_output (torch.Tensor): The gradiant of output features
with shape [num_output_planes * num_rotations,
num_input_planes * num_orientations, H, W].
Returns:
torch.Tensor: The gradiant of input features with shape
[num_output_planes, num_input_planes, num_orientations, H, W].
"""
input, indices = ctx.saved_tensors
grad_in = torch.zeros_like(input)
ext_module.active_rotated_filter_backward(grad_out, indices, grad_in)
return grad_in, None
active_rotated_filter = ActiveRotatedFilterFunction.apply

View File

@ -0,0 +1,59 @@
// Copyright (c) OpenMMLab. All rights reserved.
// Modified from
// https://github.com/csuhan/s2anet/blob/master/mmdet/ops/orn/src/cuda/ActiveRotatingFilter_cuda.cu
#ifndef ACTIVE_ROTATED_FILTER_CUDA_KERNEL_CUH
#define ACTIVE_ROTATED_FILTER_CUDA_KERNEL_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
template <typename scalar_t>
__global__ void active_rotated_filter_forward_cuda_kernel(
const int nthreads, const scalar_t* weight_data, const int* indices_data,
const int num_input_planes, const int num_output_planes,
const int num_orientations, const int num_rotations, const int nEntry,
scalar_t* output_data) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
int l = index % nEntry;
int j = (index / nEntry) % num_input_planes;
int i = index / nEntry / num_input_planes;
int k;
scalar_t val = *(weight_data + index);
for (k = 0; k < num_rotations; k++) {
int idx = (int)(*(indices_data + l * num_rotations + k)) - 1;
scalar_t* target = output_data +
i * (num_rotations * num_input_planes * nEntry) +
k * (num_input_planes * nEntry) + j * (nEntry) + idx;
*target = val;
}
}
}
template <typename scalar_t>
__global__ void active_rotated_filter_backward_cuda_kernel(
const int nthreads, const scalar_t* gradWeight_data,
const int* indices_data, const int num_input_planes,
const int num_output_planes, const int num_orientations,
const int num_rotations, const int nEntry, scalar_t* weight_data) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
int l = index % nEntry;
int j = (index / nEntry) % num_input_planes;
int i = index / nEntry / num_input_planes;
int k;
scalar_t* val = weight_data + index;
*val = 0;
scalar_t tmp = 0;
for (k = 0; k < num_rotations; k++) {
int idx = (int)(*(indices_data + l * num_rotations + k)) - 1;
scalar_t target =
*(gradWeight_data + i * (num_rotations * num_input_planes * nEntry) +
k * (num_input_planes * nEntry) + j * (nEntry) + idx);
tmp = tmp + target;
}
*val = tmp;
}
}
#endif // ACTIVE_ROTATED_FILTER_CUDA_KERNEL_CUH

View File

@ -0,0 +1,28 @@
// Copyright (c) OpenMMLab. All rights reserved.
// Modified from
// https://github.com/csuhan/s2anet/blob/master/mmdet/ops/orn/src/ActiveRotatingFilter.h
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"
void active_rotated_filter_forward_impl(const Tensor input,
const Tensor indices, Tensor output) {
DISPATCH_DEVICE_IMPL(active_rotated_filter_forward_impl, input, indices,
output);
}
void active_rotated_filter_backward_impl(const Tensor grad_out,
const Tensor indices, Tensor grad_in) {
DISPATCH_DEVICE_IMPL(active_rotated_filter_backward_impl, grad_out, indices,
grad_in);
}
void active_rotated_filter_forward(const Tensor input, const Tensor indices,
Tensor output) {
active_rotated_filter_forward_impl(input, indices, output);
}
void active_rotated_filter_backward(const Tensor grad_out, const Tensor indices,
Tensor grad_in) {
active_rotated_filter_backward_impl(grad_out, indices, grad_in);
}

View File

@ -0,0 +1,120 @@
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
// modified from
// https://github.com/csuhan/s2anet/blob/master/mmdet/ops/orn/src/cpu/ActiveRotatingFilter_cpu.cpp
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"
template <typename T>
void active_rotated_filter_forward_cpu_kernel(
const T* weightData, const int* indicesData, const int num_output_planes,
const int num_input_planes, const int num_orientations, const int kH,
const int kW, const int num_rotations, T* outputData) {
const int nEntry = num_orientations * kH * kW;
int i, j, l;
int k;
#pragma omp parallel for private(i, j, l, k)
for (i = 0; i < num_output_planes; i++) {
for (j = 0; j < num_input_planes; j++) {
for (l = 0; l < nEntry; l++) {
int weightIndex = i * num_input_planes * nEntry + j * nEntry + l;
T val = *(weightData + weightIndex);
for (k = 0; k < num_rotations; k++) {
int index = (int)(*(indicesData + l * num_rotations + k)) - 1;
T* target = outputData +
i * (num_rotations * num_input_planes * nEntry) +
k * (num_input_planes * nEntry) + j * (nEntry) + index;
*target = val;
}
}
}
}
}
template <typename T>
void active_rotated_filter_backward_cpu_kernel(
const T* gradOutputData, const int* indicesData,
const int num_output_planes, const int num_input_planes,
const int num_orientations, const int kH, const int kW,
const int num_rotations, T* gradInputData) {
const int nEntry = num_orientations * kH * kW;
int i, j, l;
int k;
#pragma omp parallel for private(i, j, l, k)
for (i = 0; i < num_output_planes; i++) {
for (j = 0; j < num_input_planes; j++) {
for (l = 0; l < nEntry; l++) {
int gradInputIndex = i * num_input_planes * nEntry + j * nEntry + l;
T* val = gradInputData + gradInputIndex;
*val = 0;
for (k = 0; k < num_rotations; k++) {
int index = (int)(*(indicesData + l * num_rotations + k)) - 1;
const T* target =
gradOutputData + i * (num_rotations * num_input_planes * nEntry) +
k * (num_input_planes * nEntry) + j * (nEntry) + index;
*val = *val + *target;
}
}
}
}
}
void ActiveRotatedFilterForwardCPULauncher(const Tensor input,
const Tensor indices,
Tensor output) {
const int num_output_planes = input.size(0);
const int num_input_planes = input.size(1);
const int num_orientations = input.size(2);
const int kH = input.size(3);
const int kW = input.size(4);
const int num_rotations = indices.size(3);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "active_rotated_filter_forward_cpu_kernel", [&] {
active_rotated_filter_forward_cpu_kernel<scalar_t>(
input.data_ptr<scalar_t>(), indices.data_ptr<int>(),
num_output_planes, num_input_planes, num_orientations, kH, kW,
num_rotations, output.data_ptr<scalar_t>());
});
}
void ActiveRotatedFilterBackwardCPULauncher(const Tensor grad_out,
const Tensor indices,
Tensor grad_in) {
const int num_orientations = indices.size(0);
const int kH = indices.size(1);
const int kW = indices.size(2);
const int num_rotations = indices.size(3);
const int num_output_planes = grad_out.size(0) / num_rotations;
const int num_input_planes = grad_out.size(1) / num_orientations;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_out.scalar_type(), "active_rotated_filter_backward_cpu_kernel", [&] {
active_rotated_filter_backward_cpu_kernel<scalar_t>(
grad_out.data_ptr<scalar_t>(), indices.data_ptr<int>(),
num_output_planes, num_input_planes, num_orientations, kH, kW,
num_rotations, grad_in.data_ptr<scalar_t>());
});
}
void active_rotated_filter_forward_cpu(const Tensor input, const Tensor indices,
Tensor output) {
ActiveRotatedFilterForwardCPULauncher(input, indices, output);
}
void active_rotated_filter_backward_cpu(const Tensor grad_out,
const Tensor indices, Tensor grad_in) {
ActiveRotatedFilterBackwardCPULauncher(grad_out, indices, grad_in);
}
void active_rotated_filter_forward_impl(const Tensor input,
const Tensor indices, Tensor output);
void active_rotated_filter_backward_impl(const Tensor grad_out,
const Tensor indices, Tensor grad_in);
REGISTER_DEVICE_IMPL(active_rotated_filter_forward_impl, CPU,
active_rotated_filter_forward_cpu);
REGISTER_DEVICE_IMPL(active_rotated_filter_backward_impl, CPU,
active_rotated_filter_backward_cpu);

View File

@ -0,0 +1,58 @@
// Copyright (c) OpenMMLab. All rights reserved.
// Modified from
// https://github.com/csuhan/s2anet/blob/master/mmdet/ops/orn/src/cuda/ActiveRotatingFilter_cuda.cu
#include "active_rotated_filter_cuda_kernel.cuh"
#include "pytorch_cuda_helper.hpp"
void ActiveRotatedFilterForwardCUDAKernelLauncher(const Tensor input,
const Tensor indices,
Tensor output) {
int num_output_planes = input.size(0);
int num_input_planes = input.size(1);
int num_orientations = input.size(2);
int kH = input.size(3);
int kW = input.size(4);
int num_rotations = indices.size(3);
int nEntry = num_orientations * kH * kW;
int output_size = output.numel();
at::cuda::CUDAGuard device_guard(input.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "active_rotated_filter_forward_cuda_kernel", [&] {
active_rotated_filter_forward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, input.data_ptr<scalar_t>(),
indices.data_ptr<int>(), num_input_planes, num_output_planes,
num_orientations, num_rotations, nEntry,
output.data_ptr<scalar_t>());
});
AT_CUDA_CHECK(cudaGetLastError());
}
void ActiveRotatedFilterBackwardCUDAKernelLauncher(const Tensor grad_out,
const Tensor indices,
Tensor grad_in) {
int num_orientations = indices.size(0);
int kH = indices.size(1);
int kW = indices.size(2);
int num_rotations = indices.size(3);
int num_output_planes = grad_out.size(0) / num_rotations;
int num_input_planes = grad_out.size(1) / num_orientations;
int nEntry = num_orientations * kH * kW;
int output_size = grad_in.numel();
at::cuda::CUDAGuard device_guard(indices.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_out.scalar_type(), "active_rotated_filter_backward_cuda_kernel",
[&] {
active_rotated_filter_backward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, grad_out.data_ptr<scalar_t>(),
indices.data_ptr<int>(), num_input_planes, num_output_planes,
num_orientations, num_rotations, nEntry,
grad_in.data_ptr<scalar_t>());
});
AT_CUDA_CHECK(cudaGetLastError());
}

View File

@ -1508,4 +1508,34 @@ void min_area_polygons_cuda(const Tensor pointsets, Tensor polygons) {
}
void min_area_polygons_impl(const Tensor pointsets, Tensor polygons);
REGISTER_DEVICE_IMPL(min_area_polygons_impl, CUDA, min_area_polygons_cuda);
void ActiveRotatedFilterForwardCUDAKernelLauncher(const Tensor input,
const Tensor indices,
Tensor output);
void ActiveRotatedFilterBackwardCUDAKernelLauncher(const Tensor grad_out,
const Tensor indices,
Tensor grad_in);
void active_rotated_filter_forward_cuda(const Tensor input,
const Tensor indices, Tensor output) {
ActiveRotatedFilterForwardCUDAKernelLauncher(input, indices, output);
};
void active_rotated_filter_backward_cuda(const Tensor grad_out,
const Tensor indices, Tensor grad_in) {
ActiveRotatedFilterBackwardCUDAKernelLauncher(grad_out, indices, grad_in);
};
void active_rotated_filter_forward_impl(const Tensor input,
const Tensor indices, Tensor output);
void active_rotated_filter_backward_impl(const Tensor grad_out,
const Tensor indices, Tensor grad_in);
REGISTER_DEVICE_IMPL(active_rotated_filter_forward_impl, CUDA,
active_rotated_filter_forward_cuda);
REGISTER_DEVICE_IMPL(active_rotated_filter_backward_impl, CUDA,
active_rotated_filter_backward_cuda);

View File

@ -365,6 +365,12 @@ void points_in_polygons_forward(Tensor points, Tensor polygons, Tensor output);
void min_area_polygons(const Tensor pointsets, Tensor polygons);
void active_rotated_filter_forward(const Tensor input, const Tensor indices,
Tensor output);
void active_rotated_filter_backward(const Tensor grad_out, const Tensor indices,
Tensor grad_in);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"),
py::arg("kernel"), py::arg("up_x"), py::arg("up_y"), py::arg("down_x"),
@ -735,4 +741,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("output"));
m.def("min_area_polygons", &min_area_polygons, "min_area_polygons",
py::arg("pointsets"), py::arg("polygons"));
m.def("active_rotated_filter_forward", &active_rotated_filter_forward,
"active_rotated_filter_forward", py::arg("input"), py::arg("indices"),
py::arg("output"));
m.def("active_rotated_filter_backward", &active_rotated_filter_backward,
"active_rotated_filter_backward", py::arg("grad_out"),
py::arg("indices"), py::arg("grad_in"));
}

View File

@ -0,0 +1,257 @@
import numpy as np
import pytest
import torch
from mmcv.ops import active_rotated_filter
np_feature = np.array([[[[[-1.4934e-01, 1.1341e+00, -1.6241e-01],
[-1.0986e+00, -1.1463e+00, -1.3176e+00],
[1.4808e+00, 7.6572e-01, -1.4548e+00]]]],
[[[[1.9370e+00, 6.2799e-01, 2.5834e-02],
[-1.4242e+00, 7.6566e-01, 1.0015e+00],
[9.8669e-01, 4.1356e-01, 6.1068e-01]]]],
[[[[1.4565e+00, 1.4960e+00, 2.4339e-01],
[-2.2484e-01, 7.5942e-01, -8.1184e-01],
[-1.7077e+00, 1.0658e+00, 3.8311e-01]]]],
[[[[8.4734e-01, 1.0904e+00, 2.4356e+00],
[9.5822e-01, 2.2260e-01, -2.4450e-01],
[-1.5078e+00, 7.0902e-02, -1.5921e+00]]]],
[[[[2.1173e+00, -7.3524e-01, 1.8888e+00],
[1.0169e+00, 4.7033e-01, -1.0875e+00],
[-1.0736e+00, -5.2245e-01, -2.8733e-01]]]],
[[[[-5.6433e-01, 1.5835e+00, -1.5826e+00],
[-8.8974e-01, -4.3128e-01, -2.2423e-01],
[1.6552e-03, -1.7292e+00, 2.6639e-01]]]],
[[[[-1.2951e-01, 1.3493e+00, -1.9329e+00],
[5.6248e-01, -5.1189e-01, 1.3614e+00],
[3.3680e-01, -8.7148e-01, 5.0592e-01]]]],
[[[[1.6781e-02, -8.3929e-01, 1.2060e+00],
[-1.0764e+00, 4.7821e-01, 1.5342e+00],
[-4.4542e-01, -1.8606e+00, 3.0827e-01]]]]])
np_indices = np.array([[[[1, 2, 3, 6, 9, 8, 7, 4], [2, 3, 6, 9, 8, 7, 4, 1],
[3, 6, 9, 8, 7, 4, 1, 2]],
[[4, 1, 2, 3, 6, 9, 8, 7], [5, 5, 5, 5, 5, 5, 5, 5],
[6, 9, 8, 7, 4, 1, 2, 3]],
[[7, 4, 1, 2, 3, 6, 9, 8], [8, 7, 4, 1, 2, 3, 6, 9],
[9, 8, 7, 4, 1, 2, 3, 6]]]])
expected_output = np.array([[[[-1.4934e-01, 1.1341e+00, -1.6241e-01],
[-1.0986e+00, -1.1463e+00, -1.3176e+00],
[1.4808e+00, 7.6572e-01, -1.4548e+00]]],
[[[-1.0986e+00, -1.4934e-01, 1.1341e+00],
[1.4808e+00, -1.1463e+00, -1.6241e-01],
[7.6572e-01, -1.4548e+00, -1.3176e+00]]],
[[[1.4808e+00, -1.0986e+00, -1.4934e-01],
[7.6572e-01, -1.1463e+00, 1.1341e+00],
[-1.4548e+00, -1.3176e+00, -1.6241e-01]]],
[[[7.6572e-01, 1.4808e+00, -1.0986e+00],
[-1.4548e+00, -1.1463e+00, -1.4934e-01],
[-1.3176e+00, -1.6241e-01, 1.1341e+00]]],
[[[-1.4548e+00, 7.6572e-01, 1.4808e+00],
[-1.3176e+00, -1.1463e+00, -1.0986e+00],
[-1.6241e-01, 1.1341e+00, -1.4934e-01]]],
[[[-1.3176e+00, -1.4548e+00, 7.6572e-01],
[-1.6241e-01, -1.1463e+00, 1.4808e+00],
[1.1341e+00, -1.4934e-01, -1.0986e+00]]],
[[[-1.6241e-01, -1.3176e+00, -1.4548e+00],
[1.1341e+00, -1.1463e+00, 7.6572e-01],
[-1.4934e-01, -1.0986e+00, 1.4808e+00]]],
[[[1.1341e+00, -1.6241e-01, -1.3176e+00],
[-1.4934e-01, -1.1463e+00, -1.4548e+00],
[-1.0986e+00, 1.4808e+00, 7.6572e-01]]],
[[[1.9370e+00, 6.2799e-01, 2.5834e-02],
[-1.4242e+00, 7.6566e-01, 1.0015e+00],
[9.8669e-01, 4.1356e-01, 6.1068e-01]]],
[[[-1.4242e+00, 1.9370e+00, 6.2799e-01],
[9.8669e-01, 7.6566e-01, 2.5834e-02],
[4.1356e-01, 6.1068e-01, 1.0015e+00]]],
[[[9.8669e-01, -1.4242e+00, 1.9370e+00],
[4.1356e-01, 7.6566e-01, 6.2799e-01],
[6.1068e-01, 1.0015e+00, 2.5834e-02]]],
[[[4.1356e-01, 9.8669e-01, -1.4242e+00],
[6.1068e-01, 7.6566e-01, 1.9370e+00],
[1.0015e+00, 2.5834e-02, 6.2799e-01]]],
[[[6.1068e-01, 4.1356e-01, 9.8669e-01],
[1.0015e+00, 7.6566e-01, -1.4242e+00],
[2.5834e-02, 6.2799e-01, 1.9370e+00]]],
[[[1.0015e+00, 6.1068e-01, 4.1356e-01],
[2.5834e-02, 7.6566e-01, 9.8669e-01],
[6.2799e-01, 1.9370e+00, -1.4242e+00]]],
[[[2.5834e-02, 1.0015e+00, 6.1068e-01],
[6.2799e-01, 7.6566e-01, 4.1356e-01],
[1.9370e+00, -1.4242e+00, 9.8669e-01]]],
[[[6.2799e-01, 2.5834e-02, 1.0015e+00],
[1.9370e+00, 7.6566e-01, 6.1068e-01],
[-1.4242e+00, 9.8669e-01, 4.1356e-01]]],
[[[1.4565e+00, 1.4960e+00, 2.4339e-01],
[-2.2484e-01, 7.5942e-01, -8.1184e-01],
[-1.7077e+00, 1.0658e+00, 3.8311e-01]]],
[[[-2.2484e-01, 1.4565e+00, 1.4960e+00],
[-1.7077e+00, 7.5942e-01, 2.4339e-01],
[1.0658e+00, 3.8311e-01, -8.1184e-01]]],
[[[-1.7077e+00, -2.2484e-01, 1.4565e+00],
[1.0658e+00, 7.5942e-01, 1.4960e+00],
[3.8311e-01, -8.1184e-01, 2.4339e-01]]],
[[[1.0658e+00, -1.7077e+00, -2.2484e-01],
[3.8311e-01, 7.5942e-01, 1.4565e+00],
[-8.1184e-01, 2.4339e-01, 1.4960e+00]]],
[[[3.8311e-01, 1.0658e+00, -1.7077e+00],
[-8.1184e-01, 7.5942e-01, -2.2484e-01],
[2.4339e-01, 1.4960e+00, 1.4565e+00]]],
[[[-8.1184e-01, 3.8311e-01, 1.0658e+00],
[2.4339e-01, 7.5942e-01, -1.7077e+00],
[1.4960e+00, 1.4565e+00, -2.2484e-01]]],
[[[2.4339e-01, -8.1184e-01, 3.8311e-01],
[1.4960e+00, 7.5942e-01, 1.0658e+00],
[1.4565e+00, -2.2484e-01, -1.7077e+00]]],
[[[1.4960e+00, 2.4339e-01, -8.1184e-01],
[1.4565e+00, 7.5942e-01, 3.8311e-01],
[-2.2484e-01, -1.7077e+00, 1.0658e+00]]],
[[[8.4734e-01, 1.0904e+00, 2.4356e+00],
[9.5822e-01, 2.2260e-01, -2.4450e-01],
[-1.5078e+00, 7.0902e-02, -1.5921e+00]]],
[[[9.5822e-01, 8.4734e-01, 1.0904e+00],
[-1.5078e+00, 2.2260e-01, 2.4356e+00],
[7.0902e-02, -1.5921e+00, -2.4450e-01]]],
[[[-1.5078e+00, 9.5822e-01, 8.4734e-01],
[7.0902e-02, 2.2260e-01, 1.0904e+00],
[-1.5921e+00, -2.4450e-01, 2.4356e+00]]],
[[[7.0902e-02, -1.5078e+00, 9.5822e-01],
[-1.5921e+00, 2.2260e-01, 8.4734e-01],
[-2.4450e-01, 2.4356e+00, 1.0904e+00]]],
[[[-1.5921e+00, 7.0902e-02, -1.5078e+00],
[-2.4450e-01, 2.2260e-01, 9.5822e-01],
[2.4356e+00, 1.0904e+00, 8.4734e-01]]],
[[[-2.4450e-01, -1.5921e+00, 7.0902e-02],
[2.4356e+00, 2.2260e-01, -1.5078e+00],
[1.0904e+00, 8.4734e-01, 9.5822e-01]]],
[[[2.4356e+00, -2.4450e-01, -1.5921e+00],
[1.0904e+00, 2.2260e-01, 7.0902e-02],
[8.4734e-01, 9.5822e-01, -1.5078e+00]]],
[[[1.0904e+00, 2.4356e+00, -2.4450e-01],
[8.4734e-01, 2.2260e-01, -1.5921e+00],
[9.5822e-01, -1.5078e+00, 7.0902e-02]]],
[[[2.1173e+00, -7.3524e-01, 1.8888e+00],
[1.0169e+00, 4.7033e-01, -1.0875e+00],
[-1.0736e+00, -5.2245e-01, -2.8733e-01]]],
[[[1.0169e+00, 2.1173e+00, -7.3524e-01],
[-1.0736e+00, 4.7033e-01, 1.8888e+00],
[-5.2245e-01, -2.8733e-01, -1.0875e+00]]],
[[[-1.0736e+00, 1.0169e+00, 2.1173e+00],
[-5.2245e-01, 4.7033e-01, -7.3524e-01],
[-2.8733e-01, -1.0875e+00, 1.8888e+00]]],
[[[-5.2245e-01, -1.0736e+00, 1.0169e+00],
[-2.8733e-01, 4.7033e-01, 2.1173e+00],
[-1.0875e+00, 1.8888e+00, -7.3524e-01]]],
[[[-2.8733e-01, -5.2245e-01, -1.0736e+00],
[-1.0875e+00, 4.7033e-01, 1.0169e+00],
[1.8888e+00, -7.3524e-01, 2.1173e+00]]],
[[[-1.0875e+00, -2.8733e-01, -5.2245e-01],
[1.8888e+00, 4.7033e-01, -1.0736e+00],
[-7.3524e-01, 2.1173e+00, 1.0169e+00]]],
[[[1.8888e+00, -1.0875e+00, -2.8733e-01],
[-7.3524e-01, 4.7033e-01, -5.2245e-01],
[2.1173e+00, 1.0169e+00, -1.0736e+00]]],
[[[-7.3524e-01, 1.8888e+00, -1.0875e+00],
[2.1173e+00, 4.7033e-01, -2.8733e-01],
[1.0169e+00, -1.0736e+00, -5.2245e-01]]],
[[[-5.6433e-01, 1.5835e+00, -1.5826e+00],
[-8.8974e-01, -4.3128e-01, -2.2423e-01],
[1.6552e-03, -1.7292e+00, 2.6639e-01]]],
[[[-8.8974e-01, -5.6433e-01, 1.5835e+00],
[1.6552e-03, -4.3128e-01, -1.5826e+00],
[-1.7292e+00, 2.6639e-01, -2.2423e-01]]],
[[[1.6552e-03, -8.8974e-01, -5.6433e-01],
[-1.7292e+00, -4.3128e-01, 1.5835e+00],
[2.6639e-01, -2.2423e-01, -1.5826e+00]]],
[[[-1.7292e+00, 1.6552e-03, -8.8974e-01],
[2.6639e-01, -4.3128e-01, -5.6433e-01],
[-2.2423e-01, -1.5826e+00, 1.5835e+00]]],
[[[2.6639e-01, -1.7292e+00, 1.6552e-03],
[-2.2423e-01, -4.3128e-01, -8.8974e-01],
[-1.5826e+00, 1.5835e+00, -5.6433e-01]]],
[[[-2.2423e-01, 2.6639e-01, -1.7292e+00],
[-1.5826e+00, -4.3128e-01, 1.6552e-03],
[1.5835e+00, -5.6433e-01, -8.8974e-01]]],
[[[-1.5826e+00, -2.2423e-01, 2.6639e-01],
[1.5835e+00, -4.3128e-01, -1.7292e+00],
[-5.6433e-01, -8.8974e-01, 1.6552e-03]]],
[[[1.5835e+00, -1.5826e+00, -2.2423e-01],
[-5.6433e-01, -4.3128e-01, 2.6639e-01],
[-8.8974e-01, 1.6552e-03, -1.7292e+00]]],
[[[-1.2951e-01, 1.3493e+00, -1.9329e+00],
[5.6248e-01, -5.1189e-01, 1.3614e+00],
[3.3680e-01, -8.7148e-01, 5.0592e-01]]],
[[[5.6248e-01, -1.2951e-01, 1.3493e+00],
[3.3680e-01, -5.1189e-01, -1.9329e+00],
[-8.7148e-01, 5.0592e-01, 1.3614e+00]]],
[[[3.3680e-01, 5.6248e-01, -1.2951e-01],
[-8.7148e-01, -5.1189e-01, 1.3493e+00],
[5.0592e-01, 1.3614e+00, -1.9329e+00]]],
[[[-8.7148e-01, 3.3680e-01, 5.6248e-01],
[5.0592e-01, -5.1189e-01, -1.2951e-01],
[1.3614e+00, -1.9329e+00, 1.3493e+00]]],
[[[5.0592e-01, -8.7148e-01, 3.3680e-01],
[1.3614e+00, -5.1189e-01, 5.6248e-01],
[-1.9329e+00, 1.3493e+00, -1.2951e-01]]],
[[[1.3614e+00, 5.0592e-01, -8.7148e-01],
[-1.9329e+00, -5.1189e-01, 3.3680e-01],
[1.3493e+00, -1.2951e-01, 5.6248e-01]]],
[[[-1.9329e+00, 1.3614e+00, 5.0592e-01],
[1.3493e+00, -5.1189e-01, -8.7148e-01],
[-1.2951e-01, 5.6248e-01, 3.3680e-01]]],
[[[1.3493e+00, -1.9329e+00, 1.3614e+00],
[-1.2951e-01, -5.1189e-01, 5.0592e-01],
[5.6248e-01, 3.3680e-01, -8.7148e-01]]],
[[[1.6781e-02, -8.3929e-01, 1.2060e+00],
[-1.0764e+00, 4.7821e-01, 1.5342e+00],
[-4.4542e-01, -1.8606e+00, 3.0827e-01]]],
[[[-1.0764e+00, 1.6781e-02, -8.3929e-01],
[-4.4542e-01, 4.7821e-01, 1.2060e+00],
[-1.8606e+00, 3.0827e-01, 1.5342e+00]]],
[[[-4.4542e-01, -1.0764e+00, 1.6781e-02],
[-1.8606e+00, 4.7821e-01, -8.3929e-01],
[3.0827e-01, 1.5342e+00, 1.2060e+00]]],
[[[-1.8606e+00, -4.4542e-01, -1.0764e+00],
[3.0827e-01, 4.7821e-01, 1.6781e-02],
[1.5342e+00, 1.2060e+00, -8.3929e-01]]],
[[[3.0827e-01, -1.8606e+00, -4.4542e-01],
[1.5342e+00, 4.7821e-01, -1.0764e+00],
[1.2060e+00, -8.3929e-01, 1.6781e-02]]],
[[[1.5342e+00, 3.0827e-01, -1.8606e+00],
[1.2060e+00, 4.7821e-01, -4.4542e-01],
[-8.3929e-01, 1.6781e-02, -1.0764e+00]]],
[[[1.2060e+00, 1.5342e+00, 3.0827e-01],
[-8.3929e-01, 4.7821e-01, -1.8606e+00],
[1.6781e-02, -1.0764e+00, -4.4542e-01]]],
[[[-8.3929e-01, 1.2060e+00, 1.5342e+00],
[1.6781e-02, 4.7821e-01, 3.0827e-01],
[-1.0764e+00, -4.4542e-01, -1.8606e+00]]]])
expected_grad = np.array([[[[[8., 8., 8.], [8., 8., 8.], [8., 8., 8.]]]],
[[[[8., 8., 8.], [8., 8., 8.], [8., 8., 8.]]]],
[[[[8., 8., 8.], [8., 8., 8.], [8., 8., 8.]]]],
[[[[8., 8., 8.], [8., 8., 8.], [8., 8., 8.]]]],
[[[[8., 8., 8.], [8., 8., 8.], [8., 8., 8.]]]],
[[[[8., 8., 8.], [8., 8., 8.], [8., 8., 8.]]]],
[[[[8., 8., 8.], [8., 8., 8.], [8., 8., 8.]]]],
[[[[8., 8., 8.], [8., 8., 8.], [8., 8., 8.]]]]])
@pytest.mark.parametrize('device', [
'cpu',
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')),
])
def test_active_rotated_filter(device):
feature = torch.tensor(
np_feature, dtype=torch.float, device=device, requires_grad=True)
indices = torch.tensor(np_indices, dtype=torch.int, device=device)
output = active_rotated_filter(feature, indices)
output.backward(torch.ones_like(output))
assert np.allclose(output.data.cpu().numpy(), expected_output, atol=1e-3)
assert np.allclose(
feature.grad.data.cpu().numpy(), expected_grad, atol=1e-3)