relevant files modified according to Jerry's instructions

This commit is contained in:
johnzja 2020-08-11 14:35:27 +08:00
parent 164e038a5d
commit 1456a48a0e
3 changed files with 78 additions and 49 deletions

View File

@ -52,14 +52,16 @@ class LearningToDownsample(nn.Module):
self.dsconv1 = DepthwiseSeparableConvModule(
dw_channels1,
dw_channels2,
kernel_size=3,
stride=2,
relu_first=False,
padding=1,
norm_cfg=self.norm_cfg)
self.dsconv2 = DepthwiseSeparableConvModule(
dw_channels2,
out_channels,
kernel_size=3,
stride=2,
relu_first=False,
padding=1,
norm_cfg=self.norm_cfg)
def forward(self, x):

View File

@ -35,19 +35,18 @@ class DepthwiseSeparableFCNHead(FCNHead):
self.convs[0] = DepthwiseSeparableConvModule(
self.in_channels,
self.channels,
norm_cfg=self.norm_cfg,
relu_first=False)
kernel_size=3,
norm_cfg=self.norm_cfg)
for i in range(1, self.num_convs):
self.convs[i] = DepthwiseSeparableConvModule(
self.channels,
self.channels,
norm_cfg=self.norm_cfg,
relu_first=False)
kernel_size=3,
norm_cfg=self.norm_cfg)
if self.concat_input:
self.conv_cat = DepthwiseSeparableConvModule(
self.in_channels + self.channels,
self.channels,
self.channels,
norm_cfg=self.norm_cfg,
relu_first=False)
norm_cfg=self.norm_cfg)

View File

@ -1,60 +1,88 @@
from mmcv.cnn import build_norm_layer
from torch import nn
import torch.nn as nn
from mmcv.cnn import ConvModule
class DepthwiseSeparableConvModule(nn.Module):
"""Depthwise separable convolution module.
See https://arxiv.org/pdf/1704.04861.pdf for details.
This module can replace a ConvModule with the conv block replaced by two
conv block: depthwise conv block and pointwise conv block. The depthwise
conv block contains depthwise-conv/norm/activation layers. The pointwise
conv block contains pointwise-conv/norm/activation layers. It should be
noted that there will be norm/activation layer in the depthwise conv block
if `norm_cfg` and `act_cfg` are specified.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
kernel_size (int or tuple[int]): Same as nn.Conv2d.
stride (int or tuple[int]): Same as nn.Conv2d. Default: 1.
padding (int or tuple[int]): Same as nn.Conv2d. Default: 0.
dilation (int or tuple[int]): Same as nn.Conv2d. Default: 1.
norm_cfg (dict): Default norm config for both depthwise ConvModule and
pointwise ConvModule. Default: None.
act_cfg (dict): Default activation config for both depthwise ConvModule
and pointwise ConvModule. Default: dict(type='ReLU').
dw_norm_cfg (dict): Norm config of depthwise ConvModule. If it is
'default', it will be the same as `norm_cfg`. Default: 'default'.
dw_act_cfg (dict): Activation config of depthwise ConvModule. If it is
'default', it will be the same as `act_cfg`. Default: 'default'.
pw_norm_cfg (dict): Norm config of pointwise ConvModule. If it is
'default', it will be the same as `norm_cfg`. Default: 'default'.
pw_act_cfg (dict): Activation config of pointwise ConvModule. If it is
'default', it will be the same as `act_cfg`. Default: 'default'.
kwargs (optional): Other shared arguments for depthwise and pointwise
ConvModule. See ConvModule for ref.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
kernel_size,
stride=1,
padding=0,
dilation=1,
relu_first=True,
bias=False,
norm_cfg=dict(type='BN')):
norm_cfg=None,
act_cfg=dict(type='ReLU'),
dw_norm_cfg='default',
dw_act_cfg='default',
pw_norm_cfg='default',
pw_act_cfg='default',
**kwargs):
super(DepthwiseSeparableConvModule, self).__init__()
self.depthwise = nn.Conv2d(
assert 'groups' not in kwargs, 'groups should not be specified'
# if norm/activation config of depthwise/pointwise ConvModule is not
# specified, use default config.
dw_norm_cfg = dw_norm_cfg if dw_norm_cfg != 'default' else norm_cfg
dw_act_cfg = dw_act_cfg if dw_act_cfg != 'default' else act_cfg
pw_norm_cfg = pw_norm_cfg if pw_norm_cfg != 'default' else norm_cfg
pw_act_cfg = pw_act_cfg if pw_act_cfg != 'default' else act_cfg
# depthwise convolution
self.depthwise_conv = ConvModule(
in_channels,
in_channels,
kernel_size,
stride=stride,
padding=dilation,
padding=padding,
dilation=dilation,
groups=in_channels,
bias=bias)
self.norm_depth_name, norm_depth = build_norm_layer(
norm_cfg, in_channels, postfix='_depth')
self.add_module(self.norm_depth_name, norm_depth)
norm_cfg=dw_norm_cfg,
act_cfg=dw_act_cfg,
**kwargs)
self.pointwise = nn.Conv2d(in_channels, out_channels, 1, bias=bias)
self.norm_point_name, norm_point = build_norm_layer(
norm_cfg, out_channels, postfix='_point')
self.add_module(self.norm_point_name, norm_point)
self.relu_first = relu_first
self.relu = nn.ReLU(inplace=not relu_first)
@property
def norm_depth(self):
return getattr(self, self.norm_depth_name)
@property
def norm_point(self):
return getattr(self, self.norm_point_name)
self.pointwise_conv = ConvModule(
in_channels,
out_channels,
1,
norm_cfg=pw_norm_cfg,
act_cfg=pw_act_cfg,
**kwargs)
def forward(self, x):
if self.relu_first:
out = self.relu(x)
out = self.depthwise(out)
out = self.norm_depth(out)
out = self.pointwise(out)
out = self.norm_point(out)
else:
out = self.depthwise(x)
out = self.norm_depth(out)
out = self.relu(out)
out = self.pointwise(out)
out = self.norm_point(out)
out = self.relu(out)
return out
x = self.depthwise_conv(x)
x = self.pointwise_conv(x)
return x