mmclassification/mmpretrain/models/backbones/efficientformer.py

607 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# Copyright (c) OpenMMLab. All rights reserved.
import itertools
from typing import Optional, Sequence
import torch
import torch.nn as nn
from mmcv.cnn.bricks import (ConvModule, DropPath, build_activation_layer,
build_norm_layer)
from mmengine.model import BaseModule, ModuleList, Sequential
from mmpretrain.registry import MODELS
from ..utils import LayerScale
from .base_backbone import BaseBackbone
from .poolformer import Pooling
class AttentionWithBias(BaseModule):
"""Multi-head Attention Module with attention_bias.
Args:
embed_dims (int): The embedding dimension.
num_heads (int): Parallel attention heads. Defaults to 8.
key_dim (int): The dimension of q, k. Defaults to 32.
attn_ratio (float): The dimension of v equals to
``key_dim * attn_ratio``. Defaults to 4.
resolution (int): The height and width of attention_bias.
Defaults to 7.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
def __init__(self,
embed_dims,
num_heads=8,
key_dim=32,
attn_ratio=4.,
resolution=7,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.num_heads = num_heads
self.scale = key_dim**-0.5
self.attn_ratio = attn_ratio
self.key_dim = key_dim
self.nh_kd = key_dim * num_heads
self.d = int(attn_ratio * key_dim)
self.dh = int(attn_ratio * key_dim) * num_heads
h = self.dh + self.nh_kd * 2
self.qkv = nn.Linear(embed_dims, h)
self.proj = nn.Linear(self.dh, embed_dims)
points = list(itertools.product(range(resolution), range(resolution)))
N = len(points)
attention_offsets = {}
idxs = []
for p1 in points:
for p2 in points:
offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
if offset not in attention_offsets:
attention_offsets[offset] = len(attention_offsets)
idxs.append(attention_offsets[offset])
self.attention_biases = nn.Parameter(
torch.zeros(num_heads, len(attention_offsets)))
self.register_buffer('attention_bias_idxs',
torch.LongTensor(idxs).view(N, N))
@torch.no_grad()
def train(self, mode=True):
"""change the mode of model."""
super().train(mode)
if mode and hasattr(self, 'ab'):
del self.ab
else:
self.ab = self.attention_biases[:, self.attention_bias_idxs]
def forward(self, x):
"""forward function.
Args:
x (tensor): input features with shape of (B, N, C)
"""
B, N, _ = x.shape
qkv = self.qkv(x)
qkv = qkv.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
q, k, v = qkv.split([self.key_dim, self.key_dim, self.d], dim=-1)
attn = ((q @ k.transpose(-2, -1)) * self.scale +
(self.attention_biases[:, self.attention_bias_idxs]
if self.training else self.ab))
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
x = self.proj(x)
return x
class Flat(nn.Module):
"""Flat the input from (B, C, H, W) to (B, H*W, C)."""
def __init__(self, ):
super().__init__()
def forward(self, x: torch.Tensor):
x = x.flatten(2).transpose(1, 2)
return x
class LinearMlp(BaseModule):
"""Mlp implemented with linear.
The shape of input and output tensor are (B, N, C).
Args:
in_features (int): Dimension of input features.
hidden_features (int): Dimension of hidden features.
out_features (int): Dimension of output features.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='BN')``.
act_cfg (dict): The config dict for activation between pointwise
convolution. Defaults to ``dict(type='GELU')``.
drop (float): Dropout rate. Defaults to 0.0.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def __init__(self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_cfg=dict(type='GELU'),
drop=0.,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = build_activation_layer(act_cfg)
self.drop1 = nn.Dropout(drop)
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop2 = nn.Dropout(drop)
def forward(self, x):
"""
Args:
x (torch.Tensor): input tensor with shape (B, N, C).
Returns:
torch.Tensor: output tensor with shape (B, N, C).
"""
x = self.drop1(self.act(self.fc1(x)))
x = self.drop2(self.fc2(x))
return x
class ConvMlp(BaseModule):
"""Mlp implemented with 1*1 convolutions.
Args:
in_features (int): Dimension of input features.
hidden_features (int): Dimension of hidden features.
out_features (int): Dimension of output features.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='BN')``.
act_cfg (dict): The config dict for activation between pointwise
convolution. Defaults to ``dict(type='GELU')``.
drop (float): Dropout rate. Defaults to 0.0.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='GELU'),
drop=0.,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
self.act = build_activation_layer(act_cfg)
self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
self.norm1 = build_norm_layer(norm_cfg, hidden_features)[1]
self.norm2 = build_norm_layer(norm_cfg, out_features)[1]
self.drop = nn.Dropout(drop)
def forward(self, x):
"""
Args:
x (torch.Tensor): input tensor with shape (B, C, H, W).
Returns:
torch.Tensor: output tensor with shape (B, C, H, W).
"""
x = self.act(self.norm1(self.fc1(x)))
x = self.drop(x)
x = self.norm2(self.fc2(x))
x = self.drop(x)
return x
class Meta3D(BaseModule):
"""Meta Former block using 3 dimensions inputs, ``torch.Tensor`` with shape
(B, N, C)."""
def __init__(self,
dim,
mlp_ratio=4.,
norm_cfg=dict(type='LN'),
act_cfg=dict(type='GELU'),
drop=0.,
drop_path=0.,
use_layer_scale=True,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.norm1 = build_norm_layer(norm_cfg, dim)[1]
self.token_mixer = AttentionWithBias(dim)
self.norm2 = build_norm_layer(norm_cfg, dim)[1]
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = LinearMlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_cfg=act_cfg,
drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. \
else nn.Identity()
if use_layer_scale:
self.ls1 = LayerScale(dim)
self.ls2 = LayerScale(dim)
else:
self.ls1, self.ls2 = nn.Identity(), nn.Identity()
def forward(self, x):
x = x + self.drop_path(self.ls1(self.token_mixer(self.norm1(x))))
x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x))))
return x
class Meta4D(BaseModule):
"""Meta Former block using 4 dimensions inputs, ``torch.Tensor`` with shape
(B, C, H, W)."""
def __init__(self,
dim,
pool_size=3,
mlp_ratio=4.,
act_cfg=dict(type='GELU'),
drop=0.,
drop_path=0.,
use_layer_scale=True,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.token_mixer = Pooling(pool_size=pool_size)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = ConvMlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_cfg=act_cfg,
drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. \
else nn.Identity()
if use_layer_scale:
self.ls1 = LayerScale(dim, data_format='channels_first')
self.ls2 = LayerScale(dim, data_format='channels_first')
else:
self.ls1, self.ls2 = nn.Identity(), nn.Identity()
def forward(self, x):
x = x + self.drop_path(self.ls1(self.token_mixer(x)))
x = x + self.drop_path(self.ls2(self.mlp(x)))
return x
def basic_blocks(in_channels,
out_channels,
index,
layers,
pool_size=3,
mlp_ratio=4.,
act_cfg=dict(type='GELU'),
drop_rate=.0,
drop_path_rate=0.,
use_layer_scale=True,
vit_num=1,
has_downsamper=False):
"""generate EfficientFormer blocks for a stage."""
blocks = []
if has_downsamper:
blocks.append(
ConvModule(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=2,
padding=1,
bias=True,
norm_cfg=dict(type='BN'),
act_cfg=None))
if index == 3 and vit_num == layers[index]:
blocks.append(Flat())
for block_idx in range(layers[index]):
block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (
sum(layers) - 1)
if index == 3 and layers[index] - block_idx <= vit_num:
blocks.append(
Meta3D(
out_channels,
mlp_ratio=mlp_ratio,
act_cfg=act_cfg,
drop=drop_rate,
drop_path=block_dpr,
use_layer_scale=use_layer_scale,
))
else:
blocks.append(
Meta4D(
out_channels,
pool_size=pool_size,
act_cfg=act_cfg,
drop=drop_rate,
drop_path=block_dpr,
use_layer_scale=use_layer_scale))
if index == 3 and layers[index] - block_idx - 1 == vit_num:
blocks.append(Flat())
blocks = nn.Sequential(*blocks)
return blocks
@MODELS.register_module()
class EfficientFormer(BaseBackbone):
"""EfficientFormer.
A PyTorch implementation of EfficientFormer introduced by:
`EfficientFormer: Vision Transformers at MobileNet Speed <https://arxiv.org/abs/2206.01191>`_
Modified from the `official repo
<https://github.com/snap-research/EfficientFormer>`.
Args:
arch (str | dict): The model's architecture. If string, it should be
one of architecture in ``EfficientFormer.arch_settings``. And if dict,
it should include the following 4 keys:
- layers (list[int]): Number of blocks at each stage.
- embed_dims (list[int]): The number of channels at each stage.
- downsamples (list[int]): Has downsample or not in the four stages.
- vit_num (int): The num of vit blocks in the last stage.
Defaults to 'l1'.
in_channels (int): The num of input channels. Defaults to 3.
pool_size (int): The pooling size of ``Meta4D`` blocks. Defaults to 3.
mlp_ratios (int): The dimension ratio of multi-head attention mechanism
in ``Meta4D`` blocks. Defaults to 3.
reshape_last_feat (bool): Whether to reshape the feature map from
(B, N, C) to (B, C, H, W) in the last stage, when the ``vit-num``
in ``arch`` is not 0. Defaults to False. Usually set to True
in downstream tasks.
out_indices (Sequence[int]): Output from which stages.
Defaults to -1.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
act_cfg (dict): The config dict for activation between pointwise
convolution. Defaults to ``dict(type='GELU')``.
drop_rate (float): Dropout rate. Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
use_layer_scale (bool): Whether to use use_layer_scale in MetaFormer
block. Defaults to True.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
Example:
>>> from mmpretrain.models import EfficientFormer
>>> import torch
>>> inputs = torch.rand((1, 3, 224, 224))
>>> # build EfficientFormer backbone for classification task
>>> model = EfficientFormer(arch="l1")
>>> model.eval()
>>> level_outputs = model(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 448, 49)
>>> # build EfficientFormer backbone for downstream task
>>> model = EfficientFormer(
>>> arch="l3",
>>> out_indices=(0, 1, 2, 3),
>>> reshape_last_feat=True)
>>> model.eval()
>>> level_outputs = model(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 64, 56, 56)
(1, 128, 28, 28)
(1, 320, 14, 14)
(1, 512, 7, 7)
""" # noqa: E501
# --layers: [x,x,x,x], numbers of layers for the four stages
# --embed_dims: [x,x,x,x], embedding dims for the four stages
# --downsamples: [x,x,x,x], has downsample or not in the four stages
# --vit_num(int), the num of vit blocks in the last stage
arch_settings = {
'l1': {
'layers': [3, 2, 6, 4],
'embed_dims': [48, 96, 224, 448],
'downsamples': [False, True, True, True],
'vit_num': 1,
},
'l3': {
'layers': [4, 4, 12, 6],
'embed_dims': [64, 128, 320, 512],
'downsamples': [False, True, True, True],
'vit_num': 4,
},
'l7': {
'layers': [6, 6, 18, 8],
'embed_dims': [96, 192, 384, 768],
'downsamples': [False, True, True, True],
'vit_num': 8,
},
}
def __init__(self,
arch='l1',
in_channels=3,
pool_size=3,
mlp_ratios=4,
reshape_last_feat=False,
out_indices=-1,
frozen_stages=-1,
act_cfg=dict(type='GELU'),
drop_rate=0.,
drop_path_rate=0.,
use_layer_scale=True,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.num_extra_tokens = 0 # no cls_token, no dist_token
if isinstance(arch, str):
assert arch in self.arch_settings, \
f'Unavailable arch, please choose from ' \
f'({set(self.arch_settings)}) or pass a dict.'
arch = self.arch_settings[arch]
elif isinstance(arch, dict):
default_keys = set(self.arch_settings['l1'].keys())
assert set(arch.keys()) == default_keys, \
f'The arch dict must have {default_keys}, ' \
f'but got {list(arch.keys())}.'
self.layers = arch['layers']
self.embed_dims = arch['embed_dims']
self.downsamples = arch['downsamples']
assert isinstance(self.layers, list) and isinstance(
self.embed_dims, list) and isinstance(self.downsamples, list)
assert len(self.layers) == len(self.embed_dims) == len(
self.downsamples)
self.vit_num = arch['vit_num']
self.reshape_last_feat = reshape_last_feat
assert self.vit_num >= 0, "'vit_num' must be an integer " \
'greater than or equal to 0.'
assert self.vit_num <= self.layers[-1], (
"'vit_num' must be an integer smaller than layer number")
self._make_stem(in_channels, self.embed_dims[0])
# set the main block in network
network = []
for i in range(len(self.layers)):
if i != 0:
in_channels = self.embed_dims[i - 1]
else:
in_channels = self.embed_dims[i]
out_channels = self.embed_dims[i]
stage = basic_blocks(
in_channels,
out_channels,
i,
self.layers,
pool_size=pool_size,
mlp_ratio=mlp_ratios,
act_cfg=act_cfg,
drop_rate=drop_rate,
drop_path_rate=drop_path_rate,
vit_num=self.vit_num,
use_layer_scale=use_layer_scale,
has_downsamper=self.downsamples[i])
network.append(stage)
self.network = ModuleList(network)
if isinstance(out_indices, int):
out_indices = [out_indices]
assert isinstance(out_indices, Sequence), \
f'"out_indices" must by a sequence or int, ' \
f'get {type(out_indices)} instead.'
for i, index in enumerate(out_indices):
if index < 0:
out_indices[i] = 4 + index
assert out_indices[i] >= 0, f'Invalid out_indices {index}'
self.out_indices = out_indices
for i_layer in self.out_indices:
if not self.reshape_last_feat and \
i_layer == 3 and self.vit_num > 0:
layer = build_norm_layer(
dict(type='LN'), self.embed_dims[i_layer])[1]
else:
# use GN with 1 group as channel-first LN2D
layer = build_norm_layer(
dict(type='GN', num_groups=1), self.embed_dims[i_layer])[1]
layer_name = f'norm{i_layer}'
self.add_module(layer_name, layer)
self.frozen_stages = frozen_stages
self._freeze_stages()
def _make_stem(self, in_channels: int, stem_channels: int):
"""make 2-ConvBNReLu stem layer."""
self.patch_embed = Sequential(
ConvModule(
in_channels,
stem_channels // 2,
kernel_size=3,
stride=2,
padding=1,
bias=True,
conv_cfg=None,
norm_cfg=dict(type='BN'),
inplace=True),
ConvModule(
stem_channels // 2,
stem_channels,
kernel_size=3,
stride=2,
padding=1,
bias=True,
conv_cfg=None,
norm_cfg=dict(type='BN'),
inplace=True))
def forward_tokens(self, x):
outs = []
for idx, block in enumerate(self.network):
if idx == len(self.network) - 1:
N, _, H, W = x.shape
if self.downsamples[idx]:
H, W = H // 2, W // 2
x = block(x)
if idx in self.out_indices:
norm_layer = getattr(self, f'norm{idx}')
if idx == len(self.network) - 1 and x.dim() == 3:
# when ``vit-num`` > 0 and in the last stage,
# if `self.reshape_last_feat`` is True, reshape the
# features to `BCHW` format before the final normalization.
# if `self.reshape_last_feat`` is False, do
# normalization directly and permute the features to `BCN`.
if self.reshape_last_feat:
x = x.permute((0, 2, 1)).reshape(N, -1, H, W)
x_out = norm_layer(x)
else:
x_out = norm_layer(x).permute((0, 2, 1))
else:
x_out = norm_layer(x)
outs.append(x_out.contiguous())
return tuple(outs)
def forward(self, x):
# input embedding
x = self.patch_embed(x)
# through stages
x = self.forward_tokens(x)
return x
def _freeze_stages(self):
if self.frozen_stages >= 0:
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
for i in range(self.frozen_stages):
# Include both block and downsample layer.
module = self.network[i]
module.eval()
for param in module.parameters():
param.requires_grad = False
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
norm_layer.eval()
for param in norm_layer.parameters():
param.requires_grad = False
def train(self, mode=True):
super(EfficientFormer, self).train(mode)
self._freeze_stages()