Refactoring for ResNet family

pull/2/head
chenkai 2020-06-25 11:57:50 +08:00
parent 2a05c77f0f
commit 02e11cc1f3
10 changed files with 763 additions and 431 deletions

View File

@ -1,7 +1,7 @@
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init,
kaiming_init)
from mmcv.cnn import (ConvModule, build_conv_layer, build_norm_layer,
constant_init, kaiming_init)
from mmcv.utils.parrots_wrapper import _BatchNorm
from ..builder import BACKBONES
@ -12,15 +12,17 @@ class BasicBlock(nn.Module):
"""BasicBlock for ResNet.
Args:
inplanes (int): inplanes of block.
planes (int): planes of block.
in_channels (int): Input channels of this block.
out_channels (int): Output channels of this block.
expansion (int): The ratio of ``out_channels/mid_channels`` where
``mid_channels`` is the output channels of conv1. This is a
reserved argument in BasicBlock and should always be 1. Default: 1.
stride (int): stride of the block. Default: 1
dilation (int): dilation of convolution. Default: 1
downsample (nn.Module): downsample operation on identity branch.
Default: None
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
Default: None.
style (str): `pytorch` or `caffe`. It is unused and reserved for
unified API with Bottleneck.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
conv_cfg (dict): dictionary to construct and config conv layer.
@ -29,11 +31,10 @@ class BasicBlock(nn.Module):
Default: dict(type='BN')
"""
expansion = 1
def __init__(self,
inplanes,
planes,
in_channels,
out_channels,
expansion=1,
stride=1,
dilation=1,
downsample=None,
@ -42,14 +43,28 @@ class BasicBlock(nn.Module):
conv_cfg=None,
norm_cfg=dict(type='BN')):
super(BasicBlock, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.expansion = expansion
assert self.expansion == 1
assert out_channels % expansion == 0
self.mid_channels = out_channels // expansion
self.stride = stride
self.dilation = dilation
self.style = style
self.with_cp = with_cp
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
self.norm1_name, norm1 = build_norm_layer(
norm_cfg, self.mid_channels, postfix=1)
self.norm2_name, norm2 = build_norm_layer(
norm_cfg, out_channels, postfix=2)
self.conv1 = build_conv_layer(
conv_cfg,
inplanes,
planes,
in_channels,
self.mid_channels,
3,
stride=stride,
padding=dilation,
@ -57,14 +72,16 @@ class BasicBlock(nn.Module):
bias=False)
self.add_module(self.norm1_name, norm1)
self.conv2 = build_conv_layer(
conv_cfg, planes, planes, 3, padding=1, bias=False)
conv_cfg,
self.mid_channels,
out_channels,
3,
padding=1,
bias=False)
self.add_module(self.norm2_name, norm2)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
self.dilation = dilation
self.with_cp = with_cp
@property
def norm1(self):
@ -107,15 +124,17 @@ class Bottleneck(nn.Module):
"""Bottleneck block for ResNet.
Args:
inplanes (int): inplanes of block.
planes (int): planes of block.
in_channels (int): Input channels of this block.
out_channels (int): Output channels of this block.
expansion (int): The ratio of ``out_channels/mid_channels`` where
``mid_channels`` is the input/output channels of conv2. Default: 4.
stride (int): stride of the block. Default: 1
dilation (int): dilation of convolution. Default: 1
downsample (nn.Module): downsample operation on identity branch.
Default: None
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
Default: None.
style (str): ``"pytorch"`` or ``"caffe"``. If set to "pytorch", the
stride-two layer is the 3x3 conv layer, otherwise the stride-two
layer is the first 1x1 conv layer. Default: "pytorch".
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
conv_cfg (dict): dictionary to construct and config conv layer.
@ -124,11 +143,10 @@ class Bottleneck(nn.Module):
Default: dict(type='BN')
"""
expansion = 4
def __init__(self,
inplanes,
planes,
in_channels,
out_channels,
expansion=4,
stride=1,
dilation=1,
downsample=None,
@ -139,8 +157,11 @@ class Bottleneck(nn.Module):
super(Bottleneck, self).__init__()
assert style in ['pytorch', 'caffe']
self.inplanes = inplanes
self.planes = planes
self.in_channels = in_channels
self.out_channels = out_channels
self.expansion = expansion
assert out_channels % expansion == 0
self.mid_channels = out_channels // expansion
self.stride = stride
self.dilation = dilation
self.style = style
@ -155,23 +176,25 @@ class Bottleneck(nn.Module):
self.conv1_stride = stride
self.conv2_stride = 1
self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
self.norm1_name, norm1 = build_norm_layer(
norm_cfg, self.mid_channels, postfix=1)
self.norm2_name, norm2 = build_norm_layer(
norm_cfg, self.mid_channels, postfix=2)
self.norm3_name, norm3 = build_norm_layer(
norm_cfg, planes * self.expansion, postfix=3)
norm_cfg, out_channels, postfix=3)
self.conv1 = build_conv_layer(
conv_cfg,
inplanes,
planes,
in_channels,
self.mid_channels,
kernel_size=1,
stride=self.conv1_stride,
bias=False)
self.add_module(self.norm1_name, norm1)
self.conv2 = build_conv_layer(
conv_cfg,
planes,
planes,
self.mid_channels,
self.mid_channels,
kernel_size=3,
stride=self.conv2_stride,
padding=dilation,
@ -181,8 +204,8 @@ class Bottleneck(nn.Module):
self.add_module(self.norm2_name, norm2)
self.conv3 = build_conv_layer(
conv_cfg,
planes,
planes * self.expansion,
self.mid_channels,
out_channels,
kernel_size=1,
bias=False)
self.add_module(self.norm3_name, norm3)
@ -235,15 +258,55 @@ class Bottleneck(nn.Module):
return out
def get_expansion(block, expansion=None):
"""Get the expansion of a residual block.
The block expansion will be obtained by the following order:
1. If ``expansion`` is given, just return it.
2. If ``block`` has the attribute ``expansion``, then return
``block.expansion``.
3. Return the default value according the the block type:
1 for ``BasicBlock`` and 4 for ``Bottleneck``.
Args:
block (class): The block class.
expansion (int | None): The given expansion ratio.
Returns:
int: The expansion of the block.
"""
if isinstance(expansion, int):
assert expansion > 0
elif expansion is None:
if hasattr(block, 'expansion'):
expansion = block.expansion
elif issubclass(block, BasicBlock):
expansion = 1
elif issubclass(block, Bottleneck):
expansion = 4
else:
raise TypeError(f'expansion is not specified for {block.__name__}')
else:
raise TypeError('expansion must be an integer or None')
return expansion
class ResLayer(nn.Sequential):
"""ResLayer to build ResNet style backbone.
Args:
block (nn.Module): block used to build ResLayer.
inplanes (int): inplanes of block.
planes (int): planes of block.
num_blocks (int): number of blocks.
stride (int): stride of the first block. Default: 1
block (nn.Module): Residual block used to build ResLayer.
num_blocks (int): Number of blocks.
in_channels (int): Input channels of this block.
out_channels (int): Output channels of this block.
expansion (int, optional): The expansion for BasicBlock/Bottleneck.
If not specified, it will firstly be obtained via
``block.expansion``. If the block has no attribute "expansion",
the following default values will be used: 1 for BasicBlock and
4 for Bottleneck. Default: None.
stride (int): stride of the first block. Default: 1.
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottleneck. Default: False
conv_cfg (dict): dictionary to construct and config conv layer.
@ -254,18 +317,20 @@ class ResLayer(nn.Sequential):
def __init__(self,
block,
inplanes,
planes,
num_blocks,
in_channels,
out_channels,
expansion=None,
stride=1,
avg_down=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
**kwargs):
self.block = block
self.expansion = get_expansion(block, expansion)
downsample = None
if stride != 1 or inplanes != planes * block.expansion:
if stride != 1 or in_channels != out_channels:
downsample = []
conv_stride = stride
if avg_down and stride != 1:
@ -279,31 +344,33 @@ class ResLayer(nn.Sequential):
downsample.extend([
build_conv_layer(
conv_cfg,
inplanes,
planes * block.expansion,
in_channels,
out_channels,
kernel_size=1,
stride=conv_stride,
bias=False),
build_norm_layer(norm_cfg, planes * block.expansion)[1]
build_norm_layer(norm_cfg, out_channels)[1]
])
downsample = nn.Sequential(*downsample)
layers = []
layers.append(
block(
inplanes=inplanes,
planes=planes,
in_channels=in_channels,
out_channels=out_channels,
expansion=self.expansion,
stride=stride,
downsample=downsample,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
**kwargs))
inplanes = planes * block.expansion
in_channels = out_channels
for i in range(1, num_blocks):
layers.append(
block(
inplanes=inplanes,
planes=planes,
in_channels=in_channels,
out_channels=out_channels,
expansion=self.expansion,
stride=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
@ -315,30 +382,41 @@ class ResLayer(nn.Sequential):
class ResNet(BaseBackbone):
"""ResNet backbone.
Please refer to the `paper <https://arxiv.org/abs/1512.03385>`_ for
details.
Args:
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
in_channels (int): Number of input image channels. Normally 3.
base_channels (int): Number of base channels of hidden layer.
num_stages (int): Resnet stages, normally 4.
depth (int): Network depth, from {18, 34, 50, 101, 152}.
in_channels (int): Number of input image channels. Default: 3.
stem_channels (int): Output channels of the stem layer. Default: 64.
base_channels (int): Middle channels of the first stage. Default: 64.
num_stages (int): Stages of the network. Default: 4.
strides (Sequence[int]): Strides of the first block of each stage.
Default: ``(1, 2, 2, 2)``.
dilations (Sequence[int]): Dilation of each stage.
out_indices (Sequence[int]): Output from which stages.
Default: ``(1, 1, 1, 1)``.
out_indices (Sequence[int]): Output from which stages. If only one
stage is specified, a single tensor (feature map) is returned,
otherwise multiple stages are specified, a tuple of tensors will
be returned. Default: ``(3, )``.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
Default: False.
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottleneck.
downsampling in the bottleneck. Default: False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters.
norm_cfg (dict): Dictionary to construct and config norm layer.
-1 means not freezing any parameters. Default: -1.
conv_cfg (dict | None): The config dict for conv layers. Default: None.
norm_cfg (dict): The config dict for norm layers.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only.
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
memory while slowing down the training speed. Default: False.
zero_init_residual (bool): Whether to use zero init for last norm layer
in resblocks to let them behave as identity.
in resblocks to let them behave as identity. Default: True.
Example:
>>> from mmcls.models import ResNet
@ -366,7 +444,9 @@ class ResNet(BaseBackbone):
def __init__(self,
depth,
in_channels=3,
stem_channels=64,
base_channels=64,
expansion=None,
num_stages=4,
strides=(1, 2, 2, 2),
dilations=(1, 1, 1, 1),
@ -384,6 +464,7 @@ class ResNet(BaseBackbone):
if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for resnet')
self.depth = depth
self.stem_channels = stem_channels
self.base_channels = base_channels
self.num_stages = num_stages
assert num_stages >= 1 and num_stages <= 4
@ -403,20 +484,22 @@ class ResNet(BaseBackbone):
self.zero_init_residual = zero_init_residual
self.block, stage_blocks = self.arch_settings[depth]
self.stage_blocks = stage_blocks[:num_stages]
self.inplanes = base_channels
self.expansion = get_expansion(self.block, expansion)
self._make_stem_layer(in_channels, base_channels)
self._make_stem_layer(in_channels, stem_channels)
self.res_layers = []
_in_channels = stem_channels
_out_channels = base_channels * self.expansion
for i, num_blocks in enumerate(self.stage_blocks):
stride = strides[i]
dilation = dilations[i]
planes = base_channels * 2**i
res_layer = self.make_res_layer(
block=self.block,
inplanes=self.inplanes,
planes=planes,
num_blocks=num_blocks,
in_channels=_in_channels,
out_channels=_out_channels,
expansion=self.expansion,
stride=stride,
dilation=dilation,
style=self.style,
@ -424,15 +507,15 @@ class ResNet(BaseBackbone):
with_cp=with_cp,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg)
self.inplanes = planes * self.block.expansion
_in_channels = _out_channels
_out_channels *= 2
layer_name = f'layer{i + 1}'
self.add_module(layer_name, res_layer)
self.res_layers.append(layer_name)
self._freeze_stages()
self.feat_dim = self.block.expansion * base_channels * 2**(
len(self.stage_blocks) - 1)
self.feat_dim = res_layer[-1].out_channels
def make_res_layer(self, **kwargs):
return ResLayer(**kwargs)
@ -441,50 +524,47 @@ class ResNet(BaseBackbone):
def norm1(self):
return getattr(self, self.norm1_name)
def _make_stem_layer(self, in_channels, base_channels):
def _make_stem_layer(self, in_channels, stem_channels):
if self.deep_stem:
self.stem = nn.Sequential(
build_conv_layer(
self.conv_cfg,
ConvModule(
in_channels,
base_channels // 2,
stem_channels // 2,
kernel_size=3,
stride=2,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg, base_channels // 2)[1],
nn.ReLU(inplace=True),
build_conv_layer(
self.conv_cfg,
base_channels // 2,
base_channels // 2,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
inplace=True),
ConvModule(
stem_channels // 2,
stem_channels // 2,
kernel_size=3,
stride=1,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg, base_channels // 2)[1],
nn.ReLU(inplace=True),
build_conv_layer(
self.conv_cfg,
base_channels // 2,
base_channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
inplace=True),
ConvModule(
stem_channels // 2,
stem_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg, base_channels)[1],
nn.ReLU(inplace=True))
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
inplace=True))
else:
self.conv1 = build_conv_layer(
self.conv_cfg,
in_channels,
base_channels,
stem_channels,
kernel_size=7,
stride=2,
padding=3,
bias=False)
self.norm1_name, norm1 = build_norm_layer(
self.norm_cfg, base_channels, postfix=1)
self.norm_cfg, stem_channels, postfix=1)
self.add_module(self.norm1_name, norm1)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

View File

@ -1,5 +1,3 @@
import math
from mmcv.cnn import build_conv_layer, build_norm_layer
from ..builder import BACKBONES
@ -11,11 +9,12 @@ class Bottleneck(_Bottleneck):
"""Bottleneck block for ResNeXt.
Args:
inplanes (int): inplanes of block.
planes (int): planes of block.
groups (int): group of convolution.
base_width (int): Base width of resnext.
base_channels (int): Number of base channels of hidden layer.
in_channels (int): Input channels of this block.
out_channels (int): Output channels of this block.
groups (int): Groups of conv2.
width_per_group (int): Width per group of conv2. 64x4d indicates
``groups=64, width_per_group=4`` and 32x8d indicates
``groups=32, width_per_group=8``.
stride (int): stride of the block. Default: 1
dilation (int): dilation of convolution. Default: 1
downsample (nn.Module): downsample operation on identity branch.
@ -31,42 +30,44 @@ class Bottleneck(_Bottleneck):
memory while slowing down the training speed.
"""
expansion = 4
def __init__(self,
inplanes,
planes,
groups=1,
base_width=4,
in_channels,
out_channels,
base_channels=64,
groups=32,
width_per_group=4,
**kwargs):
super(Bottleneck, self).__init__(inplanes, planes, **kwargs)
super(Bottleneck, self).__init__(in_channels, out_channels, **kwargs)
self.groups = groups
self.width_per_group = width_per_group
if groups == 1:
width = self.planes
else:
width = math.floor(self.planes *
(base_width / base_channels)) * groups
# For ResNet bottleneck, middle channels are determined by expansion
# and out_channels, but for ResNeXt bottleneck, it is determined by
# groups and width_per_group and the stage it is located in.
if groups != 1:
assert self.mid_channels % base_channels == 0
self.mid_channels = (
groups * width_per_group * self.mid_channels // base_channels)
self.norm1_name, norm1 = build_norm_layer(
self.norm_cfg, width, postfix=1)
self.norm_cfg, self.mid_channels, postfix=1)
self.norm2_name, norm2 = build_norm_layer(
self.norm_cfg, width, postfix=2)
self.norm_cfg, self.mid_channels, postfix=2)
self.norm3_name, norm3 = build_norm_layer(
self.norm_cfg, self.planes * self.expansion, postfix=3)
self.norm_cfg, self.out_channels, postfix=3)
self.conv1 = build_conv_layer(
self.conv_cfg,
self.inplanes,
width,
self.in_channels,
self.mid_channels,
kernel_size=1,
stride=self.conv1_stride,
bias=False)
self.add_module(self.norm1_name, norm1)
self.conv2 = build_conv_layer(
self.conv_cfg,
width,
width,
self.mid_channels,
self.mid_channels,
kernel_size=3,
stride=self.conv2_stride,
padding=self.dilation,
@ -77,8 +78,8 @@ class Bottleneck(_Bottleneck):
self.add_module(self.norm2_name, norm2)
self.conv3 = build_conv_layer(
self.conv_cfg,
width,
self.planes * self.expansion,
self.mid_channels,
self.out_channels,
kernel_size=1,
bias=False)
self.add_module(self.norm3_name, norm3)
@ -88,29 +89,43 @@ class Bottleneck(_Bottleneck):
class ResNeXt(ResNet):
"""ResNeXt backbone.
Please refer to the `paper <https://arxiv.org/abs/1611.05431>`_ for
details.
Args:
groups (int): Group of resnext.
base_width (int): Base width of resnext.
depth (int): Depth of resnext, from {50, 101, 152}.
depth (int): Network depth, from {50, 101, 152}.
groups (int): Groups of conv2 in Bottleneck. Default: 32.
width_per_group (int): Width per group of conv2 in Bottleneck.
Default: 4.
in_channels (int): Number of input image channels. Default: 3.
base_channels (int): Number of base channels of hidden layer.
num_stages (int): Resnet stages. Default: 4.
stem_channels (int): Output channels of the stem layer. Default: 64.
num_stages (int): Stages of the network. Default: 4.
strides (Sequence[int]): Strides of the first block of each stage.
Default: ``(1, 2, 2, 2)``.
dilations (Sequence[int]): Dilation of each stage.
out_indices (Sequence[int]): Output from which stages.
Default: ``(1, 1, 1, 1)``.
out_indices (Sequence[int]): Output from which stages. If only one
stage is specified, a single tensor (feature map) is returned,
otherwise multiple stages are specified, a tuple of tensors will
be returned. Default: ``(3, )``.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
not freezing any parameters.
norm_cfg (dict): dictionary to construct and config norm layer.
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
Default: False.
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottleneck. Default: False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Default: -1.
conv_cfg (dict | None): The config dict for conv layers. Default: None.
norm_cfg (dict): The config dict for norm layers.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only.
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
zero_init_residual (bool): whether to use zero init for last norm layer
in resblocks to let them behave as identity.
memory while slowing down the training speed. Default: False.
zero_init_residual (bool): Whether to use zero init for last norm layer
in resblocks to let them behave as identity. Default: True.
"""
arch_settings = {
@ -119,14 +134,14 @@ class ResNeXt(ResNet):
152: (Bottleneck, (3, 8, 36, 3))
}
def __init__(self, groups=1, base_width=4, **kwargs):
def __init__(self, depth, groups=32, width_per_group=4, **kwargs):
self.groups = groups
self.base_width = base_width
super(ResNeXt, self).__init__(**kwargs)
self.width_per_group = width_per_group
super(ResNeXt, self).__init__(depth, **kwargs)
def make_res_layer(self, **kwargs):
return ResLayer(
groups=self.groups,
base_width=self.base_width,
width_per_group=self.width_per_group,
base_channels=self.base_channels,
**kwargs)

View File

@ -9,15 +9,14 @@ class SEBottleneck(Bottleneck):
"""SEBottleneck block for SEResNet.
Args:
inplanes (int): The input channels of the SEBottleneck block.
planes (int): The output channel base of the SEBottleneck block.
in_channels (int): The input channels of the SEBottleneck block.
out_channels (int): The output channel of the SEBottleneck block.
se_ratio (int): Squeeze ratio in SELayer. Default: 16
"""
expansion = 4
def __init__(self, inplanes, planes, se_ratio=16, **kwargs):
super(SEBottleneck, self).__init__(inplanes, planes, **kwargs)
self.se_layer = SELayer(planes * self.expansion, ratio=se_ratio)
def __init__(self, in_channels, out_channels, se_ratio=16, **kwargs):
super(SEBottleneck, self).__init__(in_channels, out_channels, **kwargs)
self.se_layer = SELayer(out_channels, ratio=se_ratio)
def forward(self, x):
@ -58,31 +57,41 @@ class SEBottleneck(Bottleneck):
class SEResNet(ResNet):
"""SEResNet backbone.
Please refer to the `paper <https://arxiv.org/abs/1709.01507>`_ for
details.
Args:
depth (int): Depth of seresnet, from {50, 101, 152}.
in_channels (int): Number of input image channels. Normally 3.
base_channels (int): Number of base channels of hidden layer.
num_stages (int): Resnet stages, normally 4.
depth (int): Network depth, from {50, 101, 152}.
se_ratio (int): Squeeze ratio in SELayer. Default: 16.
in_channels (int): Number of input image channels. Default: 3.
stem_channels (int): Output channels of the stem layer. Default: 64.
num_stages (int): Stages of the network. Default: 4.
strides (Sequence[int]): Strides of the first block of each stage.
Default: ``(1, 2, 2, 2)``.
dilations (Sequence[int]): Dilation of each stage.
out_indices (Sequence[int]): Output from which stages.
se_ratio (int): Squeeze ratio in SELayer. Default: 16
Default: ``(1, 1, 1, 1)``.
out_indices (Sequence[int]): Output from which stages. If only one
stage is specified, a single tensor (feature map) is returned,
otherwise multiple stages are specified, a tuple of tensors will
be returned. Default: ``(3, )``.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
Default: False.
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottleneck.
downsampling in the bottleneck. Default: False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters.
norm_cfg (dict): Dictionary to construct and config norm layer.
-1 means not freezing any parameters. Default: -1.
conv_cfg (dict | None): The config dict for conv layers. Default: None.
norm_cfg (dict): The config dict for norm layers.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only.
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
memory while slowing down the training speed. Default: False.
zero_init_residual (bool): Whether to use zero init for last norm layer
in resblocks to let them behave as identity.
in resblocks to let them behave as identity. Default: True.
Example:
>>> from mmcls.models import SEResNet
@ -107,7 +116,7 @@ class SEResNet(ResNet):
def __init__(self, depth, se_ratio=16, **kwargs):
if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for resnet')
raise KeyError(f'invalid depth {depth} for SEResNet')
self.se_ratio = se_ratio
super(SEResNet, self).__init__(depth, **kwargs)

View File

@ -1,5 +1,3 @@
import math
from mmcv.cnn import build_conv_layer, build_norm_layer
from ..builder import BACKBONES
@ -12,14 +10,15 @@ class SEBottleneck(_SEBottleneck):
"""SEBottleneck block for SEResNeXt.
Args:
inplanes (int): inplanes of block.
planes (int): planes of block.
groups (int): group of convolution.
base_width (int): Base width of resnext.
base_channels (int): Number of base channels of hidden layer.
in_channels (int): Input channels of this block.
out_channels (int): Output channels of this block.
groups (int): Groups of conv2.
width_per_group (int): Width per group of conv2. 64x4d indicates
``groups=64, width_per_group=4`` and 32x8d indicates
``groups=32, width_per_group=8``.
stride (int): stride of the block. Default: 1
dilation (int): dilation of convolution. Default: 1
dbownsample (nn.Module): downsample operation on identity branch.
downsample (nn.Module): downsample operation on identity branch.
Default: None
se_ratio (int): Squeeze ratio in SELayer. Default: 16
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
@ -33,35 +32,33 @@ class SEBottleneck(_SEBottleneck):
memory while slowing down the training speed.
"""
expansion = 4
def __init__(self,
inplanes,
planes,
groups=1,
base_width=4,
base_channels=64,
in_channels,
out_channels,
groups=32,
width_per_group=4,
se_ratio=16,
**kwargs):
super(SEBottleneck, self).__init__(inplanes, planes, se_ratio,
super(SEBottleneck, self).__init__(in_channels, out_channels, se_ratio,
**kwargs)
self.groups = groups
self.width_per_group = width_per_group
if groups == 1:
width = self.planes
width = self.mid_channels
else:
width = math.floor(self.planes *
(base_width / base_channels)) * groups
width = groups * width_per_group
self.norm1_name, norm1 = build_norm_layer(
self.norm_cfg, width, postfix=1)
self.norm2_name, norm2 = build_norm_layer(
self.norm_cfg, width, postfix=2)
self.norm3_name, norm3 = build_norm_layer(
self.norm_cfg, self.planes * self.expansion, postfix=3)
self.norm_cfg, self.out_channels, postfix=3)
self.conv1 = build_conv_layer(
self.conv_cfg,
self.inplanes,
self.in_channels,
width,
kernel_size=1,
stride=self.conv1_stride,
@ -80,11 +77,7 @@ class SEBottleneck(_SEBottleneck):
self.add_module(self.norm2_name, norm2)
self.conv3 = build_conv_layer(
self.conv_cfg,
width,
self.planes * self.expansion,
kernel_size=1,
bias=False)
self.conv_cfg, width, self.out_channels, kernel_size=1, bias=False)
self.add_module(self.norm3_name, norm3)
@ -92,30 +85,44 @@ class SEBottleneck(_SEBottleneck):
class SEResNeXt(SEResNet):
"""SEResNeXt backbone.
Please refer to the `paper <https://arxiv.org/abs/1709.01507>`_ for
details.
Args:
groups (int): Group of seresnext.
base_width (int): Base width of resnext.
depth (int): Depth of resnext, from {50, 101, 152}.
depth (int): Network depth, from {50, 101, 152}.
groups (int): Groups of conv2 in Bottleneck. Default: 32.
width_per_group (int): Width per group of conv2 in Bottleneck.
Default: 4.
se_ratio (int): Squeeze ratio in SELayer. Default: 16.
in_channels (int): Number of input image channels. Default: 3.
base_channels (int): Number of base channels of hidden layer.
num_stages (int): Resnet stages. Default: 4.
stem_channels (int): Output channels of the stem layer. Default: 64.
num_stages (int): Stages of the network. Default: 4.
strides (Sequence[int]): Strides of the first block of each stage.
Default: ``(1, 2, 2, 2)``.
dilations (Sequence[int]): Dilation of each stage.
out_indices (Sequence[int]): Output from which stages.
se_ratio (int): Squeeze ratio in SELayer. Default: 16
Default: ``(1, 1, 1, 1)``.
out_indices (Sequence[int]): Output from which stages. If only one
stage is specified, a single tensor (feature map) is returned,
otherwise multiple stages are specified, a tuple of tensors will
be returned. Default: ``(3, )``.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
not freezing any parameters.
norm_cfg (dict): dictionary to construct and config norm layer.
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
Default: False.
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottleneck. Default: False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Default: -1.
conv_cfg (dict | None): The config dict for conv layers. Default: None.
norm_cfg (dict): The config dict for norm layers.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only.
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
zero_init_residual (bool): whether to use zero init for last norm layer
in resblocks to let them behave as identity.
memory while slowing down the training speed. Default: False.
zero_init_residual (bool): Whether to use zero init for last norm layer
in resblocks to let them behave as identity. Default: True.
"""
arch_settings = {
@ -124,14 +131,11 @@ class SEResNeXt(SEResNet):
152: (SEBottleneck, (3, 8, 36, 3))
}
def __init__(self, groups=1, base_width=4, **kwargs):
def __init__(self, depth, groups=32, width_per_group=4, **kwargs):
self.groups = groups
self.base_width = base_width
super(SEResNeXt, self).__init__(**kwargs)
self.width_per_group = width_per_group
super(SEResNeXt, self).__init__(depth, **kwargs)
def make_res_layer(self, **kwargs):
return ResLayer(
groups=self.groups,
base_width=self.base_width,
base_channels=self.base_channels,
**kwargs)
groups=self.groups, width_per_group=self.width_per_group, **kwargs)

View File

@ -5,17 +5,18 @@ class SELayer(nn.Module):
"""Squeeze-and-Excitation Module.
Args:
inplanes (int): The input channels of the SEBottleneck block.
ratio (int): Squeeze ratio in SELayer. Default: 16
channels (int): The input (and output) channels of the SE layer.
ratio (int): Squeeze ratio in SELayer, the intermediate channel will be
``int(channels/ratio)``. Default: 16.
"""
def __init__(self, inplanes, ratio=16):
def __init__(self, channels, ratio=16):
super(SELayer, self).__init__()
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
self.conv1 = nn.Conv2d(
inplanes, int(inplanes / ratio), kernel_size=1, stride=1)
channels, int(channels / ratio), kernel_size=1, stride=1)
self.conv2 = nn.Conv2d(
int(inplanes / ratio), inplanes, kernel_size=1, stride=1)
int(channels / ratio), channels, kernel_size=1, stride=1)
self.relu = nn.ReLU(inplace=True)
self.sigmoid = nn.Sigmoid()

View File

@ -17,6 +17,6 @@ line_length = 79
multi_line_output = 0
known_standard_library = pkg_resources,setuptools
known_first_party = mmcls
known_third_party = cv2,mmcv,numpy,torch,torchvision
known_third_party = cv2,mmcv,numpy,pytest,torch,torchvision
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY

View File

@ -1,10 +1,12 @@
import pytest
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.utils.parrots_wrapper import _BatchNorm
from torch.nn.modules import AvgPool2d
from mmcls.models.backbones import ResNet, ResNetV1d
from mmcls.models.backbones.resnet import BasicBlock, Bottleneck, ResLayer
from mmcls.models.backbones.resnet import (BasicBlock, Bottleneck, ResLayer,
get_expansion)
def is_block(modules):
@ -14,13 +16,6 @@ def is_block(modules):
return False
def is_norm(modules):
"""Check if is one of the norms."""
if isinstance(modules, (_BatchNorm, )):
return True
return False
def all_zeros(modules):
"""Check if the weight(and bias) is all zero."""
weight_zero = torch.equal(modules.weight.data,
@ -43,12 +38,44 @@ def check_norm_state(modules, train_state):
return True
def test_resnet_basic_block():
# Test BasicBlock structure and forward
def test_get_expansion():
assert get_expansion(Bottleneck, 2) == 2
assert get_expansion(BasicBlock) == 1
assert get_expansion(Bottleneck) == 4
class MyResBlock(nn.Module):
expansion = 8
assert get_expansion(MyResBlock) == 8
# expansion must be an integer or None
with pytest.raises(TypeError):
get_expansion(Bottleneck, '0')
# expansion is not specified and cannot be inferred
with pytest.raises(TypeError):
class SomeModule(nn.Module):
pass
get_expansion(SomeModule)
def test_basic_block():
# expansion must be 1
with pytest.raises(AssertionError):
BasicBlock(64, 64, expansion=2)
# BasicBlock with stride 1, out_channels == in_channels
block = BasicBlock(64, 64)
assert block.in_channels == 64
assert block.mid_channels == 64
assert block.out_channels == 64
assert block.conv1.in_channels == 64
assert block.conv1.out_channels == 64
assert block.conv1.kernel_size == (3, 3)
assert block.conv1.stride == (1, 1)
assert block.conv2.in_channels == 64
assert block.conv2.out_channels == 64
assert block.conv2.kernel_size == (3, 3)
@ -56,26 +83,59 @@ def test_resnet_basic_block():
x_out = block(x)
assert x_out.shape == torch.Size([1, 64, 56, 56])
# Test BasicBlock with checkpoint forward
# BasicBlock with stride 1 and downsample
downsample = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=1, bias=False), nn.BatchNorm2d(128))
block = BasicBlock(64, 128, downsample=downsample)
assert block.in_channels == 64
assert block.mid_channels == 128
assert block.out_channels == 128
assert block.conv1.in_channels == 64
assert block.conv1.out_channels == 128
assert block.conv1.kernel_size == (3, 3)
assert block.conv1.stride == (1, 1)
assert block.conv2.in_channels == 128
assert block.conv2.out_channels == 128
assert block.conv2.kernel_size == (3, 3)
x = torch.randn(1, 64, 56, 56)
x_out = block(x)
assert x_out.shape == torch.Size([1, 128, 56, 56])
# BasicBlock with stride 2 and downsample
downsample = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=1, stride=2, bias=False),
nn.BatchNorm2d(128))
block = BasicBlock(64, 128, stride=2, downsample=downsample)
assert block.in_channels == 64
assert block.mid_channels == 128
assert block.out_channels == 128
assert block.conv1.in_channels == 64
assert block.conv1.out_channels == 128
assert block.conv1.kernel_size == (3, 3)
assert block.conv1.stride == (2, 2)
assert block.conv2.in_channels == 128
assert block.conv2.out_channels == 128
assert block.conv2.kernel_size == (3, 3)
x = torch.randn(1, 64, 56, 56)
x_out = block(x)
assert x_out.shape == torch.Size([1, 128, 28, 28])
# forward with checkpointing
block = BasicBlock(64, 64, with_cp=True)
assert block.with_cp
x = torch.randn(1, 64, 56, 56)
x = torch.randn(1, 64, 56, 56, requires_grad=True)
x_out = block(x)
assert x_out.shape == torch.Size([1, 64, 56, 56])
def test_resnet_bottleneck():
def test_bottleneck():
# style must be in ['pytorch', 'caffe']
with pytest.raises(AssertionError):
# Style must be in ['pytorch', 'caffe']
Bottleneck(64, 64, style='tensorflow')
# Test Bottleneck with checkpoint forward
block = Bottleneck(64, 16, with_cp=True)
assert block.with_cp
x = torch.randn(1, 64, 56, 56)
x_out = block(x)
assert x_out.shape == torch.Size([1, 64, 56, 56])
# expansion must be divisible by out_channels
with pytest.raises(AssertionError):
Bottleneck(64, 64, expansion=3)
# Test Bottleneck style
block = Bottleneck(64, 64, stride=2, style='pytorch')
@ -85,60 +145,232 @@ def test_resnet_bottleneck():
assert block.conv1.stride == (2, 2)
assert block.conv2.stride == (1, 1)
# Test Bottleneck forward
block = Bottleneck(64, 16)
# Bottleneck with stride 1
block = Bottleneck(64, 64, style='pytorch')
assert block.in_channels == 64
assert block.mid_channels == 16
assert block.out_channels == 64
assert block.conv1.in_channels == 64
assert block.conv1.out_channels == 16
assert block.conv1.kernel_size == (1, 1)
assert block.conv2.in_channels == 16
assert block.conv2.out_channels == 16
assert block.conv2.kernel_size == (3, 3)
assert block.conv3.in_channels == 16
assert block.conv3.out_channels == 64
assert block.conv3.kernel_size == (1, 1)
x = torch.randn(1, 64, 56, 56)
x_out = block(x)
assert x_out.shape == (1, 64, 56, 56)
# Bottleneck with stride 1 and downsample
downsample = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=1), nn.BatchNorm2d(128))
block = Bottleneck(64, 128, style='pytorch', downsample=downsample)
assert block.in_channels == 64
assert block.mid_channels == 32
assert block.out_channels == 128
assert block.conv1.in_channels == 64
assert block.conv1.out_channels == 32
assert block.conv1.kernel_size == (1, 1)
assert block.conv2.in_channels == 32
assert block.conv2.out_channels == 32
assert block.conv2.kernel_size == (3, 3)
assert block.conv3.in_channels == 32
assert block.conv3.out_channels == 128
assert block.conv3.kernel_size == (1, 1)
x = torch.randn(1, 64, 56, 56)
x_out = block(x)
assert x_out.shape == (1, 128, 56, 56)
# Bottleneck with stride 2 and downsample
downsample = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=1, stride=2), nn.BatchNorm2d(128))
block = Bottleneck(
64, 128, stride=2, style='pytorch', downsample=downsample)
x = torch.randn(1, 64, 56, 56)
x_out = block(x)
assert x_out.shape == (1, 128, 28, 28)
# Bottleneck with expansion 2
block = Bottleneck(64, 64, style='pytorch', expansion=2)
assert block.in_channels == 64
assert block.mid_channels == 32
assert block.out_channels == 64
assert block.conv1.in_channels == 64
assert block.conv1.out_channels == 32
assert block.conv1.kernel_size == (1, 1)
assert block.conv2.in_channels == 32
assert block.conv2.out_channels == 32
assert block.conv2.kernel_size == (3, 3)
assert block.conv3.in_channels == 32
assert block.conv3.out_channels == 64
assert block.conv3.kernel_size == (1, 1)
x = torch.randn(1, 64, 56, 56)
x_out = block(x)
assert x_out.shape == (1, 64, 56, 56)
# Test Bottleneck with checkpointing
block = Bottleneck(64, 64, with_cp=True)
block.train()
assert block.with_cp
x = torch.randn(1, 64, 56, 56, requires_grad=True)
x_out = block(x)
assert x_out.shape == torch.Size([1, 64, 56, 56])
def test_resnet_res_layer():
# Test ResLayer of 3 Bottleneck w\o downsample
layer = ResLayer(Bottleneck, 64, 16, 3)
def test_basicblock_reslayer():
# 3 BasicBlock w/o downsample
layer = ResLayer(BasicBlock, 3, 32, 32)
assert len(layer) == 3
assert layer[0].conv1.in_channels == 64
assert layer[0].conv1.out_channels == 16
for i in range(1, len(layer)):
assert layer[i].conv1.in_channels == 64
assert layer[i].conv1.out_channels == 16
for i in range(len(layer)):
for i in range(3):
assert layer[i].in_channels == 32
assert layer[i].out_channels == 32
assert layer[i].downsample is None
x = torch.randn(1, 64, 56, 56)
x = torch.randn(1, 32, 56, 56)
x_out = layer(x)
assert x_out.shape == torch.Size([1, 64, 56, 56])
assert x_out.shape == (1, 32, 56, 56)
# Test ResLayer of 3 Bottleneck with downsample
layer = ResLayer(Bottleneck, 64, 64, 3)
assert layer[0].downsample[0].out_channels == 256
for i in range(1, len(layer)):
# 3 BasicBlock w/ stride 1 and downsample
layer = ResLayer(BasicBlock, 3, 32, 64)
assert len(layer) == 3
assert layer[0].in_channels == 32
assert layer[0].out_channels == 64
assert layer[0].downsample is not None and len(layer[0].downsample) == 2
assert isinstance(layer[0].downsample[0], nn.Conv2d)
assert layer[0].downsample[0].stride == (1, 1)
for i in range(1, 3):
assert layer[i].in_channels == 64
assert layer[i].out_channels == 64
assert layer[i].downsample is None
x = torch.randn(1, 64, 56, 56)
x = torch.randn(1, 32, 56, 56)
x_out = layer(x)
assert x_out.shape == torch.Size([1, 256, 56, 56])
assert x_out.shape == (1, 64, 56, 56)
# Test ResLayer of 3 Bottleneck with stride=2
layer = ResLayer(Bottleneck, 64, 64, 3, stride=2)
assert layer[0].downsample[0].out_channels == 256
# 3 BasicBlock w/ stride 2 and downsample
layer = ResLayer(BasicBlock, 3, 32, 64, stride=2)
assert len(layer) == 3
assert layer[0].in_channels == 32
assert layer[0].out_channels == 64
assert layer[0].stride == 2
assert layer[0].downsample is not None and len(layer[0].downsample) == 2
assert isinstance(layer[0].downsample[0], nn.Conv2d)
assert layer[0].downsample[0].stride == (2, 2)
for i in range(1, len(layer)):
for i in range(1, 3):
assert layer[i].in_channels == 64
assert layer[i].out_channels == 64
assert layer[i].stride == 1
assert layer[i].downsample is None
x = torch.randn(1, 64, 56, 56)
x = torch.randn(1, 32, 56, 56)
x_out = layer(x)
assert x_out.shape == torch.Size([1, 256, 28, 28])
assert x_out.shape == (1, 64, 28, 28)
# Test ResLayer of 3 Bottleneck with stride=2 and average downsample
layer = ResLayer(Bottleneck, 64, 64, 3, stride=2, avg_down=True)
assert isinstance(layer[0].downsample[0], AvgPool2d)
assert layer[0].downsample[1].out_channels == 256
assert layer[0].downsample[1].stride == (1, 1)
for i in range(1, len(layer)):
# 3 BasicBlock w/ stride 2 and downsample with avg pool
layer = ResLayer(BasicBlock, 3, 32, 64, stride=2, avg_down=True)
assert len(layer) == 3
assert layer[0].in_channels == 32
assert layer[0].out_channels == 64
assert layer[0].stride == 2
assert layer[0].downsample is not None and len(layer[0].downsample) == 3
assert isinstance(layer[0].downsample[0], nn.AvgPool2d)
assert layer[0].downsample[0].stride == 2
for i in range(1, 3):
assert layer[i].in_channels == 64
assert layer[i].out_channels == 64
assert layer[i].stride == 1
assert layer[i].downsample is None
x = torch.randn(1, 64, 56, 56)
x = torch.randn(1, 32, 56, 56)
x_out = layer(x)
assert x_out.shape == torch.Size([1, 256, 28, 28])
assert x_out.shape == (1, 64, 28, 28)
def test_resnet_backbone():
def test_bottleneck_reslayer():
# 3 Bottleneck w/o downsample
layer = ResLayer(Bottleneck, 3, 32, 32)
assert len(layer) == 3
for i in range(3):
assert layer[i].in_channels == 32
assert layer[i].out_channels == 32
assert layer[i].downsample is None
x = torch.randn(1, 32, 56, 56)
x_out = layer(x)
assert x_out.shape == (1, 32, 56, 56)
# 3 Bottleneck w/ stride 1 and downsample
layer = ResLayer(Bottleneck, 3, 32, 64)
assert len(layer) == 3
assert layer[0].in_channels == 32
assert layer[0].out_channels == 64
assert layer[0].stride == 1
assert layer[0].conv1.out_channels == 16
assert layer[0].downsample is not None and len(layer[0].downsample) == 2
assert isinstance(layer[0].downsample[0], nn.Conv2d)
assert layer[0].downsample[0].stride == (1, 1)
for i in range(1, 3):
assert layer[i].in_channels == 64
assert layer[i].out_channels == 64
assert layer[i].conv1.out_channels == 16
assert layer[i].stride == 1
assert layer[i].downsample is None
x = torch.randn(1, 32, 56, 56)
x_out = layer(x)
assert x_out.shape == (1, 64, 56, 56)
# 3 Bottleneck w/ stride 2 and downsample
layer = ResLayer(Bottleneck, 3, 32, 64, stride=2)
assert len(layer) == 3
assert layer[0].in_channels == 32
assert layer[0].out_channels == 64
assert layer[0].stride == 2
assert layer[0].conv1.out_channels == 16
assert layer[0].downsample is not None and len(layer[0].downsample) == 2
assert isinstance(layer[0].downsample[0], nn.Conv2d)
assert layer[0].downsample[0].stride == (2, 2)
for i in range(1, 3):
assert layer[i].in_channels == 64
assert layer[i].out_channels == 64
assert layer[i].conv1.out_channels == 16
assert layer[i].stride == 1
assert layer[i].downsample is None
x = torch.randn(1, 32, 56, 56)
x_out = layer(x)
assert x_out.shape == (1, 64, 28, 28)
# 3 Bottleneck w/ stride 2 and downsample with avg pool
layer = ResLayer(Bottleneck, 3, 32, 64, stride=2, avg_down=True)
assert len(layer) == 3
assert layer[0].in_channels == 32
assert layer[0].out_channels == 64
assert layer[0].stride == 2
assert layer[0].conv1.out_channels == 16
assert layer[0].downsample is not None and len(layer[0].downsample) == 3
assert isinstance(layer[0].downsample[0], nn.AvgPool2d)
assert layer[0].downsample[0].stride == 2
for i in range(1, 3):
assert layer[i].in_channels == 64
assert layer[i].out_channels == 64
assert layer[i].conv1.out_channels == 16
assert layer[i].stride == 1
assert layer[i].downsample is None
x = torch.randn(1, 32, 56, 56)
x_out = layer(x)
assert x_out.shape == (1, 64, 28, 28)
# 3 Bottleneck with custom expansion
layer = ResLayer(Bottleneck, 3, 32, 32, expansion=2)
assert len(layer) == 3
for i in range(3):
assert layer[i].in_channels == 32
assert layer[i].out_channels == 32
assert layer[i].stride == 1
assert layer[i].conv1.out_channels == 16
assert layer[i].downsample is None
x = torch.randn(1, 32, 56, 56)
x_out = layer(x)
assert x_out.shape == (1, 32, 56, 56)
def test_resnet():
"""Test resnet backbone"""
with pytest.raises(KeyError):
# ResNet depth should be in [18, 34, 50, 101, 152]
@ -194,9 +426,113 @@ def test_resnet_backbone():
for param in layer.parameters():
assert param.requires_grad is False
# Test ResNet18 forward
model = ResNet(18, out_indices=(0, 1, 2, 3))
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == (1, 64, 56, 56)
assert feat[1].shape == (1, 128, 28, 28)
assert feat[2].shape == (1, 256, 14, 14)
assert feat[3].shape == (1, 512, 7, 7)
# Test ResNet50 with BatchNorm forward
model = ResNet(50, out_indices=(0, 1, 2, 3))
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == (1, 256, 56, 56)
assert feat[1].shape == (1, 512, 28, 28)
assert feat[2].shape == (1, 1024, 14, 14)
assert feat[3].shape == (1, 2048, 7, 7)
# Test ResNet50 with layers 1, 2, 3 out forward
model = ResNet(50, out_indices=(0, 1, 2))
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 3
assert feat[0].shape == (1, 256, 56, 56)
assert feat[1].shape == (1, 512, 28, 28)
assert feat[2].shape == (1, 1024, 14, 14)
# Test ResNet50 with layers 3 (top feature maps) out forward
model = ResNet(50, out_indices=(3, ))
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert feat.shape == (1, 2048, 7, 7)
# Test ResNet50 with checkpoint forward
model = ResNet(50, out_indices=(0, 1, 2, 3), with_cp=True)
for m in model.modules():
if is_block(m):
assert m.with_cp
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == (1, 256, 56, 56)
assert feat[1].shape == (1, 512, 28, 28)
assert feat[2].shape == (1, 1024, 14, 14)
assert feat[3].shape == (1, 2048, 7, 7)
# zero initialization of residual blocks
model = ResNet(50, out_indices=(0, 1, 2, 3), zero_init_residual=True)
model.init_weights()
for m in model.modules():
if isinstance(m, Bottleneck):
assert all_zeros(m.norm3)
elif isinstance(m, BasicBlock):
assert all_zeros(m.norm2)
# non-zero initialization of residual blocks
model = ResNet(50, out_indices=(0, 1, 2, 3), zero_init_residual=False)
model.init_weights()
for m in model.modules():
if isinstance(m, Bottleneck):
assert not all_zeros(m.norm3)
elif isinstance(m, BasicBlock):
assert not all_zeros(m.norm2)
def test_resnet_v1d():
model = ResNetV1d(depth=50, out_indices=(0, 1, 2, 3))
model.init_weights()
model.train()
assert len(model.stem) == 3
for i in range(3):
assert isinstance(model.stem[i], ConvModule)
imgs = torch.randn(1, 3, 224, 224)
feat = model.stem(imgs)
assert feat.shape == (1, 64, 112, 112)
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == (1, 256, 56, 56)
assert feat[1].shape == (1, 512, 28, 28)
assert feat[2].shape == (1, 1024, 14, 14)
assert feat[3].shape == (1, 2048, 7, 7)
# Test ResNet50V1d with first stage frozen
frozen_stages = 1
model = ResNetV1d(depth=50, frozen_stages=frozen_stages)
assert len(model.stem) == 9
assert len(model.stem) == 3
for i in range(3):
assert isinstance(model.stem[i], ConvModule)
model.init_weights()
model.train()
check_norm_state(model.stem, False)
@ -210,99 +546,16 @@ def test_resnet_backbone():
for param in layer.parameters():
assert param.requires_grad is False
# Test ResNet18 forward
model = ResNet(18, out_indices=(0, 1, 2, 3))
def test_resnet_half_channel():
model = ResNet(50, base_channels=32, out_indices=(0, 1, 2, 3))
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == torch.Size([1, 64, 56, 56])
assert feat[1].shape == torch.Size([1, 128, 28, 28])
assert feat[2].shape == torch.Size([1, 256, 14, 14])
assert feat[3].shape == torch.Size([1, 512, 7, 7])
# Test ResNet50 with BatchNorm forward
model = ResNet(50, out_indices=(0, 1, 2, 3))
for m in model.modules():
if is_norm(m):
assert isinstance(m, _BatchNorm)
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == torch.Size([1, 256, 56, 56])
assert feat[1].shape == torch.Size([1, 512, 28, 28])
assert feat[2].shape == torch.Size([1, 1024, 14, 14])
assert feat[3].shape == torch.Size([1, 2048, 7, 7])
# Test ResNet50 with layers 1, 2, 3 out forward
model = ResNet(50, out_indices=(0, 1, 2))
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 3
assert feat[0].shape == torch.Size([1, 256, 56, 56])
assert feat[1].shape == torch.Size([1, 512, 28, 28])
assert feat[2].shape == torch.Size([1, 1024, 14, 14])
# Test ResNet50 with layers 3 (top feature maps) out forward
model = ResNet(50, out_indices=(3, ))
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert feat.shape == torch.Size([1, 2048, 7, 7])
# Test ResNet50 with checkpoint forward
model = ResNet(50, out_indices=(0, 1, 2, 3), with_cp=True)
for m in model.modules():
if is_block(m):
assert m.with_cp
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == torch.Size([1, 256, 56, 56])
assert feat[1].shape == torch.Size([1, 512, 28, 28])
assert feat[2].shape == torch.Size([1, 1024, 14, 14])
assert feat[3].shape == torch.Size([1, 2048, 7, 7])
# Test ResNet50 zero initialization of residual
model = ResNet(50, out_indices=(0, 1, 2, 3), zero_init_residual=True)
model.init_weights()
for m in model.modules():
if isinstance(m, Bottleneck):
assert all_zeros(m.norm3)
elif isinstance(m, BasicBlock):
assert all_zeros(m.norm2)
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == torch.Size([1, 256, 56, 56])
assert feat[1].shape == torch.Size([1, 512, 28, 28])
assert feat[2].shape == torch.Size([1, 1024, 14, 14])
assert feat[3].shape == torch.Size([1, 2048, 7, 7])
# Test ResNetV1d forward
model = ResNetV1d(depth=50, out_indices=(0, 1, 2, 3))
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == torch.Size([1, 256, 56, 56])
assert feat[1].shape == torch.Size([1, 512, 28, 28])
assert feat[2].shape == torch.Size([1, 1024, 14, 14])
assert feat[3].shape == torch.Size([1, 2048, 7, 7])
assert feat[0].shape == (1, 128, 56, 56)
assert feat[1].shape == (1, 256, 28, 28)
assert feat[2].shape == (1, 512, 14, 14)
assert feat[3].shape == (1, 1024, 7, 7)

View File

@ -5,42 +5,35 @@ from mmcls.models.backbones import ResNeXt
from mmcls.models.backbones.resnext import Bottleneck as BottleneckX
def is_block(modules):
"""Check if is ResNeXt building block."""
if isinstance(modules, (BottleneckX)):
return True
return False
def test_resnext_bottleneck():
def test_bottleneck():
with pytest.raises(AssertionError):
# Style must be in ['pytorch', 'caffe']
BottleneckX(64, 64, groups=32, base_width=4, style='tensorflow')
BottleneckX(64, 64, groups=32, width_per_group=4, style='tensorflow')
# Test ResNeXt Bottleneck structure
block = BottleneckX(
64, 64, groups=32, base_width=4, stride=2, style='pytorch')
64, 256, groups=32, width_per_group=4, stride=2, style='pytorch')
assert block.conv2.stride == (2, 2)
assert block.conv2.groups == 32
assert block.conv2.out_channels == 128
# Test ResNeXt Bottleneck forward
block = BottleneckX(64, 16, groups=32, base_width=4)
block = BottleneckX(64, 64, base_channels=16, groups=32, width_per_group=4)
x = torch.randn(1, 64, 56, 56)
x_out = block(x)
assert x_out.shape == torch.Size([1, 64, 56, 56])
def test_resnext_backbone():
def test_resnext():
with pytest.raises(KeyError):
# ResNeXt depth should be in [50, 101, 152]
ResNeXt(depth=18)
# Test ResNeXt with group 32, base_width 4
# Test ResNeXt with group 32, width_per_group 4
model = ResNeXt(
depth=50, groups=32, base_width=4, out_indices=(0, 1, 2, 3))
depth=50, groups=32, width_per_group=4, out_indices=(0, 1, 2, 3))
for m in model.modules():
if is_block(m):
if isinstance(m, BottleneckX):
assert m.conv2.groups == 32
model.init_weights()
model.train()
@ -53,10 +46,10 @@ def test_resnext_backbone():
assert feat[2].shape == torch.Size([1, 1024, 14, 14])
assert feat[3].shape == torch.Size([1, 2048, 7, 7])
# Test ResNeXt with group 32, base_width 4 and layers 3 out forward
model = ResNeXt(depth=50, groups=32, base_width=4, out_indices=(3, ))
# Test ResNeXt with group 32, width_per_group 4 and layers 3 out forward
model = ResNeXt(depth=50, groups=32, width_per_group=4, out_indices=(3, ))
for m in model.modules():
if is_block(m):
if isinstance(m, BottleneckX):
assert m.conv2.groups == 32
model.init_weights()
model.train()

View File

@ -8,20 +8,6 @@ from mmcls.models.backbones.resnet import ResLayer
from mmcls.models.backbones.seresnet import SEBottleneck, SELayer
def is_block(modules):
"""Check if is ResNet building block."""
if isinstance(modules, (SEBottleneck, )):
return True
return False
def is_norm(modules):
"""Check if is one of the norms."""
if isinstance(modules, (_BatchNorm, )):
return True
return False
def all_zeros(modules):
"""Check if the weight(and bias) is all zero."""
weight_zero = torch.equal(modules.weight.data,
@ -44,7 +30,7 @@ def check_norm_state(modules, train_state):
return True
def test_serenet_selayer():
def test_selayer():
# Test selayer forward
layer = SELayer(64)
x = torch.randn(1, 64, 56, 56)
@ -58,37 +44,37 @@ def test_serenet_selayer():
assert x_out.shape == torch.Size([1, 64, 56, 56])
def test_seresnet_bottleneckse():
def test_bottleneck():
with pytest.raises(AssertionError):
# Style must be in ['pytorch', 'caffe']
SEBottleneck(64, 64, style='tensorflow')
# Test SEBottleneck with checkpoint forward
block = SEBottleneck(64, 16, with_cp=True)
block = SEBottleneck(64, 64, with_cp=True)
assert block.with_cp
x = torch.randn(1, 64, 56, 56)
x_out = block(x)
assert x_out.shape == torch.Size([1, 64, 56, 56])
# Test Bottleneck style
block = SEBottleneck(64, 64, stride=2, style='pytorch')
block = SEBottleneck(64, 256, stride=2, style='pytorch')
assert block.conv1.stride == (1, 1)
assert block.conv2.stride == (2, 2)
block = SEBottleneck(64, 64, stride=2, style='caffe')
block = SEBottleneck(64, 256, stride=2, style='caffe')
assert block.conv1.stride == (2, 2)
assert block.conv2.stride == (1, 1)
# Test Bottleneck forward
block = SEBottleneck(64, 16)
block = SEBottleneck(64, 64)
x = torch.randn(1, 64, 56, 56)
x_out = block(x)
assert x_out.shape == torch.Size([1, 64, 56, 56])
def test_seresnet_res_layer():
def test_res_layer():
# Test ResLayer of 3 Bottleneck w\o downsample
layer = ResLayer(SEBottleneck, 64, 16, 3, se_ratio=16)
layer = ResLayer(SEBottleneck, 3, 64, 64, se_ratio=16)
assert len(layer) == 3
assert layer[0].conv1.in_channels == 64
assert layer[0].conv1.out_channels == 16
@ -102,7 +88,7 @@ def test_seresnet_res_layer():
assert x_out.shape == torch.Size([1, 64, 56, 56])
# Test ResLayer of 3 SEBottleneck with downsample
layer = ResLayer(SEBottleneck, 64, 64, 3, se_ratio=16)
layer = ResLayer(SEBottleneck, 3, 64, 256, se_ratio=16)
assert layer[0].downsample[0].out_channels == 256
for i in range(1, len(layer)):
assert layer[i].downsample is None
@ -111,7 +97,7 @@ def test_seresnet_res_layer():
assert x_out.shape == torch.Size([1, 256, 56, 56])
# Test ResLayer of 3 SEBottleneck with stride=2
layer = ResLayer(SEBottleneck, 64, 64, 3, stride=2, se_ratio=8)
layer = ResLayer(SEBottleneck, 3, 64, 256, stride=2, se_ratio=8)
assert layer[0].downsample[0].out_channels == 256
assert layer[0].downsample[0].stride == (2, 2)
for i in range(1, len(layer)):
@ -122,7 +108,7 @@ def test_seresnet_res_layer():
# Test ResLayer of 3 SEBottleneck with stride=2 and average downsample
layer = ResLayer(
SEBottleneck, 64, 64, 3, stride=2, avg_down=True, se_ratio=8)
SEBottleneck, 3, 64, 256, stride=2, avg_down=True, se_ratio=8)
assert isinstance(layer[0].downsample[0], AvgPool2d)
assert layer[0].downsample[1].out_channels == 256
assert layer[0].downsample[1].stride == (1, 1)
@ -133,7 +119,7 @@ def test_seresnet_res_layer():
assert x_out.shape == torch.Size([1, 256, 28, 28])
def test_seresnet_backbone():
def test_seresnet():
"""Test resnet backbone"""
with pytest.raises(KeyError):
# SEResNet depth should be in [50, 101, 152]
@ -191,9 +177,6 @@ def test_seresnet_backbone():
# Test SEResNet50 with BatchNorm forward
model = SEResNet(50, out_indices=(0, 1, 2, 3))
for m in model.modules():
if is_norm(m):
assert isinstance(m, _BatchNorm)
model.init_weights()
model.train()
@ -229,7 +212,7 @@ def test_seresnet_backbone():
# Test SEResNet50 with checkpoint forward
model = SEResNet(50, out_indices=(0, 1, 2, 3), with_cp=True)
for m in model.modules():
if is_block(m):
if isinstance(m, SEBottleneck):
assert m.with_cp
model.init_weights()
model.train()

View File

@ -5,42 +5,35 @@ from mmcls.models.backbones import SEResNeXt
from mmcls.models.backbones.seresnext import SEBottleneck as SEBottleneckX
def is_block(modules):
"""Check if is SEResNeXt building block."""
if isinstance(modules, (SEBottleneckX)):
return True
return False
def test_seresnext_bottleneck():
def test_bottleneck():
with pytest.raises(AssertionError):
# Style must be in ['pytorch', 'caffe']
SEBottleneckX(64, 64, groups=32, base_width=4, style='tensorflow')
SEBottleneckX(64, 64, groups=32, width_per_group=4, style='tensorflow')
# Test SEResNeXt Bottleneck structure
block = SEBottleneckX(
64, 64, groups=32, base_width=4, stride=2, style='pytorch')
64, 256, groups=32, width_per_group=4, stride=2, style='pytorch')
assert block.conv2.stride == (2, 2)
assert block.conv2.groups == 32
assert block.conv2.out_channels == 128
# Test SEResNeXt Bottleneck forward
block = SEBottleneckX(64, 16, groups=32, base_width=4)
block = SEBottleneckX(64, 64, groups=32, width_per_group=4)
x = torch.randn(1, 64, 56, 56)
x_out = block(x)
assert x_out.shape == torch.Size([1, 64, 56, 56])
def test_seresnext_backbone():
def test_seresnext():
with pytest.raises(KeyError):
# SEResNeXt depth should be in [50, 101, 152]
SEResNeXt(depth=18)
# Test SEResNeXt with group 32, base_width 4
# Test SEResNeXt with group 32, width_per_group 4
model = SEResNeXt(
depth=50, groups=32, base_width=4, out_indices=(0, 1, 2, 3))
depth=50, groups=32, width_per_group=4, out_indices=(0, 1, 2, 3))
for m in model.modules():
if is_block(m):
if isinstance(m, SEBottleneckX):
assert m.conv2.groups == 32
model.init_weights()
model.train()
@ -53,10 +46,11 @@ def test_seresnext_backbone():
assert feat[2].shape == torch.Size([1, 1024, 14, 14])
assert feat[3].shape == torch.Size([1, 2048, 7, 7])
# Test SEResNeXt with group 32, base_width 4 and layers 3 out forward
model = SEResNeXt(depth=50, groups=32, base_width=4, out_indices=(3, ))
# Test SEResNeXt with group 32, width_per_group 4 and layers 3 out forward
model = SEResNeXt(
depth=50, groups=32, width_per_group=4, out_indices=(3, ))
for m in model.modules():
if is_block(m):
if isinstance(m, SEBottleneckX):
assert m.conv2.groups == 32
model.init_weights()
model.train()