mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Fix] Fix wrong init usage in transformer models (#1069)
* fix wrong trunc_normal_init usage * fix mit init weights * fix vit init weights * fix mit init weights * fix typo * fix swin init weights
This commit is contained in:
parent
6a3c31ae3f
commit
b8ca9b6d99
@ -4,10 +4,11 @@ import warnings
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from mmcv.cnn import (Conv2d, build_activation_layer, build_norm_layer,
|
from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer
|
||||||
constant_init, normal_init, trunc_normal_init)
|
|
||||||
from mmcv.cnn.bricks.drop import build_dropout
|
from mmcv.cnn.bricks.drop import build_dropout
|
||||||
from mmcv.cnn.bricks.transformer import MultiheadAttention
|
from mmcv.cnn.bricks.transformer import MultiheadAttention
|
||||||
|
from mmcv.cnn.utils.weight_init import (constant_init, normal_init,
|
||||||
|
trunc_normal_init)
|
||||||
from mmcv.runner import BaseModule, ModuleList, Sequential, _load_checkpoint
|
from mmcv.runner import BaseModule, ModuleList, Sequential, _load_checkpoint
|
||||||
|
|
||||||
from ...utils import get_root_logger
|
from ...utils import get_root_logger
|
||||||
@ -343,7 +344,7 @@ class MixVisionTransformer(BaseModule):
|
|||||||
norm_cfg=dict(type='LN', eps=1e-6),
|
norm_cfg=dict(type='LN', eps=1e-6),
|
||||||
pretrained=None,
|
pretrained=None,
|
||||||
init_cfg=None):
|
init_cfg=None):
|
||||||
super().__init__()
|
super().__init__(init_cfg=init_cfg)
|
||||||
|
|
||||||
if isinstance(pretrained, str) or pretrained is None:
|
if isinstance(pretrained, str) or pretrained is None:
|
||||||
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
||||||
@ -365,7 +366,6 @@ class MixVisionTransformer(BaseModule):
|
|||||||
self.out_indices = out_indices
|
self.out_indices = out_indices
|
||||||
assert max(out_indices) < self.num_stages
|
assert max(out_indices) < self.num_stages
|
||||||
self.pretrained = pretrained
|
self.pretrained = pretrained
|
||||||
self.init_cfg = init_cfg
|
|
||||||
|
|
||||||
# transformer encoder
|
# transformer encoder
|
||||||
dpr = [
|
dpr = [
|
||||||
@ -407,19 +407,15 @@ class MixVisionTransformer(BaseModule):
|
|||||||
if self.pretrained is None:
|
if self.pretrained is None:
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
if isinstance(m, nn.Linear):
|
if isinstance(m, nn.Linear):
|
||||||
trunc_normal_init(m.weight, std=.02)
|
trunc_normal_init(m, std=.02, bias=0.)
|
||||||
if m.bias is not None:
|
|
||||||
constant_init(m.bias, 0)
|
|
||||||
elif isinstance(m, nn.LayerNorm):
|
elif isinstance(m, nn.LayerNorm):
|
||||||
constant_init(m.bias, 0)
|
constant_init(m, val=1.0, bias=0.)
|
||||||
constant_init(m.weight, 1.0)
|
|
||||||
elif isinstance(m, nn.Conv2d):
|
elif isinstance(m, nn.Conv2d):
|
||||||
fan_out = m.kernel_size[0] * m.kernel_size[
|
fan_out = m.kernel_size[0] * m.kernel_size[
|
||||||
1] * m.out_channels
|
1] * m.out_channels
|
||||||
fan_out //= m.groups
|
fan_out //= m.groups
|
||||||
normal_init(m.weight, 0, math.sqrt(2.0 / fan_out))
|
normal_init(
|
||||||
if m.bias is not None:
|
m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0)
|
||||||
constant_init(m.bias, 0)
|
|
||||||
elif isinstance(self.pretrained, str):
|
elif isinstance(self.pretrained, str):
|
||||||
logger = get_root_logger()
|
logger = get_root_logger()
|
||||||
checkpoint = _load_checkpoint(
|
checkpoint = _load_checkpoint(
|
||||||
|
@ -7,8 +7,10 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.utils.checkpoint as cp
|
import torch.utils.checkpoint as cp
|
||||||
from mmcv.cnn import build_norm_layer, constant_init, trunc_normal_init
|
from mmcv.cnn import build_norm_layer
|
||||||
from mmcv.cnn.bricks.transformer import FFN, build_dropout
|
from mmcv.cnn.bricks.transformer import FFN, build_dropout
|
||||||
|
from mmcv.cnn.utils.weight_init import (constant_init, trunc_normal_,
|
||||||
|
trunc_normal_init)
|
||||||
from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
|
from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
|
||||||
from mmcv.utils import to_2tuple
|
from mmcv.utils import to_2tuple
|
||||||
|
|
||||||
@ -73,7 +75,7 @@ class WindowMSA(BaseModule):
|
|||||||
self.softmax = nn.Softmax(dim=-1)
|
self.softmax = nn.Softmax(dim=-1)
|
||||||
|
|
||||||
def init_weights(self):
|
def init_weights(self):
|
||||||
trunc_normal_init(self.relative_position_bias_table, std=0.02)
|
trunc_normal_(self.relative_position_bias_table, std=0.02)
|
||||||
|
|
||||||
def forward(self, x, mask=None):
|
def forward(self, x, mask=None):
|
||||||
"""
|
"""
|
||||||
@ -665,15 +667,12 @@ class SwinTransformer(BaseModule):
|
|||||||
f'{self.__class__.__name__}, '
|
f'{self.__class__.__name__}, '
|
||||||
f'training start from scratch')
|
f'training start from scratch')
|
||||||
if self.use_abs_pos_embed:
|
if self.use_abs_pos_embed:
|
||||||
trunc_normal_init(self.absolute_pos_embed, std=0.02)
|
trunc_normal_(self.absolute_pos_embed, std=0.02)
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
if isinstance(m, nn.Linear):
|
if isinstance(m, nn.Linear):
|
||||||
trunc_normal_init(m.weight, std=.02)
|
trunc_normal_init(m, std=.02, bias=0.)
|
||||||
if m.bias is not None:
|
|
||||||
constant_init(m.bias, 0)
|
|
||||||
elif isinstance(m, nn.LayerNorm):
|
elif isinstance(m, nn.LayerNorm):
|
||||||
constant_init(m.bias, 0)
|
constant_init(m, val=1.0, bias=0.)
|
||||||
constant_init(m.weight, 1.0)
|
|
||||||
else:
|
else:
|
||||||
assert 'checkpoint' in self.init_cfg, f'Only support ' \
|
assert 'checkpoint' in self.init_cfg, f'Only support ' \
|
||||||
f'specify `Pretrained` in ' \
|
f'specify `Pretrained` in ' \
|
||||||
|
@ -4,9 +4,10 @@ import warnings
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from mmcv.cnn import (build_norm_layer, constant_init, kaiming_init,
|
from mmcv.cnn import build_norm_layer
|
||||||
normal_init, trunc_normal_init)
|
|
||||||
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
|
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
|
||||||
|
from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init,
|
||||||
|
trunc_normal_)
|
||||||
from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
|
from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
|
||||||
from torch.nn.modules.batchnorm import _BatchNorm
|
from torch.nn.modules.batchnorm import _BatchNorm
|
||||||
from torch.nn.modules.utils import _pair as to_2tuple
|
from torch.nn.modules.utils import _pair as to_2tuple
|
||||||
@ -292,23 +293,20 @@ class VisionTransformer(BaseModule):
|
|||||||
else:
|
else:
|
||||||
# We only implement the 'jax_impl' initialization implemented at
|
# We only implement the 'jax_impl' initialization implemented at
|
||||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
|
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
|
||||||
trunc_normal_init(self.pos_embed, std=.02)
|
trunc_normal_(self.pos_embed, std=.02)
|
||||||
trunc_normal_init(self.cls_token, std=.02)
|
trunc_normal_(self.cls_token, std=.02)
|
||||||
for n, m in self.named_modules():
|
for n, m in self.named_modules():
|
||||||
if isinstance(m, nn.Linear):
|
if isinstance(m, nn.Linear):
|
||||||
trunc_normal_init(m.weight, std=.02)
|
trunc_normal_(m.weight, std=.02)
|
||||||
if m.bias is not None:
|
if m.bias is not None:
|
||||||
if 'ffn' in n:
|
if 'ffn' in n:
|
||||||
normal_init(m.bias, std=1e-6)
|
nn.init.normal_(m.bias, mean=0., std=1e-6)
|
||||||
else:
|
else:
|
||||||
constant_init(m.bias, 0)
|
nn.init.constant_(m.bias, 0)
|
||||||
elif isinstance(m, nn.Conv2d):
|
elif isinstance(m, nn.Conv2d):
|
||||||
kaiming_init(m.weight, mode='fan_in')
|
kaiming_init(m, mode='fan_in', bias=0.)
|
||||||
if m.bias is not None:
|
|
||||||
constant_init(m.bias, 0)
|
|
||||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
|
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
|
||||||
constant_init(m.bias, 0)
|
constant_init(m, val=1.0, bias=0.)
|
||||||
constant_init(m.weight, 1.0)
|
|
||||||
|
|
||||||
def _pos_embeding(self, patched_img, hw_shape, pos_embed):
|
def _pos_embeding(self, patched_img, hw_shape, pos_embed):
|
||||||
"""Positiong embeding method.
|
"""Positiong embeding method.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user