[Fix] refactor _prepare_pos_embed in ViT to fix bug in loading old checkpoint (#1679)
parent
59c077746f
commit
00030e3f7d
|
@ -394,6 +394,12 @@ class VisionTransformer(BaseBackbone):
|
|||
return
|
||||
|
||||
ckpt_pos_embed_shape = state_dict[name].shape
|
||||
if (not self.with_cls_token
|
||||
and ckpt_pos_embed_shape[1] == self.pos_embed.shape[1] + 1):
|
||||
# Remove cls token from state dict if it's not used.
|
||||
state_dict[name] = state_dict[name][:, 1:]
|
||||
ckpt_pos_embed_shape = state_dict[name].shape
|
||||
|
||||
if self.pos_embed.shape != ckpt_pos_embed_shape:
|
||||
from mmengine.logging import MMLogger
|
||||
logger = MMLogger.get_current_instance()
|
||||
|
@ -405,11 +411,6 @@ class VisionTransformer(BaseBackbone):
|
|||
int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens)))
|
||||
pos_embed_shape = self.patch_embed.init_out_size
|
||||
|
||||
if (not self.with_cls_token and ckpt_pos_embed_shape[1]
|
||||
== self.pos_embed.shape[1] + 1):
|
||||
# Remove cls token from state dict if it's not used.
|
||||
state_dict[name] = state_dict[name][:, 1:]
|
||||
|
||||
state_dict[name] = resize_pos_embed(state_dict[name],
|
||||
ckpt_pos_embed_shape,
|
||||
pos_embed_shape,
|
||||
|
|
Loading…
Reference in New Issue