[Fix] Add setr & vit msg. (#635)

* [Fix] Add setr & vit msg.

* Fix init bug

* Modify init_cfg arg

* Add conv_seg init
This commit is contained in:
sennnnn 2021-06-24 13:25:06 +08:00 committed by GitHub
parent 5876868a48
commit 60baa4e841
4 changed files with 21 additions and 12 deletions

View File

@ -63,6 +63,7 @@ Supported backbones:
- [x] [ResNeSt (ArXiv'2020)](configs/resnest/README.md)
- [x] [MobileNetV2 (CVPR'2018)](configs/mobilenet_v2/README.md)
- [x] [MobileNetV3 (ICCV'2019)](configs/mobilenet_v3/README.md)
- [x] [Vision Transformer (ICLR'2021)]
Supported methods:
@ -89,6 +90,7 @@ Supported methods:
- [x] [DNLNet (ECCV'2020)](configs/dnlnet)
- [x] [PointRend (CVPR'2020)](configs/point_rend)
- [x] [CGNet (TIP'2020)](configs/cgnet)
- [x] [SETR (CVPR'2021)](configs/setr)
## Installation

View File

@ -62,6 +62,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O
- [x] [ResNeSt (ArXiv'2020)](configs/resnest/README.md)
- [x] [MobileNetV2 (CVPR'2018)](configs/mobilenet_v2/README.md)
- [x] [MobileNetV3 (ICCV'2019)](configs/mobilenet_v3/README.md)
- [x] [Vision Transformer (ICLR'2021)]
已支持的算法:
@ -87,6 +88,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O
- [x] [DNLNet (ECCV'2020)](configs/dnlnet)
- [x] [PointRend (CVPR'2020)](configs/point_rend)
- [x] [CGNet (TIP'2020)](configs/cgnet)
- [x] [SETR (CVPR'2021)](configs/setr)
## 安装

View File

@ -1,5 +1,5 @@
import torch.nn as nn
from mmcv.cnn import ConvModule, build_norm_layer, constant_init
from mmcv.cnn import ConvModule, build_norm_layer
from ..builder import HEADS
from .decode_head import BaseDecodeHead
@ -18,6 +18,9 @@ class SETRUPHead(BaseDecodeHead):
up_scale (int): The scale factor of interpolate. Default:4.
kernel_size (int): The kernel size of convolution when decoding
feature information from backbone. Default: 3.
init_cfg (dict | list[dict] | None): Initialization config dict.
Default: dict(
type='Constant', val=1.0, bias=0, layer='LayerNorm').
"""
def __init__(self,
@ -25,11 +28,18 @@ class SETRUPHead(BaseDecodeHead):
num_convs=1,
up_scale=4,
kernel_size=3,
init_cfg=[
dict(type='Constant', val=1.0, bias=0, layer='LayerNorm'),
dict(
type='Normal',
std=0.01,
override=dict(name='conv_seg'))
],
**kwargs):
assert kernel_size in [1, 3], 'kernel_size must be 1 or 3.'
super(SETRUPHead, self).__init__(**kwargs)
super(SETRUPHead, self).__init__(init_cfg=init_cfg, **kwargs)
assert isinstance(self.in_channels, int)
@ -38,7 +48,7 @@ class SETRUPHead(BaseDecodeHead):
self.up_convs = nn.ModuleList()
in_channels = self.in_channels
out_channels = self.channels
for i in range(num_convs):
for _ in range(num_convs):
self.up_convs.append(
nn.Sequential(
ConvModule(
@ -55,12 +65,6 @@ class SETRUPHead(BaseDecodeHead):
align_corners=self.align_corners)))
in_channels = out_channels
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.LayerNorm):
constant_init(m.bias, 0)
constant_init(m.weight, 1.0)
def forward(self, x):
x = self._transform_inputs(x)

View File

@ -16,13 +16,14 @@ def test_setr_up_head(capsys):
# as embed_dim.
SETRUPHead(in_channels=(32, 32), channels=16, num_classes=19)
# test init_weights of head
# test init_cfg of head
head = SETRUPHead(
in_channels=32,
channels=16,
norm_cfg=dict(type='SyncBN'),
num_classes=19)
head.init_weights()
num_classes=19,
init_cfg=dict(type='Kaiming'))
super(SETRUPHead, head).init_weights()
# test inference of Naive head
# the auxiliary head of Naive head is same as Naive head