mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge jax and original weight init
This commit is contained in:
parent
acbd698c83
commit
bf2ca6bdf4
@ -289,40 +289,19 @@ class VisionTransformer(nn.Module):
|
||||
assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '')
|
||||
head_bias = -math.log(self.num_classes) if 'nlhb' in weight_init else 0.
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
if self.dist_token is not None:
|
||||
trunc_normal_(self.dist_token, std=.02)
|
||||
if weight_init.startswith('jax'):
|
||||
# leave cls token as zeros to match jax impl
|
||||
for n, m in self.named_modules():
|
||||
_init_weights_jax(m, n, head_bias=head_bias)
|
||||
_init_vit_weights(m, n, head_bias=head_bias, jax_impl=True)
|
||||
else:
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
if self.dist_token is not None:
|
||||
trunc_normal_(self.dist_token, std=.02)
|
||||
for n, m in self.named_modules():
|
||||
self._init_weights(m, n, head_bias=head_bias)
|
||||
self.apply(_init_vit_weights)
|
||||
|
||||
def _init_weights(self, m, n: str = '', head_bias: float = 0., init_conv=False):
|
||||
# This impl does not exactly match the official JAX version.
|
||||
# When called w/o n, head_bias, init_conv args it will behave exactly the same
|
||||
# as my original init for compatibility with downstream use cases (ie DeiT).
|
||||
if isinstance(m, nn.Linear):
|
||||
if n.startswith('head'):
|
||||
nn.init.zeros_(m.weight)
|
||||
nn.init.constant_(m.bias, head_bias)
|
||||
elif n.startswith('pre_logits'):
|
||||
lecun_normal_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
else:
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif init_conv and isinstance(m, nn.Conv2d):
|
||||
# NOTE conv was left to pytorch default init originally
|
||||
lecun_normal_(m.weight)
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.zeros_(m.bias)
|
||||
nn.init.ones_(m.weight)
|
||||
def _init_weights(self, m):
|
||||
# this fn left here for compat with downstream users
|
||||
_init_vit_weights(m)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
@ -369,9 +348,12 @@ class VisionTransformer(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def _init_weights_jax(m: nn.Module, n: str, head_bias: float = 0.):
|
||||
# A weight init scheme closer to the official JAX impl than my original init
|
||||
# NOTE: requires module name so cannot be used via module.apply()
|
||||
def _init_vit_weights(m, n: str = '', head_bias: float = 0., jax_impl: bool = False):
|
||||
""" ViT weight initialization
|
||||
* When called without n, head_bias, jax_impl args it will behave exactly the same
|
||||
as my original init for compatibility with prev hparam / downstream use cases (ie DeiT).
|
||||
* When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl
|
||||
"""
|
||||
if isinstance(m, nn.Linear):
|
||||
if n.startswith('head'):
|
||||
nn.init.zeros_(m.weight)
|
||||
@ -380,13 +362,19 @@ def _init_weights_jax(m: nn.Module, n: str, head_bias: float = 0.):
|
||||
lecun_normal_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
else:
|
||||
nn.init.xavier_uniform_(m.weight)
|
||||
if m.bias is not None:
|
||||
if 'mlp' in n:
|
||||
nn.init.normal_(m.bias, 0, 1e-6)
|
||||
else:
|
||||
if jax_impl:
|
||||
nn.init.xavier_uniform_(m.weight)
|
||||
if m.bias is not None:
|
||||
if 'mlp' in n:
|
||||
nn.init.normal_(m.bias, std=1e-6)
|
||||
else:
|
||||
nn.init.zeros_(m.bias)
|
||||
else:
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
elif jax_impl and isinstance(m, nn.Conv2d):
|
||||
# NOTE conv was left to pytorch default in my original init
|
||||
lecun_normal_(m.weight)
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
|
Loading…
x
Reference in New Issue
Block a user