[Fix] refactor _prepare_pos_embed in ViT to fix bug in loading old checkpoint (#1679)

pull/1689/head
Peng Lu 2023-07-03 11:36:44 +08:00 committed by GitHub
parent 59c077746f
commit 00030e3f7d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 5 deletions

View File

@ -394,6 +394,12 @@ class VisionTransformer(BaseBackbone):
return
ckpt_pos_embed_shape = state_dict[name].shape
if (not self.with_cls_token
and ckpt_pos_embed_shape[1] == self.pos_embed.shape[1] + 1):
# Remove cls token from state dict if it's not used.
state_dict[name] = state_dict[name][:, 1:]
ckpt_pos_embed_shape = state_dict[name].shape
if self.pos_embed.shape != ckpt_pos_embed_shape:
from mmengine.logging import MMLogger
logger = MMLogger.get_current_instance()
@ -405,11 +411,6 @@ class VisionTransformer(BaseBackbone):
int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens)))
pos_embed_shape = self.patch_embed.init_out_size
if (not self.with_cls_token and ckpt_pos_embed_shape[1]
== self.pos_embed.shape[1] + 1):
# Remove cls token from state dict if it's not used.
state_dict[name] = state_dict[name][:, 1:]
state_dict[name] = resize_pos_embed(state_dict[name],
ckpt_pos_embed_shape,
pos_embed_shape,