69 lines
2.3 KiB
Python
69 lines
2.3 KiB
Python
import numpy as np
|
|
import torch.nn as nn
|
|
from mmcv.cnn import ConvModule
|
|
|
|
from mmseg.ops import resize
|
|
from ..builder import HEADS
|
|
from .decode_head import BaseDecodeHead
|
|
|
|
|
|
@HEADS.register_module()
|
|
class FPNHead(BaseDecodeHead):
|
|
"""Panoptic Feature Pyramid Networks.
|
|
|
|
This head is the implementation of `Semantic FPN
|
|
<https://arxiv.org/abs/1901.02446>`_.
|
|
|
|
Args:
|
|
feature_strides (tuple[int]): The strides for input feature maps.
|
|
stack_lateral. All strides suppose to be power of 2. The first
|
|
one is of largest resolution.
|
|
"""
|
|
|
|
def __init__(self, feature_strides, **kwargs):
|
|
super(FPNHead, self).__init__(
|
|
input_transform='multiple_select', **kwargs)
|
|
assert len(feature_strides) == len(self.in_channels)
|
|
assert min(feature_strides) == feature_strides[0]
|
|
self.feature_strides = feature_strides
|
|
|
|
self.scale_heads = nn.ModuleList()
|
|
for i in range(len(feature_strides)):
|
|
head_length = max(
|
|
1,
|
|
int(np.log2(feature_strides[i]) - np.log2(feature_strides[0])))
|
|
scale_head = []
|
|
for k in range(head_length):
|
|
scale_head.append(
|
|
ConvModule(
|
|
self.in_channels[i] if k == 0 else self.channels,
|
|
self.channels,
|
|
3,
|
|
padding=1,
|
|
conv_cfg=self.conv_cfg,
|
|
norm_cfg=self.norm_cfg,
|
|
act_cfg=self.act_cfg))
|
|
if feature_strides[i] != feature_strides[0]:
|
|
scale_head.append(
|
|
nn.Upsample(
|
|
scale_factor=2,
|
|
mode='bilinear',
|
|
align_corners=self.align_corners))
|
|
self.scale_heads.append(nn.Sequential(*scale_head))
|
|
|
|
def forward(self, inputs):
|
|
|
|
x = self._transform_inputs(inputs)
|
|
|
|
output = self.scale_heads[0](x[0])
|
|
for i in range(1, len(self.feature_strides)):
|
|
# non inplace
|
|
output = output + resize(
|
|
self.scale_heads[i](x[i]),
|
|
size=output.shape[2:],
|
|
mode='bilinear',
|
|
align_corners=self.align_corners)
|
|
|
|
output = self.cls_seg(output)
|
|
return output
|