[Refactor] Refactor _prepare_pos_embed in ViT (#1656)
* deal with cls_token * Update implement --------- Co-authored-by: mzr1996 <mzr1996@163.com>pull/1665/head
parent
d4a6dfa00a
commit
70ff2abbf7
|
@ -305,6 +305,7 @@ class VisionTransformer(BaseBackbone):
|
|||
self.out_type = out_type
|
||||
|
||||
# Set cls token
|
||||
self.with_cls_token = with_cls_token
|
||||
if with_cls_token:
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
|
||||
elif out_type != 'cls_token':
|
||||
|
@ -404,6 +405,11 @@ class VisionTransformer(BaseBackbone):
|
|||
int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens)))
|
||||
pos_embed_shape = self.patch_embed.init_out_size
|
||||
|
||||
if (not self.with_cls_token and ckpt_pos_embed_shape[1]
|
||||
== self.pos_embed.shape[1] + 1):
|
||||
# Remove cls token from state dict if it's not used.
|
||||
state_dict[name] = state_dict[name][:, 1:]
|
||||
|
||||
state_dict[name] = resize_pos_embed(state_dict[name],
|
||||
ckpt_pos_embed_shape,
|
||||
pos_embed_shape,
|
||||
|
|
Loading…
Reference in New Issue