diff --git a/mmcls/models/backbones/__init__.py b/mmcls/models/backbones/__init__.py index 6b999624b..ecf0f328a 100644 --- a/mmcls/models/backbones/__init__.py +++ b/mmcls/models/backbones/__init__.py @@ -1,10 +1,11 @@ from .mobilenet_v2 import MobileNetV2 from .resnet import ResNet, ResNetV1d from .resnext import ResNeXt +from .seresnet import SEResNet from .shufflenet_v1 import ShuffleNetV1 from .shufflenet_v2 import ShuffleNetV2 __all__ = [ - 'ResNet', 'ResNeXt', 'ResNetV1d', 'ResNetV1d', 'ShuffleNetV1', + 'ResNet', 'ResNeXt', 'ResNetV1d', 'ResNetV1d', 'SEResNet', 'ShuffleNetV1', 'ShuffleNetV2', 'MobileNetV2' ] diff --git a/mmcls/models/backbones/seresnet.py b/mmcls/models/backbones/seresnet.py new file mode 100644 index 000000000..9420b9efe --- /dev/null +++ b/mmcls/models/backbones/seresnet.py @@ -0,0 +1,115 @@ +import torch.utils.checkpoint as cp + +from ..builder import BACKBONES +from ..utils.se_layer import SELayer +from .resnet import Bottleneck, ResLayer, ResNet + + +class SEBottleneck(Bottleneck): + """SEBottleneck block for SEResNet. + + Args: + inplanes (int): The input channels of the SEBottleneck block. + planes (int): The output channel base of the SEBottleneck block. + se_ratio (int): Squeeze ratio in SELayer. Default: 16 + """ + expansion = 4 + + def __init__(self, inplanes, planes, se_ratio=16, **kwargs): + super(SEBottleneck, self).__init__(inplanes, planes, **kwargs) + self.se_layer = SELayer(planes * self.expansion, ratio=se_ratio) + + def forward(self, x): + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.norm3(out) + + out = self.se_layer(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +@BACKBONES.register_module() +class SEResNet(ResNet): + """SEResNet backbone. + + Args: + depth (int): Depth of seresnet, from {50, 101, 152}. + in_channels (int): Number of input image channels. Normally 3. + base_channels (int): Number of base channels of hidden layer. + num_stages (int): Resnet stages, normally 4. + strides (Sequence[int]): Strides of the first block of each stage. + dilations (Sequence[int]): Dilation of each stage. + out_indices (Sequence[int]): Output from which stages. + se_ratio (int): Squeeze ratio in SELayer. Default: 16 + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + norm_cfg (dict): Dictionary to construct and config norm layer. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. + + Example: + >>> from mmcls.models import SEResNet + >>> import torch + >>> self = SEResNet(depth=50) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 224, 224) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 64, 56, 56) + (1, 128, 28, 28) + (1, 256, 14, 14) + (1, 512, 7, 7) + """ + + arch_settings = { + 50: (SEBottleneck, (3, 4, 6, 3)), + 101: (SEBottleneck, (3, 4, 23, 3)), + 152: (SEBottleneck, (3, 8, 36, 3)) + } + + def __init__(self, depth, se_ratio=16, **kwargs): + if depth not in self.arch_settings: + raise KeyError(f'invalid depth {depth} for resnet') + self.se_ratio = se_ratio + super(SEResNet, self).__init__(depth, **kwargs) + + def make_res_layer(self, **kwargs): + return ResLayer(se_ratio=self.se_ratio, **kwargs) diff --git a/mmcls/models/utils/se_layer.py b/mmcls/models/utils/se_layer.py new file mode 100644 index 000000000..93028c57c --- /dev/null +++ b/mmcls/models/utils/se_layer.py @@ -0,0 +1,30 @@ +import torch.nn as nn + + +class SELayer(nn.Module): + """Squeeze-and-Excitation Module. + + Args: + inplanes (int): The input channels of the SEBottleneck block. + ratio (int): Squeeze ratio in SELayer. Default: 16 + """ + + def __init__(self, inplanes, ratio=16): + super(SELayer, self).__init__() + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + self.conv1 = nn.Conv2d( + inplanes, int(inplanes / ratio), kernel_size=1, stride=1) + self.conv2 = nn.Conv2d( + int(inplanes / ratio), inplanes, kernel_size=1, stride=1) + self.relu = nn.ReLU(inplace=True) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + out = self.global_avgpool(x) + + out = self.conv1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.sigmoid(out) + return x * out diff --git a/tests/test_backbones/test_seresnet.py b/tests/test_backbones/test_seresnet.py new file mode 100644 index 000000000..f7039f97a --- /dev/null +++ b/tests/test_backbones/test_seresnet.py @@ -0,0 +1,259 @@ +import pytest +import torch +from torch.nn.modules import AvgPool2d +from torch.nn.modules.batchnorm import _BatchNorm + +from mmcls.models.backbones import SEResNet +from mmcls.models.backbones.resnet import ResLayer +from mmcls.models.backbones.seresnet import SEBottleneck, SELayer + + +def is_block(modules): + """Check if is ResNet building block.""" + if isinstance(modules, (SEBottleneck, )): + return True + return False + + +def is_norm(modules): + """Check if is one of the norms.""" + if isinstance(modules, (_BatchNorm, )): + return True + return False + + +def all_zeros(modules): + """Check if the weight(and bias) is all zero.""" + weight_zero = torch.equal(modules.weight.data, + torch.zeros_like(modules.weight.data)) + if hasattr(modules, 'bias'): + bias_zero = torch.equal(modules.bias.data, + torch.zeros_like(modules.bias.data)) + else: + bias_zero = True + + return weight_zero and bias_zero + + +def check_norm_state(modules, train_state): + """Check if norm layer is in correct train state.""" + for mod in modules: + if isinstance(mod, _BatchNorm): + if mod.training != train_state: + return False + return True + + +def test_serenet_selayer(): + # Test selayer forward + layer = SELayer(64) + x = torch.randn(1, 64, 56, 56) + x_out = layer(x) + assert x_out.shape == torch.Size([1, 64, 56, 56]) + + # Test selayer forward with different ratio + layer = SELayer(64, ratio=8) + x = torch.randn(1, 64, 56, 56) + x_out = layer(x) + assert x_out.shape == torch.Size([1, 64, 56, 56]) + + +def test_seresnet_bottleneckse(): + + with pytest.raises(AssertionError): + # Style must be in ['pytorch', 'caffe'] + SEBottleneck(64, 64, style='tensorflow') + + # Test SEBottleneck with checkpoint forward + block = SEBottleneck(64, 16, with_cp=True) + assert block.with_cp + x = torch.randn(1, 64, 56, 56) + x_out = block(x) + assert x_out.shape == torch.Size([1, 64, 56, 56]) + + # Test Bottleneck style + block = SEBottleneck(64, 64, stride=2, style='pytorch') + assert block.conv1.stride == (1, 1) + assert block.conv2.stride == (2, 2) + block = SEBottleneck(64, 64, stride=2, style='caffe') + assert block.conv1.stride == (2, 2) + assert block.conv2.stride == (1, 1) + + # Test Bottleneck forward + block = SEBottleneck(64, 16) + x = torch.randn(1, 64, 56, 56) + x_out = block(x) + assert x_out.shape == torch.Size([1, 64, 56, 56]) + + +def test_seresnet_res_layer(): + # Test ResLayer of 3 Bottleneck w\o downsample + layer = ResLayer(SEBottleneck, 64, 16, 3, se_ratio=16) + assert len(layer) == 3 + assert layer[0].conv1.in_channels == 64 + assert layer[0].conv1.out_channels == 16 + for i in range(1, len(layer)): + assert layer[i].conv1.in_channels == 64 + assert layer[i].conv1.out_channels == 16 + for i in range(len(layer)): + assert layer[i].downsample is None + x = torch.randn(1, 64, 56, 56) + x_out = layer(x) + assert x_out.shape == torch.Size([1, 64, 56, 56]) + + # Test ResLayer of 3 SEBottleneck with downsample + layer = ResLayer(SEBottleneck, 64, 64, 3, se_ratio=16) + assert layer[0].downsample[0].out_channels == 256 + for i in range(1, len(layer)): + assert layer[i].downsample is None + x = torch.randn(1, 64, 56, 56) + x_out = layer(x) + assert x_out.shape == torch.Size([1, 256, 56, 56]) + + # Test ResLayer of 3 SEBottleneck with stride=2 + layer = ResLayer(SEBottleneck, 64, 64, 3, stride=2, se_ratio=8) + assert layer[0].downsample[0].out_channels == 256 + assert layer[0].downsample[0].stride == (2, 2) + for i in range(1, len(layer)): + assert layer[i].downsample is None + x = torch.randn(1, 64, 56, 56) + x_out = layer(x) + assert x_out.shape == torch.Size([1, 256, 28, 28]) + + # Test ResLayer of 3 SEBottleneck with stride=2 and average downsample + layer = ResLayer( + SEBottleneck, 64, 64, 3, stride=2, avg_down=True, se_ratio=8) + assert isinstance(layer[0].downsample[0], AvgPool2d) + assert layer[0].downsample[1].out_channels == 256 + assert layer[0].downsample[1].stride == (1, 1) + for i in range(1, len(layer)): + assert layer[i].downsample is None + x = torch.randn(1, 64, 56, 56) + x_out = layer(x) + assert x_out.shape == torch.Size([1, 256, 28, 28]) + + +def test_seresnet_backbone(): + """Test resnet backbone""" + with pytest.raises(KeyError): + # SEResNet depth should be in [50, 101, 152] + SEResNet(20) + + with pytest.raises(AssertionError): + # In SEResNet: 1 <= num_stages <= 4 + SEResNet(50, num_stages=0) + + with pytest.raises(AssertionError): + # In SEResNet: 1 <= num_stages <= 4 + SEResNet(50, num_stages=5) + + with pytest.raises(AssertionError): + # len(strides) == len(dilations) == num_stages + SEResNet(50, strides=(1, ), dilations=(1, 1), num_stages=3) + + with pytest.raises(TypeError): + # pretrained must be a string path + model = SEResNet(50) + model.init_weights(pretrained=0) + + with pytest.raises(AssertionError): + # Style must be in ['pytorch', 'caffe'] + SEResNet(50, style='tensorflow') + + # Test SEResNet50 norm_eval=True + model = SEResNet(50, norm_eval=True) + model.init_weights() + model.train() + assert check_norm_state(model.modules(), False) + + # Test SEResNet50 with torchvision pretrained weight + model = SEResNet(depth=50, norm_eval=True) + model.init_weights('torchvision://resnet50') + model.train() + assert check_norm_state(model.modules(), False) + + # Test SEResNet50 with first stage frozen + frozen_stages = 1 + model = SEResNet(50, frozen_stages=frozen_stages) + model.init_weights() + model.train() + assert model.norm1.training is False + for layer in [model.conv1, model.norm1]: + for param in layer.parameters(): + assert param.requires_grad is False + for i in range(1, frozen_stages + 1): + layer = getattr(model, f'layer{i}') + for mod in layer.modules(): + if isinstance(mod, _BatchNorm): + assert mod.training is False + for param in layer.parameters(): + assert param.requires_grad is False + + # Test SEResNet50 with BatchNorm forward + model = SEResNet(50, out_indices=(0, 1, 2, 3)) + for m in model.modules(): + if is_norm(m): + assert isinstance(m, _BatchNorm) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 4 + assert feat[0].shape == torch.Size([1, 256, 56, 56]) + assert feat[1].shape == torch.Size([1, 512, 28, 28]) + assert feat[2].shape == torch.Size([1, 1024, 14, 14]) + assert feat[3].shape == torch.Size([1, 2048, 7, 7]) + + # Test SEResNet50 with layers 1, 2, 3 out forward + model = SEResNet(50, out_indices=(0, 1, 2)) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 3 + assert feat[0].shape == torch.Size([1, 256, 56, 56]) + assert feat[1].shape == torch.Size([1, 512, 28, 28]) + assert feat[2].shape == torch.Size([1, 1024, 14, 14]) + + # Test SEResNet50 with layers 3 (top feature maps) out forward + model = SEResNet(50, out_indices=(3, )) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert feat.shape == torch.Size([1, 2048, 7, 7]) + + # Test SEResNet50 with checkpoint forward + model = SEResNet(50, out_indices=(0, 1, 2, 3), with_cp=True) + for m in model.modules(): + if is_block(m): + assert m.with_cp + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 4 + assert feat[0].shape == torch.Size([1, 256, 56, 56]) + assert feat[1].shape == torch.Size([1, 512, 28, 28]) + assert feat[2].shape == torch.Size([1, 1024, 14, 14]) + assert feat[3].shape == torch.Size([1, 2048, 7, 7]) + + # Test SEResNet50 zero initialization of residual + model = SEResNet(50, out_indices=(0, 1, 2, 3), zero_init_residual=True) + model.init_weights() + for m in model.modules(): + if isinstance(m, SEBottleneck): + assert all_zeros(m.norm3) + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 4 + assert feat[0].shape == torch.Size([1, 256, 56, 56]) + assert feat[1].shape == torch.Size([1, 512, 28, 28]) + assert feat[2].shape == torch.Size([1, 1024, 14, 14]) + assert feat[3].shape == torch.Size([1, 2048, 7, 7])