mirror of https://github.com/open-mmlab/mmcv.git
use torch.cummax in corner_pool for torch 1.5+ (#390)
* synchronize from mmdetection * fix parrotspull/394/head
parent
b11c56603f
commit
039e06e43a
|
@ -1,3 +1,4 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Function
|
||||
|
||||
|
@ -98,10 +99,27 @@ class CornerPool(nn.Module):
|
|||
'top': TopPoolFunction,
|
||||
}
|
||||
|
||||
cummax_dim_flip = {
|
||||
'bottom': (2, False),
|
||||
'left': (3, True),
|
||||
'right': (3, False),
|
||||
'top': (2, True),
|
||||
}
|
||||
|
||||
def __init__(self, mode):
|
||||
super(CornerPool, self).__init__()
|
||||
assert mode in self.pool_functions
|
||||
self.mode = mode
|
||||
self.corner_pool = self.pool_functions[mode]
|
||||
|
||||
def forward(self, x):
|
||||
return self.corner_pool.apply(x)
|
||||
if torch.__version__ != 'parrots' and torch.__version__ >= '1.5.0':
|
||||
dim, flip = self.cummax_dim_flip[self.mode]
|
||||
if flip:
|
||||
x = x.flip(dim)
|
||||
pool_tensor, _ = torch.cummax(x, dim=dim)
|
||||
if flip:
|
||||
pool_tensor = pool_tensor.flip(dim)
|
||||
return pool_tensor
|
||||
else:
|
||||
return self.corner_pool.apply(x)
|
||||
|
|
Loading…
Reference in New Issue