mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
relevant files modified according to Jerry's instructions
This commit is contained in:
parent
164e038a5d
commit
1456a48a0e
@ -52,14 +52,16 @@ class LearningToDownsample(nn.Module):
|
||||
self.dsconv1 = DepthwiseSeparableConvModule(
|
||||
dw_channels1,
|
||||
dw_channels2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
relu_first=False,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg)
|
||||
self.dsconv2 = DepthwiseSeparableConvModule(
|
||||
dw_channels2,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
relu_first=False,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -35,19 +35,18 @@ class DepthwiseSeparableFCNHead(FCNHead):
|
||||
self.convs[0] = DepthwiseSeparableConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
norm_cfg=self.norm_cfg,
|
||||
relu_first=False)
|
||||
kernel_size=3,
|
||||
norm_cfg=self.norm_cfg)
|
||||
for i in range(1, self.num_convs):
|
||||
self.convs[i] = DepthwiseSeparableConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
norm_cfg=self.norm_cfg,
|
||||
relu_first=False)
|
||||
kernel_size=3,
|
||||
norm_cfg=self.norm_cfg)
|
||||
|
||||
if self.concat_input:
|
||||
self.conv_cat = DepthwiseSeparableConvModule(
|
||||
self.in_channels + self.channels,
|
||||
self.channels,
|
||||
self.channels,
|
||||
norm_cfg=self.norm_cfg,
|
||||
relu_first=False)
|
||||
norm_cfg=self.norm_cfg)
|
||||
|
@ -1,60 +1,88 @@
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from torch import nn
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
|
||||
class DepthwiseSeparableConvModule(nn.Module):
|
||||
"""Depthwise separable convolution module.
|
||||
|
||||
See https://arxiv.org/pdf/1704.04861.pdf for details.
|
||||
|
||||
This module can replace a ConvModule with the conv block replaced by two
|
||||
conv block: depthwise conv block and pointwise conv block. The depthwise
|
||||
conv block contains depthwise-conv/norm/activation layers. The pointwise
|
||||
conv block contains pointwise-conv/norm/activation layers. It should be
|
||||
noted that there will be norm/activation layer in the depthwise conv block
|
||||
if `norm_cfg` and `act_cfg` are specified.
|
||||
|
||||
Args:
|
||||
in_channels (int): Same as nn.Conv2d.
|
||||
out_channels (int): Same as nn.Conv2d.
|
||||
kernel_size (int or tuple[int]): Same as nn.Conv2d.
|
||||
stride (int or tuple[int]): Same as nn.Conv2d. Default: 1.
|
||||
padding (int or tuple[int]): Same as nn.Conv2d. Default: 0.
|
||||
dilation (int or tuple[int]): Same as nn.Conv2d. Default: 1.
|
||||
norm_cfg (dict): Default norm config for both depthwise ConvModule and
|
||||
pointwise ConvModule. Default: None.
|
||||
act_cfg (dict): Default activation config for both depthwise ConvModule
|
||||
and pointwise ConvModule. Default: dict(type='ReLU').
|
||||
dw_norm_cfg (dict): Norm config of depthwise ConvModule. If it is
|
||||
'default', it will be the same as `norm_cfg`. Default: 'default'.
|
||||
dw_act_cfg (dict): Activation config of depthwise ConvModule. If it is
|
||||
'default', it will be the same as `act_cfg`. Default: 'default'.
|
||||
pw_norm_cfg (dict): Norm config of pointwise ConvModule. If it is
|
||||
'default', it will be the same as `norm_cfg`. Default: 'default'.
|
||||
pw_act_cfg (dict): Activation config of pointwise ConvModule. If it is
|
||||
'default', it will be the same as `act_cfg`. Default: 'default'.
|
||||
kwargs (optional): Other shared arguments for depthwise and pointwise
|
||||
ConvModule. See ConvModule for ref.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
relu_first=True,
|
||||
bias=False,
|
||||
norm_cfg=dict(type='BN')):
|
||||
norm_cfg=None,
|
||||
act_cfg=dict(type='ReLU'),
|
||||
dw_norm_cfg='default',
|
||||
dw_act_cfg='default',
|
||||
pw_norm_cfg='default',
|
||||
pw_act_cfg='default',
|
||||
**kwargs):
|
||||
super(DepthwiseSeparableConvModule, self).__init__()
|
||||
self.depthwise = nn.Conv2d(
|
||||
assert 'groups' not in kwargs, 'groups should not be specified'
|
||||
|
||||
# if norm/activation config of depthwise/pointwise ConvModule is not
|
||||
# specified, use default config.
|
||||
dw_norm_cfg = dw_norm_cfg if dw_norm_cfg != 'default' else norm_cfg
|
||||
dw_act_cfg = dw_act_cfg if dw_act_cfg != 'default' else act_cfg
|
||||
pw_norm_cfg = pw_norm_cfg if pw_norm_cfg != 'default' else norm_cfg
|
||||
pw_act_cfg = pw_act_cfg if pw_act_cfg != 'default' else act_cfg
|
||||
|
||||
# depthwise convolution
|
||||
self.depthwise_conv = ConvModule(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=dilation,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=in_channels,
|
||||
bias=bias)
|
||||
self.norm_depth_name, norm_depth = build_norm_layer(
|
||||
norm_cfg, in_channels, postfix='_depth')
|
||||
self.add_module(self.norm_depth_name, norm_depth)
|
||||
norm_cfg=dw_norm_cfg,
|
||||
act_cfg=dw_act_cfg,
|
||||
**kwargs)
|
||||
|
||||
self.pointwise = nn.Conv2d(in_channels, out_channels, 1, bias=bias)
|
||||
self.norm_point_name, norm_point = build_norm_layer(
|
||||
norm_cfg, out_channels, postfix='_point')
|
||||
self.add_module(self.norm_point_name, norm_point)
|
||||
|
||||
self.relu_first = relu_first
|
||||
self.relu = nn.ReLU(inplace=not relu_first)
|
||||
|
||||
@property
|
||||
def norm_depth(self):
|
||||
return getattr(self, self.norm_depth_name)
|
||||
|
||||
@property
|
||||
def norm_point(self):
|
||||
return getattr(self, self.norm_point_name)
|
||||
self.pointwise_conv = ConvModule(
|
||||
in_channels,
|
||||
out_channels,
|
||||
1,
|
||||
norm_cfg=pw_norm_cfg,
|
||||
act_cfg=pw_act_cfg,
|
||||
**kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
if self.relu_first:
|
||||
out = self.relu(x)
|
||||
out = self.depthwise(out)
|
||||
out = self.norm_depth(out)
|
||||
out = self.pointwise(out)
|
||||
out = self.norm_point(out)
|
||||
else:
|
||||
out = self.depthwise(x)
|
||||
out = self.norm_depth(out)
|
||||
out = self.relu(out)
|
||||
out = self.pointwise(out)
|
||||
out = self.norm_point(out)
|
||||
out = self.relu(out)
|
||||
return out
|
||||
x = self.depthwise_conv(x)
|
||||
x = self.pointwise_conv(x)
|
||||
return x
|
||||
|
Loading…
x
Reference in New Issue
Block a user