mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
cleanup davit padding
This commit is contained in:
parent
c715c724e7
commit
02d0f27721
@ -79,8 +79,9 @@ class Stem(nn.Module):
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
B, C, H, W = x.shape
|
||||
x = F.pad(x, (0, (self.stride[1] - W % self.stride[1]) % self.stride[1]))
|
||||
x = F.pad(x, (0, 0, 0, (self.stride[0] - H % self.stride[0]) % self.stride[0]))
|
||||
pad_r = (self.stride[1] - W % self.stride[1]) % self.stride[1]
|
||||
pad_b = (self.stride[0] - H % self.stride[0]) % self.stride[0]
|
||||
x = F.pad(x, (0, pad_r, 0, pad_b))
|
||||
x = self.conv(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
@ -113,8 +114,9 @@ class Downsample(nn.Module):
|
||||
x = self.norm(x)
|
||||
if self.even_k:
|
||||
k_h, k_w = self.conv.kernel_size
|
||||
x = F.pad(x, (0, (k_w - W % k_w) % k_w))
|
||||
x = F.pad(x, (0, 0, 0, (k_h - H % k_h) % k_h))
|
||||
pad_r = (k_w - W % k_w) % k_w
|
||||
pad_b = (k_h - H % k_h) % k_h
|
||||
x = F.pad(x, (0, pad_r , 0, pad_b))
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user