From 1fbb537958eb36ddd77567ee6837ec9610706715 Mon Sep 17 00:00:00 2001 From: Jerry Jiarui XU Date: Fri, 4 Sep 2020 15:35:52 +0800 Subject: [PATCH] [Feature] Support MobileNetV2 backbone (#86) * [Feature] Support MobileNetV2 backbone * Fixed import * Fixed test * Fixed test * Fixed dilate * upload model * update table * update table * update bibtex * update MMCV requirement --- configs/mobilenet_v2/README.md | 32 +++ ...eplabv3_m-v2-d8_512x1024_80k_cityscapes.py | 12 + .../deeplabv3_m-v2-d8_512x512_160k_ade20k.py | 12 + ...bv3plus_m-v2-d8_512x1024_80k_cityscapes.py | 12 + ...eplabv3plus_m-v2-d8_512x512_160k_ade20k.py | 12 + .../fcn_m-v2-d8_512x1024_80k_cityscapes.py | 12 + .../fcn_m-v2-d8_512x512_160k_ade20k.py | 12 + .../pspnet_m-v2-d8_512x1024_80k_cityscapes.py | 12 + .../pspnet_m-v2-d8_512x512_160k_ade20k.py | 12 + mmseg/__init__.py | 4 +- mmseg/models/backbones/__init__.py | 3 +- mmseg/models/backbones/mobilenet_v2.py | 270 ++++++++++++++++++ mmseg/models/utils/__init__.py | 3 +- mmseg/models/utils/make_divisible.py | 24 ++ tests/test_models/test_forward.py | 5 + 15 files changed, 433 insertions(+), 4 deletions(-) create mode 100644 configs/mobilenet_v2/README.md create mode 100644 configs/mobilenet_v2/deeplabv3_m-v2-d8_512x1024_80k_cityscapes.py create mode 100644 configs/mobilenet_v2/deeplabv3_m-v2-d8_512x512_160k_ade20k.py create mode 100644 configs/mobilenet_v2/deeplabv3plus_m-v2-d8_512x1024_80k_cityscapes.py create mode 100644 configs/mobilenet_v2/deeplabv3plus_m-v2-d8_512x512_160k_ade20k.py create mode 100644 configs/mobilenet_v2/fcn_m-v2-d8_512x1024_80k_cityscapes.py create mode 100644 configs/mobilenet_v2/fcn_m-v2-d8_512x512_160k_ade20k.py create mode 100644 configs/mobilenet_v2/pspnet_m-v2-d8_512x1024_80k_cityscapes.py create mode 100644 configs/mobilenet_v2/pspnet_m-v2-d8_512x512_160k_ade20k.py create mode 100644 mmseg/models/backbones/mobilenet_v2.py create mode 100644 mmseg/models/utils/make_divisible.py diff --git a/configs/mobilenet_v2/README.md b/configs/mobilenet_v2/README.md new file mode 100644 index 000000000..0dcbb4d38 --- /dev/null +++ b/configs/mobilenet_v2/README.md @@ -0,0 +1,32 @@ +# MobileNetV2: Inverted Residuals and Linear Bottlenecks + +## Introduction + +``` +@inproceedings{sandler2018mobilenetv2, + title={Mobilenetv2: Inverted residuals and linear bottlenecks}, + author={Sandler, Mark and Howard, Andrew and Zhu, Menglong and Zhmoginov, Andrey and Chen, Liang-Chieh}, + booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition}, + pages={4510--4520}, + year={2018} +} +``` + + +## Results and models + +### Cityscapes +| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download | +|------------|----------|-----------|--------:|---------:|----------------|------:|---------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| FCN | M-V2-D8 | 512x1024 | 80000 | 3.4 | 14.2 | 61.54 | - | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/mobilenet_v2/fcn_m-v2-d8_512x1024_80k_cityscapes/fcn_m-v2-d8_512x1024_80k_cityscapes_20200825_124817-d24c28c1.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/mobilenet_v2/fcn_m-v2-d8_512x1024_80k_cityscapes/fcn_m-v2-d8_512x1024_80k_cityscapes-20200825_124817.log.json) | +| PSPNet | M-V2-D8 | 512x1024 | 80000 | 3.6 | 11.2 | 70.23 | - | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/mobilenet_v2/pspnet_m-v2-d8_512x1024_80k_cityscapes/pspnet_m-v2-d8_512x1024_80k_cityscapes_20200825_124817-19e81d51.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/mobilenet_v2/pspnet_m-v2-d8_512x1024_80k_cityscapes/pspnet_m-v2-d8_512x1024_80k_cityscapes-20200825_124817.log.json) | +| DeepLabV3 | M-V2-D8 | 512x1024 | 80000 | 3.9 | 8.4 | 73.84 | - | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/mobilenet_v2/deeplabv3_m-v2-d8_512x1024_80k_cityscapes/deeplabv3_m-v2-d8_512x1024_80k_cityscapes_20200825_124836-bef03590.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/mobilenet_v2/deeplabv3_m-v2-d8_512x1024_80k_cityscapes/deeplabv3_m-v2-d8_512x1024_80k_cityscapes-20200825_124836.log.json) | +| DeepLabV3+ | M-V2-D8 | 512x1024 | 80000 | 5.1 | 8.4 | 75.20 | - | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/mobilenet_v2/deeplabv3plus_m-v2-d8_512x1024_80k_cityscapes/deeplabv3plus_m-v2-d8_512x1024_80k_cityscapes_20200825_124836-d256dd4b.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/mobilenet_v2/deeplabv3plus_m-v2-d8_512x1024_80k_cityscapes/deeplabv3plus_m-v2-d8_512x1024_80k_cityscapes-20200825_124836.log.json) | + +### ADE20k +| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download | +|------------|----------|-----------|--------:|---------:|----------------|------:|---------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| FCN | M-V2-D8 | 512x512 | 160000 | 6.5 | 64.4 | 19.71 | - | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/mobilenet_v2/fcn_m-v2-d8_512x512_160k_ade20k/fcn_m-v2-d8_512x512_160k_ade20k_20200825_214953-c40e1095.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/mobilenet_v2/fcn_m-v2-d8_512x512_160k_ade20k/fcn_m-v2-d8_512x512_160k_ade20k-20200825_214953.log.json) | +| PSPNet | M-V2-D8 | 512x512 | 160000 | 6.5 | 57.7 | 29.68 | - | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/mobilenet_v2/pspnet_m-v2-d8_512x512_160k_ade20k/pspnet_m-v2-d8_512x512_160k_ade20k_20200825_214953-f5942f7a.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/mobilenet_v2/pspnet_m-v2-d8_512x512_160k_ade20k/pspnet_m-v2-d8_512x512_160k_ade20k-20200825_214953.log.json) | +| DeepLabV3 | M-V2-D8 | 512x512 | 160000 | 6.8 | 39.9 | 34.08 | - | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/mobilenet_v2/deeplabv3_m-v2-d8_512x512_160k_ade20k/deeplabv3_m-v2-d8_512x512_160k_ade20k_20200825_223255-63986343.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/mobilenet_v2/deeplabv3_m-v2-d8_512x512_160k_ade20k/deeplabv3_m-v2-d8_512x512_160k_ade20k-20200825_223255.log.json) | +| DeepLabV3+ | M-V2-D8 | 512x512 | 160000 | 8.2 | 43.1 | 34.02 | - | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/mobilenet_v2/deeplabv3plus_m-v2-d8_512x512_160k_ade20k/deeplabv3plus_m-v2-d8_512x512_160k_ade20k_20200825_223255-465a01d4.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/mobilenet_v2/deeplabv3plus_m-v2-d8_512x512_160k_ade20k/deeplabv3plus_m-v2-d8_512x512_160k_ade20k-20200825_223255.log.json) | diff --git a/configs/mobilenet_v2/deeplabv3_m-v2-d8_512x1024_80k_cityscapes.py b/configs/mobilenet_v2/deeplabv3_m-v2-d8_512x1024_80k_cityscapes.py new file mode 100644 index 000000000..267483d88 --- /dev/null +++ b/configs/mobilenet_v2/deeplabv3_m-v2-d8_512x1024_80k_cityscapes.py @@ -0,0 +1,12 @@ +_base_ = '../deeplabv3/deeplabv3_r101-d8_512x1024_80k_cityscapes.py' +model = dict( + pretrained='mmcls://mobilenet_v2', + backbone=dict( + _delete_=True, + type='MobileNetV2', + widen_factor=1., + strides=(1, 2, 2, 1, 1, 1, 1), + dilations=(1, 1, 1, 2, 2, 4, 4), + out_indices=(1, 2, 4, 6)), + decode_head=dict(in_channels=320), + auxiliary_head=dict(in_channels=96)) diff --git a/configs/mobilenet_v2/deeplabv3_m-v2-d8_512x512_160k_ade20k.py b/configs/mobilenet_v2/deeplabv3_m-v2-d8_512x512_160k_ade20k.py new file mode 100644 index 000000000..e15b8cc82 --- /dev/null +++ b/configs/mobilenet_v2/deeplabv3_m-v2-d8_512x512_160k_ade20k.py @@ -0,0 +1,12 @@ +_base_ = '../deeplabv3/deeplabv3_r101-d8_512x512_160k_ade20k.py' +model = dict( + pretrained='mmcls://mobilenet_v2', + backbone=dict( + _delete_=True, + type='MobileNetV2', + widen_factor=1., + strides=(1, 2, 2, 1, 1, 1, 1), + dilations=(1, 1, 1, 2, 2, 4, 4), + out_indices=(1, 2, 4, 6)), + decode_head=dict(in_channels=320), + auxiliary_head=dict(in_channels=96)) diff --git a/configs/mobilenet_v2/deeplabv3plus_m-v2-d8_512x1024_80k_cityscapes.py b/configs/mobilenet_v2/deeplabv3plus_m-v2-d8_512x1024_80k_cityscapes.py new file mode 100644 index 000000000..d4533d79a --- /dev/null +++ b/configs/mobilenet_v2/deeplabv3plus_m-v2-d8_512x1024_80k_cityscapes.py @@ -0,0 +1,12 @@ +_base_ = '../deeplabv3plus/deeplabv3plus_r101-d8_512x1024_80k_cityscapes.py' +model = dict( + pretrained='mmcls://mobilenet_v2', + backbone=dict( + _delete_=True, + type='MobileNetV2', + widen_factor=1., + strides=(1, 2, 2, 1, 1, 1, 1), + dilations=(1, 1, 1, 2, 2, 4, 4), + out_indices=(1, 2, 4, 6)), + decode_head=dict(in_channels=320, c1_in_channels=24), + auxiliary_head=dict(in_channels=96)) diff --git a/configs/mobilenet_v2/deeplabv3plus_m-v2-d8_512x512_160k_ade20k.py b/configs/mobilenet_v2/deeplabv3plus_m-v2-d8_512x512_160k_ade20k.py new file mode 100644 index 000000000..7615a7c19 --- /dev/null +++ b/configs/mobilenet_v2/deeplabv3plus_m-v2-d8_512x512_160k_ade20k.py @@ -0,0 +1,12 @@ +_base_ = '../deeplabv3plus/deeplabv3plus_r101-d8_512x512_160k_ade20k.py' +model = dict( + pretrained='mmcls://mobilenet_v2', + backbone=dict( + _delete_=True, + type='MobileNetV2', + widen_factor=1., + strides=(1, 2, 2, 1, 1, 1, 1), + dilations=(1, 1, 1, 2, 2, 4, 4), + out_indices=(1, 2, 4, 6)), + decode_head=dict(in_channels=320, c1_in_channels=24), + auxiliary_head=dict(in_channels=96)) diff --git a/configs/mobilenet_v2/fcn_m-v2-d8_512x1024_80k_cityscapes.py b/configs/mobilenet_v2/fcn_m-v2-d8_512x1024_80k_cityscapes.py new file mode 100644 index 000000000..a535bd0ed --- /dev/null +++ b/configs/mobilenet_v2/fcn_m-v2-d8_512x1024_80k_cityscapes.py @@ -0,0 +1,12 @@ +_base_ = '../fcn/fcn_r101-d8_512x1024_80k_cityscapes.py' +model = dict( + pretrained='mmcls://mobilenet_v2', + backbone=dict( + _delete_=True, + type='MobileNetV2', + widen_factor=1., + strides=(1, 2, 2, 1, 1, 1, 1), + dilations=(1, 1, 1, 2, 2, 4, 4), + out_indices=(1, 2, 4, 6)), + decode_head=dict(in_channels=320), + auxiliary_head=dict(in_channels=96)) diff --git a/configs/mobilenet_v2/fcn_m-v2-d8_512x512_160k_ade20k.py b/configs/mobilenet_v2/fcn_m-v2-d8_512x512_160k_ade20k.py new file mode 100644 index 000000000..c5f6ab0d6 --- /dev/null +++ b/configs/mobilenet_v2/fcn_m-v2-d8_512x512_160k_ade20k.py @@ -0,0 +1,12 @@ +_base_ = '../fcn/fcn_r101-d8_512x512_160k_ade20k.py' +model = dict( + pretrained='mmcls://mobilenet_v2', + backbone=dict( + _delete_=True, + type='MobileNetV2', + widen_factor=1., + strides=(1, 2, 2, 1, 1, 1, 1), + dilations=(1, 1, 1, 2, 2, 4, 4), + out_indices=(1, 2, 4, 6)), + decode_head=dict(in_channels=320), + auxiliary_head=dict(in_channels=96)) diff --git a/configs/mobilenet_v2/pspnet_m-v2-d8_512x1024_80k_cityscapes.py b/configs/mobilenet_v2/pspnet_m-v2-d8_512x1024_80k_cityscapes.py new file mode 100644 index 000000000..7403bee86 --- /dev/null +++ b/configs/mobilenet_v2/pspnet_m-v2-d8_512x1024_80k_cityscapes.py @@ -0,0 +1,12 @@ +_base_ = '../pspnet/pspnet_r101-d8_512x1024_80k_cityscapes.py' +model = dict( + pretrained='mmcls://mobilenet_v2', + backbone=dict( + _delete_=True, + type='MobileNetV2', + widen_factor=1., + strides=(1, 2, 2, 1, 1, 1, 1), + dilations=(1, 1, 1, 2, 2, 4, 4), + out_indices=(1, 2, 4, 6)), + decode_head=dict(in_channels=320), + auxiliary_head=dict(in_channels=96)) diff --git a/configs/mobilenet_v2/pspnet_m-v2-d8_512x512_160k_ade20k.py b/configs/mobilenet_v2/pspnet_m-v2-d8_512x512_160k_ade20k.py new file mode 100644 index 000000000..5b72ac830 --- /dev/null +++ b/configs/mobilenet_v2/pspnet_m-v2-d8_512x512_160k_ade20k.py @@ -0,0 +1,12 @@ +_base_ = '../pspnet/pspnet_r101-d8_512x512_160k_ade20k.py' +model = dict( + pretrained='mmcls://mobilenet_v2', + backbone=dict( + _delete_=True, + type='MobileNetV2', + widen_factor=1., + strides=(1, 2, 2, 1, 1, 1, 1), + dilations=(1, 1, 1, 2, 2, 4, 4), + out_indices=(1, 2, 4, 6)), + decode_head=dict(in_channels=320), + auxiliary_head=dict(in_channels=96)) diff --git a/mmseg/__init__.py b/mmseg/__init__.py index 0ab33a8ed..20bce069a 100644 --- a/mmseg/__init__.py +++ b/mmseg/__init__.py @@ -2,8 +2,8 @@ import mmcv from .version import __version__, version_info -MMCV_MIN = '1.0.5' -MMCV_MAX = '1.1.1' +MMCV_MIN = '1.1.2' +MMCV_MAX = '1.2.0' def digit_version(version_str): diff --git a/mmseg/models/backbones/__init__.py b/mmseg/models/backbones/__init__.py index 0cb2ec17b..6253bab42 100644 --- a/mmseg/models/backbones/__init__.py +++ b/mmseg/models/backbones/__init__.py @@ -1,10 +1,11 @@ from .fast_scnn import FastSCNN from .hrnet import HRNet +from .mobilenet_v2 import MobileNetV2 from .resnest import ResNeSt from .resnet import ResNet, ResNetV1c, ResNetV1d from .resnext import ResNeXt __all__ = [ 'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN', - 'ResNeSt' + 'ResNeSt', 'MobileNetV2' ] diff --git a/mmseg/models/backbones/mobilenet_v2.py b/mmseg/models/backbones/mobilenet_v2.py new file mode 100644 index 000000000..5fff485f0 --- /dev/null +++ b/mmseg/models/backbones/mobilenet_v2.py @@ -0,0 +1,270 @@ +import logging + +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import ConvModule, constant_init, kaiming_init +from mmcv.runner import load_checkpoint +from torch.nn.modules.batchnorm import _BatchNorm + +from ..builder import BACKBONES +from ..utils import make_divisible + + +class InvertedResidual(nn.Module): + """InvertedResidual block for MobileNetV2. + + Args: + in_channels (int): The input channels of the InvertedResidual block. + out_channels (int): The output channels of the InvertedResidual block. + stride (int): Stride of the middle (first) 3x3 convolution. + expand_ratio (int): Adjusts number of channels of the hidden layer + in InvertedResidual by this amount. + dilation (int): Dilation rate of depthwise conv. Default: 1 + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU6'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + + Returns: + Tensor: The output tensor + """ + + def __init__(self, + in_channels, + out_channels, + stride, + expand_ratio, + dilation=1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6'), + with_cp=False): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2], f'stride must in [1, 2]. ' \ + f'But received {stride}.' + self.with_cp = with_cp + self.use_res_connect = self.stride == 1 and in_channels == out_channels + hidden_dim = int(round(in_channels * expand_ratio)) + + layers = [] + if expand_ratio != 1: + layers.append( + ConvModule( + in_channels=in_channels, + out_channels=hidden_dim, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + layers.extend([ + ConvModule( + in_channels=hidden_dim, + out_channels=hidden_dim, + kernel_size=3, + stride=stride, + padding=dilation, + dilation=dilation, + groups=hidden_dim, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + in_channels=hidden_dim, + out_channels=out_channels, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + ]) + self.conv = nn.Sequential(*layers) + + def forward(self, x): + + def _inner_forward(x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +@BACKBONES.register_module() +class MobileNetV2(nn.Module): + """MobileNetV2 backbone. + + Args: + widen_factor (float): Width multiplier, multiply number of + channels in each layer by this amount. Default: 1.0. + strides (Sequence[int], optional): Strides of the first block of each + layer. If not specified, default config in ``arch_setting`` will + be used. + dilations (Sequence[int]): Dilation of each layer. + out_indices (None or Sequence[int]): Output from which stages. + Default: (7, ). + frozen_stages (int): Stages to be frozen (all param fixed). + Default: -1, which means not freezing any parameters. + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU6'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + """ + + # Parameters to build layers. 3 parameters are needed to construct a + # layer, from left to right: expand_ratio, channel, num_blocks. + arch_settings = [[1, 16, 1], [6, 24, 2], [6, 32, 3], [6, 64, 4], + [6, 96, 3], [6, 160, 3], [6, 320, 1]] + + def __init__(self, + widen_factor=1., + strides=(1, 2, 2, 2, 1, 2, 1), + dilations=(1, 1, 1, 1, 1, 1, 1), + out_indices=(1, 2, 4, 6), + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6'), + norm_eval=False, + with_cp=False): + super(MobileNetV2, self).__init__() + self.widen_factor = widen_factor + self.strides = strides + self.dilations = dilations + assert len(strides) == len(dilations) == len(self.arch_settings) + self.out_indices = out_indices + for index in out_indices: + if index not in range(0, 7): + raise ValueError('the item in out_indices must in ' + f'range(0, 8). But received {index}') + + if frozen_stages not in range(-1, 7): + raise ValueError('frozen_stages must be in range(-1, 7). ' + f'But received {frozen_stages}') + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + + self.in_channels = make_divisible(32 * widen_factor, 8) + + self.conv1 = ConvModule( + in_channels=3, + out_channels=self.in_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.layers = [] + + for i, layer_cfg in enumerate(self.arch_settings): + expand_ratio, channel, num_blocks = layer_cfg + stride = self.strides[i] + dilation = self.dilations[i] + out_channels = make_divisible(channel * widen_factor, 8) + inverted_res_layer = self.make_layer( + out_channels=out_channels, + num_blocks=num_blocks, + stride=stride, + dilation=dilation, + expand_ratio=expand_ratio) + layer_name = f'layer{i + 1}' + self.add_module(layer_name, inverted_res_layer) + self.layers.append(layer_name) + + def make_layer(self, out_channels, num_blocks, stride, dilation, + expand_ratio): + """Stack InvertedResidual blocks to build a layer for MobileNetV2. + + Args: + out_channels (int): out_channels of block. + num_blocks (int): Number of blocks. + stride (int): Stride of the first block. + dilation (int): Dilation of the first block. + expand_ratio (int): Expand the number of channels of the + hidden layer in InvertedResidual by this ratio. + """ + layers = [] + for i in range(num_blocks): + layers.append( + InvertedResidual( + self.in_channels, + out_channels, + stride if i == 0 else 1, + expand_ratio=expand_ratio, + dilation=dilation if i == 0 else 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + with_cp=self.with_cp)) + self.in_channels = out_channels + + return nn.Sequential(*layers) + + def init_weights(self, pretrained=None): + if isinstance(pretrained, str): + logger = logging.getLogger() + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m, 1) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + x = self.conv1(x) + + outs = [] + for i, layer_name in enumerate(self.layers): + layer = getattr(self, layer_name) + x = layer(x) + if i in self.out_indices: + outs.append(x) + + if len(outs) == 1: + return outs[0] + else: + return tuple(outs) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + for param in self.conv1.parameters(): + param.requires_grad = False + for i in range(1, self.frozen_stages + 1): + layer = getattr(self, f'layer{i}') + layer.eval() + for param in layer.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(MobileNetV2, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() diff --git a/mmseg/models/utils/__init__.py b/mmseg/models/utils/__init__.py index 71d3f423c..bea300c3a 100644 --- a/mmseg/models/utils/__init__.py +++ b/mmseg/models/utils/__init__.py @@ -1,4 +1,5 @@ +from .make_divisible import make_divisible from .res_layer import ResLayer from .self_attention_block import SelfAttentionBlock -__all__ = ['ResLayer', 'SelfAttentionBlock'] +__all__ = ['ResLayer', 'SelfAttentionBlock', 'make_divisible'] diff --git a/mmseg/models/utils/make_divisible.py b/mmseg/models/utils/make_divisible.py new file mode 100644 index 000000000..02ee047c5 --- /dev/null +++ b/mmseg/models/utils/make_divisible.py @@ -0,0 +1,24 @@ +def make_divisible(value, divisor, min_value=None, min_ratio=0.9): + """Make divisible function. + + This function rounds the channel number down to the nearest value that can + be divisible by the divisor. + + Args: + value (int): The original channel number. + divisor (int): The divisor to fully divide the channel number. + min_value (int, optional): The minimum value of the output channel. + Default: None, means that the minimum value equal to the divisor. + min_ratio (float, optional): The minimum ratio of the rounded channel + number to the original channel number. Default: 0.9. + Returns: + int: The modified output channel number + """ + + if min_value is None: + min_value = divisor + new_value = max(min_value, int(value + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than (1-min_ratio). + if new_value < min_ratio * value: + new_value += divisor + return new_value diff --git a/tests/test_models/test_forward.py b/tests/test_models/test_forward.py index fffe23e74..ca2f7cb27 100644 --- a/tests/test_models/test_forward.py +++ b/tests/test_models/test_forward.py @@ -157,6 +157,11 @@ def test_sem_fpn_forward(): _test_encoder_decoder_forward('sem_fpn/fpn_r50_512x1024_80k_cityscapes.py') +def test_mobilenet_v2_forward(): + _test_encoder_decoder_forward( + 'mobilenet_v2/pspnet_m-v2-d8_512x1024_80k_cityscapes.py') + + def get_world_size(process_group): return 1