mmsegmentation/mmseg/models/decode_heads/segmenter_mask_head.py

134 lines
4.8 KiB
Python
Raw Normal View History

[Feature] Support Segmenter (#955) * segmenter: add model * update * readme: update * config: update * segmenter: update readme * segmenter: update * segmenter: update * segmenter: update * configs: set checkpoint path to pretrain folder * segmenter: modify vit-s/lin, remove data config * rreadme: update * configs: transfer from _base_ to segmenter * configs: add 8x1 suffix * configs: remove redundant lines * configs: cleanup * first attempt * swipe CI error * Update mmseg/models/decode_heads/__init__.py Co-authored-by: Junjun2016 <hejunjun@sjtu.edu.cn> * segmenter_linear: use fcn backbone * segmenter_mask: update * models: add segmenter vit * decoders: yapf+remove unused imports * apply precommit * segmenter/linear_head: fix * segmenter/linear_header: fix * segmenter: fix mask transformer * fix error * segmenter/mask_head: use trunc_normal init * refactor segmenter head * Fetch upstream (#1) * [Feature] Change options to cfg-option (#1129) * [Feature] Change option to cfg-option * add expire date and fix the docs * modify docstring * [Fix] Add <!-- [ABSTRACT] --> in metafile #1127 * [Fix] Fix correct num_classes of HRNet in LoveDA dataset #1136 * Bump to v0.20.1 (#1138) * bump version 0.20.1 * bump version 0.20.1 * [Fix] revise --option to --options #1140 Co-authored-by: Rockey <41846794+RockeyCoss@users.noreply.github.com> Co-authored-by: MengzhangLI <mcmong@pku.edu.cn> * decode_head: switch from linear to fcn * fix init list formatting * configs: remove variants, keep only vit-s on ade * align inference metric of vit-s-mask * configs: add vit t/b/l * Update mmseg/models/decode_heads/segmenter_mask_head.py Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com> * Update mmseg/models/decode_heads/segmenter_mask_head.py Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com> * Update mmseg/models/decode_heads/segmenter_mask_head.py Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com> * Update mmseg/models/decode_heads/segmenter_mask_head.py Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com> * Update mmseg/models/decode_heads/segmenter_mask_head.py Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com> * model_converters: use torch instead of einops * setup: remove einops * segmenter_mask: fix missing imports * add necessary imported init funtion * segmenter/seg-l: set resolution to 640 * segmenter/seg-l: fix test size * fix vitjax2mmseg * add README and unittest * fix unittest * add docstring * refactor config and add pretrained link * fix typo * add paper name in readme * change segmenter config names * fix typo in readme * fix typos in readme * fix segmenter typo * fix segmenter typo * delete redundant comma in config files * delete redundant comma in config files * fix convert script * update lateset master version Co-authored-by: MengzhangLI <mcmong@pku.edu.cn> Co-authored-by: Junjun2016 <hejunjun@sjtu.edu.cn> Co-authored-by: Rockey <41846794+RockeyCoss@users.noreply.github.com> Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>
2022-01-26 13:50:51 +08:00
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_norm_layer
from mmcv.cnn.utils.weight_init import (constant_init, trunc_normal_,
trunc_normal_init)
from mmcv.runner import ModuleList
from mmseg.models.backbones.vit import TransformerEncoderLayer
from ..builder import HEADS
from .decode_head import BaseDecodeHead
@HEADS.register_module()
class SegmenterMaskTransformerHead(BaseDecodeHead):
"""Segmenter: Transformer for Semantic Segmentation.
This head is the implementation of
`Segmenter: <https://arxiv.org/abs/2105.05633>`_.
Args:
backbone_cfg:(dict): Config of backbone of
Context Path.
in_channels (int): The number of channels of input image.
num_layers (int): The depth of transformer.
num_heads (int): The number of attention heads.
embed_dims (int): The number of embedding dimension.
mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
Default: 4.
drop_path_rate (float): stochastic depth rate. Default 0.1.
drop_rate (float): Probability of an element to be zeroed.
Default 0.0
attn_drop_rate (float): The drop out rate for attention layer.
Default 0.0
num_fcs (int): The number of fully-connected layers for FFNs.
Default: 2.
qkv_bias (bool): Enable bias for qkv if True. Default: True.
act_cfg (dict): The activation config for FFNs.
Default: dict(type='GELU').
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN')
init_std (float): The value of std in weight initialization.
Default: 0.02.
"""
def __init__(
self,
in_channels,
num_layers,
num_heads,
embed_dims,
mlp_ratio=4,
drop_path_rate=0.1,
drop_rate=0.0,
attn_drop_rate=0.0,
num_fcs=2,
qkv_bias=True,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
init_std=0.02,
**kwargs,
):
super(SegmenterMaskTransformerHead, self).__init__(
in_channels=in_channels, **kwargs)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)]
self.layers = ModuleList()
for i in range(num_layers):
self.layers.append(
TransformerEncoderLayer(
embed_dims=embed_dims,
num_heads=num_heads,
feedforward_channels=mlp_ratio * embed_dims,
attn_drop_rate=attn_drop_rate,
drop_rate=drop_rate,
drop_path_rate=dpr[i],
num_fcs=num_fcs,
qkv_bias=qkv_bias,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
batch_first=True,
))
self.dec_proj = nn.Linear(in_channels, embed_dims)
self.cls_emb = nn.Parameter(
torch.randn(1, self.num_classes, embed_dims))
self.patch_proj = nn.Linear(embed_dims, embed_dims, bias=False)
self.classes_proj = nn.Linear(embed_dims, embed_dims, bias=False)
self.decoder_norm = build_norm_layer(
norm_cfg, embed_dims, postfix=1)[1]
self.mask_norm = build_norm_layer(
norm_cfg, self.num_classes, postfix=2)[1]
self.init_std = init_std
delattr(self, 'conv_seg')
def init_weights(self):
trunc_normal_(self.cls_emb, std=self.init_std)
trunc_normal_init(self.patch_proj, std=self.init_std)
trunc_normal_init(self.classes_proj, std=self.init_std)
for n, m in self.named_modules():
if isinstance(m, nn.Linear):
trunc_normal_init(m, std=self.init_std, bias=0)
elif isinstance(m, nn.LayerNorm):
constant_init(m, val=1.0, bias=0.0)
def forward(self, inputs):
x = self._transform_inputs(inputs)
b, c, h, w = x.shape
x = x.permute(0, 2, 3, 1).contiguous().view(b, -1, c)
x = self.dec_proj(x)
cls_emb = self.cls_emb.expand(x.size(0), -1, -1)
x = torch.cat((x, cls_emb), 1)
for layer in self.layers:
x = layer(x)
x = self.decoder_norm(x)
patches = self.patch_proj(x[:, :-self.num_classes])
cls_seg_feat = self.classes_proj(x[:, -self.num_classes:])
patches = F.normalize(patches, dim=2, p=2)
cls_seg_feat = F.normalize(cls_seg_feat, dim=2, p=2)
masks = patches @ cls_seg_feat.transpose(1, 2)
masks = self.mask_norm(masks)
masks = masks.permute(0, 2, 1).contiguous().view(b, -1, h, w)
return masks