diff --git a/mmcv/ops/corner_pool.py b/mmcv/ops/corner_pool.py index 2fa3b69ca..6b0d87193 100644 --- a/mmcv/ops/corner_pool.py +++ b/mmcv/ops/corner_pool.py @@ -1,3 +1,4 @@ +import torch from torch import nn from torch.autograd import Function @@ -98,10 +99,27 @@ class CornerPool(nn.Module): 'top': TopPoolFunction, } + cummax_dim_flip = { + 'bottom': (2, False), + 'left': (3, True), + 'right': (3, False), + 'top': (2, True), + } + def __init__(self, mode): super(CornerPool, self).__init__() assert mode in self.pool_functions + self.mode = mode self.corner_pool = self.pool_functions[mode] def forward(self, x): - return self.corner_pool.apply(x) + if torch.__version__ != 'parrots' and torch.__version__ >= '1.5.0': + dim, flip = self.cummax_dim_flip[self.mode] + if flip: + x = x.flip(dim) + pool_tensor, _ = torch.cummax(x, dim=dim) + if flip: + pool_tensor = pool_tensor.flip(dim) + return pool_tensor + else: + return self.corner_pool.apply(x)