Fix pit regression

This commit is contained in:
Ross Wightman 2023-04-26 23:16:06 -07:00
parent 437d344e03
commit 493c730ffc

View File

@ -254,11 +254,7 @@ class PoolingVisionTransformer(nn.Module):
def forward_features(self, x):
x = self.patch_embed(x)
if x.shape[-1] != self.pos_embed.shape[-1] or x.shape[-2] != self.pos_embed.shape[-2]:
pos_embed = nn.functional.interpolate(self.pos_embed, x.shape[2:], mode='bilinear')
else:
pos_embed = self.pos_embed
x = self.pos_drop(x +pos_embed)
x = self.pos_drop(x + self.pos_embed)
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
x, cls_tokens = self.transformers((x, cls_tokens))
cls_tokens = self.norm(cls_tokens)