From 493c730ffc9e532541b19b31daf9238190d8b991 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 26 Apr 2023 23:16:06 -0700 Subject: [PATCH] Fix pit regression --- timm/models/pit.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/timm/models/pit.py b/timm/models/pit.py index d5e51038..4c5addd8 100644 --- a/timm/models/pit.py +++ b/timm/models/pit.py @@ -254,11 +254,7 @@ class PoolingVisionTransformer(nn.Module): def forward_features(self, x): x = self.patch_embed(x) - if x.shape[-1] != self.pos_embed.shape[-1] or x.shape[-2] != self.pos_embed.shape[-2]: - pos_embed = nn.functional.interpolate(self.pos_embed, x.shape[2:], mode='bilinear') - else: - pos_embed = self.pos_embed - x = self.pos_drop(x +pos_embed) + x = self.pos_drop(x + self.pos_embed) cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) x, cls_tokens = self.transformers((x, cls_tokens)) cls_tokens = self.norm(cls_tokens)