From 1456a48a0edd1d376219932ce9919a4899af6d5e Mon Sep 17 00:00:00 2001 From: johnzja Date: Tue, 11 Aug 2020 14:35:27 +0800 Subject: [PATCH] relevant files modified according to Jerry's instructions --- mmseg/models/backbones/fast_scnn.py | 6 +- mmseg/models/decode_heads/sep_fcn_head.py | 11 +-- mmseg/ops/separable_conv_module.py | 110 ++++++++++++++-------- 3 files changed, 78 insertions(+), 49 deletions(-) diff --git a/mmseg/models/backbones/fast_scnn.py b/mmseg/models/backbones/fast_scnn.py index 2937c82e2..9a1284d5f 100644 --- a/mmseg/models/backbones/fast_scnn.py +++ b/mmseg/models/backbones/fast_scnn.py @@ -52,14 +52,16 @@ class LearningToDownsample(nn.Module): self.dsconv1 = DepthwiseSeparableConvModule( dw_channels1, dw_channels2, + kernel_size=3, stride=2, - relu_first=False, + padding=1, norm_cfg=self.norm_cfg) self.dsconv2 = DepthwiseSeparableConvModule( dw_channels2, out_channels, + kernel_size=3, stride=2, - relu_first=False, + padding=1, norm_cfg=self.norm_cfg) def forward(self, x): diff --git a/mmseg/models/decode_heads/sep_fcn_head.py b/mmseg/models/decode_heads/sep_fcn_head.py index a4689ed08..4b263cc47 100644 --- a/mmseg/models/decode_heads/sep_fcn_head.py +++ b/mmseg/models/decode_heads/sep_fcn_head.py @@ -35,19 +35,18 @@ class DepthwiseSeparableFCNHead(FCNHead): self.convs[0] = DepthwiseSeparableConvModule( self.in_channels, self.channels, - norm_cfg=self.norm_cfg, - relu_first=False) + kernel_size=3, + norm_cfg=self.norm_cfg) for i in range(1, self.num_convs): self.convs[i] = DepthwiseSeparableConvModule( self.channels, self.channels, - norm_cfg=self.norm_cfg, - relu_first=False) + kernel_size=3, + norm_cfg=self.norm_cfg) if self.concat_input: self.conv_cat = DepthwiseSeparableConvModule( self.in_channels + self.channels, self.channels, self.channels, - norm_cfg=self.norm_cfg, - relu_first=False) + norm_cfg=self.norm_cfg) diff --git a/mmseg/ops/separable_conv_module.py b/mmseg/ops/separable_conv_module.py index e11365400..4e5922cc4 100644 --- a/mmseg/ops/separable_conv_module.py +++ b/mmseg/ops/separable_conv_module.py @@ -1,60 +1,88 @@ -from mmcv.cnn import build_norm_layer -from torch import nn +import torch.nn as nn +from mmcv.cnn import ConvModule class DepthwiseSeparableConvModule(nn.Module): + """Depthwise separable convolution module. + + See https://arxiv.org/pdf/1704.04861.pdf for details. + + This module can replace a ConvModule with the conv block replaced by two + conv block: depthwise conv block and pointwise conv block. The depthwise + conv block contains depthwise-conv/norm/activation layers. The pointwise + conv block contains pointwise-conv/norm/activation layers. It should be + noted that there will be norm/activation layer in the depthwise conv block + if `norm_cfg` and `act_cfg` are specified. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int or tuple[int]): Same as nn.Conv2d. + stride (int or tuple[int]): Same as nn.Conv2d. Default: 1. + padding (int or tuple[int]): Same as nn.Conv2d. Default: 0. + dilation (int or tuple[int]): Same as nn.Conv2d. Default: 1. + norm_cfg (dict): Default norm config for both depthwise ConvModule and + pointwise ConvModule. Default: None. + act_cfg (dict): Default activation config for both depthwise ConvModule + and pointwise ConvModule. Default: dict(type='ReLU'). + dw_norm_cfg (dict): Norm config of depthwise ConvModule. If it is + 'default', it will be the same as `norm_cfg`. Default: 'default'. + dw_act_cfg (dict): Activation config of depthwise ConvModule. If it is + 'default', it will be the same as `act_cfg`. Default: 'default'. + pw_norm_cfg (dict): Norm config of pointwise ConvModule. If it is + 'default', it will be the same as `norm_cfg`. Default: 'default'. + pw_act_cfg (dict): Activation config of pointwise ConvModule. If it is + 'default', it will be the same as `act_cfg`. Default: 'default'. + kwargs (optional): Other shared arguments for depthwise and pointwise + ConvModule. See ConvModule for ref. + """ def __init__(self, in_channels, out_channels, - kernel_size=3, + kernel_size, stride=1, + padding=0, dilation=1, - relu_first=True, - bias=False, - norm_cfg=dict(type='BN')): + norm_cfg=None, + act_cfg=dict(type='ReLU'), + dw_norm_cfg='default', + dw_act_cfg='default', + pw_norm_cfg='default', + pw_act_cfg='default', + **kwargs): super(DepthwiseSeparableConvModule, self).__init__() - self.depthwise = nn.Conv2d( + assert 'groups' not in kwargs, 'groups should not be specified' + + # if norm/activation config of depthwise/pointwise ConvModule is not + # specified, use default config. + dw_norm_cfg = dw_norm_cfg if dw_norm_cfg != 'default' else norm_cfg + dw_act_cfg = dw_act_cfg if dw_act_cfg != 'default' else act_cfg + pw_norm_cfg = pw_norm_cfg if pw_norm_cfg != 'default' else norm_cfg + pw_act_cfg = pw_act_cfg if pw_act_cfg != 'default' else act_cfg + + # depthwise convolution + self.depthwise_conv = ConvModule( in_channels, in_channels, kernel_size, stride=stride, - padding=dilation, + padding=padding, dilation=dilation, groups=in_channels, - bias=bias) - self.norm_depth_name, norm_depth = build_norm_layer( - norm_cfg, in_channels, postfix='_depth') - self.add_module(self.norm_depth_name, norm_depth) + norm_cfg=dw_norm_cfg, + act_cfg=dw_act_cfg, + **kwargs) - self.pointwise = nn.Conv2d(in_channels, out_channels, 1, bias=bias) - self.norm_point_name, norm_point = build_norm_layer( - norm_cfg, out_channels, postfix='_point') - self.add_module(self.norm_point_name, norm_point) - - self.relu_first = relu_first - self.relu = nn.ReLU(inplace=not relu_first) - - @property - def norm_depth(self): - return getattr(self, self.norm_depth_name) - - @property - def norm_point(self): - return getattr(self, self.norm_point_name) + self.pointwise_conv = ConvModule( + in_channels, + out_channels, + 1, + norm_cfg=pw_norm_cfg, + act_cfg=pw_act_cfg, + **kwargs) def forward(self, x): - if self.relu_first: - out = self.relu(x) - out = self.depthwise(out) - out = self.norm_depth(out) - out = self.pointwise(out) - out = self.norm_point(out) - else: - out = self.depthwise(x) - out = self.norm_depth(out) - out = self.relu(out) - out = self.pointwise(out) - out = self.norm_point(out) - out = self.relu(out) - return out + x = self.depthwise_conv(x) + x = self.pointwise_conv(x) + return x