556 lines
21 KiB
Python
556 lines
21 KiB
Python
""" EfficientViT (by MSRA)
|
|
|
|
Paper: `EfficientViT: Memory Efficient Vision Transformer with Cascaded Group Attention`
|
|
- https://arxiv.org/abs/2305.07027
|
|
|
|
Adapted from official impl at https://github.com/microsoft/Cream/tree/main/EfficientViT
|
|
"""
|
|
|
|
__all__ = ['EfficientViTMSRA']
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
from timm.models.vision_transformer import trunc_normal_
|
|
from timm.models.layers import SqueezeExcite, SelectAdaptivePool2d
|
|
from ._registry import register_model, generate_default_cfgs
|
|
from ._builder import build_model_with_cfg
|
|
from ._manipulate import checkpoint_seq
|
|
import itertools
|
|
from collections import OrderedDict
|
|
|
|
|
|
class ConvBN(torch.nn.Sequential):
|
|
def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1):
|
|
super().__init__()
|
|
self.add_module('conv', torch.nn.Conv2d(
|
|
a, b, ks, stride, pad, dilation, groups, bias=False))
|
|
self.add_module('bn', torch.nn.BatchNorm2d(b))
|
|
torch.nn.init.constant_(self.bn.weight, bn_weight_init)
|
|
torch.nn.init.constant_(self.bn.bias, 0)
|
|
|
|
@torch.no_grad()
|
|
def fuse(self):
|
|
c, bn = self._modules.values()
|
|
w = bn.weight / (bn.running_var + bn.eps)**0.5
|
|
w = c.weight * w[:, None, None, None]
|
|
b = bn.bias - bn.running_mean * bn.weight / \
|
|
(bn.running_var + bn.eps)**0.5
|
|
m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size(
|
|
0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)
|
|
m.weight.data.copy_(w)
|
|
m.bias.data.copy_(b)
|
|
return m
|
|
|
|
|
|
class BNLinear(torch.nn.Sequential):
|
|
def __init__(self, a, b, bias=True, std=0.02):
|
|
super().__init__()
|
|
self.add_module('bn', torch.nn.BatchNorm1d(a))
|
|
self.add_module('linear', torch.nn.Linear(a, b, bias=bias))
|
|
trunc_normal_(self.linear.weight, std=std)
|
|
if bias:
|
|
torch.nn.init.constant_(self.linear.bias, 0)
|
|
|
|
@torch.no_grad()
|
|
def fuse(self):
|
|
bn, linear = self._modules.values()
|
|
w = bn.weight / (bn.running_var + bn.eps)**0.5
|
|
b = bn.bias - self.bn.running_mean * \
|
|
self.bn.weight / (bn.running_var + bn.eps)**0.5
|
|
w = linear.weight * w[None, :]
|
|
if linear.bias is None:
|
|
b = b @ self.linear.weight.T
|
|
else:
|
|
b = (linear.weight @ b[:, None]).view(-1) + self.linear.bias
|
|
m = torch.nn.Linear(w.size(1), w.size(0))
|
|
m.weight.data.copy_(w)
|
|
m.bias.data.copy_(b)
|
|
return m
|
|
|
|
|
|
class PatchMerging(torch.nn.Module):
|
|
def __init__(self, dim, out_dim):
|
|
super().__init__()
|
|
hid_dim = int(dim * 4)
|
|
self.conv1 = ConvBN(dim, hid_dim, 1, 1, 0)
|
|
self.act = torch.nn.ReLU()
|
|
self.conv2 = ConvBN(hid_dim, hid_dim, 3, 2, 1, groups=hid_dim)
|
|
self.se = SqueezeExcite(hid_dim, .25)
|
|
self.conv3 = ConvBN(hid_dim, out_dim, 1, 1, 0)
|
|
|
|
def forward(self, x):
|
|
x = self.conv3(self.se(self.act(self.conv2(self.act(self.conv1(x))))))
|
|
return x
|
|
|
|
|
|
class ResidualDrop(torch.nn.Module):
|
|
def __init__(self, m, drop=0.):
|
|
super().__init__()
|
|
self.m = m
|
|
self.drop = drop
|
|
|
|
def forward(self, x):
|
|
if self.training and self.drop > 0:
|
|
return x + self.m(x) * torch.rand(x.size(0), 1, 1, 1,
|
|
device=x.device).ge_(self.drop).div(1 - self.drop).detach()
|
|
else:
|
|
return x + self.m(x)
|
|
|
|
|
|
class FFN(torch.nn.Module):
|
|
def __init__(self, ed, h):
|
|
super().__init__()
|
|
self.pw1 = ConvBN(ed, h)
|
|
self.act = torch.nn.ReLU()
|
|
self.pw2 = ConvBN(h, ed, bn_weight_init=0)
|
|
|
|
def forward(self, x):
|
|
x = self.pw2(self.act(self.pw1(x)))
|
|
return x
|
|
|
|
|
|
class CascadedGroupAttention(torch.nn.Module):
|
|
r""" Cascaded Group Attention.
|
|
|
|
Args:
|
|
dim (int): Number of input channels.
|
|
key_dim (int): The dimension for query and key.
|
|
num_heads (int): Number of attention heads.
|
|
attn_ratio (int): Multiplier for the query dim for value dimension.
|
|
resolution (int): Input resolution, correspond to the window size.
|
|
kernels (List[int]): The kernel size of the dw conv on query.
|
|
"""
|
|
def __init__(self, dim, key_dim, num_heads=8,
|
|
attn_ratio=4,
|
|
resolution=14,
|
|
kernels=[5, 5, 5, 5],):
|
|
super().__init__()
|
|
self.num_heads = num_heads
|
|
self.scale = key_dim ** -0.5
|
|
self.key_dim = key_dim
|
|
self.d = int(attn_ratio * key_dim)
|
|
self.attn_ratio = attn_ratio
|
|
|
|
qkvs = []
|
|
dws = []
|
|
for i in range(num_heads):
|
|
qkvs.append(ConvBN(dim // (num_heads), self.key_dim * 2 + self.d))
|
|
dws.append(ConvBN(self.key_dim, self.key_dim, kernels[i], 1, kernels[i]//2, groups=self.key_dim))
|
|
self.qkvs = torch.nn.ModuleList(qkvs)
|
|
self.dws = torch.nn.ModuleList(dws)
|
|
self.proj = torch.nn.Sequential(
|
|
torch.nn.ReLU(),
|
|
ConvBN(self.d * num_heads, dim, bn_weight_init=0)
|
|
)
|
|
|
|
points = list(itertools.product(range(resolution), range(resolution)))
|
|
N = len(points)
|
|
attention_offsets = {}
|
|
idxs = []
|
|
for p1 in points:
|
|
for p2 in points:
|
|
offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
|
|
if offset not in attention_offsets:
|
|
attention_offsets[offset] = len(attention_offsets)
|
|
idxs.append(attention_offsets[offset])
|
|
self.attention_biases = torch.nn.Parameter(
|
|
torch.zeros(num_heads, len(attention_offsets)))
|
|
self.register_buffer('attention_bias_idxs',
|
|
torch.LongTensor(idxs).view(N, N))
|
|
|
|
@torch.no_grad()
|
|
def train(self, mode=True):
|
|
super().train(mode)
|
|
if mode and hasattr(self, 'ab'):
|
|
del self.ab
|
|
else:
|
|
self.ab = self.attention_biases[:, self.attention_bias_idxs]
|
|
|
|
def forward(self, x):
|
|
B, C, H, W = x.shape
|
|
trainingab = self.attention_biases[:, self.attention_bias_idxs]
|
|
feats_in = x.chunk(len(self.qkvs), dim=1)
|
|
feats_out = []
|
|
feat = feats_in[0]
|
|
for i, qkv in enumerate(self.qkvs):
|
|
if i > 0:
|
|
feat = feat + feats_in[i]
|
|
feat = qkv(feat)
|
|
q, k, v = feat.view(B, -1, H, W).split([self.key_dim, self.key_dim, self.d], dim=1)
|
|
q = self.dws[i](q)
|
|
q, k, v = q.flatten(2), k.flatten(2), v.flatten(2)
|
|
attn = (
|
|
(q.transpose(-2, -1) @ k) * self.scale
|
|
+
|
|
(trainingab[i] if self.training else self.ab[i])
|
|
)
|
|
attn = attn.softmax(dim=-1)
|
|
feat = (v @ attn.transpose(-2, -1)).view(B, self.d, H, W)
|
|
feats_out.append(feat)
|
|
x = self.proj(torch.cat(feats_out, 1))
|
|
return x
|
|
|
|
|
|
class LocalWindowAttention(torch.nn.Module):
|
|
r""" Local Window Attention.
|
|
|
|
Args:
|
|
dim (int): Number of input channels.
|
|
key_dim (int): The dimension for query and key.
|
|
num_heads (int): Number of attention heads.
|
|
attn_ratio (int): Multiplier for the query dim for value dimension.
|
|
resolution (int): Input resolution.
|
|
window_resolution (int): Local window resolution.
|
|
kernels (List[int]): The kernel size of the dw conv on query.
|
|
"""
|
|
def __init__(self, dim, key_dim, num_heads=8,
|
|
attn_ratio=4,
|
|
resolution=14,
|
|
window_resolution=7,
|
|
kernels=[5, 5, 5, 5],):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.num_heads = num_heads
|
|
self.resolution = resolution
|
|
assert window_resolution > 0, 'window_size must be greater than 0'
|
|
self.window_resolution = window_resolution
|
|
window_resolution = min(window_resolution, resolution)
|
|
self.attn = CascadedGroupAttention(dim, key_dim, num_heads,
|
|
attn_ratio=attn_ratio,
|
|
resolution=window_resolution,
|
|
kernels=kernels,)
|
|
|
|
def forward(self, x):
|
|
H = W = self.resolution
|
|
B, C, H_, W_ = x.shape
|
|
# Only check this for classifcation models
|
|
assert H == H_ and W == W_, 'input feature has wrong size, expect {}, got {}'.format((H, W), (H_, W_))
|
|
|
|
if H <= self.window_resolution and W <= self.window_resolution:
|
|
x = self.attn(x)
|
|
else:
|
|
x = x.permute(0, 2, 3, 1)
|
|
pad_b = (self.window_resolution - H %
|
|
self.window_resolution) % self.window_resolution
|
|
pad_r = (self.window_resolution - W %
|
|
self.window_resolution) % self.window_resolution
|
|
padding = pad_b > 0 or pad_r > 0
|
|
|
|
if padding:
|
|
x = torch.nn.functional.pad(x, (0, 0, 0, pad_r, 0, pad_b))
|
|
|
|
pH, pW = H + pad_b, W + pad_r
|
|
nH = pH // self.window_resolution
|
|
nW = pW // self.window_resolution
|
|
# window partition, BHWC -> B(nHh)(nWw)C -> BnHnWhwC -> (BnHnW)hwC -> (BnHnW)Chw
|
|
x = x.view(B, nH, self.window_resolution, nW, self.window_resolution, C).transpose(2, 3).reshape(
|
|
B * nH * nW, self.window_resolution, self.window_resolution, C
|
|
).permute(0, 3, 1, 2)
|
|
x = self.attn(x)
|
|
# window reverse, (BnHnW)Chw -> (BnHnW)hwC -> BnHnWhwC -> B(nHh)(nWw)C -> BHWC
|
|
x = x.permute(0, 2, 3, 1).view(B, nH, nW, self.window_resolution, self.window_resolution, C).transpose(2, 3).reshape(B, pH, pW, C)
|
|
if padding:
|
|
x = x[:, :H, :W].contiguous()
|
|
x = x.permute(0, 3, 1, 2)
|
|
return x
|
|
|
|
|
|
class EfficientViTBlock(torch.nn.Module):
|
|
""" A basic EfficientViT building block.
|
|
|
|
Args:
|
|
ed (int): Number of input channels.
|
|
kd (int): Dimension for query and key in the token mixer.
|
|
nh (int): Number of attention heads.
|
|
ar (int): Multiplier for the query dim for value dimension.
|
|
resolution (int): Input resolution.
|
|
window_resolution (int): Local window resolution.
|
|
kernels (List[int]): The kernel size of the dw conv on query.
|
|
"""
|
|
def __init__(self, ed, kd, nh=8,
|
|
ar=4,
|
|
resolution=14,
|
|
window_resolution=7,
|
|
kernels=[5, 5, 5, 5],):
|
|
super().__init__()
|
|
|
|
self.dw0 = ResidualDrop(ConvBN(ed, ed, 3, 1, 1, groups=ed, bn_weight_init=0.))
|
|
self.ffn0 = ResidualDrop(FFN(ed, int(ed * 2)))
|
|
|
|
self.mixer = ResidualDrop(
|
|
LocalWindowAttention(ed, kd, nh, attn_ratio=ar, resolution=resolution,
|
|
window_resolution=window_resolution, kernels=kernels))
|
|
|
|
self.dw1 = ResidualDrop(ConvBN(ed, ed, 3, 1, 1, groups=ed, bn_weight_init=0.))
|
|
self.ffn1 = ResidualDrop(FFN(ed, int(ed * 2)))
|
|
|
|
def forward(self, x):
|
|
return self.ffn1(self.dw1(self.mixer(self.ffn0(self.dw0(x)))))
|
|
|
|
|
|
class EfficientViTStage(torch.nn.Module):
|
|
def __init__(self, do, pre_ed, ed, kd, nh=8,
|
|
ar=4,
|
|
resolution=14,
|
|
window_resolution=7,
|
|
kernels=[5, 5, 5, 5],
|
|
depth=1):
|
|
super().__init__()
|
|
if do[0] == 'subsample':
|
|
self.resolution = (resolution - 1) // do[1] + 1
|
|
down_blocks = []
|
|
down_blocks.append(('res1', torch.nn.Sequential(ResidualDrop(ConvBN(pre_ed, pre_ed, 3, 1, 1, groups=pre_ed)),
|
|
ResidualDrop(FFN(pre_ed, int(pre_ed * 2))),)))
|
|
down_blocks.append(('patchmerge', PatchMerging(pre_ed, ed)))
|
|
down_blocks.append(('res2', torch.nn.Sequential(ResidualDrop(ConvBN(ed, ed, 3, 1, 1, groups=ed)),
|
|
ResidualDrop(FFN(ed, int(ed * 2))),)))
|
|
self.downsample = nn.Sequential(OrderedDict(down_blocks))
|
|
else:
|
|
self.downsample = nn.Identity()
|
|
self.resolution = resolution
|
|
|
|
blocks = []
|
|
for d in range(depth):
|
|
blocks.append(EfficientViTBlock(ed, kd, nh, ar, self.resolution, window_resolution, kernels))
|
|
self.blocks = nn.Sequential(*blocks)
|
|
|
|
def forward(self, x):
|
|
x = self.downsample(x)
|
|
x = self.blocks(x)
|
|
return x
|
|
|
|
|
|
class PatchEmbedding(torch.nn.Sequential):
|
|
def __init__(self, in_chans, dim):
|
|
super().__init__()
|
|
self.add_module('conv1', ConvBN(in_chans, dim // 8, 3, 2, 1))
|
|
self.add_module('relu1', torch.nn.ReLU())
|
|
self.add_module('conv2', ConvBN(dim // 8, dim // 4, 3, 2, 1))
|
|
self.add_module('relu2', torch.nn.ReLU())
|
|
self.add_module('conv3', ConvBN(dim // 4, dim // 2, 3, 2, 1))
|
|
self.add_module('relu3', torch.nn.ReLU())
|
|
self.add_module('conv4', ConvBN(dim // 2, dim, 3, 2, 1))
|
|
self.patch_size = 16
|
|
|
|
|
|
class EfficientViTMSRA(nn.Module):
|
|
def __init__(
|
|
self,
|
|
img_size=224,
|
|
in_chans=3,
|
|
num_classes=1000,
|
|
embed_dim=[64, 128, 192],
|
|
key_dim=[16, 16, 16],
|
|
depth=[1, 2, 3],
|
|
num_heads=[4, 4, 4],
|
|
window_size=[7, 7, 7],
|
|
kernels=[5, 5, 5, 5],
|
|
down_ops=[[''], ['subsample', 2], ['subsample', 2]],
|
|
global_pool='avg',
|
|
):
|
|
super(EfficientViTMSRA, self).__init__()
|
|
self.grad_checkpointing = False
|
|
resolution = img_size
|
|
# Patch embedding
|
|
self.patch_embed = PatchEmbedding(in_chans, embed_dim[0])
|
|
stride = self.patch_embed.patch_size
|
|
resolution = img_size // self.patch_embed.patch_size
|
|
attn_ratio = [embed_dim[i] / (key_dim[i] * num_heads[i]) for i in range(len(embed_dim))]
|
|
|
|
# Build EfficientViT blocks
|
|
self.feature_info = []
|
|
stages = []
|
|
for i, (ed, kd, dpth, nh, ar, wd, do) in enumerate(
|
|
zip(embed_dim, key_dim, depth, num_heads, attn_ratio, window_size, down_ops)):
|
|
pre_ed = embed_dim[i - 1]
|
|
stage = EfficientViTStage(do, pre_ed, ed, kd, nh, ar, resolution, wd, kernels, dpth)
|
|
if do[0] == 'subsample' and i != 0:
|
|
stride *= 2
|
|
resolution = stage.resolution
|
|
stages.append(stage)
|
|
self.feature_info += [dict(num_chs=ed, reduction=stride, module=f'stages.{i}')]
|
|
self.stages = nn.Sequential(*stages)
|
|
|
|
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True, input_fmt='NCHW')
|
|
self.out_dims = embed_dim[-1]
|
|
self.head = BNLinear(self.out_dims, num_classes) if num_classes > 0 else torch.nn.Identity()
|
|
|
|
@torch.jit.ignore
|
|
def group_matcher(self, coarse=False):
|
|
matcher = dict(
|
|
stem=r'^patch_embed',
|
|
blocks=[(r'^stages\.(\d+)', None)]
|
|
)
|
|
return matcher
|
|
|
|
@torch.jit.ignore
|
|
def set_grad_checkpointing(self, enable=True):
|
|
self.grad_checkpointing = enable
|
|
|
|
@torch.jit.ignore
|
|
def get_classifier(self):
|
|
return self.head
|
|
|
|
def reset_classifier(self, num_classes, global_pool=None):
|
|
self.num_classes = num_classes
|
|
if global_pool is not None:
|
|
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True, input_fmt='NCHW')
|
|
self.head = BNLinear(self.out_dims, num_classes) if num_classes > 0 else torch.nn.Identity()
|
|
|
|
def forward_features(self, x):
|
|
x = self.patch_embed(x)
|
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
|
x = checkpoint_seq(self.stages, x)
|
|
else:
|
|
x = self.stages(x)
|
|
return x
|
|
|
|
def forward_head(self, x, pre_logits: bool = False):
|
|
x = self.global_pool(x)
|
|
return x if pre_logits else self.head(x)
|
|
|
|
def forward(self, x):
|
|
x = self.forward_features(x)
|
|
x = self.forward_head(x)
|
|
return x
|
|
|
|
|
|
def checkpoint_filter_fn(state_dict, model):
|
|
if 'model' in state_dict.keys():
|
|
state_dict = state_dict['model']
|
|
tmp_dict = {}
|
|
out_dict = {}
|
|
target_keys = model.state_dict().keys()
|
|
target_keys = [k for k in target_keys if k.startswith('stages.')]
|
|
for k, v in state_dict.items():
|
|
k = k.split('.')
|
|
if k[-2] == 'c':
|
|
k[-2] = 'conv'
|
|
if k[-2] == 'l':
|
|
k[-2] = 'linear'
|
|
k = '.'.join(k)
|
|
tmp_dict[k] = v
|
|
for k, v in tmp_dict.items():
|
|
if k.startswith('patch_embed'):
|
|
k = k.split('.')
|
|
k[1] = 'conv' + str(int(k[1]) // 2 + 1)
|
|
k = '.'.join(k)
|
|
elif k.startswith('blocks'):
|
|
kw = '.'.join(k.split('.')[2:])
|
|
find_kw = [a for a in list(sorted(tmp_dict.keys())) if kw in a]
|
|
idx = find_kw.index(k)
|
|
k = [a for a in target_keys if kw in a][idx]
|
|
out_dict[k] = v
|
|
return out_dict
|
|
|
|
|
|
def _cfg(url='', **kwargs):
|
|
return {
|
|
'url': url,
|
|
'num_classes': 1000,
|
|
'mean': IMAGENET_DEFAULT_MEAN,
|
|
'std': IMAGENET_DEFAULT_STD,
|
|
'first_conv': 'patch_embed.conv1.conv',
|
|
'classifier': 'head.linear',
|
|
**kwargs,
|
|
}
|
|
|
|
|
|
default_cfgs = generate_default_cfgs(
|
|
{
|
|
'efficientvit_m0.r224_in1k': _cfg(
|
|
url='https://github.com/xinyuliu-jeffrey/EfficientViT_Model_Zoo/releases/download/v1.0/efficientvit_m0.pth'
|
|
),
|
|
'efficientvit_m1.r224_in1k': _cfg(
|
|
url='https://github.com/xinyuliu-jeffrey/EfficientViT_Model_Zoo/releases/download/v1.0/efficientvit_m1.pth'
|
|
),
|
|
'efficientvit_m2.r224_in1k': _cfg(
|
|
url='https://github.com/xinyuliu-jeffrey/EfficientViT_Model_Zoo/releases/download/v1.0/efficientvit_m2.pth'
|
|
),
|
|
'efficientvit_m3.r224_in1k': _cfg(
|
|
url='https://github.com/xinyuliu-jeffrey/EfficientViT_Model_Zoo/releases/download/v1.0/efficientvit_m3.pth'
|
|
),
|
|
'efficientvit_m4.r224_in1k': _cfg(
|
|
url='https://github.com/xinyuliu-jeffrey/EfficientViT_Model_Zoo/releases/download/v1.0/efficientvit_m4.pth'
|
|
),
|
|
'efficientvit_m5.r224_in1k': _cfg(
|
|
url='https://github.com/xinyuliu-jeffrey/EfficientViT_Model_Zoo/releases/download/v1.0/efficientvit_m5.pth'
|
|
),
|
|
}
|
|
)
|
|
|
|
|
|
def _create_efficientvit_msra(variant, pretrained=False, **kwargs):
|
|
out_indices = kwargs.pop('out_indices', (0, 1, 2))
|
|
model = build_model_with_cfg(
|
|
EfficientViTMSRA,
|
|
variant,
|
|
pretrained,
|
|
pretrained_filter_fn=checkpoint_filter_fn,
|
|
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
|
|
**kwargs
|
|
)
|
|
return model
|
|
|
|
|
|
@register_model
|
|
def efficientvit_m0(pretrained=False, **kwargs):
|
|
model_args = dict(img_size=224,
|
|
embed_dim=[64, 128, 192],
|
|
depth=[1, 2, 3],
|
|
num_heads=[4, 4, 4],
|
|
window_size=[7, 7, 7],
|
|
kernels=[5, 5, 5, 5])
|
|
return _create_efficientvit_msra('efficientvit_m0', pretrained=pretrained, **dict(model_args, **kwargs))
|
|
|
|
@register_model
|
|
def efficientvit_m1(pretrained=False, **kwargs):
|
|
model_args = dict(img_size=224,
|
|
embed_dim=[128, 144, 192],
|
|
depth=[1, 2, 3],
|
|
num_heads=[2, 3, 3],
|
|
window_size=[7, 7, 7],
|
|
kernels=[7, 5, 3, 3])
|
|
return _create_efficientvit_msra('efficientvit_m1', pretrained=pretrained, **dict(model_args, **kwargs))
|
|
|
|
@register_model
|
|
def efficientvit_m2(pretrained=False, **kwargs):
|
|
model_args = dict(img_size=224,
|
|
embed_dim=[128, 192, 224],
|
|
depth=[1, 2, 3],
|
|
num_heads=[4, 3, 2],
|
|
window_size=[7, 7, 7],
|
|
kernels=[7, 5, 3, 3])
|
|
return _create_efficientvit_msra('efficientvit_m2', pretrained=pretrained, **dict(model_args, **kwargs))
|
|
|
|
@register_model
|
|
def efficientvit_m3(pretrained=False, **kwargs):
|
|
model_args = dict(img_size=224,
|
|
embed_dim=[128, 240, 320],
|
|
depth=[1, 2, 3],
|
|
num_heads=[4, 3, 4],
|
|
window_size=[7, 7, 7],
|
|
kernels=[5, 5, 5, 5])
|
|
return _create_efficientvit_msra('efficientvit_m3', pretrained=pretrained, **dict(model_args, **kwargs))
|
|
|
|
@register_model
|
|
def efficientvit_m4(pretrained=False, **kwargs):
|
|
model_args = dict(img_size=224,
|
|
embed_dim=[128, 256, 384],
|
|
depth=[1, 2, 3],
|
|
num_heads=[4, 4, 4],
|
|
window_size=[7, 7, 7],
|
|
kernels=[7, 5, 3, 3])
|
|
return _create_efficientvit_msra('efficientvit_m4', pretrained=pretrained, **dict(model_args, **kwargs))
|
|
|
|
@register_model
|
|
def efficientvit_m5(pretrained=False, **kwargs):
|
|
model_args = dict(img_size=224,
|
|
embed_dim=[192, 288, 384],
|
|
depth=[1, 3, 4],
|
|
num_heads=[3, 3, 4],
|
|
window_size=[7, 7, 7],
|
|
kernels=[7, 5, 3, 3])
|
|
return _create_efficientvit_msra('efficientvit_m5', pretrained=pretrained, **dict(model_args, **kwargs))
|