fix pos embed dynamic resampling for deit

pull/2326/head
Wojtek Jasiński 2024-11-01 23:24:13 +01:00 committed by Ross Wightman
parent 3ae3f44288
commit 3c7822c621
1 changed files with 3 additions and 1 deletions

View File

@ -75,9 +75,11 @@ class VisionTransformerDistilled(VisionTransformer):
def _pos_embed(self, x):
if self.dynamic_img_size:
B, H, W, C = x.shape
prev_grid_size = self.patch_embed.grid_size
pos_embed = resample_abs_pos_embed(
self.pos_embed,
(H, W),
new_size=(H, W),
old_size=prev_grid_size,
num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
)
x = x.view(B, -1, C)