From 02d0f2772172087a5e7b47352aaedfea29a3a2a7 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 22 Jun 2024 12:06:46 -0700 Subject: [PATCH] cleanup davit padding --- timm/models/davit.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index 09fa9bed..1dc74d23 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -79,8 +79,9 @@ class Stem(nn.Module): def forward(self, x: Tensor): B, C, H, W = x.shape - x = F.pad(x, (0, (self.stride[1] - W % self.stride[1]) % self.stride[1])) - x = F.pad(x, (0, 0, 0, (self.stride[0] - H % self.stride[0]) % self.stride[0])) + pad_r = (self.stride[1] - W % self.stride[1]) % self.stride[1] + pad_b = (self.stride[0] - H % self.stride[0]) % self.stride[0] + x = F.pad(x, (0, pad_r, 0, pad_b)) x = self.conv(x) x = self.norm(x) return x @@ -113,8 +114,9 @@ class Downsample(nn.Module): x = self.norm(x) if self.even_k: k_h, k_w = self.conv.kernel_size - x = F.pad(x, (0, (k_w - W % k_w) % k_w)) - x = F.pad(x, (0, 0, 0, (k_h - H % k_h) % k_h)) + pad_r = (k_w - W % k_w) % k_w + pad_b = (k_h - H % k_h) % k_h + x = F.pad(x, (0, pad_r , 0, pad_b)) x = self.conv(x) return x