Padding helpers work if tuples/lists passed

fix_mnv4_query_strides
Ross Wightman 2024-07-19 14:28:03 -07:00
parent 2180800646
commit 7e0caa1ba3
1 changed files with 10 additions and 2 deletions

View File

@ -3,14 +3,19 @@
Hacked together by / Copyright 2020 Ross Wightman
"""
import math
from typing import List, Tuple
from typing import List, Tuple, Union
import torch
import torch.nn.functional as F
from .helpers import to_2tuple
# Calculate symmetric padding for a convolution
def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int:
def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> Union[int, List[int]]:
if any([isinstance(v, (tuple, list)) for v in [kernel_size, stride, dilation]]):
kernel_size, stride, dilation = to_2tuple(kernel_size), to_2tuple(stride), to_2tuple(dilation)
return [get_padding(*a) for a in zip(kernel_size, stride, dilation)]
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
return padding
@ -25,6 +30,9 @@ def get_same_padding(x: int, kernel_size: int, stride: int, dilation: int):
# Can SAME padding for given args be done statically?
def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_):
if any([isinstance(v, (tuple, list)) for v in [kernel_size, stride, dilation]]):
kernel_size, stride, dilation = to_2tuple(kernel_size), to_2tuple(stride), to_2tuple(dilation)
return all([is_static_pad(*a) for a in zip(kernel_size, stride, dilation)])
return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0