mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Fix] Add setr & vit msg. (#635)
* [Fix] Add setr & vit msg. * Fix init bug * Modify init_cfg arg * Add conv_seg init
This commit is contained in:
parent
5876868a48
commit
60baa4e841
@ -63,6 +63,7 @@ Supported backbones:
|
|||||||
- [x] [ResNeSt (ArXiv'2020)](configs/resnest/README.md)
|
- [x] [ResNeSt (ArXiv'2020)](configs/resnest/README.md)
|
||||||
- [x] [MobileNetV2 (CVPR'2018)](configs/mobilenet_v2/README.md)
|
- [x] [MobileNetV2 (CVPR'2018)](configs/mobilenet_v2/README.md)
|
||||||
- [x] [MobileNetV3 (ICCV'2019)](configs/mobilenet_v3/README.md)
|
- [x] [MobileNetV3 (ICCV'2019)](configs/mobilenet_v3/README.md)
|
||||||
|
- [x] [Vision Transformer (ICLR'2021)]
|
||||||
|
|
||||||
Supported methods:
|
Supported methods:
|
||||||
|
|
||||||
@ -89,6 +90,7 @@ Supported methods:
|
|||||||
- [x] [DNLNet (ECCV'2020)](configs/dnlnet)
|
- [x] [DNLNet (ECCV'2020)](configs/dnlnet)
|
||||||
- [x] [PointRend (CVPR'2020)](configs/point_rend)
|
- [x] [PointRend (CVPR'2020)](configs/point_rend)
|
||||||
- [x] [CGNet (TIP'2020)](configs/cgnet)
|
- [x] [CGNet (TIP'2020)](configs/cgnet)
|
||||||
|
- [x] [SETR (CVPR'2021)](configs/setr)
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
|
@ -62,6 +62,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O
|
|||||||
- [x] [ResNeSt (ArXiv'2020)](configs/resnest/README.md)
|
- [x] [ResNeSt (ArXiv'2020)](configs/resnest/README.md)
|
||||||
- [x] [MobileNetV2 (CVPR'2018)](configs/mobilenet_v2/README.md)
|
- [x] [MobileNetV2 (CVPR'2018)](configs/mobilenet_v2/README.md)
|
||||||
- [x] [MobileNetV3 (ICCV'2019)](configs/mobilenet_v3/README.md)
|
- [x] [MobileNetV3 (ICCV'2019)](configs/mobilenet_v3/README.md)
|
||||||
|
- [x] [Vision Transformer (ICLR'2021)]
|
||||||
|
|
||||||
已支持的算法:
|
已支持的算法:
|
||||||
|
|
||||||
@ -87,6 +88,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O
|
|||||||
- [x] [DNLNet (ECCV'2020)](configs/dnlnet)
|
- [x] [DNLNet (ECCV'2020)](configs/dnlnet)
|
||||||
- [x] [PointRend (CVPR'2020)](configs/point_rend)
|
- [x] [PointRend (CVPR'2020)](configs/point_rend)
|
||||||
- [x] [CGNet (TIP'2020)](configs/cgnet)
|
- [x] [CGNet (TIP'2020)](configs/cgnet)
|
||||||
|
- [x] [SETR (CVPR'2021)](configs/setr)
|
||||||
|
|
||||||
## 安装
|
## 安装
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from mmcv.cnn import ConvModule, build_norm_layer, constant_init
|
from mmcv.cnn import ConvModule, build_norm_layer
|
||||||
|
|
||||||
from ..builder import HEADS
|
from ..builder import HEADS
|
||||||
from .decode_head import BaseDecodeHead
|
from .decode_head import BaseDecodeHead
|
||||||
@ -18,6 +18,9 @@ class SETRUPHead(BaseDecodeHead):
|
|||||||
up_scale (int): The scale factor of interpolate. Default:4.
|
up_scale (int): The scale factor of interpolate. Default:4.
|
||||||
kernel_size (int): The kernel size of convolution when decoding
|
kernel_size (int): The kernel size of convolution when decoding
|
||||||
feature information from backbone. Default: 3.
|
feature information from backbone. Default: 3.
|
||||||
|
init_cfg (dict | list[dict] | None): Initialization config dict.
|
||||||
|
Default: dict(
|
||||||
|
type='Constant', val=1.0, bias=0, layer='LayerNorm').
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -25,11 +28,18 @@ class SETRUPHead(BaseDecodeHead):
|
|||||||
num_convs=1,
|
num_convs=1,
|
||||||
up_scale=4,
|
up_scale=4,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
|
init_cfg=[
|
||||||
|
dict(type='Constant', val=1.0, bias=0, layer='LayerNorm'),
|
||||||
|
dict(
|
||||||
|
type='Normal',
|
||||||
|
std=0.01,
|
||||||
|
override=dict(name='conv_seg'))
|
||||||
|
],
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
|
||||||
assert kernel_size in [1, 3], 'kernel_size must be 1 or 3.'
|
assert kernel_size in [1, 3], 'kernel_size must be 1 or 3.'
|
||||||
|
|
||||||
super(SETRUPHead, self).__init__(**kwargs)
|
super(SETRUPHead, self).__init__(init_cfg=init_cfg, **kwargs)
|
||||||
|
|
||||||
assert isinstance(self.in_channels, int)
|
assert isinstance(self.in_channels, int)
|
||||||
|
|
||||||
@ -38,7 +48,7 @@ class SETRUPHead(BaseDecodeHead):
|
|||||||
self.up_convs = nn.ModuleList()
|
self.up_convs = nn.ModuleList()
|
||||||
in_channels = self.in_channels
|
in_channels = self.in_channels
|
||||||
out_channels = self.channels
|
out_channels = self.channels
|
||||||
for i in range(num_convs):
|
for _ in range(num_convs):
|
||||||
self.up_convs.append(
|
self.up_convs.append(
|
||||||
nn.Sequential(
|
nn.Sequential(
|
||||||
ConvModule(
|
ConvModule(
|
||||||
@ -55,12 +65,6 @@ class SETRUPHead(BaseDecodeHead):
|
|||||||
align_corners=self.align_corners)))
|
align_corners=self.align_corners)))
|
||||||
in_channels = out_channels
|
in_channels = out_channels
|
||||||
|
|
||||||
def init_weights(self):
|
|
||||||
for m in self.modules():
|
|
||||||
if isinstance(m, nn.LayerNorm):
|
|
||||||
constant_init(m.bias, 0)
|
|
||||||
constant_init(m.weight, 1.0)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self._transform_inputs(x)
|
x = self._transform_inputs(x)
|
||||||
|
|
||||||
|
@ -16,13 +16,14 @@ def test_setr_up_head(capsys):
|
|||||||
# as embed_dim.
|
# as embed_dim.
|
||||||
SETRUPHead(in_channels=(32, 32), channels=16, num_classes=19)
|
SETRUPHead(in_channels=(32, 32), channels=16, num_classes=19)
|
||||||
|
|
||||||
# test init_weights of head
|
# test init_cfg of head
|
||||||
head = SETRUPHead(
|
head = SETRUPHead(
|
||||||
in_channels=32,
|
in_channels=32,
|
||||||
channels=16,
|
channels=16,
|
||||||
norm_cfg=dict(type='SyncBN'),
|
norm_cfg=dict(type='SyncBN'),
|
||||||
num_classes=19)
|
num_classes=19,
|
||||||
head.init_weights()
|
init_cfg=dict(type='Kaiming'))
|
||||||
|
super(SETRUPHead, head).init_weights()
|
||||||
|
|
||||||
# test inference of Naive head
|
# test inference of Naive head
|
||||||
# the auxiliary head of Naive head is same as Naive head
|
# the auxiliary head of Naive head is same as Naive head
|
||||||
|
Loading…
x
Reference in New Issue
Block a user