770 lines
26 KiB
Python
770 lines
26 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from typing import Sequence, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.utils.checkpoint as checkpoint
|
|
from mmcv.cnn.bricks import DropPath, build_activation_layer, build_norm_layer
|
|
from mmengine.model import BaseModule, ModuleList, Sequential
|
|
from torch.nn import functional as F
|
|
|
|
from mmcls.registry import MODELS
|
|
from ..utils import LeAttention
|
|
from .base_backbone import BaseBackbone
|
|
|
|
|
|
class ConvBN2d(Sequential):
|
|
"""An implementation of Conv2d + BatchNorm2d with support of fusion.
|
|
|
|
Modified from
|
|
https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py
|
|
|
|
Args:
|
|
in_channels (int): The number of input channels.
|
|
out_channels (int): The number of output channels.
|
|
kernel_size (int): The size of the convolution kernel.
|
|
Default: 1.
|
|
stride (int): The stride of the convolution.
|
|
Default: 1.
|
|
padding (int): The padding of the convolution.
|
|
Default: 0.
|
|
dilation (int): The dilation of the convolution.
|
|
Default: 1.
|
|
groups (int): The number of groups in the convolution.
|
|
Default: 1.
|
|
bn_weight_init (float): The initial value of the weight of
|
|
the nn.BatchNorm2d layer. Default: 1.0.
|
|
init_cfg (dict): The initialization config of the module.
|
|
Default: None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
dilation=1,
|
|
groups=1,
|
|
bn_weight_init=1.0,
|
|
init_cfg=None):
|
|
super().__init__(init_cfg=init_cfg)
|
|
self.add_module(
|
|
'conv2d',
|
|
nn.Conv2d(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
dilation=dilation,
|
|
groups=groups,
|
|
bias=False))
|
|
bn2d = nn.BatchNorm2d(num_features=out_channels)
|
|
# bn initialization
|
|
torch.nn.init.constant_(bn2d.weight, bn_weight_init)
|
|
torch.nn.init.constant_(bn2d.bias, 0)
|
|
self.add_module('bn2d', bn2d)
|
|
|
|
@torch.no_grad()
|
|
def fuse(self):
|
|
conv2d, bn2d = self._modules.values()
|
|
w = bn2d.weight / (bn2d.running_var + bn2d.eps)**0.5
|
|
w = conv2d.weight * w[:, None, None, None]
|
|
b = bn2d.bias - bn2d.running_mean * bn2d.weight / \
|
|
(bn2d.running_var + bn2d.eps)**0.5
|
|
|
|
m = nn.Conv2d(
|
|
in_channels=w.size(1) * self.c.groups,
|
|
out_channels=w.size(0),
|
|
kernel_size=w.shape[2:],
|
|
stride=self.conv2d.stride,
|
|
padding=self.conv2d.padding,
|
|
dilation=self.conv2d.dilation,
|
|
groups=self.conv2d.groups)
|
|
m.weight.data.copy_(w)
|
|
m.bias.data.copy_(b)
|
|
return m
|
|
|
|
|
|
class PatchEmbed(BaseModule):
|
|
"""Patch Embedding for Vision Transformer.
|
|
|
|
Adapted from
|
|
https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py
|
|
|
|
Different from `mmcv.cnn.bricks.transformer.PatchEmbed`, this module use
|
|
Conv2d and BatchNorm2d to implement PatchEmbedding, and output shape is
|
|
(N, C, H, W).
|
|
|
|
Args:
|
|
in_channels (int): The number of input channels.
|
|
embed_dim (int): The embedding dimension.
|
|
resolution (Tuple[int, int]): The resolution of the input feature.
|
|
act_cfg (dict): The activation config of the module.
|
|
Default: dict(type='GELU').
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
embed_dim,
|
|
resolution,
|
|
act_cfg=dict(type='GELU')):
|
|
super().__init__()
|
|
img_size: Tuple[int, int] = resolution
|
|
self.patches_resolution = (img_size[0] // 4, img_size[1] // 4)
|
|
self.num_patches = self.patches_resolution[0] * \
|
|
self.patches_resolution[1]
|
|
self.in_channels = in_channels
|
|
self.embed_dim = embed_dim
|
|
self.seq = nn.Sequential(
|
|
ConvBN2d(
|
|
in_channels,
|
|
embed_dim // 2,
|
|
kernel_size=3,
|
|
stride=2,
|
|
padding=1),
|
|
build_activation_layer(act_cfg),
|
|
ConvBN2d(
|
|
embed_dim // 2, embed_dim, kernel_size=3, stride=2, padding=1),
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.seq(x)
|
|
|
|
|
|
class PatchMerging(nn.Module):
|
|
"""Patch Merging for TinyViT.
|
|
|
|
Adapted from
|
|
https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py
|
|
|
|
Different from `mmcls.models.utils.PatchMerging`, this module use Conv2d
|
|
and BatchNorm2d to implement PatchMerging.
|
|
|
|
Args:
|
|
in_channels (int): The number of input channels.
|
|
resolution (Tuple[int, int]): The resolution of the input feature.
|
|
out_channels (int): The number of output channels.
|
|
act_cfg (dict): The activation config of the module.
|
|
Default: dict(type='GELU').
|
|
"""
|
|
|
|
def __init__(self,
|
|
resolution,
|
|
in_channels,
|
|
out_channels,
|
|
act_cfg=dict(type='GELU')):
|
|
super().__init__()
|
|
|
|
self.img_size = resolution
|
|
|
|
self.act = build_activation_layer(act_cfg)
|
|
self.conv1 = ConvBN2d(in_channels, out_channels, kernel_size=1)
|
|
self.conv2 = ConvBN2d(
|
|
out_channels,
|
|
out_channels,
|
|
kernel_size=3,
|
|
stride=2,
|
|
padding=1,
|
|
groups=out_channels)
|
|
self.conv3 = ConvBN2d(out_channels, out_channels, kernel_size=1)
|
|
self.out_resolution = (resolution[0] // 2, resolution[1] // 2)
|
|
|
|
def forward(self, x):
|
|
if len(x.shape) == 3:
|
|
H, W = self.img_size
|
|
B = x.shape[0]
|
|
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
|
|
x = self.conv1(x)
|
|
x = self.act(x)
|
|
x = self.conv2(x)
|
|
x = self.act(x)
|
|
x = self.conv3(x)
|
|
|
|
x = x.flatten(2).transpose(1, 2)
|
|
return x
|
|
|
|
|
|
class MBConvBlock(nn.Module):
|
|
"""Mobile Inverted Residual Bottleneck Block for TinyViT. Adapted from
|
|
https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py.
|
|
|
|
Args:
|
|
in_channels (int): The number of input channels.
|
|
out_channels (int): The number of output channels.
|
|
expand_ratio (int): The expand ratio of the hidden channels.
|
|
drop_rate (float): The drop rate of the block.
|
|
act_cfg (dict): The activation config of the module.
|
|
Default: dict(type='GELU').
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
expand_ratio,
|
|
drop_path,
|
|
act_cfg=dict(type='GELU')):
|
|
super().__init__()
|
|
self.in_channels = in_channels
|
|
hidden_channels = int(in_channels * expand_ratio)
|
|
|
|
# linear
|
|
self.conv1 = ConvBN2d(in_channels, hidden_channels, kernel_size=1)
|
|
self.act = build_activation_layer(act_cfg)
|
|
# depthwise conv
|
|
self.conv2 = ConvBN2d(
|
|
in_channels=hidden_channels,
|
|
out_channels=hidden_channels,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
groups=hidden_channels)
|
|
# linear
|
|
self.conv3 = ConvBN2d(
|
|
hidden_channels, out_channels, kernel_size=1, bn_weight_init=0.0)
|
|
|
|
self.drop_path = DropPath(
|
|
drop_path) if drop_path > 0. else nn.Identity()
|
|
|
|
def forward(self, x):
|
|
shortcut = x
|
|
|
|
x = self.conv1(x)
|
|
x = self.act(x)
|
|
|
|
x = self.conv2(x)
|
|
x = self.act(x)
|
|
|
|
x = self.conv3(x)
|
|
|
|
x = self.drop_path(x)
|
|
|
|
x += shortcut
|
|
x = self.act(x)
|
|
|
|
return x
|
|
|
|
|
|
class ConvStage(BaseModule):
|
|
"""Convolution Stage for TinyViT.
|
|
|
|
Adapted from
|
|
https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py
|
|
|
|
Args:
|
|
in_channels (int): The number of input channels.
|
|
resolution (Tuple[int, int]): The resolution of the input feature.
|
|
depth (int): The number of blocks in the stage.
|
|
act_cfg (dict): The activation config of the module.
|
|
drop_path (float): The drop path of the block.
|
|
downsample (None | nn.Module): The downsample operation.
|
|
Default: None.
|
|
use_checkpoint (bool): Whether to use checkpointing to save memory.
|
|
out_channels (int): The number of output channels.
|
|
conv_expand_ratio (int): The expand ratio of the hidden channels.
|
|
Default: 4.
|
|
init_cfg (dict | list[dict], optional): Initialization config dict.
|
|
Default: None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
resolution,
|
|
depth,
|
|
act_cfg,
|
|
drop_path=0.,
|
|
downsample=None,
|
|
use_checkpoint=False,
|
|
out_channels=None,
|
|
conv_expand_ratio=4.,
|
|
init_cfg=None):
|
|
super().__init__(init_cfg=init_cfg)
|
|
|
|
self.use_checkpoint = use_checkpoint
|
|
# build blocks
|
|
self.blocks = ModuleList([
|
|
MBConvBlock(
|
|
in_channels=in_channels,
|
|
out_channels=in_channels,
|
|
expand_ratio=conv_expand_ratio,
|
|
drop_path=drop_path[i]
|
|
if isinstance(drop_path, list) else drop_path)
|
|
for i in range(depth)
|
|
])
|
|
|
|
# patch merging layer
|
|
if downsample is not None:
|
|
self.downsample = downsample(
|
|
resolution=resolution,
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
act_cfg=act_cfg)
|
|
self.resolution = self.downsample.out_resolution
|
|
else:
|
|
self.downsample = None
|
|
self.resolution = resolution
|
|
|
|
def forward(self, x):
|
|
for block in self.blocks:
|
|
if self.use_checkpoint:
|
|
x = checkpoint.checkpoint(block, x)
|
|
else:
|
|
x = block(x)
|
|
|
|
if self.downsample is not None:
|
|
x = self.downsample(x)
|
|
return x
|
|
|
|
|
|
class MLP(BaseModule):
|
|
"""MLP module for TinyViT.
|
|
|
|
Args:
|
|
in_channels (int): The number of input channels.
|
|
hidden_channels (int, optional): The number of hidden channels.
|
|
Default: None.
|
|
out_channels (int, optional): The number of output channels.
|
|
Default: None.
|
|
act_cfg (dict): The activation config of the module.
|
|
Default: dict(type='GELU').
|
|
drop (float): Probability of an element to be zeroed.
|
|
Default: 0.
|
|
init_cfg (dict | list[dict], optional): Initialization config dict.
|
|
Default: None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
hidden_channels=None,
|
|
out_channels=None,
|
|
act_cfg=dict(type='GELU'),
|
|
drop=0.,
|
|
init_cfg=None):
|
|
super().__init__(init_cfg=init_cfg)
|
|
out_channels = out_channels or in_channels
|
|
hidden_channels = hidden_channels or in_channels
|
|
self.norm = nn.LayerNorm(in_channels)
|
|
self.fc1 = nn.Linear(in_channels, hidden_channels)
|
|
self.fc2 = nn.Linear(hidden_channels, out_channels)
|
|
self.act = build_activation_layer(act_cfg)
|
|
self.drop = nn.Dropout(drop)
|
|
|
|
def forward(self, x):
|
|
x = self.norm(x)
|
|
|
|
x = self.fc1(x)
|
|
x = self.act(x)
|
|
x = self.drop(x)
|
|
x = self.fc2(x)
|
|
x = self.drop(x)
|
|
return x
|
|
|
|
|
|
class TinyViTBlock(BaseModule):
|
|
"""TinViT Block.
|
|
|
|
Args:
|
|
in_channels (int): The number of input channels.
|
|
resolution (Tuple[int, int]): The resolution of the input feature.
|
|
num_heads (int): The number of heads in the multi-head attention.
|
|
window_size (int): The size of the window.
|
|
Default: 7.
|
|
mlp_ratio (float): The ratio of mlp hidden dim to embedding dim.
|
|
Default: 4.
|
|
drop (float): Probability of an element to be zeroed.
|
|
Default: 0.
|
|
drop_path (float): The drop path of the block.
|
|
Default: 0.
|
|
local_conv_size (int): The size of the local convolution.
|
|
Default: 3.
|
|
act_cfg (dict): The activation config of the module.
|
|
Default: dict(type='GELU').
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
resolution,
|
|
num_heads,
|
|
window_size=7,
|
|
mlp_ratio=4.,
|
|
drop=0.,
|
|
drop_path=0.,
|
|
local_conv_size=3,
|
|
act_cfg=dict(type='GELU')):
|
|
super().__init__()
|
|
self.in_channels = in_channels
|
|
self.img_size = resolution
|
|
self.num_heads = num_heads
|
|
assert window_size > 0, 'window_size must be greater than 0'
|
|
self.window_size = window_size
|
|
self.mlp_ratio = mlp_ratio
|
|
|
|
self.drop_path = DropPath(
|
|
drop_path) if drop_path > 0. else nn.Identity()
|
|
|
|
assert in_channels % num_heads == 0, \
|
|
'dim must be divisible by num_heads'
|
|
head_dim = in_channels // num_heads
|
|
|
|
window_resolution = (window_size, window_size)
|
|
self.attn = LeAttention(
|
|
in_channels,
|
|
head_dim,
|
|
num_heads,
|
|
attn_ratio=1,
|
|
resolution=window_resolution)
|
|
|
|
mlp_hidden_dim = int(in_channels * mlp_ratio)
|
|
self.mlp = MLP(
|
|
in_channels=in_channels,
|
|
hidden_channels=mlp_hidden_dim,
|
|
act_cfg=act_cfg,
|
|
drop=drop)
|
|
|
|
self.local_conv = ConvBN2d(
|
|
in_channels=in_channels,
|
|
out_channels=in_channels,
|
|
kernel_size=local_conv_size,
|
|
stride=1,
|
|
padding=local_conv_size // 2,
|
|
groups=in_channels)
|
|
|
|
def forward(self, x):
|
|
H, W = self.img_size
|
|
B, L, C = x.shape
|
|
assert L == H * W, 'input feature has wrong size'
|
|
res_x = x
|
|
if H == self.window_size and W == self.window_size:
|
|
x = self.attn(x)
|
|
else:
|
|
x = x.view(B, H, W, C)
|
|
pad_b = (self.window_size -
|
|
H % self.window_size) % self.window_size
|
|
pad_r = (self.window_size -
|
|
W % self.window_size) % self.window_size
|
|
padding = pad_b > 0 or pad_r > 0
|
|
|
|
if padding:
|
|
x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
|
|
|
|
pH, pW = H + pad_b, W + pad_r
|
|
nH = pH // self.window_size
|
|
nW = pW // self.window_size
|
|
# window partition
|
|
x = x.view(B, nH, self.window_size, nW, self.window_size,
|
|
C).transpose(2, 3).reshape(
|
|
B * nH * nW, self.window_size * self.window_size, C)
|
|
x = self.attn(x)
|
|
# window reverse
|
|
x = x.view(B, nH, nW, self.window_size, self.window_size,
|
|
C).transpose(2, 3).reshape(B, pH, pW, C)
|
|
|
|
if padding:
|
|
x = x[:, :H, :W].contiguous()
|
|
|
|
x = x.view(B, L, C)
|
|
|
|
x = res_x + self.drop_path(x)
|
|
|
|
x = x.transpose(1, 2).reshape(B, C, H, W)
|
|
x = self.local_conv(x)
|
|
x = x.view(B, C, L).transpose(1, 2)
|
|
|
|
x = x + self.drop_path(self.mlp(x))
|
|
return x
|
|
|
|
|
|
class BasicStage(BaseModule):
|
|
"""Basic Stage for TinyViT.
|
|
|
|
Args:
|
|
in_channels (int): The number of input channels.
|
|
resolution (Tuple[int, int]): The resolution of the input feature.
|
|
depth (int): The number of blocks in the stage.
|
|
num_heads (int): The number of heads in the multi-head attention.
|
|
window_size (int): The size of the window.
|
|
mlp_ratio (float): The ratio of mlp hidden dim to embedding dim.
|
|
Default: 4.
|
|
drop (float): Probability of an element to be zeroed.
|
|
Default: 0.
|
|
drop_path (float): The drop path of the block.
|
|
Default: 0.
|
|
downsample (None | nn.Module): The downsample operation.
|
|
Default: None.
|
|
use_checkpoint (bool): Whether to use checkpointing to save memory.
|
|
Default: False.
|
|
act_cfg (dict): The activation config of the module.
|
|
Default: dict(type='GELU').
|
|
init_cfg (dict | list[dict], optional): Initialization config dict.
|
|
Default: None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
resolution,
|
|
depth,
|
|
num_heads,
|
|
window_size,
|
|
mlp_ratio=4.,
|
|
drop=0.,
|
|
drop_path=0.,
|
|
downsample=None,
|
|
use_checkpoint=False,
|
|
local_conv_size=3,
|
|
out_channels=None,
|
|
act_cfg=dict(type='GELU'),
|
|
init_cfg=None):
|
|
super().__init__(init_cfg=init_cfg)
|
|
self.use_checkpoint = use_checkpoint
|
|
# build blocks
|
|
self.blocks = ModuleList([
|
|
TinyViTBlock(
|
|
in_channels=in_channels,
|
|
resolution=resolution,
|
|
num_heads=num_heads,
|
|
window_size=window_size,
|
|
mlp_ratio=mlp_ratio,
|
|
drop=drop,
|
|
local_conv_size=local_conv_size,
|
|
act_cfg=act_cfg,
|
|
drop_path=drop_path[i]
|
|
if isinstance(drop_path, list) else drop_path)
|
|
for i in range(depth)
|
|
])
|
|
|
|
# build patch merging layer
|
|
if downsample is not None:
|
|
self.downsample = downsample(
|
|
resolution=resolution,
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
act_cfg=act_cfg)
|
|
self.resolution = self.downsample.out_resolution
|
|
else:
|
|
self.downsample = None
|
|
self.resolution = resolution
|
|
|
|
def forward(self, x):
|
|
for block in self.blocks:
|
|
if self.use_checkpoint:
|
|
x = checkpoint.checkpoint(block, x)
|
|
else:
|
|
x = block(x)
|
|
|
|
if self.downsample is not None:
|
|
x = self.downsample(x)
|
|
return x
|
|
|
|
|
|
@MODELS.register_module()
|
|
class TinyViT(BaseBackbone):
|
|
"""TinyViT.
|
|
A PyTorch implementation of : `TinyViT: Fast Pretraining Distillation
|
|
for Small Vision Transformers<https://arxiv.org/abs/2201.03545v1>`_
|
|
|
|
Inspiration from
|
|
https://github.com/microsoft/Cream/blob/main/TinyViT
|
|
|
|
Args:
|
|
arch (str | dict): The architecture of TinyViT.
|
|
Default: '5m'.
|
|
img_size (tuple | int): The resolution of the input image.
|
|
Default: (224, 224)
|
|
window_size (list): The size of the window.
|
|
Default: [7, 7, 14, 7]
|
|
in_channels (int): The number of input channels.
|
|
Default: 3.
|
|
depths (list[int]): The depth of each stage.
|
|
Default: [2, 2, 6, 2].
|
|
mlp_ratio (list[int]): The ratio of mlp hidden dim to embedding dim.
|
|
Default: 4.
|
|
drop_rate (float): Probability of an element to be zeroed.
|
|
Default: 0.
|
|
drop_path_rate (float): The drop path of the block.
|
|
Default: 0.1.
|
|
use_checkpoint (bool): Whether to use checkpointing to save memory.
|
|
Default: False.
|
|
mbconv_expand_ratio (int): The expand ratio of the mbconv.
|
|
Default: 4.0
|
|
local_conv_size (int): The size of the local conv.
|
|
Default: 3.
|
|
layer_lr_decay (float): The layer lr decay.
|
|
Default: 1.0
|
|
out_indices (int | list[int]): Output from which stages.
|
|
Default: -1
|
|
frozen_stages (int | list[int]): Stages to be frozen (all param fixed).
|
|
Default: -0
|
|
gap_before_final_nrom (bool): Whether to add a gap before the final
|
|
norm. Default: True.
|
|
act_cfg (dict): The activation config of the module.
|
|
Default: dict(type='GELU').
|
|
norm_cfg (dict): Config dict for normalization layer.
|
|
Default: dict(type='LN').
|
|
init_cfg (dict | list[dict], optional): Initialization config dict.
|
|
Default: None.
|
|
"""
|
|
arch_settings = {
|
|
'5m': {
|
|
'channels': [64, 128, 160, 320],
|
|
'num_heads': [2, 4, 5, 10],
|
|
'depths': [2, 2, 6, 2],
|
|
},
|
|
'11m': {
|
|
'channels': [64, 128, 256, 448],
|
|
'num_heads': [2, 4, 8, 14],
|
|
'depths': [2, 2, 6, 2],
|
|
},
|
|
'21m': {
|
|
'channels': [96, 192, 384, 576],
|
|
'num_heads': [3, 6, 12, 18],
|
|
'depths': [2, 2, 6, 2],
|
|
},
|
|
}
|
|
|
|
def __init__(self,
|
|
arch='5m',
|
|
img_size=(224, 224),
|
|
window_size=[7, 7, 14, 7],
|
|
in_channels=3,
|
|
mlp_ratio=4.,
|
|
drop_rate=0.,
|
|
drop_path_rate=0.1,
|
|
use_checkpoint=False,
|
|
mbconv_expand_ratio=4.0,
|
|
local_conv_size=3,
|
|
layer_lr_decay=1.0,
|
|
out_indices=-1,
|
|
frozen_stages=0,
|
|
gap_before_final_norm=True,
|
|
act_cfg=dict(type='GELU'),
|
|
norm_cfg=dict(type='LN'),
|
|
init_cfg=None):
|
|
super().__init__(init_cfg=init_cfg)
|
|
|
|
if isinstance(arch, str):
|
|
assert arch in self.arch_settings, \
|
|
f'Unavaiable arch, please choose from ' \
|
|
f'({set(self.arch_settings)} or pass a dict.'
|
|
arch = self.arch_settings[arch]
|
|
elif isinstance(arch, dict):
|
|
assert 'channels' in arch and 'num_heads' in arch and \
|
|
'depths' in arch, 'The arch dict must have' \
|
|
f'"channels", "num_heads", "window_sizes" ' \
|
|
f'keys, but got {arch.keys()}'
|
|
|
|
self.channels = arch['channels']
|
|
self.num_heads = arch['num_heads']
|
|
self.widow_sizes = window_size
|
|
self.img_size = img_size
|
|
self.depths = arch['depths']
|
|
|
|
self.num_stages = len(self.channels)
|
|
|
|
if isinstance(out_indices, int):
|
|
out_indices = [out_indices]
|
|
assert isinstance(out_indices, Sequence), \
|
|
f'"out_indices" must by a sequence or int, ' \
|
|
f'get {type(out_indices)} instead.'
|
|
for i, index in enumerate(out_indices):
|
|
if index < 0:
|
|
out_indices[i] = 4 + index
|
|
assert out_indices[i] >= 0, f'Invalid out_indices {index}'
|
|
self.out_indices = out_indices
|
|
|
|
self.frozen_stages = frozen_stages
|
|
self.gap_before_final_norm = gap_before_final_norm
|
|
self.layer_lr_decay = layer_lr_decay
|
|
|
|
self.patch_embed = PatchEmbed(
|
|
in_channels=in_channels,
|
|
embed_dim=self.channels[0],
|
|
resolution=self.img_size,
|
|
act_cfg=dict(type='GELU'))
|
|
patches_resolution = self.patch_embed.patches_resolution
|
|
|
|
# stochastic depth decay rule
|
|
dpr = [
|
|
x.item()
|
|
for x in torch.linspace(0, drop_path_rate, sum(self.depths))
|
|
]
|
|
|
|
# build stages
|
|
self.stages = ModuleList()
|
|
for i in range(self.num_stages):
|
|
depth = self.depths[i]
|
|
channel = self.channels[i]
|
|
curr_resolution = (patches_resolution[0] // (2**i),
|
|
patches_resolution[1] // (2**i))
|
|
drop_path = dpr[sum(self.depths[:i]):sum(self.depths[:i + 1])]
|
|
downsample = PatchMerging if (i < self.num_stages - 1) else None
|
|
out_channels = self.channels[min(i + 1, self.num_stages - 1)]
|
|
if i >= 1:
|
|
stage = BasicStage(
|
|
in_channels=channel,
|
|
resolution=curr_resolution,
|
|
depth=depth,
|
|
num_heads=self.num_heads[i],
|
|
window_size=self.widow_sizes[i],
|
|
mlp_ratio=mlp_ratio,
|
|
drop=drop_rate,
|
|
drop_path=drop_path,
|
|
downsample=downsample,
|
|
use_checkpoint=use_checkpoint,
|
|
local_conv_size=local_conv_size,
|
|
out_channels=out_channels,
|
|
act_cfg=act_cfg)
|
|
else:
|
|
stage = ConvStage(
|
|
in_channels=channel,
|
|
resolution=curr_resolution,
|
|
depth=depth,
|
|
act_cfg=act_cfg,
|
|
drop_path=drop_path,
|
|
downsample=downsample,
|
|
use_checkpoint=use_checkpoint,
|
|
out_channels=out_channels,
|
|
conv_expand_ratio=mbconv_expand_ratio)
|
|
self.stages.append(stage)
|
|
|
|
# add output norm
|
|
if i in self.out_indices:
|
|
norm_layer = build_norm_layer(norm_cfg, out_channels)[1]
|
|
self.add_module(f'norm{i}', norm_layer)
|
|
|
|
def set_layer_lr_decay(self, layer_lr_decay):
|
|
# TODO: add layer_lr_decay
|
|
pass
|
|
|
|
def forward(self, x):
|
|
outs = []
|
|
x = self.patch_embed(x)
|
|
|
|
for i, stage in enumerate(self.stages):
|
|
x = stage(x)
|
|
if i in self.out_indices:
|
|
norm_layer = getattr(self, f'norm{i}')
|
|
if self.gap_before_final_norm:
|
|
gap = x.mean(1)
|
|
outs.append(norm_layer(gap))
|
|
else:
|
|
out = norm_layer(x)
|
|
# convert the (B,L,C) format into (B,C,H,W) format
|
|
# which would be better for the downstream tasks.
|
|
B, L, C = out.shape
|
|
out = out.view(B, *stage.resolution, C)
|
|
outs.append(out.permute(0, 3, 1, 2))
|
|
|
|
return tuple(outs)
|
|
|
|
def _freeze_stages(self):
|
|
for i in range(self.frozen_stages):
|
|
stage = self.stages[i]
|
|
stage.eval()
|
|
for param in stage.parameters():
|
|
param.requires_grad = False
|
|
|
|
def train(self, mode=True):
|
|
super(TinyViT, self).train(mode)
|
|
self._freeze_stages()
|