[Refactor]: Unified parameter initialization (#567)

* [Refactor]: Unified parameter initialization

* fixed pretrained
This commit is contained in:
Jerry Jiarui XU 2021-06-16 21:41:29 -07:00 committed by GitHub
parent af6478dd7a
commit 9849a8dc23
19 changed files with 329 additions and 298 deletions

View File

@ -1,12 +1,12 @@
import warnings
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import (ConvModule, build_conv_layer, build_norm_layer,
constant_init, kaiming_init)
from mmcv.runner import load_checkpoint
from mmcv.cnn import ConvModule, build_conv_layer, build_norm_layer
from mmcv.runner import BaseModule
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmseg.utils import get_root_logger
from ..builder import BACKBONES
@ -183,7 +183,7 @@ class InputInjection(nn.Module):
@BACKBONES.register_module()
class CGNet(nn.Module):
class CGNet(BaseModule):
"""CGNet backbone.
A Light-weight Context Guided Network for Semantic Segmentation
@ -210,6 +210,9 @@ class CGNet(nn.Module):
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
pretrained (str, optional): model pretrained path. Default: None
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
"""
def __init__(self,
@ -222,9 +225,31 @@ class CGNet(nn.Module):
norm_cfg=dict(type='BN', requires_grad=True),
act_cfg=dict(type='PReLU'),
norm_eval=False,
with_cp=False):
with_cp=False,
pretrained=None,
init_cfg=None):
super(CGNet, self).__init__(init_cfg)
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be setting at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is None:
if init_cfg is None:
self.init_cfg = [
dict(type='Kaiming', layer=['Conv2d', 'Linear']),
dict(
type='Constant',
val=1,
layer=['_BatchNorm', 'GroupNorm']),
dict(type='Constant', val=0, layer='PReLU')
]
else:
raise TypeError('pretrained must be a str or None')
super(CGNet, self).__init__()
self.in_channels = in_channels
self.num_channels = num_channels
assert isinstance(self.num_channels, tuple) and len(
@ -335,27 +360,6 @@ class CGNet(nn.Module):
return output
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.Linear)):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
elif isinstance(m, nn.PReLU):
constant_init(m, 0)
else:
raise TypeError('pretrained must be a str or None')
def train(self, mode=True):
"""Convert the model into training mode will keeping the normalization
layer freezed."""

View File

@ -1,8 +1,7 @@
import torch
import torch.nn as nn
from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule, constant_init,
kaiming_init)
from torch.nn.modules.batchnorm import _BatchNorm
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
from mmcv.runner import BaseModule
from mmseg.models.decode_heads.psp_head import PPM
from mmseg.ops import resize
@ -247,7 +246,7 @@ class FeatureFusionModule(nn.Module):
@BACKBONES.register_module()
class FastSCNN(nn.Module):
class FastSCNN(BaseModule):
"""Fast-SCNN Backbone.
Args:
@ -291,6 +290,8 @@ class FastSCNN(nn.Module):
dict(type='ReLU')
align_corners (bool): align_corners argument of F.interpolate.
Default: False
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
"""
def __init__(self,
@ -307,9 +308,18 @@ class FastSCNN(nn.Module):
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
align_corners=False):
align_corners=False,
init_cfg=None):
super(FastSCNN, self).__init__(init_cfg)
if init_cfg is None:
self.init_cfg = [
dict(type='Kaiming', layer='Conv2d'),
dict(
type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
]
super(FastSCNN, self).__init__()
if global_in_channels != higher_in_channels:
raise AssertionError('Global Input Channels must be the same \
with Higher Input Channels!')
@ -357,13 +367,6 @@ class FastSCNN(nn.Module):
act_cfg=self.act_cfg,
align_corners=self.align_corners)
def init_weights(self, pretrained=None):
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
def forward(self, x):
higher_res_features = self.learning_to_downsample(x)
lower_res_features = self.global_feature_extractor(higher_res_features)

View File

