mmsegmentation/mmseg/models/decode_heads/segmenter_mask_head.py
rstrudel cb1bf9f372
[Feature] Support Segmenter (#955)
* segmenter: add model

* update

* readme: update

* config: update

* segmenter: update readme

* segmenter: update

* segmenter: update

* segmenter: update

* configs: set checkpoint path to pretrain folder

* segmenter: modify vit-s/lin, remove data config

* rreadme: update

* configs: transfer from _base_ to segmenter

* configs: add 8x1 suffix

* configs: remove redundant lines

* configs: cleanup

* first attempt

* swipe CI error

* Update mmseg/models/decode_heads/__init__.py

Co-authored-by: Junjun2016 <hejunjun@sjtu.edu.cn>

* segmenter_linear: use fcn backbone

* segmenter_mask: update

* models: add segmenter vit

* decoders: yapf+remove unused imports

* apply precommit

* segmenter/linear_head: fix

* segmenter/linear_header: fix

* segmenter: fix mask transformer

* fix error

* segmenter/mask_head: use trunc_normal init

* refactor segmenter head

* Fetch upstream (#1)

* [Feature] Change options to cfg-option (#1129)

* [Feature] Change option to cfg-option

* add expire date and fix the docs

* modify docstring

* [Fix] Add <!-- [ABSTRACT] --> in metafile #1127

* [Fix] Fix correct num_classes of HRNet in LoveDA dataset #1136

* Bump to v0.20.1 (#1138)

* bump version 0.20.1

* bump version 0.20.1

* [Fix] revise --option to --options #1140

Co-authored-by: Rockey <41846794+RockeyCoss@users.noreply.github.com>
Co-authored-by: MengzhangLI <mcmong@pku.edu.cn>

* decode_head: switch from linear to fcn

* fix init list formatting

* configs: remove variants, keep only vit-s on ade

* align inference metric of vit-s-mask

* configs: add vit t/b/l

* Update mmseg/models/decode_heads/segmenter_mask_head.py

Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>

* Update mmseg/models/decode_heads/segmenter_mask_head.py

Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>

* Update mmseg/models/decode_heads/segmenter_mask_head.py

Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>

* Update mmseg/models/decode_heads/segmenter_mask_head.py

Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>

* Update mmseg/models/decode_heads/segmenter_mask_head.py

Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>

* model_converters: use torch instead of einops

* setup: remove einops

* segmenter_mask: fix missing imports

* add necessary imported init funtion

* segmenter/seg-l: set resolution to 640

* segmenter/seg-l: fix test size

* fix vitjax2mmseg

* add README and unittest

* fix unittest

* add docstring

* refactor config and add pretrained link

* fix typo

* add paper name in readme

* change segmenter config names

* fix typo in readme

* fix typos in readme

* fix segmenter typo

* fix segmenter typo

* delete redundant comma in config files

* delete redundant comma in config files

* fix convert script

* update lateset master version

Co-authored-by: MengzhangLI <mcmong@pku.edu.cn>
Co-authored-by: Junjun2016 <hejunjun@sjtu.edu.cn>
Co-authored-by: Rockey <41846794+RockeyCoss@users.noreply.github.com>
Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>
2022-01-26 13:50:51 +08:00

134 lines
4.8 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_norm_layer
from mmcv.cnn.utils.weight_init import (constant_init, trunc_normal_,
trunc_normal_init)
from mmcv.runner import ModuleList
from mmseg.models.backbones.vit import TransformerEncoderLayer
from ..builder import HEADS
from .decode_head import BaseDecodeHead
@HEADS.register_module()
class SegmenterMaskTransformerHead(BaseDecodeHead):
"""Segmenter: Transformer for Semantic Segmentation.
This head is the implementation of
`Segmenter: <https://arxiv.org/abs/2105.05633>`_.
Args:
backbone_cfg:(dict): Config of backbone of
Context Path.
in_channels (int): The number of channels of input image.
num_layers (int): The depth of transformer.
num_heads (int): The number of attention heads.
embed_dims (int): The number of embedding dimension.
mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
Default: 4.
drop_path_rate (float): stochastic depth rate. Default 0.1.
drop_rate (float): Probability of an element to be zeroed.
Default 0.0
attn_drop_rate (float): The drop out rate for attention layer.
Default 0.0
num_fcs (int): The number of fully-connected layers for FFNs.
Default: 2.
qkv_bias (bool): Enable bias for qkv if True. Default: True.
act_cfg (dict): The activation config for FFNs.
Default: dict(type='GELU').
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN')
init_std (float): The value of std in weight initialization.
Default: 0.02.
"""
def __init__(
self,
in_channels,
num_layers,
num_heads,
embed_dims,
mlp_ratio=4,
drop_path_rate=0.1,
drop_rate=0.0,
attn_drop_rate=0.0,
num_fcs=2,
qkv_bias=True,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
init_std=0.02,
**kwargs,
):
super(SegmenterMaskTransformerHead, self).__init__(
in_channels=in_channels, **kwargs)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)]
self.layers = ModuleList()
for i in range(num_layers):
self.layers.append(
TransformerEncoderLayer(
embed_dims=embed_dims,
num_heads=num_heads,
feedforward_channels=mlp_ratio * embed_dims,
attn_drop_rate=attn_drop_rate,
drop_rate=drop_rate,
drop_path_rate=dpr[i],
num_fcs=num_fcs,
qkv_bias=qkv_bias,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
batch_first=True,
))
self.dec_proj = nn.Linear(in_channels, embed_dims)
self.cls_emb = nn.Parameter(
torch.randn(1, self.num_classes, embed_dims))
self.patch_proj = nn.Linear(embed_dims, embed_dims, bias=False)
self.classes_proj = nn.Linear(embed_dims, embed_dims, bias=False)
self.decoder_norm = build_norm_layer(
norm_cfg, embed_dims, postfix=1)[1]
self.mask_norm = build_norm_layer(
norm_cfg, self.num_classes, postfix=2)[1]
self.init_std = init_std
delattr(self, 'conv_seg')
def init_weights(self):
trunc_normal_(self.cls_emb, std=self.init_std)
trunc_normal_init(self.patch_proj, std=self.init_std)
trunc_normal_init(self.classes_proj, std=self.init_std)
for n, m in self.named_modules():
if isinstance(m, nn.Linear):
trunc_normal_init(m, std=self.init_std, bias=0)
elif isinstance(m, nn.LayerNorm):
constant_init(m, val=1.0, bias=0.0)
def forward(self, inputs):
x = self._transform_inputs(inputs)
b, c, h, w = x.shape
x = x.permute(0, 2, 3, 1).contiguous().view(b, -1, c)
x = self.dec_proj(x)
cls_emb = self.cls_emb.expand(x.size(0), -1, -1)
x = torch.cat((x, cls_emb), 1)
for layer in self.layers:
x = layer(x)
x = self.decoder_norm(x)
patches = self.patch_proj(x[:, :-self.num_classes])
cls_seg_feat = self.classes_proj(x[:, -self.num_classes:])
patches = F.normalize(patches, dim=2, p=2)
cls_seg_feat = F.normalize(cls_seg_feat, dim=2, p=2)
masks = patches @ cls_seg_feat.transpose(1, 2)
masks = self.mask_norm(masks)
masks = masks.permute(0, 2, 1).contiguous().view(b, -1, h, w)
return masks