Supports cuda version BorderAlign module (#1021)

* support cuda version `BorderAlign` module

* fix symbolic

* merge

* fix cpp lint error

* add unit test

* add comments for code references

* reformat doc

* fix lint error

* fix conflict
pull/1050/head
v-qjqs 2021-05-25 15:23:11 +08:00 committed by GitHub
parent 732ff5093e
commit 6c7d6c32ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 549 additions and 1 deletions

View File

@ -1,4 +1,5 @@
from .bbox import bbox_overlaps
from .border_align import BorderAlign, border_align
from .box_iou_rotated import box_iou_rotated
from .carafe import CARAFE, CARAFENaive, CARAFEPack, carafe, carafe_naive
from .cc_attention import CrissCrossAttention
@ -48,5 +49,6 @@ __all__ = [
'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign',
'SAConv2d', 'TINShift', 'tin_shift', 'box_iou_rotated', 'nms_rotated',
'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu',
'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'contour_expand'
'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'contour_expand',
'BorderAlign', 'border_align'
]

View File

@ -0,0 +1,108 @@
# modified from
# https://github.com/Megvii-BaseDetection/cvpods/blob/master/cvpods/layers/border_align.py
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from ..utils import ext_loader
ext_module = ext_loader.load_ext(
'_ext', ['border_align_forward', 'border_align_backward'])
class BorderAlignFunction(Function):
@staticmethod
def symbolic(g, input, boxes, pool_size):
return g.op(
'mmcv::MMCVBorderAlign', input, boxes, pool_size_i=pool_size)
@staticmethod
def forward(ctx, input, boxes, pool_size):
ctx.pool_size = pool_size
ctx.input_shape = input.size()
assert boxes.ndim == 3, 'boxes must be with shape [B, H*W, 4]'
assert boxes.size(2) == 4, \
'the last dimension of boxes must be (x1, y1, x2, y2)'
assert input.size(1) % 4 == 0, \
'the channel for input feature must be divisible by factor 4'
# [B, C//4, H*W, 4]
output_shape = (input.size(0), input.size(1) // 4, boxes.size(1), 4)
output = input.new_zeros(output_shape)
# `argmax_idx` only used for backward
argmax_idx = input.new_zeros(output_shape).to(torch.int)
ext_module.border_align_forward(
input, boxes, output, argmax_idx, pool_size=ctx.pool_size)
ctx.save_for_backward(boxes, argmax_idx)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
boxes, argmax_idx = ctx.saved_tensors
grad_input = grad_output.new_zeros(ctx.input_shape)
# complex head architecture may cause grad_output uncontiguous
grad_output = grad_output.contiguous()
ext_module.border_align_backward(
grad_output,
boxes,
argmax_idx,
grad_input,
pool_size=ctx.pool_size)
return grad_input, None, None
border_align = BorderAlignFunction.apply
class BorderAlign(nn.Module):
r"""Border align pooling layer.
Applies border_align over the input feature based on predicted bboxes.
The details were described in the paper
`BorderDet: Border Feature for Dense Object Detection
<https://arxiv.org/abs/2007.11056>`_.
For each border line (e.g. top, left, bottom or right) of each box,
border_align does the following:
1. uniformly samples `pool_size`+1 positions on this line, involving \
the start and end points.
2. the corresponding features on these points are computed by \
bilinear interpolation.
3. max pooling over all the `pool_size`+1 positions are used for \
computing pooled feature.
Args:
pool_size (int): number of positions sampled over the boxes' borders
(e.g. top, bottom, left, right).
"""
def __init__(self, pool_size):
super(BorderAlign, self).__init__()
self.pool_size = pool_size
def forward(self, input, boxes):
"""
Args:
input: Features with shape [N,4C,H,W]. Channels ranged in [0,C),
[C,2C), [2C,3C), [3C,4C) represent the top, left, bottom,
right features respectively.
boxes: Boxes with shape [N,H*W,4]. Coordinate format (x1,y1,x2,y2).
Returns:
Tensor: Pooled features with shape [N,C,H*W,4]. The order is
(top,left,bottom,right) for the last dimension.
"""
return border_align(input, boxes, self.pool_size)
def __repr__(self):
s = self.__class__.__name__
s += f'(pool_size={self.pool_size})'
return s

View File

@ -0,0 +1,199 @@
// modified from
// https://github.com/Megvii-BaseDetection/cvpods/blob/master/cvpods/layers/csrc/border_align/border_align_kernel.cu.
// the main difference: (1) use `argmax_idx` for fast computing of gradient
// during the backward. (2) `wh` is directly computed by `boxes`, rather than
// passing it as argument to forward or backward functions.
#ifndef BORDER_ALIGN_CUDA_KERNEL_CUH
#define BORDER_ALIGN_CUDA_KERNEL_CUH
#include <float.h>
#ifdef MMCV_WITH_TRT
#include "common_cuda_helper.hpp"
#else // MMCV_WITH_TRT
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else // MMCV_USE_PARROTS
#include "pytorch_cuda_helper.hpp"
#endif // MMCV_USE_PARROTS
#endif // MMCV_WITH_TRT
enum BorderMode { Top = 0, Left = 1, Bottom = 2, Right = 3 };
/*** Forward ***/
template <typename T>
__global__ void border_align_forward_cuda_kernel(
const int nthreads, const T* input, const T* boxes, T* output,
int* argmax_idx, const int channels, const int box_size, const int height,
const int width, const int pool_size) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
// (batch_idx, c_idx, box_idx) is an element paralleled for computing
// output, and `extreme_idx` is in range [0,3]
int batch_idx, c_idx, box_idx, extreme_idx, maxidx, *offset_argmax_idx;
const T *offset_box, *offset_input, *offset_box_x;
T *offset_output, box_width, box_height, stride, x_stride, y_stride, x, y,
val, maxval;
extreme_idx = threadIdx.y;
// shape (N, C, box_size, 4) for output
batch_idx = index / channels / box_size;
// shape (N, box_size, 4) for boxes
box_idx = index % box_size + batch_idx * box_size;
c_idx = (index / box_size) % channels;
offset_box = boxes + box_idx * 4;
box_width = *(offset_box + 2) - *offset_box;
box_height = *(offset_box + 3) - *(offset_box + 1);
offset_output = output + index * 4 + extreme_idx;
offset_argmax_idx = argmax_idx + index * 4 + extreme_idx;
// shape (N, 4C, h, w) for input.
// [0,C) for top feature, [C,2C) for left feature,
// [2C,3C) for bottom feature, [3C,4C) for right feature
offset_input =
input + (batch_idx * channels * 4 + extreme_idx * channels + c_idx) *
height * width;
// extreme_idx in [0,1] -> offset_box_x indexed at x1
// extreme_idx in [2,3] -> offset_box_x indexed at x2
offset_box_x = offset_box + extreme_idx / 2 * 2;
// (x1,y1) or (x2,y2) for (x,y)
x = *offset_box_x;
y = *(offset_box_x + 1);
switch (extreme_idx) {
// top
case BorderMode::Top:
stride = box_width / pool_size;
x_stride = stride;
y_stride = 0;
break;
// left
case BorderMode::Left:
stride = box_height / pool_size;
x_stride = 0;
y_stride = stride;
break;
// bottom
case BorderMode::Bottom:
stride = box_width / pool_size;
x_stride = -stride;
y_stride = 0;
break;
// right
case BorderMode::Right:
stride = box_height / pool_size;
x_stride = 0;
y_stride = -stride;
break;
}
// initialize maxval and maxidx with the start position (e.g. (x1,y1) or
// (x2,y2))
maxval = bilinear_interpolate(offset_input, height, width, y, x, index);
maxidx = 0;
// do max_pool along the border
for (int i = 1; i <= pool_size; i++) {
x += x_stride;
y += y_stride;
val = bilinear_interpolate(offset_input, height, width, y, x, index);
if (val > maxval) {
maxval = val;
maxidx = i;
}
}
// update output and argmax_idx
*offset_output = maxval;
*offset_argmax_idx = maxidx;
}
}
/*** Backward ***/
template <typename T>
__global__ void border_align_backward_cuda_kernel(
const int nthreads, const T* grad_output, const T* boxes,
const int* argmax_idx, T* grad_input, const int channels,
const int box_size, const int height, const int width,
const int pool_size) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
// (batch_idx, c_idx, box_idx) is an element paralleled for computing
// output, and `extreme_idx` is in range [0,3]
int batch_idx, c_idx, box_idx, extreme_idx;
const int* offset_argmax_idx;
const T *offset_grad_output, *offset_box, *offset_box_x;
T *offset_grad_input, box_width, box_height, stride, x_stride, y_stride, x,
y;
extreme_idx = threadIdx.y;
batch_idx = index / channels / box_size;
box_idx = index % box_size + batch_idx * box_size;
c_idx = (index / box_size) % channels;
offset_box = boxes + box_idx * 4;
box_width = *(offset_box + 2) - *offset_box;
box_height = *(offset_box + 3) - *(offset_box + 1);
offset_grad_output = grad_output + index * 4 + extreme_idx;
offset_argmax_idx = argmax_idx + index * 4 + extreme_idx;
// [0,C) for top feature grad, [C,2C) for left feature grad,
// [2C,3C) for bottom feature grad, [3C,4C) for right feature grad
offset_grad_input = grad_input + (batch_idx * channels * 4 +
extreme_idx * channels + c_idx) *
height * width;
// extreme_idx in [0,1] -> offset_box_x indexed at x1
// extreme_idx in [2,3] -> offset_box_x indexed at x2
offset_box_x = offset_box + extreme_idx / 2 * 2;
switch (extreme_idx) {
// top
case BorderMode::Top:
stride = box_width / pool_size;
x_stride = stride;
y_stride = 0;
break;
// left
case BorderMode::Left:
stride = box_height / pool_size;
x_stride = 0;
y_stride = stride;
break;
// bottom
case BorderMode::Bottom:
stride = box_width / pool_size;
x_stride = -stride;
y_stride = 0;
break;
// right
case BorderMode::Right:
stride = box_height / pool_size;
x_stride = 0;
y_stride = -stride;
break;
}
// get position (x,y) which has maximum value during forward
x = *offset_box_x;
y = *(offset_box_x + 1);
x += x_stride * (T)(*offset_argmax_idx);
y += y_stride * (T)(*offset_argmax_idx);
T w1, w2, w3, w4;
int x_low, x_high, y_low, y_high;
bilinear_interpolate_gradient(height, width, y, x, w1, w2, w3, w4, x_low,
x_high, y_low, y_high, index);
// update grad_output
atomicAdd(offset_grad_input + y_low * width + x_low,
*offset_grad_output * w1);
atomicAdd(offset_grad_input + y_low * width + x_high,
*offset_grad_output * w2);
atomicAdd(offset_grad_input + y_high * width + x_low,
*offset_grad_output * w3);
atomicAdd(offset_grad_input + y_high * width + x_high,
*offset_grad_output * w4);
}
}
#endif // BORDER_ALIGN_CUDA_KERNEL_CUH

