mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix pit regression
This commit is contained in:
parent
437d344e03
commit
493c730ffc
@ -254,11 +254,7 @@ class PoolingVisionTransformer(nn.Module):
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
if x.shape[-1] != self.pos_embed.shape[-1] or x.shape[-2] != self.pos_embed.shape[-2]:
|
||||
pos_embed = nn.functional.interpolate(self.pos_embed, x.shape[2:], mode='bilinear')
|
||||
else:
|
||||
pos_embed = self.pos_embed
|
||||
x = self.pos_drop(x +pos_embed)
|
||||
x = self.pos_drop(x + self.pos_embed)
|
||||
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
|
||||
x, cls_tokens = self.transformers((x, cls_tokens))
|
||||
cls_tokens = self.norm(cls_tokens)
|
||||
|
Loading…
x
Reference in New Issue
Block a user