mirror of https://github.com/open-mmlab/mmcv.git
[Feature] Support PrRoI op for Parrots (#2280)
* Support parrots extension for op PrRoI * Fix lint * Fix cpp lint * Fix testcase failure by false requires_grad in self-defined autograd Funtion * Fix issues * Fix flake8 * Fix isort * Adaption for typechecking for PrRoIPoolFunction * Fix lint * Support only float32 * bugfixpull/2205/head^2
parent
c001e2fcba
commit
dd5b415d2a
|
@ -1624,3 +1624,54 @@ REGISTER_DEVICE_IMPL(chamfer_distance_forward_impl, CUDA,
|
|||
chamfer_distance_forward_cuda);
|
||||
REGISTER_DEVICE_IMPL(chamfer_distance_backward_impl, CUDA,
|
||||
chamfer_distance_backward_cuda);
|
||||
|
||||
void PrROIPoolForwardCUDAKernelLauncher(Tensor input, Tensor rois,
|
||||
Tensor output, int pooled_height,
|
||||
int pooled_width, float spatial_scale);
|
||||
|
||||
void PrROIPoolBackwardCUDAKernelLauncher(Tensor grad_output, Tensor rois,
|
||||
Tensor grad_input, int pooled_height,
|
||||
int pooled_width, float spatial_scale);
|
||||
|
||||
void PrROIPoolCoorBackwardCUDAKernelLauncher(
|
||||
Tensor output, Tensor grad_output, Tensor input, Tensor rois,
|
||||
Tensor grad_rois, int pooled_height, int pooled_width, float spatial_scale);
|
||||
|
||||
void prroi_pool_forward_cuda(Tensor input, Tensor rois, Tensor output,
|
||||
int pooled_height, int pooled_width,
|
||||
float spatial_scale) {
|
||||
PrROIPoolForwardCUDAKernelLauncher(input, rois, output, pooled_height,
|
||||
pooled_width, spatial_scale);
|
||||
}
|
||||
|
||||
void prroi_pool_backward_cuda(Tensor grad_output, Tensor rois,
|
||||
Tensor grad_input, int pooled_height,
|
||||
int pooled_width, float spatial_scale) {
|
||||
PrROIPoolBackwardCUDAKernelLauncher(grad_output, rois, grad_input,
|
||||
pooled_height, pooled_width,
|
||||
spatial_scale);
|
||||
}
|
||||
|
||||
void prroi_pool_coor_backward_cuda(Tensor output, Tensor grad_output,
|
||||
Tensor input, Tensor rois, Tensor grad_rois,
|
||||
int pooled_height, int pooled_width,
|
||||
float spatial_scale) {
|
||||
PrROIPoolCoorBackwardCUDAKernelLauncher(output, grad_output, input, rois,
|
||||
grad_rois, pooled_height,
|
||||
pooled_width, spatial_scale);
|
||||
}
|
||||
|
||||
void prroi_pool_forward_impl(Tensor input, Tensor rois, Tensor output,
|
||||
int pooled_height, int pooled_width,
|
||||
float spatial_scale);
|
||||
void prroi_pool_backward_impl(Tensor grad_output, Tensor rois,
|
||||
Tensor grad_input, int pooled_height,
|
||||
int pooled_width, float spatial_scale);
|
||||
void prroi_pool_coor_backward_impl(Tensor output, Tensor grad_output,
|
||||
Tensor input, Tensor rois, Tensor grad_rois,
|
||||
int pooled_height, int pooled_width,
|
||||
float spatial_scale);
|
||||
REGISTER_DEVICE_IMPL(prroi_pool_forward_impl, CUDA, prroi_pool_forward_cuda);
|
||||
REGISTER_DEVICE_IMPL(prroi_pool_backward_impl, CUDA, prroi_pool_backward_cuda);
|
||||
REGISTER_DEVICE_IMPL(prroi_pool_coor_backward_impl, CUDA,
|
||||
prroi_pool_coor_backward_cuda);
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#include "pytorch_cpp_helper.hpp"
|
||||
#include "pytorch_device_registry.hpp"
|
||||
|
||||
void prroi_pool_forward_impl(Tensor input, Tensor rois, Tensor output,
|
||||
int pooled_height, int pooled_width,
|
||||
float spatial_scale) {
|
||||
DISPATCH_DEVICE_IMPL(prroi_pool_forward_impl, input, rois, output,
|
||||
pooled_height, pooled_width, spatial_scale);
|
||||
}
|
||||
|
||||
void prroi_pool_backward_impl(Tensor grad_output, Tensor rois,
|
||||
Tensor grad_input, int pooled_height,
|
||||
int pooled_width, float spatial_scale) {
|
||||
DISPATCH_DEVICE_IMPL(prroi_pool_backward_impl, grad_output, rois, grad_input,
|
||||
pooled_height, pooled_width, spatial_scale);
|
||||
}
|
||||
|
||||
void prroi_pool_coor_backward_impl(Tensor output, Tensor grad_output,
|
||||
Tensor input, Tensor rois, Tensor grad_rois,
|
||||
int pooled_height, int pooled_width,
|
||||
float spatial_scale) {
|
||||
DISPATCH_DEVICE_IMPL(prroi_pool_coor_backward_impl, output, grad_output,
|
||||
input, rois, grad_rois, pooled_height, pooled_width,
|
||||
spatial_scale);
|
||||
}
|
||||
|
||||
void prroi_pool_forward(Tensor input, Tensor rois, Tensor output,
|
||||
int pooled_height, int pooled_width,
|
||||
float spatial_scale) {
|
||||
prroi_pool_forward_impl(input, rois, output, pooled_height, pooled_width,
|
||||
spatial_scale);
|
||||
}
|
||||
|
||||
void prroi_pool_backward(Tensor grad_output, Tensor rois, Tensor grad_input,
|
||||
int pooled_height, int pooled_width,
|
||||
float spatial_scale) {
|
||||
prroi_pool_backward_impl(grad_output, rois, grad_input, pooled_height,
|
||||
pooled_width, spatial_scale);
|
||||
}
|
||||
|
||||
void prroi_pool_coor_backward(Tensor output, Tensor grad_output, Tensor input,
|
||||
Tensor rois, Tensor grad_rois, int pooled_height,
|
||||
int pooled_width, float spatial_scale) {
|
||||
prroi_pool_coor_backward_impl(output, grad_output, input, rois, grad_rois,
|
||||
pooled_height, pooled_width, spatial_scale);
|
||||
}
|
|
@ -0,0 +1,97 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#include <parrots/compute/aten.hpp>
|
||||
#include <parrots/extension.hpp>
|
||||
#include <parrots/foundation/ssattrs.hpp>
|
||||
|
||||
#include "prroi_pool_pytorch.h"
|
||||
|
||||
using namespace parrots;
|
||||
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
void prroi_pool_forward_cuda_parrots(CudaContext& ctx, const SSElement& attr,
|
||||
const OperatorBase::in_list_t& ins,
|
||||
OperatorBase::out_list_t& outs) {
|
||||
int pooled_height;
|
||||
int pooled_width;
|
||||
float spatial_scale;
|
||||
SSAttrs(attr)
|
||||
.get<int>("pooled_height", pooled_height)
|
||||
.get<int>("pooled_width", pooled_width)
|
||||
.get<float>("spatial_scale", spatial_scale)
|
||||
.done();
|
||||
|
||||
const auto& input = buildATensor(ctx, ins[0]);
|
||||
const auto& rois = buildATensor(ctx, ins[1]);
|
||||
auto output = buildATensor(ctx, outs[0]);
|
||||
prroi_pool_forward(input, rois, output, pooled_height, pooled_width,
|
||||
spatial_scale);
|
||||
}
|
||||
|
||||
void prroi_pool_backward_cuda_parrots(CudaContext& ctx, const SSElement& attr,
|
||||
const OperatorBase::in_list_t& ins,
|
||||
OperatorBase::out_list_t& outs) {
|
||||
int pooled_height;
|
||||
int pooled_width;
|
||||
float spatial_scale;
|
||||
SSAttrs(attr)
|
||||
.get<int>("pooled_height", pooled_height)
|
||||
.get<int>("pooled_width", pooled_width)
|
||||
.get<float>("spatial_scale", spatial_scale)
|
||||
.done();
|
||||
|
||||
const auto& grad_output = buildATensor(ctx, ins[0]);
|
||||
const auto& rois = buildATensor(ctx, ins[1]);
|
||||
auto grad_input = buildATensor(ctx, outs[0]);
|
||||
prroi_pool_backward(grad_output, rois, grad_input, pooled_height,
|
||||
pooled_width, spatial_scale);
|
||||
}
|
||||
|
||||
void prroi_pool_coor_backward_cuda_parrots(CudaContext& ctx,
|
||||
const SSElement& attr,
|
||||
const OperatorBase::in_list_t& ins,
|
||||
OperatorBase::out_list_t& outs) {
|
||||
int pooled_height;
|
||||
int pooled_width;
|
||||
float spatial_scale;
|
||||
SSAttrs(attr)
|
||||
.get<int>("pooled_height", pooled_height)
|
||||
.get<int>("pooled_width", pooled_width)
|
||||
.get<float>("spatial_scale", spatial_scale)
|
||||
.done();
|
||||
|
||||
const auto& output = buildATensor(ctx, ins[0]);
|
||||
const auto& grad_output = buildATensor(ctx, ins[1]);
|
||||
const auto& input = buildATensor(ctx, ins[2]);
|
||||
const auto& rois = buildATensor(ctx, ins[3]);
|
||||
auto grad_rois = buildATensor(ctx, outs[0]);
|
||||
prroi_pool_coor_backward(output, grad_output, input, rois, grad_rois,
|
||||
pooled_height, pooled_width, spatial_scale);
|
||||
}
|
||||
|
||||
PARROTS_EXTENSION_REGISTER(prroi_pool_forward)
|
||||
.attr("pooled_height")
|
||||
.attr("pooled_width")
|
||||
.attr("spatial_scale")
|
||||
.input(2)
|
||||
.output(1)
|
||||
.apply(prroi_pool_forward_cuda_parrots)
|
||||
.done();
|
||||
|
||||
PARROTS_EXTENSION_REGISTER(prroi_pool_backward)
|
||||
.attr("pooled_height")
|
||||
.attr("pooled_width")
|
||||
.attr("spatial_scale")
|
||||
.input(2)
|
||||
.output(1)
|
||||
.apply(prroi_pool_backward_cuda_parrots)
|
||||
.done();
|
||||
|
||||
PARROTS_EXTENSION_REGISTER(prroi_pool_coor_backward)
|
||||
.attr("pooled_height")
|
||||
.attr("pooled_width")
|
||||
.attr("spatial_scale")
|
||||
.input(4)
|
||||
.output(1)
|
||||
.apply(prroi_pool_coor_backward_cuda_parrots)
|
||||
.done();
|
||||
#endif
|
|
@ -0,0 +1,19 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#ifndef PRROI_POOL_PYTORCH_H
|
||||
#define PRROI_POOL_PYTORCH_H
|
||||
#include <torch/extension.h>
|
||||
using namespace at;
|
||||
|
||||
void prroi_pool_forward(Tensor input, Tensor rois, Tensor output,
|
||||
int pooled_height, int pooled_width,
|
||||
float spatial_scale);
|
||||
|
||||
void prroi_pool_backward(Tensor grad_output, Tensor rois, Tensor grad_input,
|
||||
int pooled_height, int pooled_width,
|
||||
float spatial_scale);
|
||||
|
||||
void prroi_pool_coor_backward(Tensor output, Tensor grad_output, Tensor input,
|
||||
Tensor rois, Tensor grad_rois, int pooled_height,
|
||||
int pooled_width, float spatial_scale);
|
||||
|
||||
#endif // PRROI_POOL_PYTORCH_H
|
|
@ -7,7 +7,7 @@ from torch.autograd import Function
|
|||
from torch.autograd.function import once_differentiable
|
||||
from torch.nn.modules.utils import _pair
|
||||
|
||||
from ..utils import ext_loader
|
||||
from ..utils import TORCH_VERSION, ext_loader
|
||||
|
||||
ext_module = ext_loader.load_ext(
|
||||
'_ext',
|
||||
|
@ -32,11 +32,10 @@ class PrRoIPoolFunction(Function):
|
|||
rois: torch.Tensor,
|
||||
output_size: Tuple,
|
||||
spatial_scale: float = 1.0) -> torch.Tensor:
|
||||
if 'FloatTensor' not in features.type(
|
||||
) or 'FloatTensor' not in rois.type():
|
||||
raise ValueError(
|
||||
'Precise RoI Pooling only takes float input, got '
|
||||
f'{features.type()} for features and {rois.type()} for rois.')
|
||||
if features.dtype != torch.float32 or rois.dtype != torch.float32:
|
||||
raise ValueError('Precise RoI Pooling only takes float input, got '
|
||||
f'{features.dtype()} for features and'
|
||||
f'{rois.dtype()} for rois.')
|
||||
|
||||
pooled_height = int(output_size[0])
|
||||
pooled_width = int(output_size[1])
|
||||
|
@ -49,7 +48,13 @@ class PrRoIPoolFunction(Function):
|
|||
output = features.new_zeros(output_shape)
|
||||
params = (pooled_height, pooled_width, spatial_scale)
|
||||
|
||||
ext_module.prroi_pool_forward(features, rois, output, *params)
|
||||
ext_module.prroi_pool_forward(
|
||||
features,
|
||||
rois,
|
||||
output,
|
||||
pooled_height=params[0],
|
||||
pooled_width=params[1],
|
||||
spatial_scale=params[2])
|
||||
ctx.params = params
|
||||
# everything here is contiguous.
|
||||
ctx.save_for_backward(features, rois, output)
|
||||
|
@ -65,14 +70,26 @@ class PrRoIPoolFunction(Function):
|
|||
grad_input = grad_output.new_zeros(*features.shape)
|
||||
grad_coor = grad_output.new_zeros(*rois.shape)
|
||||
|
||||
if features.requires_grad:
|
||||
if features.requires_grad or TORCH_VERSION == 'parrots':
|
||||
grad_output = grad_output.contiguous()
|
||||
ext_module.prroi_pool_backward(grad_output, rois, grad_input,
|
||||
*ctx.params)
|
||||
if rois.requires_grad:
|
||||
ext_module.prroi_pool_backward(
|
||||
grad_output,
|
||||
rois,
|
||||
grad_input,
|
||||
pooled_height=ctx.params[0],
|
||||
pooled_width=ctx.params[1],
|
||||
spatial_scale=ctx.params[2])
|
||||
if rois.requires_grad or TORCH_VERSION == 'parrots':
|
||||
grad_output = grad_output.contiguous()
|
||||
ext_module.prroi_pool_coor_backward(output, grad_output, features,
|
||||
rois, grad_coor, *ctx.params)
|
||||
ext_module.prroi_pool_coor_backward(
|
||||
output,
|
||||
grad_output,
|
||||
features,
|
||||
rois,
|
||||
grad_coor,
|
||||
pooled_height=ctx.params[0],
|
||||
pooled_width=ctx.params[1],
|
||||
spatial_scale=ctx.params[2])
|
||||
|
||||
return grad_input, grad_coor, None, None, None
|
||||
|
||||
|
|
|
@ -59,8 +59,7 @@ class TestPrRoiPool:
|
|||
froipool = PrRoIPool((pool_h, pool_w), spatial_scale)
|
||||
|
||||
if _USING_PARROTS:
|
||||
pass
|
||||
# gradcheck(froipool, (x, rois), no_grads=[rois])
|
||||
gradcheck(froipool, (x, rois), no_grads=[rois])
|
||||
else:
|
||||
gradcheck(froipool, (x, rois), eps=1e-2, atol=1e-2)
|
||||
|
||||
|
|
Loading…
Reference in New Issue