mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
relevant files modified according to Jerry's instructions
This commit is contained in:
parent
164e038a5d
commit
1456a48a0e
@ -52,14 +52,16 @@ class LearningToDownsample(nn.Module):
|
|||||||
self.dsconv1 = DepthwiseSeparableConvModule(
|
self.dsconv1 = DepthwiseSeparableConvModule(
|
||||||
dw_channels1,
|
dw_channels1,
|
||||||
dw_channels2,
|
dw_channels2,
|
||||||
|
kernel_size=3,
|
||||||
stride=2,
|
stride=2,
|
||||||
relu_first=False,
|
padding=1,
|
||||||
norm_cfg=self.norm_cfg)
|
norm_cfg=self.norm_cfg)
|
||||||
self.dsconv2 = DepthwiseSeparableConvModule(
|
self.dsconv2 = DepthwiseSeparableConvModule(
|
||||||
dw_channels2,
|
dw_channels2,
|
||||||
out_channels,
|
out_channels,
|
||||||
|
kernel_size=3,
|
||||||
stride=2,
|
stride=2,
|
||||||
relu_first=False,
|
padding=1,
|
||||||
norm_cfg=self.norm_cfg)
|
norm_cfg=self.norm_cfg)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -35,19 +35,18 @@ class DepthwiseSeparableFCNHead(FCNHead):
|
|||||||
self.convs[0] = DepthwiseSeparableConvModule(
|
self.convs[0] = DepthwiseSeparableConvModule(
|
||||||
self.in_channels,
|
self.in_channels,
|
||||||
self.channels,
|
self.channels,
|
||||||
norm_cfg=self.norm_cfg,
|
kernel_size=3,
|
||||||
relu_first=False)
|
norm_cfg=self.norm_cfg)
|
||||||
for i in range(1, self.num_convs):
|
for i in range(1, self.num_convs):
|
||||||
self.convs[i] = DepthwiseSeparableConvModule(
|
self.convs[i] = DepthwiseSeparableConvModule(
|
||||||
self.channels,
|
self.channels,
|
||||||
self.channels,
|
self.channels,
|
||||||
norm_cfg=self.norm_cfg,
|
kernel_size=3,
|
||||||
relu_first=False)
|
norm_cfg=self.norm_cfg)
|
||||||
|
|
||||||
if self.concat_input:
|
if self.concat_input:
|
||||||
self.conv_cat = DepthwiseSeparableConvModule(
|
self.conv_cat = DepthwiseSeparableConvModule(
|
||||||
self.in_channels + self.channels,
|
self.in_channels + self.channels,
|
||||||
self.channels,
|
self.channels,
|
||||||
self.channels,
|
self.channels,
|
||||||
norm_cfg=self.norm_cfg,
|
norm_cfg=self.norm_cfg)
|
||||||
relu_first=False)
|
|
||||||
|
@ -1,60 +1,88 @@
|
|||||||
from mmcv.cnn import build_norm_layer
|
import torch.nn as nn
|
||||||
from torch import nn
|
from mmcv.cnn import ConvModule
|
||||||
|
|
||||||
|
|
||||||
class DepthwiseSeparableConvModule(nn.Module):
|
class DepthwiseSeparableConvModule(nn.Module):
|
||||||
|
"""Depthwise separable convolution module.
|
||||||
|
|
||||||
|
See https://arxiv.org/pdf/1704.04861.pdf for details.
|
||||||
|
|
||||||
|
This module can replace a ConvModule with the conv block replaced by two
|
||||||
|
conv block: depthwise conv block and pointwise conv block. The depthwise
|
||||||
|
conv block contains depthwise-conv/norm/activation layers. The pointwise
|
||||||
|
conv block contains pointwise-conv/norm/activation layers. It should be
|
||||||
|
noted that there will be norm/activation layer in the depthwise conv block
|
||||||
|
if `norm_cfg` and `act_cfg` are specified.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): Same as nn.Conv2d.
|
||||||
|
out_channels (int): Same as nn.Conv2d.
|
||||||
|
kernel_size (int or tuple[int]): Same as nn.Conv2d.
|
||||||
|
stride (int or tuple[int]): Same as nn.Conv2d. Default: 1.
|
||||||
|
padding (int or tuple[int]): Same as nn.Conv2d. Default: 0.
|
||||||
|
dilation (int or tuple[int]): Same as nn.Conv2d. Default: 1.
|
||||||
|
norm_cfg (dict): Default norm config for both depthwise ConvModule and
|
||||||
|
pointwise ConvModule. Default: None.
|
||||||
|
act_cfg (dict): Default activation config for both depthwise ConvModule
|
||||||
|
and pointwise ConvModule. Default: dict(type='ReLU').
|
||||||
|
dw_norm_cfg (dict): Norm config of depthwise ConvModule. If it is
|
||||||
|
'default', it will be the same as `norm_cfg`. Default: 'default'.
|
||||||
|
dw_act_cfg (dict): Activation config of depthwise ConvModule. If it is
|
||||||
|
'default', it will be the same as `act_cfg`. Default: 'default'.
|
||||||
|
pw_norm_cfg (dict): Norm config of pointwise ConvModule. If it is
|
||||||
|
'default', it will be the same as `norm_cfg`. Default: 'default'.
|
||||||
|
pw_act_cfg (dict): Activation config of pointwise ConvModule. If it is
|
||||||
|
'default', it will be the same as `act_cfg`. Default: 'default'.
|
||||||
|
kwargs (optional): Other shared arguments for depthwise and pointwise
|
||||||
|
ConvModule. See ConvModule for ref.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
in_channels,
|
in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
kernel_size=3,
|
kernel_size,
|
||||||
stride=1,
|
stride=1,
|
||||||
|
padding=0,
|
||||||
dilation=1,
|
dilation=1,
|
||||||
relu_first=True,
|
norm_cfg=None,
|
||||||
bias=False,
|
act_cfg=dict(type='ReLU'),
|
||||||
norm_cfg=dict(type='BN')):
|
dw_norm_cfg='default',
|
||||||
|
dw_act_cfg='default',
|
||||||
|
pw_norm_cfg='default',
|
||||||
|
pw_act_cfg='default',
|
||||||
|
**kwargs):
|
||||||
super(DepthwiseSeparableConvModule, self).__init__()
|
super(DepthwiseSeparableConvModule, self).__init__()
|
||||||
self.depthwise = nn.Conv2d(
|
assert 'groups' not in kwargs, 'groups should not be specified'
|
||||||
|
|
||||||
|
# if norm/activation config of depthwise/pointwise ConvModule is not
|
||||||
|
# specified, use default config.
|
||||||
|
dw_norm_cfg = dw_norm_cfg if dw_norm_cfg != 'default' else norm_cfg
|
||||||
|
dw_act_cfg = dw_act_cfg if dw_act_cfg != 'default' else act_cfg
|
||||||
|
pw_norm_cfg = pw_norm_cfg if pw_norm_cfg != 'default' else norm_cfg
|
||||||
|
pw_act_cfg = pw_act_cfg if pw_act_cfg != 'default' else act_cfg
|
||||||
|
|
||||||
|
# depthwise convolution
|
||||||
|
self.depthwise_conv = ConvModule(
|
||||||
in_channels,
|
in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
stride=stride,
|
stride=stride,
|
||||||
padding=dilation,
|
padding=padding,
|
||||||
dilation=dilation,
|
dilation=dilation,
|
||||||
groups=in_channels,
|
groups=in_channels,
|
||||||
bias=bias)
|
norm_cfg=dw_norm_cfg,
|
||||||
self.norm_depth_name, norm_depth = build_norm_layer(
|
act_cfg=dw_act_cfg,
|
||||||
norm_cfg, in_channels, postfix='_depth')
|
**kwargs)
|
||||||
self.add_module(self.norm_depth_name, norm_depth)
|
|
||||||
|
|
||||||
self.pointwise = nn.Conv2d(in_channels, out_channels, 1, bias=bias)
|
self.pointwise_conv = ConvModule(
|
||||||
self.norm_point_name, norm_point = build_norm_layer(
|
in_channels,
|
||||||
norm_cfg, out_channels, postfix='_point')
|
out_channels,
|
||||||
self.add_module(self.norm_point_name, norm_point)
|
1,
|
||||||
|
norm_cfg=pw_norm_cfg,
|
||||||
self.relu_first = relu_first
|
act_cfg=pw_act_cfg,
|
||||||
self.relu = nn.ReLU(inplace=not relu_first)
|
**kwargs)
|
||||||
|
|
||||||
@property
|
|
||||||
def norm_depth(self):
|
|
||||||
return getattr(self, self.norm_depth_name)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def norm_point(self):
|
|
||||||
return getattr(self, self.norm_point_name)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self.relu_first:
|
x = self.depthwise_conv(x)
|
||||||
out = self.relu(x)
|
x = self.pointwise_conv(x)
|
||||||
out = self.depthwise(out)
|
return x
|
||||||
out = self.norm_depth(out)
|
|
||||||
out = self.pointwise(out)
|
|
||||||
out = self.norm_point(out)
|
|
||||||
else:
|
|
||||||
out = self.depthwise(x)
|
|
||||||
out = self.norm_depth(out)
|
|
||||||
out = self.relu(out)
|
|
||||||
out = self.pointwise(out)
|
|
||||||
out = self.norm_point(out)
|
|
||||||
out = self.relu(out)
|
|
||||||
return out
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user