View File

@ -0,0 +1,67 @@
#include "pytorch_cpp_helper.hpp"
#ifdef MMCV_WITH_CUDA
void BorderAlignForwardCUDAKernelLauncher(const Tensor &input,
const Tensor &boxes, Tensor output,
Tensor argmax_idx,
const int pool_size);
void BorderAlignBackwardCUDAKernelLauncher(const Tensor &grad_output,
const Tensor &boxes,
const Tensor &argmax_idx,
Tensor grad_input,
const int pool_size);
void border_align_forward_cuda(const Tensor &input, const Tensor &boxes,
Tensor output, Tensor argmax_idx,
const int pool_size) {
BorderAlignForwardCUDAKernelLauncher(input, boxes, output, argmax_idx,
pool_size);
}
void border_align_backward_cuda(const Tensor &grad_output, const Tensor &boxes,
const Tensor &argmax_idx, Tensor grad_input,
const int pool_size) {
BorderAlignBackwardCUDAKernelLauncher(grad_output, boxes, argmax_idx,
grad_input, pool_size);
}
#endif
void border_align_forward(const Tensor &input, const Tensor &boxes,
Tensor output, Tensor argmax_idx,
const int pool_size) {
if (input.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(input);
CHECK_CUDA_INPUT(boxes);
CHECK_CUDA_INPUT(output);
CHECK_CUDA_INPUT(argmax_idx);
border_align_forward_cuda(input, boxes, output, argmax_idx, pool_size);
#else
AT_ERROR("BorderAlign is not compiled with GPU support");
#endif
} else {
AT_ERROR("BorderAlign is not implemented on CPU");
}
}
void border_align_backward(const Tensor &grad_output, const Tensor &boxes,
const Tensor &argmax_idx, Tensor grad_input,
const int pool_size) {
if (grad_output.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(grad_output);
CHECK_CUDA_INPUT(boxes);
CHECK_CUDA_INPUT(argmax_idx);
CHECK_CUDA_INPUT(grad_input);
border_align_backward_cuda(grad_output, boxes, argmax_idx, grad_input,
pool_size);
#else
AT_ERROR("BorderAlign is not compiled with GPU support");
#endif
} else {
AT_ERROR("BorderAlign is not implemented on CPU");
}
}

