[Fix] Fix wrong init usage in transformer models (#1069)

* fix wrong trunc_normal_init usage

* fix mit init weights

* fix vit init weights

* fix mit init weights

* fix typo

* fix swin init weights
This commit is contained in:
Junjun2016 2021-12-06 19:59:33 +08:00 committed by GitHub
parent 6a3c31ae3f
commit b8ca9b6d99
3 changed files with 25 additions and 32 deletions

View File

@ -4,10 +4,11 @@ import warnings
import torch
import torch.nn as nn
from mmcv.cnn import (Conv2d, build_activation_layer, build_norm_layer,
constant_init, normal_init, trunc_normal_init)
from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer
from mmcv.cnn.bricks.drop import build_dropout
from mmcv.cnn.bricks.transformer import MultiheadAttention
from mmcv.cnn.utils.weight_init import (constant_init, normal_init,
trunc_normal_init)
from mmcv.runner import BaseModule, ModuleList, Sequential, _load_checkpoint
from ...utils import get_root_logger
@ -343,7 +344,7 @@ class MixVisionTransformer(BaseModule):
norm_cfg=dict(type='LN', eps=1e-6),
pretrained=None,
init_cfg=None):
super().__init__()
super().__init__(init_cfg=init_cfg)
if isinstance(pretrained, str) or pretrained is None:
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
@ -365,7 +366,6 @@ class MixVisionTransformer(BaseModule):
self.out_indices = out_indices
assert max(out_indices) < self.num_stages
self.pretrained = pretrained
self.init_cfg = init_cfg
# transformer encoder
dpr = [
@ -407,19 +407,15 @@ class MixVisionTransformer(BaseModule):
if self.pretrained is None:
for m in self.modules():
if isinstance(m, nn.Linear):
trunc_normal_init(m.weight, std=.02)
if m.bias is not None:
constant_init(m.bias, 0)
trunc_normal_init(m, std=.02, bias=0.)
elif isinstance(m, nn.LayerNorm):
constant_init(m.bias, 0)
constant_init(m.weight, 1.0)
constant_init(m, val=1.0, bias=0.)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[
1] * m.out_channels
fan_out //= m.groups
normal_init(m.weight, 0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
constant_init(m.bias, 0)
normal_init(
m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0)
elif isinstance(self.pretrained, str):
logger = get_root_logger()
checkpoint = _load_checkpoint(

View File

@ -7,8 +7,10 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from mmcv.cnn import build_norm_layer, constant_init, trunc_normal_init
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN, build_dropout
from mmcv.cnn.utils.weight_init import (constant_init, trunc_normal_,
trunc_normal_init)
from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
from mmcv.utils import to_2tuple
@ -73,7 +75,7 @@ class WindowMSA(BaseModule):
self.softmax = nn.Softmax(dim=-1)
def init_weights(self):
trunc_normal_init(self.relative_position_bias_table, std=0.02)
trunc_normal_(self.relative_position_bias_table, std=0.02)
def forward(self, x, mask=None):
"""
@ -665,15 +667,12 @@ class SwinTransformer(BaseModule):
f'{self.__class__.__name__}, '
f'training start from scratch')
if self.use_abs_pos_embed:
trunc_normal_init(self.absolute_pos_embed, std=0.02)
trunc_normal_(self.absolute_pos_embed, std=0.02)
for m in self.modules():
if isinstance(m, nn.Linear):
trunc_normal_init(m.weight, std=.02)
if m.bias is not None:
constant_init(m.bias, 0)
trunc_normal_init(m, std=.02, bias=0.)
elif isinstance(m, nn.LayerNorm):
constant_init(m.bias, 0)
constant_init(m.weight, 1.0)
constant_init(m, val=1.0, bias=0.)
else:
assert 'checkpoint' in self.init_cfg, f'Only support ' \
f'specify `Pretrained` in ' \

View File

@ -4,9 +4,10 @@ import warnings
import torch
import torch.nn as nn
from mmcv.cnn import (build_norm_layer, constant_init, kaiming_init,
normal_init, trunc_normal_init)
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init,
trunc_normal_)
from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.utils import _pair as to_2tuple
@ -292,23 +293,20 @@ class VisionTransformer(BaseModule):
else:
# We only implement the 'jax_impl' initialization implemented at
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
trunc_normal_init(self.pos_embed, std=.02)
trunc_normal_init(self.cls_token, std=.02)
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
for n, m in self.named_modules():
if isinstance(m, nn.Linear):
trunc_normal_init(m.weight, std=.02)
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
if 'ffn' in n:
normal_init(m.bias, std=1e-6)
nn.init.normal_(m.bias, mean=0., std=1e-6)
else:
constant_init(m.bias, 0)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Conv2d):
kaiming_init(m.weight, mode='fan_in')
if m.bias is not None:
constant_init(m.bias, 0)
kaiming_init(m, mode='fan_in', bias=0.)
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
constant_init(m.bias, 0)
constant_init(m.weight, 1.0)
constant_init(m, val=1.0, bias=0.)
def _pos_embeding(self, patched_img, hw_shape, pos_embed):
"""Positiong embeding method.