pytorch-image-models/timm/models/efficientvit_mit.py

675 lines
21 KiB
Python
Raw Normal View History

2023-08-01 18:51:08 +08:00
""" EfficientViT (by MIT Song Han's Lab)
2023-08-01 12:42:21 +08:00
Paper: `Efficientvit: Enhanced linear attention for high-resolution low-computation visual recognition`
- https://arxiv.org/abs/2205.14756
Adapted from official impl at https://github.com/mit-han-lab/efficientvit
"""
__all__ = ['EfficientViT']
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from ._registry import register_model, generate_default_cfgs
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
from timm.layers import SelectAdaptivePool2d
def val2list(x: list or tuple or any, repeat_time=1):
if isinstance(x, (list, tuple)):
return list(x)
return [x for _ in range(repeat_time)]
2023-08-02 14:12:37 +08:00
2023-08-01 12:42:21 +08:00
def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1):
# repeat elements if necessary
x = val2list(x)
if len(x) > 0:
x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]
return tuple(x)
2023-08-02 14:12:37 +08:00
2023-08-01 12:42:21 +08:00
def get_same_padding(kernel_size: int or tuple[int, ...]) -> int or tuple[int, ...]:
if isinstance(kernel_size, tuple):
return tuple([get_same_padding(ks) for ks in kernel_size])
else:
assert kernel_size % 2 > 0, "kernel size should be odd number"
return kernel_size // 2
2023-08-02 14:12:37 +08:00
2023-08-01 18:51:08 +08:00
class ConvNormAct(nn.Module):
2023-08-01 12:42:21 +08:00
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size=3,
stride=1,
dilation=1,
groups=1,
use_bias=False,
dropout=0,
norm=nn.BatchNorm2d,
act_func=nn.ReLU,
):
2023-08-01 18:51:08 +08:00
super(ConvNormAct, self).__init__()
2023-08-01 12:42:21 +08:00
padding = get_same_padding(kernel_size)
padding *= dilation
2023-08-03 14:59:35 +08:00
self.dropout = nn.Dropout(dropout, inplace=False)
2023-08-01 12:42:21 +08:00
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=(kernel_size, kernel_size),
stride=(stride, stride),
padding=padding,
dilation=(dilation, dilation),
groups=groups,
bias=use_bias,
)
self.norm = norm(num_features=out_channels) if norm else None
self.act = act_func(inplace=True) if act_func else None
def forward(self, x):
2023-08-03 14:59:35 +08:00
x = self.dropout(x)
2023-08-01 12:42:21 +08:00
x = self.conv(x)
if self.norm:
x = self.norm(x)
if self.act:
x = self.act(x)
return x
class DSConv(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size=3,
stride=1,
use_bias=False,
norm=(nn.BatchNorm2d, nn.BatchNorm2d),
act_func=(nn.ReLU6, None),
):
super(DSConv, self).__init__()
use_bias = val2tuple(use_bias, 2)
norm = val2tuple(norm, 2)
act_func = val2tuple(act_func, 2)
2023-08-01 18:51:08 +08:00
self.depth_conv = ConvNormAct(
2023-08-01 12:42:21 +08:00
in_channels,
in_channels,
kernel_size,
stride,
groups=in_channels,
norm=norm[0],
act_func=act_func[0],
use_bias=use_bias[0],
)
2023-08-01 18:51:08 +08:00
self.point_conv = ConvNormAct(
2023-08-01 12:42:21 +08:00
in_channels,
out_channels,
1,
norm=norm[1],
act_func=act_func[1],
use_bias=use_bias[1],
)
def forward(self, x):
x = self.depth_conv(x)
x = self.point_conv(x)
return x
class MBConv(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size=3,
stride=1,
mid_channels=None,
expand_ratio=6,
use_bias=False,
norm=(nn.BatchNorm2d, nn.BatchNorm2d, nn.BatchNorm2d),
act_func=(nn.ReLU6, nn.ReLU6, None),
):
super(MBConv, self).__init__()
use_bias = val2tuple(use_bias, 3)
norm = val2tuple(norm, 3)
act_func = val2tuple(act_func, 3)
mid_channels = mid_channels or round(in_channels * expand_ratio)
2023-08-01 18:51:08 +08:00
self.inverted_conv = ConvNormAct(
2023-08-01 12:42:21 +08:00
in_channels,
mid_channels,
1,
stride=1,
norm=norm[0],
act_func=act_func[0],
use_bias=use_bias[0],
)
2023-08-01 18:51:08 +08:00
self.depth_conv = ConvNormAct(
2023-08-01 12:42:21 +08:00
mid_channels,
mid_channels,
kernel_size,
stride=stride,
groups=mid_channels,
norm=norm[1],
act_func=act_func[1],
use_bias=use_bias[1],
)
2023-08-01 18:51:08 +08:00
self.point_conv = ConvNormAct(
2023-08-01 12:42:21 +08:00
mid_channels,
out_channels,
1,
norm=norm[2],
act_func=act_func[2],
use_bias=use_bias[2],
)
def forward(self, x):
x = self.inverted_conv(x)
x = self.depth_conv(x)
x = self.point_conv(x)
return x
class LiteMSA(nn.Module):
"""Lightweight multi-scale attention"""
def __init__(
self,
in_channels: int,
out_channels: int,
heads: int or None = None,
heads_ratio: float = 1.0,
dim=8,
use_bias=False,
norm=(None, nn.BatchNorm2d),
act_func=(None, None),
kernel_func=nn.ReLU,
scales=(5,),
):
super(LiteMSA, self).__init__()
heads = heads or int(in_channels // dim * heads_ratio)
total_dim = heads * dim
use_bias = val2tuple(use_bias, 2)
norm = val2tuple(norm, 2)
act_func = val2tuple(act_func, 2)
self.dim = dim
2023-08-01 18:51:08 +08:00
self.qkv = ConvNormAct(
2023-08-01 12:42:21 +08:00
in_channels,
3 * total_dim,
1,
use_bias=use_bias[0],
norm=norm[0],
act_func=act_func[0],
)
self.aggreg = nn.ModuleList(
[
nn.Sequential(
nn.Conv2d(
3 * total_dim,
3 * total_dim,
scale,
padding=get_same_padding(scale),
groups=3 * total_dim,
bias=use_bias[0],
),
nn.Conv2d(3 * total_dim, 3 * total_dim, 1, groups=3 * heads, bias=use_bias[0]),
)
for scale in scales
]
)
self.kernel_func = kernel_func(inplace=False)
2023-08-01 18:51:08 +08:00
self.proj = ConvNormAct(
2023-08-01 12:42:21 +08:00
total_dim * (1 + len(scales)),
out_channels,
1,
use_bias=use_bias[1],
norm=norm[1],
act_func=act_func[1],
)
def forward(self, x):
2023-08-03 14:59:35 +08:00
B, _, H, W = x.shape
2023-08-01 12:42:21 +08:00
# generate multi-scale q, k, v
qkv = self.qkv(x)
multi_scale_qkv = [qkv]
for op in self.aggreg:
multi_scale_qkv.append(op(qkv))
multi_scale_qkv = torch.cat(multi_scale_qkv, dim=1)
multi_scale_qkv = torch.reshape(
multi_scale_qkv,
(
B,
-1,
3 * self.dim,
H * W,
),
)
multi_scale_qkv = torch.transpose(multi_scale_qkv, -1, -2)
q, k, v = (
2023-08-02 14:12:37 +08:00
multi_scale_qkv[..., 0: self.dim],
multi_scale_qkv[..., self.dim: 2 * self.dim],
multi_scale_qkv[..., 2 * self.dim:],
2023-08-01 12:42:21 +08:00
)
# lightweight global attention
q = self.kernel_func(q)
k = self.kernel_func(k)
trans_k = k.transpose(-1, -2)
v = F.pad(v, (0, 1), mode="constant", value=1)
kv = torch.matmul(trans_k, v)
out = torch.matmul(q, kv)
out = out[..., :-1] / (out[..., -1:] + 1e-15)
# final projecttion
out = torch.transpose(out, -1, -2)
out = torch.reshape(out, (B, -1, H, W))
out = self.proj(out)
return out
2023-08-02 14:12:37 +08:00
2023-08-01 12:42:21 +08:00
class EfficientViTBlock(nn.Module):
def __init__(
self,
in_channels,
heads_ratio=1.0,
dim=32,
expand_ratio=4,
norm=nn.BatchNorm2d,
act_func=nn.Hardswish,
):
super(EfficientViTBlock, self).__init__()
self.context_module = ResidualBlock(
LiteMSA(
in_channels=in_channels,
out_channels=in_channels,
heads_ratio=heads_ratio,
dim=dim,
norm=(None, norm),
),
nn.Identity(),
)
local_module = MBConv(
in_channels=in_channels,
out_channels=in_channels,
expand_ratio=expand_ratio,
use_bias=(True, True, False),
norm=(None, None, norm),
act_func=(act_func, act_func, None),
)
self.local_module = ResidualBlock(local_module, nn.Identity())
def forward(self, x):
x = self.context_module(x)
x = self.local_module(x)
return x
class ResidualBlock(nn.Module):
def __init__(
self,
main: nn.Module or None,
shortcut: nn.Module or None,
post_act=None,
pre_norm: nn.Module or None = None,
):
super(ResidualBlock, self).__init__()
self.pre_norm = pre_norm
self.main = main
self.shortcut = shortcut
self.post_act = post_act(inplace=True) if post_act else nn.Identity()
def forward_main(self, x):
if self.pre_norm is None:
return self.main(x)
else:
return self.main(self.pre_norm(x))
def forward(self, x):
if self.main is None:
res = x
elif self.shortcut is None:
res = self.forward_main(x)
else:
res = self.forward_main(x) + self.shortcut(x)
if self.post_act:
res = self.post_act(res)
return res
2023-08-02 14:12:37 +08:00
2023-08-03 14:59:35 +08:00
def build_local_block(in_channels: int,
out_channels: int,
stride: int,
expand_ratio: float,
norm: str,
act_func: str,
fewer_norm: bool = False,):
if expand_ratio == 1:
block = DSConv(
in_channels=in_channels,
out_channels=out_channels,
stride=stride,
use_bias=(True, False) if fewer_norm else False,
norm=(None, norm) if fewer_norm else norm,
act_func=(act_func, None),
)
else:
block = MBConv(
in_channels=in_channels,
out_channels=out_channels,
stride=stride,
expand_ratio=expand_ratio,
use_bias=(True, True, False) if fewer_norm else False,
norm=(None, None, norm) if fewer_norm else norm,
act_func=(act_func, act_func, None),
)
return block
class Stem(nn.Sequential):
def __init__(self, in_chs, out_chs, depth, norm, act_func):
super().__init__()
self.stride = 2
self.add_module('in_conv', ConvNormAct(in_channels=in_chs, out_channels=out_chs, kernel_size=3, stride=2, norm=norm, act_func=act_func))
stem_block = 0
for _ in range(depth):
block = build_local_block(
in_channels=out_chs,
out_channels=out_chs,
stride=1,
expand_ratio=1,
norm=norm,
act_func=act_func,
)
self.add_module(f'res{stem_block}', ResidualBlock(block, nn.Identity()))
stem_block += 1
class EfficientViTStage(nn.Module):
def __init__(self, in_chs, out_chs, depth, norm, act_func, expand_ratio, dim, conv_stage=False):
super(EfficientViTStage, self).__init__()
blocks = []
if conv_stage:
# for stage 1, 2
for i in range(depth):
stage_stride = 2 if i == 0 else 1
block = build_local_block(
in_channels=in_chs,
out_channels=out_chs,
stride=stage_stride,
expand_ratio=expand_ratio,
norm=norm,
act_func=act_func,
)
block = ResidualBlock(block, nn.Identity() if stage_stride == 1 else None)
blocks.append(block)
in_chs = out_chs
else:
# for stage 3, 4
block = build_local_block(
in_channels=in_chs,
out_channels=out_chs,
stride=2,
expand_ratio=expand_ratio,
norm=norm,
act_func=act_func,
fewer_norm=True,
)
blocks.append(ResidualBlock(block, None))
in_chs = out_chs
for _ in range(depth):
blocks.append(
EfficientViTBlock(
in_channels=in_chs,
dim=dim,
expand_ratio=expand_ratio,
norm=norm,
act_func=act_func,
)
)
self.blocks = nn.Sequential(*blocks)
def forward(self, x):
return self.blocks(x)
class ClassifierHead(nn.Module):
2023-08-01 12:42:21 +08:00
def __init__(
self,
in_channels,
width_list,
n_classes=1000,
dropout=0,
norm=nn.BatchNorm2d,
act_func=nn.Hardswish,
global_pool='avg',
):
2023-08-03 14:59:35 +08:00
super(ClassifierHead, self).__init__()
self.in_conv = ConvNormAct(in_channels, width_list[0], 1, norm=norm, act_func=act_func)
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True, input_fmt='NCHW')
self.classifier = nn.Sequential(
2023-08-01 12:42:21 +08:00
nn.Linear(width_list[0], width_list[1], bias=False),
nn.LayerNorm(width_list[1]),
act_func(inplace=True),
2023-08-03 14:59:35 +08:00
nn.Dropout(dropout, inplace=False),
2023-08-01 12:42:21 +08:00
nn.Linear(width_list[1], n_classes, bias=True),
)
def forward(self, x):
2023-08-03 14:59:35 +08:00
x = self.in_conv(x)
x = self.global_pool(x)
x = self.classifier(x)
2023-08-01 12:42:21 +08:00
return x
class EfficientViT(nn.Module):
def __init__(
self,
in_chans=3,
width_list=[],
depth_list=[],
dim=32,
expand_ratio=4,
norm=nn.BatchNorm2d,
act_func=nn.Hardswish,
global_pool='avg',
head_width_list=[],
head_dropout=0.0,
num_classes=1000,
):
super(EfficientViT, self).__init__()
self.grad_checkpointing = False
2023-08-03 14:59:35 +08:00
self.global_pool_name = global_pool
2023-08-01 12:42:21 +08:00
# input stem
2023-08-03 14:59:35 +08:00
self.stem = Stem(in_chans, width_list[0], depth_list[0], norm, act_func)
stride = self.stem.stride
# stages
2023-08-01 12:42:21 +08:00
self.feature_info = []
stages = []
stage_idx = 0
2023-08-03 14:59:35 +08:00
in_channels = width_list[0]
2023-08-01 12:42:21 +08:00
for w, d in zip(width_list[1:3], depth_list[1:3]):
2023-08-03 14:59:35 +08:00
stages.append(EfficientViTStage(in_channels, w, d, norm, act_func, expand_ratio, dim, conv_stage=True))
stride *= 2
in_channels = w
2023-08-01 12:42:21 +08:00
self.feature_info += [dict(num_chs=in_channels, reduction=stride, module=f'stages.{stage_idx}')]
stage_idx += 1
2023-08-02 14:12:37 +08:00
2023-08-01 12:42:21 +08:00
for w, d in zip(width_list[3:], depth_list[3:]):
2023-08-03 14:59:35 +08:00
stages.append(EfficientViTStage(in_channels, w, d, norm, act_func, expand_ratio, dim, conv_stage=False))
2023-08-01 18:51:08 +08:00
stride *= 2
2023-08-03 14:59:35 +08:00
in_channels = w
2023-08-01 12:42:21 +08:00
self.feature_info += [dict(num_chs=in_channels, reduction=stride, module=f'stages.{stage_idx}')]
stage_idx += 1
2023-08-01 18:51:08 +08:00
2023-08-01 12:42:21 +08:00
self.stages = nn.Sequential(*stages)
self.num_features = in_channels
self.head_width_list = head_width_list
self.head_dropout = head_dropout
if num_classes > 0:
2023-08-03 14:59:35 +08:00
self.head = ClassifierHead(self.num_features, self.head_width_list, n_classes=num_classes, dropout=self.head_dropout, global_pool=self.global_pool_name)
2023-08-01 12:42:21 +08:00
else:
if global_pool is not None:
self.head = SelectAdaptivePool2d(pool_type=global_pool, flatten=True, input_fmt='NCHW')
else:
self.head = nn.Identity()
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^stem', # stem and embed
blocks=[(r'^stages\.(\d+)', None)]
)
return matcher
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=None, dropout=0):
self.num_classes = num_classes
if global_pool is not None:
2023-08-03 14:59:35 +08:00
self.global_pool_name = global_pool
2023-08-01 12:42:21 +08:00
if num_classes > 0:
2023-08-03 14:59:35 +08:00
self.head = ClassifierHead(self.num_features, self.head_width_list, n_classes=num_classes, dropout=self.head_dropout, global_pool=self.global_pool_name)
2023-08-01 12:42:21 +08:00
else:
if global_pool is not None:
self.head = SelectAdaptivePool2d(pool_type=global_pool, flatten=True, input_fmt='NCHW')
else:
self.head = nn.Identity()
def forward_features(self, x):
x = self.stem(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.stages, x)
else:
x = self.stages(x)
return x
def forward_head(self, x, pre_logits: bool = False):
return x if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
def checkpoint_filter_fn(state_dict, model):
target_keys = list(model.state_dict().keys())
if 'state_dict' in state_dict.keys():
state_dict = state_dict['state_dict']
out_dict = {}
for i, (k, v) in enumerate(state_dict.items()):
out_dict[target_keys[i]] = v
return out_dict
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000,
'mean': IMAGENET_DEFAULT_MEAN,
'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.in_conv.conv',
2023-08-03 14:59:35 +08:00
'classifier': 'head.classifier',
2023-08-01 12:42:21 +08:00
**kwargs,
}
default_cfgs = generate_default_cfgs(
{
'efficientvit_b0.r224_in1k': _cfg(
# url='https://drive.google.com/file/d/1ganFBZmmvCTpgUwiLb8ePD6NBNxRyZDk/view?usp=drive_link'
),
'efficientvit_b1.r224_in1k': _cfg(
# url='https://drive.google.com/file/d/1hKN_hvLG4nmRzbfzKY7GlqwpR5uKpOOk/view?usp=share_link'
),
'efficientvit_b1.r256_in1k': _cfg(
# url='https://drive.google.com/file/d/1hXcG_jB0ODMOESsSkzVye-58B4F3Cahs/view?usp=share_link'
),
'efficientvit_b1.r288_in1k': _cfg(
# url='https://drive.google.com/file/d/1sE_Suz9gOOUO7o5r9eeAT4nKK8Hrbhsu/view?usp=share_link'
),
'efficientvit_b2.r224_in1k': _cfg(
# url='https://drive.google.com/file/d/1DiM-iqVGTrq4te8mefHl3e1c12u4qR7d/view?usp=share_link'
),
'efficientvit_b2.r256_in1k': _cfg(
# url='https://drive.google.com/file/d/192OOk4ISitwlyW979M-FSJ_fYMMW9HQz/view?usp=share_link'
),
'efficientvit_b2.r288_in1k': _cfg(
# url='https://drive.google.com/file/d/1aodcepOyne667hvBAGpf9nDwmd5g0NpU/view?usp=share_link'
),
'efficientvit_b3.r224_in1k': _cfg(
# url='https://drive.google.com/file/d/18RZDGLiY8KsyJ7LGic4mg1JHwd-a_ky6/view?usp=share_link'
),
'efficientvit_b3.r256_in1k': _cfg(
# url='https://drive.google.com/file/d/1y1rnir4I0XiId-oTCcHhs7jqnrHGFi-g/view?usp=share_link'
),
'efficientvit_b3.r288_in1k': _cfg(
# url='https://drive.google.com/file/d/1KfwbGtlyFgslNr4LIHERv6aCfkItEvRk/view?usp=share_link'
),
}
)
def _create_efficientvit(variant, pretrained=False, **kwargs):
2023-08-02 14:12:37 +08:00
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3))
2023-08-01 12:42:21 +08:00
model = build_model_with_cfg(
EfficientViT,
variant,
pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
2023-08-02 14:12:37 +08:00
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
2023-08-01 12:42:21 +08:00
**kwargs
)
return model
@register_model
def efficientvit_b0(pretrained=False, **kwargs):
model_args = dict(width_list=[8, 16, 32, 64, 128], depth_list=[1, 2, 2, 2, 2], dim=16, head_width_list=[1024, 1280])
return _create_efficientvit('efficientvit_b0', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def efficientvit_b1(pretrained=False, **kwargs):
model_args = dict(width_list=[16, 32, 64, 128, 256], depth_list=[1, 2, 3, 3, 4], dim=16, head_width_list=[1536, 1600])
return _create_efficientvit('efficientvit_b0', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def efficientvit_b2(pretrained=False, **kwargs):
model_args = dict(width_list=[24, 48, 96, 192, 384], depth_list=[1, 3, 4, 4, 6], dim=32, head_width_list=[2304, 2560])
return _create_efficientvit('efficientvit_b0', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def efficientvit_b3(pretrained=False, **kwargs):
model_args = dict(width_list=[32, 64, 128, 256, 512], depth_list=[1, 4, 6, 6, 9], dim=32, head_width_list=[2304, 2560])
return _create_efficientvit('efficientvit_b0', pretrained=pretrained, **dict(model_args, **kwargs))