mmsegmentation/mmseg/ops/separable_conv_module.py

61 lines
1.8 KiB
Python
Raw Normal View History

2020-07-31 14:16:00 +08:00
from mmcv.cnn import build_norm_layer
from torch import nn
2020-07-07 20:52:19 +08:00
class DepthwiseSeparableConvModule(nn.Module):
def __init__(self,
in_channels,
out_channels,
2020-07-31 14:16:00 +08:00
kernel_size=3,
2020-07-07 20:52:19 +08:00
stride=1,
dilation=1,
2020-07-31 14:16:00 +08:00
relu_first=True,
bias=False,
norm_cfg=dict(type='BN')):
2020-07-07 20:52:19 +08:00
super(DepthwiseSeparableConvModule, self).__init__()
2020-07-31 14:16:00 +08:00
self.depthwise = nn.Conv2d(
2020-07-07 20:52:19 +08:00
in_channels,
in_channels,
kernel_size,
stride=stride,
2020-07-31 14:16:00 +08:00
padding=dilation,
2020-07-07 20:52:19 +08:00
dilation=dilation,
groups=in_channels,
2020-07-31 14:16:00 +08:00
bias=bias)
self.norm_depth_name, norm_depth = build_norm_layer(
norm_cfg, in_channels, postfix='_depth')
self.add_module(self.norm_depth_name, norm_depth)
2020-07-07 20:52:19 +08:00
2020-07-31 14:16:00 +08:00
self.pointwise = nn.Conv2d(in_channels, out_channels, 1, bias=bias)
self.norm_point_name, norm_point = build_norm_layer(
norm_cfg, out_channels, postfix='_point')
self.add_module(self.norm_point_name, norm_point)
self.relu_first = relu_first
self.relu = nn.ReLU(inplace=not relu_first)
@property
def norm_depth(self):
return getattr(self, self.norm_depth_name)
@property
def norm_point(self):
return getattr(self, self.norm_point_name)
2020-07-07 20:52:19 +08:00
def forward(self, x):
2020-07-31 14:16:00 +08:00
if self.relu_first:
out = self.relu(x)
out = self.depthwise(out)
out = self.norm_depth(out)
out = self.pointwise(out)
out = self.norm_point(out)
else:
out = self.depthwise(x)
out = self.norm_depth(out)
out = self.relu(out)
out = self.pointwise(out)
out = self.norm_point(out)
out = self.relu(out)
return out