@ -1,16 +1,16 @@
import warnings
import torch.nn as nn
from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init,
kaiming_init)
from mmcv.runner import load_checkpoint
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.runner import BaseModule, ModuleList, Sequential
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmseg.ops import Upsample, resize
from mmseg.utils import get_root_logger
from ..builder import BACKBONES
from .resnet import BasicBlock, Bottleneck
class HRModule(nn.Module):
class HRModule(BaseModule):
"""High-Resolution Module for HRNet.
In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange
@ -26,8 +26,11 @@ class HRModule(nn.Module):
multiscale_output=True,
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN', requires_grad=True)):
super(HRModule, self).__init__()
norm_cfg=dict(type='BN', requires_grad=True),
block_init_cfg=None,
init_cfg=None):
super(HRModule, self).__init__(init_cfg)
self.block_init_cfg = block_init_cfg
self._check_branches(num_branches, num_blocks, in_channels,
num_channels)
@ -92,7 +95,8 @@ class HRModule(nn.Module):
downsample=downsample,
with_cp=self.with_cp,
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg))
conv_cfg=self.conv_cfg,
init_cfg=self.block_init_cfg))
self.in_channels[branch_index] = \
num_channels[branch_index] * block.expansion
for i in range(1, num_blocks[branch_index]):
@ -102,9 +106,10 @@ class HRModule(nn.Module):
num_channels[branch_index],
with_cp=self.with_cp,
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg))
conv_cfg=self.conv_cfg,
init_cfg=self.block_init_cfg))
return nn.Sequential(*layers)
return Sequential(*layers)
def _make_branches(self, num_branches, block, num_blocks, num_channels):
"""Build multiple branch."""
@ -114,7 +119,7 @@ class HRModule(nn.Module):
branches.append(
self._make_one_branch(i, block, num_blocks, num_channels))
return nn.ModuleList(branches)
return ModuleList(branches)
def _make_fuse_layers(self):
"""Build fuse layer."""
@ -209,7 +214,7 @@ class HRModule(nn.Module):
@BACKBONES.register_module()
class HRNet(nn.Module):
class HRNet(BaseModule):
"""HRNet backbone.
High-Resolution Representations for Labeling Pixels and Regions
@ -227,6 +232,9 @@ class HRNet(nn.Module):
memory while slowing down the training speed.
zero_init_residual (bool): whether to use zero init for last norm layer
in resblocks to let them behave as identity.
pretrained (str, optional): model pretrained path. Default: None
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
Example:
>>> from mmseg.models import HRNet
@ -277,14 +285,36 @@ class HRNet(nn.Module):
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=False,
with_cp=False,
zero_init_residual=False):
super(HRNet, self).__init__()
zero_init_residual=False,
pretrained=None,
init_cfg=None):
super(HRNet, self).__init__(init_cfg)
self.pretrained = pretrained
self.zero_init_residual = zero_init_residual
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be setting at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is None:
if init_cfg is None:
self.init_cfg = [
dict(type='Kaiming', layer='Conv2d'),
dict(
type='Constant',
val=1,
layer=['_BatchNorm', 'GroupNorm'])
]
else:
raise TypeError('pretrained must be a str or None')
self.extra = extra
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.norm_eval = norm_eval
self.with_cp = with_cp
self.zero_init_residual = zero_init_residual
# stem net
self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
@ -430,6 +460,16 @@ class HRNet(nn.Module):
build_norm_layer(self.norm_cfg, planes * block.expansion)[1])
layers = []
block_init_cfg = None
if self.pretrained is None and not hasattr(
self, 'init_cfg') and self.zero_init_residual:
if block is BasicBlock:
block_init_cfg = dict(
type='Constant', val=0, override=dict(name='norm2'))
elif block is Bottleneck:
block_init_cfg = dict(
type='Constant', val=0, override=dict(name='norm3'))
layers.append(
block(
inplanes,
@ -438,7 +478,8 @@ class HRNet(nn.Module):
downsample=downsample,
with_cp=self.with_cp,
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg))
conv_cfg=self.conv_cfg,
init_cfg=block_init_cfg))
inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(
@ -447,9 +488,10 @@ class HRNet(nn.Module):
planes,
with_cp=self.with_cp,
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg))
conv_cfg=self.conv_cfg,
init_cfg=block_init_cfg))
return nn.Sequential(*layers)
return Sequential(*layers)
def _make_stage(self, layer_config, in_channels, multiscale_output=True):
"""Make each stage."""
@ -460,6 +502,16 @@ class HRNet(nn.Module):
block = self.blocks_dict[layer_config['block']]
hr_modules = []
block_init_cfg = None
if self.pretrained is None and not hasattr(
self, 'init_cfg') and self.zero_init_residual:
if block is BasicBlock:
block_init_cfg = dict(
type='Constant', val=0, override=dict(name='norm2'))
elif block is Bottleneck:
block_init_cfg = dict(
type='Constant', val=0, override=dict(name='norm3'))
for i in range(num_modules):
# multi_scale_output is only used for the last module
if not multiscale_output and i == num_modules - 1:
@ -477,35 +529,10 @@ class HRNet(nn.Module):
reset_multiscale_output,
with_cp=self.with_cp,
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg))
conv_cfg=self.conv_cfg,
block_init_cfg=block_init_cfg))
return nn.Sequential(*hr_modules), in_channels
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
if self.zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
constant_init(m.norm3, 0)
elif isinstance(m, BasicBlock):
constant_init(m.norm2, 0)
else:
raise TypeError('pretrained must be a str or None')
return Sequential(*hr_modules), in_channels
def forward(self, x):
"""Forward function."""

