From 3bc95a433223e261dcc9cba4c329eca644cbaded Mon Sep 17 00:00:00 2001 From: johnzja Date: Wed, 12 Aug 2020 15:58:48 +0800 Subject: [PATCH] Inverted Residual kept coherent with mmcl. --- .../fast_scnn_4x8_80k_lr0.12_cityscapes.py | 1 - mmseg/ops/inverted_residual_module.py | 85 +++++++++++-------- 2 files changed, 51 insertions(+), 35 deletions(-) diff --git a/configs/fastscnn/fast_scnn_4x8_80k_lr0.12_cityscapes.py b/configs/fastscnn/fast_scnn_4x8_80k_lr0.12_cityscapes.py index 0da288735..53fcfc420 100644 --- a/configs/fastscnn/fast_scnn_4x8_80k_lr0.12_cityscapes.py +++ b/configs/fastscnn/fast_scnn_4x8_80k_lr0.12_cityscapes.py @@ -2,7 +2,6 @@ _base_ = [ '../_base_/models/fast_scnn.py', '../_base_/datasets/cityscapes.py', '../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py' ] -cudnn_benchmark = True # Re-config the data sampler. data = dict(samples_per_gpu=8, workers_per_gpu=4) diff --git a/mmseg/ops/inverted_residual_module.py b/mmseg/ops/inverted_residual_module.py index 3be849982..53f6744e8 100644 --- a/mmseg/ops/inverted_residual_module.py +++ b/mmseg/ops/inverted_residual_module.py @@ -1,71 +1,88 @@ -from mmcv.cnn import ConvModule, build_norm_layer +import torch.utils.checkpoint as cp +from mmcv.cnn import ConvModule from torch import nn class InvertedResidual(nn.Module): - """Inverted residual module. + """InvertedResidual block for MobileNetV2. Args: - inp (int): input channels. - oup (int): output channels. - stride (int): downsampling factor. - expand_ratio (int): 1 or 2. - dilation (int): Dilated conv. Default: 1. - conv_cfg (dict | None): Config of conv layers. Default: None. - norm_cfg (dict | None): Config of norm layers. Default: - dict(type='BN'). - act_cfg (dict): Config of activation layers. Default: - dict(type='ReLU6'). + in_channels (int): The input channels of the InvertedResidual block. + out_channels (int): The output channels of the InvertedResidual block. + stride (int): Stride of the middle (first) 3x3 convolution. + expand_ratio (int): adjusts number of channels of the hidden layer + in InvertedResidual by this amount. + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU6'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + Returns: + Tensor: The output tensor """ def __init__(self, - inp, - oup, + in_channels, + out_channels, stride, expand_ratio, - dilation=1, conv_cfg=None, norm_cfg=dict(type='BN'), - act_cfg=dict(type='ReLU6')): + act_cfg=dict(type='ReLU6'), + with_cp=False): super(InvertedResidual, self).__init__() self.stride = stride - assert stride in [1, 2] - - hidden_dim = int(round(inp * expand_ratio)) - self.use_res_connect = self.stride == 1 and inp == oup + assert stride in [1, 2], f'stride must in [1, 2]. ' \ + f'But received {stride}.' + self.with_cp = with_cp + self.use_res_connect = self.stride == 1 and in_channels == out_channels + hidden_dim = int(round(in_channels * expand_ratio)) layers = [] if expand_ratio != 1: - # pw layers.append( ConvModule( - inp, - hidden_dim, + in_channels=in_channels, + out_channels=hidden_dim, kernel_size=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)) layers.extend([ - # dw ConvModule( - hidden_dim, - hidden_dim, + in_channels=hidden_dim, + out_channels=hidden_dim, kernel_size=3, - padding=dilation, stride=stride, - dilation=dilation, + padding=1, groups=hidden_dim, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg), - # pw-linear - nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), - build_norm_layer(norm_cfg, oup)[1], + ConvModule( + in_channels=hidden_dim, + out_channels=out_channels, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) ]) self.conv = nn.Sequential(*layers) def forward(self, x): - if self.use_res_connect: - return x + self.conv(x) + + def _inner_forward(x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) else: - return self.conv(x) + out = _inner_forward(x) + + return out