cleanup davit padding

This commit is contained in:
Ross Wightman 2024-06-22 12:06:46 -07:00
parent c715c724e7
commit 02d0f27721

View File

@ -79,8 +79,9 @@ class Stem(nn.Module):
def forward(self, x: Tensor):
B, C, H, W = x.shape
x = F.pad(x, (0, (self.stride[1] - W % self.stride[1]) % self.stride[1]))
x = F.pad(x, (0, 0, 0, (self.stride[0] - H % self.stride[0]) % self.stride[0]))
pad_r = (self.stride[1] - W % self.stride[1]) % self.stride[1]
pad_b = (self.stride[0] - H % self.stride[0]) % self.stride[0]
x = F.pad(x, (0, pad_r, 0, pad_b))
x = self.conv(x)
x = self.norm(x)
return x
@ -113,8 +114,9 @@ class Downsample(nn.Module):
x = self.norm(x)
if self.even_k:
k_h, k_w = self.conv.kernel_size
x = F.pad(x, (0, (k_w - W % k_w) % k_w))
x = F.pad(x, (0, 0, 0, (k_h - H % k_h) % k_h))
pad_r = (k_w - W % k_w) % k_w
pad_b = (k_h - H % k_h) % k_h
x = F.pad(x, (0, pad_r , 0, pad_b))
x = self.conv(x)
return x