mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Refactor]: Unified parameter initialization (#567)
* [Refactor]: Unified parameter initialization * fixed pretrained
This commit is contained in:
parent
af6478dd7a
commit
9849a8dc23
@ -1,12 +1,12 @@
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import (ConvModule, build_conv_layer, build_norm_layer,
|
||||
constant_init, kaiming_init)
|
||||
from mmcv.runner import load_checkpoint
|
||||
from mmcv.cnn import ConvModule, build_conv_layer, build_norm_layer
|
||||
from mmcv.runner import BaseModule
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmseg.utils import get_root_logger
|
||||
from ..builder import BACKBONES
|
||||
|
||||
|
||||
@ -183,7 +183,7 @@ class InputInjection(nn.Module):
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class CGNet(nn.Module):
|
||||
class CGNet(BaseModule):
|
||||
"""CGNet backbone.
|
||||
|
||||
A Light-weight Context Guided Network for Semantic Segmentation
|
||||
@ -210,6 +210,9 @@ class CGNet(nn.Module):
|
||||
and its variants only. Default: False.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
pretrained (str, optional): model pretrained path. Default: None
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -222,9 +225,31 @@ class CGNet(nn.Module):
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
act_cfg=dict(type='PReLU'),
|
||||
norm_eval=False,
|
||||
with_cp=False):
|
||||
with_cp=False,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
|
||||
super(CGNet, self).__init__(init_cfg)
|
||||
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be setting at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is None:
|
||||
if init_cfg is None:
|
||||
self.init_cfg = [
|
||||
dict(type='Kaiming', layer=['Conv2d', 'Linear']),
|
||||
dict(
|
||||
type='Constant',
|
||||
val=1,
|
||||
layer=['_BatchNorm', 'GroupNorm']),
|
||||
dict(type='Constant', val=0, layer='PReLU')
|
||||
]
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
super(CGNet, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.num_channels = num_channels
|
||||
assert isinstance(self.num_channels, tuple) and len(
|
||||
@ -335,27 +360,6 @@ class CGNet(nn.Module):
|
||||
|
||||
return output
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
"""Initialize the weights in backbone.
|
||||
|
||||
Args:
|
||||
pretrained (str, optional): Path to pre-trained weights.
|
||||
Defaults to None.
|
||||
"""
|
||||
if isinstance(pretrained, str):
|
||||
logger = get_root_logger()
|
||||
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
||||
elif pretrained is None:
|
||||
for m in self.modules():
|
||||
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
||||
kaiming_init(m)
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
||||
constant_init(m, 1)
|
||||
elif isinstance(m, nn.PReLU):
|
||||
constant_init(m, 0)
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
def train(self, mode=True):
|
||||
"""Convert the model into training mode will keeping the normalization
|
||||
layer freezed."""
|
||||
|
@ -1,8 +1,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule, constant_init,
|
||||
kaiming_init)
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
|
||||
from mmcv.runner import BaseModule
|
||||
|
||||
from mmseg.models.decode_heads.psp_head import PPM
|
||||
from mmseg.ops import resize
|
||||
@ -247,7 +246,7 @@ class FeatureFusionModule(nn.Module):
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class FastSCNN(nn.Module):
|
||||
class FastSCNN(BaseModule):
|
||||
"""Fast-SCNN Backbone.
|
||||
|
||||
Args:
|
||||
@ -291,6 +290,8 @@ class FastSCNN(nn.Module):
|
||||
dict(type='ReLU')
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -307,9 +308,18 @@ class FastSCNN(nn.Module):
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
align_corners=False):
|
||||
align_corners=False,
|
||||
init_cfg=None):
|
||||
|
||||
super(FastSCNN, self).__init__(init_cfg)
|
||||
|
||||
if init_cfg is None:
|
||||
self.init_cfg = [
|
||||
dict(type='Kaiming', layer='Conv2d'),
|
||||
dict(
|
||||
type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
|
||||
]
|
||||
|
||||
super(FastSCNN, self).__init__()
|
||||
if global_in_channels != higher_in_channels:
|
||||
raise AssertionError('Global Input Channels must be the same \
|
||||
with Higher Input Channels!')
|
||||
@ -357,13 +367,6 @@ class FastSCNN(nn.Module):
|
||||
act_cfg=self.act_cfg,
|
||||
align_corners=self.align_corners)
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m)
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
||||
constant_init(m, 1)
|
||||
|
||||
def forward(self, x):
|
||||
higher_res_features = self.learning_to_downsample(x)
|
||||
lower_res_features = self.global_feature_extractor(higher_res_features)
|
||||
|
@ -1,16 +1,16 @@
|
||||
import warnings
|
||||
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init,
|
||||
kaiming_init)
|
||||
from mmcv.runner import load_checkpoint
|
||||
from mmcv.cnn import build_conv_layer, build_norm_layer
|
||||
from mmcv.runner import BaseModule, ModuleList, Sequential
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmseg.ops import Upsample, resize
|
||||
from mmseg.utils import get_root_logger
|
||||
from ..builder import BACKBONES
|
||||
from .resnet import BasicBlock, Bottleneck
|
||||
|
||||
|
||||
class HRModule(nn.Module):
|
||||
class HRModule(BaseModule):
|
||||
"""High-Resolution Module for HRNet.
|
||||
|
||||
In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange
|
||||
@ -26,8 +26,11 @@ class HRModule(nn.Module):
|
||||
multiscale_output=True,
|
||||
with_cp=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', requires_grad=True)):
|
||||
super(HRModule, self).__init__()
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
block_init_cfg=None,
|
||||
init_cfg=None):
|
||||
super(HRModule, self).__init__(init_cfg)
|
||||
self.block_init_cfg = block_init_cfg
|
||||
self._check_branches(num_branches, num_blocks, in_channels,
|
||||
num_channels)
|
||||
|
||||
@ -92,7 +95,8 @@ class HRModule(nn.Module):
|
||||
downsample=downsample,
|
||||
with_cp=self.with_cp,
|
||||
norm_cfg=self.norm_cfg,
|
||||
conv_cfg=self.conv_cfg))
|
||||
conv_cfg=self.conv_cfg,
|
||||
init_cfg=self.block_init_cfg))
|
||||
self.in_channels[branch_index] = \
|
||||
num_channels[branch_index] * block.expansion
|
||||
for i in range(1, num_blocks[branch_index]):
|
||||
@ -102,9 +106,10 @@ class HRModule(nn.Module):
|
||||
num_channels[branch_index],
|
||||
with_cp=self.with_cp,
|
||||
norm_cfg=self.norm_cfg,
|
||||
conv_cfg=self.conv_cfg))
|
||||
conv_cfg=self.conv_cfg,
|
||||
init_cfg=self.block_init_cfg))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
return Sequential(*layers)
|
||||
|
||||
def _make_branches(self, num_branches, block, num_blocks, num_channels):
|
||||
"""Build multiple branch."""
|
||||
@ -114,7 +119,7 @@ class HRModule(nn.Module):
|
||||
branches.append(
|
||||
self._make_one_branch(i, block, num_blocks, num_channels))
|
||||
|
||||
return nn.ModuleList(branches)
|
||||
return ModuleList(branches)
|
||||
|
||||
def _make_fuse_layers(self):
|
||||
"""Build fuse layer."""
|
||||
@ -209,7 +214,7 @@ class HRModule(nn.Module):
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class HRNet(nn.Module):
|
||||
class HRNet(BaseModule):
|
||||
"""HRNet backbone.
|
||||
|
||||
High-Resolution Representations for Labeling Pixels and Regions
|
||||
@ -227,6 +232,9 @@ class HRNet(nn.Module):
|
||||
memory while slowing down the training speed.
|
||||
zero_init_residual (bool): whether to use zero init for last norm layer
|
||||
in resblocks to let them behave as identity.
|
||||
pretrained (str, optional): model pretrained path. Default: None
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None
|
||||
|
||||
Example:
|
||||
>>> from mmseg.models import HRNet
|
||||
@ -277,14 +285,36 @@ class HRNet(nn.Module):
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
norm_eval=False,
|
||||
with_cp=False,
|
||||
zero_init_residual=False):
|
||||
super(HRNet, self).__init__()
|
||||
zero_init_residual=False,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super(HRNet, self).__init__(init_cfg)
|
||||
|
||||
self.pretrained = pretrained
|
||||
self.zero_init_residual = zero_init_residual
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be setting at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is None:
|
||||
if init_cfg is None:
|
||||
self.init_cfg = [
|
||||
dict(type='Kaiming', layer='Conv2d'),
|
||||
dict(
|
||||
type='Constant',
|
||||
val=1,
|
||||
layer=['_BatchNorm', 'GroupNorm'])
|
||||
]
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
self.extra = extra
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.norm_eval = norm_eval
|
||||
self.with_cp = with_cp
|
||||
self.zero_init_residual = zero_init_residual
|
||||
|
||||
# stem net
|
||||
self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
|
||||
@ -430,6 +460,16 @@ class HRNet(nn.Module):
|
||||
build_norm_layer(self.norm_cfg, planes * block.expansion)[1])
|
||||
|
||||
layers = []
|
||||
block_init_cfg = None
|
||||
if self.pretrained is None and not hasattr(
|
||||
self, 'init_cfg') and self.zero_init_residual:
|
||||
if block is BasicBlock:
|
||||
block_init_cfg = dict(
|
||||
type='Constant', val=0, override=dict(name='norm2'))
|
||||
elif block is Bottleneck:
|
||||
block_init_cfg = dict(
|
||||
type='Constant', val=0, override=dict(name='norm3'))
|
||||
|
||||
layers.append(
|
||||
block(
|
||||
inplanes,
|
||||
@ -438,7 +478,8 @@ class HRNet(nn.Module):
|
||||
downsample=downsample,
|
||||
with_cp=self.with_cp,
|
||||
norm_cfg=self.norm_cfg,
|
||||
conv_cfg=self.conv_cfg))
|
||||
conv_cfg=self.conv_cfg,
|
||||
init_cfg=block_init_cfg))
|
||||
inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(
|
||||
@ -447,9 +488,10 @@ class HRNet(nn.Module):
|
||||
planes,
|
||||
with_cp=self.with_cp,
|
||||
norm_cfg=self.norm_cfg,
|
||||
conv_cfg=self.conv_cfg))
|
||||
conv_cfg=self.conv_cfg,
|
||||
init_cfg=block_init_cfg))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
return Sequential(*layers)
|
||||
|
||||
def _make_stage(self, layer_config, in_channels, multiscale_output=True):
|
||||
"""Make each stage."""
|
||||
@ -460,6 +502,16 @@ class HRNet(nn.Module):
|
||||
block = self.blocks_dict[layer_config['block']]
|
||||
|
||||
hr_modules = []
|
||||
block_init_cfg = None
|
||||
if self.pretrained is None and not hasattr(
|
||||
self, 'init_cfg') and self.zero_init_residual:
|
||||
if block is BasicBlock:
|
||||
block_init_cfg = dict(
|
||||
type='Constant', val=0, override=dict(name='norm2'))
|
||||
elif block is Bottleneck:
|
||||
block_init_cfg = dict(
|
||||
type='Constant', val=0, override=dict(name='norm3'))
|
||||
|
||||
for i in range(num_modules):
|
||||
# multi_scale_output is only used for the last module
|
||||
if not multiscale_output and i == num_modules - 1:
|
||||
@ -477,35 +529,10 @@ class HRNet(nn.Module):
|
||||
reset_multiscale_output,
|
||||
with_cp=self.with_cp,
|
||||
norm_cfg=self.norm_cfg,
|
||||
conv_cfg=self.conv_cfg))
|
||||
conv_cfg=self.conv_cfg,
|
||||
block_init_cfg=block_init_cfg))
|
||||
|
||||
return nn.Sequential(*hr_modules), in_channels
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
"""Initialize the weights in backbone.
|
||||
|
||||
Args:
|
||||
pretrained (str, optional): Path to pre-trained weights.
|
||||
Defaults to None.
|
||||
"""
|
||||
if isinstance(pretrained, str):
|
||||
logger = get_root_logger()
|
||||
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
||||
elif pretrained is None:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m)
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
||||
constant_init(m, 1)
|
||||
|
||||
if self.zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
constant_init(m.norm3, 0)
|
||||
elif isinstance(m, BasicBlock):
|
||||
constant_init(m.norm2, 0)
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
return Sequential(*hr_modules), in_channels
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
@ -1,8 +1,8 @@
|
||||
import logging
|
||||
import warnings
|
||||
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, constant_init, kaiming_init
|
||||
from mmcv.runner import load_checkpoint
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmcv.runner import BaseModule
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from ..builder import BACKBONES
|
||||
@ -10,7 +10,7 @@ from ..utils import InvertedResidual, make_divisible
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class MobileNetV2(nn.Module):
|
||||
class MobileNetV2(BaseModule):
|
||||
"""MobileNetV2 backbone.
|
||||
|
||||
Args:
|
||||
@ -35,6 +35,9 @@ class MobileNetV2(nn.Module):
|
||||
and its variants only. Default: False.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
pretrained (str, optional): model pretrained path. Default: None
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None
|
||||
"""
|
||||
|
||||
# Parameters to build layers. 3 parameters are needed to construct a
|
||||
@ -52,8 +55,30 @@ class MobileNetV2(nn.Module):
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU6'),
|
||||
norm_eval=False,
|
||||
with_cp=False):
|
||||
super(MobileNetV2, self).__init__()
|
||||
with_cp=False,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super(MobileNetV2, self).__init__(init_cfg)
|
||||
|
||||
self.pretrained = pretrained
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be setting at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is None:
|
||||
if init_cfg is None:
|
||||
self.init_cfg = [
|
||||
dict(type='Kaiming', layer='Conv2d'),
|
||||
dict(
|
||||
type='Constant',
|
||||
val=1,
|
||||
layer=['_BatchNorm', 'GroupNorm'])
|
||||
]
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
self.widen_factor = widen_factor
|
||||
self.strides = strides
|
||||
self.dilations = dilations
|
||||
@ -133,19 +158,6 @@ class MobileNetV2(nn.Module):
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
if isinstance(pretrained, str):
|
||||
logger = logging.getLogger()
|
||||
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
||||
elif pretrained is None:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m)
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
||||
constant_init(m, 1)
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
|
||||
|
@ -1,10 +1,9 @@
|
||||
import logging
|
||||
import warnings
|
||||
|
||||
import mmcv
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, constant_init, kaiming_init
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmcv.cnn.bricks import Conv2dAdaptivePadding
|
||||
from mmcv.runner import load_checkpoint
|
||||
from mmcv.runner import BaseModule
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from ..builder import BACKBONES
|
||||
@ -12,7 +11,7 @@ from ..utils import InvertedResidualV3 as InvertedResidual
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class MobileNetV3(nn.Module):
|
||||
class MobileNetV3(BaseModule):
|
||||
"""MobileNetV3 backbone.
|
||||
|
||||
This backbone is the improved implementation of `Searching for MobileNetV3
|
||||
@ -35,6 +34,9 @@ class MobileNetV3(nn.Module):
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save
|
||||
some memory while slowing down the training speed.
|
||||
Default: False.
|
||||
pretrained (str, optional): model pretrained path. Default: None
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None
|
||||
"""
|
||||
# Parameters to build each block:
|
||||
# [kernel size, mid channels, out channels, with_se, act type, stride]
|
||||
@ -75,8 +77,30 @@ class MobileNetV3(nn.Module):
|
||||
frozen_stages=-1,
|
||||
reduction_factor=1,
|
||||
norm_eval=False,
|
||||
with_cp=False):
|
||||
super(MobileNetV3, self).__init__()
|
||||
with_cp=False,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super(MobileNetV3, self).__init__(init_cfg)
|
||||
|
||||
self.pretrained = pretrained
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be setting at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is None:
|
||||
if init_cfg is None:
|
||||
self.init_cfg = [
|
||||
dict(type='Kaiming', layer='Conv2d'),
|
||||
dict(
|
||||
type='Constant',
|
||||
val=1,
|
||||
layer=['_BatchNorm', 'GroupNorm'])
|
||||
]
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
assert arch in self.arch_settings
|
||||
assert isinstance(reduction_factor, int) and reduction_factor > 0
|
||||
assert mmcv.is_tuple_of(out_indices, int)
|
||||
@ -217,19 +241,6 @@ class MobileNetV3(nn.Module):
|
||||
|
||||
return layers
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
if isinstance(pretrained, str):
|
||||
logger = logging.getLogger()
|
||||
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
||||
elif pretrained is None:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
constant_init(m, 1)
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
def forward(self, x):
|
||||
outs = []
|
||||
for i, layer_name in enumerate(self.layers):
|
||||
|
@ -1,16 +1,16 @@
|
||||
import warnings
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import (build_conv_layer, build_norm_layer, build_plugin_layer,
|
||||
constant_init, kaiming_init)
|
||||
from mmcv.runner import load_checkpoint
|
||||
from mmcv.cnn import build_conv_layer, build_norm_layer, build_plugin_layer
|
||||
from mmcv.runner import BaseModule
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmseg.utils import get_root_logger
|
||||
from ..builder import BACKBONES
|
||||
from ..utils import ResLayer
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
class BasicBlock(BaseModule):
|
||||
"""Basic block for ResNet."""
|
||||
|
||||
expansion = 1
|
||||
@ -26,8 +26,9 @@ class BasicBlock(nn.Module):
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
dcn=None,
|
||||
plugins=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
plugins=None,
|
||||
init_cfg=None):
|
||||
super(BasicBlock, self).__init__(init_cfg)
|
||||
assert dcn is None, 'Not implemented yet.'
|
||||
assert plugins is None, 'Not implemented yet.'
|
||||
|
||||
@ -94,7 +95,7 @@ class BasicBlock(nn.Module):
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
class Bottleneck(BaseModule):
|
||||
"""Bottleneck block for ResNet.
|
||||
|
||||
If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is
|
||||
@ -114,8 +115,9 @@ class Bottleneck(nn.Module):
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
dcn=None,
|
||||
plugins=None):
|
||||
super(Bottleneck, self).__init__()
|
||||
plugins=None,
|
||||
init_cfg=None):
|
||||
super(Bottleneck, self).__init__(init_cfg)
|
||||
assert style in ['pytorch', 'caffe']
|
||||
assert dcn is None or isinstance(dcn, dict)
|
||||
assert plugins is None or isinstance(plugins, list)
|
||||
@ -305,7 +307,7 @@ class Bottleneck(nn.Module):
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class ResNet(nn.Module):
|
||||
class ResNet(BaseModule):
|
||||
"""ResNet backbone.
|
||||
|
||||
Args:
|
||||
@ -346,6 +348,9 @@ class ResNet(nn.Module):
|
||||
memory while slowing down the training speed.
|
||||
zero_init_residual (bool): Whether to use zero init for last norm layer
|
||||
in resblocks to let them behave as identity.
|
||||
pretrained (str, optional): model pretrained path. Default: None
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None
|
||||
|
||||
Example:
|
||||
>>> from mmseg.models import ResNet
|
||||
@ -392,10 +397,46 @@ class ResNet(nn.Module):
|
||||
multi_grid=None,
|
||||
contract_dilation=False,
|
||||
with_cp=False,
|
||||
zero_init_residual=True):
|
||||
zero_init_residual=True,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super(ResNet, self).__init__()
|
||||
if depth not in self.arch_settings:
|
||||
raise KeyError(f'invalid depth {depth} for resnet')
|
||||
|
||||
self.pretrained = pretrained
|
||||
self.zero_init_residual = zero_init_residual
|
||||
block_init_cfg = None
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be setting at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is None:
|
||||
if init_cfg is None:
|
||||
self.init_cfg = [
|
||||
dict(type='Kaiming', layer='Conv2d'),
|
||||
dict(
|
||||
type='Constant',
|
||||
val=1,
|
||||
layer=['_BatchNorm', 'GroupNorm'])
|
||||
]
|
||||
block = self.arch_settings[depth][0]
|
||||
if self.zero_init_residual:
|
||||
if block is BasicBlock:
|
||||
block_init_cfg = dict(
|
||||
type='Constant',
|
||||
val=0,
|
||||
override=dict(name='norm2'))
|
||||
elif block is Bottleneck:
|
||||
block_init_cfg = dict(
|
||||
type='Constant',
|
||||
val=0,
|
||||
override=dict(name='norm3'))
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
self.depth = depth
|
||||
self.stem_channels = stem_channels
|
||||
self.base_channels = base_channels
|
||||
@ -421,7 +462,6 @@ class ResNet(nn.Module):
|
||||
self.plugins = plugins
|
||||
self.multi_grid = multi_grid
|
||||
self.contract_dilation = contract_dilation
|
||||
self.zero_init_residual = zero_init_residual
|
||||
self.block, stage_blocks = self.arch_settings[depth]
|
||||
self.stage_blocks = stage_blocks[:num_stages]
|
||||
self.inplanes = stem_channels
|
||||
@ -456,7 +496,8 @@ class ResNet(nn.Module):
|
||||
dcn=dcn,
|
||||
plugins=stage_plugins,
|
||||
multi_grid=stage_multi_grid,
|
||||
contract_dilation=contract_dilation)
|
||||
contract_dilation=contract_dilation,
|
||||
init_cfg=block_init_cfg)
|
||||
self.inplanes = planes * self.block.expansion
|
||||
layer_name = f'layer{i+1}'
|
||||
self.add_module(layer_name, res_layer)
|
||||
@ -597,38 +638,6 @@ class ResNet(nn.Module):
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
"""Initialize the weights in backbone.
|
||||
|
||||
Args:
|
||||
pretrained (str, optional): Path to pre-trained weights.
|
||||
Defaults to None.
|
||||
"""
|
||||
if isinstance(pretrained, str):
|
||||
logger = get_root_logger()
|
||||
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
||||
elif pretrained is None:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m)
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
||||
constant_init(m, 1)
|
||||
|
||||
if self.dcn is not None:
|
||||
for m in self.modules():
|
||||
if isinstance(m, Bottleneck) and hasattr(
|
||||
m, 'conv2_offset'):
|
||||
constant_init(m.conv2_offset, 0)
|
||||
|
||||
if self.zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
constant_init(m.norm3, 0)
|
||||
elif isinstance(m, BasicBlock):
|
||||
constant_init(m.norm2, 0)
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
if self.deep_stem:
|
||||
|
@ -1,11 +1,12 @@
|
||||
import warnings
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer,
|
||||
build_norm_layer, constant_init, kaiming_init)
|
||||
from mmcv.runner import load_checkpoint
|
||||
build_norm_layer)
|
||||
from mmcv.runner import BaseModule
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmseg.utils import get_root_logger
|
||||
from ..builder import BACKBONES
|
||||
from ..utils import UpConvBlock
|
||||
|
||||
@ -219,7 +220,7 @@ class InterpConv(nn.Module):
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class UNet(nn.Module):
|
||||
class UNet(BaseModule):
|
||||
"""UNet backbone.
|
||||
U-Net: Convolutional Networks for Biomedical Image Segmentation.
|
||||
https://arxiv.org/pdf/1505.04597.pdf
|
||||
@ -266,6 +267,9 @@ class UNet(nn.Module):
|
||||
dcn (bool): Use deformable convolution in convolutional layer or not.
|
||||
Default: None.
|
||||
plugins (dict): plugins for convolutional layers. Default: None.
|
||||
pretrained (str, optional): model pretrained path. Default: None
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None
|
||||
|
||||
Notice:
|
||||
The input image size should be divisible by the whole downsample rate
|
||||
@ -291,8 +295,30 @@ class UNet(nn.Module):
|
||||
upsample_cfg=dict(type='InterpConv'),
|
||||
norm_eval=False,
|
||||
dcn=None,
|
||||
plugins=None):
|
||||
super(UNet, self).__init__()
|
||||
plugins=None,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super(UNet, self).__init__(init_cfg)
|
||||
|
||||
self.pretrained = pretrained
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be setting at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is None:
|
||||
if init_cfg is None:
|
||||
self.init_cfg = [
|
||||
dict(type='Kaiming', layer='Conv2d'),
|
||||
dict(
|
||||
type='Constant',
|
||||
val=1,
|
||||
layer=['_BatchNorm', 'GroupNorm'])
|
||||
]
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
assert dcn is None, 'Not implemented yet.'
|
||||
assert plugins is None, 'Not implemented yet.'
|
||||
assert len(strides) == num_stages, \
|
||||
@ -408,22 +434,3 @@ class UNet(nn.Module):
|
||||
f'downsample rate {whole_downsample_rate}, when num_stages is '\
|
||||
f'{self.num_stages}, strides is {self.strides}, and downsamples '\
|
||||
f'is {self.downsamples}.'
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
"""Initialize the weights in backbone.
|
||||
|
||||
Args:
|
||||
pretrained (str, optional): Path to pre-trained weights.
|
||||
Defaults to None.
|
||||
"""
|
||||
if isinstance(pretrained, str):
|
||||
logger = get_root_logger()
|
||||
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
||||
elif pretrained is None:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m)
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
||||
constant_init(m, 1)
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
@ -9,7 +9,7 @@ import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import (Conv2d, Linear, build_activation_layer, build_norm_layer,
|
||||
constant_init, kaiming_init, normal_init)
|
||||
from mmcv.runner import _load_checkpoint
|
||||
from mmcv.runner import BaseModule, _load_checkpoint
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmseg.utils import get_root_logger
|
||||
@ -203,7 +203,7 @@ class PatchEmbed(nn.Module):
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class VisionTransformer(nn.Module):
|
||||
class VisionTransformer(BaseModule):
|
||||
"""Vision transformer backbone.
|
||||
|
||||
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for
|
||||
@ -243,6 +243,9 @@ class VisionTransformer(nn.Module):
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint
|
||||
will save some memory while slowing down the training speed.
|
||||
Default: False.
|
||||
pretrained (str, optional): model pretrained path. Default: None
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -266,8 +269,12 @@ class VisionTransformer(nn.Module):
|
||||
out_shape='NCHW',
|
||||
with_cls_token=True,
|
||||
interpolate_mode='bicubic',
|
||||
with_cp=False):
|
||||
super(VisionTransformer, self).__init__()
|
||||
with_cp=False,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super(VisionTransformer, self).__init__(init_cfg)
|
||||
self.pretrained = pretrained
|
||||
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.features = self.embed_dim = embed_dim
|
||||
@ -319,7 +326,8 @@ class VisionTransformer(nn.Module):
|
||||
self.norm_eval = norm_eval
|
||||
self.with_cp = with_cp
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
def init_weights(self):
|
||||
pretrained = self.pretrained
|
||||
if isinstance(pretrained, str):
|
||||
logger = get_root_logger()
|
||||
checkpoint = _load_checkpoint(pretrained, logger=logger)
|
||||
|
@ -2,8 +2,7 @@ from abc import ABCMeta, abstractmethod
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import normal_init
|
||||
from mmcv.runner import auto_fp16, force_fp32
|
||||
from mmcv.runner import BaseModule, auto_fp16, force_fp32
|
||||
|
||||
from mmseg.core import build_pixel_sampler
|
||||
from mmseg.ops import resize
|
||||
@ -11,7 +10,7 @@ from ..builder import build_loss
|
||||
from ..losses import accuracy
|
||||
|
||||
|
||||
class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
|
||||
class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
|
||||
"""Base class for BaseDecodeHead.
|
||||
|
||||
Args:
|
||||
@ -41,6 +40,7 @@ class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
|
||||
Default: None.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -60,8 +60,10 @@ class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
|
||||
loss_weight=1.0),
|
||||
ignore_index=255,
|
||||
sampler=None,
|
||||
align_corners=False):
|
||||
super(BaseDecodeHead, self).__init__()
|
||||
align_corners=False,
|
||||
init_cfg=dict(
|
||||
type='Normal', std=0.01, override=dict(name='conv_seg'))):
|
||||
super(BaseDecodeHead, self).__init__(init_cfg)
|
||||
self._init_inputs(in_channels, in_index, input_transform)
|
||||
self.channels = channels
|
||||
self.num_classes = num_classes
|
||||
@ -130,10 +132,6 @@ class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
|
||||
assert isinstance(in_index, int)
|
||||
self.in_channels = in_channels
|
||||
|
||||
def init_weights(self):
|
||||
"""Initialize weights of classification layer."""
|
||||
normal_init(self.conv_seg, mean=0, std=0.01)
|
||||
|
||||
def _transform_inputs(self, inputs):
|
||||
"""Transform inputs for decoder.
|
||||
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, normal_init
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmcv.ops import point_sample
|
||||
|
||||
from mmseg.models.builder import HEADS
|
||||
@ -69,6 +69,8 @@ class PointHead(BaseCascadeDecodeHead):
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
init_cfg=dict(
|
||||
type='Normal', std=0.01, override=dict(name='fc_seg')),
|
||||
**kwargs)
|
||||
|
||||
self.num_fcs = num_fcs
|
||||
@ -101,10 +103,6 @@ class PointHead(BaseCascadeDecodeHead):
|
||||
self.dropout = nn.Dropout(self.dropout_ratio)
|
||||
delattr(self, 'conv_seg')
|
||||
|
||||
def init_weights(self):
|
||||
"""Initialize weights of classification layer."""
|
||||
normal_init(self.fc_seg, std=0.001)
|
||||
|
||||
def cls_seg(self, feat):
|
||||
"""Classify each pixel with fc."""
|
||||
if self.dropout is not None:
|
||||
|
@ -1,12 +1,13 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule, xavier_init
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmcv.runner import BaseModule, auto_fp16
|
||||
|
||||
from ..builder import NECKS
|
||||
|
||||
|
||||
@NECKS.register_module()
|
||||
class FPN(nn.Module):
|
||||
class FPN(BaseModule):
|
||||
"""Feature Pyramid Network.
|
||||
|
||||
This is an implementation of - Feature Pyramid Networks for Object
|
||||
@ -43,6 +44,7 @@ class FPN(nn.Module):
|
||||
Default: None.
|
||||
upsample_cfg (dict): Config dict for interpolate layer.
|
||||
Default: `dict(mode='nearest')`
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
|
||||
Example:
|
||||
>>> import torch
|
||||
@ -73,8 +75,10 @@ class FPN(nn.Module):
|
||||
conv_cfg=None,
|
||||
norm_cfg=None,
|
||||
act_cfg=None,
|
||||
upsample_cfg=dict(mode='nearest')):
|
||||
super(FPN, self).__init__()
|
||||
upsample_cfg=dict(mode='nearest'),
|
||||
init_cfg=dict(
|
||||
type='Xavier', layer='Conv2d', distribution='uniform')):
|
||||
super(FPN, self).__init__(init_cfg)
|
||||
assert isinstance(in_channels, list)
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
@ -153,12 +157,7 @@ class FPN(nn.Module):
|
||||
inplace=False)
|
||||
self.fpn_convs.append(extra_fpn_conv)
|
||||
|
||||
# default init_weights for conv(msra) and norm in ConvModule
|
||||
def init_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
xavier_init(m, distribution='uniform')
|
||||
|
||||
@auto_fp16()
|
||||
def forward(self, inputs):
|
||||
assert len(inputs) == len(self.in_channels)
|
||||
|
||||
|
@ -1,4 +1,3 @@
|
||||
import logging
|
||||
import warnings
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from collections import OrderedDict
|
||||
@ -7,17 +6,14 @@ import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from mmcv.runner import auto_fp16
|
||||
from mmcv.runner import BaseModule, auto_fp16
|
||||
|
||||
|
||||
class BaseSegmentor(nn.Module):
|
||||
class BaseSegmentor(BaseModule, metaclass=ABCMeta):
|
||||
"""Base class for segmentors."""
|
||||
|
||||
__metaclass__ = ABCMeta
|
||||
|
||||
def __init__(self):
|
||||
super(BaseSegmentor, self).__init__()
|
||||
def __init__(self, init_cfg=None):
|
||||
super(BaseSegmentor, self).__init__(init_cfg)
|
||||
self.fp16_enabled = False
|
||||
|
||||
@property
|
||||
@ -62,17 +58,6 @@ class BaseSegmentor(nn.Module):
|
||||
"""Placeholder for augmentation test."""
|
||||
pass
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
"""Initialize the weights in segmentor.
|
||||
|
||||
Args:
|
||||
pretrained (str, optional): Path to pre-trained weights.
|
||||
Defaults to None.
|
||||
"""
|
||||
if pretrained is not None:
|
||||
logger = logging.getLogger()
|
||||
logger.info(f'load model from: {pretrained}')
|
||||
|
||||
def forward_test(self, imgs, img_metas, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
|
@ -24,7 +24,8 @@ class CascadeEncoderDecoder(EncoderDecoder):
|
||||
auxiliary_head=None,
|
||||
train_cfg=None,
|
||||
test_cfg=None,
|
||||
pretrained=None):
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
self.num_stages = num_stages
|
||||
super(CascadeEncoderDecoder, self).__init__(
|
||||
backbone=backbone,
|
||||
@ -33,7 +34,8 @@ class CascadeEncoderDecoder(EncoderDecoder):
|
||||
auxiliary_head=auxiliary_head,
|
||||
train_cfg=train_cfg,
|
||||
test_cfg=test_cfg,
|
||||
pretrained=pretrained)
|
||||
pretrained=pretrained,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
def _init_decode_head(self, decode_head):
|
||||
"""Initialize ``decode_head``"""
|
||||
@ -45,23 +47,6 @@ class CascadeEncoderDecoder(EncoderDecoder):
|
||||
self.align_corners = self.decode_head[-1].align_corners
|
||||
self.num_classes = self.decode_head[-1].num_classes
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
"""Initialize the weights in backbone and heads.
|
||||
|
||||
Args:
|
||||
pretrained (str, optional): Path to pre-trained weights.
|
||||
Defaults to None.
|
||||
"""
|
||||
self.backbone.init_weights(pretrained=pretrained)
|
||||
for i in range(self.num_stages):
|
||||
self.decode_head[i].init_weights()
|
||||
if self.with_auxiliary_head:
|
||||
if isinstance(self.auxiliary_head, nn.ModuleList):
|
||||
for aux_head in self.auxiliary_head:
|
||||
aux_head.init_weights()
|
||||
else:
|
||||
self.auxiliary_head.init_weights()
|
||||
|
||||
def encode_decode(self, img, img_metas):
|
||||
"""Encode images with backbone and decode into a semantic segmentation
|
||||
map of the same size as input."""
|
||||
|
@ -25,8 +25,13 @@ class EncoderDecoder(BaseSegmentor):
|
||||
auxiliary_head=None,
|
||||
train_cfg=None,
|
||||
test_cfg=None,
|
||||
pretrained=None):
|
||||
super(EncoderDecoder, self).__init__()
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super(EncoderDecoder, self).__init__(init_cfg)
|
||||
if pretrained is not None:
|
||||
assert backbone.get('pretrained') is None, \
|
||||
'both backbone and segmentor set pretrained weight'
|
||||
backbone.pretrained = pretrained
|
||||
self.backbone = builder.build_backbone(backbone)
|
||||
if neck is not None:
|
||||
self.neck = builder.build_neck(neck)
|
||||
@ -36,8 +41,6 @@ class EncoderDecoder(BaseSegmentor):
|
||||
self.train_cfg = train_cfg
|
||||
self.test_cfg = test_cfg
|
||||
|
||||
self.init_weights(pretrained=pretrained)
|
||||
|
||||
assert self.with_decode_head
|
||||
|
||||
def _init_decode_head(self, decode_head):
|
||||
@ -56,24 +59,6 @@ class EncoderDecoder(BaseSegmentor):
|
||||
else:
|
||||
self.auxiliary_head = builder.build_head(auxiliary_head)
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
"""Initialize the weights in backbone and heads.
|
||||
|
||||
Args:
|
||||
pretrained (str, optional): Path to pre-trained weights.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
super(EncoderDecoder, self).init_weights(pretrained)
|
||||
self.backbone.init_weights(pretrained=pretrained)
|
||||
self.decode_head.init_weights()
|
||||
if self.with_auxiliary_head:
|
||||
if isinstance(self.auxiliary_head, nn.ModuleList):
|
||||
for aux_head in self.auxiliary_head:
|
||||
aux_head.init_weights()
|
||||
else:
|
||||
self.auxiliary_head.init_weights()
|
||||
|
||||
def extract_feat(self, img):
|
||||
"""Extract features from images."""
|
||||
x = self.backbone(img)
|
||||
|
@ -1,8 +1,9 @@
|
||||
from mmcv.cnn import build_conv_layer, build_norm_layer
|
||||
from mmcv.runner import Sequential
|
||||
from torch import nn as nn
|
||||
|
||||
|
||||
class ResLayer(nn.Sequential):
|
||||
class ResLayer(Sequential):
|
||||
"""ResLayer to build ResNet style backbone.
|
||||
|
||||
Args:
|
||||
|
@ -300,8 +300,8 @@ def test_resnet_backbone():
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
# pretrained must be a string path
|
||||
model = ResNet(50)
|
||||
model.init_weights(pretrained=0)
|
||||
model = ResNet(50, pretrained=0)
|
||||
model.init_weights()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Style must be in ['pytorch', 'caffe']
|
||||
@ -314,8 +314,9 @@ def test_resnet_backbone():
|
||||
assert check_norm_state(model.modules(), False)
|
||||
|
||||
# Test ResNet50 with torchvision pretrained weight
|
||||
model = ResNet(depth=50, norm_eval=True)
|
||||
model.init_weights('torchvision://resnet50')
|
||||
model = ResNet(
|
||||
depth=50, norm_eval=True, pretrained='torchvision://resnet50')
|
||||
model.init_weights()
|
||||
model.train()
|
||||
assert check_norm_state(model.modules(), False)
|
||||
|
||||
|
@ -734,7 +734,6 @@ def test_unet():
|
||||
downsamples=(True, True, True, True),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
print(unet)
|
||||
x = torch.randn(2, 3, 128, 128)
|
||||
x_outs = unet(x)
|
||||
assert x_outs[0].shape == torch.Size([2, 1024, 8, 8])
|
||||
@ -754,7 +753,6 @@ def test_unet():
|
||||
downsamples=(True, True, True, False),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
print(unet)
|
||||
x = torch.randn(2, 3, 128, 128)
|
||||
x_outs = unet(x)
|
||||
assert x_outs[0].shape == torch.Size([2, 1024, 16, 16])
|
||||
@ -774,7 +772,6 @@ def test_unet():
|
||||
downsamples=(True, True, True, False),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
print(unet)
|
||||
x = torch.randn(2, 3, 128, 128)
|
||||
x_outs = unet(x)
|
||||
assert x_outs[0].shape == torch.Size([2, 1024, 16, 16])
|
||||
@ -794,7 +791,6 @@ def test_unet():
|
||||
downsamples=(True, True, False, False),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
print(unet)
|
||||
x = torch.randn(2, 3, 128, 128)
|
||||
x_outs = unet(x)
|
||||
assert x_outs[0].shape == torch.Size([2, 1024, 32, 32])
|
||||
@ -813,9 +809,9 @@ def test_unet():
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, False, False),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
unet.init_weights(pretrained=None)
|
||||
print(unet)
|
||||
dec_dilations=(1, 1, 1, 1),
|
||||
pretrained=None)
|
||||
unet.init_weights()
|
||||
x = torch.randn(2, 3, 128, 128)
|
||||
x_outs = unet(x)
|
||||
assert x_outs[0].shape == torch.Size([2, 1024, 32, 32])
|
||||
|
@ -215,6 +215,7 @@ def _test_encoder_decoder_forward(cfg_file):
|
||||
|
||||
from mmseg.models import build_segmentor
|
||||
segmentor = build_segmentor(model)
|
||||
segmentor.init_weights()
|
||||
|
||||
if isinstance(segmentor.decode_head, nn.ModuleList):
|
||||
num_classes = segmentor.decode_head[-1].num_classes
|
||||
|
@ -131,6 +131,7 @@ def main():
|
||||
cfg.model,
|
||||
train_cfg=cfg.get('train_cfg'),
|
||||
test_cfg=cfg.get('test_cfg'))
|
||||
model.init_weights()
|
||||
|
||||
logger.info(model)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user