diff --git a/mmcv/ops/upfirdn2d.py b/mmcv/ops/upfirdn2d.py index 1de193d8a..1d2f32141 100644 --- a/mmcv/ops/upfirdn2d.py +++ b/mmcv/ops/upfirdn2d.py @@ -99,6 +99,7 @@ import torch from torch.autograd import Function from torch.nn import functional as F +from mmcv.utils import to_2tuple from ..utils import ext_loader upfirdn2d_ext = ext_loader.load_ext('_ext', ['upfirdn2d']) @@ -249,20 +250,39 @@ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): Args: input (Tensor): Tensor with shape of (n, c, h, w). kernel (Tensor): Filter kernel. - up (int, optional): Upsampling factor. Defaults to 1. - down (int, optional): Downsampling factor. Defaults to 1. - pad (tuple[int], optional): Padding for tensors, (x_pad, y_pad). - Defaults to (0, 0). + up (int | tuple[int], optional): Upsampling factor. If given a number, + we will use this factor for the both height and width side. + Defaults to 1. + down (int | tuple[int], optional): Downsampling factor. If given a + number, we will use this factor for the both height and width side. + Defaults to 1. + pad (tuple[int], optional): Padding for tensors, (x_pad, y_pad) or + (x_pad_0, x_pad_1, y_pad_0, y_pad_1). Defaults to (0, 0). Returns: Tensor: Tensor after UpFIRDn. """ if input.device.type == 'cpu': - out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], - pad[1], pad[0], pad[1]) + if len(pad) == 2: + pad = (pad[0], pad[1], pad[0], pad[1]) + + up = to_2tuple(up) + + down = to_2tuple(down) + + out = upfirdn2d_native(input, kernel, up[0], up[1], down[0], down[1], + pad[0], pad[1], pad[2], pad[3]) else: - out = UpFirDn2d.apply(input, kernel, (up, up), (down, down), - (pad[0], pad[1], pad[0], pad[1])) + _up = to_2tuple(up) + + _down = to_2tuple(down) + + if len(pad) == 4: + _pad = pad + elif len(pad) == 2: + _pad = (pad[0], pad[1], pad[0], pad[1]) + + out = UpFirDn2d.apply(input, kernel, _up, _down, _pad) return out