Fix reversed H & W padding for swin patch merging
parent
ee5b1e8217
commit
65564f7da5
|
@ -435,7 +435,7 @@ class PatchMerging(nn.Module):
|
|||
def forward(self, x):
|
||||
B, H, W, C = x.shape
|
||||
|
||||
pad_values = (0, 0, 0, H % 2, 0, W % 2)
|
||||
pad_values = (0, 0, 0, W % 2, 0, H % 2)
|
||||
x = nn.functional.pad(x, pad_values)
|
||||
_, H, W, _ = x.shape
|
||||
|
||||
|
|
|
@ -439,7 +439,7 @@ class PatchMerging(nn.Module):
|
|||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
B, H, W, C = x.shape
|
||||
|
||||
pad_values = (0, 0, 0, H % 2, 0, W % 2)
|
||||
pad_values = (0, 0, 0, W % 2, 0, H % 2)
|
||||
x = nn.functional.pad(x, pad_values)
|
||||
_, H, W, _ = x.shape
|
||||
|
||||
|
|
|
@ -445,7 +445,7 @@ class PatchMerging(nn.Module):
|
|||
"""
|
||||
B, H, W, C = x.shape
|
||||
|
||||
pad_values = (0, 0, 0, H % 2, 0, W % 2)
|
||||
pad_values = (0, 0, 0, W % 2, 0, H % 2)
|
||||
x = nn.functional.pad(x, pad_values)
|
||||
_, H, W, _ = x.shape
|
||||
|
||||
|
|
Loading…
Reference in New Issue