View File

@ -0,0 +1,67 @@
#include "border_align_cuda_kernel.cuh"
#include "pytorch_cuda_helper.hpp"
void BorderAlignForwardCUDAKernelLauncher(const Tensor &input,
const Tensor &boxes, Tensor output,
Tensor argmax_idx,
const int pool_size) {
// shape assertion
AT_ASSERTM(input.ndimension() == 4,
"non-empty 4D(batch mode) tensor expected for input feature");
AT_ASSERTM(boxes.ndimension() == 3,
"boxes must be 3D tensor with size of [B, H*W, 4]");
int batch_size = input.size(0);
int feat_channels = input.size(1);
int channels = feat_channels / 4;
int height = input.size(2);
int width = input.size(3);
// shape [N, box_size, 4] for boxes. (x1, y1, x2, y2) format
int box_size = boxes.size(1);
// shape [N, channels, box_size, 4] for output
int nthreads = batch_size * channels * box_size;
at::cuda::CUDAGuard device_guard(input.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 block(128, 4);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "border_align_forward_cuda_kernel", [&] {
border_align_forward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(nthreads), block, 0, stream>>>(
nthreads, input.data_ptr<scalar_t>(),
boxes.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
argmax_idx.data_ptr<int>(), channels, box_size, height, width,
pool_size);
});
AT_CUDA_CHECK(cudaGetLastError());
}
void BorderAlignBackwardCUDAKernelLauncher(const Tensor &grad_output,
const Tensor &boxes,
const Tensor &argmax_idx,
Tensor grad_input,
const int pool_size) {
int batch_size = grad_input.size(0);
int feat_channels = grad_input.size(1);
int channels = feat_channels / 4;
int height = grad_input.size(2);
int width = grad_input.size(3);
int box_size = boxes.size(1);
int nthreads = batch_size * channels * box_size;
at::cuda::CUDAGuard device_guard(grad_output.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 block(128, 4);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_output.scalar_type(), "border_align_backward_cuda_kernel", [&] {
border_align_backward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(nthreads), block, 0, stream>>>(
nthreads, grad_output.data_ptr<scalar_t>(),
boxes.data_ptr<scalar_t>(), argmax_idx.data_ptr<int>(),
grad_input.data_ptr<scalar_t>(), channels, box_size, height,
width, pool_size);
});
AT_CUDA_CHECK(cudaGetLastError());
}

