92 lines
3.0 KiB
Python
92 lines
3.0 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch
|
|
import torch.nn as nn
|
|
from mmcv import is_tuple_of
|
|
from mmcv.cnn import ConvModule
|
|
|
|
from mmseg.ops import resize
|
|
from ..builder import HEADS
|
|
from .decode_head import BaseDecodeHead
|
|
|
|
|
|
@HEADS.register_module()
|
|
class LRASPPHead(BaseDecodeHead):
|
|
"""Lite R-ASPP (LRASPP) head is proposed in Searching for MobileNetV3.
|
|
|
|
This head is the improved implementation of `Searching for MobileNetV3
|
|
<https://ieeexplore.ieee.org/document/9008835>`_.
|
|
|
|
Args:
|
|
branch_channels (tuple[int]): The number of output channels in every
|
|
each branch. Default: (32, 64).
|
|
"""
|
|
|
|
def __init__(self, branch_channels=(32, 64), **kwargs):
|
|
super(LRASPPHead, self).__init__(**kwargs)
|
|
if self.input_transform != 'multiple_select':
|
|
raise ValueError('in Lite R-ASPP (LRASPP) head, input_transform '
|
|
f'must be \'multiple_select\'. But received '
|
|
f'\'{self.input_transform}\'')
|
|
assert is_tuple_of(branch_channels, int)
|
|
assert len(branch_channels) == len(self.in_channels) - 1
|
|
self.branch_channels = branch_channels
|
|
|
|
self.convs = nn.Sequential()
|
|
self.conv_ups = nn.Sequential()
|
|
for i in range(len(branch_channels)):
|
|
self.convs.add_module(
|
|
f'conv{i}',
|
|
nn.Conv2d(
|
|
self.in_channels[i], branch_channels[i], 1, bias=False))
|
|
self.conv_ups.add_module(
|
|
f'conv_up{i}',
|
|
ConvModule(
|
|
self.channels + branch_channels[i],
|
|
self.channels,
|
|
1,
|
|
norm_cfg=self.norm_cfg,
|
|
act_cfg=self.act_cfg,
|
|
bias=False))
|
|
|
|
self.conv_up_input = nn.Conv2d(self.channels, self.channels, 1)
|
|
|
|
self.aspp_conv = ConvModule(
|
|
self.in_channels[-1],
|
|
self.channels,
|
|
1,
|
|
norm_cfg=self.norm_cfg,
|
|
act_cfg=self.act_cfg,
|
|
bias=False)
|
|
self.image_pool = nn.Sequential(
|
|
nn.AvgPool2d(kernel_size=49, stride=(16, 20)),
|
|
ConvModule(
|
|
self.in_channels[2],
|
|
self.channels,
|
|
1,
|
|
act_cfg=dict(type='Sigmoid'),
|
|
bias=False))
|
|
|
|
def forward(self, inputs):
|
|
"""Forward function."""
|
|
inputs = self._transform_inputs(inputs)
|
|
|
|
x = inputs[-1]
|
|
|
|
x = self.aspp_conv(x) * resize(
|
|
self.image_pool(x),
|
|
size=x.size()[2:],
|
|
mode='bilinear',
|
|
align_corners=self.align_corners)
|
|
x = self.conv_up_input(x)
|
|
|
|
for i in range(len(self.branch_channels) - 1, -1, -1):
|
|
x = resize(
|
|
x,
|
|
size=inputs[i].size()[2:],
|
|
mode='bilinear',
|
|
align_corners=self.align_corners)
|
|
x = torch.cat([x, self.convs[i](inputs[i])], 1)
|
|
x = self.conv_ups[i](x)
|
|
|
|
return self.cls_seg(x)
|