mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Cleanup experimental vit weight init a bit
This commit is contained in:
parent
f42f1df26c
commit
cf5fec5047
@ -31,4 +31,4 @@ from .split_attn import SplitAttnConv2d
|
|||||||
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
|
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
|
||||||
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
|
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
|
||||||
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
|
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
|
||||||
from .weight_init import trunc_normal_
|
from .weight_init import trunc_normal_, variance_scaling_, lecun_normal_
|
||||||
|
@ -2,6 +2,8 @@ import torch
|
|||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
from torch.nn.init import _calculate_fan_in_and_fan_out
|
||||||
|
|
||||||
|
|
||||||
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
||||||
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
||||||
@ -58,3 +60,30 @@ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
|||||||
>>> nn.init.trunc_normal_(w)
|
>>> nn.init.trunc_normal_(w)
|
||||||
"""
|
"""
|
||||||
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
||||||
|
|
||||||
|
|
||||||
|
def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
|
||||||
|
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
||||||
|
if mode == 'fan_in':
|
||||||
|
denom = fan_in
|
||||||
|
elif mode == 'fan_out':
|
||||||
|
denom = fan_out
|
||||||
|
elif mode == 'fan_avg':
|
||||||
|
denom = (fan_in + fan_out) / 2
|
||||||
|
|
||||||
|
variance = scale / denom
|
||||||
|
|
||||||
|
if distribution == "truncated_normal":
|
||||||
|
# constant is stddev of standard normal truncated to (-2, 2)
|
||||||
|
trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978)
|
||||||
|
elif distribution == "normal":
|
||||||
|
tensor.normal_(std=math.sqrt(variance))
|
||||||
|
elif distribution == "uniform":
|
||||||
|
bound = math.sqrt(3 * variance)
|
||||||
|
tensor.uniform_(-bound, bound)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"invalid distribution {distribution}")
|
||||||
|
|
||||||
|
|
||||||
|
def lecun_normal_(tensor):
|
||||||
|
variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')
|
||||||
|
@ -28,7 +28,7 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from .helpers import load_pretrained
|
from .helpers import load_pretrained
|
||||||
from .layers import StdConv2dSame, StdConv2d, DropPath, to_2tuple, trunc_normal_
|
from .layers import StdConv2dSame, StdConv2d, DropPath, to_2tuple, trunc_normal_, lecun_normal_
|
||||||
from .resnet import resnet26d, resnet50d
|
from .resnet import resnet26d, resnet50d
|
||||||
from .resnetv2 import ResNetV2, create_resnetv2_stem
|
from .resnetv2 import ResNetV2, create_resnetv2_stem
|
||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
@ -373,7 +373,7 @@ class VisionTransformer(nn.Module):
|
|||||||
def __init__(self, img_size=224, patch_size=None, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
def __init__(self, img_size=224, patch_size=None, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
||||||
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
|
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
|
||||||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None,
|
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None,
|
||||||
act_layer=None, weight_init=''):
|
act_layer=None, weight_init='new_nlhb'):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
img_size (int, tuple): input image size
|
img_size (int, tuple): input image size
|
||||||
@ -433,14 +433,20 @@ class VisionTransformer(nn.Module):
|
|||||||
# Classifier head
|
# Classifier head
|
||||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
|
||||||
|
self._init_weights(weight_init)
|
||||||
|
|
||||||
|
def _init_weights(self, weight_init: str):
|
||||||
trunc_normal_(self.pos_embed, std=.02)
|
trunc_normal_(self.pos_embed, std=.02)
|
||||||
if weight_init != 'jax': # leave as zeros to match JAX impl
|
if weight_init.startswith('jax'):
|
||||||
trunc_normal_(self.cls_token, std=.02)
|
init_fn = _init_weights_jax
|
||||||
for n, m in self.named_modules():
|
# leave cls token as zeros to match jax impl
|
||||||
if weight_init == 'jax':
|
|
||||||
_init_weights_jax(m, n)
|
|
||||||
else:
|
else:
|
||||||
_init_weights_original(m, n)
|
trunc_normal_(self.cls_token, std=.02)
|
||||||
|
init_fn = _init_weights_new if weight_init.startswith('new') else _init_weights_old
|
||||||
|
hb = -math.log(self.num_classes) if 'nlhb' in weight_init else 0.
|
||||||
|
init_fn = partial(init_fn, head_bias=hb)
|
||||||
|
for n, m in self.named_modules():
|
||||||
|
init_fn(m, n)
|
||||||
|
|
||||||
@torch.jit.ignore
|
@torch.jit.ignore
|
||||||
def no_weight_decay(self):
|
def no_weight_decay(self):
|
||||||
@ -475,41 +481,42 @@ class VisionTransformer(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def _init_weights_original(m: nn.Module, n: str = ''):
|
def _init_weights_old(m: nn.Module, n: str = '', head_bias: float = 0.):
|
||||||
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
if isinstance(m, nn.Linear):
|
||||||
trunc_normal_(m.weight, std=.02)
|
trunc_normal_(m.weight, std=.02)
|
||||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
if m.bias is not None:
|
||||||
nn.init.constant_(m.bias, 0)
|
if 'head' in n:
|
||||||
|
nn.init.constant_(m.bias, head_bias)
|
||||||
|
else:
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
elif isinstance(m, nn.LayerNorm):
|
elif isinstance(m, nn.LayerNorm):
|
||||||
nn.init.zeros_(m.bias)
|
nn.init.zeros_(m.bias)
|
||||||
nn.init.ones_(m.weight)
|
nn.init.ones_(m.weight)
|
||||||
|
|
||||||
|
|
||||||
def _init_weights_jax(m: nn.Module, n: str):
|
def _init_weights_new(m: nn.Module, n: str = '', head_bias: float = 0.):
|
||||||
""" Weight init scheme closer to the official JAX impl than my original init"""
|
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
||||||
|
#trunc_normal_(m.weight, std=.02)
|
||||||
|
lecun_normal_(m.weight)
|
||||||
|
if m.bias is not None:
|
||||||
|
if 'head' in n:
|
||||||
|
nn.init.constant_(m.bias, head_bias)
|
||||||
|
else:
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
elif isinstance(m, nn.LayerNorm):
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
nn.init.ones_(m.weight)
|
||||||
|
|
||||||
def _fan_in(tensor):
|
|
||||||
dimensions = tensor.dim()
|
|
||||||
if dimensions < 2:
|
|
||||||
raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
|
|
||||||
|
|
||||||
num_input_fmaps = tensor.size(1)
|
def _init_weights_jax(m: nn.Module, n: str, head_bias: float = 0.):
|
||||||
receptive_field_size = 1
|
""" Attempt at weight init scheme closer to the official JAX impl than my original init"""
|
||||||
if tensor.dim() > 2:
|
|
||||||
receptive_field_size = tensor[0][0].numel()
|
|
||||||
fan_in = num_input_fmaps * receptive_field_size
|
|
||||||
return fan_in
|
|
||||||
|
|
||||||
def _lecun_normal(w):
|
|
||||||
stddev = (1.0 / _fan_in(w)) ** 0.5 / .87962566103423978
|
|
||||||
trunc_normal_(w, 0, stddev)
|
|
||||||
|
|
||||||
if isinstance(m, nn.Linear):
|
if isinstance(m, nn.Linear):
|
||||||
if 'head' in n:
|
if 'head' in n:
|
||||||
nn.init.zeros_(m.weight)
|
nn.init.zeros_(m.weight)
|
||||||
nn.init.zeros_(m.bias)
|
nn.init.constant_(m.bias, head_bias)
|
||||||
elif 'pre_logits' in n:
|
elif 'pre_logits' in n:
|
||||||
_lecun_normal(m.weight)
|
lecun_normal_(m.weight)
|
||||||
nn.init.zeros_(m.bias)
|
nn.init.zeros_(m.bias)
|
||||||
else:
|
else:
|
||||||
nn.init.xavier_uniform_(m.weight)
|
nn.init.xavier_uniform_(m.weight)
|
||||||
@ -519,7 +526,7 @@ def _init_weights_jax(m: nn.Module, n: str):
|
|||||||
else:
|
else:
|
||||||
nn.init.zeros_(m.bias)
|
nn.init.zeros_(m.bias)
|
||||||
elif isinstance(m, nn.Conv2d):
|
elif isinstance(m, nn.Conv2d):
|
||||||
_lecun_normal(m.weight)
|
lecun_normal_(m.weight)
|
||||||
if m.bias is not None:
|
if m.bias is not None:
|
||||||
nn.init.zeros_(m.bias)
|
nn.init.zeros_(m.bias)
|
||||||
elif isinstance(m, nn.LayerNorm):
|
elif isinstance(m, nn.LayerNorm):
|
||||||
@ -544,7 +551,7 @@ class DistilledVisionTransformer(VisionTransformer):
|
|||||||
|
|
||||||
trunc_normal_(self.dist_token, std=.02)
|
trunc_normal_(self.dist_token, std=.02)
|
||||||
trunc_normal_(self.pos_embed, std=.02)
|
trunc_normal_(self.pos_embed, std=.02)
|
||||||
self.head_dist.apply(_init_weights_original)
|
self.head_dist.apply(_init_weights_new)
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
B = x.shape[0]
|
B = x.shape[0]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user