mirror of https://github.com/open-mmlab/mmcv.git
[Feature] Add active rotated filter op for rotated detection. (#1598)
* add active_rotated_filter * fix lint * fix lint * renaming nRotation and nOrientation * Update active_rotated_filter_cuda_kernel.cuh * Update active_rotated_filter_cuda.cu * fix bug * fix lint * Update test_active_rotated_filter.py * fix lint * Update active_rotated_filter_cuda_kernel.cuh * renaming * Update mmcv/ops/active_rotated_filter.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmcv/ops/active_rotated_filter.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmcv/ops/active_rotated_filter.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmcv/ops/active_rotated_filter.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmcv/ops/active_rotated_filter.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmcv/ops/active_rotated_filter.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * fix lint * Resolving conflicts * fix lint. * Update __init__.py * Update mmcv/ops/csrc/pytorch/cuda/active_rotated_filter_cuda.cu Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update active_rotated_filter.cpp * fix lint * Update mmcv/ops/csrc/pytorch/cpu/active_rotated_filter.cpp Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmcv/ops/csrc/pytorch/cpu/active_rotated_filter.cpp Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update active_rotated_filter.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>pull/1610/head^2
parent
40518322b6
commit
9acc892a44
|
@ -2,6 +2,7 @@
|
|||
|
||||
We implement common CUDA ops used in detection, segmentation, etc.
|
||||
|
||||
- ActiveRotatedFilter
|
||||
- AssignScoreWithK
|
||||
- BallQuery
|
||||
- BBoxOverlaps
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
MMCV 提供了检测、分割等任务中常用的 CUDA 算子
|
||||
|
||||
- ActiveRotatedFilter
|
||||
- AssignScoreWithK
|
||||
- BallQuery
|
||||
- BBoxOverlaps
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .active_rotated_filter import active_rotated_filter
|
||||
from .assign_score_withk import assign_score_withk
|
||||
from .ball_query import ball_query
|
||||
from .bbox import bbox_overlaps
|
||||
|
@ -83,5 +84,5 @@ __all__ = [
|
|||
'boxes_iou_bev', 'nms_bev', 'nms_normal_bev', 'Voxelization',
|
||||
'voxelization', 'dynamic_scatter', 'DynamicScatter', 'RoIAwarePool3d',
|
||||
'points_in_boxes_part', 'points_in_boxes_cpu', 'points_in_boxes_all',
|
||||
'points_in_polygons', 'min_area_polygons'
|
||||
'points_in_polygons', 'min_area_polygons', 'active_rotated_filter'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from torch.autograd import Function
|
||||
from torch.autograd.function import once_differentiable
|
||||
|
||||
from ..utils import ext_loader
|
||||
|
||||
ext_module = ext_loader.load_ext(
|
||||
'_ext',
|
||||
['active_rotated_filter_forward', 'active_rotated_filter_backward'])
|
||||
|
||||
|
||||
class ActiveRotatedFilterFunction(Function):
|
||||
"""Encoding the orientation information and generating orientation-
|
||||
sensitive features.
|
||||
|
||||
The details are described in the paper `Align Deep Features for Oriented
|
||||
Object Detection <https://arxiv.org/abs/2008.09397>_`.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input, indices):
|
||||
"""
|
||||
Args:
|
||||
input (torch.Tensor): Input features with shape
|
||||
[num_output_planes, num_input_planes, num_orientations, H, W].
|
||||
indices (torch.Tensor): Indices with shape
|
||||
[num_orientations, H, W, num_rotations].
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Refined features with shape [num_output_planes *
|
||||
num_rotations, num_input_planes * num_orientations, H, W].
|
||||
"""
|
||||
ctx.save_for_backward(input, indices)
|
||||
op, ip, o, h, w = input.size()
|
||||
o, h, w, r = indices.size()
|
||||
output = input.new_zeros((op * r, ip * o, h, w))
|
||||
ext_module.active_rotated_filter_forward(input, indices, output)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@once_differentiable
|
||||
def backward(ctx, grad_out):
|
||||
"""
|
||||
Args:
|
||||
grad_output (torch.Tensor): The gradiant of output features
|
||||
with shape [num_output_planes * num_rotations,
|
||||
num_input_planes * num_orientations, H, W].
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The gradiant of input features with shape
|
||||
[num_output_planes, num_input_planes, num_orientations, H, W].
|
||||
"""
|
||||
input, indices = ctx.saved_tensors
|
||||
grad_in = torch.zeros_like(input)
|
||||
ext_module.active_rotated_filter_backward(grad_out, indices, grad_in)
|
||||
return grad_in, None
|
||||
|
||||
|
||||
active_rotated_filter = ActiveRotatedFilterFunction.apply
|
|
@ -0,0 +1,59 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
// Modified from
|
||||
// https://github.com/csuhan/s2anet/blob/master/mmdet/ops/orn/src/cuda/ActiveRotatingFilter_cuda.cu
|
||||
#ifndef ACTIVE_ROTATED_FILTER_CUDA_KERNEL_CUH
|
||||
#define ACTIVE_ROTATED_FILTER_CUDA_KERNEL_CUH
|
||||
|
||||
#ifdef MMCV_USE_PARROTS
|
||||
#include "parrots_cuda_helper.hpp"
|
||||
#else
|
||||
#include "pytorch_cuda_helper.hpp"
|
||||
#endif
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void active_rotated_filter_forward_cuda_kernel(
|
||||
const int nthreads, const scalar_t* weight_data, const int* indices_data,
|
||||
const int num_input_planes, const int num_output_planes,
|
||||
const int num_orientations, const int num_rotations, const int nEntry,
|
||||
scalar_t* output_data) {
|
||||
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
int l = index % nEntry;
|
||||
int j = (index / nEntry) % num_input_planes;
|
||||
int i = index / nEntry / num_input_planes;
|
||||
int k;
|
||||
scalar_t val = *(weight_data + index);
|
||||
for (k = 0; k < num_rotations; k++) {
|
||||
int idx = (int)(*(indices_data + l * num_rotations + k)) - 1;
|
||||
scalar_t* target = output_data +
|
||||
i * (num_rotations * num_input_planes * nEntry) +
|
||||
k * (num_input_planes * nEntry) + j * (nEntry) + idx;
|
||||
*target = val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void active_rotated_filter_backward_cuda_kernel(
|
||||
const int nthreads, const scalar_t* gradWeight_data,
|
||||
const int* indices_data, const int num_input_planes,
|
||||
const int num_output_planes, const int num_orientations,
|
||||
const int num_rotations, const int nEntry, scalar_t* weight_data) {
|
||||
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
int l = index % nEntry;
|
||||
int j = (index / nEntry) % num_input_planes;
|
||||
int i = index / nEntry / num_input_planes;
|
||||
int k;
|
||||
scalar_t* val = weight_data + index;
|
||||
*val = 0;
|
||||
scalar_t tmp = 0;
|
||||
for (k = 0; k < num_rotations; k++) {
|
||||
int idx = (int)(*(indices_data + l * num_rotations + k)) - 1;
|
||||
scalar_t target =
|
||||
*(gradWeight_data + i * (num_rotations * num_input_planes * nEntry) +
|
||||
k * (num_input_planes * nEntry) + j * (nEntry) + idx);
|
||||
tmp = tmp + target;
|
||||
}
|
||||
*val = tmp;
|
||||
}
|
||||
}
|
||||
#endif // ACTIVE_ROTATED_FILTER_CUDA_KERNEL_CUH
|
|
@ -0,0 +1,28 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
// Modified from
|
||||
// https://github.com/csuhan/s2anet/blob/master/mmdet/ops/orn/src/ActiveRotatingFilter.h
|
||||
|
||||
#include "pytorch_cpp_helper.hpp"
|
||||
#include "pytorch_device_registry.hpp"
|
||||
|
||||
void active_rotated_filter_forward_impl(const Tensor input,
|
||||
const Tensor indices, Tensor output) {
|
||||
DISPATCH_DEVICE_IMPL(active_rotated_filter_forward_impl, input, indices,
|
||||
output);
|
||||
}
|
||||
|
||||
void active_rotated_filter_backward_impl(const Tensor grad_out,
|
||||
const Tensor indices, Tensor grad_in) {
|
||||
DISPATCH_DEVICE_IMPL(active_rotated_filter_backward_impl, grad_out, indices,
|
||||
grad_in);
|
||||
}
|
||||
|
||||
void active_rotated_filter_forward(const Tensor input, const Tensor indices,
|
||||
Tensor output) {
|
||||
active_rotated_filter_forward_impl(input, indices, output);
|
||||
}
|
||||
|
||||
void active_rotated_filter_backward(const Tensor grad_out, const Tensor indices,
|
||||
Tensor grad_in) {
|
||||
active_rotated_filter_backward_impl(grad_out, indices, grad_in);
|
||||
}
|
|
@ -0,0 +1,120 @@
|
|||
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
// modified from
|
||||
// https://github.com/csuhan/s2anet/blob/master/mmdet/ops/orn/src/cpu/ActiveRotatingFilter_cpu.cpp
|
||||
#include "pytorch_cpp_helper.hpp"
|
||||
#include "pytorch_device_registry.hpp"
|
||||
|
||||
template <typename T>
|
||||
void active_rotated_filter_forward_cpu_kernel(
|
||||
const T* weightData, const int* indicesData, const int num_output_planes,
|
||||
const int num_input_planes, const int num_orientations, const int kH,
|
||||
const int kW, const int num_rotations, T* outputData) {
|
||||
const int nEntry = num_orientations * kH * kW;
|
||||
int i, j, l;
|
||||
int k;
|
||||
|
||||
#pragma omp parallel for private(i, j, l, k)
|
||||
for (i = 0; i < num_output_planes; i++) {
|
||||
for (j = 0; j < num_input_planes; j++) {
|
||||
for (l = 0; l < nEntry; l++) {
|
||||
int weightIndex = i * num_input_planes * nEntry + j * nEntry + l;
|
||||
T val = *(weightData + weightIndex);
|
||||
for (k = 0; k < num_rotations; k++) {
|
||||
int index = (int)(*(indicesData + l * num_rotations + k)) - 1;
|
||||
T* target = outputData +
|
||||
i * (num_rotations * num_input_planes * nEntry) +
|
||||
k * (num_input_planes * nEntry) + j * (nEntry) + index;
|
||||
*target = val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void active_rotated_filter_backward_cpu_kernel(
|
||||
const T* gradOutputData, const int* indicesData,
|
||||
const int num_output_planes, const int num_input_planes,
|
||||
const int num_orientations, const int kH, const int kW,
|
||||
const int num_rotations, T* gradInputData) {
|
||||
const int nEntry = num_orientations * kH * kW;
|
||||
int i, j, l;
|
||||
int k;
|
||||
|
||||
#pragma omp parallel for private(i, j, l, k)
|
||||
for (i = 0; i < num_output_planes; i++) {
|
||||
for (j = 0; j < num_input_planes; j++) {
|
||||
for (l = 0; l < nEntry; l++) {
|
||||
int gradInputIndex = i * num_input_planes * nEntry + j * nEntry + l;
|
||||
T* val = gradInputData + gradInputIndex;
|
||||
*val = 0;
|
||||
for (k = 0; k < num_rotations; k++) {
|
||||
int index = (int)(*(indicesData + l * num_rotations + k)) - 1;
|
||||
const T* target =
|
||||
gradOutputData + i * (num_rotations * num_input_planes * nEntry) +
|
||||
k * (num_input_planes * nEntry) + j * (nEntry) + index;
|
||||
*val = *val + *target;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ActiveRotatedFilterForwardCPULauncher(const Tensor input,
|
||||
const Tensor indices,
|
||||
Tensor output) {
|
||||
const int num_output_planes = input.size(0);
|
||||
const int num_input_planes = input.size(1);
|
||||
const int num_orientations = input.size(2);
|
||||
const int kH = input.size(3);
|
||||
const int kW = input.size(4);
|
||||
const int num_rotations = indices.size(3);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
input.scalar_type(), "active_rotated_filter_forward_cpu_kernel", [&] {
|
||||
active_rotated_filter_forward_cpu_kernel<scalar_t>(
|
||||
input.data_ptr<scalar_t>(), indices.data_ptr<int>(),
|
||||
num_output_planes, num_input_planes, num_orientations, kH, kW,
|
||||
num_rotations, output.data_ptr<scalar_t>());
|
||||
});
|
||||
}
|
||||
|
||||
void ActiveRotatedFilterBackwardCPULauncher(const Tensor grad_out,
|
||||
const Tensor indices,
|
||||
Tensor grad_in) {
|
||||
const int num_orientations = indices.size(0);
|
||||
const int kH = indices.size(1);
|
||||
const int kW = indices.size(2);
|
||||
const int num_rotations = indices.size(3);
|
||||
const int num_output_planes = grad_out.size(0) / num_rotations;
|
||||
const int num_input_planes = grad_out.size(1) / num_orientations;
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
grad_out.scalar_type(), "active_rotated_filter_backward_cpu_kernel", [&] {
|
||||
active_rotated_filter_backward_cpu_kernel<scalar_t>(
|
||||
grad_out.data_ptr<scalar_t>(), indices.data_ptr<int>(),
|
||||
num_output_planes, num_input_planes, num_orientations, kH, kW,
|
||||
num_rotations, grad_in.data_ptr<scalar_t>());
|
||||
});
|
||||
}
|
||||
|
||||
void active_rotated_filter_forward_cpu(const Tensor input, const Tensor indices,
|
||||
Tensor output) {
|
||||
ActiveRotatedFilterForwardCPULauncher(input, indices, output);
|
||||
}
|
||||
|
||||
void active_rotated_filter_backward_cpu(const Tensor grad_out,
|
||||
const Tensor indices, Tensor grad_in) {
|
||||
ActiveRotatedFilterBackwardCPULauncher(grad_out, indices, grad_in);
|
||||
}
|
||||
|
||||
void active_rotated_filter_forward_impl(const Tensor input,
|
||||
const Tensor indices, Tensor output);
|
||||
|
||||
void active_rotated_filter_backward_impl(const Tensor grad_out,
|
||||
const Tensor indices, Tensor grad_in);
|
||||
|
||||
REGISTER_DEVICE_IMPL(active_rotated_filter_forward_impl, CPU,
|
||||
active_rotated_filter_forward_cpu);
|
||||
REGISTER_DEVICE_IMPL(active_rotated_filter_backward_impl, CPU,
|
||||
active_rotated_filter_backward_cpu);
|
|
@ -0,0 +1,58 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
// Modified from
|
||||
// https://github.com/csuhan/s2anet/blob/master/mmdet/ops/orn/src/cuda/ActiveRotatingFilter_cuda.cu
|
||||
#include "active_rotated_filter_cuda_kernel.cuh"
|
||||
#include "pytorch_cuda_helper.hpp"
|
||||
|
||||
void ActiveRotatedFilterForwardCUDAKernelLauncher(const Tensor input,
|
||||
const Tensor indices,
|
||||
Tensor output) {
|
||||
int num_output_planes = input.size(0);
|
||||
int num_input_planes = input.size(1);
|
||||
int num_orientations = input.size(2);
|
||||
int kH = input.size(3);
|
||||
int kW = input.size(4);
|
||||
int num_rotations = indices.size(3);
|
||||
int nEntry = num_orientations * kH * kW;
|
||||
int output_size = output.numel();
|
||||
|
||||
at::cuda::CUDAGuard device_guard(input.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
input.scalar_type(), "active_rotated_filter_forward_cuda_kernel", [&] {
|
||||
active_rotated_filter_forward_cuda_kernel<scalar_t>
|
||||
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
|
||||
output_size, input.data_ptr<scalar_t>(),
|
||||
indices.data_ptr<int>(), num_input_planes, num_output_planes,
|
||||
num_orientations, num_rotations, nEntry,
|
||||
output.data_ptr<scalar_t>());
|
||||
});
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
void ActiveRotatedFilterBackwardCUDAKernelLauncher(const Tensor grad_out,
|
||||
const Tensor indices,
|
||||
Tensor grad_in) {
|
||||
int num_orientations = indices.size(0);
|
||||
int kH = indices.size(1);
|
||||
int kW = indices.size(2);
|
||||
int num_rotations = indices.size(3);
|
||||
int num_output_planes = grad_out.size(0) / num_rotations;
|
||||
int num_input_planes = grad_out.size(1) / num_orientations;
|
||||
int nEntry = num_orientations * kH * kW;
|
||||
int output_size = grad_in.numel();
|
||||
|
||||
at::cuda::CUDAGuard device_guard(indices.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
grad_out.scalar_type(), "active_rotated_filter_backward_cuda_kernel",
|
||||
[&] {
|
||||
active_rotated_filter_backward_cuda_kernel<scalar_t>
|
||||
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
|
||||
output_size, grad_out.data_ptr<scalar_t>(),
|
||||
indices.data_ptr<int>(), num_input_planes, num_output_planes,
|
||||
num_orientations, num_rotations, nEntry,
|
||||
grad_in.data_ptr<scalar_t>());
|
||||
});
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
|
@ -1508,4 +1508,34 @@ void min_area_polygons_cuda(const Tensor pointsets, Tensor polygons) {
|
|||
}
|
||||
|
||||
void min_area_polygons_impl(const Tensor pointsets, Tensor polygons);
|
||||
|
||||
REGISTER_DEVICE_IMPL(min_area_polygons_impl, CUDA, min_area_polygons_cuda);
|
||||
|
||||
void ActiveRotatedFilterForwardCUDAKernelLauncher(const Tensor input,
|
||||
const Tensor indices,
|
||||
Tensor output);
|
||||
|
||||
void ActiveRotatedFilterBackwardCUDAKernelLauncher(const Tensor grad_out,
|
||||
const Tensor indices,
|
||||
Tensor grad_in);
|
||||
|
||||
void active_rotated_filter_forward_cuda(const Tensor input,
|
||||
const Tensor indices, Tensor output) {
|
||||
ActiveRotatedFilterForwardCUDAKernelLauncher(input, indices, output);
|
||||
};
|
||||
|
||||
void active_rotated_filter_backward_cuda(const Tensor grad_out,
|
||||
const Tensor indices, Tensor grad_in) {
|
||||
ActiveRotatedFilterBackwardCUDAKernelLauncher(grad_out, indices, grad_in);
|
||||
};
|
||||
|
||||
void active_rotated_filter_forward_impl(const Tensor input,
|
||||
const Tensor indices, Tensor output);
|
||||
|
||||
void active_rotated_filter_backward_impl(const Tensor grad_out,
|
||||
const Tensor indices, Tensor grad_in);
|
||||
|
||||
REGISTER_DEVICE_IMPL(active_rotated_filter_forward_impl, CUDA,
|
||||
active_rotated_filter_forward_cuda);
|
||||
REGISTER_DEVICE_IMPL(active_rotated_filter_backward_impl, CUDA,
|
||||
active_rotated_filter_backward_cuda);
|
||||
|
|
|
@ -365,6 +365,12 @@ void points_in_polygons_forward(Tensor points, Tensor polygons, Tensor output);
|
|||
|
||||
void min_area_polygons(const Tensor pointsets, Tensor polygons);
|
||||
|
||||
void active_rotated_filter_forward(const Tensor input, const Tensor indices,
|
||||
Tensor output);
|
||||
|
||||
void active_rotated_filter_backward(const Tensor grad_out, const Tensor indices,
|
||||
Tensor grad_in);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"),
|
||||
py::arg("kernel"), py::arg("up_x"), py::arg("up_y"), py::arg("down_x"),
|
||||
|
@ -735,4 +741,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||
py::arg("output"));
|
||||
m.def("min_area_polygons", &min_area_polygons, "min_area_polygons",
|
||||
py::arg("pointsets"), py::arg("polygons"));
|
||||
m.def("active_rotated_filter_forward", &active_rotated_filter_forward,
|
||||
"active_rotated_filter_forward", py::arg("input"), py::arg("indices"),
|
||||
py::arg("output"));
|
||||
m.def("active_rotated_filter_backward", &active_rotated_filter_backward,
|
||||
"active_rotated_filter_backward", py::arg("grad_out"),
|
||||
py::arg("indices"), py::arg("grad_in"));
|
||||
}
|
||||
|
|
|
@ -0,0 +1,257 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcv.ops import active_rotated_filter
|
||||
|
||||
np_feature = np.array([[[[[-1.4934e-01, 1.1341e+00, -1.6241e-01],
|
||||
[-1.0986e+00, -1.1463e+00, -1.3176e+00],
|
||||
[1.4808e+00, 7.6572e-01, -1.4548e+00]]]],
|
||||
[[[[1.9370e+00, 6.2799e-01, 2.5834e-02],
|
||||
[-1.4242e+00, 7.6566e-01, 1.0015e+00],
|
||||
[9.8669e-01, 4.1356e-01, 6.1068e-01]]]],
|
||||
[[[[1.4565e+00, 1.4960e+00, 2.4339e-01],
|
||||
[-2.2484e-01, 7.5942e-01, -8.1184e-01],
|
||||
[-1.7077e+00, 1.0658e+00, 3.8311e-01]]]],
|
||||
[[[[8.4734e-01, 1.0904e+00, 2.4356e+00],
|
||||
[9.5822e-01, 2.2260e-01, -2.4450e-01],
|
||||
[-1.5078e+00, 7.0902e-02, -1.5921e+00]]]],
|
||||
[[[[2.1173e+00, -7.3524e-01, 1.8888e+00],
|
||||
[1.0169e+00, 4.7033e-01, -1.0875e+00],
|
||||
[-1.0736e+00, -5.2245e-01, -2.8733e-01]]]],
|
||||
[[[[-5.6433e-01, 1.5835e+00, -1.5826e+00],
|
||||
[-8.8974e-01, -4.3128e-01, -2.2423e-01],
|
||||
[1.6552e-03, -1.7292e+00, 2.6639e-01]]]],
|
||||
[[[[-1.2951e-01, 1.3493e+00, -1.9329e+00],
|
||||
[5.6248e-01, -5.1189e-01, 1.3614e+00],
|
||||
[3.3680e-01, -8.7148e-01, 5.0592e-01]]]],
|
||||
[[[[1.6781e-02, -8.3929e-01, 1.2060e+00],
|
||||
[-1.0764e+00, 4.7821e-01, 1.5342e+00],
|
||||
[-4.4542e-01, -1.8606e+00, 3.0827e-01]]]]])
|
||||
|
||||
np_indices = np.array([[[[1, 2, 3, 6, 9, 8, 7, 4], [2, 3, 6, 9, 8, 7, 4, 1],
|
||||
[3, 6, 9, 8, 7, 4, 1, 2]],
|
||||
[[4, 1, 2, 3, 6, 9, 8, 7], [5, 5, 5, 5, 5, 5, 5, 5],
|
||||
[6, 9, 8, 7, 4, 1, 2, 3]],
|
||||
[[7, 4, 1, 2, 3, 6, 9, 8], [8, 7, 4, 1, 2, 3, 6, 9],
|
||||
[9, 8, 7, 4, 1, 2, 3, 6]]]])
|
||||
|
||||
expected_output = np.array([[[[-1.4934e-01, 1.1341e+00, -1.6241e-01],
|
||||
[-1.0986e+00, -1.1463e+00, -1.3176e+00],
|
||||
[1.4808e+00, 7.6572e-01, -1.4548e+00]]],
|
||||
[[[-1.0986e+00, -1.4934e-01, 1.1341e+00],
|
||||
[1.4808e+00, -1.1463e+00, -1.6241e-01],
|
||||
[7.6572e-01, -1.4548e+00, -1.3176e+00]]],
|
||||
[[[1.4808e+00, -1.0986e+00, -1.4934e-01],
|
||||
[7.6572e-01, -1.1463e+00, 1.1341e+00],
|
||||
[-1.4548e+00, -1.3176e+00, -1.6241e-01]]],
|
||||
[[[7.6572e-01, 1.4808e+00, -1.0986e+00],
|
||||
[-1.4548e+00, -1.1463e+00, -1.4934e-01],
|
||||
[-1.3176e+00, -1.6241e-01, 1.1341e+00]]],
|
||||
[[[-1.4548e+00, 7.6572e-01, 1.4808e+00],
|
||||
[-1.3176e+00, -1.1463e+00, -1.0986e+00],
|
||||
[-1.6241e-01, 1.1341e+00, -1.4934e-01]]],
|
||||
[[[-1.3176e+00, -1.4548e+00, 7.6572e-01],
|
||||
[-1.6241e-01, -1.1463e+00, 1.4808e+00],
|
||||
[1.1341e+00, -1.4934e-01, -1.0986e+00]]],
|
||||
[[[-1.6241e-01, -1.3176e+00, -1.4548e+00],
|
||||
[1.1341e+00, -1.1463e+00, 7.6572e-01],
|
||||
[-1.4934e-01, -1.0986e+00, 1.4808e+00]]],
|
||||
[[[1.1341e+00, -1.6241e-01, -1.3176e+00],
|
||||
[-1.4934e-01, -1.1463e+00, -1.4548e+00],
|
||||
[-1.0986e+00, 1.4808e+00, 7.6572e-01]]],
|
||||
[[[1.9370e+00, 6.2799e-01, 2.5834e-02],
|
||||
[-1.4242e+00, 7.6566e-01, 1.0015e+00],
|
||||
[9.8669e-01, 4.1356e-01, 6.1068e-01]]],
|
||||
[[[-1.4242e+00, 1.9370e+00, 6.2799e-01],
|
||||
[9.8669e-01, 7.6566e-01, 2.5834e-02],
|
||||
[4.1356e-01, 6.1068e-01, 1.0015e+00]]],
|
||||
[[[9.8669e-01, -1.4242e+00, 1.9370e+00],
|
||||
[4.1356e-01, 7.6566e-01, 6.2799e-01],
|
||||
[6.1068e-01, 1.0015e+00, 2.5834e-02]]],
|
||||
[[[4.1356e-01, 9.8669e-01, -1.4242e+00],
|
||||
[6.1068e-01, 7.6566e-01, 1.9370e+00],
|
||||
[1.0015e+00, 2.5834e-02, 6.2799e-01]]],
|
||||
[[[6.1068e-01, 4.1356e-01, 9.8669e-01],
|
||||
[1.0015e+00, 7.6566e-01, -1.4242e+00],
|
||||
[2.5834e-02, 6.2799e-01, 1.9370e+00]]],
|
||||
[[[1.0015e+00, 6.1068e-01, 4.1356e-01],
|
||||
[2.5834e-02, 7.6566e-01, 9.8669e-01],
|
||||
[6.2799e-01, 1.9370e+00, -1.4242e+00]]],
|
||||
[[[2.5834e-02, 1.0015e+00, 6.1068e-01],
|
||||
[6.2799e-01, 7.6566e-01, 4.1356e-01],
|
||||
[1.9370e+00, -1.4242e+00, 9.8669e-01]]],
|
||||
[[[6.2799e-01, 2.5834e-02, 1.0015e+00],
|
||||
[1.9370e+00, 7.6566e-01, 6.1068e-01],
|
||||
[-1.4242e+00, 9.8669e-01, 4.1356e-01]]],
|
||||
[[[1.4565e+00, 1.4960e+00, 2.4339e-01],
|
||||
[-2.2484e-01, 7.5942e-01, -8.1184e-01],
|
||||
[-1.7077e+00, 1.0658e+00, 3.8311e-01]]],
|
||||
[[[-2.2484e-01, 1.4565e+00, 1.4960e+00],
|
||||
[-1.7077e+00, 7.5942e-01, 2.4339e-01],
|
||||
[1.0658e+00, 3.8311e-01, -8.1184e-01]]],
|
||||
[[[-1.7077e+00, -2.2484e-01, 1.4565e+00],
|
||||
[1.0658e+00, 7.5942e-01, 1.4960e+00],
|
||||
[3.8311e-01, -8.1184e-01, 2.4339e-01]]],
|
||||
[[[1.0658e+00, -1.7077e+00, -2.2484e-01],
|
||||
[3.8311e-01, 7.5942e-01, 1.4565e+00],
|
||||
[-8.1184e-01, 2.4339e-01, 1.4960e+00]]],
|
||||
[[[3.8311e-01, 1.0658e+00, -1.7077e+00],
|
||||
[-8.1184e-01, 7.5942e-01, -2.2484e-01],
|
||||
[2.4339e-01, 1.4960e+00, 1.4565e+00]]],
|
||||
[[[-8.1184e-01, 3.8311e-01, 1.0658e+00],
|
||||
[2.4339e-01, 7.5942e-01, -1.7077e+00],
|
||||
[1.4960e+00, 1.4565e+00, -2.2484e-01]]],
|
||||
[[[2.4339e-01, -8.1184e-01, 3.8311e-01],
|
||||
[1.4960e+00, 7.5942e-01, 1.0658e+00],
|
||||
[1.4565e+00, -2.2484e-01, -1.7077e+00]]],
|
||||
[[[1.4960e+00, 2.4339e-01, -8.1184e-01],
|
||||
[1.4565e+00, 7.5942e-01, 3.8311e-01],
|
||||
[-2.2484e-01, -1.7077e+00, 1.0658e+00]]],
|
||||
[[[8.4734e-01, 1.0904e+00, 2.4356e+00],
|
||||
[9.5822e-01, 2.2260e-01, -2.4450e-01],
|
||||
[-1.5078e+00, 7.0902e-02, -1.5921e+00]]],
|
||||
[[[9.5822e-01, 8.4734e-01, 1.0904e+00],
|
||||
[-1.5078e+00, 2.2260e-01, 2.4356e+00],
|
||||
[7.0902e-02, -1.5921e+00, -2.4450e-01]]],
|
||||
[[[-1.5078e+00, 9.5822e-01, 8.4734e-01],
|
||||
[7.0902e-02, 2.2260e-01, 1.0904e+00],
|
||||
[-1.5921e+00, -2.4450e-01, 2.4356e+00]]],
|
||||
[[[7.0902e-02, -1.5078e+00, 9.5822e-01],
|
||||
[-1.5921e+00, 2.2260e-01, 8.4734e-01],
|
||||
[-2.4450e-01, 2.4356e+00, 1.0904e+00]]],
|
||||
[[[-1.5921e+00, 7.0902e-02, -1.5078e+00],
|
||||
[-2.4450e-01, 2.2260e-01, 9.5822e-01],
|
||||
[2.4356e+00, 1.0904e+00, 8.4734e-01]]],
|
||||
[[[-2.4450e-01, -1.5921e+00, 7.0902e-02],
|
||||
[2.4356e+00, 2.2260e-01, -1.5078e+00],
|
||||
[1.0904e+00, 8.4734e-01, 9.5822e-01]]],
|
||||
[[[2.4356e+00, -2.4450e-01, -1.5921e+00],
|
||||
[1.0904e+00, 2.2260e-01, 7.0902e-02],
|
||||
[8.4734e-01, 9.5822e-01, -1.5078e+00]]],
|
||||
[[[1.0904e+00, 2.4356e+00, -2.4450e-01],
|
||||
[8.4734e-01, 2.2260e-01, -1.5921e+00],
|
||||
[9.5822e-01, -1.5078e+00, 7.0902e-02]]],
|
||||
[[[2.1173e+00, -7.3524e-01, 1.8888e+00],
|
||||
[1.0169e+00, 4.7033e-01, -1.0875e+00],
|
||||
[-1.0736e+00, -5.2245e-01, -2.8733e-01]]],
|
||||
[[[1.0169e+00, 2.1173e+00, -7.3524e-01],
|
||||
[-1.0736e+00, 4.7033e-01, 1.8888e+00],
|
||||
[-5.2245e-01, -2.8733e-01, -1.0875e+00]]],
|
||||
[[[-1.0736e+00, 1.0169e+00, 2.1173e+00],
|
||||
[-5.2245e-01, 4.7033e-01, -7.3524e-01],
|
||||
[-2.8733e-01, -1.0875e+00, 1.8888e+00]]],
|
||||
[[[-5.2245e-01, -1.0736e+00, 1.0169e+00],
|
||||
[-2.8733e-01, 4.7033e-01, 2.1173e+00],
|
||||
[-1.0875e+00, 1.8888e+00, -7.3524e-01]]],
|
||||
[[[-2.8733e-01, -5.2245e-01, -1.0736e+00],
|
||||
[-1.0875e+00, 4.7033e-01, 1.0169e+00],
|
||||
[1.8888e+00, -7.3524e-01, 2.1173e+00]]],
|
||||
[[[-1.0875e+00, -2.8733e-01, -5.2245e-01],
|
||||
[1.8888e+00, 4.7033e-01, -1.0736e+00],
|
||||
[-7.3524e-01, 2.1173e+00, 1.0169e+00]]],
|
||||
[[[1.8888e+00, -1.0875e+00, -2.8733e-01],
|
||||
[-7.3524e-01, 4.7033e-01, -5.2245e-01],
|
||||
[2.1173e+00, 1.0169e+00, -1.0736e+00]]],
|
||||
[[[-7.3524e-01, 1.8888e+00, -1.0875e+00],
|
||||
[2.1173e+00, 4.7033e-01, -2.8733e-01],
|
||||
[1.0169e+00, -1.0736e+00, -5.2245e-01]]],
|
||||
[[[-5.6433e-01, 1.5835e+00, -1.5826e+00],
|
||||
[-8.8974e-01, -4.3128e-01, -2.2423e-01],
|
||||
[1.6552e-03, -1.7292e+00, 2.6639e-01]]],
|
||||
[[[-8.8974e-01, -5.6433e-01, 1.5835e+00],
|
||||
[1.6552e-03, -4.3128e-01, -1.5826e+00],
|
||||
[-1.7292e+00, 2.6639e-01, -2.2423e-01]]],
|
||||
[[[1.6552e-03, -8.8974e-01, -5.6433e-01],
|
||||
[-1.7292e+00, -4.3128e-01, 1.5835e+00],
|
||||
[2.6639e-01, -2.2423e-01, -1.5826e+00]]],
|
||||
[[[-1.7292e+00, 1.6552e-03, -8.8974e-01],
|
||||
[2.6639e-01, -4.3128e-01, -5.6433e-01],
|
||||
[-2.2423e-01, -1.5826e+00, 1.5835e+00]]],
|
||||
[[[2.6639e-01, -1.7292e+00, 1.6552e-03],
|
||||
[-2.2423e-01, -4.3128e-01, -8.8974e-01],
|
||||
[-1.5826e+00, 1.5835e+00, -5.6433e-01]]],
|
||||
[[[-2.2423e-01, 2.6639e-01, -1.7292e+00],
|
||||
[-1.5826e+00, -4.3128e-01, 1.6552e-03],
|
||||
[1.5835e+00, -5.6433e-01, -8.8974e-01]]],
|
||||
[[[-1.5826e+00, -2.2423e-01, 2.6639e-01],
|
||||
[1.5835e+00, -4.3128e-01, -1.7292e+00],
|
||||
[-5.6433e-01, -8.8974e-01, 1.6552e-03]]],
|
||||
[[[1.5835e+00, -1.5826e+00, -2.2423e-01],
|
||||
[-5.6433e-01, -4.3128e-01, 2.6639e-01],
|
||||
[-8.8974e-01, 1.6552e-03, -1.7292e+00]]],
|
||||
[[[-1.2951e-01, 1.3493e+00, -1.9329e+00],
|
||||
[5.6248e-01, -5.1189e-01, 1.3614e+00],
|
||||
[3.3680e-01, -8.7148e-01, 5.0592e-01]]],
|
||||
[[[5.6248e-01, -1.2951e-01, 1.3493e+00],
|
||||
[3.3680e-01, -5.1189e-01, -1.9329e+00],
|
||||
[-8.7148e-01, 5.0592e-01, 1.3614e+00]]],
|
||||
[[[3.3680e-01, 5.6248e-01, -1.2951e-01],
|
||||
[-8.7148e-01, -5.1189e-01, 1.3493e+00],
|
||||
[5.0592e-01, 1.3614e+00, -1.9329e+00]]],
|
||||
[[[-8.7148e-01, 3.3680e-01, 5.6248e-01],
|
||||
[5.0592e-01, -5.1189e-01, -1.2951e-01],
|
||||
[1.3614e+00, -1.9329e+00, 1.3493e+00]]],
|
||||
[[[5.0592e-01, -8.7148e-01, 3.3680e-01],
|
||||
[1.3614e+00, -5.1189e-01, 5.6248e-01],
|
||||
[-1.9329e+00, 1.3493e+00, -1.2951e-01]]],
|
||||
[[[1.3614e+00, 5.0592e-01, -8.7148e-01],
|
||||
[-1.9329e+00, -5.1189e-01, 3.3680e-01],
|
||||
[1.3493e+00, -1.2951e-01, 5.6248e-01]]],
|
||||
[[[-1.9329e+00, 1.3614e+00, 5.0592e-01],
|
||||
[1.3493e+00, -5.1189e-01, -8.7148e-01],
|
||||
[-1.2951e-01, 5.6248e-01, 3.3680e-01]]],
|
||||
[[[1.3493e+00, -1.9329e+00, 1.3614e+00],
|
||||
[-1.2951e-01, -5.1189e-01, 5.0592e-01],
|
||||
[5.6248e-01, 3.3680e-01, -8.7148e-01]]],
|
||||
[[[1.6781e-02, -8.3929e-01, 1.2060e+00],
|
||||
[-1.0764e+00, 4.7821e-01, 1.5342e+00],
|
||||
[-4.4542e-01, -1.8606e+00, 3.0827e-01]]],
|
||||
[[[-1.0764e+00, 1.6781e-02, -8.3929e-01],
|
||||
[-4.4542e-01, 4.7821e-01, 1.2060e+00],
|
||||
[-1.8606e+00, 3.0827e-01, 1.5342e+00]]],
|
||||
[[[-4.4542e-01, -1.0764e+00, 1.6781e-02],
|
||||
[-1.8606e+00, 4.7821e-01, -8.3929e-01],
|
||||
[3.0827e-01, 1.5342e+00, 1.2060e+00]]],
|
||||
[[[-1.8606e+00, -4.4542e-01, -1.0764e+00],
|
||||
[3.0827e-01, 4.7821e-01, 1.6781e-02],
|
||||
[1.5342e+00, 1.2060e+00, -8.3929e-01]]],
|
||||
[[[3.0827e-01, -1.8606e+00, -4.4542e-01],
|
||||
[1.5342e+00, 4.7821e-01, -1.0764e+00],
|
||||
[1.2060e+00, -8.3929e-01, 1.6781e-02]]],
|
||||
[[[1.5342e+00, 3.0827e-01, -1.8606e+00],
|
||||
[1.2060e+00, 4.7821e-01, -4.4542e-01],
|
||||
[-8.3929e-01, 1.6781e-02, -1.0764e+00]]],
|
||||
[[[1.2060e+00, 1.5342e+00, 3.0827e-01],
|
||||
[-8.3929e-01, 4.7821e-01, -1.8606e+00],
|
||||
[1.6781e-02, -1.0764e+00, -4.4542e-01]]],
|
||||
[[[-8.3929e-01, 1.2060e+00, 1.5342e+00],
|
||||
[1.6781e-02, 4.7821e-01, 3.0827e-01],
|
||||
[-1.0764e+00, -4.4542e-01, -1.8606e+00]]]])
|
||||
|
||||
expected_grad = np.array([[[[[8., 8., 8.], [8., 8., 8.], [8., 8., 8.]]]],
|
||||
[[[[8., 8., 8.], [8., 8., 8.], [8., 8., 8.]]]],
|
||||
[[[[8., 8., 8.], [8., 8., 8.], [8., 8., 8.]]]],
|
||||
[[[[8., 8., 8.], [8., 8., 8.], [8., 8., 8.]]]],
|
||||
[[[[8., 8., 8.], [8., 8., 8.], [8., 8., 8.]]]],
|
||||
[[[[8., 8., 8.], [8., 8., 8.], [8., 8., 8.]]]],
|
||||
[[[[8., 8., 8.], [8., 8., 8.], [8., 8., 8.]]]],
|
||||
[[[[8., 8., 8.], [8., 8., 8.], [8., 8., 8.]]]]])
|
||||
|
||||
|
||||
@pytest.mark.parametrize('device', [
|
||||
'cpu',
|
||||
pytest.param(
|
||||
'cuda',
|
||||
marks=pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='requires CUDA support')),
|
||||
])
|
||||
def test_active_rotated_filter(device):
|
||||
feature = torch.tensor(
|
||||
np_feature, dtype=torch.float, device=device, requires_grad=True)
|
||||
indices = torch.tensor(np_indices, dtype=torch.int, device=device)
|
||||
output = active_rotated_filter(feature, indices)
|
||||
output.backward(torch.ones_like(output))
|
||||
assert np.allclose(output.data.cpu().numpy(), expected_output, atol=1e-3)
|
||||
assert np.allclose(
|
||||
feature.grad.data.cpu().numpy(), expected_grad, atol=1e-3)
|
Loading…
Reference in New Issue