[Enhancement] Revise the interface of upfirdn2d function (#1195)

* revise the interface of upfirdn2d function

* adopt to_2tuple
pull/1210/head
Rui Xu 2021-07-20 17:12:02 +08:00 committed by GitHub
parent 5f9e6b610b
commit faf6c6cd8e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 28 additions and 8 deletions

View File

@ -99,6 +99,7 @@ import torch
from torch.autograd import Function
from torch.nn import functional as F
from mmcv.utils import to_2tuple
from ..utils import ext_loader
upfirdn2d_ext = ext_loader.load_ext('_ext', ['upfirdn2d'])
@ -249,20 +250,39 @@ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
Args:
input (Tensor): Tensor with shape of (n, c, h, w).
kernel (Tensor): Filter kernel.
up (int, optional): Upsampling factor. Defaults to 1.
down (int, optional): Downsampling factor. Defaults to 1.
pad (tuple[int], optional): Padding for tensors, (x_pad, y_pad).
Defaults to (0, 0).
up (int | tuple[int], optional): Upsampling factor. If given a number,
we will use this factor for the both height and width side.
Defaults to 1.
down (int | tuple[int], optional): Downsampling factor. If given a
number, we will use this factor for the both height and width side.
Defaults to 1.
pad (tuple[int], optional): Padding for tensors, (x_pad, y_pad) or
(x_pad_0, x_pad_1, y_pad_0, y_pad_1). Defaults to (0, 0).
Returns:
Tensor: Tensor after UpFIRDn.
"""
if input.device.type == 'cpu':
out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0],
pad[1], pad[0], pad[1])
if len(pad) == 2:
pad = (pad[0], pad[1], pad[0], pad[1])
up = to_2tuple(up)
down = to_2tuple(down)
out = upfirdn2d_native(input, kernel, up[0], up[1], down[0], down[1],
pad[0], pad[1], pad[2], pad[3])
else:
out = UpFirDn2d.apply(input, kernel, (up, up), (down, down),
(pad[0], pad[1], pad[0], pad[1]))
_up = to_2tuple(up)
_down = to_2tuple(down)
if len(pad) == 4:
_pad = pad
elif len(pad) == 2:
_pad = (pad[0], pad[1], pad[0], pad[1])
out = UpFirDn2d.apply(input, kernel, _up, _down, _pad)
return out