mmpretrain/mmcls/models/backbones/van.py

435 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer
from mmcv.cnn.bricks import DropPath
from mmcv.cnn.bricks.transformer import PatchEmbed
from mmcv.runner import BaseModule, ModuleList
from mmcv.utils.parrots_wrapper import _BatchNorm
from ..builder import BACKBONES
from .base_backbone import BaseBackbone
class MixFFN(BaseModule):
"""An implementation of MixFFN of VAN. Refer to
mmdetection/mmdet/models/backbones/pvt.py.
The differences between MixFFN & FFN:
1. Use 1X1 Conv to replace Linear layer.
2. Introduce 3X3 Depth-wise Conv to encode positional information.
Args:
embed_dims (int): The feature dimension. Same as
`MultiheadAttention`.
feedforward_channels (int): The hidden dimension of FFNs.
act_cfg (dict, optional): The activation config for FFNs.
Default: dict(type='GELU').
ffn_drop (float, optional): Probability of an element to be
zeroed in FFN. Default 0.0.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def __init__(self,
embed_dims,
feedforward_channels,
act_cfg=dict(type='GELU'),
ffn_drop=0.,
init_cfg=None):
super(MixFFN, self).__init__(init_cfg=init_cfg)
self.embed_dims = embed_dims
self.feedforward_channels = feedforward_channels
self.act_cfg = act_cfg
self.fc1 = Conv2d(
in_channels=embed_dims,
out_channels=feedforward_channels,
kernel_size=1)
self.dwconv = Conv2d(
in_channels=feedforward_channels,
out_channels=feedforward_channels,
kernel_size=3,
stride=1,
padding=1,
bias=True,
groups=feedforward_channels)
self.act = build_activation_layer(act_cfg)
self.fc2 = Conv2d(
in_channels=feedforward_channels,
out_channels=embed_dims,
kernel_size=1)
self.drop = nn.Dropout(ffn_drop)
def forward(self, x):
x = self.fc1(x)
x = self.dwconv(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class LKA(BaseModule):
"""Large Kernel Attention(LKA) of VAN.
.. code:: text
DW_conv (depth-wise convolution)
|
|
DW_D_conv (depth-wise dilation convolution)
|
|
Transition Convolution (1×1 convolution)
Args:
embed_dims (int): Number of input channels.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def __init__(self, embed_dims, init_cfg=None):
super(LKA, self).__init__(init_cfg=init_cfg)
# a spatial local convolution (depth-wise convolution)
self.DW_conv = Conv2d(
in_channels=embed_dims,
out_channels=embed_dims,
kernel_size=5,
padding=2,
groups=embed_dims)
# a spatial long-range convolution (depth-wise dilation convolution)
self.DW_D_conv = Conv2d(
in_channels=embed_dims,
out_channels=embed_dims,
kernel_size=7,
stride=1,
padding=9,
groups=embed_dims,
dilation=3)
self.conv1 = Conv2d(
in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)
def forward(self, x):
u = x.clone()
attn = self.DW_conv(x)
attn = self.DW_D_conv(attn)
attn = self.conv1(attn)
return u * attn
class SpatialAttention(BaseModule):
"""Basic attention module in VANBloack.
Args:
embed_dims (int): Number of input channels.
act_cfg (dict, optional): The activation config for FFNs.
Default: dict(type='GELU').
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def __init__(self, embed_dims, act_cfg=dict(type='GELU'), init_cfg=None):
super(SpatialAttention, self).__init__(init_cfg=init_cfg)
self.proj_1 = Conv2d(
in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)
self.activation = build_activation_layer(act_cfg)
self.spatial_gating_unit = LKA(embed_dims)
self.proj_2 = Conv2d(
in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)
def forward(self, x):
shorcut = x.clone()
x = self.proj_1(x)
x = self.activation(x)
x = self.spatial_gating_unit(x)
x = self.proj_2(x)
x = x + shorcut
return x
class VANBlock(BaseModule):
"""A block of VAN.
Args:
embed_dims (int): Number of input channels.
ffn_ratio (float): The expansion ratio of feedforward network hidden
layer channels. Defaults to 4.
drop_rate (float): Dropout rate after embedding. Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.1.
act_cfg (dict, optional): The activation config for FFNs.
Default: dict(type='GELU').
layer_scale_init_value (float): Init value for Layer Scale.
Defaults to 1e-2.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def __init__(self,
embed_dims,
ffn_ratio=4.,
drop_rate=0.,
drop_path_rate=0.,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='BN', eps=1e-5),
layer_scale_init_value=1e-2,
init_cfg=None):
super(VANBlock, self).__init__(init_cfg=init_cfg)
self.out_channels = embed_dims
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
self.attn = SpatialAttention(embed_dims, act_cfg=act_cfg)
self.drop_path = DropPath(
drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
mlp_hidden_dim = int(embed_dims * ffn_ratio)
self.mlp = MixFFN(
embed_dims=embed_dims,
feedforward_channels=mlp_hidden_dim,
act_cfg=act_cfg,
ffn_drop=drop_rate)
self.layer_scale_1 = nn.Parameter(
layer_scale_init_value * torch.ones((embed_dims)),
requires_grad=True) if layer_scale_init_value > 0 else None
self.layer_scale_2 = nn.Parameter(
layer_scale_init_value * torch.ones((embed_dims)),
requires_grad=True) if layer_scale_init_value > 0 else None
def forward(self, x):
identity = x
x = self.norm1(x)
x = self.attn(x)
if self.layer_scale_1 is not None:
x = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * x
x = identity + self.drop_path(x)
identity = x
x = self.norm2(x)
x = self.mlp(x)
if self.layer_scale_2 is not None:
x = self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * x
x = identity + self.drop_path(x)
return x
class VANPatchEmbed(PatchEmbed):
"""Image to Patch Embedding of VAN.
The differences between VANPatchEmbed & PatchEmbed:
1. Use BN.
2. Do not use 'flatten' and 'transpose'.
"""
def __init__(self, *args, norm_cfg=dict(type='BN'), **kwargs):
super(VANPatchEmbed, self).__init__(*args, norm_cfg=norm_cfg, **kwargs)
def forward(self, x):
"""
Args:
x (Tensor): Has shape (B, C, H, W). In most case, C is 3.
Returns:
tuple: Contains merged results and its spatial shape.
- x (Tensor): Has shape (B, out_h * out_w, embed_dims)
- out_size (tuple[int]): Spatial shape of x, arrange as
(out_h, out_w).
"""
if self.adaptive_padding:
x = self.adaptive_padding(x)
x = self.projection(x)
out_size = (x.shape[2], x.shape[3])
if self.norm is not None:
x = self.norm(x)
return x, out_size
@BACKBONES.register_module()
class VAN(BaseBackbone):
"""Visual Attention Network.
A PyTorch implement of : `Visual Attention Network
<https://arxiv.org/pdf/2202.09741v2.pdf>`_
Inspiration from
https://github.com/Visual-Attention-Network/VAN-Classification
Args:
arch (str | dict): Visual Attention Network architecture.
If use string, choose from 'tiny', 'small', 'base' and 'large'.
If use dict, it should have below keys:
- **embed_dims** (List[int]): The dimensions of embedding.
- **depths** (List[int]): The number of blocks in each stage.
- **ffn_ratios** (List[int]): The number of expansion ratio of
feedforward network hidden layer channels.
Defaults to 'tiny'.
patch_sizes (List[int | tuple]): The patch size in patch embeddings.
Defaults to [7, 3, 3, 3].
in_channels (int): The num of input channels. Defaults to 3.
drop_rate (float): Dropout rate after embedding. Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.1.
out_indices (Sequence[int]): Output from which stages.
Default: ``(3, )``.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Defaults to False.
norm_cfg (dict): Config dict for normalization layer for all output
features. Defaults to ``dict(type='LN')``
block_cfgs (Sequence[dict] | dict): The extra config of each block.
Defaults to empty dicts.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
Examples:
>>> from mmcls.models import VAN
>>> import torch
>>> cfg = dict(arch='tiny')
>>> model = VAN(**cfg)
>>> inputs = torch.rand(1, 3, 224, 224)
>>> outputs = model(inputs)
>>> for out in outputs:
>>> print(out.size())
(1, 256, 7, 7)
"""
arch_zoo = {
**dict.fromkeys(['t', 'tiny'],
{'embed_dims': [32, 64, 160, 256],
'depths': [3, 3, 5, 2],
'ffn_ratios': [8, 8, 4, 4]}),
**dict.fromkeys(['s', 'small'],
{'embed_dims': [64, 128, 320, 512],
'depths': [2, 2, 4, 2],
'ffn_ratios': [8, 8, 4, 4]}),
**dict.fromkeys(['b', 'base'],
{'embed_dims': [64, 128, 320, 512],
'depths': [3, 3, 12, 3],
'ffn_ratios': [8, 8, 4, 4]}),
**dict.fromkeys(['l', 'large'],
{'embed_dims': [64, 128, 320, 512],
'depths': [3, 5, 27, 3],
'ffn_ratios': [8, 8, 4, 4]}),
} # yapf: disable
def __init__(self,
arch='tiny',
patch_sizes=[7, 3, 3, 3],
in_channels=3,
drop_rate=0.,
drop_path_rate=0.,
out_indices=(3, ),
frozen_stages=-1,
norm_eval=False,
norm_cfg=dict(type='LN'),
block_cfgs=dict(),
init_cfg=None):
super(VAN, self).__init__(init_cfg=init_cfg)
if isinstance(arch, str):
arch = arch.lower()
assert arch in set(self.arch_zoo), \
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
self.arch_settings = self.arch_zoo[arch]
else:
essential_keys = {'embed_dims', 'depths', 'ffn_ratios'}
assert isinstance(arch, dict) and set(arch) == essential_keys, \
f'Custom arch needs a dict with keys {essential_keys}'
self.arch_settings = arch
self.embed_dims = self.arch_settings['embed_dims']
self.depths = self.arch_settings['depths']
self.ffn_ratios = self.arch_settings['ffn_ratios']
self.num_stages = len(self.depths)
self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.norm_eval = norm_eval
total_depth = sum(self.depths)
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
] # stochastic depth decay rule
cur_block_idx = 0
for i, depth in enumerate(self.depths):
patch_embed = VANPatchEmbed(
in_channels=in_channels if i == 0 else self.embed_dims[i - 1],
input_size=None,
embed_dims=self.embed_dims[i],
kernel_size=patch_sizes[i],
stride=patch_sizes[i] // 2 + 1,
padding=(patch_sizes[i] // 2, patch_sizes[i] // 2),
norm_cfg=dict(type='BN'))
blocks = ModuleList([
VANBlock(
embed_dims=self.embed_dims[i],
ffn_ratio=self.ffn_ratios[i],
drop_rate=drop_rate,
drop_path_rate=dpr[cur_block_idx + j],
**block_cfgs) for j in range(depth)
])
cur_block_idx += depth
norm = build_norm_layer(norm_cfg, self.embed_dims[i])[1]
self.add_module(f'patch_embed{i + 1}', patch_embed)
self.add_module(f'blocks{i + 1}', blocks)
self.add_module(f'norm{i + 1}', norm)
def train(self, mode=True):
super(VAN, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()
def _freeze_stages(self):
for i in range(0, self.frozen_stages + 1):
# freeze patch embed
m = getattr(self, f'patch_embed{i + 1}')
m.eval()
for param in m.parameters():
param.requires_grad = False
# freeze blocks
m = getattr(self, f'blocks{i + 1}')
m.eval()
for param in m.parameters():
param.requires_grad = False
# freeze norm
m = getattr(self, f'norm{i + 1}')
m.eval()
for param in m.parameters():
param.requires_grad = False
def forward(self, x):
outs = []
for i in range(self.num_stages):
patch_embed = getattr(self, f'patch_embed{i + 1}')
blocks = getattr(self, f'blocks{i + 1}')
norm = getattr(self, f'norm{i + 1}')
x, hw_shape = patch_embed(x)
for block in blocks:
x = block(x)
x = x.flatten(2).transpose(1, 2)
x = norm(x)
x = x.reshape(-1, *hw_shape,
block.out_channels).permute(0, 3, 1, 2).contiguous()
if i in self.out_indices:
outs.append(x)
return tuple(outs)