607 lines
22 KiB
Python
607 lines
22 KiB
Python
|
# Copyright (c) OpenMMLab. All rights reserved.
|
|||
|
import itertools
|
|||
|
from typing import Optional, Sequence
|
|||
|
|
|||
|
import torch
|
|||
|
import torch.nn as nn
|
|||
|
from mmcv.cnn.bricks import (ConvModule, DropPath, build_activation_layer,
|
|||
|
build_norm_layer)
|
|||
|
from mmengine.model import BaseModule, ModuleList, Sequential
|
|||
|
|
|||
|
from mmcls.registry import MODELS
|
|||
|
from ..utils import LayerScale
|
|||
|
from .base_backbone import BaseBackbone
|
|||
|
from .poolformer import Pooling
|
|||
|
|
|||
|
|
|||
|
class AttentionWithBias(BaseModule):
|
|||
|
"""Multi-head Attention Module with attention_bias.
|
|||
|
|
|||
|
Args:
|
|||
|
embed_dims (int): The embedding dimension.
|
|||
|
num_heads (int): Parallel attention heads. Defaults to 8.
|
|||
|
key_dim (int): The dimension of q, k. Defaults to 32.
|
|||
|
attn_ratio (float): The dimension of v equals to
|
|||
|
``key_dim * attn_ratio``. Defaults to 4.
|
|||
|
resolution (int): The height and width of attention_bias.
|
|||
|
Defaults to 7.
|
|||
|
init_cfg (dict, optional): The Config for initialization.
|
|||
|
Defaults to None.
|
|||
|
"""
|
|||
|
|
|||
|
def __init__(self,
|
|||
|
embed_dims,
|
|||
|
num_heads=8,
|
|||
|
key_dim=32,
|
|||
|
attn_ratio=4.,
|
|||
|
resolution=7,
|
|||
|
init_cfg=None):
|
|||
|
super().__init__(init_cfg=init_cfg)
|
|||
|
self.num_heads = num_heads
|
|||
|
self.scale = key_dim**-0.5
|
|||
|
self.attn_ratio = attn_ratio
|
|||
|
self.key_dim = key_dim
|
|||
|
self.nh_kd = key_dim * num_heads
|
|||
|
self.d = int(attn_ratio * key_dim)
|
|||
|
self.dh = int(attn_ratio * key_dim) * num_heads
|
|||
|
h = self.dh + self.nh_kd * 2
|
|||
|
self.qkv = nn.Linear(embed_dims, h)
|
|||
|
self.proj = nn.Linear(self.dh, embed_dims)
|
|||
|
|
|||
|
points = list(itertools.product(range(resolution), range(resolution)))
|
|||
|
N = len(points)
|
|||
|
attention_offsets = {}
|
|||
|
idxs = []
|
|||
|
for p1 in points:
|
|||
|
for p2 in points:
|
|||
|
offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
|
|||
|
if offset not in attention_offsets:
|
|||
|
attention_offsets[offset] = len(attention_offsets)
|
|||
|
idxs.append(attention_offsets[offset])
|
|||
|
self.attention_biases = nn.Parameter(
|
|||
|
torch.zeros(num_heads, len(attention_offsets)))
|
|||
|
self.register_buffer('attention_bias_idxs',
|
|||
|
torch.LongTensor(idxs).view(N, N))
|
|||
|
|
|||
|
@torch.no_grad()
|
|||
|
def train(self, mode=True):
|
|||
|
"""change the mode of model."""
|
|||
|
super().train(mode)
|
|||
|
if mode and hasattr(self, 'ab'):
|
|||
|
del self.ab
|
|||
|
else:
|
|||
|
self.ab = self.attention_biases[:, self.attention_bias_idxs]
|
|||
|
|
|||
|
def forward(self, x):
|
|||
|
"""forward function.
|
|||
|
|
|||
|
Args:
|
|||
|
x (tensor): input features with shape of (B, N, C)
|
|||
|
"""
|
|||
|
B, N, _ = x.shape
|
|||
|
qkv = self.qkv(x)
|
|||
|
qkv = qkv.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
|
|||
|
q, k, v = qkv.split([self.key_dim, self.key_dim, self.d], dim=-1)
|
|||
|
|
|||
|
attn = ((q @ k.transpose(-2, -1)) * self.scale +
|
|||
|
(self.attention_biases[:, self.attention_bias_idxs]
|
|||
|
if self.training else self.ab))
|
|||
|
attn = attn.softmax(dim=-1)
|
|||
|
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
|
|||
|
x = self.proj(x)
|
|||
|
return x
|
|||
|
|
|||
|
|
|||
|
class Flat(nn.Module):
|
|||
|
"""Flat the input from (B, C, H, W) to (B, H*W, C)."""
|
|||
|
|
|||
|
def __init__(self, ):
|
|||
|
super().__init__()
|
|||
|
|
|||
|
def forward(self, x: torch.Tensor):
|
|||
|
x = x.flatten(2).transpose(1, 2)
|
|||
|
return x
|
|||
|
|
|||
|
|
|||
|
class LinearMlp(BaseModule):
|
|||
|
"""Mlp implemented with linear.
|
|||
|
|
|||
|
The shape of input and output tensor are (B, N, C).
|
|||
|
|
|||
|
Args:
|
|||
|
in_features (int): Dimension of input features.
|
|||
|
hidden_features (int): Dimension of hidden features.
|
|||
|
out_features (int): Dimension of output features.
|
|||
|
norm_cfg (dict): Config dict for normalization layer.
|
|||
|
Defaults to ``dict(type='BN')``.
|
|||
|
act_cfg (dict): The config dict for activation between pointwise
|
|||
|
convolution. Defaults to ``dict(type='GELU')``.
|
|||
|
drop (float): Dropout rate. Defaults to 0.0.
|
|||
|
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
|||
|
Default: None.
|
|||
|
"""
|
|||
|
|
|||
|
def __init__(self,
|
|||
|
in_features: int,
|
|||
|
hidden_features: Optional[int] = None,
|
|||
|
out_features: Optional[int] = None,
|
|||
|
act_cfg=dict(type='GELU'),
|
|||
|
drop=0.,
|
|||
|
init_cfg=None):
|
|||
|
super().__init__(init_cfg=init_cfg)
|
|||
|
out_features = out_features or in_features
|
|||
|
hidden_features = hidden_features or in_features
|
|||
|
|
|||
|
self.fc1 = nn.Linear(in_features, hidden_features)
|
|||
|
self.act = build_activation_layer(act_cfg)
|
|||
|
self.drop1 = nn.Dropout(drop)
|
|||
|
self.fc2 = nn.Linear(hidden_features, out_features)
|
|||
|
self.drop2 = nn.Dropout(drop)
|
|||
|
|
|||
|
def forward(self, x):
|
|||
|
"""
|
|||
|
Args:
|
|||
|
x (torch.Tensor): input tensor with shape (B, N, C).
|
|||
|
|
|||
|
Returns:
|
|||
|
torch.Tensor: output tensor with shape (B, N, C).
|
|||
|
"""
|
|||
|
x = self.drop1(self.act(self.fc1(x)))
|
|||
|
x = self.drop2(self.fc2(x))
|
|||
|
return x
|
|||
|
|
|||
|
|
|||
|
class ConvMlp(BaseModule):
|
|||
|
"""Mlp implemented with 1*1 convolutions.
|
|||
|
|
|||
|
Args:
|
|||
|
in_features (int): Dimension of input features.
|
|||
|
hidden_features (int): Dimension of hidden features.
|
|||
|
out_features (int): Dimension of output features.
|
|||
|
norm_cfg (dict): Config dict for normalization layer.
|
|||
|
Defaults to ``dict(type='BN')``.
|
|||
|
act_cfg (dict): The config dict for activation between pointwise
|
|||
|
convolution. Defaults to ``dict(type='GELU')``.
|
|||
|
drop (float): Dropout rate. Defaults to 0.0.
|
|||
|
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
|||
|
Default: None.
|
|||
|
"""
|
|||
|
|
|||
|
def __init__(self,
|
|||
|
in_features,
|
|||
|
hidden_features=None,
|
|||
|
out_features=None,
|
|||
|
norm_cfg=dict(type='BN'),
|
|||
|
act_cfg=dict(type='GELU'),
|
|||
|
drop=0.,
|
|||
|
init_cfg=None):
|
|||
|
super().__init__(init_cfg=init_cfg)
|
|||
|
out_features = out_features or in_features
|
|||
|
hidden_features = hidden_features or in_features
|
|||
|
self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
|
|||
|
self.act = build_activation_layer(act_cfg)
|
|||
|
self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
|
|||
|
self.norm1 = build_norm_layer(norm_cfg, hidden_features)[1]
|
|||
|
self.norm2 = build_norm_layer(norm_cfg, out_features)[1]
|
|||
|
|
|||
|
self.drop = nn.Dropout(drop)
|
|||
|
|
|||
|
def forward(self, x):
|
|||
|
"""
|
|||
|
Args:
|
|||
|
x (torch.Tensor): input tensor with shape (B, C, H, W).
|
|||
|
|
|||
|
Returns:
|
|||
|
torch.Tensor: output tensor with shape (B, C, H, W).
|
|||
|
"""
|
|||
|
|
|||
|
x = self.act(self.norm1(self.fc1(x)))
|
|||
|
x = self.drop(x)
|
|||
|
x = self.norm2(self.fc2(x))
|
|||
|
x = self.drop(x)
|
|||
|
return x
|
|||
|
|
|||
|
|
|||
|
class Meta3D(BaseModule):
|
|||
|
"""Meta Former block using 3 dimensions inputs, ``torch.Tensor`` with shape
|
|||
|
(B, N, C)."""
|
|||
|
|
|||
|
def __init__(self,
|
|||
|
dim,
|
|||
|
mlp_ratio=4.,
|
|||
|
norm_cfg=dict(type='LN'),
|
|||
|
act_cfg=dict(type='GELU'),
|
|||
|
drop=0.,
|
|||
|
drop_path=0.,
|
|||
|
use_layer_scale=True,
|
|||
|
init_cfg=None):
|
|||
|
super().__init__(init_cfg=init_cfg)
|
|||
|
self.norm1 = build_norm_layer(norm_cfg, dim)[1]
|
|||
|
self.token_mixer = AttentionWithBias(dim)
|
|||
|
self.norm2 = build_norm_layer(norm_cfg, dim)[1]
|
|||
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
|||
|
self.mlp = LinearMlp(
|
|||
|
in_features=dim,
|
|||
|
hidden_features=mlp_hidden_dim,
|
|||
|
act_cfg=act_cfg,
|
|||
|
drop=drop)
|
|||
|
|
|||
|
self.drop_path = DropPath(drop_path) if drop_path > 0. \
|
|||
|
else nn.Identity()
|
|||
|
if use_layer_scale:
|
|||
|
self.ls1 = LayerScale(dim)
|
|||
|
self.ls2 = LayerScale(dim)
|
|||
|
else:
|
|||
|
self.ls1, self.ls2 = nn.Identity(), nn.Identity()
|
|||
|
|
|||
|
def forward(self, x):
|
|||
|
x = x + self.drop_path(self.ls1(self.token_mixer(self.norm1(x))))
|
|||
|
x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x))))
|
|||
|
return x
|
|||
|
|
|||
|
|
|||
|
class Meta4D(BaseModule):
|
|||
|
"""Meta Former block using 4 dimensions inputs, ``torch.Tensor`` with shape
|
|||
|
(B, C, H, W)."""
|
|||
|
|
|||
|
def __init__(self,
|
|||
|
dim,
|
|||
|
pool_size=3,
|
|||
|
mlp_ratio=4.,
|
|||
|
act_cfg=dict(type='GELU'),
|
|||
|
drop=0.,
|
|||
|
drop_path=0.,
|
|||
|
use_layer_scale=True,
|
|||
|
init_cfg=None):
|
|||
|
super().__init__(init_cfg=init_cfg)
|
|||
|
|
|||
|
self.token_mixer = Pooling(pool_size=pool_size)
|
|||
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
|||
|
self.mlp = ConvMlp(
|
|||
|
in_features=dim,
|
|||
|
hidden_features=mlp_hidden_dim,
|
|||
|
act_cfg=act_cfg,
|
|||
|
drop=drop)
|
|||
|
|
|||
|
self.drop_path = DropPath(drop_path) if drop_path > 0. \
|
|||
|
else nn.Identity()
|
|||
|
if use_layer_scale:
|
|||
|
self.ls1 = LayerScale(dim, data_format='channels_first')
|
|||
|
self.ls2 = LayerScale(dim, data_format='channels_first')
|
|||
|
else:
|
|||
|
self.ls1, self.ls2 = nn.Identity(), nn.Identity()
|
|||
|
|
|||
|
def forward(self, x):
|
|||
|
x = x + self.drop_path(self.ls1(self.token_mixer(x)))
|
|||
|
x = x + self.drop_path(self.ls2(self.mlp(x)))
|
|||
|
return x
|
|||
|
|
|||
|
|
|||
|
def basic_blocks(in_channels,
|
|||
|
out_channels,
|
|||
|
index,
|
|||
|
layers,
|
|||
|
pool_size=3,
|
|||
|
mlp_ratio=4.,
|
|||
|
act_cfg=dict(type='GELU'),
|
|||
|
drop_rate=.0,
|
|||
|
drop_path_rate=0.,
|
|||
|
use_layer_scale=True,
|
|||
|
vit_num=1,
|
|||
|
has_downsamper=False):
|
|||
|
"""generate EfficientFormer blocks for a stage."""
|
|||
|
blocks = []
|
|||
|
if has_downsamper:
|
|||
|
blocks.append(
|
|||
|
ConvModule(
|
|||
|
in_channels=in_channels,
|
|||
|
out_channels=out_channels,
|
|||
|
kernel_size=3,
|
|||
|
stride=2,
|
|||
|
padding=1,
|
|||
|
bias=True,
|
|||
|
norm_cfg=dict(type='BN'),
|
|||
|
act_cfg=None))
|
|||
|
if index == 3 and vit_num == layers[index]:
|
|||
|
blocks.append(Flat())
|
|||
|
for block_idx in range(layers[index]):
|
|||
|
block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (
|
|||
|
sum(layers) - 1)
|
|||
|
if index == 3 and layers[index] - block_idx <= vit_num:
|
|||
|
blocks.append(
|
|||
|
Meta3D(
|
|||
|
out_channels,
|
|||
|
mlp_ratio=mlp_ratio,
|
|||
|
act_cfg=act_cfg,
|
|||
|
drop=drop_rate,
|
|||
|
drop_path=block_dpr,
|
|||
|
use_layer_scale=use_layer_scale,
|
|||
|
))
|
|||
|
else:
|
|||
|
blocks.append(
|
|||
|
Meta4D(
|
|||
|
out_channels,
|
|||
|
pool_size=pool_size,
|
|||
|
act_cfg=act_cfg,
|
|||
|
drop=drop_rate,
|
|||
|
drop_path=block_dpr,
|
|||
|
use_layer_scale=use_layer_scale))
|
|||
|
if index == 3 and layers[index] - block_idx - 1 == vit_num:
|
|||
|
blocks.append(Flat())
|
|||
|
blocks = nn.Sequential(*blocks)
|
|||
|
return blocks
|
|||
|
|
|||
|
|
|||
|
@MODELS.register_module()
|
|||
|
class EfficientFormer(BaseBackbone):
|
|||
|
"""EfficientFormer.
|
|||
|
|
|||
|
A PyTorch implementation of EfficientFormer introduced by:
|
|||
|
`EfficientFormer: Vision Transformers at MobileNet Speed <https://arxiv.org/abs/2206.01191>`_
|
|||
|
|
|||
|
Modified from the `official repo
|
|||
|
<https://github.com/snap-research/EfficientFormer>`.
|
|||
|
|
|||
|
Args:
|
|||
|
arch (str | dict): The model's architecture. If string, it should be
|
|||
|
one of architecture in ``EfficientFormer.arch_settings``. And if dict,
|
|||
|
it should include the following 4 keys:
|
|||
|
|
|||
|
- layers (list[int]): Number of blocks at each stage.
|
|||
|
- embed_dims (list[int]): The number of channels at each stage.
|
|||
|
- downsamples (list[int]): Has downsample or not in the four stages.
|
|||
|
- vit_num (int): The num of vit blocks in the last stage.
|
|||
|
|
|||
|
Defaults to 'l1'.
|
|||
|
|
|||
|
in_channels (int): The num of input channels. Defaults to 3.
|
|||
|
pool_size (int): The pooling size of ``Meta4D`` blocks. Defaults to 3.
|
|||
|
mlp_ratios (int): The dimension ratio of multi-head attention mechanism
|
|||
|
in ``Meta4D`` blocks. Defaults to 3.
|
|||
|
reshape_last_feat (bool): Whether to reshape the feature map from
|
|||
|
(B, N, C) to (B, C, H, W) in the last stage, when the ``vit-num``
|
|||
|
in ``arch`` is not 0. Defaults to False. Usually set to True
|
|||
|
in downstream tasks.
|
|||
|
out_indices (Sequence[int]): Output from which stages.
|
|||
|
Defaults to -1.
|
|||
|
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
|||
|
-1 means not freezing any parameters. Defaults to -1.
|
|||
|
act_cfg (dict): The config dict for activation between pointwise
|
|||
|
convolution. Defaults to ``dict(type='GELU')``.
|
|||
|
drop_rate (float): Dropout rate. Defaults to 0.
|
|||
|
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
|
|||
|
use_layer_scale (bool): Whether to use use_layer_scale in MetaFormer
|
|||
|
block. Defaults to True.
|
|||
|
init_cfg (dict, optional): Initialization config dict.
|
|||
|
Defaults to None.
|
|||
|
|
|||
|
Example:
|
|||
|
>>> from mmcls.models import EfficientFormer
|
|||
|
>>> import torch
|
|||
|
>>> inputs = torch.rand((1, 3, 224, 224))
|
|||
|
>>> # build EfficientFormer backbone for classification task
|
|||
|
>>> model = EfficientFormer(arch="l1")
|
|||
|
>>> model.eval()
|
|||
|
>>> level_outputs = model(inputs)
|
|||
|
>>> for level_out in level_outputs:
|
|||
|
... print(tuple(level_out.shape))
|
|||
|
(1, 448, 49)
|
|||
|
>>> # build EfficientFormer backbone for downstream task
|
|||
|
>>> model = EfficientFormer(
|
|||
|
>>> arch="l3",
|
|||
|
>>> out_indices=(0, 1, 2, 3),
|
|||
|
>>> reshape_last_feat=True)
|
|||
|
>>> model.eval()
|
|||
|
>>> level_outputs = model(inputs)
|
|||
|
>>> for level_out in level_outputs:
|
|||
|
... print(tuple(level_out.shape))
|
|||
|
(1, 64, 56, 56)
|
|||
|
(1, 128, 28, 28)
|
|||
|
(1, 320, 14, 14)
|
|||
|
(1, 512, 7, 7)
|
|||
|
""" # noqa: E501
|
|||
|
|
|||
|
# --layers: [x,x,x,x], numbers of layers for the four stages
|
|||
|
# --embed_dims: [x,x,x,x], embedding dims for the four stages
|
|||
|
# --downsamples: [x,x,x,x], has downsample or not in the four stages
|
|||
|
# --vit_num:(int), the num of vit blocks in the last stage
|
|||
|
arch_settings = {
|
|||
|
'l1': {
|
|||
|
'layers': [3, 2, 6, 4],
|
|||
|
'embed_dims': [48, 96, 224, 448],
|
|||
|
'downsamples': [False, True, True, True],
|
|||
|
'vit_num': 1,
|
|||
|
},
|
|||
|
'l3': {
|
|||
|
'layers': [4, 4, 12, 6],
|
|||
|
'embed_dims': [64, 128, 320, 512],
|
|||
|
'downsamples': [False, True, True, True],
|
|||
|
'vit_num': 4,
|
|||
|
},
|
|||
|
'l7': {
|
|||
|
'layers': [6, 6, 18, 8],
|
|||
|
'embed_dims': [96, 192, 384, 768],
|
|||
|
'downsamples': [False, True, True, True],
|
|||
|
'vit_num': 8,
|
|||
|
},
|
|||
|
}
|
|||
|
|
|||
|
def __init__(self,
|
|||
|
arch='l1',
|
|||
|
in_channels=3,
|
|||
|
pool_size=3,
|
|||
|
mlp_ratios=4,
|
|||
|
reshape_last_feat=False,
|
|||
|
out_indices=-1,
|
|||
|
frozen_stages=-1,
|
|||
|
act_cfg=dict(type='GELU'),
|
|||
|
drop_rate=0.,
|
|||
|
drop_path_rate=0.,
|
|||
|
use_layer_scale=True,
|
|||
|
init_cfg=None):
|
|||
|
|
|||
|
super().__init__(init_cfg=init_cfg)
|
|||
|
self.num_extra_tokens = 0 # no cls_token, no dist_token
|
|||
|
|
|||
|
if isinstance(arch, str):
|
|||
|
assert arch in self.arch_settings, \
|
|||
|
f'Unavailable arch, please choose from ' \
|
|||
|
f'({set(self.arch_settings)}) or pass a dict.'
|
|||
|
arch = self.arch_settings[arch]
|
|||
|
elif isinstance(arch, dict):
|
|||
|
default_keys = set(self.arch_settings['l1'].keys())
|
|||
|
assert set(arch.keys()) == default_keys, \
|
|||
|
f'The arch dict must have {default_keys}, ' \
|
|||
|
f'but got {list(arch.keys())}.'
|
|||
|
|
|||
|
self.layers = arch['layers']
|
|||
|
self.embed_dims = arch['embed_dims']
|
|||
|
self.downsamples = arch['downsamples']
|
|||
|
assert isinstance(self.layers, list) and isinstance(
|
|||
|
self.embed_dims, list) and isinstance(self.downsamples, list)
|
|||
|
assert len(self.layers) == len(self.embed_dims) == len(
|
|||
|
self.downsamples)
|
|||
|
|
|||
|
self.vit_num = arch['vit_num']
|
|||
|
self.reshape_last_feat = reshape_last_feat
|
|||
|
|
|||
|
assert self.vit_num >= 0, "'vit_num' must be an integer " \
|
|||
|
'greater than or equal to 0.'
|
|||
|
assert self.vit_num <= self.layers[-1], (
|
|||
|
"'vit_num' must be an integer smaller than layer number")
|
|||
|
|
|||
|
self._make_stem(in_channels, self.embed_dims[0])
|
|||
|
|
|||
|
# set the main block in network
|
|||
|
network = []
|
|||
|
for i in range(len(self.layers)):
|
|||
|
if i != 0:
|
|||
|
in_channels = self.embed_dims[i - 1]
|
|||
|
else:
|
|||
|
in_channels = self.embed_dims[i]
|
|||
|
out_channels = self.embed_dims[i]
|
|||
|
stage = basic_blocks(
|
|||
|
in_channels,
|
|||
|
out_channels,
|
|||
|
i,
|
|||
|
self.layers,
|
|||
|
pool_size=pool_size,
|
|||
|
mlp_ratio=mlp_ratios,
|
|||
|
act_cfg=act_cfg,
|
|||
|
drop_rate=drop_rate,
|
|||
|
drop_path_rate=drop_path_rate,
|
|||
|
vit_num=self.vit_num,
|
|||
|
use_layer_scale=use_layer_scale,
|
|||
|
has_downsamper=self.downsamples[i])
|
|||
|
network.append(stage)
|
|||
|
|
|||
|
self.network = ModuleList(network)
|
|||
|
|
|||
|
if isinstance(out_indices, int):
|
|||
|
out_indices = [out_indices]
|
|||
|
assert isinstance(out_indices, Sequence), \
|
|||
|
f'"out_indices" must by a sequence or int, ' \
|
|||
|
f'get {type(out_indices)} instead.'
|
|||
|
for i, index in enumerate(out_indices):
|
|||
|
if index < 0:
|
|||
|
out_indices[i] = 4 + index
|
|||
|
assert out_indices[i] >= 0, f'Invalid out_indices {index}'
|
|||
|
|
|||
|
self.out_indices = out_indices
|
|||
|
for i_layer in self.out_indices:
|
|||
|
if not self.reshape_last_feat and \
|
|||
|
i_layer == 3 and self.vit_num > 0:
|
|||
|
layer = build_norm_layer(
|
|||
|
dict(type='LN'), self.embed_dims[i_layer])[1]
|
|||
|
else:
|
|||
|
# use GN with 1 group as channel-first LN2D
|
|||
|
layer = build_norm_layer(
|
|||
|
dict(type='GN', num_groups=1), self.embed_dims[i_layer])[1]
|
|||
|
|
|||
|
layer_name = f'norm{i_layer}'
|
|||
|
self.add_module(layer_name, layer)
|
|||
|
|
|||
|
self.frozen_stages = frozen_stages
|
|||
|
self._freeze_stages()
|
|||
|
|
|||
|
def _make_stem(self, in_channels: int, stem_channels: int):
|
|||
|
"""make 2-ConvBNReLu stem layer."""
|
|||
|
self.patch_embed = Sequential(
|
|||
|
ConvModule(
|
|||
|
in_channels,
|
|||
|
stem_channels // 2,
|
|||
|
kernel_size=3,
|
|||
|
stride=2,
|
|||
|
padding=1,
|
|||
|
bias=True,
|
|||
|
conv_cfg=None,
|
|||
|
norm_cfg=dict(type='BN'),
|
|||
|
inplace=True),
|
|||
|
ConvModule(
|
|||
|
stem_channels // 2,
|
|||
|
stem_channels,
|
|||
|
kernel_size=3,
|
|||
|
stride=2,
|
|||
|
padding=1,
|
|||
|
bias=True,
|
|||
|
conv_cfg=None,
|
|||
|
norm_cfg=dict(type='BN'),
|
|||
|
inplace=True))
|
|||
|
|
|||
|
def forward_tokens(self, x):
|
|||
|
outs = []
|
|||
|
for idx, block in enumerate(self.network):
|
|||
|
if idx == len(self.network) - 1:
|
|||
|
N, _, H, W = x.shape
|
|||
|
if self.downsamples[idx]:
|
|||
|
H, W = H // 2, W // 2
|
|||
|
x = block(x)
|
|||
|
if idx in self.out_indices:
|
|||
|
norm_layer = getattr(self, f'norm{idx}')
|
|||
|
|
|||
|
if idx == len(self.network) - 1 and x.dim() == 3:
|
|||
|
# when ``vit-num`` > 0 and in the last stage,
|
|||
|
# if `self.reshape_last_feat`` is True, reshape the
|
|||
|
# features to `BCHW` format before the final normalization.
|
|||
|
# if `self.reshape_last_feat`` is False, do
|
|||
|
# normalization directly and permute the features to `BCN`.
|
|||
|
if self.reshape_last_feat:
|
|||
|
x = x.permute((0, 2, 1)).reshape(N, -1, H, W)
|
|||
|
x_out = norm_layer(x)
|
|||
|
else:
|
|||
|
x_out = norm_layer(x).permute((0, 2, 1))
|
|||
|
else:
|
|||
|
x_out = norm_layer(x)
|
|||
|
|
|||
|
outs.append(x_out.contiguous())
|
|||
|
return tuple(outs)
|
|||
|
|
|||
|
def forward(self, x):
|
|||
|
# input embedding
|
|||
|
x = self.patch_embed(x)
|
|||
|
# through stages
|
|||
|
x = self.forward_tokens(x)
|
|||
|
return x
|
|||
|
|
|||
|
def _freeze_stages(self):
|
|||
|
if self.frozen_stages >= 0:
|
|||
|
self.patch_embed.eval()
|
|||
|
for param in self.patch_embed.parameters():
|
|||
|
param.requires_grad = False
|
|||
|
|
|||
|
for i in range(self.frozen_stages):
|
|||
|
# Include both block and downsample layer.
|
|||
|
module = self.network[i]
|
|||
|
module.eval()
|
|||
|
for param in module.parameters():
|
|||
|
param.requires_grad = False
|
|||
|
if i in self.out_indices:
|
|||
|
norm_layer = getattr(self, f'norm{i}')
|
|||
|
norm_layer.eval()
|
|||
|
for param in norm_layer.parameters():
|
|||
|
param.requires_grad = False
|
|||
|
|
|||
|
def train(self, mode=True):
|
|||
|
super(EfficientFormer, self).train(mode)
|
|||
|
self._freeze_stages()
|