From dd5b415d2acfdc2ebd346f23c849556bfd2d0199 Mon Sep 17 00:00:00 2001 From: CokeDong <408244909@qq.com> Date: Tue, 11 Oct 2022 15:21:00 +0800 Subject: [PATCH] [Feature] Support PrRoI op for Parrots (#2280) * Support parrots extension for op PrRoI * Fix lint * Fix cpp lint * Fix testcase failure by false requires_grad in self-defined autograd Funtion * Fix issues * Fix flake8 * Fix isort * Adaption for typechecking for PrRoIPoolFunction * Fix lint * Support only float32 * bugfix --- mmcv/ops/csrc/parrots/cudabind.cpp | 51 ++++++++++ mmcv/ops/csrc/parrots/prroi_pool.cpp | 47 ++++++++++ mmcv/ops/csrc/parrots/prroi_pool_parrots.cpp | 97 ++++++++++++++++++++ mmcv/ops/csrc/parrots/prroi_pool_pytorch.h | 19 ++++ mmcv/ops/prroi_pool.py | 43 ++++++--- tests/test_ops/test_prroi_pool.py | 3 +- 6 files changed, 245 insertions(+), 15 deletions(-) create mode 100644 mmcv/ops/csrc/parrots/prroi_pool.cpp create mode 100644 mmcv/ops/csrc/parrots/prroi_pool_parrots.cpp create mode 100644 mmcv/ops/csrc/parrots/prroi_pool_pytorch.h diff --git a/mmcv/ops/csrc/parrots/cudabind.cpp b/mmcv/ops/csrc/parrots/cudabind.cpp index ed297260f..9627e26f4 100644 --- a/mmcv/ops/csrc/parrots/cudabind.cpp +++ b/mmcv/ops/csrc/parrots/cudabind.cpp @@ -1624,3 +1624,54 @@ REGISTER_DEVICE_IMPL(chamfer_distance_forward_impl, CUDA, chamfer_distance_forward_cuda); REGISTER_DEVICE_IMPL(chamfer_distance_backward_impl, CUDA, chamfer_distance_backward_cuda); + +void PrROIPoolForwardCUDAKernelLauncher(Tensor input, Tensor rois, + Tensor output, int pooled_height, + int pooled_width, float spatial_scale); + +void PrROIPoolBackwardCUDAKernelLauncher(Tensor grad_output, Tensor rois, + Tensor grad_input, int pooled_height, + int pooled_width, float spatial_scale); + +void PrROIPoolCoorBackwardCUDAKernelLauncher( + Tensor output, Tensor grad_output, Tensor input, Tensor rois, + Tensor grad_rois, int pooled_height, int pooled_width, float spatial_scale); + +void prroi_pool_forward_cuda(Tensor input, Tensor rois, Tensor output, + int pooled_height, int pooled_width, + float spatial_scale) { + PrROIPoolForwardCUDAKernelLauncher(input, rois, output, pooled_height, + pooled_width, spatial_scale); +} + +void prroi_pool_backward_cuda(Tensor grad_output, Tensor rois, + Tensor grad_input, int pooled_height, + int pooled_width, float spatial_scale) { + PrROIPoolBackwardCUDAKernelLauncher(grad_output, rois, grad_input, + pooled_height, pooled_width, + spatial_scale); +} + +void prroi_pool_coor_backward_cuda(Tensor output, Tensor grad_output, + Tensor input, Tensor rois, Tensor grad_rois, + int pooled_height, int pooled_width, + float spatial_scale) { + PrROIPoolCoorBackwardCUDAKernelLauncher(output, grad_output, input, rois, + grad_rois, pooled_height, + pooled_width, spatial_scale); +} + +void prroi_pool_forward_impl(Tensor input, Tensor rois, Tensor output, + int pooled_height, int pooled_width, + float spatial_scale); +void prroi_pool_backward_impl(Tensor grad_output, Tensor rois, + Tensor grad_input, int pooled_height, + int pooled_width, float spatial_scale); +void prroi_pool_coor_backward_impl(Tensor output, Tensor grad_output, + Tensor input, Tensor rois, Tensor grad_rois, + int pooled_height, int pooled_width, + float spatial_scale); +REGISTER_DEVICE_IMPL(prroi_pool_forward_impl, CUDA, prroi_pool_forward_cuda); +REGISTER_DEVICE_IMPL(prroi_pool_backward_impl, CUDA, prroi_pool_backward_cuda); +REGISTER_DEVICE_IMPL(prroi_pool_coor_backward_impl, CUDA, + prroi_pool_coor_backward_cuda); diff --git a/mmcv/ops/csrc/parrots/prroi_pool.cpp b/mmcv/ops/csrc/parrots/prroi_pool.cpp new file mode 100644 index 000000000..00db84a15 --- /dev/null +++ b/mmcv/ops/csrc/parrots/prroi_pool.cpp @@ -0,0 +1,47 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include "pytorch_cpp_helper.hpp" +#include "pytorch_device_registry.hpp" + +void prroi_pool_forward_impl(Tensor input, Tensor rois, Tensor output, + int pooled_height, int pooled_width, + float spatial_scale) { + DISPATCH_DEVICE_IMPL(prroi_pool_forward_impl, input, rois, output, + pooled_height, pooled_width, spatial_scale); +} + +void prroi_pool_backward_impl(Tensor grad_output, Tensor rois, + Tensor grad_input, int pooled_height, + int pooled_width, float spatial_scale) { + DISPATCH_DEVICE_IMPL(prroi_pool_backward_impl, grad_output, rois, grad_input, + pooled_height, pooled_width, spatial_scale); +} + +void prroi_pool_coor_backward_impl(Tensor output, Tensor grad_output, + Tensor input, Tensor rois, Tensor grad_rois, + int pooled_height, int pooled_width, + float spatial_scale) { + DISPATCH_DEVICE_IMPL(prroi_pool_coor_backward_impl, output, grad_output, + input, rois, grad_rois, pooled_height, pooled_width, + spatial_scale); +} + +void prroi_pool_forward(Tensor input, Tensor rois, Tensor output, + int pooled_height, int pooled_width, + float spatial_scale) { + prroi_pool_forward_impl(input, rois, output, pooled_height, pooled_width, + spatial_scale); +} + +void prroi_pool_backward(Tensor grad_output, Tensor rois, Tensor grad_input, + int pooled_height, int pooled_width, + float spatial_scale) { + prroi_pool_backward_impl(grad_output, rois, grad_input, pooled_height, + pooled_width, spatial_scale); +} + +void prroi_pool_coor_backward(Tensor output, Tensor grad_output, Tensor input, + Tensor rois, Tensor grad_rois, int pooled_height, + int pooled_width, float spatial_scale) { + prroi_pool_coor_backward_impl(output, grad_output, input, rois, grad_rois, + pooled_height, pooled_width, spatial_scale); +} diff --git a/mmcv/ops/csrc/parrots/prroi_pool_parrots.cpp b/mmcv/ops/csrc/parrots/prroi_pool_parrots.cpp new file mode 100644 index 000000000..4e8295581 --- /dev/null +++ b/mmcv/ops/csrc/parrots/prroi_pool_parrots.cpp @@ -0,0 +1,97 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include +#include +#include + +#include "prroi_pool_pytorch.h" + +using namespace parrots; + +#ifdef MMCV_WITH_CUDA +void prroi_pool_forward_cuda_parrots(CudaContext& ctx, const SSElement& attr, + const OperatorBase::in_list_t& ins, + OperatorBase::out_list_t& outs) { + int pooled_height; + int pooled_width; + float spatial_scale; + SSAttrs(attr) + .get("pooled_height", pooled_height) + .get("pooled_width", pooled_width) + .get("spatial_scale", spatial_scale) + .done(); + + const auto& input = buildATensor(ctx, ins[0]); + const auto& rois = buildATensor(ctx, ins[1]); + auto output = buildATensor(ctx, outs[0]); + prroi_pool_forward(input, rois, output, pooled_height, pooled_width, + spatial_scale); +} + +void prroi_pool_backward_cuda_parrots(CudaContext& ctx, const SSElement& attr, + const OperatorBase::in_list_t& ins, + OperatorBase::out_list_t& outs) { + int pooled_height; + int pooled_width; + float spatial_scale; + SSAttrs(attr) + .get("pooled_height", pooled_height) + .get("pooled_width", pooled_width) + .get("spatial_scale", spatial_scale) + .done(); + + const auto& grad_output = buildATensor(ctx, ins[0]); + const auto& rois = buildATensor(ctx, ins[1]); + auto grad_input = buildATensor(ctx, outs[0]); + prroi_pool_backward(grad_output, rois, grad_input, pooled_height, + pooled_width, spatial_scale); +} + +void prroi_pool_coor_backward_cuda_parrots(CudaContext& ctx, + const SSElement& attr, + const OperatorBase::in_list_t& ins, + OperatorBase::out_list_t& outs) { + int pooled_height; + int pooled_width; + float spatial_scale; + SSAttrs(attr) + .get("pooled_height", pooled_height) + .get("pooled_width", pooled_width) + .get("spatial_scale", spatial_scale) + .done(); + + const auto& output = buildATensor(ctx, ins[0]); + const auto& grad_output = buildATensor(ctx, ins[1]); + const auto& input = buildATensor(ctx, ins[2]); + const auto& rois = buildATensor(ctx, ins[3]); + auto grad_rois = buildATensor(ctx, outs[0]); + prroi_pool_coor_backward(output, grad_output, input, rois, grad_rois, + pooled_height, pooled_width, spatial_scale); +} + +PARROTS_EXTENSION_REGISTER(prroi_pool_forward) + .attr("pooled_height") + .attr("pooled_width") + .attr("spatial_scale") + .input(2) + .output(1) + .apply(prroi_pool_forward_cuda_parrots) + .done(); + +PARROTS_EXTENSION_REGISTER(prroi_pool_backward) + .attr("pooled_height") + .attr("pooled_width") + .attr("spatial_scale") + .input(2) + .output(1) + .apply(prroi_pool_backward_cuda_parrots) + .done(); + +PARROTS_EXTENSION_REGISTER(prroi_pool_coor_backward) + .attr("pooled_height") + .attr("pooled_width") + .attr("spatial_scale") + .input(4) + .output(1) + .apply(prroi_pool_coor_backward_cuda_parrots) + .done(); +#endif diff --git a/mmcv/ops/csrc/parrots/prroi_pool_pytorch.h b/mmcv/ops/csrc/parrots/prroi_pool_pytorch.h new file mode 100644 index 000000000..451b01dd5 --- /dev/null +++ b/mmcv/ops/csrc/parrots/prroi_pool_pytorch.h @@ -0,0 +1,19 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef PRROI_POOL_PYTORCH_H +#define PRROI_POOL_PYTORCH_H +#include +using namespace at; + +void prroi_pool_forward(Tensor input, Tensor rois, Tensor output, + int pooled_height, int pooled_width, + float spatial_scale); + +void prroi_pool_backward(Tensor grad_output, Tensor rois, Tensor grad_input, + int pooled_height, int pooled_width, + float spatial_scale); + +void prroi_pool_coor_backward(Tensor output, Tensor grad_output, Tensor input, + Tensor rois, Tensor grad_rois, int pooled_height, + int pooled_width, float spatial_scale); + +#endif // PRROI_POOL_PYTORCH_H diff --git a/mmcv/ops/prroi_pool.py b/mmcv/ops/prroi_pool.py index 47c223aa5..b8a99c11e 100644 --- a/mmcv/ops/prroi_pool.py +++ b/mmcv/ops/prroi_pool.py @@ -7,7 +7,7 @@ from torch.autograd import Function from torch.autograd.function import once_differentiable from torch.nn.modules.utils import _pair -from ..utils import ext_loader +from ..utils import TORCH_VERSION, ext_loader ext_module = ext_loader.load_ext( '_ext', @@ -32,11 +32,10 @@ class PrRoIPoolFunction(Function): rois: torch.Tensor, output_size: Tuple, spatial_scale: float = 1.0) -> torch.Tensor: - if 'FloatTensor' not in features.type( - ) or 'FloatTensor' not in rois.type(): - raise ValueError( - 'Precise RoI Pooling only takes float input, got ' - f'{features.type()} for features and {rois.type()} for rois.') + if features.dtype != torch.float32 or rois.dtype != torch.float32: + raise ValueError('Precise RoI Pooling only takes float input, got ' + f'{features.dtype()} for features and' + f'{rois.dtype()} for rois.') pooled_height = int(output_size[0]) pooled_width = int(output_size[1]) @@ -49,7 +48,13 @@ class PrRoIPoolFunction(Function): output = features.new_zeros(output_shape) params = (pooled_height, pooled_width, spatial_scale) - ext_module.prroi_pool_forward(features, rois, output, *params) + ext_module.prroi_pool_forward( + features, + rois, + output, + pooled_height=params[0], + pooled_width=params[1], + spatial_scale=params[2]) ctx.params = params # everything here is contiguous. ctx.save_for_backward(features, rois, output) @@ -65,14 +70,26 @@ class PrRoIPoolFunction(Function): grad_input = grad_output.new_zeros(*features.shape) grad_coor = grad_output.new_zeros(*rois.shape) - if features.requires_grad: + if features.requires_grad or TORCH_VERSION == 'parrots': grad_output = grad_output.contiguous() - ext_module.prroi_pool_backward(grad_output, rois, grad_input, - *ctx.params) - if rois.requires_grad: + ext_module.prroi_pool_backward( + grad_output, + rois, + grad_input, + pooled_height=ctx.params[0], + pooled_width=ctx.params[1], + spatial_scale=ctx.params[2]) + if rois.requires_grad or TORCH_VERSION == 'parrots': grad_output = grad_output.contiguous() - ext_module.prroi_pool_coor_backward(output, grad_output, features, - rois, grad_coor, *ctx.params) + ext_module.prroi_pool_coor_backward( + output, + grad_output, + features, + rois, + grad_coor, + pooled_height=ctx.params[0], + pooled_width=ctx.params[1], + spatial_scale=ctx.params[2]) return grad_input, grad_coor, None, None, None diff --git a/tests/test_ops/test_prroi_pool.py b/tests/test_ops/test_prroi_pool.py index 6ee471e82..0535dfbe2 100644 --- a/tests/test_ops/test_prroi_pool.py +++ b/tests/test_ops/test_prroi_pool.py @@ -59,8 +59,7 @@ class TestPrRoiPool: froipool = PrRoIPool((pool_h, pool_w), spatial_scale) if _USING_PARROTS: - pass - # gradcheck(froipool, (x, rois), no_grads=[rois]) + gradcheck(froipool, (x, rois), no_grads=[rois]) else: gradcheck(froipool, (x, rois), eps=1e-2, atol=1e-2)