mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
add upsample neck (#512)
* init * upsample v1.0 * fix errors * change to in_channels list * add unittest, docstring, norm/act config and rename Co-authored-by: xiexinch <test767803@foxmail.com>
This commit is contained in:
parent
84fb600d47
commit
98ef5ac705
@ -1,3 +1,4 @@
|
|||||||
from .fpn import FPN
|
from .fpn import FPN
|
||||||
|
from .multilevel_neck import MultiLevelNeck
|
||||||
|
|
||||||
__all__ = ['FPN']
|
__all__ = ['FPN', 'MultiLevelNeck']
|
||||||
|
70
mmseg/models/necks/multilevel_neck.py
Normal file
70
mmseg/models/necks/multilevel_neck.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from mmcv.cnn import ConvModule
|
||||||
|
|
||||||
|
from ..builder import NECKS
|
||||||
|
|
||||||
|
|
||||||
|
@NECKS.register_module()
|
||||||
|
class MultiLevelNeck(nn.Module):
|
||||||
|
"""MultiLevelNeck.
|
||||||
|
|
||||||
|
A neck structure connect vit backbone and decoder_heads.
|
||||||
|
Args:
|
||||||
|
in_channels (List[int]): Number of input channels per scale.
|
||||||
|
out_channels (int): Number of output channels (used at each scale).
|
||||||
|
scales (List[int]): Scale factors for each input feature map.
|
||||||
|
norm_cfg (dict): Config dict for normalization layer. Default: None.
|
||||||
|
act_cfg (dict): Config dict for activation layer in ConvModule.
|
||||||
|
Default: None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
scales=[0.5, 1, 2, 4],
|
||||||
|
norm_cfg=None,
|
||||||
|
act_cfg=None):
|
||||||
|
super(MultiLevelNeck, self).__init__()
|
||||||
|
assert isinstance(in_channels, list)
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.scales = scales
|
||||||
|
self.num_outs = len(scales)
|
||||||
|
self.lateral_convs = nn.ModuleList()
|
||||||
|
self.convs = nn.ModuleList()
|
||||||
|
for in_channel in in_channels:
|
||||||
|
self.lateral_convs.append(
|
||||||
|
ConvModule(
|
||||||
|
in_channel,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
act_cfg=act_cfg))
|
||||||
|
for _ in range(self.num_outs):
|
||||||
|
self.convs.append(
|
||||||
|
ConvModule(
|
||||||
|
out_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
stride=1,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
act_cfg=act_cfg))
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
assert len(inputs) == len(self.in_channels)
|
||||||
|
print(inputs[0].shape)
|
||||||
|
inputs = [
|
||||||
|
lateral_conv(inputs[i])
|
||||||
|
for i, lateral_conv in enumerate(self.lateral_convs)
|
||||||
|
]
|
||||||
|
# for len(inputs) not equal to self.num_outs
|
||||||
|
if len(inputs) == 1:
|
||||||
|
inputs = [inputs[0] for _ in range(self.num_outs)]
|
||||||
|
outs = []
|
||||||
|
for i in range(self.num_outs):
|
||||||
|
x_resize = F.interpolate(
|
||||||
|
inputs[i], scale_factor=self.scales[i], mode='bilinear')
|
||||||
|
outs.append(self.convs[i](x_resize))
|
||||||
|
return tuple(outs)
|
28
tests/test_models/test_necks/test_multilevel_neck.py
Normal file
28
tests/test_models/test_necks/test_multilevel_neck.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from mmseg.models import MultiLevelNeck
|
||||||
|
|
||||||
|
|
||||||
|
def test_multilevel_neck():
|
||||||
|
|
||||||
|
# Test multi feature maps
|
||||||
|
in_channels = [256, 512, 1024, 2048]
|
||||||
|
inputs = [torch.randn(1, c, 14, 14) for i, c in enumerate(in_channels)]
|
||||||
|
|
||||||
|
neck = MultiLevelNeck(in_channels, 256)
|
||||||
|
outputs = neck(inputs)
|
||||||
|
assert outputs[0].shape == torch.Size([1, 256, 7, 7])
|
||||||
|
assert outputs[1].shape == torch.Size([1, 256, 14, 14])
|
||||||
|
assert outputs[2].shape == torch.Size([1, 256, 28, 28])
|
||||||
|
assert outputs[3].shape == torch.Size([1, 256, 56, 56])
|
||||||
|
|
||||||
|
# Test one feature map
|
||||||
|
in_channels = [768]
|
||||||
|
inputs = [torch.randn(1, 768, 14, 14)]
|
||||||
|
|
||||||
|
neck = MultiLevelNeck(in_channels, 256)
|
||||||
|
outputs = neck(inputs)
|
||||||
|
assert outputs[0].shape == torch.Size([1, 256, 7, 7])
|
||||||
|
assert outputs[1].shape == torch.Size([1, 256, 14, 14])
|
||||||
|
assert outputs[2].shape == torch.Size([1, 256, 28, 28])
|
||||||
|
assert outputs[3].shape == torch.Size([1, 256, 56, 56])
|
Loading…
x
Reference in New Issue
Block a user