mirror of https://github.com/open-mmlab/mmcv.git
[Enhancement] Revise the interface of upfirdn2d function (#1195)
* revise the interface of upfirdn2d function * adopt to_2tuplepull/1210/head
parent
5f9e6b610b
commit
faf6c6cd8e
|
@ -99,6 +99,7 @@ import torch
|
|||
from torch.autograd import Function
|
||||
from torch.nn import functional as F
|
||||
|
||||
from mmcv.utils import to_2tuple
|
||||
from ..utils import ext_loader
|
||||
|
||||
upfirdn2d_ext = ext_loader.load_ext('_ext', ['upfirdn2d'])
|
||||
|
@ -249,20 +250,39 @@ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
|||
Args:
|
||||
input (Tensor): Tensor with shape of (n, c, h, w).
|
||||
kernel (Tensor): Filter kernel.
|
||||
up (int, optional): Upsampling factor. Defaults to 1.
|
||||
down (int, optional): Downsampling factor. Defaults to 1.
|
||||
pad (tuple[int], optional): Padding for tensors, (x_pad, y_pad).
|
||||
Defaults to (0, 0).
|
||||
up (int | tuple[int], optional): Upsampling factor. If given a number,
|
||||
we will use this factor for the both height and width side.
|
||||
Defaults to 1.
|
||||
down (int | tuple[int], optional): Downsampling factor. If given a
|
||||
number, we will use this factor for the both height and width side.
|
||||
Defaults to 1.
|
||||
pad (tuple[int], optional): Padding for tensors, (x_pad, y_pad) or
|
||||
(x_pad_0, x_pad_1, y_pad_0, y_pad_1). Defaults to (0, 0).
|
||||
|
||||
Returns:
|
||||
Tensor: Tensor after UpFIRDn.
|
||||
"""
|
||||
if input.device.type == 'cpu':
|
||||
out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0],
|
||||
pad[1], pad[0], pad[1])
|
||||
if len(pad) == 2:
|
||||
pad = (pad[0], pad[1], pad[0], pad[1])
|
||||
|
||||
up = to_2tuple(up)
|
||||
|
||||
down = to_2tuple(down)
|
||||
|
||||
out = upfirdn2d_native(input, kernel, up[0], up[1], down[0], down[1],
|
||||
pad[0], pad[1], pad[2], pad[3])
|
||||
else:
|
||||
out = UpFirDn2d.apply(input, kernel, (up, up), (down, down),
|
||||
(pad[0], pad[1], pad[0], pad[1]))
|
||||
_up = to_2tuple(up)
|
||||
|
||||
_down = to_2tuple(down)
|
||||
|
||||
if len(pad) == 4:
|
||||
_pad = pad
|
||||
elif len(pad) == 2:
|
||||
_pad = (pad[0], pad[1], pad[0], pad[1])
|
||||
|
||||
out = UpFirDn2d.apply(input, kernel, _up, _down, _pad)
|
||||
|
||||
return out
|
||||
|
||||
|
|
Loading…
Reference in New Issue