From 9a661ef981a7fff20ffd738331874c1d2b0b6de3 Mon Sep 17 00:00:00 2001 From: yanglei Date: Thu, 9 Jul 2020 14:19:15 +0800 Subject: [PATCH] Add ResNet_CIFAR --- mmcls/models/backbones/__init__.py | 5 +- mmcls/models/backbones/resnet_cifar.py | 83 +++++++++++++++++++++++ tests/test_backbones/test_resnet_cifar.py | 66 ++++++++++++++++++ 3 files changed, 152 insertions(+), 2 deletions(-) create mode 100644 mmcls/models/backbones/resnet_cifar.py create mode 100644 tests/test_backbones/test_resnet_cifar.py diff --git a/mmcls/models/backbones/__init__.py b/mmcls/models/backbones/__init__.py index 7ef1c114..40462207 100644 --- a/mmcls/models/backbones/__init__.py +++ b/mmcls/models/backbones/__init__.py @@ -4,6 +4,7 @@ from .mobilenet_v2 import MobileNetV2 from .mobilenet_v3 import MobileNetv3 from .regnet import RegNet from .resnet import ResNet, ResNetV1d +from .resnet_cifar import ResNet_CIFAR from .resnext import ResNeXt from .seresnet import SEResNet from .seresnext import SEResNeXt @@ -12,6 +13,6 @@ from .shufflenet_v2 import ShuffleNetV2 __all__ = [ 'LeNet5', 'AlexNet', 'RegNet', 'ResNet', 'ResNeXt', 'ResNetV1d', - 'ResNetV1d', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1', 'ShuffleNetV2', - 'MobileNetV2', 'MobileNetv3' + 'ResNetV1d', 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1', + 'ShuffleNetV2', 'MobileNetV2', 'MobileNetv3' ] diff --git a/mmcls/models/backbones/resnet_cifar.py b/mmcls/models/backbones/resnet_cifar.py new file mode 100644 index 00000000..d0759940 --- /dev/null +++ b/mmcls/models/backbones/resnet_cifar.py @@ -0,0 +1,83 @@ +import torch.nn as nn +from mmcv.cnn import build_conv_layer, build_norm_layer + +from ..builder import BACKBONES +from .resnet import ResNet + + +@BACKBONES.register_module() +class ResNet_CIFAR(ResNet): + """ResNet backbone for CIFAR. + + Compared to standard ResNet, it uses `kernel_size=3` and `stride=1` in + conv1, and does not apply MaxPoolinng after stem. It has been proven to + be more efficient than standard ResNet in other public codebase, e.g., + `https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py`. + + Args: + depth (int): Network depth, from {18, 34, 50, 101, 152}. + in_channels (int): Number of input image channels. Default: 3. + stem_channels (int): Output channels of the stem layer. Default: 64. + base_channels (int): Middle channels of the first stage. Default: 64. + num_stages (int): Stages of the network. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + Default: ``(1, 2, 2, 2)``. + dilations (Sequence[int]): Dilation of each stage. + Default: ``(1, 1, 1, 1)``. + out_indices (Sequence[int]): Output from which stages. If only one + stage is specified, a single tensor (feature map) is returned, + otherwise multiple stages are specified, a tuple of tensors will + be returned. Default: ``(3, )``. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): This network has specific designed stem, thus it is + asserted to be False. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Default: -1. + conv_cfg (dict | None): The config dict for conv layers. Default: None. + norm_cfg (dict): The config dict for norm layers. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: True. + """ + + def __init__(self, depth, deep_stem=False, **kwargs): + super(ResNet_CIFAR, self).__init__( + depth, deep_stem=deep_stem, **kwargs) + assert not self.deep_stem, 'ResNet_CIFAR do not support deep_stem' + + def _make_stem_layer(self, in_channels, base_channels): + self.conv1 = build_conv_layer( + self.conv_cfg, + in_channels, + base_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False) + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, base_channels, postfix=1) + self.add_module(self.norm1_name, norm1) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + outs = [] + for i, layer_name in enumerate(self.res_layers): + res_layer = getattr(self, layer_name) + x = res_layer(x) + if i in self.out_indices: + outs.append(x) + if len(outs) == 1: + return outs[0] + else: + return tuple(outs) diff --git a/tests/test_backbones/test_resnet_cifar.py b/tests/test_backbones/test_resnet_cifar.py new file mode 100644 index 00000000..533c2e05 --- /dev/null +++ b/tests/test_backbones/test_resnet_cifar.py @@ -0,0 +1,66 @@ +import pytest +import torch +from mmcv.utils.parrots_wrapper import _BatchNorm + +from mmcls.models.backbones import ResNet_CIFAR + + +def check_norm_state(modules, train_state): + """Check if norm layer is in correct train state.""" + for mod in modules: + if isinstance(mod, _BatchNorm): + if mod.training != train_state: + return False + return True + + +def test_resnet_cifar(): + # deep_stem must be False + with pytest.raises(AssertionError): + ResNet_CIFAR(depth=18, deep_stem=True) + + # test the feature map size when depth is 18 + model = ResNet_CIFAR(depth=18, out_indices=(0, 1, 2, 3)) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 32, 32) + feat = model.conv1(imgs) + assert feat.shape == (1, 64, 32, 32) + feat = model(imgs) + assert len(feat) == 4 + assert feat[0].shape == (1, 64, 32, 32) + assert feat[1].shape == (1, 128, 16, 16) + assert feat[2].shape == (1, 256, 8, 8) + assert feat[3].shape == (1, 512, 4, 4) + + # test the feature map size when depth is 50 + model = ResNet_CIFAR(depth=50, out_indices=(0, 1, 2, 3)) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 32, 32) + feat = model.conv1(imgs) + assert feat.shape == (1, 64, 32, 32) + feat = model(imgs) + assert len(feat) == 4 + assert feat[0].shape == (1, 256, 32, 32) + assert feat[1].shape == (1, 512, 16, 16) + assert feat[2].shape == (1, 1024, 8, 8) + assert feat[3].shape == (1, 2048, 4, 4) + + # Test ResNet_CIFAR with first stage frozen + frozen_stages = 1 + model = ResNet_CIFAR(depth=50, frozen_stages=frozen_stages) + model.init_weights() + model.train() + check_norm_state([model.norm1], False) + for param in model.conv1.parameters(): + assert param.requires_grad is False + for i in range(1, frozen_stages + 1): + layer = getattr(model, f'layer{i}') + for mod in layer.modules(): + if isinstance(mod, _BatchNorm): + assert mod.training is False + for param in layer.parameters(): + assert param.requires_grad is False