63 lines
2.1 KiB
Python
63 lines
2.1 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from mmcv.cnn import ConvModule
|
|
|
|
from mmseg.ops import Upsample
|
|
from ..builder import HEADS
|
|
from .decode_head import BaseDecodeHead
|
|
|
|
|
|
@HEADS.register_module()
|
|
class SETRMLAHead(BaseDecodeHead):
|
|
"""Multi level feature aggretation head of SETR.
|
|
|
|
MLA head of `SETR <https://arxiv.org/pdf/2012.15840.pdf>`_.
|
|
|
|
Args:
|
|
mlahead_channels (int): Channels of conv-conv-4x of multi-level feature
|
|
aggregation. Default: 128.
|
|
up_scale (int): The scale factor of interpolate. Default:4.
|
|
"""
|
|
|
|
def __init__(self, mla_channels=128, up_scale=4, **kwargs):
|
|
super(SETRMLAHead, self).__init__(
|
|
input_transform='multiple_select', **kwargs)
|
|
self.mla_channels = mla_channels
|
|
|
|
num_inputs = len(self.in_channels)
|
|
|
|
# Refer to self.cls_seg settings of BaseDecodeHead
|
|
assert self.channels == num_inputs * mla_channels
|
|
|
|
self.up_convs = nn.ModuleList()
|
|
for i in range(num_inputs):
|
|
self.up_convs.append(
|
|
nn.Sequential(
|
|
ConvModule(
|
|
in_channels=self.in_channels[i],
|
|
out_channels=mla_channels,
|
|
kernel_size=3,
|
|
padding=1,
|
|
norm_cfg=self.norm_cfg,
|
|
act_cfg=self.act_cfg),
|
|
ConvModule(
|
|
in_channels=mla_channels,
|
|
out_channels=mla_channels,
|
|
kernel_size=3,
|
|
padding=1,
|
|
norm_cfg=self.norm_cfg,
|
|
act_cfg=self.act_cfg),
|
|
Upsample(
|
|
scale_factor=up_scale,
|
|
mode='bilinear',
|
|
align_corners=self.align_corners)))
|
|
|
|
def forward(self, inputs):
|
|
inputs = self._transform_inputs(inputs)
|
|
outs = []
|
|
for x, up_conv in zip(inputs, self.up_convs):
|
|
outs.append(up_conv(x))
|
|
out = torch.cat(outs, dim=1)
|
|
out = self.cls_seg(out)
|
|
return out
|