From 2aaff0e4f2725e94659239ba5535267965b405dd Mon Sep 17 00:00:00 2001 From: lixiaojie Date: Wed, 3 Jun 2020 15:51:17 +0800 Subject: [PATCH] fix & add test --- mmcls/models/backbones/__init__.py | 5 ++ mmcls/models/backbones/base_backbone.py | 27 ++++++++++ mmcls/models/backbones/mobilenet_v2.py | 69 ++++++++++++------------- mmcls/models/backbones/weight_init.py | 66 +++++++++++++++++++++++ tests/test_backbone.py | 25 +++++++++ 5 files changed, 157 insertions(+), 35 deletions(-) create mode 100644 mmcls/models/backbones/base_backbone.py create mode 100644 mmcls/models/backbones/weight_init.py create mode 100644 tests/test_backbone.py diff --git a/mmcls/models/backbones/__init__.py b/mmcls/models/backbones/__init__.py index e69de29bb..f66558e04 100644 --- a/mmcls/models/backbones/__init__.py +++ b/mmcls/models/backbones/__init__.py @@ -0,0 +1,5 @@ +from .mobilenet_v2 import MobileNetv2 + +__all__ = [ + 'MobileNetv2', +] \ No newline at end of file diff --git a/mmcls/models/backbones/base_backbone.py b/mmcls/models/backbones/base_backbone.py new file mode 100644 index 000000000..703e0f59f --- /dev/null +++ b/mmcls/models/backbones/base_backbone.py @@ -0,0 +1,27 @@ +import logging +import torch.nn as nn + +from abc import ABCMeta, abstractmethod +from mmcv.runner import load_checkpoint + + +class BaseBackbone(nn.Module, metaclass=ABCMeta): + + def __init__(self): + super(BaseBackbone, self).__init__() + + def init_weights(self, pretrained=None): + if isinstance(pretrained, str): + logger = logging.getLogger() + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + pass + else: + raise TypeError('pretrained must be a str or None') + + @abstractmethod + def forward(self, x): + pass + + def train(self, mode=True): + super(BaseBackbone, self).train(mode) diff --git a/mmcls/models/backbones/mobilenet_v2.py b/mmcls/models/backbones/mobilenet_v2.py index 82e100ae0..c2c45d0cc 100644 --- a/mmcls/models/backbones/mobilenet_v2.py +++ b/mmcls/models/backbones/mobilenet_v2.py @@ -3,7 +3,7 @@ import logging import torch.nn as nn import torch.utils.checkpoint as cp -from ..runner import load_checkpoint +# from ..runner import load_checkpoint from .base_backbone import BaseBackbone from .weight_init import constant_init, kaiming_init @@ -20,11 +20,11 @@ def conv3x3(in_planes, out_planes, stride=1, dilation=1): bias=False) -def conv_1x1_bn(inp, oup, act=nn.ReLU6): +def conv_1x1_bn(inp, oup, activation=nn.ReLU6): return nn.Sequential( nn.Conv2d(inp, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), - act(inplace=True) + activation(inplace=True) ) @@ -38,11 +38,6 @@ class ConvBNReLU(nn.Sequential): activation=nn.ReLU6): padding = (kernel_size - 1) // 2 - try: - self.activation = activation(inplace=True) - except RuntimeWarning('inplace is not allowed to use'): - self.activation = activation() - super(ConvBNReLU, self).__init__( nn.Conv2d(in_planes, out_planes, @@ -52,7 +47,7 @@ class ConvBNReLU(nn.Sequential): groups=groups, bias=False), nn.BatchNorm2d(out_planes), - self.activation + activation(inplace=True) ) @@ -122,20 +117,21 @@ def make_inverted_res_layer(block, num_blocks, stride=1, expand_ratio=6, - activation_type=nn.ReLU6, + activation=nn.ReLU6, with_cp=False): layers = [] for i in range(num_blocks): if i == 0: layers.append(block(inplanes, planes, stride, expand_ratio=expand_ratio, - activation=activation_type, + activation=activation, with_cp=with_cp)) else: layers.append(block(inplanes, planes, 1, expand_ratio=expand_ratio, - activation=activation_type, + activation=activation, with_cp=with_cp)) + inplanes = planes return nn.Sequential(*layers) @@ -165,23 +161,20 @@ class MobileNetv2(BaseBackbone): with_cp=False): super(MobileNetv2, self).__init__() block = InvertedResidual - - inverted_residual_setting = { - # lager_index: [expand_ratio, out_channel, n, stide] - 0: [1, 16, 1, 1], - 1: [6, 24, 2, 2], - 2: [6, 32, 3, 2], - 3: [6, 64, 4, 2], - 4: [6, 96, 3, 1], - 5: [6, 160, 3, 2], - 6: [6, 320, 1, 1] - } + # expand_ratio, out_channel, n, stride + inverted_residual_setting = [ + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + [6, 96, 3, 1], + [6, 160, 3, 2], + [6, 320, 1, 1] + ] self.widen_factor = widen_factor - self.activation_type = activation - try: - self.activation = activation(inplace=True) - except RuntimeWarning('inplace is not allowed to use'): - self.activation = activation() + if isinstance(activation, str): + activation = eval(activation) + self.activation = activation(inplace=True) self.out_indices = out_indices self.frozen_stages = frozen_stages @@ -191,11 +184,13 @@ class MobileNetv2(BaseBackbone): self.inplanes = 32 self.inplanes = _make_divisible(self.inplanes * widen_factor, 8) - self.conv1 = conv3x3(3, self.inplanes, stride=2) + self.conv1 = conv3x3(3, self.inplanes, stride=2) + self.bn1 = nn.BatchNorm2d(self.inplanes) self.inverted_res_layers = [] - for i, later_cfg in enumerate(inverted_residual_setting): - t, c, n, s = later_cfg + + for i, layer_cfg in enumerate(inverted_residual_setting): + t, c, n, s = layer_cfg planes = _make_divisible(c * widen_factor, 8) inverted_res_layer = make_inverted_res_layer( block, @@ -204,7 +199,7 @@ class MobileNetv2(BaseBackbone): num_blocks=n, stride=s, expand_ratio=t, - activation_type=self.activation_type, + activation=activation, with_cp=self.with_cp) self.inplanes = planes layer_name = 'layer{}'.format(i + 1) @@ -214,7 +209,9 @@ class MobileNetv2(BaseBackbone): self.out_channel = 1280 self.out_channel = int(self.out_channel * widen_factor) \ if widen_factor > 1.0 else self.out_channel - self.conv1_bn = conv_1x1_bn(self.inplanes, self.out_channel) + + self.conv_last = nn.Conv2d(self.inplanes, self.out_channel, 1, 1, 0, bias=False) + self.bn_last = nn.BatchNorm2d(self.out_channel) self.feat_dim = self.out_channel @@ -233,7 +230,6 @@ class MobileNetv2(BaseBackbone): def forward(self, x): x = self.conv1(x) - x = self.bn1(x) x = self.activation(x) outs = [] @@ -243,7 +239,10 @@ class MobileNetv2(BaseBackbone): if i in self.out_indices: outs.append(x) - x = self.conv1_bn(x) + x = self.conv_last(x) + x = self.bn_last(x) + x = self.activation(x) + outs.append(x) if len(outs) == 1: diff --git a/mmcls/models/backbones/weight_init.py b/mmcls/models/backbones/weight_init.py new file mode 100644 index 000000000..e06e6ccaa --- /dev/null +++ b/mmcls/models/backbones/weight_init.py @@ -0,0 +1,66 @@ +# Copyright (c) Open-MMLab. All rights reserved. +import numpy as np +import torch.nn as nn + + +def constant_init(module, val, bias=0): + if hasattr(module, 'weight') and module.weight is not None: + nn.init.constant_(module.weight, val) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + + +def xavier_init(module, gain=1, bias=0, distribution='normal'): + assert distribution in ['uniform', 'normal'] + if distribution == 'uniform': + nn.init.xavier_uniform_(module.weight, gain=gain) + else: + nn.init.xavier_normal_(module.weight, gain=gain) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + + +def normal_init(module, mean=0, std=1, bias=0): + nn.init.normal_(module.weight, mean, std) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + + +def uniform_init(module, a=0, b=1, bias=0): + nn.init.uniform_(module.weight, a, b) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + + +def kaiming_init(module, + a=0, + mode='fan_out', + nonlinearity='relu', + bias=0, + distribution='normal'): + assert distribution in ['uniform', 'normal'] + if distribution == 'uniform': + nn.init.kaiming_uniform_( + module.weight, a=a, mode=mode, nonlinearity=nonlinearity) + else: + nn.init.kaiming_normal_( + module.weight, a=a, mode=mode, nonlinearity=nonlinearity) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + + +def caffe2_xavier_init(module, bias=0): + # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch + # Acknowledgment to FAIR's internal code + kaiming_init( + module, + a=1, + mode='fan_in', + nonlinearity='leaky_relu', + distribution='uniform') + + +def bias_init_with_prob(prior_prob): + """ initialize conv/fc bias value according to giving probablity""" + bias_init = float(-np.log((1 - prior_prob) / prior_prob)) + return bias_init diff --git a/tests/test_backbone.py b/tests/test_backbone.py new file mode 100644 index 000000000..db7bce950 --- /dev/null +++ b/tests/test_backbone.py @@ -0,0 +1,25 @@ +import pytest +import torch +import torch.nn as nn +from torch.nn.modules import AvgPool2d, GroupNorm +from torch.nn.modules.batchnorm import _BatchNorm + +from mmcls.models.backbones import MobileNetv2 + + +def test_mobilenetv2_backbone(): + # Test MobileNetv2 with widen_factor 1.0, activation nn.ReLU6 + model = MobileNetv2(widen_factor=1.0, activation=nn.ReLU6) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 8 + assert feat[0].shape == torch.Size([1, 16, 112, 112]) + assert feat[1].shape == torch.Size([1, 24, 56, 56]) + assert feat[2].shape == torch.Size([1, 32, 28, 28]) + assert feat[3].shape == torch.Size([1, 64, 14, 14]) + assert feat[4].shape == torch.Size([1, 96, 14, 14]) + assert feat[5].shape == torch.Size([1, 160, 7, 7]) + assert feat[6].shape == torch.Size([1, 320, 7, 7])