View File

@ -1,8 +1,8 @@
import logging
import warnings
import torch.nn as nn
from mmcv.cnn import ConvModule, constant_init, kaiming_init
from mmcv.runner import load_checkpoint
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
from torch.nn.modules.batchnorm import _BatchNorm
from ..builder import BACKBONES
@ -10,7 +10,7 @@ from ..utils import InvertedResidual, make_divisible
@BACKBONES.register_module()
class MobileNetV2(nn.Module):
class MobileNetV2(BaseModule):
"""MobileNetV2 backbone.
Args:
@ -35,6 +35,9 @@ class MobileNetV2(nn.Module):
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
pretrained (str, optional): model pretrained path. Default: None
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
"""
# Parameters to build layers. 3 parameters are needed to construct a
@ -52,8 +55,30 @@ class MobileNetV2(nn.Module):
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU6'),
norm_eval=False,
with_cp=False):
super(MobileNetV2, self).__init__()
with_cp=False,
pretrained=None,
init_cfg=None):
super(MobileNetV2, self).__init__(init_cfg)
self.pretrained = pretrained
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be setting at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is None:
if init_cfg is None:
self.init_cfg = [
dict(type='Kaiming', layer='Conv2d'),
dict(
type='Constant',
val=1,
layer=['_BatchNorm', 'GroupNorm'])
]
else:
raise TypeError('pretrained must be a str or None')
self.widen_factor = widen_factor
self.strides = strides
self.dilations = dilations
@ -133,19 +158,6 @@ class MobileNetV2(nn.Module):
return nn.Sequential(*layers)
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = logging.getLogger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
else:
raise TypeError('pretrained must be a str or None')
def forward(self, x):
x = self.conv1(x)

View File

