mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
558 lines
21 KiB
Python
558 lines
21 KiB
Python
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||
|
import math
|
||
|
import warnings
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer
|
||
|
from mmcv.cnn.bricks.drop import build_dropout
|
||
|
from mmcv.cnn.utils.weight_init import trunc_normal_init
|
||
|
from mmcv.runner import BaseModule, ModuleList, Sequential
|
||
|
from mmcv.utils import to_2tuple
|
||
|
|
||
|
from mmseg.models import BACKBONES
|
||
|
from mmseg.models.utils.embed import AdaptivePadding
|
||
|
|
||
|
|
||
|
class SimplifiedPatchEmbed(BaseModule):
|
||
|
"""Image to Patch Embedding.
|
||
|
|
||
|
We use a conv layer to implement SimplifiedPatchEmbed.
|
||
|
|
||
|
Args:
|
||
|
in_channels (int): The num of input channels. Default: 3
|
||
|
embed_dims (int): The dimensions of embedding. Default: 768
|
||
|
conv_type (str): The config dict for embedding
|
||
|
conv layer type selection. Default: "Conv2d".
|
||
|
kernel_size (int): The kernel_size of embedding conv. Default: 16.
|
||
|
stride (int, optional): The slide stride of embedding conv.
|
||
|
Default: None (Would be set as `kernel_size`).
|
||
|
padding (int | tuple | string ): The padding length of
|
||
|
embedding conv. When it is a string, it means the mode
|
||
|
of adaptive padding, support "same" and "corner" now.
|
||
|
Default: "corner".
|
||
|
dilation (int): The dilation rate of embedding conv. Default: 1.
|
||
|
bias (bool): Bias of embed conv. Default: True.
|
||
|
norm_cfg (dict, optional): Config dict for normalization layer.
|
||
|
Default: None.
|
||
|
input_size (int | tuple | None): The size of input, which will be
|
||
|
used to calculate the out size. Only work when `dynamic_size`
|
||
|
is False. Default: None.
|
||
|
init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization.
|
||
|
Default: None.
|
||
|
"""
|
||
|
|
||
|
def __init__(self,
|
||
|
in_channels=3,
|
||
|
embed_dims=768,
|
||
|
conv_type='Conv2d',
|
||
|
kernel_size=16,
|
||
|
stride=None,
|
||
|
padding='corner',
|
||
|
dilation=1,
|
||
|
bias=True,
|
||
|
norm_cfg=None,
|
||
|
input_size=None,
|
||
|
init_cfg=None):
|
||
|
super(SimplifiedPatchEmbed, self).__init__(init_cfg=init_cfg)
|
||
|
|
||
|
self.embed_dims = embed_dims
|
||
|
if stride is None:
|
||
|
stride = kernel_size
|
||
|
|
||
|
kernel_size = to_2tuple(kernel_size)
|
||
|
stride = to_2tuple(stride)
|
||
|
dilation = to_2tuple(dilation)
|
||
|
|
||
|
if isinstance(padding, str):
|
||
|
self.adap_padding = AdaptivePadding(
|
||
|
kernel_size=kernel_size,
|
||
|
stride=stride,
|
||
|
dilation=dilation,
|
||
|
padding=padding)
|
||
|
# disable the padding of conv
|
||
|
padding = 0
|
||
|
else:
|
||
|
self.adap_padding = None
|
||
|
padding = to_2tuple(padding)
|
||
|
|
||
|
self.projection = build_conv_layer(
|
||
|
dict(type=conv_type),
|
||
|
in_channels=in_channels,
|
||
|
out_channels=embed_dims,
|
||
|
kernel_size=kernel_size,
|
||
|
stride=stride,
|
||
|
padding=padding,
|
||
|
dilation=dilation,
|
||
|
bias=bias)
|
||
|
|
||
|
if norm_cfg is not None:
|
||
|
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
|
||
|
else:
|
||
|
self.norm = None
|
||
|
|
||
|
if input_size:
|
||
|
input_size = to_2tuple(input_size)
|
||
|
# `init_out_size` would be used outside to
|
||
|
# calculate the num_patches
|
||
|
# when `use_abs_pos_embed` outside
|
||
|
self.init_input_size = input_size
|
||
|
if self.adap_padding:
|
||
|
pad_h, pad_w = self.adap_padding.get_pad_shape(input_size)
|
||
|
input_h, input_w = input_size
|
||
|
input_h = input_h + pad_h
|
||
|
input_w = input_w + pad_w
|
||
|
input_size = (input_h, input_w)
|
||
|
|
||
|
# https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
|
||
|
h_out = (input_size[0] + 2 * padding[0] - dilation[0] *
|
||
|
(kernel_size[0] - 1) - 1) // stride[0] + 1
|
||
|
w_out = (input_size[1] + 2 * padding[1] - dilation[1] *
|
||
|
(kernel_size[1] - 1) - 1) // stride[1] + 1
|
||
|
self.init_out_size = (h_out, w_out)
|
||
|
else:
|
||
|
self.init_input_size = None
|
||
|
self.init_out_size = None
|
||
|
self.apply(self._init_weights)
|
||
|
|
||
|
def _init_weights(self, m):
|
||
|
if isinstance(m, nn.Linear):
|
||
|
trunc_normal_init(m.weight, std=.02)
|
||
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
||
|
nn.init.constant_(m.bias, 0)
|
||
|
elif isinstance(m, nn.Conv2d):
|
||
|
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||
|
fan_out //= m.groups
|
||
|
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||
|
if m.bias is not None:
|
||
|
m.bias.data.zero_()
|
||
|
|
||
|
def forward(self, x):
|
||
|
"""
|
||
|
Args:
|
||
|
x (Tensor): Has shape (B, C, H, W). In most case, C is 3.
|
||
|
|
||
|
Returns:
|
||
|
tuple: Contains merged results and its spatial shape.
|
||
|
|
||
|
- x (Tensor): Has shape (B, embed_dims, out_h * out_w)
|
||
|
- out_size (tuple[int]): Spatial shape of x, arrange as
|
||
|
(out_h, out_w).
|
||
|
"""
|
||
|
|
||
|
if self.adap_padding:
|
||
|
x = self.adap_padding(x)
|
||
|
|
||
|
x = self.projection(x)
|
||
|
out_size = (x.shape[2], x.shape[3])
|
||
|
if self.norm is not None:
|
||
|
x = self.norm(x)
|
||
|
x = x.flatten(2)
|
||
|
return x, out_size
|
||
|
|
||
|
|
||
|
class DWConv(nn.Module):
|
||
|
|
||
|
def __init__(self, dims):
|
||
|
super(DWConv, self).__init__()
|
||
|
self.dwconv = nn.Conv2d(dims, dims, 3, 1, 1, bias=True, groups=dims)
|
||
|
|
||
|
def forward(self, x, H, W):
|
||
|
B, C, N = x.shape
|
||
|
x = x.reshape(B, C, H, W)
|
||
|
x = self.dwconv(x)
|
||
|
x = x.flatten(2)
|
||
|
return x
|
||
|
|
||
|
|
||
|
class MixFFN(nn.Module):
|
||
|
"""An implementation of MixFFN of DEST.
|
||
|
|
||
|
The differences between MixFFN & FFN:
|
||
|
1. Use 1X1 Conv to replace Linear layer.
|
||
|
2. Introduce 3X3 Conv to encode positional information.
|
||
|
Args:
|
||
|
embed_dims (int): The feature dimension. Same as
|
||
|
`MultiheadAttention`. Defaults: 256.
|
||
|
feedforward_channels (int): The hidden dimension of FFNs.
|
||
|
Defaults: 1024.
|
||
|
act_cfg (dict, optional): The activation config for FFNs.
|
||
|
Default: dict(type='ReLU')
|
||
|
ffn_drop (float, optional): Probability of an element to be
|
||
|
zeroed in FFN. Default 0.0.
|
||
|
dropout_layer (obj:`ConfigDict`): The dropout_layer used
|
||
|
when adding the shortcut.
|
||
|
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
||
|
Default: None.
|
||
|
"""
|
||
|
|
||
|
def __init__(self,
|
||
|
embed_dims,
|
||
|
feedforward_channels,
|
||
|
act_cfg=dict(type='ReLU'),
|
||
|
ffn_drop=0.,
|
||
|
norm_cfg=dict(type='SyncBN', requires_grad=True),
|
||
|
dropout_layer=None,
|
||
|
init_cfg=None):
|
||
|
super(MixFFN, self).__init__()
|
||
|
|
||
|
self.embed_dims = embed_dims
|
||
|
self.feedforward_channels = feedforward_channels
|
||
|
self.act_cfg = act_cfg
|
||
|
activate = build_activation_layer(act_cfg)
|
||
|
|
||
|
in_channels = embed_dims
|
||
|
fc1 = nn.Conv1d(
|
||
|
in_channels=in_channels,
|
||
|
out_channels=feedforward_channels,
|
||
|
kernel_size=1,
|
||
|
stride=1)
|
||
|
norm1 = build_norm_layer(norm_cfg, feedforward_channels)[1]
|
||
|
self.dwconv = DWConv(feedforward_channels)
|
||
|
norm2 = build_norm_layer(norm_cfg, feedforward_channels)[1]
|
||
|
fc2 = nn.Conv1d(
|
||
|
in_channels=feedforward_channels,
|
||
|
out_channels=in_channels,
|
||
|
kernel_size=1,
|
||
|
stride=1)
|
||
|
drop = nn.Dropout(ffn_drop)
|
||
|
pre_layers = [fc1, norm1]
|
||
|
post_layers = [norm2, activate, drop, fc2, drop]
|
||
|
self.pre_layers = Sequential(*pre_layers)
|
||
|
self.post_layers = Sequential(*post_layers)
|
||
|
self.dropout_layer = build_dropout(
|
||
|
dropout_layer) if dropout_layer else torch.nn.Identity()
|
||
|
self.apply(self._init_weights)
|
||
|
|
||
|
def _init_weights(self, m):
|
||
|
if isinstance(m, nn.Conv1d):
|
||
|
trunc_normal_init(m, std=.02, bias=0.)
|
||
|
|
||
|
def forward(self, x, hw_shape, identity):
|
||
|
out = self.pre_layers(x)
|
||
|
out = self.dwconv(out, hw_shape[0], hw_shape[1])
|
||
|
out = self.post_layers(out)
|
||
|
return identity + self.dropout_layer(out)
|
||
|
|
||
|
|
||
|
class SimplifiedAttention(nn.Module):
|
||
|
"""An implementation of Simplified Multi-head Attention of DEST.
|
||
|
|
||
|
Args:
|
||
|
embed_dims (int): The embedding dimension.
|
||
|
num_heads (int): Parallel attention heads.
|
||
|
attn_drop (float): A Dropout layer on attn_output_weights.
|
||
|
Default: 0.0.
|
||
|
proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
|
||
|
Default: 0.0.
|
||
|
sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head
|
||
|
Attention of Segformer. Default: 1.
|
||
|
qkv_bias (bool): enable bias for qkv if True. Default True.
|
||
|
qk_scale (float, optional): scales for query and key. Default: None.
|
||
|
norm_cfg (dict): Config dict for normalization layer.
|
||
|
Default: dict(type='SyncBN', requires_grad=True).
|
||
|
"""
|
||
|
|
||
|
def __init__(self,
|
||
|
embed_dims,
|
||
|
num_heads,
|
||
|
attn_drop=0.,
|
||
|
proj_drop=0.,
|
||
|
sr_ratio=1,
|
||
|
qkv_bias=False,
|
||
|
qk_scale=None,
|
||
|
dropout_layer=None,
|
||
|
norm_cfg=dict(type='SyncBN', requires_grad=True)):
|
||
|
super().__init__()
|
||
|
|
||
|
self.embed_dims = embed_dims
|
||
|
self.num_heads = num_heads
|
||
|
head_dim = embed_dims // num_heads
|
||
|
|
||
|
self.scale = qk_scale or head_dim**-0.5
|
||
|
|
||
|
self.q = nn.Conv1d(embed_dims, embed_dims, 1, bias=qkv_bias)
|
||
|
self.k = nn.Conv1d(embed_dims, embed_dims, 1, bias=qkv_bias)
|
||
|
|
||
|
self.attn_drop = nn.Dropout(attn_drop)
|
||
|
self.proj = nn.Conv1d(embed_dims, embed_dims, 1)
|
||
|
self.proj_drop = nn.Dropout(proj_drop)
|
||
|
|
||
|
self.sr_ratio = sr_ratio
|
||
|
if sr_ratio > 1:
|
||
|
self.sr = nn.Conv2d(
|
||
|
embed_dims, embed_dims, kernel_size=sr_ratio, stride=sr_ratio)
|
||
|
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
|
||
|
self.dropout_layer = build_dropout(
|
||
|
dropout_layer) if dropout_layer else torch.nn.Identity()
|
||
|
self.apply(self._init_weights)
|
||
|
|
||
|
def _init_weights(self, m):
|
||
|
if isinstance(m, nn.Linear):
|
||
|
trunc_normal_init(m, std=.02, bias=0.)
|
||
|
elif isinstance(m, nn.Conv2d):
|
||
|
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||
|
fan_out //= m.groups
|
||
|
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||
|
if m.bias is not None:
|
||
|
m.bias.data.zero_()
|
||
|
elif isinstance(m, nn.Conv1d):
|
||
|
trunc_normal_init(m, std=.02, bias=0.)
|
||
|
|
||
|
def forward(self, x, hw_shape, identity):
|
||
|
H, W = hw_shape
|
||
|
B, C, N = x.shape
|
||
|
q = self.q(x)
|
||
|
q = q.reshape(B, self.num_heads, C // self.num_heads, N)
|
||
|
q = q.permute(0, 1, 3, 2)
|
||
|
|
||
|
if self.sr_ratio > 1:
|
||
|
x_ = x.reshape(B, C, H, W)
|
||
|
x_ = self.sr(x_).reshape(B, C, -1)
|
||
|
x_ = self.norm1(x_)
|
||
|
k = self.k(x_).reshape(B, self.num_heads, C // self.num_heads, -1)
|
||
|
else:
|
||
|
k = self.k(x).reshape(B, self.num_heads, C // self.num_heads, -1)
|
||
|
|
||
|
v = torch.mean(x, 2, True).repeat(1, 1,
|
||
|
self.num_heads).transpose(-2, -1)
|
||
|
attn = (q @ k) * self.scale
|
||
|
attn, _ = torch.max(attn, -1)
|
||
|
out = (attn.transpose(-2, -1) @ v)
|
||
|
out = out.transpose(-2, -1)
|
||
|
out = self.proj(out)
|
||
|
return identity + self.dropout_layer(out)
|
||
|
|
||
|
|
||
|
class SimpliefiedTransformerEncoderLayer(BaseModule):
|
||
|
"""Implements one encoder layer in DEST.
|
||
|
|
||
|
Args:
|
||
|
embed_dims (int): The feature dimension.
|
||
|
num_heads (int): Parallel attention heads.
|
||
|
feedforward_channels (int): The hidden dimension for FFNs.
|
||
|
drop_rate (float): Probability of an element to be zeroed.
|
||
|
after the feed forward layer. Default 0.0.
|
||
|
attn_drop_rate (float): The drop out rate for attention layer.
|
||
|
Default 0.0.
|
||
|
drop_path_rate (float): stochastic depth rate. Default 0.0.
|
||
|
qkv_bias (bool): enable bias for qkv if True.
|
||
|
Default: True.
|
||
|
act_cfg (dict): The activation config for FFNs.
|
||
|
Default: dict(type='GELU').
|
||
|
norm_cfg (dict): Config dict for normalization layer.
|
||
|
Default: dict(type='LN').
|
||
|
batch_first (bool): Key, Query and Value are shape of
|
||
|
(batch, n, embed_dim)
|
||
|
or (n, batch, embed_dim). Default: False.
|
||
|
qk_scale (float, optional): scales for query and key. Default: None.
|
||
|
init_cfg (dict, optional): Initialization config dict.
|
||
|
Default:None.
|
||
|
sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head
|
||
|
Attention of Segformer. Default: 1.
|
||
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save
|
||
|
some memory while slowing down the training speed. Default: False.
|
||
|
"""
|
||
|
|
||
|
def __init__(self,
|
||
|
embed_dims,
|
||
|
num_heads,
|
||
|
feedforward_channels,
|
||
|
drop_rate=0.,
|
||
|
attn_drop_rate=0.,
|
||
|
drop_path_rate=0.,
|
||
|
qkv_bias=True,
|
||
|
act_cfg=dict(type='ReLU'),
|
||
|
norm_cfg=dict(type='SyncBN'),
|
||
|
batch_first=True,
|
||
|
qk_scale=None,
|
||
|
sr_ratio=1,
|
||
|
with_cp=False):
|
||
|
super(SimpliefiedTransformerEncoderLayer, self).__init__()
|
||
|
|
||
|
# The ret[0] of build_norm_layer is norm name.
|
||
|
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
|
||
|
self.attn = SimplifiedAttention(
|
||
|
embed_dims=embed_dims,
|
||
|
num_heads=num_heads,
|
||
|
qkv_bias=qkv_bias,
|
||
|
qk_scale=qk_scale,
|
||
|
attn_drop=attn_drop_rate,
|
||
|
proj_drop=drop_rate,
|
||
|
sr_ratio=sr_ratio,
|
||
|
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate))
|
||
|
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
|
||
|
self.ffn = MixFFN(
|
||
|
embed_dims=embed_dims,
|
||
|
feedforward_channels=feedforward_channels,
|
||
|
ffn_drop=drop_rate,
|
||
|
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate))
|
||
|
self.apply(self._init_weights)
|
||
|
|
||
|
def _init_weights(self, m):
|
||
|
if isinstance(m, nn.Linear):
|
||
|
trunc_normal_init(m, std=.02, bias=0.)
|
||
|
elif isinstance(m, nn.Conv2d):
|
||
|
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||
|
fan_out //= m.groups
|
||
|
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||
|
if m.bias is not None:
|
||
|
m.bias.data.zero_()
|
||
|
|
||
|
def forward(self, x, hw_shape):
|
||
|
x = self.attn(self.norm1(x), hw_shape, identity=x)
|
||
|
x = self.ffn(self.norm2(x), hw_shape, identity=x)
|
||
|
return x
|
||
|
|
||
|
|
||
|
@BACKBONES.register_module()
|
||
|
class SimplifiedMixTransformer(BaseModule):
|
||
|
"""The backbone of DEST.
|
||
|
|
||
|
This backbone is the implementation of `SegFormer: Simple and
|
||
|
Efficient Design for Semantic Segmentation with
|
||
|
Transformers <https://arxiv.org/abs/2105.15203>`_.
|
||
|
Args:
|
||
|
in_channels (int): Number of input channels. Default: 3.
|
||
|
embed_dims (Sequence[int]): Embedding dimensions of each transformer
|
||
|
encode layer. Default: [32, 64, 160, 256].
|
||
|
num_stags (int): The num of stages. Default: 4.
|
||
|
num_layers (Sequence[int]): The layer number of each transformer encode
|
||
|
layer. Default: [3, 4, 6, 3].
|
||
|
num_heads (Sequence[int]): The attention heads of each transformer
|
||
|
encode layer. Default: [1, 2, 4, 8].
|
||
|
patch_sizes (Sequence[int]): The patch_size of each overlapped patch
|
||
|
embedding. Default: [7, 3, 3, 3].
|
||
|
strides (Sequence[int]): The stride of each overlapped patch embedding.
|
||
|
Default: [4, 2, 2, 2].
|
||
|
sr_ratios (Sequence[int]): The spatial reduction rate of each
|
||
|
transformer encode layer. Default: [8, 4, 2, 1].
|
||
|
out_indices (Sequence[int] | int): Output from which stages.
|
||
|
Default: (0, 1, 2, 3).
|
||
|
mlp_ratios (Sequence[int]): ratios of mlp hidden dim to embedding dim.
|
||
|
Default: [8, 8, 4, 4].
|
||
|
qkv_bias (bool): Enable bias for qkv if True. Default: True.
|
||
|
drop_rate (float): Probability of an element to be zeroed.
|
||
|
Default 0.0
|
||
|
attn_drop_rate (float): The drop out rate for attention layer.
|
||
|
Default 0.0
|
||
|
drop_path_rate (float): stochastic depth rate. Default 0.0
|
||
|
norm_cfg (dict): Config dict for normalization layer.
|
||
|
Default: dict(type='LN')
|
||
|
act_cfg (dict): The activation config for FFNs.
|
||
|
Default: dict(type='GELU').
|
||
|
pretrained (str, optional): model pretrained path. Default: None.
|
||
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||
|
Default: None.
|
||
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save
|
||
|
some memory while slowing down the training speed. Default: False.
|
||
|
"""
|
||
|
|
||
|
def __init__(self,
|
||
|
in_channels=3,
|
||
|
embed_dims=[32, 64, 160, 256],
|
||
|
num_stages=4,
|
||
|
num_layers=[2, 2, 2, 2],
|
||
|
num_heads=[1, 2, 4, 8],
|
||
|
patch_sizes=[7, 3, 3, 3],
|
||
|
strides=[4, 2, 2, 2],
|
||
|
sr_ratios=[8, 4, 2, 1],
|
||
|
out_indices=(0, 1, 2, 3),
|
||
|
mlp_ratios=[8, 8, 4, 4],
|
||
|
qkv_bias=True,
|
||
|
drop_rate=0.,
|
||
|
attn_drop_rate=0.,
|
||
|
drop_path_rate=0.,
|
||
|
act_cfg=dict(type='ReLU'),
|
||
|
norm_cfg=dict(type='SyncBN', requires_grad=True),
|
||
|
pretrained=None,
|
||
|
init_cfg=None,
|
||
|
with_cp=False):
|
||
|
super(SimplifiedMixTransformer, self).__init__(init_cfg=init_cfg)
|
||
|
|
||
|
assert not (init_cfg and pretrained), \
|
||
|
'init_cfg and pretrained cannot be set at the same time'
|
||
|
if isinstance(pretrained, str):
|
||
|
warnings.warn('DeprecationWarning: pretrained is deprecated, '
|
||
|
'please use "init_cfg" instead')
|
||
|
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||
|
elif pretrained is not None:
|
||
|
raise TypeError('pretrained must be a str or None')
|
||
|
|
||
|
self.embed_dims = embed_dims
|
||
|
self.num_stages = num_stages
|
||
|
self.num_layers = num_layers
|
||
|
self.num_heads = num_heads
|
||
|
self.patch_sizes = patch_sizes
|
||
|
self.strides = strides
|
||
|
self.sr_ratios = sr_ratios
|
||
|
self.with_cp = with_cp
|
||
|
assert num_stages == len(num_layers) == len(num_heads) \
|
||
|
== len(patch_sizes) == len(strides) == len(sr_ratios)
|
||
|
|
||
|
self.out_indices = out_indices
|
||
|
assert max(out_indices) < self.num_stages
|
||
|
|
||
|
# transformer encoder
|
||
|
dpr = [
|
||
|
x.item()
|
||
|
for x in torch.linspace(0, drop_path_rate, sum(num_layers))
|
||
|
] # stochastic num_layer decay rule
|
||
|
|
||
|
cur = 0
|
||
|
self.layers = ModuleList()
|
||
|
for i, num_layer in enumerate(num_layers):
|
||
|
patch_embed = SimplifiedPatchEmbed(
|
||
|
in_channels=in_channels,
|
||
|
embed_dims=embed_dims[i],
|
||
|
kernel_size=patch_sizes[i],
|
||
|
stride=strides[i],
|
||
|
padding=patch_sizes[i] // 2,
|
||
|
norm_cfg=norm_cfg)
|
||
|
layer = ModuleList([
|
||
|
SimpliefiedTransformerEncoderLayer(
|
||
|
embed_dims=embed_dims[i],
|
||
|
num_heads=num_heads[i],
|
||
|
feedforward_channels=mlp_ratios[i] * embed_dims[i],
|
||
|
drop_rate=drop_rate,
|
||
|
attn_drop_rate=attn_drop_rate,
|
||
|
drop_path_rate=dpr[cur + idx],
|
||
|
qkv_bias=qkv_bias,
|
||
|
act_cfg=act_cfg,
|
||
|
norm_cfg=norm_cfg,
|
||
|
with_cp=with_cp,
|
||
|
sr_ratio=sr_ratios[i]) for idx in range(num_layer)
|
||
|
])
|
||
|
in_channels = embed_dims[i]
|
||
|
# The ret[0] of build_norm_layer is norm name.
|
||
|
norm = build_norm_layer(norm_cfg, embed_dims[i])[1]
|
||
|
self.layers.append(ModuleList([patch_embed, layer, norm]))
|
||
|
cur += num_layer
|
||
|
|
||
|
self.apply(self._init_weights)
|
||
|
|
||
|
def _init_weights(self, m):
|
||
|
if isinstance(m, nn.Linear):
|
||
|
trunc_normal_init(m, std=.02, bias=0.)
|
||
|
elif isinstance(m, nn.GroupNorm):
|
||
|
nn.init.constant_(m.bias, 0)
|
||
|
nn.init.constant_(m.weight, 1.0)
|
||
|
elif isinstance(m, nn.Conv2d):
|
||
|
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||
|
fan_out //= m.groups
|
||
|
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||
|
if m.bias is not None:
|
||
|
m.bias.data.zero_()
|
||
|
|
||
|
def forward(self, x):
|
||
|
outs = []
|
||
|
for i, layer in enumerate(self.layers):
|
||
|
x, (H, W) = layer[0](x)
|
||
|
for block in layer[1]:
|
||
|
x = block(x, (H, W))
|
||
|
x = layer[2](x)
|
||
|
N, C, L = x.shape
|
||
|
x = x.reshape(N, C, H, W)
|
||
|
outs.append(x)
|
||
|
return outs
|