mmsegmentation/mmseg/models/decode_heads/segmenter_mask_head.py

134 lines
4.8 KiB
Python
Raw Blame History

This file contains invisible Unicode characters!

This file contains invisible Unicode characters that may be processed differently from what appears below. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to reveal hidden characters.

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_norm_layer
from mmcv.cnn.utils.weight_init import (constant_init, trunc_normal_,
trunc_normal_init)
from mmcv.runner import ModuleList
from mmseg.models.backbones.vit import TransformerEncoderLayer
from ..builder import HEADS
from .decode_head import BaseDecodeHead
@HEADS.register_module()
class SegmenterMaskTransformerHead(BaseDecodeHead):
"""Segmenter: Transformer for Semantic Segmentation.
This head is the implementation of
`Segmenter: <https://arxiv.org/abs/2105.05633>`_.
Args:
backbone_cfg:(dict): Config of backbone of
Context Path.
in_channels (int): The number of channels of input image.
num_layers (int): The depth of transformer.
num_heads (int): The number of attention heads.
embed_dims (int): The number of embedding dimension.
mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
Default: 4.
drop_path_rate (float): stochastic depth rate. Default 0.1.
drop_rate (float): Probability of an element to be zeroed.
Default 0.0
attn_drop_rate (float): The drop out rate for attention layer.
Default 0.0
num_fcs (int): The number of fully-connected layers for FFNs.
Default: 2.
qkv_bias (bool): Enable bias for qkv if True. Default: True.
act_cfg (dict): The activation config for FFNs.
Default: dict(type='GELU').
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN')
init_std (float): The value of std in weight initialization.
Default: 0.02.
"""
def __init__(
self,
in_channels,
num_layers,
num_heads,
embed_dims,
mlp_ratio=4,
drop_path_rate=0.1,
drop_rate=0.0,
attn_drop_rate=0.0,
num_fcs=2,
qkv_bias=True,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
init_std=0.02,
**kwargs,
):
super(SegmenterMaskTransformerHead, self).__init__(
in_channels=in_channels, **kwargs)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)]
self.layers = ModuleList()
for i in range(num_layers):
self.layers.append(
TransformerEncoderLayer(
embed_dims=embed_dims,
num_heads=num_heads,
feedforward_channels=mlp_ratio * embed_dims,
attn_drop_rate=attn_drop_rate,
drop_rate=drop_rate,
drop_path_rate=dpr[i],
num_fcs=num_fcs,
qkv_bias=qkv_bias,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
batch_first=True,
))
self.dec_proj = nn.Linear(in_channels, embed_dims)
self.cls_emb = nn.Parameter(
torch.randn(1, self.num_classes, embed_dims))
self.patch_proj = nn.Linear(embed_dims, embed_dims, bias=False)
self.classes_proj = nn.Linear(embed_dims, embed_dims, bias=False)
self.decoder_norm = build_norm_layer(
norm_cfg, embed_dims, postfix=1)[1]
self.mask_norm = build_norm_layer(
norm_cfg, self.num_classes, postfix=2)[1]
self.init_std = init_std
delattr(self, 'conv_seg')
def init_weights(self):
trunc_normal_(self.cls_emb, std=self.init_std)
trunc_normal_init(self.patch_proj, std=self.init_std)
trunc_normal_init(self.classes_proj, std=self.init_std)
for n, m in self.named_modules():
if isinstance(m, nn.Linear):
trunc_normal_init(m, std=self.init_std, bias=0)
elif isinstance(m, nn.LayerNorm):
constant_init(m, val=1.0, bias=0.0)
def forward(self, inputs):
x = self._transform_inputs(inputs)
b, c, h, w = x.shape
x = x.permute(0, 2, 3, 1).contiguous().view(b, -1, c)
x = self.dec_proj(x)
cls_emb = self.cls_emb.expand(x.size(0), -1, -1)
x = torch.cat((x, cls_emb), 1)
for layer in self.layers:
x = layer(x)
x = self.decoder_norm(x)
patches = self.patch_proj(x[:, :-self.num_classes])
cls_seg_feat = self.classes_proj(x[:, -self.num_classes:])
patches = F.normalize(patches, dim=2, p=2)
cls_seg_feat = F.normalize(cls_seg_feat, dim=2, p=2)
masks = patches @ cls_seg_feat.transpose(1, 2)
masks = self.mask_norm(masks)
masks = masks.permute(0, 2, 1).contiguous().view(b, -1, h, w)
return masks