From 7e0caa1ba31100c87484cd95ccab6af3cb392589 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 19 Jul 2024 14:28:03 -0700 Subject: [PATCH] Padding helpers work if tuples/lists passed --- timm/layers/padding.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/timm/layers/padding.py b/timm/layers/padding.py index d6971526..4b85d747 100644 --- a/timm/layers/padding.py +++ b/timm/layers/padding.py @@ -3,14 +3,19 @@ Hacked together by / Copyright 2020 Ross Wightman """ import math -from typing import List, Tuple +from typing import List, Tuple, Union import torch import torch.nn.functional as F +from .helpers import to_2tuple + # Calculate symmetric padding for a convolution -def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int: +def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> Union[int, List[int]]: + if any([isinstance(v, (tuple, list)) for v in [kernel_size, stride, dilation]]): + kernel_size, stride, dilation = to_2tuple(kernel_size), to_2tuple(stride), to_2tuple(dilation) + return [get_padding(*a) for a in zip(kernel_size, stride, dilation)] padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 return padding @@ -25,6 +30,9 @@ def get_same_padding(x: int, kernel_size: int, stride: int, dilation: int): # Can SAME padding for given args be done statically? def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_): + if any([isinstance(v, (tuple, list)) for v in [kernel_size, stride, dilation]]): + kernel_size, stride, dilation = to_2tuple(kernel_size), to_2tuple(stride), to_2tuple(dilation) + return all([is_static_pad(*a) for a in zip(kernel_size, stride, dilation)]) return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0