Merge branch 'dev/resnext' into 'master'

Add ResNeXt

See merge request open-mmlab/mmclassification!14
pull/2/head
chenkai 2020-06-14 12:08:38 +08:00
commit 5c95cc4c5a
3 changed files with 200 additions and 1 deletions

View File

@ -1,3 +1,4 @@
from .resnet import ResNet, ResNetV1d
from .resnext import ResNeXt
__all__ = ['ResNet', 'ResNetV1d']
__all__ = ['ResNet', 'ResNeXt', 'ResNetV1d']

View File

@ -0,0 +1,132 @@
import math
from mmcv.cnn import build_conv_layer, build_norm_layer
from ..builder import BACKBONES
from .resnet import Bottleneck as _Bottleneck
from .resnet import ResLayer, ResNet
class Bottleneck(_Bottleneck):
"""Bottleneck block for ResNeXt.
Args:
inplanes (int): inplanes of block.
planes (int): planes of block.
groups (int): group of convolution.
base_width (int): Base width of resnext.
base_channels (int): Number of base channels of hidden layer.
stride (int): stride of the block. Default: 1
dilation (int): dilation of convolution. Default: 1
downsample (nn.Module): downsample operation on identity branch.
Default: None
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
conv_cfg (dict): dictionary to construct and config conv layer.
Default: None
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
"""
expansion = 4
def __init__(self,
inplanes,
planes,
groups=1,
base_width=4,
base_channels=64,
**kwargs):
super(Bottleneck, self).__init__(inplanes, planes, **kwargs)
if groups == 1:
width = self.planes
else:
width = math.floor(self.planes *
(base_width / base_channels)) * groups
self.norm1_name, norm1 = build_norm_layer(
self.norm_cfg, width, postfix=1)
self.norm2_name, norm2 = build_norm_layer(
self.norm_cfg, width, postfix=2)
self.norm3_name, norm3 = build_norm_layer(
self.norm_cfg, self.planes * self.expansion, postfix=3)
self.conv1 = build_conv_layer(
self.conv_cfg,
self.inplanes,
width,
kernel_size=1,
stride=self.conv1_stride,
bias=False)
self.add_module(self.norm1_name, norm1)
self.conv2 = build_conv_layer(
self.conv_cfg,
width,
width,
kernel_size=3,
stride=self.conv2_stride,
padding=self.dilation,
dilation=self.dilation,
groups=groups,
bias=False)
self.add_module(self.norm2_name, norm2)
self.conv3 = build_conv_layer(
self.conv_cfg,
width,
self.planes * self.expansion,
kernel_size=1,
bias=False)
self.add_module(self.norm3_name, norm3)
@BACKBONES.register_module()
class ResNeXt(ResNet):
"""ResNeXt backbone.
Args:
groups (int): Group of resnext.
base_width (int): Base width of resnext.
depth (int): Depth of resnext, from {50, 101, 152}.
in_channels (int): Number of input image channels. Default: 3.
base_channels (int): Number of base channels of hidden layer.
num_stages (int): Resnet stages. Default: 4.
strides (Sequence[int]): Strides of the first block of each stage.
dilations (Sequence[int]): Dilation of each stage.
out_indices (Sequence[int]): Output from which stages.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
not freezing any parameters.
norm_cfg (dict): dictionary to construct and config norm layer.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
zero_init_residual (bool): whether to use zero init for last norm layer
in resblocks to let them behave as identity.
"""
arch_settings = {
50: (Bottleneck, (3, 4, 6, 3)),
101: (Bottleneck, (3, 4, 23, 3)),
152: (Bottleneck, (3, 8, 36, 3))
}
def __init__(self, groups=1, base_width=4, **kwargs):
self.groups = groups
self.base_width = base_width
super(ResNeXt, self).__init__(**kwargs)
def make_res_layer(self, **kwargs):
return ResLayer(
groups=self.groups,
base_width=self.base_width,
base_channels=self.base_channels,
**kwargs)

View File

@ -0,0 +1,66 @@
import pytest
import torch
from mmcls.models.backbones import ResNeXt
from mmcls.models.backbones.resnext import Bottleneck as BottleneckX
def is_block(modules):
"""Check if is ResNeXt building block."""
if isinstance(modules, (BottleneckX)):
return True
return False
def test_resnext_bottleneck():
with pytest.raises(AssertionError):
# Style must be in ['pytorch', 'caffe']
BottleneckX(64, 64, groups=32, base_width=4, style='tensorflow')
# Test ResNeXt Bottleneck structure
block = BottleneckX(
64, 64, groups=32, base_width=4, stride=2, style='pytorch')
assert block.conv2.stride == (2, 2)
assert block.conv2.groups == 32
assert block.conv2.out_channels == 128
# Test ResNeXt Bottleneck forward
block = BottleneckX(64, 16, groups=32, base_width=4)
x = torch.randn(1, 64, 56, 56)
x_out = block(x)
assert x_out.shape == torch.Size([1, 64, 56, 56])
def test_resnext_backbone():
with pytest.raises(KeyError):
# ResNeXt depth should be in [50, 101, 152]
ResNeXt(depth=18)
# Test ResNeXt with group 32, base_width 4
model = ResNeXt(
depth=50, groups=32, base_width=4, out_indices=(0, 1, 2, 3))
for m in model.modules():
if is_block(m):
assert m.conv2.groups == 32
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == torch.Size([1, 256, 56, 56])
assert feat[1].shape == torch.Size([1, 512, 28, 28])
assert feat[2].shape == torch.Size([1, 1024, 14, 14])
assert feat[3].shape == torch.Size([1, 2048, 7, 7])
# Test ResNeXt with group 32, base_width 4 and layers 3 out forward
model = ResNeXt(depth=50, groups=32, base_width=4, out_indices=(3, ))
for m in model.modules():
if is_block(m):
assert m.conv2.groups == 32
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert feat.shape == torch.Size([1, 2048, 7, 7])