Merge pull request #1894 from seefun/master
add two different EfficientViT modelsyehuitang-Add-GhostNetV2
commit
b8011565bd
|
@ -41,7 +41,7 @@ NON_STD_FILTERS = [
|
|||
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
|
||||
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*',
|
||||
'poolformer_*', 'volo_*', 'sequencer2d_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
|
||||
'eva_*', 'flexivit*', 'eva02*', 'samvit_*'
|
||||
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*'
|
||||
]
|
||||
NUM_NON_STD = len(NON_STD_FILTERS)
|
||||
|
||||
|
|
|
@ -17,6 +17,8 @@ from .edgenext import *
|
|||
from .efficientformer import *
|
||||
from .efficientformer_v2 import *
|
||||
from .efficientnet import *
|
||||
from .efficientvit_mit import *
|
||||
from .efficientvit_msra import *
|
||||
from .eva import *
|
||||
from .focalnet import *
|
||||
from .gcvit import *
|
||||
|
|
|
@ -0,0 +1,677 @@
|
|||
""" EfficientViT (by MIT Song Han's Lab)
|
||||
|
||||
Paper: `Efficientvit: Enhanced linear attention for high-resolution low-computation visual recognition`
|
||||
- https://arxiv.org/abs/2205.14756
|
||||
|
||||
Adapted from official impl at https://github.com/mit-han-lab/efficientvit
|
||||
"""
|
||||
|
||||
__all__ = ['EfficientVit']
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import checkpoint_seq
|
||||
from timm.layers import SelectAdaptivePool2d, create_conv2d
|
||||
|
||||
|
||||
def val2list(x: list or tuple or any, repeat_time=1):
|
||||
if isinstance(x, (list, tuple)):
|
||||
return list(x)
|
||||
return [x for _ in range(repeat_time)]
|
||||
|
||||
|
||||
def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1):
|
||||
# repeat elements if necessary
|
||||
x = val2list(x)
|
||||
if len(x) > 0:
|
||||
x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]
|
||||
|
||||
return tuple(x)
|
||||
|
||||
|
||||
def get_same_padding(kernel_size: int or tuple[int, ...]) -> int or tuple[int, ...]:
|
||||
if isinstance(kernel_size, tuple):
|
||||
return tuple([get_same_padding(ks) for ks in kernel_size])
|
||||
else:
|
||||
assert kernel_size % 2 > 0, "kernel size should be odd number"
|
||||
return kernel_size // 2
|
||||
|
||||
|
||||
class ConvNormAct(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=False,
|
||||
dropout=0.,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
act_layer=nn.ReLU,
|
||||
):
|
||||
super(ConvNormAct, self).__init__()
|
||||
self.dropout = nn.Dropout(dropout, inplace=False)
|
||||
self.conv = create_conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
)
|
||||
self.norm = norm_layer(num_features=out_channels) if norm_layer else nn.Identity()
|
||||
self.act = act_layer(inplace=True) if act_layer else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.dropout(x)
|
||||
x = self.conv(x)
|
||||
x = self.norm(x)
|
||||
x = self.act(x)
|
||||
return x
|
||||
|
||||
|
||||
class DSConv(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
use_bias=False,
|
||||
norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d),
|
||||
act_layer=(nn.ReLU6, None),
|
||||
):
|
||||
super(DSConv, self).__init__()
|
||||
use_bias = val2tuple(use_bias, 2)
|
||||
norm_layer = val2tuple(norm_layer, 2)
|
||||
act_layer = val2tuple(act_layer, 2)
|
||||
|
||||
self.depth_conv = ConvNormAct(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
groups=in_channels,
|
||||
norm_layer=norm_layer[0],
|
||||
act_layer=act_layer[0],
|
||||
bias=use_bias[0],
|
||||
)
|
||||
self.point_conv = ConvNormAct(
|
||||
in_channels,
|
||||
out_channels,
|
||||
1,
|
||||
norm_layer=norm_layer[1],
|
||||
act_layer=act_layer[1],
|
||||
bias=use_bias[1],
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.depth_conv(x)
|
||||
x = self.point_conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class MBConv(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
mid_channels=None,
|
||||
expand_ratio=6,
|
||||
use_bias=False,
|
||||
norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d, nn.BatchNorm2d),
|
||||
act_layer=(nn.ReLU6, nn.ReLU6, None),
|
||||
):
|
||||
super(MBConv, self).__init__()
|
||||
use_bias = val2tuple(use_bias, 3)
|
||||
norm_layer = val2tuple(norm_layer, 3)
|
||||
act_layer = val2tuple(act_layer, 3)
|
||||
mid_channels = mid_channels or round(in_channels * expand_ratio)
|
||||
|
||||
self.inverted_conv = ConvNormAct(
|
||||
in_channels,
|
||||
mid_channels,
|
||||
1,
|
||||
stride=1,
|
||||
norm_layer=norm_layer[0],
|
||||
act_layer=act_layer[0],
|
||||
bias=use_bias[0],
|
||||
)
|
||||
self.depth_conv = ConvNormAct(
|
||||
mid_channels,
|
||||
mid_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
groups=mid_channels,
|
||||
norm_layer=norm_layer[1],
|
||||
act_layer=act_layer[1],
|
||||
bias=use_bias[1],
|
||||
)
|
||||
self.point_conv = ConvNormAct(
|
||||
mid_channels,
|
||||
out_channels,
|
||||
1,
|
||||
norm_layer=norm_layer[2],
|
||||
act_layer=act_layer[2],
|
||||
bias=use_bias[2],
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.inverted_conv(x)
|
||||
x = self.depth_conv(x)
|
||||
x = self.point_conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class LiteMSA(nn.Module):
|
||||
"""Lightweight multi-scale attention"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
heads: int or None = None,
|
||||
heads_ratio: float = 1.0,
|
||||
dim=8,
|
||||
use_bias=False,
|
||||
norm_layer=(None, nn.BatchNorm2d),
|
||||
act_layer=(None, None),
|
||||
kernel_func=nn.ReLU,
|
||||
scales=(5,),
|
||||
eps=1e-5,
|
||||
):
|
||||
super(LiteMSA, self).__init__()
|
||||
self.eps = eps
|
||||
heads = heads or int(in_channels // dim * heads_ratio)
|
||||
total_dim = heads * dim
|
||||
use_bias = val2tuple(use_bias, 2)
|
||||
norm_layer = val2tuple(norm_layer, 2)
|
||||
act_layer = val2tuple(act_layer, 2)
|
||||
|
||||
self.dim = dim
|
||||
self.qkv = ConvNormAct(
|
||||
in_channels,
|
||||
3 * total_dim,
|
||||
1,
|
||||
bias=use_bias[0],
|
||||
norm_layer=norm_layer[0],
|
||||
act_layer=act_layer[0],
|
||||
)
|
||||
self.aggreg = nn.ModuleList([
|
||||
nn.Sequential(
|
||||
nn.Conv2d(
|
||||
3 * total_dim,
|
||||
3 * total_dim,
|
||||
scale,
|
||||
padding=get_same_padding(scale),
|
||||
groups=3 * total_dim,
|
||||
bias=use_bias[0],
|
||||
),
|
||||
nn.Conv2d(3 * total_dim, 3 * total_dim, 1, groups=3 * heads, bias=use_bias[0]),
|
||||
)
|
||||
for scale in scales
|
||||
])
|
||||
self.kernel_func = kernel_func(inplace=False)
|
||||
|
||||
self.proj = ConvNormAct(
|
||||
total_dim * (1 + len(scales)),
|
||||
out_channels,
|
||||
1,
|
||||
bias=use_bias[1],
|
||||
norm_layer=norm_layer[1],
|
||||
act_layer=act_layer[1],
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
B, _, H, W = x.shape
|
||||
|
||||
# generate multi-scale q, k, v
|
||||
qkv = self.qkv(x)
|
||||
multi_scale_qkv = [qkv]
|
||||
for op in self.aggreg:
|
||||
multi_scale_qkv.append(op(qkv))
|
||||
multi_scale_qkv = torch.cat(multi_scale_qkv, dim=1)
|
||||
multi_scale_qkv = multi_scale_qkv.reshape(B, -1, 3 * self.dim, H * W).transpose(-1, -2)
|
||||
q, k, v = multi_scale_qkv.tensor_split(3, dim=-1)
|
||||
|
||||
# lightweight global attention
|
||||
q = self.kernel_func(q)
|
||||
k = self.kernel_func(k)
|
||||
v = F.pad(v, (0, 1), mode="constant", value=1.)
|
||||
|
||||
kv = k.transpose(-1, -2) @ v
|
||||
out = q @ kv
|
||||
out = out[..., :-1] / (out[..., -1:] + self.eps)
|
||||
|
||||
# final projection
|
||||
out = out.transpose(-1, -2).reshape(B, -1, H, W)
|
||||
out = self.proj(out)
|
||||
return out
|
||||
|
||||
|
||||
class EfficientVitBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
heads_ratio=1.0,
|
||||
head_dim=32,
|
||||
expand_ratio=4,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
act_layer=nn.Hardswish,
|
||||
):
|
||||
super(EfficientVitBlock, self).__init__()
|
||||
self.context_module = ResidualBlock(
|
||||
LiteMSA(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
heads_ratio=heads_ratio,
|
||||
dim=head_dim,
|
||||
norm_layer=(None, norm_layer),
|
||||
),
|
||||
nn.Identity(),
|
||||
)
|
||||
self.local_module = ResidualBlock(
|
||||
MBConv(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
expand_ratio=expand_ratio,
|
||||
use_bias=(True, True, False),
|
||||
norm_layer=(None, None, norm_layer),
|
||||
act_layer=(act_layer, act_layer, None),
|
||||
),
|
||||
nn.Identity(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.context_module(x)
|
||||
x = self.local_module(x)
|
||||
return x
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
main: Optional[nn.Module],
|
||||
shortcut: Optional[nn.Module] = None,
|
||||
pre_norm: Optional[nn.Module] = None,
|
||||
):
|
||||
super(ResidualBlock, self).__init__()
|
||||
self.pre_norm = pre_norm if pre_norm is not None else nn.Identity()
|
||||
self.main = main
|
||||
self.shortcut = shortcut
|
||||
|
||||
def forward(self, x):
|
||||
res = self.main(self.pre_norm(x))
|
||||
if self.shortcut is not None:
|
||||
res = res + self.shortcut(x)
|
||||
return res
|
||||
|
||||
|
||||
def build_local_block(
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
stride: int,
|
||||
expand_ratio: float,
|
||||
norm_layer: str,
|
||||
act_layer: str,
|
||||
fewer_norm: bool = False,
|
||||
):
|
||||
if expand_ratio == 1:
|
||||
block = DSConv(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
stride=stride,
|
||||
use_bias=(True, False) if fewer_norm else False,
|
||||
norm_layer=(None, norm_layer) if fewer_norm else norm_layer,
|
||||
act_layer=(act_layer, None),
|
||||
)
|
||||
else:
|
||||
block = MBConv(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
stride=stride,
|
||||
expand_ratio=expand_ratio,
|
||||
use_bias=(True, True, False) if fewer_norm else False,
|
||||
norm_layer=(None, None, norm_layer) if fewer_norm else norm_layer,
|
||||
act_layer=(act_layer, act_layer, None),
|
||||
)
|
||||
return block
|
||||
|
||||
|
||||
class Stem(nn.Sequential):
|
||||
def __init__(self, in_chs, out_chs, depth, norm_layer, act_layer):
|
||||
super().__init__()
|
||||
self.stride = 2
|
||||
|
||||
self.add_module(
|
||||
'in_conv',
|
||||
ConvNormAct(
|
||||
in_chs, out_chs,
|
||||
kernel_size=3, stride=2, norm_layer=norm_layer, act_layer=act_layer,
|
||||
)
|
||||
)
|
||||
stem_block = 0
|
||||
for _ in range(depth):
|
||||
self.add_module(f'res{stem_block}', ResidualBlock(
|
||||
build_local_block(
|
||||
in_channels=out_chs,
|
||||
out_channels=out_chs,
|
||||
stride=1,
|
||||
expand_ratio=1,
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
),
|
||||
nn.Identity(),
|
||||
))
|
||||
stem_block += 1
|
||||
|
||||
|
||||
class EfficientVitStage(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_chs,
|
||||
out_chs,
|
||||
depth,
|
||||
norm_layer,
|
||||
act_layer,
|
||||
expand_ratio,
|
||||
head_dim,
|
||||
vit_stage=False,
|
||||
):
|
||||
super(EfficientVitStage, self).__init__()
|
||||
blocks = [ResidualBlock(
|
||||
build_local_block(
|
||||
in_channels=in_chs,
|
||||
out_channels=out_chs,
|
||||
stride=2,
|
||||
expand_ratio=expand_ratio,
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
fewer_norm=vit_stage,
|
||||
),
|
||||
None,
|
||||
)]
|
||||
in_chs = out_chs
|
||||
|
||||
if vit_stage:
|
||||
# for stage 3, 4
|
||||
for _ in range(depth):
|
||||
blocks.append(
|
||||
EfficientVitBlock(
|
||||
in_channels=in_chs,
|
||||
head_dim=head_dim,
|
||||
expand_ratio=expand_ratio,
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# for stage 1, 2
|
||||
for i in range(1, depth):
|
||||
blocks.append(ResidualBlock(
|
||||
build_local_block(
|
||||
in_channels=in_chs,
|
||||
out_channels=out_chs,
|
||||
stride=1,
|
||||
expand_ratio=expand_ratio,
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer
|
||||
),
|
||||
nn.Identity(),
|
||||
))
|
||||
|
||||
self.blocks = nn.Sequential(*blocks)
|
||||
|
||||
def forward(self, x):
|
||||
return self.blocks(x)
|
||||
|
||||
|
||||
class ClassifierHead(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
widths,
|
||||
n_classes=1000,
|
||||
dropout=0.,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
act_layer=nn.Hardswish,
|
||||
global_pool='avg',
|
||||
):
|
||||
super(ClassifierHead, self).__init__()
|
||||
self.in_conv = ConvNormAct(in_channels, widths[0], 1, norm_layer=norm_layer, act_layer=act_layer)
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True, input_fmt='NCHW')
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(widths[0], widths[1], bias=False),
|
||||
nn.LayerNorm(widths[1]),
|
||||
act_layer(inplace=True),
|
||||
nn.Dropout(dropout, inplace=False),
|
||||
nn.Linear(widths[1], n_classes, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, x, pre_logits: bool = False):
|
||||
x = self.in_conv(x)
|
||||
x = self.global_pool(x)
|
||||
if pre_logits:
|
||||
return x
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
|
||||
class EfficientVit(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_chans=3,
|
||||
widths=(),
|
||||
depths=(),
|
||||
head_dim=32,
|
||||
expand_ratio=4,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
act_layer=nn.Hardswish,
|
||||
global_pool='avg',
|
||||
head_widths=(),
|
||||
drop_rate=0.0,
|
||||
num_classes=1000,
|
||||
):
|
||||
super(EfficientVit, self).__init__()
|
||||
self.grad_checkpointing = False
|
||||
self.global_pool = global_pool
|
||||
self.num_classes = num_classes
|
||||
|
||||
# input stem
|
||||
self.stem = Stem(in_chans, widths[0], depths[0], norm_layer, act_layer)
|
||||
stride = self.stem.stride
|
||||
|
||||
# stages
|
||||
self.feature_info = []
|
||||
stages = []
|
||||
stage_idx = 0
|
||||
in_channels = widths[0]
|
||||
for i, (w, d) in enumerate(zip(widths[1:], depths[1:])):
|
||||
stages.append(EfficientVitStage(
|
||||
in_channels,
|
||||
w,
|
||||
depth=d,
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
expand_ratio=expand_ratio,
|
||||
head_dim=head_dim,
|
||||
vit_stage=i >= 2,
|
||||
))
|
||||
stride *= 2
|
||||
in_channels = w
|
||||
self.feature_info += [dict(num_chs=in_channels, reduction=stride, module=f'stages.{stage_idx}')]
|
||||
stage_idx += 1
|
||||
|
||||
self.stages = nn.Sequential(*stages)
|
||||
self.num_features = in_channels
|
||||
self.head_widths = head_widths
|
||||
self.head_dropout = drop_rate
|
||||
if num_classes > 0:
|
||||
self.head = ClassifierHead(
|
||||
self.num_features,
|
||||
self.head_widths,
|
||||
n_classes=num_classes,
|
||||
dropout=self.head_dropout,
|
||||
global_pool=self.global_pool,
|
||||
)
|
||||
else:
|
||||
if self.global_pool == 'avg':
|
||||
self.head = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
|
||||
else:
|
||||
self.head = nn.Identity()
|
||||
|
||||
@torch.jit.ignore
|
||||
def group_matcher(self, coarse=False):
|
||||
matcher = dict(
|
||||
stem=r'^stem', # stem and embed
|
||||
blocks=[(r'^stages\.(\d+)', None)]
|
||||
)
|
||||
return matcher
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable=True):
|
||||
self.grad_checkpointing = enable
|
||||
|
||||
@torch.jit.ignore
|
||||
def get_classifier(self):
|
||||
return self.head.classifier[-1]
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
self.global_pool = global_pool
|
||||
if num_classes > 0:
|
||||
self.head = ClassifierHead(
|
||||
self.num_features,
|
||||
self.head_widths,
|
||||
n_classes=num_classes,
|
||||
dropout=self.head_dropout,
|
||||
global_pool=self.global_pool,
|
||||
)
|
||||
else:
|
||||
if self.global_pool == 'avg':
|
||||
self.head = SelectAdaptivePool2d(pool_type=self.global_pool, flatten=True)
|
||||
else:
|
||||
self.head = nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint_seq(self.stages, x)
|
||||
else:
|
||||
x = self.stages(x)
|
||||
return x
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
x = self.forward_head(x)
|
||||
return x
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000,
|
||||
'mean': IMAGENET_DEFAULT_MEAN,
|
||||
'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'stem.in_conv.conv',
|
||||
'classifier': 'head.classifier.4',
|
||||
'crop_pct': 0.95,
|
||||
'input_size': (3, 224, 224),
|
||||
'pool_size': (7, 7),
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'efficientvit_b0.r224_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
),
|
||||
'efficientvit_b1.r224_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
),
|
||||
'efficientvit_b1.r256_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0,
|
||||
),
|
||||
'efficientvit_b1.r288_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0,
|
||||
),
|
||||
'efficientvit_b2.r224_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
),
|
||||
'efficientvit_b2.r256_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0,
|
||||
),
|
||||
'efficientvit_b2.r288_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0,
|
||||
),
|
||||
'efficientvit_b3.r224_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
),
|
||||
'efficientvit_b3.r256_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0,
|
||||
),
|
||||
'efficientvit_b3.r288_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0,
|
||||
),
|
||||
})
|
||||
|
||||
|
||||
def _create_efficientvit(variant, pretrained=False, **kwargs):
|
||||
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3))
|
||||
model = build_model_with_cfg(
|
||||
EfficientVit,
|
||||
variant,
|
||||
pretrained,
|
||||
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
|
||||
**kwargs
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientvit_b0(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
widths=(8, 16, 32, 64, 128), depths=(1, 2, 2, 2, 2), head_dim=16, head_widths=(1024, 1280))
|
||||
return _create_efficientvit('efficientvit_b0', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientvit_b1(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
widths=(16, 32, 64, 128, 256), depths=(1, 2, 3, 3, 4), head_dim=16, head_widths=(1536, 1600))
|
||||
return _create_efficientvit('efficientvit_b1', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientvit_b2(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
widths=(24, 48, 96, 192, 384), depths=(1, 3, 4, 4, 6), head_dim=32, head_widths=(2304, 2560))
|
||||
return _create_efficientvit('efficientvit_b2', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientvit_b3(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
widths=(32, 64, 128, 256, 512), depths=(1, 4, 6, 6, 9), head_dim=32, head_widths=(2304, 2560))
|
||||
return _create_efficientvit('efficientvit_b3', pretrained=pretrained, **dict(model_args, **kwargs))
|
|
@ -0,0 +1,652 @@
|
|||
""" EfficientViT (by MSRA)
|
||||
|
||||
Paper: `EfficientViT: Memory Efficient Vision Transformer with Cascaded Group Attention`
|
||||
- https://arxiv.org/abs/2305.07027
|
||||
|
||||
Adapted from official impl at https://github.com/microsoft/Cream/tree/main/EfficientViT
|
||||
"""
|
||||
|
||||
__all__ = ['EfficientVitMsra']
|
||||
import itertools
|
||||
from collections import OrderedDict
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.models.layers import SqueezeExcite, SelectAdaptivePool2d, trunc_normal_, _assert
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
|
||||
|
||||
class ConvNorm(torch.nn.Sequential):
|
||||
def __init__(self, in_chs, out_chs, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_chs, out_chs, ks, stride, pad, dilation, groups, bias=False)
|
||||
self.bn = nn.BatchNorm2d(out_chs)
|
||||
torch.nn.init.constant_(self.bn.weight, bn_weight_init)
|
||||
torch.nn.init.constant_(self.bn.bias, 0)
|
||||
|
||||
@torch.no_grad()
|
||||
def fuse(self):
|
||||
c, bn = self.conv, self.bn
|
||||
w = bn.weight / (bn.running_var + bn.eps)**0.5
|
||||
w = c.weight * w[:, None, None, None]
|
||||
b = bn.bias - bn.running_mean * bn.weight / \
|
||||
(bn.running_var + bn.eps)**0.5
|
||||
m = torch.nn.Conv2d(
|
||||
w.size(1) * self.c.groups, w.size(0), w.shape[2:],
|
||||
stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)
|
||||
m.weight.data.copy_(w)
|
||||
m.bias.data.copy_(b)
|
||||
return m
|
||||
|
||||
|
||||
class NormLinear(torch.nn.Sequential):
|
||||
def __init__(self, in_features, out_features, bias=True, std=0.02, drop=0.):
|
||||
super().__init__()
|
||||
self.bn = nn.BatchNorm1d(in_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
self.linear = nn.Linear(in_features, out_features, bias=bias)
|
||||
|
||||
trunc_normal_(self.linear.weight, std=std)
|
||||
if self.linear.bias is not None:
|
||||
nn.init.constant_(self.linear.bias, 0)
|
||||
|
||||
@torch.no_grad()
|
||||
def fuse(self):
|
||||
bn, linear = self.bn, self.linear
|
||||
w = bn.weight / (bn.running_var + bn.eps)**0.5
|
||||
b = bn.bias - self.bn.running_mean * \
|
||||
self.bn.weight / (bn.running_var + bn.eps)**0.5
|
||||
w = linear.weight * w[None, :]
|
||||
if linear.bias is None:
|
||||
b = b @ self.linear.weight.T
|
||||
else:
|
||||
b = (linear.weight @ b[:, None]).view(-1) + self.linear.bias
|
||||
m = torch.nn.Linear(w.size(1), w.size(0))
|
||||
m.weight.data.copy_(w)
|
||||
m.bias.data.copy_(b)
|
||||
return m
|
||||
|
||||
|
||||
class PatchMerging(torch.nn.Module):
|
||||
def __init__(self, dim, out_dim):
|
||||
super().__init__()
|
||||
hid_dim = int(dim * 4)
|
||||
self.conv1 = ConvNorm(dim, hid_dim, 1, 1, 0)
|
||||
self.act = torch.nn.ReLU()
|
||||
self.conv2 = ConvNorm(hid_dim, hid_dim, 3, 2, 1, groups=hid_dim)
|
||||
self.se = SqueezeExcite(hid_dim, .25)
|
||||
self.conv3 = ConvNorm(hid_dim, out_dim, 1, 1, 0)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv3(self.se(self.act(self.conv2(self.act(self.conv1(x))))))
|
||||
return x
|
||||
|
||||
|
||||
class ResidualDrop(torch.nn.Module):
|
||||
def __init__(self, m, drop=0.):
|
||||
super().__init__()
|
||||
self.m = m
|
||||
self.drop = drop
|
||||
|
||||
def forward(self, x):
|
||||
if self.training and self.drop > 0:
|
||||
return x + self.m(x) * torch.rand(
|
||||
x.size(0), 1, 1, 1, device=x.device).ge_(self.drop).div(1 - self.drop).detach()
|
||||
else:
|
||||
return x + self.m(x)
|
||||
|
||||
|
||||
class ConvMlp(torch.nn.Module):
|
||||
def __init__(self, ed, h):
|
||||
super().__init__()
|
||||
self.pw1 = ConvNorm(ed, h)
|
||||
self.act = torch.nn.ReLU()
|
||||
self.pw2 = ConvNorm(h, ed, bn_weight_init=0)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.pw2(self.act(self.pw1(x)))
|
||||
return x
|
||||
|
||||
|
||||
class CascadedGroupAttention(torch.nn.Module):
|
||||
attention_bias_cache: Dict[str, torch.Tensor]
|
||||
|
||||
r""" Cascaded Group Attention.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
key_dim (int): The dimension for query and key.
|
||||
num_heads (int): Number of attention heads.
|
||||
attn_ratio (int): Multiplier for the query dim for value dimension.
|
||||
resolution (int): Input resolution, correspond to the window size.
|
||||
kernels (List[int]): The kernel size of the dw conv on query.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
key_dim,
|
||||
num_heads=8,
|
||||
attn_ratio=4,
|
||||
resolution=14,
|
||||
kernels=(5, 5, 5, 5),
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.scale = key_dim ** -0.5
|
||||
self.key_dim = key_dim
|
||||
self.val_dim = int(attn_ratio * key_dim)
|
||||
self.attn_ratio = attn_ratio
|
||||
|
||||
qkvs = []
|
||||
dws = []
|
||||
for i in range(num_heads):
|
||||
qkvs.append(ConvNorm(dim // (num_heads), self.key_dim * 2 + self.val_dim))
|
||||
dws.append(ConvNorm(self.key_dim, self.key_dim, kernels[i], 1, kernels[i] // 2, groups=self.key_dim))
|
||||
self.qkvs = torch.nn.ModuleList(qkvs)
|
||||
self.dws = torch.nn.ModuleList(dws)
|
||||
self.proj = torch.nn.Sequential(
|
||||
torch.nn.ReLU(),
|
||||
ConvNorm(self.val_dim * num_heads, dim, bn_weight_init=0)
|
||||
)
|
||||
|
||||
points = list(itertools.product(range(resolution), range(resolution)))
|
||||
N = len(points)
|
||||
attention_offsets = {}
|
||||
idxs = []
|
||||
for p1 in points:
|
||||
for p2 in points:
|
||||
offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
|
||||
if offset not in attention_offsets:
|
||||
attention_offsets[offset] = len(attention_offsets)
|
||||
idxs.append(attention_offsets[offset])
|
||||
self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
|
||||
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N), persistent=False)
|
||||
self.attention_bias_cache = {}
|
||||
|
||||
@torch.no_grad()
|
||||
def train(self, mode=True):
|
||||
super().train(mode)
|
||||
if mode and self.attention_bias_cache:
|
||||
self.attention_bias_cache = {} # clear ab cache
|
||||
|
||||
def get_attention_biases(self, device: torch.device) -> torch.Tensor:
|
||||
if torch.jit.is_tracing() or self.training:
|
||||
return self.attention_biases[:, self.attention_bias_idxs]
|
||||
else:
|
||||
device_key = str(device)
|
||||
if device_key not in self.attention_bias_cache:
|
||||
self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
|
||||
return self.attention_bias_cache[device_key]
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
feats_in = x.chunk(len(self.qkvs), dim=1)
|
||||
feats_out = []
|
||||
feat = feats_in[0]
|
||||
attn_bias = self.get_attention_biases(x.device)
|
||||
for head_idx, (qkv, dws) in enumerate(zip(self.qkvs, self.dws)):
|
||||
if head_idx > 0:
|
||||
feat = feat + feats_in[head_idx]
|
||||
feat = qkv(feat)
|
||||
q, k, v = feat.view(B, -1, H, W).split([self.key_dim, self.key_dim, self.val_dim], dim=1)
|
||||
q = dws(q)
|
||||
q, k, v = q.flatten(2), k.flatten(2), v.flatten(2)
|
||||
q = q * self.scale
|
||||
attn = q.transpose(-2, -1) @ k
|
||||
attn = attn + attn_bias[head_idx]
|
||||
attn = attn.softmax(dim=-1)
|
||||
feat = v @ attn.transpose(-2, -1)
|
||||
feat = feat.view(B, self.val_dim, H, W)
|
||||
feats_out.append(feat)
|
||||
x = self.proj(torch.cat(feats_out, 1))
|
||||
return x
|
||||
|
||||
|
||||
class LocalWindowAttention(torch.nn.Module):
|
||||
r""" Local Window Attention.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
key_dim (int): The dimension for query and key.
|
||||
num_heads (int): Number of attention heads.
|
||||
attn_ratio (int): Multiplier for the query dim for value dimension.
|
||||
resolution (int): Input resolution.
|
||||
window_resolution (int): Local window resolution.
|
||||
kernels (List[int]): The kernel size of the dw conv on query.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
key_dim,
|
||||
num_heads=8,
|
||||
attn_ratio=4,
|
||||
resolution=14,
|
||||
window_resolution=7,
|
||||
kernels=(5, 5, 5, 5),
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.resolution = resolution
|
||||
assert window_resolution > 0, 'window_size must be greater than 0'
|
||||
self.window_resolution = window_resolution
|
||||
window_resolution = min(window_resolution, resolution)
|
||||
self.attn = CascadedGroupAttention(
|
||||
dim, key_dim, num_heads,
|
||||
attn_ratio=attn_ratio,
|
||||
resolution=window_resolution,
|
||||
kernels=kernels,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
H = W = self.resolution
|
||||
B, C, H_, W_ = x.shape
|
||||
# Only check this for classifcation models
|
||||
_assert(H == H_, f'input feature has wrong size, expect {(H, W)}, got {(H_, W_)}')
|
||||
_assert(W == W_, f'input feature has wrong size, expect {(H, W)}, got {(H_, W_)}')
|
||||
if H <= self.window_resolution and W <= self.window_resolution:
|
||||
x = self.attn(x)
|
||||
else:
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
pad_b = (self.window_resolution - H % self.window_resolution) % self.window_resolution
|
||||
pad_r = (self.window_resolution - W % self.window_resolution) % self.window_resolution
|
||||
x = torch.nn.functional.pad(x, (0, 0, 0, pad_r, 0, pad_b))
|
||||
|
||||
pH, pW = H + pad_b, W + pad_r
|
||||
nH = pH // self.window_resolution
|
||||
nW = pW // self.window_resolution
|
||||
# window partition, BHWC -> B(nHh)(nWw)C -> BnHnWhwC -> (BnHnW)hwC -> (BnHnW)Chw
|
||||
x = x.view(B, nH, self.window_resolution, nW, self.window_resolution, C).transpose(2, 3)
|
||||
x = x.reshape(B * nH * nW, self.window_resolution, self.window_resolution, C).permute(0, 3, 1, 2)
|
||||
x = self.attn(x)
|
||||
# window reverse, (BnHnW)Chw -> (BnHnW)hwC -> BnHnWhwC -> B(nHh)(nWw)C -> BHWC
|
||||
x = x.permute(0, 2, 3, 1).view(B, nH, nW, self.window_resolution, self.window_resolution, C)
|
||||
x = x.transpose(2, 3).reshape(B, pH, pW, C)
|
||||
x = x[:, :H, :W].contiguous()
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class EfficientVitBlock(torch.nn.Module):
|
||||
""" A basic EfficientVit building block.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
key_dim (int): Dimension for query and key in the token mixer.
|
||||
num_heads (int): Number of attention heads.
|
||||
attn_ratio (int): Multiplier for the query dim for value dimension.
|
||||
resolution (int): Input resolution.
|
||||
window_resolution (int): Local window resolution.
|
||||
kernels (List[int]): The kernel size of the dw conv on query.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
key_dim,
|
||||
num_heads=8,
|
||||
attn_ratio=4,
|
||||
resolution=14,
|
||||
window_resolution=7,
|
||||
kernels=[5, 5, 5, 5],
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.dw0 = ResidualDrop(ConvNorm(dim, dim, 3, 1, 1, groups=dim, bn_weight_init=0.))
|
||||
self.ffn0 = ResidualDrop(ConvMlp(dim, int(dim * 2)))
|
||||
|
||||
self.mixer = ResidualDrop(
|
||||
LocalWindowAttention(
|
||||
dim, key_dim, num_heads,
|
||||
attn_ratio=attn_ratio,
|
||||
resolution=resolution,
|
||||
window_resolution=window_resolution,
|
||||
kernels=kernels,
|
||||
)
|
||||
)
|
||||
|
||||
self.dw1 = ResidualDrop(ConvNorm(dim, dim, 3, 1, 1, groups=dim, bn_weight_init=0.))
|
||||
self.ffn1 = ResidualDrop(ConvMlp(dim, int(dim * 2)))
|
||||
|
||||
def forward(self, x):
|
||||
return self.ffn1(self.dw1(self.mixer(self.ffn0(self.dw0(x)))))
|
||||
|
||||
|
||||
class EfficientVitStage(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_dim,
|
||||
out_dim,
|
||||
key_dim,
|
||||
downsample=('', 1),
|
||||
num_heads=8,
|
||||
attn_ratio=4,
|
||||
resolution=14,
|
||||
window_resolution=7,
|
||||
kernels=[5, 5, 5, 5],
|
||||
depth=1,
|
||||
):
|
||||
super().__init__()
|
||||
if downsample[0] == 'subsample':
|
||||
self.resolution = (resolution - 1) // downsample[1] + 1
|
||||
down_blocks = []
|
||||
down_blocks.append((
|
||||
'res1',
|
||||
torch.nn.Sequential(
|
||||
ResidualDrop(ConvNorm(in_dim, in_dim, 3, 1, 1, groups=in_dim)),
|
||||
ResidualDrop(ConvMlp(in_dim, int(in_dim * 2))),
|
||||
)
|
||||
))
|
||||
down_blocks.append(('patchmerge', PatchMerging(in_dim, out_dim)))
|
||||
down_blocks.append((
|
||||
'res2',
|
||||
torch.nn.Sequential(
|
||||
ResidualDrop(ConvNorm(out_dim, out_dim, 3, 1, 1, groups=out_dim)),
|
||||
ResidualDrop(ConvMlp(out_dim, int(out_dim * 2))),
|
||||
)
|
||||
))
|
||||
self.downsample = nn.Sequential(OrderedDict(down_blocks))
|
||||
else:
|
||||
assert in_dim == out_dim
|
||||
self.downsample = nn.Identity()
|
||||
self.resolution = resolution
|
||||
|
||||
blocks = []
|
||||
for d in range(depth):
|
||||
blocks.append(EfficientVitBlock(out_dim, key_dim, num_heads, attn_ratio, self.resolution, window_resolution, kernels))
|
||||
self.blocks = nn.Sequential(*blocks)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.downsample(x)
|
||||
x = self.blocks(x)
|
||||
return x
|
||||
|
||||
|
||||
class PatchEmbedding(torch.nn.Sequential):
|
||||
def __init__(self, in_chans, dim):
|
||||
super().__init__()
|
||||
self.add_module('conv1', ConvNorm(in_chans, dim // 8, 3, 2, 1))
|
||||
self.add_module('relu1', torch.nn.ReLU())
|
||||
self.add_module('conv2', ConvNorm(dim // 8, dim // 4, 3, 2, 1))
|
||||
self.add_module('relu2', torch.nn.ReLU())
|
||||
self.add_module('conv3', ConvNorm(dim // 4, dim // 2, 3, 2, 1))
|
||||
self.add_module('relu3', torch.nn.ReLU())
|
||||
self.add_module('conv4', ConvNorm(dim // 2, dim, 3, 2, 1))
|
||||
self.patch_size = 16
|
||||
|
||||
|
||||
class EfficientVitMsra(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
embed_dim=(64, 128, 192),
|
||||
key_dim=(16, 16, 16),
|
||||
depth=(1, 2, 3),
|
||||
num_heads=(4, 4, 4),
|
||||
window_size=(7, 7, 7),
|
||||
kernels=(5, 5, 5, 5),
|
||||
down_ops=(('', 1), ('subsample', 2), ('subsample', 2)),
|
||||
global_pool='avg',
|
||||
drop_rate=0.,
|
||||
):
|
||||
super(EfficientVitMsra, self).__init__()
|
||||
self.grad_checkpointing = False
|
||||
self.num_classes = num_classes
|
||||
self.drop_rate = drop_rate
|
||||
|
||||
# Patch embedding
|
||||
self.patch_embed = PatchEmbedding(in_chans, embed_dim[0])
|
||||
stride = self.patch_embed.patch_size
|
||||
resolution = img_size // self.patch_embed.patch_size
|
||||
attn_ratio = [embed_dim[i] / (key_dim[i] * num_heads[i]) for i in range(len(embed_dim))]
|
||||
|
||||
# Build EfficientVit blocks
|
||||
self.feature_info = []
|
||||
stages = []
|
||||
pre_ed = embed_dim[0]
|
||||
for i, (ed, kd, dpth, nh, ar, wd, do) in enumerate(
|
||||
zip(embed_dim, key_dim, depth, num_heads, attn_ratio, window_size, down_ops)):
|
||||
stage = EfficientVitStage(
|
||||
in_dim=pre_ed,
|
||||
out_dim=ed,
|
||||
key_dim=kd,
|
||||
downsample=do,
|
||||
num_heads=nh,
|
||||
attn_ratio=ar,
|
||||
resolution=resolution,
|
||||
window_resolution=wd,
|
||||
kernels=kernels,
|
||||
depth=dpth,
|
||||
)
|
||||
pre_ed = ed
|
||||
if do[0] == 'subsample' and i != 0:
|
||||
stride *= do[1]
|
||||
resolution = stage.resolution
|
||||
stages.append(stage)
|
||||
self.feature_info += [dict(num_chs=ed, reduction=stride, module=f'stages.{i}')]
|
||||
self.stages = nn.Sequential(*stages)
|
||||
|
||||
if global_pool == 'avg':
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
|
||||
else:
|
||||
assert num_classes == 0
|
||||
self.global_pool = nn.Identity()
|
||||
self.num_features = embed_dim[-1]
|
||||
self.head = NormLinear(
|
||||
self.num_features, num_classes, drop=self.drop_rate) if num_classes > 0 else torch.nn.Identity()
|
||||
|
||||
@torch.jit.ignore
|
||||
def group_matcher(self, coarse=False):
|
||||
matcher = dict(
|
||||
stem=r'^patch_embed',
|
||||
blocks=[(r'^stages\.(\d+)', None)]
|
||||
)
|
||||
return matcher
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable=True):
|
||||
self.grad_checkpointing = enable
|
||||
|
||||
@torch.jit.ignore
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
if global_pool == 'avg':
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
|
||||
else:
|
||||
assert num_classes == 0
|
||||
self.global_pool = nn.Identity()
|
||||
self.head = NormLinear(
|
||||
self.num_features, num_classes, drop=self.drop_rate) if num_classes > 0 else torch.nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint_seq(self.stages, x)
|
||||
else:
|
||||
x = self.stages(x)
|
||||
return x
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
x = self.global_pool(x)
|
||||
return x if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
x = self.forward_head(x)
|
||||
return x
|
||||
|
||||
|
||||
# def checkpoint_filter_fn(state_dict, model):
|
||||
# if 'model' in state_dict.keys():
|
||||
# state_dict = state_dict['model']
|
||||
# tmp_dict = {}
|
||||
# out_dict = {}
|
||||
# target_keys = model.state_dict().keys()
|
||||
# target_keys = [k for k in target_keys if k.startswith('stages.')]
|
||||
#
|
||||
# for k, v in state_dict.items():
|
||||
# if 'attention_bias_idxs' in k:
|
||||
# continue
|
||||
# k = k.split('.')
|
||||
# if k[-2] == 'c':
|
||||
# k[-2] = 'conv'
|
||||
# if k[-2] == 'l':
|
||||
# k[-2] = 'linear'
|
||||
# k = '.'.join(k)
|
||||
# tmp_dict[k] = v
|
||||
#
|
||||
# for k, v in tmp_dict.items():
|
||||
# if k.startswith('patch_embed'):
|
||||
# k = k.split('.')
|
||||
# k[1] = 'conv' + str(int(k[1]) // 2 + 1)
|
||||
# k = '.'.join(k)
|
||||
# elif k.startswith('blocks'):
|
||||
# kw = '.'.join(k.split('.')[2:])
|
||||
# find_kw = [a for a in list(sorted(tmp_dict.keys())) if kw in a]
|
||||
# idx = find_kw.index(k)
|
||||
# k = [a for a in target_keys if kw in a][idx]
|
||||
# out_dict[k] = v
|
||||
#
|
||||
# return out_dict
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000,
|
||||
'mean': IMAGENET_DEFAULT_MEAN,
|
||||
'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'patch_embed.conv1.conv',
|
||||
'classifier': 'head.linear',
|
||||
'fixed_input_size': True,
|
||||
'pool_size': (4, 4),
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'efficientvit_m0.r224_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
#url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m0.pth'
|
||||
),
|
||||
'efficientvit_m1.r224_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
#url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m1.pth'
|
||||
),
|
||||
'efficientvit_m2.r224_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
#url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m2.pth'
|
||||
),
|
||||
'efficientvit_m3.r224_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
#url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m3.pth'
|
||||
),
|
||||
'efficientvit_m4.r224_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
#url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m4.pth'
|
||||
),
|
||||
'efficientvit_m5.r224_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
#url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m5.pth'
|
||||
),
|
||||
})
|
||||
|
||||
|
||||
def _create_efficientvit_msra(variant, pretrained=False, **kwargs):
|
||||
out_indices = kwargs.pop('out_indices', (0, 1, 2))
|
||||
model = build_model_with_cfg(
|
||||
EfficientVitMsra,
|
||||
variant,
|
||||
pretrained,
|
||||
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
|
||||
**kwargs
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientvit_m0(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
img_size=224,
|
||||
embed_dim=[64, 128, 192],
|
||||
depth=[1, 2, 3],
|
||||
num_heads=[4, 4, 4],
|
||||
window_size=[7, 7, 7],
|
||||
kernels=[5, 5, 5, 5]
|
||||
)
|
||||
return _create_efficientvit_msra('efficientvit_m0', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientvit_m1(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
img_size=224,
|
||||
embed_dim=[128, 144, 192],
|
||||
depth=[1, 2, 3],
|
||||
num_heads=[2, 3, 3],
|
||||
window_size=[7, 7, 7],
|
||||
kernels=[7, 5, 3, 3]
|
||||
)
|
||||
return _create_efficientvit_msra('efficientvit_m1', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientvit_m2(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
img_size=224,
|
||||
embed_dim=[128, 192, 224],
|
||||
depth=[1, 2, 3],
|
||||
num_heads=[4, 3, 2],
|
||||
window_size=[7, 7, 7],
|
||||
kernels=[7, 5, 3, 3]
|
||||
)
|
||||
return _create_efficientvit_msra('efficientvit_m2', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientvit_m3(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
img_size=224,
|
||||
embed_dim=[128, 240, 320],
|
||||
depth=[1, 2, 3],
|
||||
num_heads=[4, 3, 4],
|
||||
window_size=[7, 7, 7],
|
||||
kernels=[5, 5, 5, 5]
|
||||
)
|
||||
return _create_efficientvit_msra('efficientvit_m3', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientvit_m4(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
img_size=224,
|
||||
embed_dim=[128, 256, 384],
|
||||
depth=[1, 2, 3],
|
||||
num_heads=[4, 4, 4],
|
||||
window_size=[7, 7, 7],
|
||||
kernels=[7, 5, 3, 3]
|
||||
)
|
||||
return _create_efficientvit_msra('efficientvit_m4', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientvit_m5(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
img_size=224,
|
||||
embed_dim=[192, 288, 384],
|
||||
depth=[1, 3, 4],
|
||||
num_heads=[3, 3, 4],
|
||||
window_size=[7, 7, 7],
|
||||
kernels=[7, 5, 3, 3]
|
||||
)
|
||||
return _create_efficientvit_msra('efficientvit_m5', pretrained=pretrained, **dict(model_args, **kwargs))
|
Loading…
Reference in New Issue