diff --git a/mmseg/models/backbones/mit.py b/mmseg/models/backbones/mit.py index dd68decab..d2c7c2302 100644 --- a/mmseg/models/backbones/mit.py +++ b/mmseg/models/backbones/mit.py @@ -4,10 +4,11 @@ import warnings import torch import torch.nn as nn -from mmcv.cnn import (Conv2d, build_activation_layer, build_norm_layer, - constant_init, normal_init, trunc_normal_init) +from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer from mmcv.cnn.bricks.drop import build_dropout from mmcv.cnn.bricks.transformer import MultiheadAttention +from mmcv.cnn.utils.weight_init import (constant_init, normal_init, + trunc_normal_init) from mmcv.runner import BaseModule, ModuleList, Sequential, _load_checkpoint from ...utils import get_root_logger @@ -343,7 +344,7 @@ class MixVisionTransformer(BaseModule): norm_cfg=dict(type='LN', eps=1e-6), pretrained=None, init_cfg=None): - super().__init__() + super().__init__(init_cfg=init_cfg) if isinstance(pretrained, str) or pretrained is None: warnings.warn('DeprecationWarning: pretrained is a deprecated, ' @@ -365,7 +366,6 @@ class MixVisionTransformer(BaseModule): self.out_indices = out_indices assert max(out_indices) < self.num_stages self.pretrained = pretrained - self.init_cfg = init_cfg # transformer encoder dpr = [ @@ -407,19 +407,15 @@ class MixVisionTransformer(BaseModule): if self.pretrained is None: for m in self.modules(): if isinstance(m, nn.Linear): - trunc_normal_init(m.weight, std=.02) - if m.bias is not None: - constant_init(m.bias, 0) + trunc_normal_init(m, std=.02, bias=0.) elif isinstance(m, nn.LayerNorm): - constant_init(m.bias, 0) - constant_init(m.weight, 1.0) + constant_init(m, val=1.0, bias=0.) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[ 1] * m.out_channels fan_out //= m.groups - normal_init(m.weight, 0, math.sqrt(2.0 / fan_out)) - if m.bias is not None: - constant_init(m.bias, 0) + normal_init( + m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0) elif isinstance(self.pretrained, str): logger = get_root_logger() checkpoint = _load_checkpoint( diff --git a/mmseg/models/backbones/swin.py b/mmseg/models/backbones/swin.py index 59f4616c3..a360ab018 100644 --- a/mmseg/models/backbones/swin.py +++ b/mmseg/models/backbones/swin.py @@ -7,8 +7,10 @@ import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as cp -from mmcv.cnn import build_norm_layer, constant_init, trunc_normal_init +from mmcv.cnn import build_norm_layer from mmcv.cnn.bricks.transformer import FFN, build_dropout +from mmcv.cnn.utils.weight_init import (constant_init, trunc_normal_, + trunc_normal_init) from mmcv.runner import BaseModule, ModuleList, _load_checkpoint from mmcv.utils import to_2tuple @@ -73,7 +75,7 @@ class WindowMSA(BaseModule): self.softmax = nn.Softmax(dim=-1) def init_weights(self): - trunc_normal_init(self.relative_position_bias_table, std=0.02) + trunc_normal_(self.relative_position_bias_table, std=0.02) def forward(self, x, mask=None): """ @@ -665,15 +667,12 @@ class SwinTransformer(BaseModule): f'{self.__class__.__name__}, ' f'training start from scratch') if self.use_abs_pos_embed: - trunc_normal_init(self.absolute_pos_embed, std=0.02) + trunc_normal_(self.absolute_pos_embed, std=0.02) for m in self.modules(): if isinstance(m, nn.Linear): - trunc_normal_init(m.weight, std=.02) - if m.bias is not None: - constant_init(m.bias, 0) + trunc_normal_init(m, std=.02, bias=0.) elif isinstance(m, nn.LayerNorm): - constant_init(m.bias, 0) - constant_init(m.weight, 1.0) + constant_init(m, val=1.0, bias=0.) else: assert 'checkpoint' in self.init_cfg, f'Only support ' \ f'specify `Pretrained` in ' \ diff --git a/mmseg/models/backbones/vit.py b/mmseg/models/backbones/vit.py index f5afbb7f7..965652503 100644 --- a/mmseg/models/backbones/vit.py +++ b/mmseg/models/backbones/vit.py @@ -4,9 +4,10 @@ import warnings import torch import torch.nn as nn -from mmcv.cnn import (build_norm_layer, constant_init, kaiming_init, - normal_init, trunc_normal_init) +from mmcv.cnn import build_norm_layer from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention +from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init, + trunc_normal_) from mmcv.runner import BaseModule, ModuleList, _load_checkpoint from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.modules.utils import _pair as to_2tuple @@ -292,23 +293,20 @@ class VisionTransformer(BaseModule): else: # We only implement the 'jax_impl' initialization implemented at # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501 - trunc_normal_init(self.pos_embed, std=.02) - trunc_normal_init(self.cls_token, std=.02) + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) for n, m in self.named_modules(): if isinstance(m, nn.Linear): - trunc_normal_init(m.weight, std=.02) + trunc_normal_(m.weight, std=.02) if m.bias is not None: if 'ffn' in n: - normal_init(m.bias, std=1e-6) + nn.init.normal_(m.bias, mean=0., std=1e-6) else: - constant_init(m.bias, 0) + nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Conv2d): - kaiming_init(m.weight, mode='fan_in') - if m.bias is not None: - constant_init(m.bias, 0) + kaiming_init(m, mode='fan_in', bias=0.) elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): - constant_init(m.bias, 0) - constant_init(m.weight, 1.0) + constant_init(m, val=1.0, bias=0.) def _pos_embeding(self, patched_img, hw_shape, pos_embed): """Positiong embeding method.