Padding helpers work if tuples/lists passed
parent
2180800646
commit
7e0caa1ba3
|
@ -3,14 +3,19 @@
|
|||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import math
|
||||
from typing import List, Tuple
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .helpers import to_2tuple
|
||||
|
||||
|
||||
# Calculate symmetric padding for a convolution
|
||||
def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int:
|
||||
def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> Union[int, List[int]]:
|
||||
if any([isinstance(v, (tuple, list)) for v in [kernel_size, stride, dilation]]):
|
||||
kernel_size, stride, dilation = to_2tuple(kernel_size), to_2tuple(stride), to_2tuple(dilation)
|
||||
return [get_padding(*a) for a in zip(kernel_size, stride, dilation)]
|
||||
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
|
||||
return padding
|
||||
|
||||
|
@ -25,6 +30,9 @@ def get_same_padding(x: int, kernel_size: int, stride: int, dilation: int):
|
|||
|
||||
# Can SAME padding for given args be done statically?
|
||||
def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_):
|
||||
if any([isinstance(v, (tuple, list)) for v in [kernel_size, stride, dilation]]):
|
||||
kernel_size, stride, dilation = to_2tuple(kernel_size), to_2tuple(stride), to_2tuple(dilation)
|
||||
return all([is_static_pad(*a) for a in zip(kernel_size, stride, dilation)])
|
||||
return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue