use torch.cummax in corner_pool for torch 1.5+ (#390)

* synchronize from mmdetection

* fix parrots
pull/394/head
tianyuandu 2020-07-07 17:28:32 +08:00 committed by GitHub
parent b11c56603f
commit 039e06e43a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 19 additions and 1 deletions

View File

@ -1,3 +1,4 @@
import torch
from torch import nn
from torch.autograd import Function
@ -98,10 +99,27 @@ class CornerPool(nn.Module):
'top': TopPoolFunction,
}
cummax_dim_flip = {
'bottom': (2, False),
'left': (3, True),
'right': (3, False),
'top': (2, True),
}
def __init__(self, mode):
super(CornerPool, self).__init__()
assert mode in self.pool_functions
self.mode = mode
self.corner_pool = self.pool_functions[mode]
def forward(self, x):
return self.corner_pool.apply(x)
if torch.__version__ != 'parrots' and torch.__version__ >= '1.5.0':
dim, flip = self.cummax_dim_flip[self.mode]
if flip:
x = x.flip(dim)
pool_tensor, _ = torch.cummax(x, dim=dim)
if flip:
pool_tensor = pool_tensor.flip(dim)
return pool_tensor
else:
return self.corner_pool.apply(x)