@ -1,10 +1,9 @@
import logging
import warnings
import mmcv
import torch.nn as nn
from mmcv.cnn import ConvModule, constant_init, kaiming_init
from mmcv.cnn import ConvModule
from mmcv.cnn.bricks import Conv2dAdaptivePadding
from mmcv.runner import load_checkpoint
from mmcv.runner import BaseModule
from torch.nn.modules.batchnorm import _BatchNorm
from ..builder import BACKBONES
@ -12,7 +11,7 @@ from ..utils import InvertedResidualV3 as InvertedResidual
@BACKBONES.register_module()
class MobileNetV3(nn.Module):
class MobileNetV3(BaseModule):
"""MobileNetV3 backbone.
This backbone is the improved implementation of `Searching for MobileNetV3
@ -35,6 +34,9 @@ class MobileNetV3(nn.Module):
with_cp (bool): Use checkpoint or not. Using checkpoint will save
some memory while slowing down the training speed.
Default: False.
pretrained (str, optional): model pretrained path. Default: None
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
"""
# Parameters to build each block:
# [kernel size, mid channels, out channels, with_se, act type, stride]
@ -75,8 +77,30 @@ class MobileNetV3(nn.Module):
frozen_stages=-1,
reduction_factor=1,
norm_eval=False,
with_cp=False):
super(MobileNetV3, self).__init__()
with_cp=False,
pretrained=None,
init_cfg=None):
super(MobileNetV3, self).__init__(init_cfg)
self.pretrained = pretrained
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be setting at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is None:
if init_cfg is None:
self.init_cfg = [
dict(type='Kaiming', layer='Conv2d'),
dict(
type='Constant',
val=1,
layer=['_BatchNorm', 'GroupNorm'])
]
else:
raise TypeError('pretrained must be a str or None')
assert arch in self.arch_settings
assert isinstance(reduction_factor, int) and reduction_factor > 0
assert mmcv.is_tuple_of(out_indices, int)
@ -217,19 +241,6 @@ class MobileNetV3(nn.Module):
return layers
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = logging.getLogger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, nn.BatchNorm2d):
constant_init(m, 1)
else:
raise TypeError('pretrained must be a str or None')
def forward(self, x):
outs = []
for i, layer_name in enumerate(self.layers):

View File

@ -1,16 +1,16 @@
import warnings
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import (build_conv_layer, build_norm_layer, build_plugin_layer,
constant_init, kaiming_init)
from mmcv.runner import load_checkpoint
from mmcv.cnn import build_conv_layer, build_norm_layer, build_plugin_layer
from mmcv.runner import BaseModule
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmseg.utils import get_root_logger
from ..builder import BACKBONES
from ..utils import ResLayer
class BasicBlock(nn.Module):
class BasicBlock(BaseModule):
"""Basic block for ResNet."""
expansion = 1
@ -26,8 +26,9 @@ class BasicBlock(nn.Module):
conv_cfg=None,
norm_cfg=dict(type='BN'),
dcn=None,
plugins=None):
super(BasicBlock, self).__init__()
plugins=None,
init_cfg=None):
super(BasicBlock, self).__init__(init_cfg)
assert dcn is None, 'Not implemented yet.'
assert plugins is None, 'Not implemented yet.'
@ -94,7 +95,7 @@ class BasicBlock(nn.Module):
return out
class Bottleneck(nn.Module):
class Bottleneck(BaseModule):
"""Bottleneck block for ResNet.
If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is
@ -114,8 +115,9 @@ class Bottleneck(nn.Module):
conv_cfg=None,
norm_cfg=dict(type='BN'),
dcn=None,
plugins=None):
super(Bottleneck, self).__init__()
plugins=None,
init_cfg=None):
super(Bottleneck, self).__init__(init_cfg)
assert style in ['pytorch', 'caffe']
assert dcn is None or isinstance(dcn, dict)
assert plugins is None or isinstance(plugins, list)
@ -305,7 +307,7 @@ class Bottleneck(nn.Module):
@BACKBONES.register_module()
class ResNet(nn.Module):
class ResNet(BaseModule):
"""ResNet backbone.
Args:
@ -346,6 +348,9 @@ class ResNet(nn.Module):
memory while slowing down the training speed.
zero_init_residual (bool): Whether to use zero init for last norm layer
in resblocks to let them behave as identity.
pretrained (str, optional): model pretrained path. Default: None
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
Example:
>>> from mmseg.models import ResNet
@ -392,10 +397,46 @@ class ResNet(nn.Module):
multi_grid=None,
contract_dilation=False,
with_cp=False,
zero_init_residual=True):
zero_init_residual=True,
pretrained=None,
init_cfg=None):
super(ResNet, self).__init__()
if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for resnet')
self.pretrained = pretrained
self.zero_init_residual = zero_init_residual
block_init_cfg = None
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be setting at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is None:
if init_cfg is None:
self.init_cfg = [
dict(type='Kaiming', layer='Conv2d'),
dict(
type='Constant',
val=1,
layer=['_BatchNorm', 'GroupNorm'])
]
block = self.arch_settings[depth][0]
if self.zero_init_residual:
if block is BasicBlock:
block_init_cfg = dict(
type='Constant',
val=0,
override=dict(name='norm2'))
elif block is Bottleneck:
block_init_cfg = dict(
type='Constant',
val=0,
override=dict(name='norm3'))
else:
raise TypeError('pretrained must be a str or None')
self.depth = depth
self.stem_channels = stem_channels
self.base_channels = base_channels
@ -421,7 +462,6 @@ class ResNet(nn.Module):
self.plugins = plugins
self.multi_grid = multi_grid
self.contract_dilation = contract_dilation
self.zero_init_residual = zero_init_residual
self.block, stage_blocks = self.arch_settings[depth]
self.stage_blocks = stage_blocks[:num_stages]
self.inplanes = stem_channels
@ -456,7 +496,8 @@ class ResNet(nn.Module):
dcn=dcn,
plugins=stage_plugins,
multi_grid=stage_multi_grid,
contract_dilation=contract_dilation)
contract_dilation=contract_dilation,
init_cfg=block_init_cfg)
self.inplanes = planes * self.block.expansion
layer_name = f'layer{i+1}'
self.add_module(layer_name, res_layer)
@ -597,38 +638,6 @@ class ResNet(nn.Module):
for param in m.parameters():
param.requires_grad = False
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
if self.dcn is not None:
for m in self.modules():
if isinstance(m, Bottleneck) and hasattr(
m, 'conv2_offset'):
constant_init(m.conv2_offset, 0)
if self.zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
constant_init(m.norm3, 0)
elif isinstance(m, BasicBlock):
constant_init(m.norm2, 0)
else:
raise TypeError('pretrained must be a str or None')
def forward(self, x):
"""Forward function."""
if self.deep_stem:

