mirror of https://github.com/open-mmlab/mmcv.git
Implement corner pool with python for torch<1.5 (#1772)
* implement corner pool with python for torch<1.5 * fix for torch130pull/1792/head
parent
ac52bb3795
commit
8c23bf140a
|
@ -3,17 +3,38 @@ import torch
|
|||
from torch import nn
|
||||
from torch.autograd import Function
|
||||
|
||||
from ..utils import ext_loader
|
||||
|
||||
ext_module = ext_loader.load_ext('_ext', [
|
||||
'top_pool_forward', 'top_pool_backward', 'bottom_pool_forward',
|
||||
'bottom_pool_backward', 'left_pool_forward', 'left_pool_backward',
|
||||
'right_pool_forward', 'right_pool_backward'
|
||||
])
|
||||
|
||||
_mode_dict = {'top': 0, 'bottom': 1, 'left': 2, 'right': 3}
|
||||
|
||||
|
||||
def _corner_pool(x, dim, flip):
|
||||
size = x.size(dim)
|
||||
output = x.clone()
|
||||
|
||||
ind = 1
|
||||
while ind < size:
|
||||
if flip:
|
||||
cur_start = 0
|
||||
cur_len = size - ind
|
||||
next_start = ind
|
||||
next_len = size - ind
|
||||
else:
|
||||
cur_start = ind
|
||||
cur_len = size - ind
|
||||
next_start = 0
|
||||
next_len = size - ind
|
||||
|
||||
# max_temp should be cloned for backward computation
|
||||
max_temp = output.narrow(dim, cur_start, cur_len).clone()
|
||||
cur_temp = output.narrow(dim, cur_start, cur_len)
|
||||
next_temp = output.narrow(dim, next_start, next_len)
|
||||
|
||||
cur_temp[...] = torch.where(max_temp > next_temp, max_temp, next_temp)
|
||||
|
||||
ind = ind << 1
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class TopPoolFunction(Function):
|
||||
|
||||
@staticmethod
|
||||
|
@ -24,15 +45,7 @@ class TopPoolFunction(Function):
|
|||
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
output = ext_module.top_pool_forward(input)
|
||||
ctx.save_for_backward(input)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, = ctx.saved_tensors
|
||||
output = ext_module.top_pool_backward(input, grad_output)
|
||||
return output
|
||||
return _corner_pool(input, 2, True)
|
||||
|
||||
|
||||
class BottomPoolFunction(Function):
|
||||
|
@ -45,15 +58,7 @@ class BottomPoolFunction(Function):
|
|||
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
output = ext_module.bottom_pool_forward(input)
|
||||
ctx.save_for_backward(input)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, = ctx.saved_tensors
|
||||
output = ext_module.bottom_pool_backward(input, grad_output)
|
||||
return output
|
||||
return _corner_pool(input, 2, False)
|
||||
|
||||
|
||||
class LeftPoolFunction(Function):
|
||||
|
@ -66,15 +71,7 @@ class LeftPoolFunction(Function):
|
|||
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
output = ext_module.left_pool_forward(input)
|
||||
ctx.save_for_backward(input)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, = ctx.saved_tensors
|
||||
output = ext_module.left_pool_backward(input, grad_output)
|
||||
return output
|
||||
return _corner_pool(input, 3, True)
|
||||
|
||||
|
||||
class RightPoolFunction(Function):
|
||||
|
@ -87,15 +84,7 @@ class RightPoolFunction(Function):
|
|||
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
output = ext_module.right_pool_forward(input)
|
||||
ctx.save_for_backward(input)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, = ctx.saved_tensors
|
||||
output = ext_module.right_pool_backward(input, grad_output)
|
||||
return output
|
||||
return _corner_pool(input, 3, False)
|
||||
|
||||
|
||||
class CornerPool(nn.Module):
|
||||
|
@ -160,4 +149,8 @@ class CornerPool(nn.Module):
|
|||
pool_tensor = pool_tensor.flip(dim)
|
||||
return pool_tensor
|
||||
else:
|
||||
return self.corner_pool.apply(x)
|
||||
if torch.onnx.is_in_onnx_export():
|
||||
return self.corner_pool.apply(x)
|
||||
else:
|
||||
dim, flip = self.cummax_dim_flip[self.mode]
|
||||
return _corner_pool(x, dim, flip)
|
||||
|
|
|
@ -1,240 +0,0 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
// Modified from
|
||||
// https://github.com/princeton-vl/CornerNet-Lite/tree/master/core/models/py_utils/_cpools/src
|
||||
#include "pytorch_cpp_helper.hpp"
|
||||
|
||||
Tensor bottom_pool_forward(Tensor input) {
|
||||
// Initialize output
|
||||
Tensor output = at::zeros_like(input);
|
||||
// Get height
|
||||
int64_t height = input.size(2);
|
||||
output.copy_(input);
|
||||
|
||||
for (int64_t ind = 1; ind < height; ind <<= 1) {
|
||||
Tensor max_temp = at::slice(output, 2, ind, height);
|
||||
Tensor cur_temp = at::slice(output, 2, ind, height).clone();
|
||||
Tensor next_temp = at::slice(output, 2, 0, height - ind).clone();
|
||||
at::max_out(max_temp, cur_temp, next_temp);
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor bottom_pool_backward(Tensor input, Tensor grad_output) {
|
||||
auto output = at::zeros_like(input);
|
||||
|
||||
int32_t batch = input.size(0);
|
||||
int32_t channel = input.size(1);
|
||||
int32_t height = input.size(2);
|
||||
int32_t width = input.size(3);
|
||||
|
||||
auto max_val = torch::zeros({batch, channel, width},
|
||||
at::device(at::kCUDA).dtype(at::kFloat));
|
||||
auto max_ind = torch::zeros({batch, channel, width},
|
||||
at::device(at::kCUDA).dtype(at::kLong));
|
||||
|
||||
auto input_temp = input.select(2, 0);
|
||||
max_val.copy_(input_temp);
|
||||
|
||||
max_ind.fill_(0);
|
||||
|
||||
auto output_temp = output.select(2, 0);
|
||||
auto grad_output_temp = grad_output.select(2, 0);
|
||||
output_temp.copy_(grad_output_temp);
|
||||
|
||||
auto un_max_ind = max_ind.unsqueeze(2);
|
||||
auto gt_mask = torch::zeros({batch, channel, width},
|
||||
at::device(at::kCUDA).dtype(at::kBool));
|
||||
auto max_temp = torch::zeros({batch, channel, width},
|
||||
at::device(at::kCUDA).dtype(at::kFloat));
|
||||
for (int32_t ind = 0; ind < height - 1; ++ind) {
|
||||
input_temp = input.select(2, ind + 1);
|
||||
at::gt_out(gt_mask, input_temp, max_val);
|
||||
|
||||
at::masked_select_out(max_temp, input_temp, gt_mask);
|
||||
max_val.masked_scatter_(gt_mask, max_temp);
|
||||
max_ind.masked_fill_(gt_mask, ind + 1);
|
||||
|
||||
grad_output_temp = grad_output.select(2, ind + 1).unsqueeze(2);
|
||||
output.scatter_add_(2, un_max_ind, grad_output_temp);
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor left_pool_forward(Tensor input) {
|
||||
// Initialize output
|
||||
Tensor output = at::zeros_like(input);
|
||||
// Get width
|
||||
int64_t width = input.size(3);
|
||||
output.copy_(input);
|
||||
|
||||
for (int64_t ind = 1; ind < width; ind <<= 1) {
|
||||
Tensor max_temp = at::slice(output, 3, 0, width - ind);
|
||||
Tensor cur_temp = at::slice(output, 3, 0, width - ind).clone();
|
||||
Tensor next_temp = at::slice(output, 3, ind, width).clone();
|
||||
at::max_out(max_temp, cur_temp, next_temp);
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor left_pool_backward(Tensor input, Tensor grad_output) {
|
||||
auto output = at::zeros_like(input);
|
||||
|
||||
int32_t batch = input.size(0);
|
||||
int32_t channel = input.size(1);
|
||||
int32_t height = input.size(2);
|
||||
int32_t width = input.size(3);
|
||||
|
||||
auto max_val = torch::zeros({batch, channel, height},
|
||||
at::device(at::kCUDA).dtype(at::kFloat));
|
||||
auto max_ind = torch::zeros({batch, channel, height},
|
||||
at::device(at::kCUDA).dtype(at::kLong));
|
||||
|
||||
auto input_temp = input.select(3, width - 1);
|
||||
max_val.copy_(input_temp);
|
||||
|
||||
max_ind.fill_(width - 1);
|
||||
|
||||
auto output_temp = output.select(3, width - 1);
|
||||
auto grad_output_temp = grad_output.select(3, width - 1);
|
||||
output_temp.copy_(grad_output_temp);
|
||||
|
||||
auto un_max_ind = max_ind.unsqueeze(3);
|
||||
auto gt_mask = torch::zeros({batch, channel, height},
|
||||
at::device(at::kCUDA).dtype(at::kBool));
|
||||
auto max_temp = torch::zeros({batch, channel, height},
|
||||
at::device(at::kCUDA).dtype(at::kFloat));
|
||||
for (int32_t ind = 1; ind < width; ++ind) {
|
||||
input_temp = input.select(3, width - ind - 1);
|
||||
at::gt_out(gt_mask, input_temp, max_val);
|
||||
|
||||
at::masked_select_out(max_temp, input_temp, gt_mask);
|
||||
max_val.masked_scatter_(gt_mask, max_temp);
|
||||
max_ind.masked_fill_(gt_mask, width - ind - 1);
|
||||
|
||||
grad_output_temp = grad_output.select(3, width - ind - 1).unsqueeze(3);
|
||||
output.scatter_add_(3, un_max_ind, grad_output_temp);
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor right_pool_forward(Tensor input) {
|
||||
// Initialize output
|
||||
Tensor output = at::zeros_like(input);
|
||||
// Get width
|
||||
int64_t width = input.size(3);
|
||||
output.copy_(input);
|
||||
|
||||
for (int64_t ind = 1; ind < width; ind <<= 1) {
|
||||
Tensor max_temp = at::slice(output, 3, ind, width);
|
||||
Tensor cur_temp = at::slice(output, 3, ind, width).clone();
|
||||
Tensor next_temp = at::slice(output, 3, 0, width - ind).clone();
|
||||
at::max_out(max_temp, cur_temp, next_temp);
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor right_pool_backward(Tensor input, Tensor grad_output) {
|
||||
Tensor output = at::zeros_like(input);
|
||||
|
||||
int32_t batch = input.size(0);
|
||||
int32_t channel = input.size(1);
|
||||
int32_t height = input.size(2);
|
||||
int32_t width = input.size(3);
|
||||
|
||||
auto max_val = torch::zeros({batch, channel, height},
|
||||
at::device(at::kCUDA).dtype(at::kFloat));
|
||||
auto max_ind = torch::zeros({batch, channel, height},
|
||||
at::device(at::kCUDA).dtype(at::kLong));
|
||||
|
||||
auto input_temp = input.select(3, 0);
|
||||
max_val.copy_(input_temp);
|
||||
|
||||
max_ind.fill_(0);
|
||||
|
||||
auto output_temp = output.select(3, 0);
|
||||
auto grad_output_temp = grad_output.select(3, 0);
|
||||
output_temp.copy_(grad_output_temp);
|
||||
|
||||
auto un_max_ind = max_ind.unsqueeze(3);
|
||||
auto gt_mask = torch::zeros({batch, channel, height},
|
||||
at::device(at::kCUDA).dtype(at::kBool));
|
||||
auto max_temp = torch::zeros({batch, channel, height},
|
||||
at::device(at::kCUDA).dtype(at::kFloat));
|
||||
for (int32_t ind = 0; ind < width - 1; ++ind) {
|
||||
input_temp = input.select(3, ind + 1);
|
||||
at::gt_out(gt_mask, input_temp, max_val);
|
||||
|
||||
at::masked_select_out(max_temp, input_temp, gt_mask);
|
||||
max_val.masked_scatter_(gt_mask, max_temp);
|
||||
max_ind.masked_fill_(gt_mask, ind + 1);
|
||||
|
||||
grad_output_temp = grad_output.select(3, ind + 1).unsqueeze(3);
|
||||
output.scatter_add_(3, un_max_ind, grad_output_temp);
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor top_pool_forward(Tensor input) {
|
||||
// Initialize output
|
||||
Tensor output = at::zeros_like(input);
|
||||
// Get height
|
||||
int64_t height = input.size(2);
|
||||
output.copy_(input);
|
||||
|
||||
for (int64_t ind = 1; ind < height; ind <<= 1) {
|
||||
Tensor max_temp = at::slice(output, 2, 0, height - ind);
|
||||
Tensor cur_temp = at::slice(output, 2, 0, height - ind).clone();
|
||||
Tensor next_temp = at::slice(output, 2, ind, height).clone();
|
||||
at::max_out(max_temp, cur_temp, next_temp);
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor top_pool_backward(Tensor input, Tensor grad_output) {
|
||||
auto output = at::zeros_like(input);
|
||||
|
||||
int32_t batch = input.size(0);
|
||||
int32_t channel = input.size(1);
|
||||
int32_t height = input.size(2);
|
||||
int32_t width = input.size(3);
|
||||
|
||||
auto max_val = torch::zeros({batch, channel, width},
|
||||
at::device(at::kCUDA).dtype(at::kFloat));
|
||||
auto max_ind = torch::zeros({batch, channel, width},
|
||||
at::device(at::kCUDA).dtype(at::kLong));
|
||||
|
||||
auto input_temp = input.select(2, height - 1);
|
||||
max_val.copy_(input_temp);
|
||||
|
||||
max_ind.fill_(height - 1);
|
||||
|
||||
auto output_temp = output.select(2, height - 1);
|
||||
auto grad_output_temp = grad_output.select(2, height - 1);
|
||||
output_temp.copy_(grad_output_temp);
|
||||
|
||||
auto un_max_ind = max_ind.unsqueeze(2);
|
||||
auto gt_mask = torch::zeros({batch, channel, width},
|
||||
at::device(at::kCUDA).dtype(at::kBool));
|
||||
auto max_temp = torch::zeros({batch, channel, width},
|
||||
at::device(at::kCUDA).dtype(at::kFloat));
|
||||
for (int32_t ind = 1; ind < height; ++ind) {
|
||||
input_temp = input.select(2, height - ind - 1);
|
||||
at::gt_out(gt_mask, input_temp, max_val);
|
||||
|
||||
at::masked_select_out(max_temp, input_temp, gt_mask);
|
||||
max_val.masked_scatter_(gt_mask, max_temp);
|
||||
max_ind.masked_fill_(gt_mask, height - ind - 1);
|
||||
|
||||
grad_output_temp = grad_output.select(2, height - ind - 1).unsqueeze(2);
|
||||
output.scatter_add_(2, un_max_ind, grad_output_temp);
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
|
@ -1,234 +0,0 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#include <parrots/compute/aten.hpp>
|
||||
#include <parrots/extension.hpp>
|
||||
#include <parrots/foundation/ssattrs.hpp>
|
||||
|
||||
#include "corner_pool_pytorch.h"
|
||||
|
||||
using namespace parrots;
|
||||
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
void bottom_pool_forward_parrots(CudaContext& ctx, const SSElement& attr,
|
||||
const OperatorBase::in_list_t& ins,
|
||||
OperatorBase::out_list_t& outs) {
|
||||
at::Tensor input;
|
||||
input = buildATensor(ctx, ins[0]);
|
||||
auto out = bottom_pool_forward(input);
|
||||
updateDArray(ctx, out, outs[0]);
|
||||
}
|
||||
|
||||
void bottom_pool_backward_parrots(CudaContext& ctx, const SSElement& attr,
|
||||
const OperatorBase::in_list_t& ins,
|
||||
OperatorBase::out_list_t& outs) {
|
||||
at::Tensor input, grad_output;
|
||||
input = buildATensor(ctx, ins[0]);
|
||||
grad_output = buildATensor(ctx, ins[1]);
|
||||
auto out = bottom_pool_backward(input, grad_output);
|
||||
updateDArray(ctx, out, outs[0]);
|
||||
}
|
||||
|
||||
void left_pool_forward_parrots(CudaContext& ctx, const SSElement& attr,
|
||||
const OperatorBase::in_list_t& ins,
|
||||
OperatorBase::out_list_t& outs) {
|
||||
at::Tensor input;
|
||||
input = buildATensor(ctx, ins[0]);
|
||||
auto out = left_pool_forward(input);
|
||||
updateDArray(ctx, out, outs[0]);
|
||||
}
|
||||
|
||||
void left_pool_backward_parrots(CudaContext& ctx, const SSElement& attr,
|
||||
const OperatorBase::in_list_t& ins,
|
||||
OperatorBase::out_list_t& outs) {
|
||||
at::Tensor input, grad_output;
|
||||
input = buildATensor(ctx, ins[0]);
|
||||
grad_output = buildATensor(ctx, ins[1]);
|
||||
auto out = left_pool_backward(input, grad_output);
|
||||
updateDArray(ctx, out, outs[0]);
|
||||
}
|
||||
|
||||
void right_pool_forward_parrots(CudaContext& ctx, const SSElement& attr,
|
||||
const OperatorBase::in_list_t& ins,
|
||||
OperatorBase::out_list_t& outs) {
|
||||
at::Tensor input;
|
||||
input = buildATensor(ctx, ins[0]);
|
||||
auto out = right_pool_forward(input);
|
||||
updateDArray(ctx, out, outs[0]);
|
||||
}
|
||||
|
||||
void right_pool_backward_parrots(CudaContext& ctx, const SSElement& attr,
|
||||
const OperatorBase::in_list_t& ins,
|
||||
OperatorBase::out_list_t& outs) {
|
||||
at::Tensor input, grad_output;
|
||||
input = buildATensor(ctx, ins[0]);
|
||||
grad_output = buildATensor(ctx, ins[1]);
|
||||
auto out = right_pool_backward(input, grad_output);
|
||||
updateDArray(ctx, out, outs[0]);
|
||||
}
|
||||
|
||||
void top_pool_forward_parrots(CudaContext& ctx, const SSElement& attr,
|
||||
const OperatorBase::in_list_t& ins,
|
||||
OperatorBase::out_list_t& outs) {
|
||||
at::Tensor input;
|
||||
input = buildATensor(ctx, ins[0]);
|
||||
auto out = top_pool_forward(input);
|
||||
updateDArray(ctx, out, outs[0]);
|
||||
}
|
||||
|
||||
void top_pool_backward_parrots(CudaContext& ctx, const SSElement& attr,
|
||||
const OperatorBase::in_list_t& ins,
|
||||
OperatorBase::out_list_t& outs) {
|
||||
at::Tensor input, grad_output;
|
||||
input = buildATensor(ctx, ins[0]);
|
||||
grad_output = buildATensor(ctx, ins[1]);
|
||||
auto out = top_pool_backward(input, grad_output);
|
||||
updateDArray(ctx, out, outs[0]);
|
||||
}
|
||||
#endif
|
||||
|
||||
void bottom_pool_forward_parrots_cpu(HostContext& ctx, const SSElement& attr,
|
||||
const OperatorBase::in_list_t& ins,
|
||||
OperatorBase::out_list_t& outs) {
|
||||
at::Tensor input;
|
||||
input = buildATensor(ctx, ins[0]);
|
||||
auto out = bottom_pool_forward(input);
|
||||
updateDArray(ctx, out, outs[0]);
|
||||
}
|
||||
|
||||
void bottom_pool_backward_parrots_cpu(HostContext& ctx, const SSElement& attr,
|
||||
const OperatorBase::in_list_t& ins,
|
||||
OperatorBase::out_list_t& outs) {
|
||||
at::Tensor input, grad_output;
|
||||
input = buildATensor(ctx, ins[0]);
|
||||
grad_output = buildATensor(ctx, ins[1]);
|
||||
auto out = bottom_pool_backward(input, grad_output);
|
||||
updateDArray(ctx, out, outs[0]);
|
||||
}
|
||||
|
||||
void left_pool_forward_parrots_cpu(HostContext& ctx, const SSElement& attr,
|
||||
const OperatorBase::in_list_t& ins,
|
||||
OperatorBase::out_list_t& outs) {
|
||||
at::Tensor input;
|
||||
input = buildATensor(ctx, ins[0]);
|
||||
auto out = left_pool_forward(input);
|
||||
updateDArray(ctx, out, outs[0]);
|
||||
}
|
||||
|
||||
void left_pool_backward_parrots_cpu(HostContext& ctx, const SSElement& attr,
|
||||
const OperatorBase::in_list_t& ins,
|
||||
OperatorBase::out_list_t& outs) {
|
||||
at::Tensor input, grad_output;
|
||||
input = buildATensor(ctx, ins[0]);
|
||||
grad_output = buildATensor(ctx, ins[1]);
|
||||
auto out = left_pool_backward(input, grad_output);
|
||||
updateDArray(ctx, out, outs[0]);
|
||||
}
|
||||
|
||||
void right_pool_forward_parrots_cpu(HostContext& ctx, const SSElement& attr,
|
||||
const OperatorBase::in_list_t& ins,
|
||||
OperatorBase::out_list_t& outs) {
|
||||
at::Tensor input;
|
||||
input = buildATensor(ctx, ins[0]);
|
||||
auto out = right_pool_forward(input);
|
||||
updateDArray(ctx, out, outs[0]);
|
||||
}
|
||||
|
||||
void right_pool_backward_parrots_cpu(HostContext& ctx, const SSElement& attr,
|
||||
const OperatorBase::in_list_t& ins,
|
||||
OperatorBase::out_list_t& outs) {
|
||||
at::Tensor input, grad_output;
|
||||
input = buildATensor(ctx, ins[0]);
|
||||
grad_output = buildATensor(ctx, ins[1]);
|
||||
auto out = right_pool_backward(input, grad_output);
|
||||
updateDArray(ctx, out, outs[0]);
|
||||
}
|
||||
|
||||
void top_pool_forward_parrots_cpu(HostContext& ctx, const SSElement& attr,
|
||||
const OperatorBase::in_list_t& ins,
|
||||
OperatorBase::out_list_t& outs) {
|
||||
at::Tensor input;
|
||||
input = buildATensor(ctx, ins[0]);
|
||||
auto out = top_pool_forward(input);
|
||||
updateDArray(ctx, out, outs[0]);
|
||||
}
|
||||
|
||||
void top_pool_backward_parrots_cpu(HostContext& ctx, const SSElement& attr,
|
||||
const OperatorBase::in_list_t& ins,
|
||||
OperatorBase::out_list_t& outs) {
|
||||
at::Tensor input, grad_output;
|
||||
input = buildATensor(ctx, ins[0]);
|
||||
grad_output = buildATensor(ctx, ins[1]);
|
||||
auto out = top_pool_backward(input, grad_output);
|
||||
updateDArray(ctx, out, outs[0]);
|
||||
}
|
||||
|
||||
PARROTS_EXTENSION_REGISTER(bottom_pool_forward)
|
||||
.input(1)
|
||||
.output(1)
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
.apply(bottom_pool_forward_parrots)
|
||||
#endif
|
||||
.apply(bottom_pool_forward_parrots_cpu)
|
||||
.done();
|
||||
|
||||
PARROTS_EXTENSION_REGISTER(bottom_pool_backward)
|
||||
.input(2)
|
||||
.output(1)
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
.apply(bottom_pool_backward_parrots)
|
||||
#endif
|
||||
.apply(bottom_pool_backward_parrots_cpu)
|
||||
.done();
|
||||
|
||||
PARROTS_EXTENSION_REGISTER(top_pool_forward)
|
||||
.input(1)
|
||||
.output(1)
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
.apply(top_pool_forward_parrots)
|
||||
#endif
|
||||
.apply(top_pool_forward_parrots_cpu)
|
||||
.done();
|
||||
|
||||
PARROTS_EXTENSION_REGISTER(top_pool_backward)
|
||||
.input(2)
|
||||
.output(1)
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
.apply(top_pool_backward_parrots)
|
||||
#endif
|
||||
.apply(top_pool_backward_parrots_cpu)
|
||||
.done();
|
||||
|
||||
PARROTS_EXTENSION_REGISTER(left_pool_forward)
|
||||
.input(1)
|
||||
.output(1)
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
.apply(left_pool_forward_parrots)
|
||||
#endif
|
||||
.apply(left_pool_forward_parrots_cpu)
|
||||
.done();
|
||||
|
||||
PARROTS_EXTENSION_REGISTER(left_pool_backward)
|
||||
.input(2)
|
||||
.output(1)
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
.apply(left_pool_backward_parrots)
|
||||
#endif
|
||||
.apply(left_pool_backward_parrots_cpu)
|
||||
.done();
|
||||
|
||||
PARROTS_EXTENSION_REGISTER(right_pool_forward)
|
||||
.input(1)
|
||||
.output(1)
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
.apply(right_pool_forward_parrots)
|
||||
#endif
|
||||
.apply(right_pool_forward_parrots_cpu)
|
||||
.done();
|
||||
|
||||
PARROTS_EXTENSION_REGISTER(right_pool_backward)
|
||||
.input(2)
|
||||
.output(1)
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
.apply(right_pool_backward_parrots)
|
||||
#endif
|
||||
.apply(right_pool_backward_parrots_cpu)
|
||||
.done();
|
|
@ -1,15 +0,0 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#ifndef CORNER_POOL_PYTORCH_H
|
||||
#define CORNER_POOL_PYTORCH_H
|
||||
#include <torch/extension.h>
|
||||
|
||||
at::Tensor bottom_pool_forward(at::Tensor input);
|
||||
at::Tensor bottom_pool_backward(at::Tensor input, at::Tensor grad_output);
|
||||
at::Tensor left_pool_forward(at::Tensor input);
|
||||
at::Tensor left_pool_backward(at::Tensor input, at::Tensor grad_output);
|
||||
at::Tensor right_pool_forward(at::Tensor input);
|
||||
at::Tensor right_pool_backward(at::Tensor input, at::Tensor grad_output);
|
||||
at::Tensor top_pool_forward(at::Tensor input);
|
||||
at::Tensor top_pool_backward(at::Tensor input, at::Tensor grad_output);
|
||||
|
||||
#endif // CORNER_POOL_PYTORCH_H
|
|
@ -1,240 +0,0 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
// Modified from
|
||||
// https://github.com/princeton-vl/CornerNet-Lite/tree/master/core/models/py_utils/_cpools/src
|
||||
#include "pytorch_cpp_helper.hpp"
|
||||
|
||||
Tensor bottom_pool_forward(Tensor input) {
|
||||
// Initialize output
|
||||
Tensor output = at::zeros_like(input);
|
||||
// Get height
|
||||
int64_t height = input.size(2);
|
||||
output.copy_(input);
|
||||
|
||||
for (int64_t ind = 1; ind < height; ind <<= 1) {
|
||||
Tensor max_temp = at::slice(output, 2, ind, height);
|
||||
Tensor cur_temp = at::slice(output, 2, ind, height).clone();
|
||||
Tensor next_temp = at::slice(output, 2, 0, height - ind).clone();
|
||||
at::max_out(max_temp, cur_temp, next_temp);
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor bottom_pool_backward(Tensor input, Tensor grad_output) {
|
||||
auto output = at::zeros_like(input);
|
||||
|
||||
int32_t batch = input.size(0);
|
||||
int32_t channel = input.size(1);
|
||||
int32_t height = input.size(2);
|
||||
int32_t width = input.size(3);
|
||||
|
||||
auto max_val = torch::zeros({batch, channel, width},
|
||||
at::device(at::kCUDA).dtype(at::kFloat));
|
||||
auto max_ind = torch::zeros({batch, channel, width},
|
||||
at::device(at::kCUDA).dtype(at::kLong));
|
||||
|
||||
auto input_temp = input.select(2, 0);
|
||||
max_val.copy_(input_temp);
|
||||
|
||||
max_ind.fill_(0);
|
||||
|
||||
auto output_temp = output.select(2, 0);
|
||||
auto grad_output_temp = grad_output.select(2, 0);
|
||||
output_temp.copy_(grad_output_temp);
|
||||
|
||||
auto un_max_ind = max_ind.unsqueeze(2);
|
||||
auto gt_mask = torch::zeros({batch, channel, width},
|
||||
at::device(at::kCUDA).dtype(at::kBool));
|
||||
auto max_temp = torch::zeros({batch, channel, width},
|
||||
at::device(at::kCUDA).dtype(at::kFloat));
|
||||
for (int32_t ind = 0; ind < height - 1; ++ind) {
|
||||
input_temp = input.select(2, ind + 1);
|
||||
at::gt_out(gt_mask, input_temp, max_val);
|
||||
|
||||
at::masked_select_out(max_temp, input_temp, gt_mask);
|
||||
max_val.masked_scatter_(gt_mask, max_temp);
|
||||
max_ind.masked_fill_(gt_mask, ind + 1);
|
||||
|
||||
grad_output_temp = grad_output.select(2, ind + 1).unsqueeze(2);
|
||||
output.scatter_add_(2, un_max_ind, grad_output_temp);
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor left_pool_forward(Tensor input) {
|
||||
// Initialize output
|
||||
Tensor output = at::zeros_like(input);
|
||||
// Get width
|
||||
int64_t width = input.size(3);
|
||||
output.copy_(input);
|
||||
|
||||
for (int64_t ind = 1; ind < width; ind <<= 1) {
|
||||
Tensor max_temp = at::slice(output, 3, 0, width - ind);
|
||||
Tensor cur_temp = at::slice(output, 3, 0, width - ind).clone();
|
||||
Tensor next_temp = at::slice(output, 3, ind, width).clone();
|
||||
at::max_out(max_temp, cur_temp, next_temp);
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor left_pool_backward(Tensor input, Tensor grad_output) {
|
||||
auto output = at::zeros_like(input);
|
||||
|
||||
int32_t batch = input.size(0);
|
||||
int32_t channel = input.size(1);
|
||||
int32_t height = input.size(2);
|
||||
int32_t width = input.size(3);
|
||||
|
||||
auto max_val = torch::zeros({batch, channel, height},
|
||||
at::device(at::kCUDA).dtype(at::kFloat));
|
||||
auto max_ind = torch::zeros({batch, channel, height},
|
||||
at::device(at::kCUDA).dtype(at::kLong));
|
||||
|
||||
auto input_temp = input.select(3, width - 1);
|
||||
max_val.copy_(input_temp);
|
||||
|
||||
max_ind.fill_(width - 1);
|
||||
|
||||
auto output_temp = output.select(3, width - 1);
|
||||
auto grad_output_temp = grad_output.select(3, width - 1);
|
||||
output_temp.copy_(grad_output_temp);
|
||||
|
||||
auto un_max_ind = max_ind.unsqueeze(3);
|
||||
auto gt_mask = torch::zeros({batch, channel, height},
|
||||
at::device(at::kCUDA).dtype(at::kBool));
|
||||
auto max_temp = torch::zeros({batch, channel, height},
|
||||
at::device(at::kCUDA).dtype(at::kFloat));
|
||||
for (int32_t ind = 1; ind < width; ++ind) {
|
||||
input_temp = input.select(3, width - ind - 1);
|
||||
at::gt_out(gt_mask, input_temp, max_val);
|
||||
|
||||
at::masked_select_out(max_temp, input_temp, gt_mask);
|
||||
max_val.masked_scatter_(gt_mask, max_temp);
|
||||
max_ind.masked_fill_(gt_mask, width - ind - 1);
|
||||
|
||||
grad_output_temp = grad_output.select(3, width - ind - 1).unsqueeze(3);
|
||||
output.scatter_add_(3, un_max_ind, grad_output_temp);
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor right_pool_forward(Tensor input) {
|
||||
// Initialize output
|
||||
Tensor output = at::zeros_like(input);
|
||||
// Get width
|
||||
int64_t width = input.size(3);
|
||||
output.copy_(input);
|
||||
|
||||
for (int64_t ind = 1; ind < width; ind <<= 1) {
|
||||
Tensor max_temp = at::slice(output, 3, ind, width);
|
||||
Tensor cur_temp = at::slice(output, 3, ind, width).clone();
|
||||
Tensor next_temp = at::slice(output, 3, 0, width - ind).clone();
|
||||
at::max_out(max_temp, cur_temp, next_temp);
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor right_pool_backward(Tensor input, Tensor grad_output) {
|
||||
Tensor output = at::zeros_like(input);
|
||||
|
||||
int32_t batch = input.size(0);
|
||||
int32_t channel = input.size(1);
|
||||
int32_t height = input.size(2);
|
||||
int32_t width = input.size(3);
|
||||
|
||||
auto max_val = torch::zeros({batch, channel, height},
|
||||
at::device(at::kCUDA).dtype(at::kFloat));
|
||||
auto max_ind = torch::zeros({batch, channel, height},
|
||||
at::device(at::kCUDA).dtype(at::kLong));
|
||||
|
||||
auto input_temp = input.select(3, 0);
|
||||
max_val.copy_(input_temp);
|
||||
|
||||
max_ind.fill_(0);
|
||||
|
||||
auto output_temp = output.select(3, 0);
|
||||
auto grad_output_temp = grad_output.select(3, 0);
|
||||
output_temp.copy_(grad_output_temp);
|
||||
|
||||
auto un_max_ind = max_ind.unsqueeze(3);
|
||||
auto gt_mask = torch::zeros({batch, channel, height},
|
||||
at::device(at::kCUDA).dtype(at::kBool));
|
||||
auto max_temp = torch::zeros({batch, channel, height},
|
||||
at::device(at::kCUDA).dtype(at::kFloat));
|
||||
for (int32_t ind = 0; ind < width - 1; ++ind) {
|
||||
input_temp = input.select(3, ind + 1);
|
||||
at::gt_out(gt_mask, input_temp, max_val);
|
||||
|
||||
at::masked_select_out(max_temp, input_temp, gt_mask);
|
||||
max_val.masked_scatter_(gt_mask, max_temp);
|
||||
max_ind.masked_fill_(gt_mask, ind + 1);
|
||||
|
||||
grad_output_temp = grad_output.select(3, ind + 1).unsqueeze(3);
|
||||
output.scatter_add_(3, un_max_ind, grad_output_temp);
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor top_pool_forward(Tensor input) {
|
||||
// Initialize output
|
||||
Tensor output = at::zeros_like(input);
|
||||
// Get height
|
||||
int64_t height = input.size(2);
|
||||
output.copy_(input);
|
||||
|
||||
for (int64_t ind = 1; ind < height; ind <<= 1) {
|
||||
Tensor max_temp = at::slice(output, 2, 0, height - ind);
|
||||
Tensor cur_temp = at::slice(output, 2, 0, height - ind).clone();
|
||||
Tensor next_temp = at::slice(output, 2, ind, height).clone();
|
||||
at::max_out(max_temp, cur_temp, next_temp);
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor top_pool_backward(Tensor input, Tensor grad_output) {
|
||||
auto output = at::zeros_like(input);
|
||||
|
||||
int32_t batch = input.size(0);
|
||||
int32_t channel = input.size(1);
|
||||
int32_t height = input.size(2);
|
||||
int32_t width = input.size(3);
|
||||
|
||||
auto max_val = torch::zeros({batch, channel, width},
|
||||
at::device(at::kCUDA).dtype(at::kFloat));
|
||||
auto max_ind = torch::zeros({batch, channel, width},
|
||||
at::device(at::kCUDA).dtype(at::kLong));
|
||||
|
||||
auto input_temp = input.select(2, height - 1);
|
||||
max_val.copy_(input_temp);
|
||||
|
||||
max_ind.fill_(height - 1);
|
||||
|
||||
auto output_temp = output.select(2, height - 1);
|
||||
auto grad_output_temp = grad_output.select(2, height - 1);
|
||||
output_temp.copy_(grad_output_temp);
|
||||
|
||||
auto un_max_ind = max_ind.unsqueeze(2);
|
||||
auto gt_mask = torch::zeros({batch, channel, width},
|
||||
at::device(at::kCUDA).dtype(at::kBool));
|
||||
auto max_temp = torch::zeros({batch, channel, width},
|
||||
at::device(at::kCUDA).dtype(at::kFloat));
|
||||
for (int32_t ind = 1; ind < height; ++ind) {
|
||||
input_temp = input.select(2, height - ind - 1);
|
||||
at::gt_out(gt_mask, input_temp, max_val);
|
||||
|
||||
at::masked_select_out(max_temp, input_temp, gt_mask);
|
||||
max_val.masked_scatter_(gt_mask, max_temp);
|
||||
max_ind.masked_fill_(gt_mask, height - ind - 1);
|
||||
|
||||
grad_output_temp = grad_output.select(2, height - ind - 1).unsqueeze(2);
|
||||
output.scatter_add_(2, un_max_ind, grad_output_temp);
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
|
@ -279,22 +279,6 @@ Tensor indice_maxpool_backward(Tensor features, Tensor outFeatures,
|
|||
Tensor outGrad, Tensor indicePairs,
|
||||
Tensor indiceNum);
|
||||
|
||||
Tensor bottom_pool_forward(Tensor input);
|
||||
|
||||
Tensor bottom_pool_backward(Tensor input, Tensor grad_output);
|
||||
|
||||
Tensor left_pool_forward(Tensor input);
|
||||
|
||||
Tensor left_pool_backward(Tensor input, Tensor grad_output);
|
||||
|
||||
Tensor right_pool_forward(Tensor input);
|
||||
|
||||
Tensor right_pool_backward(Tensor input, Tensor grad_output);
|
||||
|
||||
Tensor top_pool_forward(Tensor input);
|
||||
|
||||
Tensor top_pool_backward(Tensor input, Tensor grad_output);
|
||||
|
||||
void box_iou_rotated(const Tensor boxes1, const Tensor boxes2, Tensor ious,
|
||||
const int mode_flag, const bool aligned);
|
||||
|
||||
|
@ -705,26 +689,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||
py::arg("input"), py::arg("shift"), py::arg("output"));
|
||||
m.def("tin_shift_backward", &tin_shift_backward, "tin_shift backward",
|
||||
py::arg("grad_output"), py::arg("shift"), py::arg("grad_input"));
|
||||
m.def("bottom_pool_forward", &bottom_pool_forward, "Bottom Pool Forward",
|
||||
py::arg("input"), py::call_guard<py::gil_scoped_release>());
|
||||
m.def("bottom_pool_backward", &bottom_pool_backward, "Bottom Pool Backward",
|
||||
py::arg("input"), py::arg("grad_output"),
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
m.def("left_pool_forward", &left_pool_forward, "Left Pool Forward",
|
||||
py::arg("input"), py::call_guard<py::gil_scoped_release>());
|
||||
m.def("left_pool_backward", &left_pool_backward, "Left Pool Backward",
|
||||
py::arg("input"), py::arg("grad_output"),
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
m.def("right_pool_forward", &right_pool_forward, "Right Pool Forward",
|
||||
py::arg("input"), py::call_guard<py::gil_scoped_release>());
|
||||
m.def("right_pool_backward", &right_pool_backward, "Right Pool Backward",
|
||||
py::arg("input"), py::arg("grad_output"),
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
m.def("top_pool_forward", &top_pool_forward, "Top Pool Forward",
|
||||
py::arg("input"), py::call_guard<py::gil_scoped_release>());
|
||||
m.def("top_pool_backward", &top_pool_backward, "Top Pool Backward",
|
||||
py::arg("input"), py::arg("grad_output"),
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
m.def("box_iou_rotated", &box_iou_rotated, "IoU for rotated boxes",
|
||||
py::arg("boxes1"), py::arg("boxes2"), py::arg("ious"),
|
||||
py::arg("mode_flag"), py::arg("aligned"));
|
||||
|
|
Loading…
Reference in New Issue