View File

@ -222,6 +222,14 @@ void roi_align_rotated_backward(Tensor grad_output, Tensor rois,
int pooled_width, float spatial_scale,
int sample_num, bool aligned, bool clockwise);
void border_align_forward(const Tensor &input, const Tensor &boxes,
Tensor output, Tensor argmax_idx,
const int pool_size);
void border_align_backward(const Tensor &grad_output, const Tensor &boxes,
const Tensor &argmax_idx, Tensor grad_input,
const int pool_size);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"),
py::arg("kernel"), py::arg("up_x"), py::arg("up_y"), py::arg("down_x"),
@ -447,4 +455,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("attention_weights"), py::arg("grad_output"),
py::arg("grad_value"), py::arg("grad_sampling_loc"),
py::arg("grad_attn_weight"), py::arg("im2col_step"));
m.def("border_align_forward", &border_align_forward,
"forward function of border_align", py::arg("input"), py::arg("boxes"),
py::arg("output"), py::arg("argmax_idx"), py::arg("pool_size"));
m.def("border_align_backward", &border_align_backward,
"backward function of border_align", py::arg("grad_output"),
py::arg("boxes"), py::arg("argmax_idx"), py::arg("grad_input"),
py::arg("pool_size"));
}

View File

@ -0,0 +1,90 @@
import copy
import numpy as np
import pytest
import torch
# [1,4c,h,w]
input_arr = [[[[1., 2., 3., 4.], [5., 6., 7., 8.], [9., 10., 11., 12.]],
[[6, 7, 5, 8], [2, 1, 3, 4], [12, 9, 11, 10]],
[[-2, -3, 2, 0], [-4, -5, 1, -1], [-1, -1, -1, -1]],
[[0, -1, 2, 1], [-4, -3, -2, -1], [-1, -2, -3, -4]]]]
# [1,h*w,4]
boxes_arr = [[[0, 0, 2, 1], [1, 0, 3, 1], [1, 0, 2, 1], [0, 0, 3, 1],
[0, 0, 1, 2], [0, 0, 2, 2], [1, 0, 2, 1], [1, 0, 3, 1],
[0, 1, 1, 2], [0, 0, 3, 2], [1, 0, 3, 2], [2, 0, 3, 2]]]
output_dict = {
# [1,c,h*w,4] for each value,
# the ouput is manually checked for its correctness
# pool_size=1
1: [[[[3., 6., 1., 2.], [4., 7., -1., 1.], [3., 7., 1., 2.],
[4., 6., -1., 1.], [2., 12., -1., -1.], [3., 12., -1., 2.],
[3., 7., 1., 2.], [4., 7., -1., 1.], [6., 12., -1., -2.],
[4., 12., -1., 1.], [4., 9., -1., 1.], [4., 11., -1., 1.]]]],
# pool_size=2
2: [[[[3., 6., 1., 2.], [4., 7., 1., 1.], [3., 7., 1., 2.],
[4., 6., -1., 1.], [2., 12., -1., -1.], [3., 12., -1., 2.],
[3., 7., 1., 2.], [4., 7., 1., 1.], [6., 12., -1., -2.],
[4., 12., -1., 1.], [4., 9., -1., 1.], [4., 11., -1., 1.]]]],
}
input_grad_dict = {
# [1,4c,h,w] for each value
# the grad is manually checked for its correctness
# pool_size=1
1: [[[[0., 1., 4., 6.], [0., 1., 0., 0.], [0., 0., 0., 0.]],
[[2., 4., 0., 0.], [0., 0., 0., 0.], [4., 1., 1., 0.]],
[[0., 0., 0., 0.], [0., 0., 3., 3.], [0., 2., 1., 3.]],
[[0., 1., 4., 6.], [0., 0., 0., 0.], [0., 1., 0., 0.]]]],
# pool_size=2
2: [[[[0., 1., 4., 6.], [0., 1., 0., 0.], [0., 0., 0., 0.]],
[[2., 4., 0., 0.], [0., 0., 0., 0.], [4., 1., 1., 0.]],
[[0., 0., 0., 0.], [0., 0., 5., 1.], [0., 2., 1., 3.]],
[[0., 1., 4., 6.], [0., 0., 0., 0.], [0., 1., 0., 0.]]]],
}
def _test_border_align_allclose(device, dtype, pool_size):
if not torch.cuda.is_available() and device == 'cuda':
pytest.skip('test requires GPU')
try:
from mmcv.ops import border_align, BorderAlign
except ModuleNotFoundError:
pytest.skip('BorderAlign op is not successfully compiled')
np_input = np.array(input_arr)
np_boxes = np.array(boxes_arr)
np_output = np.array(output_dict[pool_size])
np_grad = np.array(input_grad_dict[pool_size])
input = torch.tensor(
np_input, dtype=dtype, device=device, requires_grad=True)
boxes = torch.tensor(np_boxes, dtype=dtype, device=device)
# test for border_align
input_cp = copy.deepcopy(input)
output = border_align(input_cp, boxes, pool_size)
output.backward(torch.ones_like(output))
assert np.allclose(
output.data.type(dtype).cpu().numpy(), np_output, atol=1e-5)
assert np.allclose(
input_cp.grad.data.type(dtype).cpu().numpy(), np_grad, atol=1e-5)
# test for BorderAlign
pool_module = BorderAlign(pool_size)
output = pool_module(input, boxes)
output.backward(torch.ones_like(output))
assert np.allclose(
output.data.type(dtype).cpu().numpy(), np_output, atol=1e-5)
assert np.allclose(
input.grad.data.type(dtype).cpu().numpy(), np_grad, atol=1e-5)
@pytest.mark.parametrize('device', ['cuda'])
@pytest.mark.parametrize('dtype', [torch.float, torch.half, torch.double])
@pytest.mark.parametrize('pool_size', [1, 2])
def test_border_align(device, dtype, pool_size):
_test_border_align_allclose(device, dtype, pool_size)