View File

@ -1,11 +1,12 @@
import warnings
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer,
build_norm_layer, constant_init, kaiming_init)
from mmcv.runner import load_checkpoint
build_norm_layer)
from mmcv.runner import BaseModule
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmseg.utils import get_root_logger
from ..builder import BACKBONES
from ..utils import UpConvBlock
@ -219,7 +220,7 @@ class InterpConv(nn.Module):
@BACKBONES.register_module()
class UNet(nn.Module):
class UNet(BaseModule):
"""UNet backbone.
U-Net: Convolutional Networks for Biomedical Image Segmentation.
https://arxiv.org/pdf/1505.04597.pdf
@ -266,6 +267,9 @@ class UNet(nn.Module):
dcn (bool): Use deformable convolution in convolutional layer or not.
Default: None.
plugins (dict): plugins for convolutional layers. Default: None.
pretrained (str, optional): model pretrained path. Default: None
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
Notice:
The input image size should be divisible by the whole downsample rate
@ -291,8 +295,30 @@ class UNet(nn.Module):
upsample_cfg=dict(type='InterpConv'),
norm_eval=False,
dcn=None,
plugins=None):
super(UNet, self).__init__()
plugins=None,
pretrained=None,
init_cfg=None):
super(UNet, self).__init__(init_cfg)
self.pretrained = pretrained
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be setting at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is None:
if init_cfg is None:
self.init_cfg = [
dict(type='Kaiming', layer='Conv2d'),
dict(
type='Constant',
val=1,
layer=['_BatchNorm', 'GroupNorm'])
]
else:
raise TypeError('pretrained must be a str or None')
assert dcn is None, 'Not implemented yet.'
assert plugins is None, 'Not implemented yet.'
assert len(strides) == num_stages, \
@ -408,22 +434,3 @@ class UNet(nn.Module):
f'downsample rate {whole_downsample_rate}, when num_stages is '\
f'{self.num_stages}, strides is {self.strides}, and downsamples '\
f'is {self.downsamples}.'
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
else:
raise TypeError('pretrained must be a str or None')

View File

@ -9,7 +9,7 @@ import torch.nn.functional as F
import torch.utils.checkpoint as cp
from mmcv.cnn import (Conv2d, Linear, build_activation_layer, build_norm_layer,
constant_init, kaiming_init, normal_init)
from mmcv.runner import _load_checkpoint
from mmcv.runner import BaseModule, _load_checkpoint
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmseg.utils import get_root_logger
@ -203,7 +203,7 @@ class PatchEmbed(nn.Module):
@BACKBONES.register_module()
class VisionTransformer(nn.Module):
class VisionTransformer(BaseModule):
"""Vision transformer backbone.
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for
@ -243,6 +243,9 @@ class VisionTransformer(nn.Module):
with_cp (bool): Use checkpoint or not. Using checkpoint
will save some memory while slowing down the training speed.
Default: False.
pretrained (str, optional): model pretrained path. Default: None
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
"""
def __init__(self,
@ -266,8 +269,12 @@ class VisionTransformer(nn.Module):
out_shape='NCHW',
with_cls_token=True,
interpolate_mode='bicubic',
with_cp=False):
super(VisionTransformer, self).__init__()
with_cp=False,
pretrained=None,
init_cfg=None):
super(VisionTransformer, self).__init__(init_cfg)
self.pretrained = pretrained
self.img_size = img_size
self.patch_size = patch_size
self.features = self.embed_dim = embed_dim
@ -319,7 +326,8 @@ class VisionTransformer(nn.Module):
self.norm_eval = norm_eval
self.with_cp = with_cp
def init_weights(self, pretrained=None):
def init_weights(self):
pretrained = self.pretrained
if isinstance(pretrained, str):
logger = get_root_logger()
checkpoint = _load_checkpoint(pretrained, logger=logger)

View File

@ -2,8 +2,7 @@ from abc import ABCMeta, abstractmethod
import torch
import torch.nn as nn
from mmcv.cnn import normal_init
from mmcv.runner import auto_fp16, force_fp32
from mmcv.runner import BaseModule, auto_fp16, force_fp32
from mmseg.core import build_pixel_sampler
from mmseg.ops import resize
@ -11,7 +10,7 @@ from ..builder import build_loss
from ..losses import accuracy
class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
"""Base class for BaseDecodeHead.
Args:
@ -41,6 +40,7 @@ class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
Default: None.
align_corners (bool): align_corners argument of F.interpolate.
Default: False.
init_cfg (dict or list[dict], optional): Initialization config dict.
"""
def __init__(self,
@ -60,8 +60,10 @@ class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
loss_weight=1.0),
ignore_index=255,
sampler=None,
align_corners=False):
super(BaseDecodeHead, self).__init__()
align_corners=False,
init_cfg=dict(
type='Normal', std=0.01, override=dict(name='conv_seg'))):
super(BaseDecodeHead, self).__init__(init_cfg)
self._init_inputs(in_channels, in_index, input_transform)
self.channels = channels
self.num_classes = num_classes
@ -130,10 +132,6 @@ class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
assert isinstance(in_index, int)
self.in_channels = in_channels
def init_weights(self):
"""Initialize weights of classification layer."""
normal_init(self.conv_seg, mean=0, std=0.01)
def _transform_inputs(self, inputs):
"""Transform inputs for decoder.

