Fix positional embedding resampling for non-square inputs in ViT
parent
51ac8d2efb
commit
3ae3f44288
|
@ -669,9 +669,11 @@ class VisionTransformer(nn.Module):
|
|||
|
||||
if self.dynamic_img_size:
|
||||
B, H, W, C = x.shape
|
||||
prev_grid_size = self.patch_embed.grid_size
|
||||
pos_embed = resample_abs_pos_embed(
|
||||
self.pos_embed,
|
||||
(H, W),
|
||||
new_size=(H, W),
|
||||
old_size=prev_grid_size,
|
||||
num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
|
||||
)
|
||||
x = x.view(B, -1, C)
|
||||
|
|
Loading…
Reference in New Issue