diff --git a/mmseg/models/backbones/cgnet.py b/mmseg/models/backbones/cgnet.py index 032a55d85..32bdbc4c1 100644 --- a/mmseg/models/backbones/cgnet.py +++ b/mmseg/models/backbones/cgnet.py @@ -1,12 +1,12 @@ +import warnings + import torch import torch.nn as nn import torch.utils.checkpoint as cp -from mmcv.cnn import (ConvModule, build_conv_layer, build_norm_layer, - constant_init, kaiming_init) -from mmcv.runner import load_checkpoint +from mmcv.cnn import ConvModule, build_conv_layer, build_norm_layer +from mmcv.runner import BaseModule from mmcv.utils.parrots_wrapper import _BatchNorm -from mmseg.utils import get_root_logger from ..builder import BACKBONES @@ -183,7 +183,7 @@ class InputInjection(nn.Module): @BACKBONES.register_module() -class CGNet(nn.Module): +class CGNet(BaseModule): """CGNet backbone. A Light-weight Context Guided Network for Semantic Segmentation @@ -210,6 +210,9 @@ class CGNet(nn.Module): and its variants only. Default: False. with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False. + pretrained (str, optional): model pretrained path. Default: None + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None """ def __init__(self, @@ -222,9 +225,31 @@ class CGNet(nn.Module): norm_cfg=dict(type='BN', requires_grad=True), act_cfg=dict(type='PReLU'), norm_eval=False, - with_cp=False): + with_cp=False, + pretrained=None, + init_cfg=None): + + super(CGNet, self).__init__(init_cfg) + + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be setting at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is a deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer=['Conv2d', 'Linear']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']), + dict(type='Constant', val=0, layer='PReLU') + ] + else: + raise TypeError('pretrained must be a str or None') - super(CGNet, self).__init__() self.in_channels = in_channels self.num_channels = num_channels assert isinstance(self.num_channels, tuple) and len( @@ -335,27 +360,6 @@ class CGNet(nn.Module): return output - def init_weights(self, pretrained=None): - """Initialize the weights in backbone. - - Args: - pretrained (str, optional): Path to pre-trained weights. - Defaults to None. - """ - if isinstance(pretrained, str): - logger = get_root_logger() - load_checkpoint(self, pretrained, strict=False, logger=logger) - elif pretrained is None: - for m in self.modules(): - if isinstance(m, (nn.Conv2d, nn.Linear)): - kaiming_init(m) - elif isinstance(m, (_BatchNorm, nn.GroupNorm)): - constant_init(m, 1) - elif isinstance(m, nn.PReLU): - constant_init(m, 0) - else: - raise TypeError('pretrained must be a str or None') - def train(self, mode=True): """Convert the model into training mode will keeping the normalization layer freezed.""" diff --git a/mmseg/models/backbones/fast_scnn.py b/mmseg/models/backbones/fast_scnn.py index ee115ffda..e8a87037d 100644 --- a/mmseg/models/backbones/fast_scnn.py +++ b/mmseg/models/backbones/fast_scnn.py @@ -1,8 +1,7 @@ import torch import torch.nn as nn -from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule, constant_init, - kaiming_init) -from torch.nn.modules.batchnorm import _BatchNorm +from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule +from mmcv.runner import BaseModule from mmseg.models.decode_heads.psp_head import PPM from mmseg.ops import resize @@ -247,7 +246,7 @@ class FeatureFusionModule(nn.Module): @BACKBONES.register_module() -class FastSCNN(nn.Module): +class FastSCNN(BaseModule): """Fast-SCNN Backbone. Args: @@ -291,6 +290,8 @@ class FastSCNN(nn.Module): dict(type='ReLU') align_corners (bool): align_corners argument of F.interpolate. Default: False + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None """ def __init__(self, @@ -307,9 +308,18 @@ class FastSCNN(nn.Module): conv_cfg=None, norm_cfg=dict(type='BN'), act_cfg=dict(type='ReLU'), - align_corners=False): + align_corners=False, + init_cfg=None): + + super(FastSCNN, self).__init__(init_cfg) + + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm']) + ] - super(FastSCNN, self).__init__() if global_in_channels != higher_in_channels: raise AssertionError('Global Input Channels must be the same \ with Higher Input Channels!') @@ -357,13 +367,6 @@ class FastSCNN(nn.Module): act_cfg=self.act_cfg, align_corners=self.align_corners) - def init_weights(self, pretrained=None): - for m in self.modules(): - if isinstance(m, nn.Conv2d): - kaiming_init(m) - elif isinstance(m, (_BatchNorm, nn.GroupNorm)): - constant_init(m, 1) - def forward(self, x): higher_res_features = self.learning_to_downsample(x) lower_res_features = self.global_feature_extractor(higher_res_features) diff --git a/mmseg/models/backbones/hrnet.py b/mmseg/models/backbones/hrnet.py index 5010a2e76..055fc985b 100644 --- a/mmseg/models/backbones/hrnet.py +++ b/mmseg/models/backbones/hrnet.py @@ -1,16 +1,16 @@ +import warnings + import torch.nn as nn -from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init, - kaiming_init) -from mmcv.runner import load_checkpoint +from mmcv.cnn import build_conv_layer, build_norm_layer +from mmcv.runner import BaseModule, ModuleList, Sequential from mmcv.utils.parrots_wrapper import _BatchNorm from mmseg.ops import Upsample, resize -from mmseg.utils import get_root_logger from ..builder import BACKBONES from .resnet import BasicBlock, Bottleneck -class HRModule(nn.Module): +class HRModule(BaseModule): """High-Resolution Module for HRNet. In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange @@ -26,8 +26,11 @@ class HRModule(nn.Module): multiscale_output=True, with_cp=False, conv_cfg=None, - norm_cfg=dict(type='BN', requires_grad=True)): - super(HRModule, self).__init__() + norm_cfg=dict(type='BN', requires_grad=True), + block_init_cfg=None, + init_cfg=None): + super(HRModule, self).__init__(init_cfg) + self.block_init_cfg = block_init_cfg self._check_branches(num_branches, num_blocks, in_channels, num_channels) @@ -92,7 +95,8 @@ class HRModule(nn.Module): downsample=downsample, with_cp=self.with_cp, norm_cfg=self.norm_cfg, - conv_cfg=self.conv_cfg)) + conv_cfg=self.conv_cfg, + init_cfg=self.block_init_cfg)) self.in_channels[branch_index] = \ num_channels[branch_index] * block.expansion for i in range(1, num_blocks[branch_index]): @@ -102,9 +106,10 @@ class HRModule(nn.Module): num_channels[branch_index], with_cp=self.with_cp, norm_cfg=self.norm_cfg, - conv_cfg=self.conv_cfg)) + conv_cfg=self.conv_cfg, + init_cfg=self.block_init_cfg)) - return nn.Sequential(*layers) + return Sequential(*layers) def _make_branches(self, num_branches, block, num_blocks, num_channels): """Build multiple branch.""" @@ -114,7 +119,7 @@ class HRModule(nn.Module): branches.append( self._make_one_branch(i, block, num_blocks, num_channels)) - return nn.ModuleList(branches) + return ModuleList(branches) def _make_fuse_layers(self): """Build fuse layer.""" @@ -209,7 +214,7 @@ class HRModule(nn.Module): @BACKBONES.register_module() -class HRNet(nn.Module): +class HRNet(BaseModule): """HRNet backbone. High-Resolution Representations for Labeling Pixels and Regions @@ -227,6 +232,9 @@ class HRNet(nn.Module): memory while slowing down the training speed. zero_init_residual (bool): whether to use zero init for last norm layer in resblocks to let them behave as identity. + pretrained (str, optional): model pretrained path. Default: None + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None Example: >>> from mmseg.models import HRNet @@ -277,14 +285,36 @@ class HRNet(nn.Module): norm_cfg=dict(type='BN', requires_grad=True), norm_eval=False, with_cp=False, - zero_init_residual=False): - super(HRNet, self).__init__() + zero_init_residual=False, + pretrained=None, + init_cfg=None): + super(HRNet, self).__init__(init_cfg) + + self.pretrained = pretrained + self.zero_init_residual = zero_init_residual + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be setting at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is a deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ] + else: + raise TypeError('pretrained must be a str or None') + self.extra = extra self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.norm_eval = norm_eval self.with_cp = with_cp - self.zero_init_residual = zero_init_residual # stem net self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1) @@ -430,6 +460,16 @@ class HRNet(nn.Module): build_norm_layer(self.norm_cfg, planes * block.expansion)[1]) layers = [] + block_init_cfg = None + if self.pretrained is None and not hasattr( + self, 'init_cfg') and self.zero_init_residual: + if block is BasicBlock: + block_init_cfg = dict( + type='Constant', val=0, override=dict(name='norm2')) + elif block is Bottleneck: + block_init_cfg = dict( + type='Constant', val=0, override=dict(name='norm3')) + layers.append( block( inplanes, @@ -438,7 +478,8 @@ class HRNet(nn.Module): downsample=downsample, with_cp=self.with_cp, norm_cfg=self.norm_cfg, - conv_cfg=self.conv_cfg)) + conv_cfg=self.conv_cfg, + init_cfg=block_init_cfg)) inplanes = planes * block.expansion for i in range(1, blocks): layers.append( @@ -447,9 +488,10 @@ class HRNet(nn.Module): planes, with_cp=self.with_cp, norm_cfg=self.norm_cfg, - conv_cfg=self.conv_cfg)) + conv_cfg=self.conv_cfg, + init_cfg=block_init_cfg)) - return nn.Sequential(*layers) + return Sequential(*layers) def _make_stage(self, layer_config, in_channels, multiscale_output=True): """Make each stage.""" @@ -460,6 +502,16 @@ class HRNet(nn.Module): block = self.blocks_dict[layer_config['block']] hr_modules = [] + block_init_cfg = None + if self.pretrained is None and not hasattr( + self, 'init_cfg') and self.zero_init_residual: + if block is BasicBlock: + block_init_cfg = dict( + type='Constant', val=0, override=dict(name='norm2')) + elif block is Bottleneck: + block_init_cfg = dict( + type='Constant', val=0, override=dict(name='norm3')) + for i in range(num_modules): # multi_scale_output is only used for the last module if not multiscale_output and i == num_modules - 1: @@ -477,35 +529,10 @@ class HRNet(nn.Module): reset_multiscale_output, with_cp=self.with_cp, norm_cfg=self.norm_cfg, - conv_cfg=self.conv_cfg)) + conv_cfg=self.conv_cfg, + block_init_cfg=block_init_cfg)) - return nn.Sequential(*hr_modules), in_channels - - def init_weights(self, pretrained=None): - """Initialize the weights in backbone. - - Args: - pretrained (str, optional): Path to pre-trained weights. - Defaults to None. - """ - if isinstance(pretrained, str): - logger = get_root_logger() - load_checkpoint(self, pretrained, strict=False, logger=logger) - elif pretrained is None: - for m in self.modules(): - if isinstance(m, nn.Conv2d): - kaiming_init(m) - elif isinstance(m, (_BatchNorm, nn.GroupNorm)): - constant_init(m, 1) - - if self.zero_init_residual: - for m in self.modules(): - if isinstance(m, Bottleneck): - constant_init(m.norm3, 0) - elif isinstance(m, BasicBlock): - constant_init(m.norm2, 0) - else: - raise TypeError('pretrained must be a str or None') + return Sequential(*hr_modules), in_channels def forward(self, x): """Forward function.""" diff --git a/mmseg/models/backbones/mobilenet_v2.py b/mmseg/models/backbones/mobilenet_v2.py index 9ab628e2a..46d57fbb5 100644 --- a/mmseg/models/backbones/mobilenet_v2.py +++ b/mmseg/models/backbones/mobilenet_v2.py @@ -1,8 +1,8 @@ -import logging +import warnings import torch.nn as nn -from mmcv.cnn import ConvModule, constant_init, kaiming_init -from mmcv.runner import load_checkpoint +from mmcv.cnn import ConvModule +from mmcv.runner import BaseModule from torch.nn.modules.batchnorm import _BatchNorm from ..builder import BACKBONES @@ -10,7 +10,7 @@ from ..utils import InvertedResidual, make_divisible @BACKBONES.register_module() -class MobileNetV2(nn.Module): +class MobileNetV2(BaseModule): """MobileNetV2 backbone. Args: @@ -35,6 +35,9 @@ class MobileNetV2(nn.Module): and its variants only. Default: False. with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False. + pretrained (str, optional): model pretrained path. Default: None + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None """ # Parameters to build layers. 3 parameters are needed to construct a @@ -52,8 +55,30 @@ class MobileNetV2(nn.Module): norm_cfg=dict(type='BN'), act_cfg=dict(type='ReLU6'), norm_eval=False, - with_cp=False): - super(MobileNetV2, self).__init__() + with_cp=False, + pretrained=None, + init_cfg=None): + super(MobileNetV2, self).__init__(init_cfg) + + self.pretrained = pretrained + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be setting at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is a deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ] + else: + raise TypeError('pretrained must be a str or None') + self.widen_factor = widen_factor self.strides = strides self.dilations = dilations @@ -133,19 +158,6 @@ class MobileNetV2(nn.Module): return nn.Sequential(*layers) - def init_weights(self, pretrained=None): - if isinstance(pretrained, str): - logger = logging.getLogger() - load_checkpoint(self, pretrained, strict=False, logger=logger) - elif pretrained is None: - for m in self.modules(): - if isinstance(m, nn.Conv2d): - kaiming_init(m) - elif isinstance(m, (_BatchNorm, nn.GroupNorm)): - constant_init(m, 1) - else: - raise TypeError('pretrained must be a str or None') - def forward(self, x): x = self.conv1(x) diff --git a/mmseg/models/backbones/mobilenet_v3.py b/mmseg/models/backbones/mobilenet_v3.py index f2e9a0cc0..ae0b45db8 100644 --- a/mmseg/models/backbones/mobilenet_v3.py +++ b/mmseg/models/backbones/mobilenet_v3.py @@ -1,10 +1,9 @@ -import logging +import warnings import mmcv -import torch.nn as nn -from mmcv.cnn import ConvModule, constant_init, kaiming_init +from mmcv.cnn import ConvModule from mmcv.cnn.bricks import Conv2dAdaptivePadding -from mmcv.runner import load_checkpoint +from mmcv.runner import BaseModule from torch.nn.modules.batchnorm import _BatchNorm from ..builder import BACKBONES @@ -12,7 +11,7 @@ from ..utils import InvertedResidualV3 as InvertedResidual @BACKBONES.register_module() -class MobileNetV3(nn.Module): +class MobileNetV3(BaseModule): """MobileNetV3 backbone. This backbone is the improved implementation of `Searching for MobileNetV3 @@ -35,6 +34,9 @@ class MobileNetV3(nn.Module): with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False. + pretrained (str, optional): model pretrained path. Default: None + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None """ # Parameters to build each block: # [kernel size, mid channels, out channels, with_se, act type, stride] @@ -75,8 +77,30 @@ class MobileNetV3(nn.Module): frozen_stages=-1, reduction_factor=1, norm_eval=False, - with_cp=False): - super(MobileNetV3, self).__init__() + with_cp=False, + pretrained=None, + init_cfg=None): + super(MobileNetV3, self).__init__(init_cfg) + + self.pretrained = pretrained + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be setting at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is a deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ] + else: + raise TypeError('pretrained must be a str or None') + assert arch in self.arch_settings assert isinstance(reduction_factor, int) and reduction_factor > 0 assert mmcv.is_tuple_of(out_indices, int) @@ -217,19 +241,6 @@ class MobileNetV3(nn.Module): return layers - def init_weights(self, pretrained=None): - if isinstance(pretrained, str): - logger = logging.getLogger() - load_checkpoint(self, pretrained, strict=False, logger=logger) - elif pretrained is None: - for m in self.modules(): - if isinstance(m, nn.Conv2d): - kaiming_init(m) - elif isinstance(m, nn.BatchNorm2d): - constant_init(m, 1) - else: - raise TypeError('pretrained must be a str or None') - def forward(self, x): outs = [] for i, layer_name in enumerate(self.layers): diff --git a/mmseg/models/backbones/resnet.py b/mmseg/models/backbones/resnet.py index f6c4c08d4..e52e9122d 100644 --- a/mmseg/models/backbones/resnet.py +++ b/mmseg/models/backbones/resnet.py @@ -1,16 +1,16 @@ +import warnings + import torch.nn as nn import torch.utils.checkpoint as cp -from mmcv.cnn import (build_conv_layer, build_norm_layer, build_plugin_layer, - constant_init, kaiming_init) -from mmcv.runner import load_checkpoint +from mmcv.cnn import build_conv_layer, build_norm_layer, build_plugin_layer +from mmcv.runner import BaseModule from mmcv.utils.parrots_wrapper import _BatchNorm -from mmseg.utils import get_root_logger from ..builder import BACKBONES from ..utils import ResLayer -class BasicBlock(nn.Module): +class BasicBlock(BaseModule): """Basic block for ResNet.""" expansion = 1 @@ -26,8 +26,9 @@ class BasicBlock(nn.Module): conv_cfg=None, norm_cfg=dict(type='BN'), dcn=None, - plugins=None): - super(BasicBlock, self).__init__() + plugins=None, + init_cfg=None): + super(BasicBlock, self).__init__(init_cfg) assert dcn is None, 'Not implemented yet.' assert plugins is None, 'Not implemented yet.' @@ -94,7 +95,7 @@ class BasicBlock(nn.Module): return out -class Bottleneck(nn.Module): +class Bottleneck(BaseModule): """Bottleneck block for ResNet. If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is @@ -114,8 +115,9 @@ class Bottleneck(nn.Module): conv_cfg=None, norm_cfg=dict(type='BN'), dcn=None, - plugins=None): - super(Bottleneck, self).__init__() + plugins=None, + init_cfg=None): + super(Bottleneck, self).__init__(init_cfg) assert style in ['pytorch', 'caffe'] assert dcn is None or isinstance(dcn, dict) assert plugins is None or isinstance(plugins, list) @@ -305,7 +307,7 @@ class Bottleneck(nn.Module): @BACKBONES.register_module() -class ResNet(nn.Module): +class ResNet(BaseModule): """ResNet backbone. Args: @@ -346,6 +348,9 @@ class ResNet(nn.Module): memory while slowing down the training speed. zero_init_residual (bool): Whether to use zero init for last norm layer in resblocks to let them behave as identity. + pretrained (str, optional): model pretrained path. Default: None + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None Example: >>> from mmseg.models import ResNet @@ -392,10 +397,46 @@ class ResNet(nn.Module): multi_grid=None, contract_dilation=False, with_cp=False, - zero_init_residual=True): + zero_init_residual=True, + pretrained=None, + init_cfg=None): super(ResNet, self).__init__() if depth not in self.arch_settings: raise KeyError(f'invalid depth {depth} for resnet') + + self.pretrained = pretrained + self.zero_init_residual = zero_init_residual + block_init_cfg = None + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be setting at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is a deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ] + block = self.arch_settings[depth][0] + if self.zero_init_residual: + if block is BasicBlock: + block_init_cfg = dict( + type='Constant', + val=0, + override=dict(name='norm2')) + elif block is Bottleneck: + block_init_cfg = dict( + type='Constant', + val=0, + override=dict(name='norm3')) + else: + raise TypeError('pretrained must be a str or None') + self.depth = depth self.stem_channels = stem_channels self.base_channels = base_channels @@ -421,7 +462,6 @@ class ResNet(nn.Module): self.plugins = plugins self.multi_grid = multi_grid self.contract_dilation = contract_dilation - self.zero_init_residual = zero_init_residual self.block, stage_blocks = self.arch_settings[depth] self.stage_blocks = stage_blocks[:num_stages] self.inplanes = stem_channels @@ -456,7 +496,8 @@ class ResNet(nn.Module): dcn=dcn, plugins=stage_plugins, multi_grid=stage_multi_grid, - contract_dilation=contract_dilation) + contract_dilation=contract_dilation, + init_cfg=block_init_cfg) self.inplanes = planes * self.block.expansion layer_name = f'layer{i+1}' self.add_module(layer_name, res_layer) @@ -597,38 +638,6 @@ class ResNet(nn.Module): for param in m.parameters(): param.requires_grad = False - def init_weights(self, pretrained=None): - """Initialize the weights in backbone. - - Args: - pretrained (str, optional): Path to pre-trained weights. - Defaults to None. - """ - if isinstance(pretrained, str): - logger = get_root_logger() - load_checkpoint(self, pretrained, strict=False, logger=logger) - elif pretrained is None: - for m in self.modules(): - if isinstance(m, nn.Conv2d): - kaiming_init(m) - elif isinstance(m, (_BatchNorm, nn.GroupNorm)): - constant_init(m, 1) - - if self.dcn is not None: - for m in self.modules(): - if isinstance(m, Bottleneck) and hasattr( - m, 'conv2_offset'): - constant_init(m.conv2_offset, 0) - - if self.zero_init_residual: - for m in self.modules(): - if isinstance(m, Bottleneck): - constant_init(m.norm3, 0) - elif isinstance(m, BasicBlock): - constant_init(m.norm2, 0) - else: - raise TypeError('pretrained must be a str or None') - def forward(self, x): """Forward function.""" if self.deep_stem: diff --git a/mmseg/models/backbones/unet.py b/mmseg/models/backbones/unet.py index 6cbda009d..a8cbe57f6 100644 --- a/mmseg/models/backbones/unet.py +++ b/mmseg/models/backbones/unet.py @@ -1,11 +1,12 @@ +import warnings + import torch.nn as nn import torch.utils.checkpoint as cp from mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer, - build_norm_layer, constant_init, kaiming_init) -from mmcv.runner import load_checkpoint + build_norm_layer) +from mmcv.runner import BaseModule from mmcv.utils.parrots_wrapper import _BatchNorm -from mmseg.utils import get_root_logger from ..builder import BACKBONES from ..utils import UpConvBlock @@ -219,7 +220,7 @@ class InterpConv(nn.Module): @BACKBONES.register_module() -class UNet(nn.Module): +class UNet(BaseModule): """UNet backbone. U-Net: Convolutional Networks for Biomedical Image Segmentation. https://arxiv.org/pdf/1505.04597.pdf @@ -266,6 +267,9 @@ class UNet(nn.Module): dcn (bool): Use deformable convolution in convolutional layer or not. Default: None. plugins (dict): plugins for convolutional layers. Default: None. + pretrained (str, optional): model pretrained path. Default: None + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None Notice: The input image size should be divisible by the whole downsample rate @@ -291,8 +295,30 @@ class UNet(nn.Module): upsample_cfg=dict(type='InterpConv'), norm_eval=False, dcn=None, - plugins=None): - super(UNet, self).__init__() + plugins=None, + pretrained=None, + init_cfg=None): + super(UNet, self).__init__(init_cfg) + + self.pretrained = pretrained + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be setting at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is a deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ] + else: + raise TypeError('pretrained must be a str or None') + assert dcn is None, 'Not implemented yet.' assert plugins is None, 'Not implemented yet.' assert len(strides) == num_stages, \ @@ -408,22 +434,3 @@ class UNet(nn.Module): f'downsample rate {whole_downsample_rate}, when num_stages is '\ f'{self.num_stages}, strides is {self.strides}, and downsamples '\ f'is {self.downsamples}.' - - def init_weights(self, pretrained=None): - """Initialize the weights in backbone. - - Args: - pretrained (str, optional): Path to pre-trained weights. - Defaults to None. - """ - if isinstance(pretrained, str): - logger = get_root_logger() - load_checkpoint(self, pretrained, strict=False, logger=logger) - elif pretrained is None: - for m in self.modules(): - if isinstance(m, nn.Conv2d): - kaiming_init(m) - elif isinstance(m, (_BatchNorm, nn.GroupNorm)): - constant_init(m, 1) - else: - raise TypeError('pretrained must be a str or None') diff --git a/mmseg/models/backbones/vit.py b/mmseg/models/backbones/vit.py index 781c9c1cc..b140700a9 100644 --- a/mmseg/models/backbones/vit.py +++ b/mmseg/models/backbones/vit.py @@ -9,7 +9,7 @@ import torch.nn.functional as F import torch.utils.checkpoint as cp from mmcv.cnn import (Conv2d, Linear, build_activation_layer, build_norm_layer, constant_init, kaiming_init, normal_init) -from mmcv.runner import _load_checkpoint +from mmcv.runner import BaseModule, _load_checkpoint from mmcv.utils.parrots_wrapper import _BatchNorm from mmseg.utils import get_root_logger @@ -203,7 +203,7 @@ class PatchEmbed(nn.Module): @BACKBONES.register_module() -class VisionTransformer(nn.Module): +class VisionTransformer(BaseModule): """Vision transformer backbone. A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for @@ -243,6 +243,9 @@ class VisionTransformer(nn.Module): with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False. + pretrained (str, optional): model pretrained path. Default: None + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None """ def __init__(self, @@ -266,8 +269,12 @@ class VisionTransformer(nn.Module): out_shape='NCHW', with_cls_token=True, interpolate_mode='bicubic', - with_cp=False): - super(VisionTransformer, self).__init__() + with_cp=False, + pretrained=None, + init_cfg=None): + super(VisionTransformer, self).__init__(init_cfg) + self.pretrained = pretrained + self.img_size = img_size self.patch_size = patch_size self.features = self.embed_dim = embed_dim @@ -319,7 +326,8 @@ class VisionTransformer(nn.Module): self.norm_eval = norm_eval self.with_cp = with_cp - def init_weights(self, pretrained=None): + def init_weights(self): + pretrained = self.pretrained if isinstance(pretrained, str): logger = get_root_logger() checkpoint = _load_checkpoint(pretrained, logger=logger) diff --git a/mmseg/models/decode_heads/decode_head.py b/mmseg/models/decode_heads/decode_head.py index 86b9b63f4..54d517f02 100644 --- a/mmseg/models/decode_heads/decode_head.py +++ b/mmseg/models/decode_heads/decode_head.py @@ -2,8 +2,7 @@ from abc import ABCMeta, abstractmethod import torch import torch.nn as nn -from mmcv.cnn import normal_init -from mmcv.runner import auto_fp16, force_fp32 +from mmcv.runner import BaseModule, auto_fp16, force_fp32 from mmseg.core import build_pixel_sampler from mmseg.ops import resize @@ -11,7 +10,7 @@ from ..builder import build_loss from ..losses import accuracy -class BaseDecodeHead(nn.Module, metaclass=ABCMeta): +class BaseDecodeHead(BaseModule, metaclass=ABCMeta): """Base class for BaseDecodeHead. Args: @@ -41,6 +40,7 @@ class BaseDecodeHead(nn.Module, metaclass=ABCMeta): Default: None. align_corners (bool): align_corners argument of F.interpolate. Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. """ def __init__(self, @@ -60,8 +60,10 @@ class BaseDecodeHead(nn.Module, metaclass=ABCMeta): loss_weight=1.0), ignore_index=255, sampler=None, - align_corners=False): - super(BaseDecodeHead, self).__init__() + align_corners=False, + init_cfg=dict( + type='Normal', std=0.01, override=dict(name='conv_seg'))): + super(BaseDecodeHead, self).__init__(init_cfg) self._init_inputs(in_channels, in_index, input_transform) self.channels = channels self.num_classes = num_classes @@ -130,10 +132,6 @@ class BaseDecodeHead(nn.Module, metaclass=ABCMeta): assert isinstance(in_index, int) self.in_channels = in_channels - def init_weights(self): - """Initialize weights of classification layer.""" - normal_init(self.conv_seg, mean=0, std=0.01) - def _transform_inputs(self, inputs): """Transform inputs for decoder. diff --git a/mmseg/models/decode_heads/point_head.py b/mmseg/models/decode_heads/point_head.py index 90a23635d..f2d9fcc5a 100644 --- a/mmseg/models/decode_heads/point_head.py +++ b/mmseg/models/decode_heads/point_head.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn -from mmcv.cnn import ConvModule, normal_init +from mmcv.cnn import ConvModule from mmcv.ops import point_sample from mmseg.models.builder import HEADS @@ -69,6 +69,8 @@ class PointHead(BaseCascadeDecodeHead): conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, + init_cfg=dict( + type='Normal', std=0.01, override=dict(name='fc_seg')), **kwargs) self.num_fcs = num_fcs @@ -101,10 +103,6 @@ class PointHead(BaseCascadeDecodeHead): self.dropout = nn.Dropout(self.dropout_ratio) delattr(self, 'conv_seg') - def init_weights(self): - """Initialize weights of classification layer.""" - normal_init(self.fc_seg, std=0.001) - def cls_seg(self, feat): """Classify each pixel with fc.""" if self.dropout is not None: diff --git a/mmseg/models/necks/fpn.py b/mmseg/models/necks/fpn.py index f43d1e62f..4ba128ed4 100644 --- a/mmseg/models/necks/fpn.py +++ b/mmseg/models/necks/fpn.py @@ -1,12 +1,13 @@ import torch.nn as nn import torch.nn.functional as F -from mmcv.cnn import ConvModule, xavier_init +from mmcv.cnn import ConvModule +from mmcv.runner import BaseModule, auto_fp16 from ..builder import NECKS @NECKS.register_module() -class FPN(nn.Module): +class FPN(BaseModule): """Feature Pyramid Network. This is an implementation of - Feature Pyramid Networks for Object @@ -43,6 +44,7 @@ class FPN(nn.Module): Default: None. upsample_cfg (dict): Config dict for interpolate layer. Default: `dict(mode='nearest')` + init_cfg (dict or list[dict], optional): Initialization config dict. Example: >>> import torch @@ -73,8 +75,10 @@ class FPN(nn.Module): conv_cfg=None, norm_cfg=None, act_cfg=None, - upsample_cfg=dict(mode='nearest')): - super(FPN, self).__init__() + upsample_cfg=dict(mode='nearest'), + init_cfg=dict( + type='Xavier', layer='Conv2d', distribution='uniform')): + super(FPN, self).__init__(init_cfg) assert isinstance(in_channels, list) self.in_channels = in_channels self.out_channels = out_channels @@ -153,12 +157,7 @@ class FPN(nn.Module): inplace=False) self.fpn_convs.append(extra_fpn_conv) - # default init_weights for conv(msra) and norm in ConvModule - def init_weights(self): - for m in self.modules(): - if isinstance(m, nn.Conv2d): - xavier_init(m, distribution='uniform') - + @auto_fp16() def forward(self, inputs): assert len(inputs) == len(self.in_channels) diff --git a/mmseg/models/segmentors/base.py b/mmseg/models/segmentors/base.py index 7b5375753..0ace142ac 100644 --- a/mmseg/models/segmentors/base.py +++ b/mmseg/models/segmentors/base.py @@ -1,4 +1,3 @@ -import logging import warnings from abc import ABCMeta, abstractmethod from collections import OrderedDict @@ -7,17 +6,14 @@ import mmcv import numpy as np import torch import torch.distributed as dist -import torch.nn as nn -from mmcv.runner import auto_fp16 +from mmcv.runner import BaseModule, auto_fp16 -class BaseSegmentor(nn.Module): +class BaseSegmentor(BaseModule, metaclass=ABCMeta): """Base class for segmentors.""" - __metaclass__ = ABCMeta - - def __init__(self): - super(BaseSegmentor, self).__init__() + def __init__(self, init_cfg=None): + super(BaseSegmentor, self).__init__(init_cfg) self.fp16_enabled = False @property @@ -62,17 +58,6 @@ class BaseSegmentor(nn.Module): """Placeholder for augmentation test.""" pass - def init_weights(self, pretrained=None): - """Initialize the weights in segmentor. - - Args: - pretrained (str, optional): Path to pre-trained weights. - Defaults to None. - """ - if pretrained is not None: - logger = logging.getLogger() - logger.info(f'load model from: {pretrained}') - def forward_test(self, imgs, img_metas, **kwargs): """ Args: diff --git a/mmseg/models/segmentors/cascade_encoder_decoder.py b/mmseg/models/segmentors/cascade_encoder_decoder.py index 220ab2bb3..fb5a9aeb7 100644 --- a/mmseg/models/segmentors/cascade_encoder_decoder.py +++ b/mmseg/models/segmentors/cascade_encoder_decoder.py @@ -24,7 +24,8 @@ class CascadeEncoderDecoder(EncoderDecoder): auxiliary_head=None, train_cfg=None, test_cfg=None, - pretrained=None): + pretrained=None, + init_cfg=None): self.num_stages = num_stages super(CascadeEncoderDecoder, self).__init__( backbone=backbone, @@ -33,7 +34,8 @@ class CascadeEncoderDecoder(EncoderDecoder): auxiliary_head=auxiliary_head, train_cfg=train_cfg, test_cfg=test_cfg, - pretrained=pretrained) + pretrained=pretrained, + init_cfg=init_cfg) def _init_decode_head(self, decode_head): """Initialize ``decode_head``""" @@ -45,23 +47,6 @@ class CascadeEncoderDecoder(EncoderDecoder): self.align_corners = self.decode_head[-1].align_corners self.num_classes = self.decode_head[-1].num_classes - def init_weights(self, pretrained=None): - """Initialize the weights in backbone and heads. - - Args: - pretrained (str, optional): Path to pre-trained weights. - Defaults to None. - """ - self.backbone.init_weights(pretrained=pretrained) - for i in range(self.num_stages): - self.decode_head[i].init_weights() - if self.with_auxiliary_head: - if isinstance(self.auxiliary_head, nn.ModuleList): - for aux_head in self.auxiliary_head: - aux_head.init_weights() - else: - self.auxiliary_head.init_weights() - def encode_decode(self, img, img_metas): """Encode images with backbone and decode into a semantic segmentation map of the same size as input.""" diff --git a/mmseg/models/segmentors/encoder_decoder.py b/mmseg/models/segmentors/encoder_decoder.py index b2d067dcb..04de3f418 100644 --- a/mmseg/models/segmentors/encoder_decoder.py +++ b/mmseg/models/segmentors/encoder_decoder.py @@ -25,8 +25,13 @@ class EncoderDecoder(BaseSegmentor): auxiliary_head=None, train_cfg=None, test_cfg=None, - pretrained=None): - super(EncoderDecoder, self).__init__() + pretrained=None, + init_cfg=None): + super(EncoderDecoder, self).__init__(init_cfg) + if pretrained is not None: + assert backbone.get('pretrained') is None, \ + 'both backbone and segmentor set pretrained weight' + backbone.pretrained = pretrained self.backbone = builder.build_backbone(backbone) if neck is not None: self.neck = builder.build_neck(neck) @@ -36,8 +41,6 @@ class EncoderDecoder(BaseSegmentor): self.train_cfg = train_cfg self.test_cfg = test_cfg - self.init_weights(pretrained=pretrained) - assert self.with_decode_head def _init_decode_head(self, decode_head): @@ -56,24 +59,6 @@ class EncoderDecoder(BaseSegmentor): else: self.auxiliary_head = builder.build_head(auxiliary_head) - def init_weights(self, pretrained=None): - """Initialize the weights in backbone and heads. - - Args: - pretrained (str, optional): Path to pre-trained weights. - Defaults to None. - """ - - super(EncoderDecoder, self).init_weights(pretrained) - self.backbone.init_weights(pretrained=pretrained) - self.decode_head.init_weights() - if self.with_auxiliary_head: - if isinstance(self.auxiliary_head, nn.ModuleList): - for aux_head in self.auxiliary_head: - aux_head.init_weights() - else: - self.auxiliary_head.init_weights() - def extract_feat(self, img): """Extract features from images.""" x = self.backbone(img) diff --git a/mmseg/models/utils/res_layer.py b/mmseg/models/utils/res_layer.py index 2585ab551..9c474ede6 100644 --- a/mmseg/models/utils/res_layer.py +++ b/mmseg/models/utils/res_layer.py @@ -1,8 +1,9 @@ from mmcv.cnn import build_conv_layer, build_norm_layer +from mmcv.runner import Sequential from torch import nn as nn -class ResLayer(nn.Sequential): +class ResLayer(Sequential): """ResLayer to build ResNet style backbone. Args: diff --git a/tests/test_models/test_backbones/test_resnet.py b/tests/test_models/test_backbones/test_resnet.py index b95277ee4..e0947dba7 100644 --- a/tests/test_models/test_backbones/test_resnet.py +++ b/tests/test_models/test_backbones/test_resnet.py @@ -300,8 +300,8 @@ def test_resnet_backbone(): with pytest.raises(TypeError): # pretrained must be a string path - model = ResNet(50) - model.init_weights(pretrained=0) + model = ResNet(50, pretrained=0) + model.init_weights() with pytest.raises(AssertionError): # Style must be in ['pytorch', 'caffe'] @@ -314,8 +314,9 @@ def test_resnet_backbone(): assert check_norm_state(model.modules(), False) # Test ResNet50 with torchvision pretrained weight - model = ResNet(depth=50, norm_eval=True) - model.init_weights('torchvision://resnet50') + model = ResNet( + depth=50, norm_eval=True, pretrained='torchvision://resnet50') + model.init_weights() model.train() assert check_norm_state(model.modules(), False) diff --git a/tests/test_models/test_backbones/test_unet.py b/tests/test_models/test_backbones/test_unet.py index b17b22a05..defdf3921 100644 --- a/tests/test_models/test_backbones/test_unet.py +++ b/tests/test_models/test_backbones/test_unet.py @@ -734,7 +734,6 @@ def test_unet(): downsamples=(True, True, True, True), enc_dilations=(1, 1, 1, 1, 1), dec_dilations=(1, 1, 1, 1)) - print(unet) x = torch.randn(2, 3, 128, 128) x_outs = unet(x) assert x_outs[0].shape == torch.Size([2, 1024, 8, 8]) @@ -754,7 +753,6 @@ def test_unet(): downsamples=(True, True, True, False), enc_dilations=(1, 1, 1, 1, 1), dec_dilations=(1, 1, 1, 1)) - print(unet) x = torch.randn(2, 3, 128, 128) x_outs = unet(x) assert x_outs[0].shape == torch.Size([2, 1024, 16, 16]) @@ -774,7 +772,6 @@ def test_unet(): downsamples=(True, True, True, False), enc_dilations=(1, 1, 1, 1, 1), dec_dilations=(1, 1, 1, 1)) - print(unet) x = torch.randn(2, 3, 128, 128) x_outs = unet(x) assert x_outs[0].shape == torch.Size([2, 1024, 16, 16]) @@ -794,7 +791,6 @@ def test_unet(): downsamples=(True, True, False, False), enc_dilations=(1, 1, 1, 1, 1), dec_dilations=(1, 1, 1, 1)) - print(unet) x = torch.randn(2, 3, 128, 128) x_outs = unet(x) assert x_outs[0].shape == torch.Size([2, 1024, 32, 32]) @@ -813,9 +809,9 @@ def test_unet(): dec_num_convs=(2, 2, 2, 2), downsamples=(True, True, False, False), enc_dilations=(1, 1, 1, 1, 1), - dec_dilations=(1, 1, 1, 1)) - unet.init_weights(pretrained=None) - print(unet) + dec_dilations=(1, 1, 1, 1), + pretrained=None) + unet.init_weights() x = torch.randn(2, 3, 128, 128) x_outs = unet(x) assert x_outs[0].shape == torch.Size([2, 1024, 32, 32]) diff --git a/tests/test_models/test_forward.py b/tests/test_models/test_forward.py index ee8036246..ea9d70b61 100644 --- a/tests/test_models/test_forward.py +++ b/tests/test_models/test_forward.py @@ -215,6 +215,7 @@ def _test_encoder_decoder_forward(cfg_file): from mmseg.models import build_segmentor segmentor = build_segmentor(model) + segmentor.init_weights() if isinstance(segmentor.decode_head, nn.ModuleList): num_classes = segmentor.decode_head[-1].num_classes diff --git a/tools/train.py b/tools/train.py index 69ca7335d..2d11df37b 100644 --- a/tools/train.py +++ b/tools/train.py @@ -131,6 +131,7 @@ def main(): cfg.model, train_cfg=cfg.get('train_cfg'), test_cfg=cfg.get('test_cfg')) + model.init_weights() logger.info(model)