[Refactor] Refactor _prepare_pos_embed in ViT (#1656)

* deal with cls_token

* Update implement

---------

Co-authored-by: mzr1996 <mzr1996@163.com>
pull/1665/head
Yixiao Fang 2023-06-20 17:37:08 +08:00 committed by GitHub
parent d4a6dfa00a
commit 70ff2abbf7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 0 deletions

View File

@ -305,6 +305,7 @@ class VisionTransformer(BaseBackbone):
self.out_type = out_type
# Set cls token
self.with_cls_token = with_cls_token
if with_cls_token:
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
elif out_type != 'cls_token':
@ -404,6 +405,11 @@ class VisionTransformer(BaseBackbone):
int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens)))
pos_embed_shape = self.patch_embed.init_out_size
if (not self.with_cls_token and ckpt_pos_embed_shape[1]
== self.pos_embed.shape[1] + 1):
# Remove cls token from state dict if it's not used.
state_dict[name] = state_dict[name][:, 1:]
state_dict[name] = resize_pos_embed(state_dict[name],
ckpt_pos_embed_shape,
pos_embed_shape,