mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Fix] Fix wrong init usage in transformer models (#1069)
* fix wrong trunc_normal_init usage * fix mit init weights * fix vit init weights * fix mit init weights * fix typo * fix swin init weights
This commit is contained in:
parent
6a3c31ae3f
commit
b8ca9b6d99
@ -4,10 +4,11 @@ import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import (Conv2d, build_activation_layer, build_norm_layer,
|
||||
constant_init, normal_init, trunc_normal_init)
|
||||
from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer
|
||||
from mmcv.cnn.bricks.drop import build_dropout
|
||||
from mmcv.cnn.bricks.transformer import MultiheadAttention
|
||||
from mmcv.cnn.utils.weight_init import (constant_init, normal_init,
|
||||
trunc_normal_init)
|
||||
from mmcv.runner import BaseModule, ModuleList, Sequential, _load_checkpoint
|
||||
|
||||
from ...utils import get_root_logger
|
||||
@ -343,7 +344,7 @@ class MixVisionTransformer(BaseModule):
|
||||
norm_cfg=dict(type='LN', eps=1e-6),
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super().__init__()
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
if isinstance(pretrained, str) or pretrained is None:
|
||||
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
||||
@ -365,7 +366,6 @@ class MixVisionTransformer(BaseModule):
|
||||
self.out_indices = out_indices
|
||||
assert max(out_indices) < self.num_stages
|
||||
self.pretrained = pretrained
|
||||
self.init_cfg = init_cfg
|
||||
|
||||
# transformer encoder
|
||||
dpr = [
|
||||
@ -407,19 +407,15 @@ class MixVisionTransformer(BaseModule):
|
||||
if self.pretrained is None:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_init(m.weight, std=.02)
|
||||
if m.bias is not None:
|
||||
constant_init(m.bias, 0)
|
||||
trunc_normal_init(m, std=.02, bias=0.)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
constant_init(m.bias, 0)
|
||||
constant_init(m.weight, 1.0)
|
||||
constant_init(m, val=1.0, bias=0.)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[
|
||||
1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
normal_init(m.weight, 0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
constant_init(m.bias, 0)
|
||||
normal_init(
|
||||
m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0)
|
||||
elif isinstance(self.pretrained, str):
|
||||
logger = get_root_logger()
|
||||
checkpoint = _load_checkpoint(
|
||||
|
@ -7,8 +7,10 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import build_norm_layer, constant_init, trunc_normal_init
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.transformer import FFN, build_dropout
|
||||
from mmcv.cnn.utils.weight_init import (constant_init, trunc_normal_,
|
||||
trunc_normal_init)
|
||||
from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
|
||||
from mmcv.utils import to_2tuple
|
||||
|
||||
@ -73,7 +75,7 @@ class WindowMSA(BaseModule):
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def init_weights(self):
|
||||
trunc_normal_init(self.relative_position_bias_table, std=0.02)
|
||||
trunc_normal_(self.relative_position_bias_table, std=0.02)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
"""
|
||||
@ -665,15 +667,12 @@ class SwinTransformer(BaseModule):
|
||||
f'{self.__class__.__name__}, '
|
||||
f'training start from scratch')
|
||||
if self.use_abs_pos_embed:
|
||||
trunc_normal_init(self.absolute_pos_embed, std=0.02)
|
||||
trunc_normal_(self.absolute_pos_embed, std=0.02)
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_init(m.weight, std=.02)
|
||||
if m.bias is not None:
|
||||
constant_init(m.bias, 0)
|
||||
trunc_normal_init(m, std=.02, bias=0.)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
constant_init(m.bias, 0)
|
||||
constant_init(m.weight, 1.0)
|
||||
constant_init(m, val=1.0, bias=0.)
|
||||
else:
|
||||
assert 'checkpoint' in self.init_cfg, f'Only support ' \
|
||||
f'specify `Pretrained` in ' \
|
||||
|
@ -4,9 +4,10 @@ import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import (build_norm_layer, constant_init, kaiming_init,
|
||||
normal_init, trunc_normal_init)
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
|
||||
from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init,
|
||||
trunc_normal_)
|
||||
from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
from torch.nn.modules.utils import _pair as to_2tuple
|
||||
@ -292,23 +293,20 @@ class VisionTransformer(BaseModule):
|
||||
else:
|
||||
# We only implement the 'jax_impl' initialization implemented at
|
||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
|
||||
trunc_normal_init(self.pos_embed, std=.02)
|
||||
trunc_normal_init(self.cls_token, std=.02)
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
for n, m in self.named_modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_init(m.weight, std=.02)
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if m.bias is not None:
|
||||
if 'ffn' in n:
|
||||
normal_init(m.bias, std=1e-6)
|
||||
nn.init.normal_(m.bias, mean=0., std=1e-6)
|
||||
else:
|
||||
constant_init(m.bias, 0)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m.weight, mode='fan_in')
|
||||
if m.bias is not None:
|
||||
constant_init(m.bias, 0)
|
||||
kaiming_init(m, mode='fan_in', bias=0.)
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
|
||||
constant_init(m.bias, 0)
|
||||
constant_init(m.weight, 1.0)
|
||||
constant_init(m, val=1.0, bias=0.)
|
||||
|
||||
def _pos_embeding(self, patched_img, hw_shape, pos_embed):
|
||||
"""Positiong embeding method.
|
||||
|
Loading…
x
Reference in New Issue
Block a user