View File

@ -2,7 +2,7 @@
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, normal_init
from mmcv.cnn import ConvModule
from mmcv.ops import point_sample
from mmseg.models.builder import HEADS
@ -69,6 +69,8 @@ class PointHead(BaseCascadeDecodeHead):
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
init_cfg=dict(
type='Normal', std=0.01, override=dict(name='fc_seg')),
**kwargs)
self.num_fcs = num_fcs
@ -101,10 +103,6 @@ class PointHead(BaseCascadeDecodeHead):
self.dropout = nn.Dropout(self.dropout_ratio)
delattr(self, 'conv_seg')
def init_weights(self):
"""Initialize weights of classification layer."""
normal_init(self.fc_seg, std=0.001)
def cls_seg(self, feat):
"""Classify each pixel with fc."""
if self.dropout is not None:

View File

@ -1,12 +1,13 @@
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule, xavier_init
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule, auto_fp16
from ..builder import NECKS
@NECKS.register_module()
class FPN(nn.Module):
class FPN(BaseModule):
"""Feature Pyramid Network.
This is an implementation of - Feature Pyramid Networks for Object
@ -43,6 +44,7 @@ class FPN(nn.Module):
Default: None.
upsample_cfg (dict): Config dict for interpolate layer.
Default: `dict(mode='nearest')`
init_cfg (dict or list[dict], optional): Initialization config dict.
Example:
>>> import torch
@ -73,8 +75,10 @@ class FPN(nn.Module):
conv_cfg=None,
norm_cfg=None,
act_cfg=None,
upsample_cfg=dict(mode='nearest')):
super(FPN, self).__init__()
upsample_cfg=dict(mode='nearest'),
init_cfg=dict(
type='Xavier', layer='Conv2d', distribution='uniform')):
super(FPN, self).__init__(init_cfg)
assert isinstance(in_channels, list)
self.in_channels = in_channels
self.out_channels = out_channels
@ -153,12 +157,7 @@ class FPN(nn.Module):
inplace=False)
self.fpn_convs.append(extra_fpn_conv)
# default init_weights for conv(msra) and norm in ConvModule
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
xavier_init(m, distribution='uniform')
@auto_fp16()
def forward(self, inputs):
assert len(inputs) == len(self.in_channels)

View File

@ -1,4 +1,3 @@
import logging
import warnings
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
@ -7,17 +6,14 @@ import mmcv
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from mmcv.runner import auto_fp16
from mmcv.runner import BaseModule, auto_fp16
class BaseSegmentor(nn.Module):
class BaseSegmentor(BaseModule, metaclass=ABCMeta):
"""Base class for segmentors."""
__metaclass__ = ABCMeta
def __init__(self):
super(BaseSegmentor, self).__init__()
def __init__(self, init_cfg=None):
super(BaseSegmentor, self).__init__(init_cfg)
self.fp16_enabled = False
@property
@ -62,17 +58,6 @@ class BaseSegmentor(nn.Module):
"""Placeholder for augmentation test."""
pass
def init_weights(self, pretrained=None):
"""Initialize the weights in segmentor.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
if pretrained is not None:
logger = logging.getLogger()
logger.info(f'load model from: {pretrained}')
def forward_test(self, imgs, img_metas, **kwargs):
"""
Args:

View File

@ -24,7 +24,8 @@ class CascadeEncoderDecoder(EncoderDecoder):
auxiliary_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
pretrained=None,
init_cfg=None):
self.num_stages = num_stages
super(CascadeEncoderDecoder, self).__init__(
backbone=backbone,
@ -33,7 +34,8 @@ class CascadeEncoderDecoder(EncoderDecoder):
auxiliary_head=auxiliary_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained)
pretrained=pretrained,
init_cfg=init_cfg)
def _init_decode_head(self, decode_head):
"""Initialize ``decode_head``"""
@ -45,23 +47,6 @@ class CascadeEncoderDecoder(EncoderDecoder):
self.align_corners = self.decode_head[-1].align_corners
self.num_classes = self.decode_head[-1].num_classes
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone and heads.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
self.backbone.init_weights(pretrained=pretrained)
for i in range(self.num_stages):
self.decode_head[i].init_weights()
if self.with_auxiliary_head:
if isinstance(self.auxiliary_head, nn.ModuleList):
for aux_head in self.auxiliary_head:
aux_head.init_weights()
else:
self.auxiliary_head.init_weights()
def encode_decode(self, img, img_metas):
"""Encode images with backbone and decode into a semantic segmentation
map of the same size as input."""

View File

@ -25,8 +25,13 @@ class EncoderDecoder(BaseSegmentor):
auxiliary_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(EncoderDecoder, self).__init__()
pretrained=None,
init_cfg=None):
super(EncoderDecoder, self).__init__(init_cfg)
if pretrained is not None:
assert backbone.get('pretrained') is None, \
'both backbone and segmentor set pretrained weight'
backbone.pretrained = pretrained
self.backbone = builder.build_backbone(backbone)
if neck is not None:
self.neck = builder.build_neck(neck)
@ -36,8 +41,6 @@ class EncoderDecoder(BaseSegmentor):
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)
assert self.with_decode_head
def _init_decode_head(self, decode_head):
@ -56,24 +59,6 @@ class EncoderDecoder(BaseSegmentor):
else:
self.auxiliary_head = builder.build_head(auxiliary_head)
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone and heads.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
super(EncoderDecoder, self).init_weights(pretrained)
self.backbone.init_weights(pretrained=pretrained)
self.decode_head.init_weights()
if self.with_auxiliary_head:
if isinstance(self.auxiliary_head, nn.ModuleList):
for aux_head in self.auxiliary_head:
aux_head.init_weights()
else:
self.auxiliary_head.init_weights()
def extract_feat(self, img):
"""Extract features from images."""
x = self.backbone(img)

View File

@ -1,8 +1,9 @@
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.runner import Sequential
from torch import nn as nn
class ResLayer(nn.Sequential):
class ResLayer(Sequential):
"""ResLayer to build ResNet style backbone.
Args:

