From 00030e3f7d1266cb721abbebc6fb108e3876ca94 Mon Sep 17 00:00:00 2001 From: Peng Lu Date: Mon, 3 Jul 2023 11:36:44 +0800 Subject: [PATCH] [Fix] refactor _prepare_pos_embed in ViT to fix bug in loading old checkpoint (#1679) --- mmpretrain/models/backbones/vision_transformer.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/mmpretrain/models/backbones/vision_transformer.py b/mmpretrain/models/backbones/vision_transformer.py index 33e5baf7..2f10d43f 100644 --- a/mmpretrain/models/backbones/vision_transformer.py +++ b/mmpretrain/models/backbones/vision_transformer.py @@ -394,6 +394,12 @@ class VisionTransformer(BaseBackbone): return ckpt_pos_embed_shape = state_dict[name].shape + if (not self.with_cls_token + and ckpt_pos_embed_shape[1] == self.pos_embed.shape[1] + 1): + # Remove cls token from state dict if it's not used. + state_dict[name] = state_dict[name][:, 1:] + ckpt_pos_embed_shape = state_dict[name].shape + if self.pos_embed.shape != ckpt_pos_embed_shape: from mmengine.logging import MMLogger logger = MMLogger.get_current_instance() @@ -405,11 +411,6 @@ class VisionTransformer(BaseBackbone): int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens))) pos_embed_shape = self.patch_embed.init_out_size - if (not self.with_cls_token and ckpt_pos_embed_shape[1] - == self.pos_embed.shape[1] + 1): - # Remove cls token from state dict if it's not used. - state_dict[name] = state_dict[name][:, 1:] - state_dict[name] = resize_pos_embed(state_dict[name], ckpt_pos_embed_shape, pos_embed_shape,