[Feature] Support PrRoI op for Parrots (#2280)

* Support parrots extension for op PrRoI

* Fix lint

* Fix cpp lint

* Fix testcase failure by false requires_grad in self-defined autograd Funtion

* Fix issues

* Fix flake8

* Fix isort

* Adaption for typechecking for PrRoIPoolFunction

* Fix lint

* Support only float32

* bugfix
pull/2205/head^2
CokeDong 2022-10-11 15:21:00 +08:00 committed by GitHub
parent c001e2fcba
commit dd5b415d2a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 245 additions and 15 deletions

View File

@ -1624,3 +1624,54 @@ REGISTER_DEVICE_IMPL(chamfer_distance_forward_impl, CUDA,
chamfer_distance_forward_cuda);
REGISTER_DEVICE_IMPL(chamfer_distance_backward_impl, CUDA,
chamfer_distance_backward_cuda);
void PrROIPoolForwardCUDAKernelLauncher(Tensor input, Tensor rois,
Tensor output, int pooled_height,
int pooled_width, float spatial_scale);
void PrROIPoolBackwardCUDAKernelLauncher(Tensor grad_output, Tensor rois,
Tensor grad_input, int pooled_height,
int pooled_width, float spatial_scale);
void PrROIPoolCoorBackwardCUDAKernelLauncher(
Tensor output, Tensor grad_output, Tensor input, Tensor rois,
Tensor grad_rois, int pooled_height, int pooled_width, float spatial_scale);
void prroi_pool_forward_cuda(Tensor input, Tensor rois, Tensor output,
int pooled_height, int pooled_width,
float spatial_scale) {
PrROIPoolForwardCUDAKernelLauncher(input, rois, output, pooled_height,
pooled_width, spatial_scale);
}
void prroi_pool_backward_cuda(Tensor grad_output, Tensor rois,
Tensor grad_input, int pooled_height,
int pooled_width, float spatial_scale) {
PrROIPoolBackwardCUDAKernelLauncher(grad_output, rois, grad_input,
pooled_height, pooled_width,
spatial_scale);
}
void prroi_pool_coor_backward_cuda(Tensor output, Tensor grad_output,
Tensor input, Tensor rois, Tensor grad_rois,
int pooled_height, int pooled_width,
float spatial_scale) {
PrROIPoolCoorBackwardCUDAKernelLauncher(output, grad_output, input, rois,
grad_rois, pooled_height,
pooled_width, spatial_scale);
}
void prroi_pool_forward_impl(Tensor input, Tensor rois, Tensor output,
int pooled_height, int pooled_width,
float spatial_scale);
void prroi_pool_backward_impl(Tensor grad_output, Tensor rois,
Tensor grad_input, int pooled_height,
int pooled_width, float spatial_scale);
void prroi_pool_coor_backward_impl(Tensor output, Tensor grad_output,
Tensor input, Tensor rois, Tensor grad_rois,
int pooled_height, int pooled_width,
float spatial_scale);
REGISTER_DEVICE_IMPL(prroi_pool_forward_impl, CUDA, prroi_pool_forward_cuda);
REGISTER_DEVICE_IMPL(prroi_pool_backward_impl, CUDA, prroi_pool_backward_cuda);
REGISTER_DEVICE_IMPL(prroi_pool_coor_backward_impl, CUDA,
prroi_pool_coor_backward_cuda);

View File

@ -0,0 +1,47 @@
// Copyright (c) OpenMMLab. All rights reserved
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"
void prroi_pool_forward_impl(Tensor input, Tensor rois, Tensor output,
int pooled_height, int pooled_width,
float spatial_scale) {
DISPATCH_DEVICE_IMPL(prroi_pool_forward_impl, input, rois, output,
pooled_height, pooled_width, spatial_scale);
}
void prroi_pool_backward_impl(Tensor grad_output, Tensor rois,
Tensor grad_input, int pooled_height,
int pooled_width, float spatial_scale) {
DISPATCH_DEVICE_IMPL(prroi_pool_backward_impl, grad_output, rois, grad_input,
pooled_height, pooled_width, spatial_scale);
}
void prroi_pool_coor_backward_impl(Tensor output, Tensor grad_output,
Tensor input, Tensor rois, Tensor grad_rois,
int pooled_height, int pooled_width,
float spatial_scale) {
DISPATCH_DEVICE_IMPL(prroi_pool_coor_backward_impl, output, grad_output,
input, rois, grad_rois, pooled_height, pooled_width,
spatial_scale);
}
void prroi_pool_forward(Tensor input, Tensor rois, Tensor output,
int pooled_height, int pooled_width,
float spatial_scale) {
prroi_pool_forward_impl(input, rois, output, pooled_height, pooled_width,
spatial_scale);
}
void prroi_pool_backward(Tensor grad_output, Tensor rois, Tensor grad_input,
int pooled_height, int pooled_width,
float spatial_scale) {
prroi_pool_backward_impl(grad_output, rois, grad_input, pooled_height,
pooled_width, spatial_scale);
}
void prroi_pool_coor_backward(Tensor output, Tensor grad_output, Tensor input,
Tensor rois, Tensor grad_rois, int pooled_height,
int pooled_width, float spatial_scale) {
prroi_pool_coor_backward_impl(output, grad_output, input, rois, grad_rois,
pooled_height, pooled_width, spatial_scale);
}

View File

@ -0,0 +1,97 @@
// Copyright (c) OpenMMLab. All rights reserved
#include <parrots/compute/aten.hpp>
#include <parrots/extension.hpp>
#include <parrots/foundation/ssattrs.hpp>
#include "prroi_pool_pytorch.h"
using namespace parrots;
#ifdef MMCV_WITH_CUDA
void prroi_pool_forward_cuda_parrots(CudaContext& ctx, const SSElement& attr,
const OperatorBase::in_list_t& ins,
OperatorBase::out_list_t& outs) {
int pooled_height;
int pooled_width;
float spatial_scale;
SSAttrs(attr)
.get<int>("pooled_height", pooled_height)
.get<int>("pooled_width", pooled_width)
.get<float>("spatial_scale", spatial_scale)
.done();
const auto& input = buildATensor(ctx, ins[0]);
const auto& rois = buildATensor(ctx, ins[1]);
auto output = buildATensor(ctx, outs[0]);
prroi_pool_forward(input, rois, output, pooled_height, pooled_width,
spatial_scale);
}
void prroi_pool_backward_cuda_parrots(CudaContext& ctx, const SSElement& attr,
const OperatorBase::in_list_t& ins,
OperatorBase::out_list_t& outs) {
int pooled_height;
int pooled_width;
float spatial_scale;
SSAttrs(attr)
.get<int>("pooled_height", pooled_height)
.get<int>("pooled_width", pooled_width)
.get<float>("spatial_scale", spatial_scale)
.done();
const auto& grad_output = buildATensor(ctx, ins[0]);
const auto& rois = buildATensor(ctx, ins[1]);
auto grad_input = buildATensor(ctx, outs[0]);
prroi_pool_backward(grad_output, rois, grad_input, pooled_height,
pooled_width, spatial_scale);
}
void prroi_pool_coor_backward_cuda_parrots(CudaContext& ctx,
const SSElement& attr,
const OperatorBase::in_list_t& ins,
OperatorBase::out_list_t& outs) {
int pooled_height;
int pooled_width;
float spatial_scale;
SSAttrs(attr)
.get<int>("pooled_height", pooled_height)
.get<int>("pooled_width", pooled_width)
.get<float>("spatial_scale", spatial_scale)
.done();
const auto& output = buildATensor(ctx, ins[0]);
const auto& grad_output = buildATensor(ctx, ins[1]);
const auto& input = buildATensor(ctx, ins[2]);
const auto& rois = buildATensor(ctx, ins[3]);
auto grad_rois = buildATensor(ctx, outs[0]);
prroi_pool_coor_backward(output, grad_output, input, rois, grad_rois,
pooled_height, pooled_width, spatial_scale);
}
PARROTS_EXTENSION_REGISTER(prroi_pool_forward)
.attr("pooled_height")
.attr("pooled_width")
.attr("spatial_scale")
.input(2)
.output(1)
.apply(prroi_pool_forward_cuda_parrots)
.done();
PARROTS_EXTENSION_REGISTER(prroi_pool_backward)
.attr("pooled_height")
.attr("pooled_width")
.attr("spatial_scale")
.input(2)
.output(1)
.apply(prroi_pool_backward_cuda_parrots)
.done();
PARROTS_EXTENSION_REGISTER(prroi_pool_coor_backward)
.attr("pooled_height")
.attr("pooled_width")
.attr("spatial_scale")
.input(4)
.output(1)
.apply(prroi_pool_coor_backward_cuda_parrots)
.done();
#endif

View File

@ -0,0 +1,19 @@
// Copyright (c) OpenMMLab. All rights reserved
#ifndef PRROI_POOL_PYTORCH_H
#define PRROI_POOL_PYTORCH_H
#include <torch/extension.h>
using namespace at;
void prroi_pool_forward(Tensor input, Tensor rois, Tensor output,
int pooled_height, int pooled_width,
float spatial_scale);
void prroi_pool_backward(Tensor grad_output, Tensor rois, Tensor grad_input,
int pooled_height, int pooled_width,
float spatial_scale);
void prroi_pool_coor_backward(Tensor output, Tensor grad_output, Tensor input,
Tensor rois, Tensor grad_rois, int pooled_height,
int pooled_width, float spatial_scale);
#endif // PRROI_POOL_PYTORCH_H

View File

@ -7,7 +7,7 @@ from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair
from ..utils import ext_loader
from ..utils import TORCH_VERSION, ext_loader
ext_module = ext_loader.load_ext(
'_ext',
@ -32,11 +32,10 @@ class PrRoIPoolFunction(Function):
rois: torch.Tensor,
output_size: Tuple,
spatial_scale: float = 1.0) -> torch.Tensor:
if 'FloatTensor' not in features.type(
) or 'FloatTensor' not in rois.type():
raise ValueError(
'Precise RoI Pooling only takes float input, got '
f'{features.type()} for features and {rois.type()} for rois.')
if features.dtype != torch.float32 or rois.dtype != torch.float32:
raise ValueError('Precise RoI Pooling only takes float input, got '
f'{features.dtype()} for features and'
f'{rois.dtype()} for rois.')
pooled_height = int(output_size[0])
pooled_width = int(output_size[1])
@ -49,7 +48,13 @@ class PrRoIPoolFunction(Function):
output = features.new_zeros(output_shape)
params = (pooled_height, pooled_width, spatial_scale)
ext_module.prroi_pool_forward(features, rois, output, *params)
ext_module.prroi_pool_forward(
features,
rois,
output,
pooled_height=params[0],
pooled_width=params[1],
spatial_scale=params[2])
ctx.params = params
# everything here is contiguous.
ctx.save_for_backward(features, rois, output)
@ -65,14 +70,26 @@ class PrRoIPoolFunction(Function):
grad_input = grad_output.new_zeros(*features.shape)
grad_coor = grad_output.new_zeros(*rois.shape)
if features.requires_grad:
if features.requires_grad or TORCH_VERSION == 'parrots':
grad_output = grad_output.contiguous()
ext_module.prroi_pool_backward(grad_output, rois, grad_input,
*ctx.params)
if rois.requires_grad:
ext_module.prroi_pool_backward(
grad_output,
rois,
grad_input,
pooled_height=ctx.params[0],
pooled_width=ctx.params[1],
spatial_scale=ctx.params[2])
if rois.requires_grad or TORCH_VERSION == 'parrots':
grad_output = grad_output.contiguous()
ext_module.prroi_pool_coor_backward(output, grad_output, features,
rois, grad_coor, *ctx.params)
ext_module.prroi_pool_coor_backward(
output,
grad_output,
features,
rois,
grad_coor,
pooled_height=ctx.params[0],
pooled_width=ctx.params[1],
spatial_scale=ctx.params[2])
return grad_input, grad_coor, None, None, None

View File

@ -59,8 +59,7 @@ class TestPrRoiPool:
froipool = PrRoIPool((pool_h, pool_w), spatial_scale)
if _USING_PARROTS:
pass
# gradcheck(froipool, (x, rois), no_grads=[rois])
gradcheck(froipool, (x, rois), no_grads=[rois])
else:
gradcheck(froipool, (x, rois), eps=1e-2, atol=1e-2)