From 76f91510b3fa5c4b6818e05ecd31ab5c542428bc Mon Sep 17 00:00:00 2001 From: yanglei Date: Sun, 14 Jun 2020 12:08:37 +0800 Subject: [PATCH] Add ResNeXt --- mmcls/models/backbones/__init__.py | 3 +- mmcls/models/backbones/resnext.py | 132 +++++++++++++++++++++++++++ tests/test_backbones/test_resnext.py | 66 ++++++++++++++ 3 files changed, 200 insertions(+), 1 deletion(-) create mode 100644 mmcls/models/backbones/resnext.py create mode 100644 tests/test_backbones/test_resnext.py diff --git a/mmcls/models/backbones/__init__.py b/mmcls/models/backbones/__init__.py index 5a34fb2a..25c35547 100644 --- a/mmcls/models/backbones/__init__.py +++ b/mmcls/models/backbones/__init__.py @@ -1,3 +1,4 @@ from .resnet import ResNet, ResNetV1d +from .resnext import ResNeXt -__all__ = ['ResNet', 'ResNetV1d'] +__all__ = ['ResNet', 'ResNeXt', 'ResNetV1d'] diff --git a/mmcls/models/backbones/resnext.py b/mmcls/models/backbones/resnext.py new file mode 100644 index 00000000..521e4be8 --- /dev/null +++ b/mmcls/models/backbones/resnext.py @@ -0,0 +1,132 @@ +import math + +from mmcv.cnn import build_conv_layer, build_norm_layer + +from ..builder import BACKBONES +from .resnet import Bottleneck as _Bottleneck +from .resnet import ResLayer, ResNet + + +class Bottleneck(_Bottleneck): + """Bottleneck block for ResNeXt. + + Args: + inplanes (int): inplanes of block. + planes (int): planes of block. + groups (int): group of convolution. + base_width (int): Base width of resnext. + base_channels (int): Number of base channels of hidden layer. + stride (int): stride of the block. Default: 1 + dilation (int): dilation of convolution. Default: 1 + downsample (nn.Module): downsample operation on identity branch. + Default: None + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + conv_cfg (dict): dictionary to construct and config conv layer. + Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + """ + + expansion = 4 + + def __init__(self, + inplanes, + planes, + groups=1, + base_width=4, + base_channels=64, + **kwargs): + super(Bottleneck, self).__init__(inplanes, planes, **kwargs) + + if groups == 1: + width = self.planes + else: + width = math.floor(self.planes * + (base_width / base_channels)) * groups + + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, width, postfix=1) + self.norm2_name, norm2 = build_norm_layer( + self.norm_cfg, width, postfix=2) + self.norm3_name, norm3 = build_norm_layer( + self.norm_cfg, self.planes * self.expansion, postfix=3) + + self.conv1 = build_conv_layer( + self.conv_cfg, + self.inplanes, + width, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + self.conv2 = build_conv_layer( + self.conv_cfg, + width, + width, + kernel_size=3, + stride=self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + groups=groups, + bias=False) + + self.add_module(self.norm2_name, norm2) + self.conv3 = build_conv_layer( + self.conv_cfg, + width, + self.planes * self.expansion, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + +@BACKBONES.register_module() +class ResNeXt(ResNet): + """ResNeXt backbone. + + Args: + groups (int): Group of resnext. + base_width (int): Base width of resnext. + depth (int): Depth of resnext, from {50, 101, 152}. + in_channels (int): Number of input image channels. Default: 3. + base_channels (int): Number of base channels of hidden layer. + num_stages (int): Resnet stages. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + dilations (Sequence[int]): Dilation of each stage. + out_indices (Sequence[int]): Output from which stages. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + frozen_stages (int): Stages to be frozen (all param fixed). -1 means + not freezing any parameters. + norm_cfg (dict): dictionary to construct and config norm layer. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + zero_init_residual (bool): whether to use zero init for last norm layer + in resblocks to let them behave as identity. + """ + + arch_settings = { + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)) + } + + def __init__(self, groups=1, base_width=4, **kwargs): + self.groups = groups + self.base_width = base_width + super(ResNeXt, self).__init__(**kwargs) + + def make_res_layer(self, **kwargs): + return ResLayer( + groups=self.groups, + base_width=self.base_width, + base_channels=self.base_channels, + **kwargs) diff --git a/tests/test_backbones/test_resnext.py b/tests/test_backbones/test_resnext.py new file mode 100644 index 00000000..65f09bce --- /dev/null +++ b/tests/test_backbones/test_resnext.py @@ -0,0 +1,66 @@ +import pytest +import torch + +from mmcls.models.backbones import ResNeXt +from mmcls.models.backbones.resnext import Bottleneck as BottleneckX + + +def is_block(modules): + """Check if is ResNeXt building block.""" + if isinstance(modules, (BottleneckX)): + return True + return False + + +def test_resnext_bottleneck(): + with pytest.raises(AssertionError): + # Style must be in ['pytorch', 'caffe'] + BottleneckX(64, 64, groups=32, base_width=4, style='tensorflow') + + # Test ResNeXt Bottleneck structure + block = BottleneckX( + 64, 64, groups=32, base_width=4, stride=2, style='pytorch') + assert block.conv2.stride == (2, 2) + assert block.conv2.groups == 32 + assert block.conv2.out_channels == 128 + + # Test ResNeXt Bottleneck forward + block = BottleneckX(64, 16, groups=32, base_width=4) + x = torch.randn(1, 64, 56, 56) + x_out = block(x) + assert x_out.shape == torch.Size([1, 64, 56, 56]) + + +def test_resnext_backbone(): + with pytest.raises(KeyError): + # ResNeXt depth should be in [50, 101, 152] + ResNeXt(depth=18) + + # Test ResNeXt with group 32, base_width 4 + model = ResNeXt( + depth=50, groups=32, base_width=4, out_indices=(0, 1, 2, 3)) + for m in model.modules(): + if is_block(m): + assert m.conv2.groups == 32 + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 4 + assert feat[0].shape == torch.Size([1, 256, 56, 56]) + assert feat[1].shape == torch.Size([1, 512, 28, 28]) + assert feat[2].shape == torch.Size([1, 1024, 14, 14]) + assert feat[3].shape == torch.Size([1, 2048, 7, 7]) + + # Test ResNeXt with group 32, base_width 4 and layers 3 out forward + model = ResNeXt(depth=50, groups=32, base_width=4, out_indices=(3, )) + for m in model.modules(): + if is_block(m): + assert m.conv2.groups == 32 + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert feat.shape == torch.Size([1, 2048, 7, 7])