diff --git a/README.md b/README.md index 4b1eade1d..bbf24ec85 100644 --- a/README.md +++ b/README.md @@ -63,6 +63,7 @@ Supported backbones: - [x] [ResNeSt (ArXiv'2020)](configs/resnest/README.md) - [x] [MobileNetV2 (CVPR'2018)](configs/mobilenet_v2/README.md) - [x] [MobileNetV3 (ICCV'2019)](configs/mobilenet_v3/README.md) +- [x] [Vision Transformer (ICLR'2021)] Supported methods: @@ -89,6 +90,7 @@ Supported methods: - [x] [DNLNet (ECCV'2020)](configs/dnlnet) - [x] [PointRend (CVPR'2020)](configs/point_rend) - [x] [CGNet (TIP'2020)](configs/cgnet) +- [x] [SETR (CVPR'2021)](configs/setr) ## Installation diff --git a/README_zh-CN.md b/README_zh-CN.md index 283a045b9..2341e4768 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -62,6 +62,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O - [x] [ResNeSt (ArXiv'2020)](configs/resnest/README.md) - [x] [MobileNetV2 (CVPR'2018)](configs/mobilenet_v2/README.md) - [x] [MobileNetV3 (ICCV'2019)](configs/mobilenet_v3/README.md) +- [x] [Vision Transformer (ICLR'2021)] 已支持的算法: @@ -87,6 +88,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O - [x] [DNLNet (ECCV'2020)](configs/dnlnet) - [x] [PointRend (CVPR'2020)](configs/point_rend) - [x] [CGNet (TIP'2020)](configs/cgnet) +- [x] [SETR (CVPR'2021)](configs/setr) ## 安装 diff --git a/mmseg/models/decode_heads/setr_up_head.py b/mmseg/models/decode_heads/setr_up_head.py index 2088ec7d7..322a56dc7 100644 --- a/mmseg/models/decode_heads/setr_up_head.py +++ b/mmseg/models/decode_heads/setr_up_head.py @@ -1,5 +1,5 @@ import torch.nn as nn -from mmcv.cnn import ConvModule, build_norm_layer, constant_init +from mmcv.cnn import ConvModule, build_norm_layer from ..builder import HEADS from .decode_head import BaseDecodeHead @@ -18,6 +18,9 @@ class SETRUPHead(BaseDecodeHead): up_scale (int): The scale factor of interpolate. Default:4. kernel_size (int): The kernel size of convolution when decoding feature information from backbone. Default: 3. + init_cfg (dict | list[dict] | None): Initialization config dict. + Default: dict( + type='Constant', val=1.0, bias=0, layer='LayerNorm'). """ def __init__(self, @@ -25,11 +28,18 @@ class SETRUPHead(BaseDecodeHead): num_convs=1, up_scale=4, kernel_size=3, + init_cfg=[ + dict(type='Constant', val=1.0, bias=0, layer='LayerNorm'), + dict( + type='Normal', + std=0.01, + override=dict(name='conv_seg')) + ], **kwargs): assert kernel_size in [1, 3], 'kernel_size must be 1 or 3.' - super(SETRUPHead, self).__init__(**kwargs) + super(SETRUPHead, self).__init__(init_cfg=init_cfg, **kwargs) assert isinstance(self.in_channels, int) @@ -38,7 +48,7 @@ class SETRUPHead(BaseDecodeHead): self.up_convs = nn.ModuleList() in_channels = self.in_channels out_channels = self.channels - for i in range(num_convs): + for _ in range(num_convs): self.up_convs.append( nn.Sequential( ConvModule( @@ -55,12 +65,6 @@ class SETRUPHead(BaseDecodeHead): align_corners=self.align_corners))) in_channels = out_channels - def init_weights(self): - for m in self.modules(): - if isinstance(m, nn.LayerNorm): - constant_init(m.bias, 0) - constant_init(m.weight, 1.0) - def forward(self, x): x = self._transform_inputs(x) diff --git a/tests/test_models/test_heads/test_setr_up_head.py b/tests/test_models/test_heads/test_setr_up_head.py index 4b89621da..ad6ca56d2 100644 --- a/tests/test_models/test_heads/test_setr_up_head.py +++ b/tests/test_models/test_heads/test_setr_up_head.py @@ -16,13 +16,14 @@ def test_setr_up_head(capsys): # as embed_dim. SETRUPHead(in_channels=(32, 32), channels=16, num_classes=19) - # test init_weights of head + # test init_cfg of head head = SETRUPHead( in_channels=32, channels=16, norm_cfg=dict(type='SyncBN'), - num_classes=19) - head.init_weights() + num_classes=19, + init_cfg=dict(type='Kaiming')) + super(SETRUPHead, head).init_weights() # test inference of Naive head # the auxiliary head of Naive head is same as Naive head