Initialize weights of reg_token for ViT

This commit is contained in:
Promisery 2024-07-13 11:11:42 +08:00
parent 648aaa4123
commit 417cf7f871

View File

@ -590,6 +590,8 @@ class VisionTransformer(nn.Module):
trunc_normal_(self.pos_embed, std=.02)
if self.cls_token is not None:
nn.init.normal_(self.cls_token, std=1e-6)
if self.reg_token is not None:
nn.init.normal_(self.reg_token, std=1e-6)
named_apply(get_init_weights_vit(mode, head_bias), self)
def _init_weights(self, m: nn.Module) -> None: