Fix reversed H & W padding for swin patch merging

fix_swin_pad
Ross Wightman 2024-09-21 16:51:02 -07:00
parent ee5b1e8217
commit 65564f7da5
3 changed files with 3 additions and 3 deletions

View File

@ -435,7 +435,7 @@ class PatchMerging(nn.Module):
def forward(self, x):
B, H, W, C = x.shape
pad_values = (0, 0, 0, H % 2, 0, W % 2)
pad_values = (0, 0, 0, W % 2, 0, H % 2)
x = nn.functional.pad(x, pad_values)
_, H, W, _ = x.shape

View File

@ -439,7 +439,7 @@ class PatchMerging(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, H, W, C = x.shape
pad_values = (0, 0, 0, H % 2, 0, W % 2)
pad_values = (0, 0, 0, W % 2, 0, H % 2)
x = nn.functional.pad(x, pad_values)
_, H, W, _ = x.shape

View File

@ -445,7 +445,7 @@ class PatchMerging(nn.Module):
"""
B, H, W, C = x.shape
pad_values = (0, 0, 0, H % 2, 0, W % 2)
pad_values = (0, 0, 0, W % 2, 0, H % 2)
x = nn.functional.pad(x, pad_values)
_, H, W, _ = x.shape