589 lines
23 KiB
Python
589 lines
23 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
||
import math
|
||
import warnings
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from mmcv.cnn import build_norm_layer
|
||
from mmcv.cnn.bricks.drop import build_dropout
|
||
from mmcv.cnn.bricks.transformer import FFN
|
||
from mmcv.cnn.utils.weight_init import (constant_init, normal_init,
|
||
trunc_normal_init)
|
||
from mmcv.runner import BaseModule, ModuleList
|
||
from torch.nn.modules.batchnorm import _BatchNorm
|
||
|
||
from mmseg.models.backbones.mit import EfficientMultiheadAttention
|
||
from mmseg.models.builder import BACKBONES
|
||
from ..utils.embed import PatchEmbed
|
||
|
||
|
||
class GlobalSubsampledAttention(EfficientMultiheadAttention):
|
||
"""Global Sub-sampled Attention (Spatial Reduction Attention)
|
||
|
||
This module is modified from EfficientMultiheadAttention,
|
||
which is a module from mmseg.models.backbones.mit.py.
|
||
Specifically, there is no difference between
|
||
`GlobalSubsampledAttention` and `EfficientMultiheadAttention`,
|
||
`GlobalSubsampledAttention` is built as a brand new class
|
||
because it is renamed as `Global sub-sampled attention (GSA)`
|
||
in paper.
|
||
|
||
|
||
Args:
|
||
embed_dims (int): The embedding dimension.
|
||
num_heads (int): Parallel attention heads.
|
||
attn_drop (float): A Dropout layer on attn_output_weights.
|
||
Default: 0.0.
|
||
proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
|
||
Default: 0.0.
|
||
dropout_layer (obj:`ConfigDict`): The dropout_layer used
|
||
when adding the shortcut. Default: None.
|
||
batch_first (bool): Key, Query and Value are shape of
|
||
(batch, n, embed_dims)
|
||
or (n, batch, embed_dims). Default: False.
|
||
qkv_bias (bool): enable bias for qkv if True. Default: True.
|
||
norm_cfg (dict): Config dict for normalization layer.
|
||
Default: dict(type='LN').
|
||
sr_ratio (int): The ratio of spatial reduction of GSA of PCPVT.
|
||
Default: 1.
|
||
init_cfg (dict, optional): The Config for initialization.
|
||
Defaults to None.
|
||
"""
|
||
|
||
def __init__(self,
|
||
embed_dims,
|
||
num_heads,
|
||
attn_drop=0.,
|
||
proj_drop=0.,
|
||
dropout_layer=None,
|
||
batch_first=True,
|
||
qkv_bias=True,
|
||
norm_cfg=dict(type='LN'),
|
||
sr_ratio=1,
|
||
init_cfg=None):
|
||
super(GlobalSubsampledAttention, self).__init__(
|
||
embed_dims,
|
||
num_heads,
|
||
attn_drop=attn_drop,
|
||
proj_drop=proj_drop,
|
||
dropout_layer=dropout_layer,
|
||
batch_first=batch_first,
|
||
qkv_bias=qkv_bias,
|
||
norm_cfg=norm_cfg,
|
||
sr_ratio=sr_ratio,
|
||
init_cfg=init_cfg)
|
||
|
||
|
||
class GSAEncoderLayer(BaseModule):
|
||
"""Implements one encoder layer with GSA.
|
||
|
||
Args:
|
||
embed_dims (int): The feature dimension.
|
||
num_heads (int): Parallel attention heads.
|
||
feedforward_channels (int): The hidden dimension for FFNs.
|
||
drop_rate (float): Probability of an element to be zeroed
|
||
after the feed forward layer. Default: 0.0.
|
||
attn_drop_rate (float): The drop out rate for attention layer.
|
||
Default: 0.0.
|
||
drop_path_rate (float): Stochastic depth rate. Default 0.0.
|
||
num_fcs (int): The number of fully-connected layers for FFNs.
|
||
Default: 2.
|
||
qkv_bias (bool): Enable bias for qkv if True. Default: True
|
||
act_cfg (dict): The activation config for FFNs.
|
||
Default: dict(type='GELU').
|
||
norm_cfg (dict): Config dict for normalization layer.
|
||
Default: dict(type='LN').
|
||
sr_ratio (float): Kernel_size of conv in Attention modules. Default: 1.
|
||
init_cfg (dict, optional): The Config for initialization.
|
||
Defaults to None.
|
||
"""
|
||
|
||
def __init__(self,
|
||
embed_dims,
|
||
num_heads,
|
||
feedforward_channels,
|
||
drop_rate=0.,
|
||
attn_drop_rate=0.,
|
||
drop_path_rate=0.,
|
||
num_fcs=2,
|
||
qkv_bias=True,
|
||
act_cfg=dict(type='GELU'),
|
||
norm_cfg=dict(type='LN'),
|
||
sr_ratio=1.,
|
||
init_cfg=None):
|
||
super(GSAEncoderLayer, self).__init__(init_cfg=init_cfg)
|
||
|
||
self.norm1 = build_norm_layer(norm_cfg, embed_dims, postfix=1)[1]
|
||
self.attn = GlobalSubsampledAttention(
|
||
embed_dims=embed_dims,
|
||
num_heads=num_heads,
|
||
attn_drop=attn_drop_rate,
|
||
proj_drop=drop_rate,
|
||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||
qkv_bias=qkv_bias,
|
||
norm_cfg=norm_cfg,
|
||
sr_ratio=sr_ratio)
|
||
|
||
self.norm2 = build_norm_layer(norm_cfg, embed_dims, postfix=2)[1]
|
||
self.ffn = FFN(
|
||
embed_dims=embed_dims,
|
||
feedforward_channels=feedforward_channels,
|
||
num_fcs=num_fcs,
|
||
ffn_drop=drop_rate,
|
||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||
act_cfg=act_cfg,
|
||
add_identity=False)
|
||
|
||
self.drop_path = build_dropout(
|
||
dict(type='DropPath', drop_prob=drop_path_rate)
|
||
) if drop_path_rate > 0. else nn.Identity()
|
||
|
||
def forward(self, x, hw_shape):
|
||
x = x + self.drop_path(self.attn(self.norm1(x), hw_shape, identity=0.))
|
||
x = x + self.drop_path(self.ffn(self.norm2(x)))
|
||
return x
|
||
|
||
|
||
class LocallyGroupedSelfAttention(BaseModule):
|
||
"""Locally-grouped Self Attention (LSA) module.
|
||
|
||
Args:
|
||
embed_dims (int): Number of input channels.
|
||
num_heads (int): Number of attention heads. Default: 8
|
||
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
|
||
Default: False.
|
||
qk_scale (float | None, optional): Override default qk scale of
|
||
head_dim ** -0.5 if set. Default: None.
|
||
attn_drop_rate (float, optional): Dropout ratio of attention weight.
|
||
Default: 0.0
|
||
proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.
|
||
window_size(int): Window size of LSA. Default: 1.
|
||
init_cfg (dict, optional): The Config for initialization.
|
||
Defaults to None.
|
||
"""
|
||
|
||
def __init__(self,
|
||
embed_dims,
|
||
num_heads=8,
|
||
qkv_bias=False,
|
||
qk_scale=None,
|
||
attn_drop_rate=0.,
|
||
proj_drop_rate=0.,
|
||
window_size=1,
|
||
init_cfg=None):
|
||
super(LocallyGroupedSelfAttention, self).__init__(init_cfg=init_cfg)
|
||
|
||
assert embed_dims % num_heads == 0, f'dim {embed_dims} should be ' \
|
||
f'divided by num_heads ' \
|
||
f'{num_heads}.'
|
||
self.embed_dims = embed_dims
|
||
self.num_heads = num_heads
|
||
head_dim = embed_dims // num_heads
|
||
self.scale = qk_scale or head_dim**-0.5
|
||
|
||
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
|
||
self.attn_drop = nn.Dropout(attn_drop_rate)
|
||
self.proj = nn.Linear(embed_dims, embed_dims)
|
||
self.proj_drop = nn.Dropout(proj_drop_rate)
|
||
self.window_size = window_size
|
||
|
||
def forward(self, x, hw_shape):
|
||
b, n, c = x.shape
|
||
h, w = hw_shape
|
||
x = x.view(b, h, w, c)
|
||
|
||
# pad feature maps to multiples of Local-groups
|
||
pad_l = pad_t = 0
|
||
pad_r = (self.window_size - w % self.window_size) % self.window_size
|
||
pad_b = (self.window_size - h % self.window_size) % self.window_size
|
||
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
||
|
||
# calculate attention mask for LSA
|
||
Hp, Wp = x.shape[1:-1]
|
||
_h, _w = Hp // self.window_size, Wp // self.window_size
|
||
mask = torch.zeros((1, Hp, Wp), device=x.device)
|
||
mask[:, -pad_b:, :].fill_(1)
|
||
mask[:, :, -pad_r:].fill_(1)
|
||
|
||
# [B, _h, _w, window_size, window_size, C]
|
||
x = x.reshape(b, _h, self.window_size, _w, self.window_size,
|
||
c).transpose(2, 3)
|
||
mask = mask.reshape(1, _h, self.window_size, _w,
|
||
self.window_size).transpose(2, 3).reshape(
|
||
1, _h * _w,
|
||
self.window_size * self.window_size)
|
||
# [1, _h*_w, window_size*window_size, window_size*window_size]
|
||
attn_mask = mask.unsqueeze(2) - mask.unsqueeze(3)
|
||
attn_mask = attn_mask.masked_fill(attn_mask != 0,
|
||
float(-1000.0)).masked_fill(
|
||
attn_mask == 0, float(0.0))
|
||
|
||
# [3, B, _w*_h, nhead, window_size*window_size, dim]
|
||
qkv = self.qkv(x).reshape(b, _h * _w,
|
||
self.window_size * self.window_size, 3,
|
||
self.num_heads, c // self.num_heads).permute(
|
||
3, 0, 1, 4, 2, 5)
|
||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||
# [B, _h*_w, n_head, window_size*window_size, window_size*window_size]
|
||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||
attn = attn + attn_mask.unsqueeze(2)
|
||
attn = attn.softmax(dim=-1)
|
||
attn = self.attn_drop(attn)
|
||
attn = (attn @ v).transpose(2, 3).reshape(b, _h, _w, self.window_size,
|
||
self.window_size, c)
|
||
x = attn.transpose(2, 3).reshape(b, _h * self.window_size,
|
||
_w * self.window_size, c)
|
||
if pad_r > 0 or pad_b > 0:
|
||
x = x[:, :h, :w, :].contiguous()
|
||
|
||
x = x.reshape(b, n, c)
|
||
x = self.proj(x)
|
||
x = self.proj_drop(x)
|
||
return x
|
||
|
||
|
||
class LSAEncoderLayer(BaseModule):
|
||
"""Implements one encoder layer in Twins-SVT.
|
||
|
||
Args:
|
||
embed_dims (int): The feature dimension.
|
||
num_heads (int): Parallel attention heads.
|
||
feedforward_channels (int): The hidden dimension for FFNs.
|
||
drop_rate (float): Probability of an element to be zeroed
|
||
after the feed forward layer. Default: 0.0.
|
||
attn_drop_rate (float, optional): Dropout ratio of attention weight.
|
||
Default: 0.0
|
||
drop_path_rate (float): Stochastic depth rate. Default 0.0.
|
||
num_fcs (int): The number of fully-connected layers for FFNs.
|
||
Default: 2.
|
||
qkv_bias (bool): Enable bias for qkv if True. Default: True
|
||
qk_scale (float | None, optional): Override default qk scale of
|
||
head_dim ** -0.5 if set. Default: None.
|
||
act_cfg (dict): The activation config for FFNs.
|
||
Default: dict(type='GELU').
|
||
norm_cfg (dict): Config dict for normalization layer.
|
||
Default: dict(type='LN').
|
||
window_size (int): Window size of LSA. Default: 1.
|
||
init_cfg (dict, optional): The Config for initialization.
|
||
Defaults to None.
|
||
"""
|
||
|
||
def __init__(self,
|
||
embed_dims,
|
||
num_heads,
|
||
feedforward_channels,
|
||
drop_rate=0.,
|
||
attn_drop_rate=0.,
|
||
drop_path_rate=0.,
|
||
num_fcs=2,
|
||
qkv_bias=True,
|
||
qk_scale=None,
|
||
act_cfg=dict(type='GELU'),
|
||
norm_cfg=dict(type='LN'),
|
||
window_size=1,
|
||
init_cfg=None):
|
||
|
||
super(LSAEncoderLayer, self).__init__(init_cfg=init_cfg)
|
||
|
||
self.norm1 = build_norm_layer(norm_cfg, embed_dims, postfix=1)[1]
|
||
self.attn = LocallyGroupedSelfAttention(embed_dims, num_heads,
|
||
qkv_bias, qk_scale,
|
||
attn_drop_rate, drop_rate,
|
||
window_size)
|
||
|
||
self.norm2 = build_norm_layer(norm_cfg, embed_dims, postfix=2)[1]
|
||
self.ffn = FFN(
|
||
embed_dims=embed_dims,
|
||
feedforward_channels=feedforward_channels,
|
||
num_fcs=num_fcs,
|
||
ffn_drop=drop_rate,
|
||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||
act_cfg=act_cfg,
|
||
add_identity=False)
|
||
|
||
self.drop_path = build_dropout(
|
||
dict(type='DropPath', drop_prob=drop_path_rate)
|
||
) if drop_path_rate > 0. else nn.Identity()
|
||
|
||
def forward(self, x, hw_shape):
|
||
x = x + self.drop_path(self.attn(self.norm1(x), hw_shape))
|
||
x = x + self.drop_path(self.ffn(self.norm2(x)))
|
||
return x
|
||
|
||
|
||
class ConditionalPositionEncoding(BaseModule):
|
||
"""The Conditional Position Encoding (CPE) module.
|
||
|
||
The CPE is the implementation of 'Conditional Positional Encodings
|
||
for Vision Transformers <https://arxiv.org/abs/2102.10882>'_.
|
||
|
||
Args:
|
||
in_channels (int): Number of input channels.
|
||
embed_dims (int): The feature dimension. Default: 768.
|
||
stride (int): Stride of conv layer. Default: 1.
|
||
"""
|
||
|
||
def __init__(self, in_channels, embed_dims=768, stride=1, init_cfg=None):
|
||
super(ConditionalPositionEncoding, self).__init__(init_cfg=init_cfg)
|
||
self.proj = nn.Conv2d(
|
||
in_channels,
|
||
embed_dims,
|
||
kernel_size=3,
|
||
stride=stride,
|
||
padding=1,
|
||
bias=True,
|
||
groups=embed_dims)
|
||
self.stride = stride
|
||
|
||
def forward(self, x, hw_shape):
|
||
b, n, c = x.shape
|
||
h, w = hw_shape
|
||
feat_token = x
|
||
cnn_feat = feat_token.transpose(1, 2).view(b, c, h, w)
|
||
if self.stride == 1:
|
||
x = self.proj(cnn_feat) + cnn_feat
|
||
else:
|
||
x = self.proj(cnn_feat)
|
||
x = x.flatten(2).transpose(1, 2)
|
||
return x
|
||
|
||
|
||
@BACKBONES.register_module()
|
||
class PCPVT(BaseModule):
|
||
"""The backbone of Twins-PCPVT.
|
||
|
||
This backbone is the implementation of `Twins: Revisiting the Design
|
||
of Spatial Attention in Vision Transformers
|
||
<https://arxiv.org/abs/1512.03385>`_.
|
||
|
||
Args:
|
||
in_channels (int): Number of input channels. Default: 3.
|
||
embed_dims (list): Embedding dimension. Default: [64, 128, 256, 512].
|
||
patch_sizes (list): The patch sizes. Default: [4, 2, 2, 2].
|
||
strides (list): The strides. Default: [4, 2, 2, 2].
|
||
num_heads (int): Number of attention heads. Default: [1, 2, 4, 8].
|
||
mlp_ratios (int): Ratio of mlp hidden dim to embedding dim.
|
||
Default: [4, 4, 4, 4].
|
||
out_indices (tuple[int]): Output from which stages.
|
||
Default: (0, 1, 2, 3).
|
||
qkv_bias (bool): Enable bias for qkv if True. Default: False.
|
||
drop_rate (float): Probability of an element to be zeroed.
|
||
Default 0.
|
||
attn_drop_rate (float): The drop out rate for attention layer.
|
||
Default 0.0
|
||
drop_path_rate (float): Stochastic depth rate. Default 0.0
|
||
norm_cfg (dict): Config dict for normalization layer.
|
||
Default: dict(type='LN')
|
||
depths (list): Depths of each stage. Default [3, 4, 6, 3]
|
||
sr_ratios (list): Kernel_size of conv in each Attn module in
|
||
Transformer encoder layer. Default: [8, 4, 2, 1].
|
||
norm_after_stage(bool): Add extra norm. Default False.
|
||
init_cfg (dict, optional): The Config for initialization.
|
||
Defaults to None.
|
||
"""
|
||
|
||
def __init__(self,
|
||
in_channels=3,
|
||
embed_dims=[64, 128, 256, 512],
|
||
patch_sizes=[4, 2, 2, 2],
|
||
strides=[4, 2, 2, 2],
|
||
num_heads=[1, 2, 4, 8],
|
||
mlp_ratios=[4, 4, 4, 4],
|
||
out_indices=(0, 1, 2, 3),
|
||
qkv_bias=False,
|
||
drop_rate=0.,
|
||
attn_drop_rate=0.,
|
||
drop_path_rate=0.,
|
||
norm_cfg=dict(type='LN'),
|
||
depths=[3, 4, 6, 3],
|
||
sr_ratios=[8, 4, 2, 1],
|
||
norm_after_stage=False,
|
||
pretrained=None,
|
||
init_cfg=None):
|
||
super(PCPVT, self).__init__(init_cfg=init_cfg)
|
||
assert not (init_cfg and pretrained), \
|
||
'init_cfg and pretrained cannot be set at the same time'
|
||
if isinstance(pretrained, str):
|
||
warnings.warn('DeprecationWarning: pretrained is deprecated, '
|
||
'please use "init_cfg" instead')
|
||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||
elif pretrained is not None:
|
||
raise TypeError('pretrained must be a str or None')
|
||
self.depths = depths
|
||
|
||
# patch_embed
|
||
self.patch_embeds = ModuleList()
|
||
self.position_encoding_drops = ModuleList()
|
||
self.layers = ModuleList()
|
||
|
||
for i in range(len(depths)):
|
||
self.patch_embeds.append(
|
||
PatchEmbed(
|
||
in_channels=in_channels if i == 0 else embed_dims[i - 1],
|
||
embed_dims=embed_dims[i],
|
||
conv_type='Conv2d',
|
||
kernel_size=patch_sizes[i],
|
||
stride=strides[i],
|
||
padding='corner',
|
||
norm_cfg=norm_cfg))
|
||
|
||
self.position_encoding_drops.append(nn.Dropout(p=drop_rate))
|
||
|
||
self.position_encodings = ModuleList([
|
||
ConditionalPositionEncoding(embed_dim, embed_dim)
|
||
for embed_dim in embed_dims
|
||
])
|
||
|
||
# transformer encoder
|
||
dpr = [
|
||
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
||
] # stochastic depth decay rule
|
||
cur = 0
|
||
|
||
for k in range(len(depths)):
|
||
_block = ModuleList([
|
||
GSAEncoderLayer(
|
||
embed_dims=embed_dims[k],
|
||
num_heads=num_heads[k],
|
||
feedforward_channels=mlp_ratios[k] * embed_dims[k],
|
||
attn_drop_rate=attn_drop_rate,
|
||
drop_rate=drop_rate,
|
||
drop_path_rate=dpr[cur + i],
|
||
num_fcs=2,
|
||
qkv_bias=qkv_bias,
|
||
act_cfg=dict(type='GELU'),
|
||
norm_cfg=dict(type='LN'),
|
||
sr_ratio=sr_ratios[k]) for i in range(depths[k])
|
||
])
|
||
self.layers.append(_block)
|
||
cur += depths[k]
|
||
|
||
self.norm_name, norm = build_norm_layer(
|
||
norm_cfg, embed_dims[-1], postfix=1)
|
||
|
||
self.out_indices = out_indices
|
||
self.norm_after_stage = norm_after_stage
|
||
if self.norm_after_stage:
|
||
self.norm_list = ModuleList()
|
||
for dim in embed_dims:
|
||
self.norm_list.append(build_norm_layer(norm_cfg, dim)[1])
|
||
|
||
def init_weights(self):
|
||
if self.init_cfg is not None:
|
||
super(PCPVT, self).init_weights()
|
||
else:
|
||
for m in self.modules():
|
||
if isinstance(m, nn.Linear):
|
||
trunc_normal_init(m, std=.02, bias=0.)
|
||
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
|
||
constant_init(m, val=1.0, bias=0.)
|
||
elif isinstance(m, nn.Conv2d):
|
||
fan_out = m.kernel_size[0] * m.kernel_size[
|
||
1] * m.out_channels
|
||
fan_out //= m.groups
|
||
normal_init(
|
||
m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0)
|
||
|
||
def forward(self, x):
|
||
outputs = list()
|
||
|
||
b = x.shape[0]
|
||
|
||
for i in range(len(self.depths)):
|
||
x, hw_shape = self.patch_embeds[i](x)
|
||
h, w = hw_shape
|
||
x = self.position_encoding_drops[i](x)
|
||
for j, blk in enumerate(self.layers[i]):
|
||
x = blk(x, hw_shape)
|
||
if j == 0:
|
||
x = self.position_encodings[i](x, hw_shape)
|
||
if self.norm_after_stage:
|
||
x = self.norm_list[i](x)
|
||
x = x.reshape(b, h, w, -1).permute(0, 3, 1, 2).contiguous()
|
||
|
||
if i in self.out_indices:
|
||
outputs.append(x)
|
||
|
||
return tuple(outputs)
|
||
|
||
|
||
@BACKBONES.register_module()
|
||
class SVT(PCPVT):
|
||
"""The backbone of Twins-SVT.
|
||
|
||
This backbone is the implementation of `Twins: Revisiting the Design
|
||
of Spatial Attention in Vision Transformers
|
||
<https://arxiv.org/abs/1512.03385>`_.
|
||
|
||
Args:
|
||
in_channels (int): Number of input channels. Default: 3.
|
||
embed_dims (list): Embedding dimension. Default: [64, 128, 256, 512].
|
||
patch_sizes (list): The patch sizes. Default: [4, 2, 2, 2].
|
||
strides (list): The strides. Default: [4, 2, 2, 2].
|
||
num_heads (int): Number of attention heads. Default: [1, 2, 4].
|
||
mlp_ratios (int): Ratio of mlp hidden dim to embedding dim.
|
||
Default: [4, 4, 4].
|
||
out_indices (tuple[int]): Output from which stages.
|
||
Default: (0, 1, 2, 3).
|
||
qkv_bias (bool): Enable bias for qkv if True. Default: False.
|
||
drop_rate (float): Dropout rate. Default 0.
|
||
attn_drop_rate (float): Dropout ratio of attention weight.
|
||
Default 0.0
|
||
drop_path_rate (float): Stochastic depth rate. Default 0.2.
|
||
norm_cfg (dict): Config dict for normalization layer.
|
||
Default: dict(type='LN')
|
||
depths (list): Depths of each stage. Default [4, 4, 4].
|
||
sr_ratios (list): Kernel_size of conv in each Attn module in
|
||
Transformer encoder layer. Default: [4, 2, 1].
|
||
windiow_sizes (list): Window size of LSA. Default: [7, 7, 7],
|
||
input_features_slice(bool): Input features need slice. Default: False.
|
||
norm_after_stage(bool): Add extra norm. Default False.
|
||
strides (list): Strides in patch-Embedding modules. Default: (2, 2, 2)
|
||
init_cfg (dict, optional): The Config for initialization.
|
||
Defaults to None.
|
||
"""
|
||
|
||
def __init__(self,
|
||
in_channels=3,
|
||
embed_dims=[64, 128, 256],
|
||
patch_sizes=[4, 2, 2, 2],
|
||
strides=[4, 2, 2, 2],
|
||
num_heads=[1, 2, 4],
|
||
mlp_ratios=[4, 4, 4],
|
||
out_indices=(0, 1, 2, 3),
|
||
qkv_bias=False,
|
||
drop_rate=0.,
|
||
attn_drop_rate=0.,
|
||
drop_path_rate=0.2,
|
||
norm_cfg=dict(type='LN'),
|
||
depths=[4, 4, 4],
|
||
sr_ratios=[4, 2, 1],
|
||
windiow_sizes=[7, 7, 7],
|
||
norm_after_stage=True,
|
||
pretrained=None,
|
||
init_cfg=None):
|
||
super(SVT, self).__init__(in_channels, embed_dims, patch_sizes,
|
||
strides, num_heads, mlp_ratios, out_indices,
|
||
qkv_bias, drop_rate, attn_drop_rate,
|
||
drop_path_rate, norm_cfg, depths, sr_ratios,
|
||
norm_after_stage, pretrained, init_cfg)
|
||
# transformer encoder
|
||
dpr = [
|
||
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
||
] # stochastic depth decay rule
|
||
|
||
for k in range(len(depths)):
|
||
for i in range(depths[k]):
|
||
if i % 2 == 0:
|
||
self.layers[k][i] = \
|
||
LSAEncoderLayer(
|
||
embed_dims=embed_dims[k],
|
||
num_heads=num_heads[k],
|
||
feedforward_channels=mlp_ratios[k] * embed_dims[k],
|
||
drop_rate=drop_rate,
|
||
attn_drop_rate=attn_drop_rate,
|
||
drop_path_rate=dpr[sum(depths[:k])+i],
|
||
qkv_bias=qkv_bias,
|
||
window_size=windiow_sizes[k])
|