Refactoring for ResNet family
parent
2a05c77f0f
commit
02e11cc1f3
|
@ -1,7 +1,7 @@
|
|||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init,
|
||||
kaiming_init)
|
||||
from mmcv.cnn import (ConvModule, build_conv_layer, build_norm_layer,
|
||||
constant_init, kaiming_init)
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from ..builder import BACKBONES
|
||||
|
@ -12,15 +12,17 @@ class BasicBlock(nn.Module):
|
|||
"""BasicBlock for ResNet.
|
||||
|
||||
Args:
|
||||
inplanes (int): inplanes of block.
|
||||
planes (int): planes of block.
|
||||
in_channels (int): Input channels of this block.
|
||||
out_channels (int): Output channels of this block.
|
||||
expansion (int): The ratio of ``out_channels/mid_channels`` where
|
||||
``mid_channels`` is the output channels of conv1. This is a
|
||||
reserved argument in BasicBlock and should always be 1. Default: 1.
|
||||
stride (int): stride of the block. Default: 1
|
||||
dilation (int): dilation of convolution. Default: 1
|
||||
downsample (nn.Module): downsample operation on identity branch.
|
||||
Default: None
|
||||
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
|
||||
layer is the 3x3 conv layer, otherwise the stride-two layer is
|
||||
the first 1x1 conv layer.
|
||||
Default: None.
|
||||
style (str): `pytorch` or `caffe`. It is unused and reserved for
|
||||
unified API with Bottleneck.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed.
|
||||
conv_cfg (dict): dictionary to construct and config conv layer.
|
||||
|
@ -29,11 +31,10 @@ class BasicBlock(nn.Module):
|
|||
Default: dict(type='BN')
|
||||
"""
|
||||
|
||||
expansion = 1
|
||||
|
||||
def __init__(self,
|
||||
inplanes,
|
||||
planes,
|
||||
in_channels,
|
||||
out_channels,
|
||||
expansion=1,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
downsample=None,
|
||||
|
@ -42,14 +43,28 @@ class BasicBlock(nn.Module):
|
|||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN')):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.expansion = expansion
|
||||
assert self.expansion == 1
|
||||
assert out_channels % expansion == 0
|
||||
self.mid_channels = out_channels // expansion
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
self.style = style
|
||||
self.with_cp = with_cp
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
|
||||
self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
|
||||
self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
norm_cfg, self.mid_channels, postfix=1)
|
||||
self.norm2_name, norm2 = build_norm_layer(
|
||||
norm_cfg, out_channels, postfix=2)
|
||||
|
||||
self.conv1 = build_conv_layer(
|
||||
conv_cfg,
|
||||
inplanes,
|
||||
planes,
|
||||
in_channels,
|
||||
self.mid_channels,
|
||||
3,
|
||||
stride=stride,
|
||||
padding=dilation,
|
||||
|
@ -57,14 +72,16 @@ class BasicBlock(nn.Module):
|
|||
bias=False)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
self.conv2 = build_conv_layer(
|
||||
conv_cfg, planes, planes, 3, padding=1, bias=False)
|
||||
conv_cfg,
|
||||
self.mid_channels,
|
||||
out_channels,
|
||||
3,
|
||||
padding=1,
|
||||
bias=False)
|
||||
self.add_module(self.norm2_name, norm2)
|
||||
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
self.with_cp = with_cp
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
|
@ -107,15 +124,17 @@ class Bottleneck(nn.Module):
|
|||
"""Bottleneck block for ResNet.
|
||||
|
||||
Args:
|
||||
inplanes (int): inplanes of block.
|
||||
planes (int): planes of block.
|
||||
in_channels (int): Input channels of this block.
|
||||
out_channels (int): Output channels of this block.
|
||||
expansion (int): The ratio of ``out_channels/mid_channels`` where
|
||||
``mid_channels`` is the input/output channels of conv2. Default: 4.
|
||||
stride (int): stride of the block. Default: 1
|
||||
dilation (int): dilation of convolution. Default: 1
|
||||
downsample (nn.Module): downsample operation on identity branch.
|
||||
Default: None
|
||||
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
|
||||
layer is the 3x3 conv layer, otherwise the stride-two layer is
|
||||
the first 1x1 conv layer.
|
||||
Default: None.
|
||||
style (str): ``"pytorch"`` or ``"caffe"``. If set to "pytorch", the
|
||||
stride-two layer is the 3x3 conv layer, otherwise the stride-two
|
||||
layer is the first 1x1 conv layer. Default: "pytorch".
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed.
|
||||
conv_cfg (dict): dictionary to construct and config conv layer.
|
||||
|
@ -124,11 +143,10 @@ class Bottleneck(nn.Module):
|
|||
Default: dict(type='BN')
|
||||
"""
|
||||
|
||||
expansion = 4
|
||||
|
||||
def __init__(self,
|
||||
inplanes,
|
||||
planes,
|
||||
in_channels,
|
||||
out_channels,
|
||||
expansion=4,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
downsample=None,
|
||||
|
@ -139,8 +157,11 @@ class Bottleneck(nn.Module):
|
|||
super(Bottleneck, self).__init__()
|
||||
assert style in ['pytorch', 'caffe']
|
||||
|
||||
self.inplanes = inplanes
|
||||
self.planes = planes
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.expansion = expansion
|
||||
assert out_channels % expansion == 0
|
||||
self.mid_channels = out_channels // expansion
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
self.style = style
|
||||
|
@ -155,23 +176,25 @@ class Bottleneck(nn.Module):
|
|||
self.conv1_stride = stride
|
||||
self.conv2_stride = 1
|
||||
|
||||
self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
|
||||
self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
norm_cfg, self.mid_channels, postfix=1)
|
||||
self.norm2_name, norm2 = build_norm_layer(
|
||||
norm_cfg, self.mid_channels, postfix=2)
|
||||
self.norm3_name, norm3 = build_norm_layer(
|
||||
norm_cfg, planes * self.expansion, postfix=3)
|
||||
norm_cfg, out_channels, postfix=3)
|
||||
|
||||
self.conv1 = build_conv_layer(
|
||||
conv_cfg,
|
||||
inplanes,
|
||||
planes,
|
||||
in_channels,
|
||||
self.mid_channels,
|
||||
kernel_size=1,
|
||||
stride=self.conv1_stride,
|
||||
bias=False)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
self.conv2 = build_conv_layer(
|
||||
conv_cfg,
|
||||
planes,
|
||||
planes,
|
||||
self.mid_channels,
|
||||
self.mid_channels,
|
||||
kernel_size=3,
|
||||
stride=self.conv2_stride,
|
||||
padding=dilation,
|
||||
|
@ -181,8 +204,8 @@ class Bottleneck(nn.Module):
|
|||
self.add_module(self.norm2_name, norm2)
|
||||
self.conv3 = build_conv_layer(
|
||||
conv_cfg,
|
||||
planes,
|
||||
planes * self.expansion,
|
||||
self.mid_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
bias=False)
|
||||
self.add_module(self.norm3_name, norm3)
|
||||
|
@ -235,15 +258,55 @@ class Bottleneck(nn.Module):
|
|||
return out
|
||||
|
||||
|
||||
def get_expansion(block, expansion=None):
|
||||
"""Get the expansion of a residual block.
|
||||
|
||||
The block expansion will be obtained by the following order:
|
||||
|
||||
1. If ``expansion`` is given, just return it.
|
||||
2. If ``block`` has the attribute ``expansion``, then return
|
||||
``block.expansion``.
|
||||
3. Return the default value according the the block type:
|
||||
1 for ``BasicBlock`` and 4 for ``Bottleneck``.
|
||||
|
||||
Args:
|
||||
block (class): The block class.
|
||||
expansion (int | None): The given expansion ratio.
|
||||
|
||||
Returns:
|
||||
int: The expansion of the block.
|
||||
"""
|
||||
if isinstance(expansion, int):
|
||||
assert expansion > 0
|
||||
elif expansion is None:
|
||||
if hasattr(block, 'expansion'):
|
||||
expansion = block.expansion
|
||||
elif issubclass(block, BasicBlock):
|
||||
expansion = 1
|
||||
elif issubclass(block, Bottleneck):
|
||||
expansion = 4
|
||||
else:
|
||||
raise TypeError(f'expansion is not specified for {block.__name__}')
|
||||
else:
|
||||
raise TypeError('expansion must be an integer or None')
|
||||
|
||||
return expansion
|
||||
|
||||
|
||||
class ResLayer(nn.Sequential):
|
||||
"""ResLayer to build ResNet style backbone.
|
||||
|
||||
Args:
|
||||
block (nn.Module): block used to build ResLayer.
|
||||
inplanes (int): inplanes of block.
|
||||
planes (int): planes of block.
|
||||
num_blocks (int): number of blocks.
|
||||
stride (int): stride of the first block. Default: 1
|
||||
block (nn.Module): Residual block used to build ResLayer.
|
||||
num_blocks (int): Number of blocks.
|
||||
in_channels (int): Input channels of this block.
|
||||
out_channels (int): Output channels of this block.
|
||||
expansion (int, optional): The expansion for BasicBlock/Bottleneck.
|
||||
If not specified, it will firstly be obtained via
|
||||
``block.expansion``. If the block has no attribute "expansion",
|
||||
the following default values will be used: 1 for BasicBlock and
|
||||
4 for Bottleneck. Default: None.
|
||||
stride (int): stride of the first block. Default: 1.
|
||||
avg_down (bool): Use AvgPool instead of stride conv when
|
||||
downsampling in the bottleneck. Default: False
|
||||
conv_cfg (dict): dictionary to construct and config conv layer.
|
||||
|
@ -254,18 +317,20 @@ class ResLayer(nn.Sequential):
|
|||
|
||||
def __init__(self,
|
||||
block,
|
||||
inplanes,
|
||||
planes,
|
||||
num_blocks,
|
||||
in_channels,
|
||||
out_channels,
|
||||
expansion=None,
|
||||
stride=1,
|
||||
avg_down=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
**kwargs):
|
||||
self.block = block
|
||||
self.expansion = get_expansion(block, expansion)
|
||||
|
||||
downsample = None
|
||||
if stride != 1 or inplanes != planes * block.expansion:
|
||||
if stride != 1 or in_channels != out_channels:
|
||||
downsample = []
|
||||
conv_stride = stride
|
||||
if avg_down and stride != 1:
|
||||
|
@ -279,31 +344,33 @@ class ResLayer(nn.Sequential):
|
|||
downsample.extend([
|
||||
build_conv_layer(
|
||||
conv_cfg,
|
||||
inplanes,
|
||||
planes * block.expansion,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=conv_stride,
|
||||
bias=False),
|
||||
build_norm_layer(norm_cfg, planes * block.expansion)[1]
|
||||
build_norm_layer(norm_cfg, out_channels)[1]
|
||||
])
|
||||
downsample = nn.Sequential(*downsample)
|
||||
|
||||
layers = []
|
||||
layers.append(
|
||||
block(
|
||||
inplanes=inplanes,
|
||||
planes=planes,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
expansion=self.expansion,
|
||||
stride=stride,
|
||||
downsample=downsample,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
**kwargs))
|
||||
inplanes = planes * block.expansion
|
||||
in_channels = out_channels
|
||||
for i in range(1, num_blocks):
|
||||
layers.append(
|
||||
block(
|
||||
inplanes=inplanes,
|
||||
planes=planes,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
expansion=self.expansion,
|
||||
stride=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
|
@ -315,30 +382,41 @@ class ResLayer(nn.Sequential):
|
|||
class ResNet(BaseBackbone):
|
||||
"""ResNet backbone.
|
||||
|
||||
Please refer to the `paper <https://arxiv.org/abs/1512.03385>`_ for
|
||||
details.
|
||||
|
||||
Args:
|
||||
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
|
||||
in_channels (int): Number of input image channels. Normally 3.
|
||||
base_channels (int): Number of base channels of hidden layer.
|
||||
num_stages (int): Resnet stages, normally 4.
|
||||
depth (int): Network depth, from {18, 34, 50, 101, 152}.
|
||||
in_channels (int): Number of input image channels. Default: 3.
|
||||
stem_channels (int): Output channels of the stem layer. Default: 64.
|
||||
base_channels (int): Middle channels of the first stage. Default: 64.
|
||||
num_stages (int): Stages of the network. Default: 4.
|
||||
strides (Sequence[int]): Strides of the first block of each stage.
|
||||
Default: ``(1, 2, 2, 2)``.
|
||||
dilations (Sequence[int]): Dilation of each stage.
|
||||
out_indices (Sequence[int]): Output from which stages.
|
||||
Default: ``(1, 1, 1, 1)``.
|
||||
out_indices (Sequence[int]): Output from which stages. If only one
|
||||
stage is specified, a single tensor (feature map) is returned,
|
||||
otherwise multiple stages are specified, a tuple of tensors will
|
||||
be returned. Default: ``(3, )``.
|
||||
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
|
||||
layer is the 3x3 conv layer, otherwise the stride-two layer is
|
||||
the first 1x1 conv layer.
|
||||
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv
|
||||
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
|
||||
Default: False.
|
||||
avg_down (bool): Use AvgPool instead of stride conv when
|
||||
downsampling in the bottleneck.
|
||||
downsampling in the bottleneck. Default: False.
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
||||
-1 means not freezing any parameters.
|
||||
norm_cfg (dict): Dictionary to construct and config norm layer.
|
||||
-1 means not freezing any parameters. Default: -1.
|
||||
conv_cfg (dict | None): The config dict for conv layers. Default: None.
|
||||
norm_cfg (dict): The config dict for norm layers.
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only.
|
||||
and its variants only. Default: False.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed.
|
||||
memory while slowing down the training speed. Default: False.
|
||||
zero_init_residual (bool): Whether to use zero init for last norm layer
|
||||
in resblocks to let them behave as identity.
|
||||
in resblocks to let them behave as identity. Default: True.
|
||||
|
||||
Example:
|
||||
>>> from mmcls.models import ResNet
|
||||
|
@ -366,7 +444,9 @@ class ResNet(BaseBackbone):
|
|||
def __init__(self,
|
||||
depth,
|
||||
in_channels=3,
|
||||
stem_channels=64,
|
||||
base_channels=64,
|
||||
expansion=None,
|
||||
num_stages=4,
|
||||
strides=(1, 2, 2, 2),
|
||||
dilations=(1, 1, 1, 1),
|
||||
|
@ -384,6 +464,7 @@ class ResNet(BaseBackbone):
|
|||
if depth not in self.arch_settings:
|
||||
raise KeyError(f'invalid depth {depth} for resnet')
|
||||
self.depth = depth
|
||||
self.stem_channels = stem_channels
|
||||
self.base_channels = base_channels
|
||||
self.num_stages = num_stages
|
||||
assert num_stages >= 1 and num_stages <= 4
|
||||
|
@ -403,20 +484,22 @@ class ResNet(BaseBackbone):
|
|||
self.zero_init_residual = zero_init_residual
|
||||
self.block, stage_blocks = self.arch_settings[depth]
|
||||
self.stage_blocks = stage_blocks[:num_stages]
|
||||
self.inplanes = base_channels
|
||||
self.expansion = get_expansion(self.block, expansion)
|
||||
|
||||
self._make_stem_layer(in_channels, base_channels)
|
||||
self._make_stem_layer(in_channels, stem_channels)
|
||||
|
||||
self.res_layers = []
|
||||
_in_channels = stem_channels
|
||||
_out_channels = base_channels * self.expansion
|
||||
for i, num_blocks in enumerate(self.stage_blocks):
|
||||
stride = strides[i]
|
||||
dilation = dilations[i]
|
||||
planes = base_channels * 2**i
|
||||
res_layer = self.make_res_layer(
|
||||
block=self.block,
|
||||
inplanes=self.inplanes,
|
||||
planes=planes,
|
||||
num_blocks=num_blocks,
|
||||
in_channels=_in_channels,
|
||||
out_channels=_out_channels,
|
||||
expansion=self.expansion,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
style=self.style,
|
||||
|
@ -424,15 +507,15 @@ class ResNet(BaseBackbone):
|
|||
with_cp=with_cp,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg)
|
||||
self.inplanes = planes * self.block.expansion
|
||||
_in_channels = _out_channels
|
||||
_out_channels *= 2
|
||||
layer_name = f'layer{i + 1}'
|
||||
self.add_module(layer_name, res_layer)
|
||||
self.res_layers.append(layer_name)
|
||||
|
||||
self._freeze_stages()
|
||||
|
||||
self.feat_dim = self.block.expansion * base_channels * 2**(
|
||||
len(self.stage_blocks) - 1)
|
||||
self.feat_dim = res_layer[-1].out_channels
|
||||
|
||||
def make_res_layer(self, **kwargs):
|
||||
return ResLayer(**kwargs)
|
||||
|
@ -441,50 +524,47 @@ class ResNet(BaseBackbone):
|
|||
def norm1(self):
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
def _make_stem_layer(self, in_channels, base_channels):
|
||||
def _make_stem_layer(self, in_channels, stem_channels):
|
||||
if self.deep_stem:
|
||||
self.stem = nn.Sequential(
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
ConvModule(
|
||||
in_channels,
|
||||
base_channels // 2,
|
||||
stem_channels // 2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg, base_channels // 2)[1],
|
||||
nn.ReLU(inplace=True),
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
base_channels // 2,
|
||||
base_channels // 2,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
inplace=True),
|
||||
ConvModule(
|
||||
stem_channels // 2,
|
||||
stem_channels // 2,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg, base_channels // 2)[1],
|
||||
nn.ReLU(inplace=True),
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
base_channels // 2,
|
||||
base_channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
inplace=True),
|
||||
ConvModule(
|
||||
stem_channels // 2,
|
||||
stem_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg, base_channels)[1],
|
||||
nn.ReLU(inplace=True))
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
inplace=True))
|
||||
else:
|
||||
self.conv1 = build_conv_layer(
|
||||
self.conv_cfg,
|
||||
in_channels,
|
||||
base_channels,
|
||||
stem_channels,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
padding=3,
|
||||
bias=False)
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
self.norm_cfg, base_channels, postfix=1)
|
||||
self.norm_cfg, stem_channels, postfix=1)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
import math
|
||||
|
||||
from mmcv.cnn import build_conv_layer, build_norm_layer
|
||||
|
||||
from ..builder import BACKBONES
|
||||
|
@ -11,11 +9,12 @@ class Bottleneck(_Bottleneck):
|
|||
"""Bottleneck block for ResNeXt.
|
||||
|
||||
Args:
|
||||
inplanes (int): inplanes of block.
|
||||
planes (int): planes of block.
|
||||
groups (int): group of convolution.
|
||||
base_width (int): Base width of resnext.
|
||||
base_channels (int): Number of base channels of hidden layer.
|
||||
in_channels (int): Input channels of this block.
|
||||
out_channels (int): Output channels of this block.
|
||||
groups (int): Groups of conv2.
|
||||
width_per_group (int): Width per group of conv2. 64x4d indicates
|
||||
``groups=64, width_per_group=4`` and 32x8d indicates
|
||||
``groups=32, width_per_group=8``.
|
||||
stride (int): stride of the block. Default: 1
|
||||
dilation (int): dilation of convolution. Default: 1
|
||||
downsample (nn.Module): downsample operation on identity branch.
|
||||
|
@ -31,42 +30,44 @@ class Bottleneck(_Bottleneck):
|
|||
memory while slowing down the training speed.
|
||||
"""
|
||||
|
||||
expansion = 4
|
||||
|
||||
def __init__(self,
|
||||
inplanes,
|
||||
planes,
|
||||
groups=1,
|
||||
base_width=4,
|
||||
in_channels,
|
||||
out_channels,
|
||||
base_channels=64,
|
||||
groups=32,
|
||||
width_per_group=4,
|
||||
**kwargs):
|
||||
super(Bottleneck, self).__init__(inplanes, planes, **kwargs)
|
||||
super(Bottleneck, self).__init__(in_channels, out_channels, **kwargs)
|
||||
self.groups = groups
|
||||
self.width_per_group = width_per_group
|
||||
|
||||
if groups == 1:
|
||||
width = self.planes
|
||||
else:
|
||||
width = math.floor(self.planes *
|
||||
(base_width / base_channels)) * groups
|
||||
# For ResNet bottleneck, middle channels are determined by expansion
|
||||
# and out_channels, but for ResNeXt bottleneck, it is determined by
|
||||
# groups and width_per_group and the stage it is located in.
|
||||
if groups != 1:
|
||||
assert self.mid_channels % base_channels == 0
|
||||
self.mid_channels = (
|
||||
groups * width_per_group * self.mid_channels // base_channels)
|
||||
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
self.norm_cfg, width, postfix=1)
|
||||
self.norm_cfg, self.mid_channels, postfix=1)
|
||||
self.norm2_name, norm2 = build_norm_layer(
|
||||
self.norm_cfg, width, postfix=2)
|
||||
self.norm_cfg, self.mid_channels, postfix=2)
|
||||
self.norm3_name, norm3 = build_norm_layer(
|
||||
self.norm_cfg, self.planes * self.expansion, postfix=3)
|
||||
self.norm_cfg, self.out_channels, postfix=3)
|
||||
|
||||
self.conv1 = build_conv_layer(
|
||||
self.conv_cfg,
|
||||
self.inplanes,
|
||||
width,
|
||||
self.in_channels,
|
||||
self.mid_channels,
|
||||
kernel_size=1,
|
||||
stride=self.conv1_stride,
|
||||
bias=False)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
self.conv2 = build_conv_layer(
|
||||
self.conv_cfg,
|
||||
width,
|
||||
width,
|
||||
self.mid_channels,
|
||||
self.mid_channels,
|
||||
kernel_size=3,
|
||||
stride=self.conv2_stride,
|
||||
padding=self.dilation,
|
||||
|
@ -77,8 +78,8 @@ class Bottleneck(_Bottleneck):
|
|||
self.add_module(self.norm2_name, norm2)
|
||||
self.conv3 = build_conv_layer(
|
||||
self.conv_cfg,
|
||||
width,
|
||||
self.planes * self.expansion,
|
||||
self.mid_channels,
|
||||
self.out_channels,
|
||||
kernel_size=1,
|
||||
bias=False)
|
||||
self.add_module(self.norm3_name, norm3)
|
||||
|
@ -88,29 +89,43 @@ class Bottleneck(_Bottleneck):
|
|||
class ResNeXt(ResNet):
|
||||
"""ResNeXt backbone.
|
||||
|
||||
Please refer to the `paper <https://arxiv.org/abs/1611.05431>`_ for
|
||||
details.
|
||||
|
||||
Args:
|
||||
groups (int): Group of resnext.
|
||||
base_width (int): Base width of resnext.
|
||||
depth (int): Depth of resnext, from {50, 101, 152}.
|
||||
depth (int): Network depth, from {50, 101, 152}.
|
||||
groups (int): Groups of conv2 in Bottleneck. Default: 32.
|
||||
width_per_group (int): Width per group of conv2 in Bottleneck.
|
||||
Default: 4.
|
||||
in_channels (int): Number of input image channels. Default: 3.
|
||||
base_channels (int): Number of base channels of hidden layer.
|
||||
num_stages (int): Resnet stages. Default: 4.
|
||||
stem_channels (int): Output channels of the stem layer. Default: 64.
|
||||
num_stages (int): Stages of the network. Default: 4.
|
||||
strides (Sequence[int]): Strides of the first block of each stage.
|
||||
Default: ``(1, 2, 2, 2)``.
|
||||
dilations (Sequence[int]): Dilation of each stage.
|
||||
out_indices (Sequence[int]): Output from which stages.
|
||||
Default: ``(1, 1, 1, 1)``.
|
||||
out_indices (Sequence[int]): Output from which stages. If only one
|
||||
stage is specified, a single tensor (feature map) is returned,
|
||||
otherwise multiple stages are specified, a tuple of tensors will
|
||||
be returned. Default: ``(3, )``.
|
||||
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
|
||||
layer is the 3x3 conv layer, otherwise the stride-two layer is
|
||||
the first 1x1 conv layer.
|
||||
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
|
||||
not freezing any parameters.
|
||||
norm_cfg (dict): dictionary to construct and config norm layer.
|
||||
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
|
||||
Default: False.
|
||||
avg_down (bool): Use AvgPool instead of stride conv when
|
||||
downsampling in the bottleneck. Default: False.
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
||||
-1 means not freezing any parameters. Default: -1.
|
||||
conv_cfg (dict | None): The config dict for conv layers. Default: None.
|
||||
norm_cfg (dict): The config dict for norm layers.
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only.
|
||||
and its variants only. Default: False.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed.
|
||||
zero_init_residual (bool): whether to use zero init for last norm layer
|
||||
in resblocks to let them behave as identity.
|
||||
memory while slowing down the training speed. Default: False.
|
||||
zero_init_residual (bool): Whether to use zero init for last norm layer
|
||||
in resblocks to let them behave as identity. Default: True.
|
||||
"""
|
||||
|
||||
arch_settings = {
|
||||
|
@ -119,14 +134,14 @@ class ResNeXt(ResNet):
|
|||
152: (Bottleneck, (3, 8, 36, 3))
|
||||
}
|
||||
|
||||
def __init__(self, groups=1, base_width=4, **kwargs):
|
||||
def __init__(self, depth, groups=32, width_per_group=4, **kwargs):
|
||||
self.groups = groups
|
||||
self.base_width = base_width
|
||||
super(ResNeXt, self).__init__(**kwargs)
|
||||
self.width_per_group = width_per_group
|
||||
super(ResNeXt, self).__init__(depth, **kwargs)
|
||||
|
||||
def make_res_layer(self, **kwargs):
|
||||
return ResLayer(
|
||||
groups=self.groups,
|
||||
base_width=self.base_width,
|
||||
width_per_group=self.width_per_group,
|
||||
base_channels=self.base_channels,
|
||||
**kwargs)
|
||||
|
|
|
@ -9,15 +9,14 @@ class SEBottleneck(Bottleneck):
|
|||
"""SEBottleneck block for SEResNet.
|
||||
|
||||
Args:
|
||||
inplanes (int): The input channels of the SEBottleneck block.
|
||||
planes (int): The output channel base of the SEBottleneck block.
|
||||
in_channels (int): The input channels of the SEBottleneck block.
|
||||
out_channels (int): The output channel of the SEBottleneck block.
|
||||
se_ratio (int): Squeeze ratio in SELayer. Default: 16
|
||||
"""
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, se_ratio=16, **kwargs):
|
||||
super(SEBottleneck, self).__init__(inplanes, planes, **kwargs)
|
||||
self.se_layer = SELayer(planes * self.expansion, ratio=se_ratio)
|
||||
def __init__(self, in_channels, out_channels, se_ratio=16, **kwargs):
|
||||
super(SEBottleneck, self).__init__(in_channels, out_channels, **kwargs)
|
||||
self.se_layer = SELayer(out_channels, ratio=se_ratio)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
|
@ -58,31 +57,41 @@ class SEBottleneck(Bottleneck):
|
|||
class SEResNet(ResNet):
|
||||
"""SEResNet backbone.
|
||||
|
||||
Please refer to the `paper <https://arxiv.org/abs/1709.01507>`_ for
|
||||
details.
|
||||
|
||||
Args:
|
||||
depth (int): Depth of seresnet, from {50, 101, 152}.
|
||||
in_channels (int): Number of input image channels. Normally 3.
|
||||
base_channels (int): Number of base channels of hidden layer.
|
||||
num_stages (int): Resnet stages, normally 4.
|
||||
depth (int): Network depth, from {50, 101, 152}.
|
||||
se_ratio (int): Squeeze ratio in SELayer. Default: 16.
|
||||
in_channels (int): Number of input image channels. Default: 3.
|
||||
stem_channels (int): Output channels of the stem layer. Default: 64.
|
||||
num_stages (int): Stages of the network. Default: 4.
|
||||
strides (Sequence[int]): Strides of the first block of each stage.
|
||||
Default: ``(1, 2, 2, 2)``.
|
||||
dilations (Sequence[int]): Dilation of each stage.
|
||||
out_indices (Sequence[int]): Output from which stages.
|
||||
se_ratio (int): Squeeze ratio in SELayer. Default: 16
|
||||
Default: ``(1, 1, 1, 1)``.
|
||||
out_indices (Sequence[int]): Output from which stages. If only one
|
||||
stage is specified, a single tensor (feature map) is returned,
|
||||
otherwise multiple stages are specified, a tuple of tensors will
|
||||
be returned. Default: ``(3, )``.
|
||||
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
|
||||
layer is the 3x3 conv layer, otherwise the stride-two layer is
|
||||
the first 1x1 conv layer.
|
||||
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv
|
||||
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
|
||||
Default: False.
|
||||
avg_down (bool): Use AvgPool instead of stride conv when
|
||||
downsampling in the bottleneck.
|
||||
downsampling in the bottleneck. Default: False.
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
||||
-1 means not freezing any parameters.
|
||||
norm_cfg (dict): Dictionary to construct and config norm layer.
|
||||
-1 means not freezing any parameters. Default: -1.
|
||||
conv_cfg (dict | None): The config dict for conv layers. Default: None.
|
||||
norm_cfg (dict): The config dict for norm layers.
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only.
|
||||
and its variants only. Default: False.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed.
|
||||
memory while slowing down the training speed. Default: False.
|
||||
zero_init_residual (bool): Whether to use zero init for last norm layer
|
||||
in resblocks to let them behave as identity.
|
||||
in resblocks to let them behave as identity. Default: True.
|
||||
|
||||
Example:
|
||||
>>> from mmcls.models import SEResNet
|
||||
|
@ -107,7 +116,7 @@ class SEResNet(ResNet):
|
|||
|
||||
def __init__(self, depth, se_ratio=16, **kwargs):
|
||||
if depth not in self.arch_settings:
|
||||
raise KeyError(f'invalid depth {depth} for resnet')
|
||||
raise KeyError(f'invalid depth {depth} for SEResNet')
|
||||
self.se_ratio = se_ratio
|
||||
super(SEResNet, self).__init__(depth, **kwargs)
|
||||
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
import math
|
||||
|
||||
from mmcv.cnn import build_conv_layer, build_norm_layer
|
||||
|
||||
from ..builder import BACKBONES
|
||||
|
@ -12,14 +10,15 @@ class SEBottleneck(_SEBottleneck):
|
|||
"""SEBottleneck block for SEResNeXt.
|
||||
|
||||
Args:
|
||||
inplanes (int): inplanes of block.
|
||||
planes (int): planes of block.
|
||||
groups (int): group of convolution.
|
||||
base_width (int): Base width of resnext.
|
||||
base_channels (int): Number of base channels of hidden layer.
|
||||
in_channels (int): Input channels of this block.
|
||||
out_channels (int): Output channels of this block.
|
||||
groups (int): Groups of conv2.
|
||||
width_per_group (int): Width per group of conv2. 64x4d indicates
|
||||
``groups=64, width_per_group=4`` and 32x8d indicates
|
||||
``groups=32, width_per_group=8``.
|
||||
stride (int): stride of the block. Default: 1
|
||||
dilation (int): dilation of convolution. Default: 1
|
||||
dbownsample (nn.Module): downsample operation on identity branch.
|
||||
downsample (nn.Module): downsample operation on identity branch.
|
||||
Default: None
|
||||
se_ratio (int): Squeeze ratio in SELayer. Default: 16
|
||||
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
|
||||
|
@ -33,35 +32,33 @@ class SEBottleneck(_SEBottleneck):
|
|||
memory while slowing down the training speed.
|
||||
"""
|
||||
|
||||
expansion = 4
|
||||
|
||||
def __init__(self,
|
||||
inplanes,
|
||||
planes,
|
||||
groups=1,
|
||||
base_width=4,
|
||||
base_channels=64,
|
||||
in_channels,
|
||||
out_channels,
|
||||
groups=32,
|
||||
width_per_group=4,
|
||||
se_ratio=16,
|
||||
**kwargs):
|
||||
super(SEBottleneck, self).__init__(inplanes, planes, se_ratio,
|
||||
super(SEBottleneck, self).__init__(in_channels, out_channels, se_ratio,
|
||||
**kwargs)
|
||||
self.groups = groups
|
||||
self.width_per_group = width_per_group
|
||||
|
||||
if groups == 1:
|
||||
width = self.planes
|
||||
width = self.mid_channels
|
||||
else:
|
||||
width = math.floor(self.planes *
|
||||
(base_width / base_channels)) * groups
|
||||
width = groups * width_per_group
|
||||
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
self.norm_cfg, width, postfix=1)
|
||||
self.norm2_name, norm2 = build_norm_layer(
|
||||
self.norm_cfg, width, postfix=2)
|
||||
self.norm3_name, norm3 = build_norm_layer(
|
||||
self.norm_cfg, self.planes * self.expansion, postfix=3)
|
||||
self.norm_cfg, self.out_channels, postfix=3)
|
||||
|
||||
self.conv1 = build_conv_layer(
|
||||
self.conv_cfg,
|
||||
self.inplanes,
|
||||
self.in_channels,
|
||||
width,
|
||||
kernel_size=1,
|
||||
stride=self.conv1_stride,
|
||||
|
@ -80,11 +77,7 @@ class SEBottleneck(_SEBottleneck):
|
|||
|
||||
self.add_module(self.norm2_name, norm2)
|
||||
self.conv3 = build_conv_layer(
|
||||
self.conv_cfg,
|
||||
width,
|
||||
self.planes * self.expansion,
|
||||
kernel_size=1,
|
||||
bias=False)
|
||||
self.conv_cfg, width, self.out_channels, kernel_size=1, bias=False)
|
||||
self.add_module(self.norm3_name, norm3)
|
||||
|
||||
|
||||
|
@ -92,30 +85,44 @@ class SEBottleneck(_SEBottleneck):
|
|||
class SEResNeXt(SEResNet):
|
||||
"""SEResNeXt backbone.
|
||||
|
||||
Please refer to the `paper <https://arxiv.org/abs/1709.01507>`_ for
|
||||
details.
|
||||
|
||||
Args:
|
||||
groups (int): Group of seresnext.
|
||||
base_width (int): Base width of resnext.
|
||||
depth (int): Depth of resnext, from {50, 101, 152}.
|
||||
depth (int): Network depth, from {50, 101, 152}.
|
||||
groups (int): Groups of conv2 in Bottleneck. Default: 32.
|
||||
width_per_group (int): Width per group of conv2 in Bottleneck.
|
||||
Default: 4.
|
||||
se_ratio (int): Squeeze ratio in SELayer. Default: 16.
|
||||
in_channels (int): Number of input image channels. Default: 3.
|
||||
base_channels (int): Number of base channels of hidden layer.
|
||||
num_stages (int): Resnet stages. Default: 4.
|
||||
stem_channels (int): Output channels of the stem layer. Default: 64.
|
||||
num_stages (int): Stages of the network. Default: 4.
|
||||
strides (Sequence[int]): Strides of the first block of each stage.
|
||||
Default: ``(1, 2, 2, 2)``.
|
||||
dilations (Sequence[int]): Dilation of each stage.
|
||||
out_indices (Sequence[int]): Output from which stages.
|
||||
se_ratio (int): Squeeze ratio in SELayer. Default: 16
|
||||
Default: ``(1, 1, 1, 1)``.
|
||||
out_indices (Sequence[int]): Output from which stages. If only one
|
||||
stage is specified, a single tensor (feature map) is returned,
|
||||
otherwise multiple stages are specified, a tuple of tensors will
|
||||
be returned. Default: ``(3, )``.
|
||||
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
|
||||
layer is the 3x3 conv layer, otherwise the stride-two layer is
|
||||
the first 1x1 conv layer.
|
||||
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
|
||||
not freezing any parameters.
|
||||
norm_cfg (dict): dictionary to construct and config norm layer.
|
||||
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
|
||||
Default: False.
|
||||
avg_down (bool): Use AvgPool instead of stride conv when
|
||||
downsampling in the bottleneck. Default: False.
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
||||
-1 means not freezing any parameters. Default: -1.
|
||||
conv_cfg (dict | None): The config dict for conv layers. Default: None.
|
||||
norm_cfg (dict): The config dict for norm layers.
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only.
|
||||
and its variants only. Default: False.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed.
|
||||
zero_init_residual (bool): whether to use zero init for last norm layer
|
||||
in resblocks to let them behave as identity.
|
||||
memory while slowing down the training speed. Default: False.
|
||||
zero_init_residual (bool): Whether to use zero init for last norm layer
|
||||
in resblocks to let them behave as identity. Default: True.
|
||||
"""
|
||||
|
||||
arch_settings = {
|
||||
|
@ -124,14 +131,11 @@ class SEResNeXt(SEResNet):
|
|||
152: (SEBottleneck, (3, 8, 36, 3))
|
||||
}
|
||||
|
||||
def __init__(self, groups=1, base_width=4, **kwargs):
|
||||
def __init__(self, depth, groups=32, width_per_group=4, **kwargs):
|
||||
self.groups = groups
|
||||
self.base_width = base_width
|
||||
super(SEResNeXt, self).__init__(**kwargs)
|
||||
self.width_per_group = width_per_group
|
||||
super(SEResNeXt, self).__init__(depth, **kwargs)
|
||||
|
||||
def make_res_layer(self, **kwargs):
|
||||
return ResLayer(
|
||||
groups=self.groups,
|
||||
base_width=self.base_width,
|
||||
base_channels=self.base_channels,
|
||||
**kwargs)
|
||||
groups=self.groups, width_per_group=self.width_per_group, **kwargs)
|
||||
|
|
|
@ -5,17 +5,18 @@ class SELayer(nn.Module):
|
|||
"""Squeeze-and-Excitation Module.
|
||||
|
||||
Args:
|
||||
inplanes (int): The input channels of the SEBottleneck block.
|
||||
ratio (int): Squeeze ratio in SELayer. Default: 16
|
||||
channels (int): The input (and output) channels of the SE layer.
|
||||
ratio (int): Squeeze ratio in SELayer, the intermediate channel will be
|
||||
``int(channels/ratio)``. Default: 16.
|
||||
"""
|
||||
|
||||
def __init__(self, inplanes, ratio=16):
|
||||
def __init__(self, channels, ratio=16):
|
||||
super(SELayer, self).__init__()
|
||||
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
||||
self.conv1 = nn.Conv2d(
|
||||
inplanes, int(inplanes / ratio), kernel_size=1, stride=1)
|
||||
channels, int(channels / ratio), kernel_size=1, stride=1)
|
||||
self.conv2 = nn.Conv2d(
|
||||
int(inplanes / ratio), inplanes, kernel_size=1, stride=1)
|
||||
int(channels / ratio), channels, kernel_size=1, stride=1)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
|
|
|
@ -17,6 +17,6 @@ line_length = 79
|
|||
multi_line_output = 0
|
||||
known_standard_library = pkg_resources,setuptools
|
||||
known_first_party = mmcls
|
||||
known_third_party = cv2,mmcv,numpy,torch,torchvision
|
||||
known_third_party = cv2,mmcv,numpy,pytest,torch,torchvision
|
||||
no_lines_before = STDLIB,LOCALFOLDER
|
||||
default_section = THIRDPARTY
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
from torch.nn.modules import AvgPool2d
|
||||
|
||||
from mmcls.models.backbones import ResNet, ResNetV1d
|
||||
from mmcls.models.backbones.resnet import BasicBlock, Bottleneck, ResLayer
|
||||
from mmcls.models.backbones.resnet import (BasicBlock, Bottleneck, ResLayer,
|
||||
get_expansion)
|
||||
|
||||
|
||||
def is_block(modules):
|
||||
|
@ -14,13 +16,6 @@ def is_block(modules):
|
|||
return False
|
||||
|
||||
|
||||
def is_norm(modules):
|
||||
"""Check if is one of the norms."""
|
||||
if isinstance(modules, (_BatchNorm, )):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def all_zeros(modules):
|
||||
"""Check if the weight(and bias) is all zero."""
|
||||
weight_zero = torch.equal(modules.weight.data,
|
||||
|
@ -43,12 +38,44 @@ def check_norm_state(modules, train_state):
|
|||
return True
|
||||
|
||||
|
||||
def test_resnet_basic_block():
|
||||
# Test BasicBlock structure and forward
|
||||
def test_get_expansion():
|
||||
assert get_expansion(Bottleneck, 2) == 2
|
||||
assert get_expansion(BasicBlock) == 1
|
||||
assert get_expansion(Bottleneck) == 4
|
||||
|
||||
class MyResBlock(nn.Module):
|
||||
|
||||
expansion = 8
|
||||
|
||||
assert get_expansion(MyResBlock) == 8
|
||||
|
||||
# expansion must be an integer or None
|
||||
with pytest.raises(TypeError):
|
||||
get_expansion(Bottleneck, '0')
|
||||
|
||||
# expansion is not specified and cannot be inferred
|
||||
with pytest.raises(TypeError):
|
||||
|
||||
class SomeModule(nn.Module):
|
||||
pass
|
||||
|
||||
get_expansion(SomeModule)
|
||||
|
||||
|
||||
def test_basic_block():
|
||||
# expansion must be 1
|
||||
with pytest.raises(AssertionError):
|
||||
BasicBlock(64, 64, expansion=2)
|
||||
|
||||
# BasicBlock with stride 1, out_channels == in_channels
|
||||
block = BasicBlock(64, 64)
|
||||
assert block.in_channels == 64
|
||||
assert block.mid_channels == 64
|
||||
assert block.out_channels == 64
|
||||
assert block.conv1.in_channels == 64
|
||||
assert block.conv1.out_channels == 64
|
||||
assert block.conv1.kernel_size == (3, 3)
|
||||
assert block.conv1.stride == (1, 1)
|
||||
assert block.conv2.in_channels == 64
|
||||
assert block.conv2.out_channels == 64
|
||||
assert block.conv2.kernel_size == (3, 3)
|
||||
|
@ -56,26 +83,59 @@ def test_resnet_basic_block():
|
|||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 64, 56, 56])
|
||||
|
||||
# Test BasicBlock with checkpoint forward
|
||||
# BasicBlock with stride 1 and downsample
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(64, 128, kernel_size=1, bias=False), nn.BatchNorm2d(128))
|
||||
block = BasicBlock(64, 128, downsample=downsample)
|
||||
assert block.in_channels == 64
|
||||
assert block.mid_channels == 128
|
||||
assert block.out_channels == 128
|
||||
assert block.conv1.in_channels == 64
|
||||
assert block.conv1.out_channels == 128
|
||||
assert block.conv1.kernel_size == (3, 3)
|
||||
assert block.conv1.stride == (1, 1)
|
||||
assert block.conv2.in_channels == 128
|
||||
assert block.conv2.out_channels == 128
|
||||
assert block.conv2.kernel_size == (3, 3)
|
||||
x = torch.randn(1, 64, 56, 56)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 128, 56, 56])
|
||||
|
||||
# BasicBlock with stride 2 and downsample
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(64, 128, kernel_size=1, stride=2, bias=False),
|
||||
nn.BatchNorm2d(128))
|
||||
block = BasicBlock(64, 128, stride=2, downsample=downsample)
|
||||
assert block.in_channels == 64
|
||||
assert block.mid_channels == 128
|
||||
assert block.out_channels == 128
|
||||
assert block.conv1.in_channels == 64
|
||||
assert block.conv1.out_channels == 128
|
||||
assert block.conv1.kernel_size == (3, 3)
|
||||
assert block.conv1.stride == (2, 2)
|
||||
assert block.conv2.in_channels == 128
|
||||
assert block.conv2.out_channels == 128
|
||||
assert block.conv2.kernel_size == (3, 3)
|
||||
x = torch.randn(1, 64, 56, 56)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 128, 28, 28])
|
||||
|
||||
# forward with checkpointing
|
||||
block = BasicBlock(64, 64, with_cp=True)
|
||||
assert block.with_cp
|
||||
x = torch.randn(1, 64, 56, 56)
|
||||
x = torch.randn(1, 64, 56, 56, requires_grad=True)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 64, 56, 56])
|
||||
|
||||
|
||||
def test_resnet_bottleneck():
|
||||
|
||||
def test_bottleneck():
|
||||
# style must be in ['pytorch', 'caffe']
|
||||
with pytest.raises(AssertionError):
|
||||
# Style must be in ['pytorch', 'caffe']
|
||||
Bottleneck(64, 64, style='tensorflow')
|
||||
|
||||
# Test Bottleneck with checkpoint forward
|
||||
block = Bottleneck(64, 16, with_cp=True)
|
||||
assert block.with_cp
|
||||
x = torch.randn(1, 64, 56, 56)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 64, 56, 56])
|
||||
# expansion must be divisible by out_channels
|
||||
with pytest.raises(AssertionError):
|
||||
Bottleneck(64, 64, expansion=3)
|
||||
|
||||
# Test Bottleneck style
|
||||
block = Bottleneck(64, 64, stride=2, style='pytorch')
|
||||
|
@ -85,60 +145,232 @@ def test_resnet_bottleneck():
|
|||
assert block.conv1.stride == (2, 2)
|
||||
assert block.conv2.stride == (1, 1)
|
||||
|
||||
# Test Bottleneck forward
|
||||
block = Bottleneck(64, 16)
|
||||
# Bottleneck with stride 1
|
||||
block = Bottleneck(64, 64, style='pytorch')
|
||||
assert block.in_channels == 64
|
||||
assert block.mid_channels == 16
|
||||
assert block.out_channels == 64
|
||||
assert block.conv1.in_channels == 64
|
||||
assert block.conv1.out_channels == 16
|
||||
assert block.conv1.kernel_size == (1, 1)
|
||||
assert block.conv2.in_channels == 16
|
||||
assert block.conv2.out_channels == 16
|
||||
assert block.conv2.kernel_size == (3, 3)
|
||||
assert block.conv3.in_channels == 16
|
||||
assert block.conv3.out_channels == 64
|
||||
assert block.conv3.kernel_size == (1, 1)
|
||||
x = torch.randn(1, 64, 56, 56)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == (1, 64, 56, 56)
|
||||
|
||||
# Bottleneck with stride 1 and downsample
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(64, 128, kernel_size=1), nn.BatchNorm2d(128))
|
||||
block = Bottleneck(64, 128, style='pytorch', downsample=downsample)
|
||||
assert block.in_channels == 64
|
||||
assert block.mid_channels == 32
|
||||
assert block.out_channels == 128
|
||||
assert block.conv1.in_channels == 64
|
||||
assert block.conv1.out_channels == 32
|
||||
assert block.conv1.kernel_size == (1, 1)
|
||||
assert block.conv2.in_channels == 32
|
||||
assert block.conv2.out_channels == 32
|
||||
assert block.conv2.kernel_size == (3, 3)
|
||||
assert block.conv3.in_channels == 32
|
||||
assert block.conv3.out_channels == 128
|
||||
assert block.conv3.kernel_size == (1, 1)
|
||||
x = torch.randn(1, 64, 56, 56)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == (1, 128, 56, 56)
|
||||
|
||||
# Bottleneck with stride 2 and downsample
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(64, 128, kernel_size=1, stride=2), nn.BatchNorm2d(128))
|
||||
block = Bottleneck(
|
||||
64, 128, stride=2, style='pytorch', downsample=downsample)
|
||||
x = torch.randn(1, 64, 56, 56)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == (1, 128, 28, 28)
|
||||
|
||||
# Bottleneck with expansion 2
|
||||
block = Bottleneck(64, 64, style='pytorch', expansion=2)
|
||||
assert block.in_channels == 64
|
||||
assert block.mid_channels == 32
|
||||
assert block.out_channels == 64
|
||||
assert block.conv1.in_channels == 64
|
||||
assert block.conv1.out_channels == 32
|
||||
assert block.conv1.kernel_size == (1, 1)
|
||||
assert block.conv2.in_channels == 32
|
||||
assert block.conv2.out_channels == 32
|
||||
assert block.conv2.kernel_size == (3, 3)
|
||||
assert block.conv3.in_channels == 32
|
||||
assert block.conv3.out_channels == 64
|
||||
assert block.conv3.kernel_size == (1, 1)
|
||||
x = torch.randn(1, 64, 56, 56)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == (1, 64, 56, 56)
|
||||
|
||||
# Test Bottleneck with checkpointing
|
||||
block = Bottleneck(64, 64, with_cp=True)
|
||||
block.train()
|
||||
assert block.with_cp
|
||||
x = torch.randn(1, 64, 56, 56, requires_grad=True)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 64, 56, 56])
|
||||
|
||||
|
||||
def test_resnet_res_layer():
|
||||
# Test ResLayer of 3 Bottleneck w\o downsample
|
||||
layer = ResLayer(Bottleneck, 64, 16, 3)
|
||||
def test_basicblock_reslayer():
|
||||
# 3 BasicBlock w/o downsample
|
||||
layer = ResLayer(BasicBlock, 3, 32, 32)
|
||||
assert len(layer) == 3
|
||||
assert layer[0].conv1.in_channels == 64
|
||||
assert layer[0].conv1.out_channels == 16
|
||||
for i in range(1, len(layer)):
|
||||
assert layer[i].conv1.in_channels == 64
|
||||
assert layer[i].conv1.out_channels == 16
|
||||
for i in range(len(layer)):
|
||||
for i in range(3):
|
||||
assert layer[i].in_channels == 32
|
||||
assert layer[i].out_channels == 32
|
||||
assert layer[i].downsample is None
|
||||
x = torch.randn(1, 64, 56, 56)
|
||||
x = torch.randn(1, 32, 56, 56)
|
||||
x_out = layer(x)
|
||||
assert x_out.shape == torch.Size([1, 64, 56, 56])
|
||||
assert x_out.shape == (1, 32, 56, 56)
|
||||
|
||||
# Test ResLayer of 3 Bottleneck with downsample
|
||||
layer = ResLayer(Bottleneck, 64, 64, 3)
|
||||
assert layer[0].downsample[0].out_channels == 256
|
||||
for i in range(1, len(layer)):
|
||||
# 3 BasicBlock w/ stride 1 and downsample
|
||||
layer = ResLayer(BasicBlock, 3, 32, 64)
|
||||
assert len(layer) == 3
|
||||
assert layer[0].in_channels == 32
|
||||
assert layer[0].out_channels == 64
|
||||
assert layer[0].downsample is not None and len(layer[0].downsample) == 2
|
||||
assert isinstance(layer[0].downsample[0], nn.Conv2d)
|
||||
assert layer[0].downsample[0].stride == (1, 1)
|
||||
for i in range(1, 3):
|
||||
assert layer[i].in_channels == 64
|
||||
assert layer[i].out_channels == 64
|
||||
assert layer[i].downsample is None
|
||||
x = torch.randn(1, 64, 56, 56)
|
||||
x = torch.randn(1, 32, 56, 56)
|
||||
x_out = layer(x)
|
||||
assert x_out.shape == torch.Size([1, 256, 56, 56])
|
||||
assert x_out.shape == (1, 64, 56, 56)
|
||||
|
||||
# Test ResLayer of 3 Bottleneck with stride=2
|
||||
layer = ResLayer(Bottleneck, 64, 64, 3, stride=2)
|
||||
assert layer[0].downsample[0].out_channels == 256
|
||||
# 3 BasicBlock w/ stride 2 and downsample
|
||||
layer = ResLayer(BasicBlock, 3, 32, 64, stride=2)
|
||||
assert len(layer) == 3
|
||||
assert layer[0].in_channels == 32
|
||||
assert layer[0].out_channels == 64
|
||||
assert layer[0].stride == 2
|
||||
assert layer[0].downsample is not None and len(layer[0].downsample) == 2
|
||||
assert isinstance(layer[0].downsample[0], nn.Conv2d)
|
||||
assert layer[0].downsample[0].stride == (2, 2)
|
||||
for i in range(1, len(layer)):
|
||||
for i in range(1, 3):
|
||||
assert layer[i].in_channels == 64
|
||||
assert layer[i].out_channels == 64
|
||||
assert layer[i].stride == 1
|
||||
assert layer[i].downsample is None
|
||||
x = torch.randn(1, 64, 56, 56)
|
||||
x = torch.randn(1, 32, 56, 56)
|
||||
x_out = layer(x)
|
||||
assert x_out.shape == torch.Size([1, 256, 28, 28])
|
||||
assert x_out.shape == (1, 64, 28, 28)
|
||||
|
||||
# Test ResLayer of 3 Bottleneck with stride=2 and average downsample
|
||||
layer = ResLayer(Bottleneck, 64, 64, 3, stride=2, avg_down=True)
|
||||
assert isinstance(layer[0].downsample[0], AvgPool2d)
|
||||
assert layer[0].downsample[1].out_channels == 256
|
||||
assert layer[0].downsample[1].stride == (1, 1)
|
||||
for i in range(1, len(layer)):
|
||||
# 3 BasicBlock w/ stride 2 and downsample with avg pool
|
||||
layer = ResLayer(BasicBlock, 3, 32, 64, stride=2, avg_down=True)
|
||||
assert len(layer) == 3
|
||||
assert layer[0].in_channels == 32
|
||||
assert layer[0].out_channels == 64
|
||||
assert layer[0].stride == 2
|
||||
assert layer[0].downsample is not None and len(layer[0].downsample) == 3
|
||||
assert isinstance(layer[0].downsample[0], nn.AvgPool2d)
|
||||
assert layer[0].downsample[0].stride == 2
|
||||
for i in range(1, 3):
|
||||
assert layer[i].in_channels == 64
|
||||
assert layer[i].out_channels == 64
|
||||
assert layer[i].stride == 1
|
||||
assert layer[i].downsample is None
|
||||
x = torch.randn(1, 64, 56, 56)
|
||||
x = torch.randn(1, 32, 56, 56)
|
||||
x_out = layer(x)
|
||||
assert x_out.shape == torch.Size([1, 256, 28, 28])
|
||||
assert x_out.shape == (1, 64, 28, 28)
|
||||
|
||||
|
||||
def test_resnet_backbone():
|
||||
def test_bottleneck_reslayer():
|
||||
# 3 Bottleneck w/o downsample
|
||||
layer = ResLayer(Bottleneck, 3, 32, 32)
|
||||
assert len(layer) == 3
|
||||
for i in range(3):
|
||||
assert layer[i].in_channels == 32
|
||||
assert layer[i].out_channels == 32
|
||||
assert layer[i].downsample is None
|
||||
x = torch.randn(1, 32, 56, 56)
|
||||
x_out = layer(x)
|
||||
assert x_out.shape == (1, 32, 56, 56)
|
||||
|
||||
# 3 Bottleneck w/ stride 1 and downsample
|
||||
layer = ResLayer(Bottleneck, 3, 32, 64)
|
||||
assert len(layer) == 3
|
||||
assert layer[0].in_channels == 32
|
||||
assert layer[0].out_channels == 64
|
||||
assert layer[0].stride == 1
|
||||
assert layer[0].conv1.out_channels == 16
|
||||
assert layer[0].downsample is not None and len(layer[0].downsample) == 2
|
||||
assert isinstance(layer[0].downsample[0], nn.Conv2d)
|
||||
assert layer[0].downsample[0].stride == (1, 1)
|
||||
for i in range(1, 3):
|
||||
assert layer[i].in_channels == 64
|
||||
assert layer[i].out_channels == 64
|
||||
assert layer[i].conv1.out_channels == 16
|
||||
assert layer[i].stride == 1
|
||||
assert layer[i].downsample is None
|
||||
x = torch.randn(1, 32, 56, 56)
|
||||
x_out = layer(x)
|
||||
assert x_out.shape == (1, 64, 56, 56)
|
||||
|
||||
# 3 Bottleneck w/ stride 2 and downsample
|
||||
layer = ResLayer(Bottleneck, 3, 32, 64, stride=2)
|
||||
assert len(layer) == 3
|
||||
assert layer[0].in_channels == 32
|
||||
assert layer[0].out_channels == 64
|
||||
assert layer[0].stride == 2
|
||||
assert layer[0].conv1.out_channels == 16
|
||||
assert layer[0].downsample is not None and len(layer[0].downsample) == 2
|
||||
assert isinstance(layer[0].downsample[0], nn.Conv2d)
|
||||
assert layer[0].downsample[0].stride == (2, 2)
|
||||
for i in range(1, 3):
|
||||
assert layer[i].in_channels == 64
|
||||
assert layer[i].out_channels == 64
|
||||
assert layer[i].conv1.out_channels == 16
|
||||
assert layer[i].stride == 1
|
||||
assert layer[i].downsample is None
|
||||
x = torch.randn(1, 32, 56, 56)
|
||||
x_out = layer(x)
|
||||
assert x_out.shape == (1, 64, 28, 28)
|
||||
|
||||
# 3 Bottleneck w/ stride 2 and downsample with avg pool
|
||||
layer = ResLayer(Bottleneck, 3, 32, 64, stride=2, avg_down=True)
|
||||
assert len(layer) == 3
|
||||
assert layer[0].in_channels == 32
|
||||
assert layer[0].out_channels == 64
|
||||
assert layer[0].stride == 2
|
||||
assert layer[0].conv1.out_channels == 16
|
||||
assert layer[0].downsample is not None and len(layer[0].downsample) == 3
|
||||
assert isinstance(layer[0].downsample[0], nn.AvgPool2d)
|
||||
assert layer[0].downsample[0].stride == 2
|
||||
for i in range(1, 3):
|
||||
assert layer[i].in_channels == 64
|
||||
assert layer[i].out_channels == 64
|
||||
assert layer[i].conv1.out_channels == 16
|
||||
assert layer[i].stride == 1
|
||||
assert layer[i].downsample is None
|
||||
x = torch.randn(1, 32, 56, 56)
|
||||
x_out = layer(x)
|
||||
assert x_out.shape == (1, 64, 28, 28)
|
||||
|
||||
# 3 Bottleneck with custom expansion
|
||||
layer = ResLayer(Bottleneck, 3, 32, 32, expansion=2)
|
||||
assert len(layer) == 3
|
||||
for i in range(3):
|
||||
assert layer[i].in_channels == 32
|
||||
assert layer[i].out_channels == 32
|
||||
assert layer[i].stride == 1
|
||||
assert layer[i].conv1.out_channels == 16
|
||||
assert layer[i].downsample is None
|
||||
x = torch.randn(1, 32, 56, 56)
|
||||
x_out = layer(x)
|
||||
assert x_out.shape == (1, 32, 56, 56)
|
||||
|
||||
|
||||
def test_resnet():
|
||||
"""Test resnet backbone"""
|
||||
with pytest.raises(KeyError):
|
||||
# ResNet depth should be in [18, 34, 50, 101, 152]
|
||||
|
@ -194,9 +426,113 @@ def test_resnet_backbone():
|
|||
for param in layer.parameters():
|
||||
assert param.requires_grad is False
|
||||
|
||||
# Test ResNet18 forward
|
||||
model = ResNet(18, out_indices=(0, 1, 2, 3))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 4
|
||||
assert feat[0].shape == (1, 64, 56, 56)
|
||||
assert feat[1].shape == (1, 128, 28, 28)
|
||||
assert feat[2].shape == (1, 256, 14, 14)
|
||||
assert feat[3].shape == (1, 512, 7, 7)
|
||||
|
||||
# Test ResNet50 with BatchNorm forward
|
||||
model = ResNet(50, out_indices=(0, 1, 2, 3))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 4
|
||||
assert feat[0].shape == (1, 256, 56, 56)
|
||||
assert feat[1].shape == (1, 512, 28, 28)
|
||||
assert feat[2].shape == (1, 1024, 14, 14)
|
||||
assert feat[3].shape == (1, 2048, 7, 7)
|
||||
|
||||
# Test ResNet50 with layers 1, 2, 3 out forward
|
||||
model = ResNet(50, out_indices=(0, 1, 2))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 3
|
||||
assert feat[0].shape == (1, 256, 56, 56)
|
||||
assert feat[1].shape == (1, 512, 28, 28)
|
||||
assert feat[2].shape == (1, 1024, 14, 14)
|
||||
|
||||
# Test ResNet50 with layers 3 (top feature maps) out forward
|
||||
model = ResNet(50, out_indices=(3, ))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert feat.shape == (1, 2048, 7, 7)
|
||||
|
||||
# Test ResNet50 with checkpoint forward
|
||||
model = ResNet(50, out_indices=(0, 1, 2, 3), with_cp=True)
|
||||
for m in model.modules():
|
||||
if is_block(m):
|
||||
assert m.with_cp
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 4
|
||||
assert feat[0].shape == (1, 256, 56, 56)
|
||||
assert feat[1].shape == (1, 512, 28, 28)
|
||||
assert feat[2].shape == (1, 1024, 14, 14)
|
||||
assert feat[3].shape == (1, 2048, 7, 7)
|
||||
|
||||
# zero initialization of residual blocks
|
||||
model = ResNet(50, out_indices=(0, 1, 2, 3), zero_init_residual=True)
|
||||
model.init_weights()
|
||||
for m in model.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
assert all_zeros(m.norm3)
|
||||
elif isinstance(m, BasicBlock):
|
||||
assert all_zeros(m.norm2)
|
||||
|
||||
# non-zero initialization of residual blocks
|
||||
model = ResNet(50, out_indices=(0, 1, 2, 3), zero_init_residual=False)
|
||||
model.init_weights()
|
||||
for m in model.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
assert not all_zeros(m.norm3)
|
||||
elif isinstance(m, BasicBlock):
|
||||
assert not all_zeros(m.norm2)
|
||||
|
||||
|
||||
def test_resnet_v1d():
|
||||
model = ResNetV1d(depth=50, out_indices=(0, 1, 2, 3))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
assert len(model.stem) == 3
|
||||
for i in range(3):
|
||||
assert isinstance(model.stem[i], ConvModule)
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model.stem(imgs)
|
||||
assert feat.shape == (1, 64, 112, 112)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 4
|
||||
assert feat[0].shape == (1, 256, 56, 56)
|
||||
assert feat[1].shape == (1, 512, 28, 28)
|
||||
assert feat[2].shape == (1, 1024, 14, 14)
|
||||
assert feat[3].shape == (1, 2048, 7, 7)
|
||||
|
||||
# Test ResNet50V1d with first stage frozen
|
||||
frozen_stages = 1
|
||||
model = ResNetV1d(depth=50, frozen_stages=frozen_stages)
|
||||
assert len(model.stem) == 9
|
||||
assert len(model.stem) == 3
|
||||
for i in range(3):
|
||||
assert isinstance(model.stem[i], ConvModule)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
check_norm_state(model.stem, False)
|
||||
|
@ -210,99 +546,16 @@ def test_resnet_backbone():
|
|||
for param in layer.parameters():
|
||||
assert param.requires_grad is False
|
||||
|
||||
# Test ResNet18 forward
|
||||
model = ResNet(18, out_indices=(0, 1, 2, 3))
|
||||
|
||||
def test_resnet_half_channel():
|
||||
model = ResNet(50, base_channels=32, out_indices=(0, 1, 2, 3))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 4
|
||||
assert feat[0].shape == torch.Size([1, 64, 56, 56])
|
||||
assert feat[1].shape == torch.Size([1, 128, 28, 28])
|
||||
assert feat[2].shape == torch.Size([1, 256, 14, 14])
|
||||
assert feat[3].shape == torch.Size([1, 512, 7, 7])
|
||||
|
||||
# Test ResNet50 with BatchNorm forward
|
||||
model = ResNet(50, out_indices=(0, 1, 2, 3))
|
||||
for m in model.modules():
|
||||
if is_norm(m):
|
||||
assert isinstance(m, _BatchNorm)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 4
|
||||
assert feat[0].shape == torch.Size([1, 256, 56, 56])
|
||||
assert feat[1].shape == torch.Size([1, 512, 28, 28])
|
||||
assert feat[2].shape == torch.Size([1, 1024, 14, 14])
|
||||
assert feat[3].shape == torch.Size([1, 2048, 7, 7])
|
||||
|
||||
# Test ResNet50 with layers 1, 2, 3 out forward
|
||||
model = ResNet(50, out_indices=(0, 1, 2))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 3
|
||||
assert feat[0].shape == torch.Size([1, 256, 56, 56])
|
||||
assert feat[1].shape == torch.Size([1, 512, 28, 28])
|
||||
assert feat[2].shape == torch.Size([1, 1024, 14, 14])
|
||||
|
||||
# Test ResNet50 with layers 3 (top feature maps) out forward
|
||||
model = ResNet(50, out_indices=(3, ))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert feat.shape == torch.Size([1, 2048, 7, 7])
|
||||
|
||||
# Test ResNet50 with checkpoint forward
|
||||
model = ResNet(50, out_indices=(0, 1, 2, 3), with_cp=True)
|
||||
for m in model.modules():
|
||||
if is_block(m):
|
||||
assert m.with_cp
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 4
|
||||
assert feat[0].shape == torch.Size([1, 256, 56, 56])
|
||||
assert feat[1].shape == torch.Size([1, 512, 28, 28])
|
||||
assert feat[2].shape == torch.Size([1, 1024, 14, 14])
|
||||
assert feat[3].shape == torch.Size([1, 2048, 7, 7])
|
||||
|
||||
# Test ResNet50 zero initialization of residual
|
||||
model = ResNet(50, out_indices=(0, 1, 2, 3), zero_init_residual=True)
|
||||
model.init_weights()
|
||||
for m in model.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
assert all_zeros(m.norm3)
|
||||
elif isinstance(m, BasicBlock):
|
||||
assert all_zeros(m.norm2)
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 4
|
||||
assert feat[0].shape == torch.Size([1, 256, 56, 56])
|
||||
assert feat[1].shape == torch.Size([1, 512, 28, 28])
|
||||
assert feat[2].shape == torch.Size([1, 1024, 14, 14])
|
||||
assert feat[3].shape == torch.Size([1, 2048, 7, 7])
|
||||
|
||||
# Test ResNetV1d forward
|
||||
model = ResNetV1d(depth=50, out_indices=(0, 1, 2, 3))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 4
|
||||
assert feat[0].shape == torch.Size([1, 256, 56, 56])
|
||||
assert feat[1].shape == torch.Size([1, 512, 28, 28])
|
||||
assert feat[2].shape == torch.Size([1, 1024, 14, 14])
|
||||
assert feat[3].shape == torch.Size([1, 2048, 7, 7])
|
||||
assert feat[0].shape == (1, 128, 56, 56)
|
||||
assert feat[1].shape == (1, 256, 28, 28)
|
||||
assert feat[2].shape == (1, 512, 14, 14)
|
||||
assert feat[3].shape == (1, 1024, 7, 7)
|
||||
|
|
|
@ -5,42 +5,35 @@ from mmcls.models.backbones import ResNeXt
|
|||
from mmcls.models.backbones.resnext import Bottleneck as BottleneckX
|
||||
|
||||
|
||||
def is_block(modules):
|
||||
"""Check if is ResNeXt building block."""
|
||||
if isinstance(modules, (BottleneckX)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def test_resnext_bottleneck():
|
||||
def test_bottleneck():
|
||||
with pytest.raises(AssertionError):
|
||||
# Style must be in ['pytorch', 'caffe']
|
||||
BottleneckX(64, 64, groups=32, base_width=4, style='tensorflow')
|
||||
BottleneckX(64, 64, groups=32, width_per_group=4, style='tensorflow')
|
||||
|
||||
# Test ResNeXt Bottleneck structure
|
||||
block = BottleneckX(
|
||||
64, 64, groups=32, base_width=4, stride=2, style='pytorch')
|
||||
64, 256, groups=32, width_per_group=4, stride=2, style='pytorch')
|
||||
assert block.conv2.stride == (2, 2)
|
||||
assert block.conv2.groups == 32
|
||||
assert block.conv2.out_channels == 128
|
||||
|
||||
# Test ResNeXt Bottleneck forward
|
||||
block = BottleneckX(64, 16, groups=32, base_width=4)
|
||||
block = BottleneckX(64, 64, base_channels=16, groups=32, width_per_group=4)
|
||||
x = torch.randn(1, 64, 56, 56)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 64, 56, 56])
|
||||
|
||||
|
||||
def test_resnext_backbone():
|
||||
def test_resnext():
|
||||
with pytest.raises(KeyError):
|
||||
# ResNeXt depth should be in [50, 101, 152]
|
||||
ResNeXt(depth=18)
|
||||
|
||||
# Test ResNeXt with group 32, base_width 4
|
||||
# Test ResNeXt with group 32, width_per_group 4
|
||||
model = ResNeXt(
|
||||
depth=50, groups=32, base_width=4, out_indices=(0, 1, 2, 3))
|
||||
depth=50, groups=32, width_per_group=4, out_indices=(0, 1, 2, 3))
|
||||
for m in model.modules():
|
||||
if is_block(m):
|
||||
if isinstance(m, BottleneckX):
|
||||
assert m.conv2.groups == 32
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
@ -53,10 +46,10 @@ def test_resnext_backbone():
|
|||
assert feat[2].shape == torch.Size([1, 1024, 14, 14])
|
||||
assert feat[3].shape == torch.Size([1, 2048, 7, 7])
|
||||
|
||||
# Test ResNeXt with group 32, base_width 4 and layers 3 out forward
|
||||
model = ResNeXt(depth=50, groups=32, base_width=4, out_indices=(3, ))
|
||||
# Test ResNeXt with group 32, width_per_group 4 and layers 3 out forward
|
||||
model = ResNeXt(depth=50, groups=32, width_per_group=4, out_indices=(3, ))
|
||||
for m in model.modules():
|
||||
if is_block(m):
|
||||
if isinstance(m, BottleneckX):
|
||||
assert m.conv2.groups == 32
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
|
|
@ -8,20 +8,6 @@ from mmcls.models.backbones.resnet import ResLayer
|
|||
from mmcls.models.backbones.seresnet import SEBottleneck, SELayer
|
||||
|
||||
|
||||
def is_block(modules):
|
||||
"""Check if is ResNet building block."""
|
||||
if isinstance(modules, (SEBottleneck, )):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_norm(modules):
|
||||
"""Check if is one of the norms."""
|
||||
if isinstance(modules, (_BatchNorm, )):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def all_zeros(modules):
|
||||
"""Check if the weight(and bias) is all zero."""
|
||||
weight_zero = torch.equal(modules.weight.data,
|
||||
|
@ -44,7 +30,7 @@ def check_norm_state(modules, train_state):
|
|||
return True
|
||||
|
||||
|
||||
def test_serenet_selayer():
|
||||
def test_selayer():
|
||||
# Test selayer forward
|
||||
layer = SELayer(64)
|
||||
x = torch.randn(1, 64, 56, 56)
|
||||
|
@ -58,37 +44,37 @@ def test_serenet_selayer():
|
|||
assert x_out.shape == torch.Size([1, 64, 56, 56])
|
||||
|
||||
|
||||
def test_seresnet_bottleneckse():
|
||||
def test_bottleneck():
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Style must be in ['pytorch', 'caffe']
|
||||
SEBottleneck(64, 64, style='tensorflow')
|
||||
|
||||
# Test SEBottleneck with checkpoint forward
|
||||
block = SEBottleneck(64, 16, with_cp=True)
|
||||
block = SEBottleneck(64, 64, with_cp=True)
|
||||
assert block.with_cp
|
||||
x = torch.randn(1, 64, 56, 56)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 64, 56, 56])
|
||||
|
||||
# Test Bottleneck style
|
||||
block = SEBottleneck(64, 64, stride=2, style='pytorch')
|
||||
block = SEBottleneck(64, 256, stride=2, style='pytorch')
|
||||
assert block.conv1.stride == (1, 1)
|
||||
assert block.conv2.stride == (2, 2)
|
||||
block = SEBottleneck(64, 64, stride=2, style='caffe')
|
||||
block = SEBottleneck(64, 256, stride=2, style='caffe')
|
||||
assert block.conv1.stride == (2, 2)
|
||||
assert block.conv2.stride == (1, 1)
|
||||
|
||||
# Test Bottleneck forward
|
||||
block = SEBottleneck(64, 16)
|
||||
block = SEBottleneck(64, 64)
|
||||
x = torch.randn(1, 64, 56, 56)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 64, 56, 56])
|
||||
|
||||
|
||||
def test_seresnet_res_layer():
|
||||
def test_res_layer():
|
||||
# Test ResLayer of 3 Bottleneck w\o downsample
|
||||
layer = ResLayer(SEBottleneck, 64, 16, 3, se_ratio=16)
|
||||
layer = ResLayer(SEBottleneck, 3, 64, 64, se_ratio=16)
|
||||
assert len(layer) == 3
|
||||
assert layer[0].conv1.in_channels == 64
|
||||
assert layer[0].conv1.out_channels == 16
|
||||
|
@ -102,7 +88,7 @@ def test_seresnet_res_layer():
|
|||
assert x_out.shape == torch.Size([1, 64, 56, 56])
|
||||
|
||||
# Test ResLayer of 3 SEBottleneck with downsample
|
||||
layer = ResLayer(SEBottleneck, 64, 64, 3, se_ratio=16)
|
||||
layer = ResLayer(SEBottleneck, 3, 64, 256, se_ratio=16)
|
||||
assert layer[0].downsample[0].out_channels == 256
|
||||
for i in range(1, len(layer)):
|
||||
assert layer[i].downsample is None
|
||||
|
@ -111,7 +97,7 @@ def test_seresnet_res_layer():
|
|||
assert x_out.shape == torch.Size([1, 256, 56, 56])
|
||||
|
||||
# Test ResLayer of 3 SEBottleneck with stride=2
|
||||
layer = ResLayer(SEBottleneck, 64, 64, 3, stride=2, se_ratio=8)
|
||||
layer = ResLayer(SEBottleneck, 3, 64, 256, stride=2, se_ratio=8)
|
||||
assert layer[0].downsample[0].out_channels == 256
|
||||
assert layer[0].downsample[0].stride == (2, 2)
|
||||
for i in range(1, len(layer)):
|
||||
|
@ -122,7 +108,7 @@ def test_seresnet_res_layer():
|
|||
|
||||
# Test ResLayer of 3 SEBottleneck with stride=2 and average downsample
|
||||
layer = ResLayer(
|
||||
SEBottleneck, 64, 64, 3, stride=2, avg_down=True, se_ratio=8)
|
||||
SEBottleneck, 3, 64, 256, stride=2, avg_down=True, se_ratio=8)
|
||||
assert isinstance(layer[0].downsample[0], AvgPool2d)
|
||||
assert layer[0].downsample[1].out_channels == 256
|
||||
assert layer[0].downsample[1].stride == (1, 1)
|
||||
|
@ -133,7 +119,7 @@ def test_seresnet_res_layer():
|
|||
assert x_out.shape == torch.Size([1, 256, 28, 28])
|
||||
|
||||
|
||||
def test_seresnet_backbone():
|
||||
def test_seresnet():
|
||||
"""Test resnet backbone"""
|
||||
with pytest.raises(KeyError):
|
||||
# SEResNet depth should be in [50, 101, 152]
|
||||
|
@ -191,9 +177,6 @@ def test_seresnet_backbone():
|
|||
|
||||
# Test SEResNet50 with BatchNorm forward
|
||||
model = SEResNet(50, out_indices=(0, 1, 2, 3))
|
||||
for m in model.modules():
|
||||
if is_norm(m):
|
||||
assert isinstance(m, _BatchNorm)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
|
@ -229,7 +212,7 @@ def test_seresnet_backbone():
|
|||
# Test SEResNet50 with checkpoint forward
|
||||
model = SEResNet(50, out_indices=(0, 1, 2, 3), with_cp=True)
|
||||
for m in model.modules():
|
||||
if is_block(m):
|
||||
if isinstance(m, SEBottleneck):
|
||||
assert m.with_cp
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
|
|
@ -5,42 +5,35 @@ from mmcls.models.backbones import SEResNeXt
|
|||
from mmcls.models.backbones.seresnext import SEBottleneck as SEBottleneckX
|
||||
|
||||
|
||||
def is_block(modules):
|
||||
"""Check if is SEResNeXt building block."""
|
||||
if isinstance(modules, (SEBottleneckX)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def test_seresnext_bottleneck():
|
||||
def test_bottleneck():
|
||||
with pytest.raises(AssertionError):
|
||||
# Style must be in ['pytorch', 'caffe']
|
||||
SEBottleneckX(64, 64, groups=32, base_width=4, style='tensorflow')
|
||||
SEBottleneckX(64, 64, groups=32, width_per_group=4, style='tensorflow')
|
||||
|
||||
# Test SEResNeXt Bottleneck structure
|
||||
block = SEBottleneckX(
|
||||
64, 64, groups=32, base_width=4, stride=2, style='pytorch')
|
||||
64, 256, groups=32, width_per_group=4, stride=2, style='pytorch')
|
||||
assert block.conv2.stride == (2, 2)
|
||||
assert block.conv2.groups == 32
|
||||
assert block.conv2.out_channels == 128
|
||||
|
||||
# Test SEResNeXt Bottleneck forward
|
||||
block = SEBottleneckX(64, 16, groups=32, base_width=4)
|
||||
block = SEBottleneckX(64, 64, groups=32, width_per_group=4)
|
||||
x = torch.randn(1, 64, 56, 56)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 64, 56, 56])
|
||||
|
||||
|
||||
def test_seresnext_backbone():
|
||||
def test_seresnext():
|
||||
with pytest.raises(KeyError):
|
||||
# SEResNeXt depth should be in [50, 101, 152]
|
||||
SEResNeXt(depth=18)
|
||||
|
||||
# Test SEResNeXt with group 32, base_width 4
|
||||
# Test SEResNeXt with group 32, width_per_group 4
|
||||
model = SEResNeXt(
|
||||
depth=50, groups=32, base_width=4, out_indices=(0, 1, 2, 3))
|
||||
depth=50, groups=32, width_per_group=4, out_indices=(0, 1, 2, 3))
|
||||
for m in model.modules():
|
||||
if is_block(m):
|
||||
if isinstance(m, SEBottleneckX):
|
||||
assert m.conv2.groups == 32
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
@ -53,10 +46,11 @@ def test_seresnext_backbone():
|
|||
assert feat[2].shape == torch.Size([1, 1024, 14, 14])
|
||||
assert feat[3].shape == torch.Size([1, 2048, 7, 7])
|
||||
|
||||
# Test SEResNeXt with group 32, base_width 4 and layers 3 out forward
|
||||
model = SEResNeXt(depth=50, groups=32, base_width=4, out_indices=(3, ))
|
||||
# Test SEResNeXt with group 32, width_per_group 4 and layers 3 out forward
|
||||
model = SEResNeXt(
|
||||
depth=50, groups=32, width_per_group=4, out_indices=(3, ))
|
||||
for m in model.modules():
|
||||
if is_block(m):
|
||||
if isinstance(m, SEBottleneckX):
|
||||
assert m.conv2.groups == 32
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
|
Loading…
Reference in New Issue