diff --git a/mmpretrain/models/backbones/vision_transformer.py b/mmpretrain/models/backbones/vision_transformer.py index cd0a70d3..82e401c3 100644 --- a/mmpretrain/models/backbones/vision_transformer.py +++ b/mmpretrain/models/backbones/vision_transformer.py @@ -305,6 +305,7 @@ class VisionTransformer(BaseBackbone): self.out_type = out_type # Set cls token + self.with_cls_token = with_cls_token if with_cls_token: self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) elif out_type != 'cls_token': @@ -404,6 +405,11 @@ class VisionTransformer(BaseBackbone): int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens))) pos_embed_shape = self.patch_embed.init_out_size + if (not self.with_cls_token and ckpt_pos_embed_shape[1] + == self.pos_embed.shape[1] + 1): + # Remove cls token from state dict if it's not used. + state_dict[name] = state_dict[name][:, 1:] + state_dict[name] = resize_pos_embed(state_dict[name], ckpt_pos_embed_shape, pos_embed_shape,