2020-07-31 14:16:00 +08:00
|
|
|
from mmcv.cnn import build_norm_layer
|
|
|
|
from torch import nn
|
2020-07-07 20:52:19 +08:00
|
|
|
|
|
|
|
|
|
|
|
class DepthwiseSeparableConvModule(nn.Module):
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
in_channels,
|
|
|
|
out_channels,
|
2020-07-31 14:16:00 +08:00
|
|
|
kernel_size=3,
|
2020-07-07 20:52:19 +08:00
|
|
|
stride=1,
|
|
|
|
dilation=1,
|
2020-07-31 14:16:00 +08:00
|
|
|
relu_first=True,
|
|
|
|
bias=False,
|
|
|
|
norm_cfg=dict(type='BN')):
|
2020-07-07 20:52:19 +08:00
|
|
|
super(DepthwiseSeparableConvModule, self).__init__()
|
2020-07-31 14:16:00 +08:00
|
|
|
self.depthwise = nn.Conv2d(
|
2020-07-07 20:52:19 +08:00
|
|
|
in_channels,
|
|
|
|
in_channels,
|
|
|
|
kernel_size,
|
|
|
|
stride=stride,
|
2020-07-31 14:16:00 +08:00
|
|
|
padding=dilation,
|
2020-07-07 20:52:19 +08:00
|
|
|
dilation=dilation,
|
|
|
|
groups=in_channels,
|
2020-07-31 14:16:00 +08:00
|
|
|
bias=bias)
|
|
|
|
self.norm_depth_name, norm_depth = build_norm_layer(
|
|
|
|
norm_cfg, in_channels, postfix='_depth')
|
|
|
|
self.add_module(self.norm_depth_name, norm_depth)
|
2020-07-07 20:52:19 +08:00
|
|
|
|
2020-07-31 14:16:00 +08:00
|
|
|
self.pointwise = nn.Conv2d(in_channels, out_channels, 1, bias=bias)
|
|
|
|
self.norm_point_name, norm_point = build_norm_layer(
|
|
|
|
norm_cfg, out_channels, postfix='_point')
|
|
|
|
self.add_module(self.norm_point_name, norm_point)
|
|
|
|
|
|
|
|
self.relu_first = relu_first
|
|
|
|
self.relu = nn.ReLU(inplace=not relu_first)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def norm_depth(self):
|
|
|
|
return getattr(self, self.norm_depth_name)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def norm_point(self):
|
|
|
|
return getattr(self, self.norm_point_name)
|
2020-07-07 20:52:19 +08:00
|
|
|
|
|
|
|
def forward(self, x):
|
2020-07-31 14:16:00 +08:00
|
|
|
if self.relu_first:
|
|
|
|
out = self.relu(x)
|
|
|
|
out = self.depthwise(out)
|
|
|
|
out = self.norm_depth(out)
|
|
|
|
out = self.pointwise(out)
|
|
|
|
out = self.norm_point(out)
|
|
|
|
else:
|
|
|
|
out = self.depthwise(x)
|
|
|
|
out = self.norm_depth(out)
|
|
|
|
out = self.relu(out)
|
|
|
|
out = self.pointwise(out)
|
|
|
|
out = self.norm_point(out)
|
|
|
|
out = self.relu(out)
|
|
|
|
return out
|