mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
fix pos embed dynamic resampling for deit
This commit is contained in:
parent
3ae3f44288
commit
3c7822c621
@ -75,9 +75,11 @@ class VisionTransformerDistilled(VisionTransformer):
|
|||||||
def _pos_embed(self, x):
|
def _pos_embed(self, x):
|
||||||
if self.dynamic_img_size:
|
if self.dynamic_img_size:
|
||||||
B, H, W, C = x.shape
|
B, H, W, C = x.shape
|
||||||
|
prev_grid_size = self.patch_embed.grid_size
|
||||||
pos_embed = resample_abs_pos_embed(
|
pos_embed = resample_abs_pos_embed(
|
||||||
self.pos_embed,
|
self.pos_embed,
|
||||||
(H, W),
|
new_size=(H, W),
|
||||||
|
old_size=prev_grid_size,
|
||||||
num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
|
num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
|
||||||
)
|
)
|
||||||
x = x.view(B, -1, C)
|
x = x.view(B, -1, C)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user