View File

@ -300,8 +300,8 @@ def test_resnet_backbone():
with pytest.raises(TypeError):
# pretrained must be a string path
model = ResNet(50)
model.init_weights(pretrained=0)
model = ResNet(50, pretrained=0)
model.init_weights()
with pytest.raises(AssertionError):
# Style must be in ['pytorch', 'caffe']
@ -314,8 +314,9 @@ def test_resnet_backbone():
assert check_norm_state(model.modules(), False)
# Test ResNet50 with torchvision pretrained weight
model = ResNet(depth=50, norm_eval=True)
model.init_weights('torchvision://resnet50')
model = ResNet(
depth=50, norm_eval=True, pretrained='torchvision://resnet50')
model.init_weights()
model.train()
assert check_norm_state(model.modules(), False)

View File

@ -734,7 +734,6 @@ def test_unet():
downsamples=(True, True, True, True),
enc_dilations=(1, 1, 1, 1, 1),
dec_dilations=(1, 1, 1, 1))
print(unet)
x = torch.randn(2, 3, 128, 128)
x_outs = unet(x)
assert x_outs[0].shape == torch.Size([2, 1024, 8, 8])
@ -754,7 +753,6 @@ def test_unet():
downsamples=(True, True, True, False),
enc_dilations=(1, 1, 1, 1, 1),
dec_dilations=(1, 1, 1, 1))
print(unet)
x = torch.randn(2, 3, 128, 128)
x_outs = unet(x)
assert x_outs[0].shape == torch.Size([2, 1024, 16, 16])
@ -774,7 +772,6 @@ def test_unet():
downsamples=(True, True, True, False),
enc_dilations=(1, 1, 1, 1, 1),
dec_dilations=(1, 1, 1, 1))
print(unet)
x = torch.randn(2, 3, 128, 128)
x_outs = unet(x)
assert x_outs[0].shape == torch.Size([2, 1024, 16, 16])
@ -794,7 +791,6 @@ def test_unet():
downsamples=(True, True, False, False),
enc_dilations=(1, 1, 1, 1, 1),
dec_dilations=(1, 1, 1, 1))
print(unet)
x = torch.randn(2, 3, 128, 128)
x_outs = unet(x)
assert x_outs[0].shape == torch.Size([2, 1024, 32, 32])
@ -813,9 +809,9 @@ def test_unet():
dec_num_convs=(2, 2, 2, 2),
downsamples=(True, True, False, False),
enc_dilations=(1, 1, 1, 1, 1),
dec_dilations=(1, 1, 1, 1))
unet.init_weights(pretrained=None)
print(unet)
dec_dilations=(1, 1, 1, 1),
pretrained=None)
unet.init_weights()
x = torch.randn(2, 3, 128, 128)
x_outs = unet(x)
assert x_outs[0].shape == torch.Size([2, 1024, 32, 32])

View File

@ -215,6 +215,7 @@ def _test_encoder_decoder_forward(cfg_file):
from mmseg.models import build_segmentor
segmentor = build_segmentor(model)
segmentor.init_weights()
if isinstance(segmentor.decode_head, nn.ModuleList):
num_classes = segmentor.decode_head[-1].num_classes

View File

@ -131,6 +131,7 @@ def main():
cfg.model,
train_cfg=cfg.get('train_cfg'),
test_cfg=cfg.get('test_cfg'))
model.init_weights()
logger.info(model)