655 lines
20 KiB
Python
655 lines
20 KiB
Python
|
""" EfficientViT
|
||
|
|
||
|
Paper: `Efficientvit: Enhanced linear attention for high-resolution low-computation visual recognition`
|
||
|
- https://arxiv.org/abs/2205.14756
|
||
|
|
||
|
Adapted from official impl at https://github.com/mit-han-lab/efficientvit
|
||
|
"""
|
||
|
|
||
|
__all__ = ['EfficientViT']
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||
|
from ._registry import register_model, generate_default_cfgs
|
||
|
from ._builder import build_model_with_cfg
|
||
|
from ._manipulate import checkpoint_seq
|
||
|
from functools import partial
|
||
|
from timm.layers import SelectAdaptivePool2d
|
||
|
from collections import OrderedDict
|
||
|
|
||
|
|
||
|
def val2list(x: list or tuple or any, repeat_time=1):
|
||
|
if isinstance(x, (list, tuple)):
|
||
|
return list(x)
|
||
|
return [x for _ in range(repeat_time)]
|
||
|
|
||
|
def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1):
|
||
|
# repeat elements if necessary
|
||
|
x = val2list(x)
|
||
|
if len(x) > 0:
|
||
|
x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]
|
||
|
|
||
|
return tuple(x)
|
||
|
|
||
|
def get_same_padding(kernel_size: int or tuple[int, ...]) -> int or tuple[int, ...]:
|
||
|
if isinstance(kernel_size, tuple):
|
||
|
return tuple([get_same_padding(ks) for ks in kernel_size])
|
||
|
else:
|
||
|
assert kernel_size % 2 > 0, "kernel size should be odd number"
|
||
|
return kernel_size // 2
|
||
|
|
||
|
class ConvLayer(nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
in_channels: int,
|
||
|
out_channels: int,
|
||
|
kernel_size=3,
|
||
|
stride=1,
|
||
|
dilation=1,
|
||
|
groups=1,
|
||
|
use_bias=False,
|
||
|
dropout=0,
|
||
|
norm=nn.BatchNorm2d,
|
||
|
act_func=nn.ReLU,
|
||
|
):
|
||
|
super(ConvLayer, self).__init__()
|
||
|
|
||
|
padding = get_same_padding(kernel_size)
|
||
|
padding *= dilation
|
||
|
|
||
|
self.dropout = nn.Dropout(dropout, inplace=False) if dropout > 0 else None
|
||
|
self.conv = nn.Conv2d(
|
||
|
in_channels,
|
||
|
out_channels,
|
||
|
kernel_size=(kernel_size, kernel_size),
|
||
|
stride=(stride, stride),
|
||
|
padding=padding,
|
||
|
dilation=(dilation, dilation),
|
||
|
groups=groups,
|
||
|
bias=use_bias,
|
||
|
)
|
||
|
self.norm = norm(num_features=out_channels) if norm else None
|
||
|
self.act = act_func(inplace=True) if act_func else None
|
||
|
|
||
|
def forward(self, x):
|
||
|
if self.dropout is not None:
|
||
|
x = self.dropout(x)
|
||
|
x = self.conv(x)
|
||
|
if self.norm:
|
||
|
x = self.norm(x)
|
||
|
if self.act:
|
||
|
x = self.act(x)
|
||
|
return x
|
||
|
|
||
|
|
||
|
class DSConv(nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
in_channels: int,
|
||
|
out_channels: int,
|
||
|
kernel_size=3,
|
||
|
stride=1,
|
||
|
use_bias=False,
|
||
|
norm=(nn.BatchNorm2d, nn.BatchNorm2d),
|
||
|
act_func=(nn.ReLU6, None),
|
||
|
):
|
||
|
super(DSConv, self).__init__()
|
||
|
|
||
|
use_bias = val2tuple(use_bias, 2)
|
||
|
norm = val2tuple(norm, 2)
|
||
|
act_func = val2tuple(act_func, 2)
|
||
|
|
||
|
self.depth_conv = ConvLayer(
|
||
|
in_channels,
|
||
|
in_channels,
|
||
|
kernel_size,
|
||
|
stride,
|
||
|
groups=in_channels,
|
||
|
norm=norm[0],
|
||
|
act_func=act_func[0],
|
||
|
use_bias=use_bias[0],
|
||
|
)
|
||
|
self.point_conv = ConvLayer(
|
||
|
in_channels,
|
||
|
out_channels,
|
||
|
1,
|
||
|
norm=norm[1],
|
||
|
act_func=act_func[1],
|
||
|
use_bias=use_bias[1],
|
||
|
)
|
||
|
|
||
|
def forward(self, x):
|
||
|
x = self.depth_conv(x)
|
||
|
x = self.point_conv(x)
|
||
|
return x
|
||
|
|
||
|
|
||
|
class MBConv(nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
in_channels: int,
|
||
|
out_channels: int,
|
||
|
kernel_size=3,
|
||
|
stride=1,
|
||
|
mid_channels=None,
|
||
|
expand_ratio=6,
|
||
|
use_bias=False,
|
||
|
norm=(nn.BatchNorm2d, nn.BatchNorm2d, nn.BatchNorm2d),
|
||
|
act_func=(nn.ReLU6, nn.ReLU6, None),
|
||
|
):
|
||
|
super(MBConv, self).__init__()
|
||
|
|
||
|
use_bias = val2tuple(use_bias, 3)
|
||
|
norm = val2tuple(norm, 3)
|
||
|
act_func = val2tuple(act_func, 3)
|
||
|
mid_channels = mid_channels or round(in_channels * expand_ratio)
|
||
|
|
||
|
self.inverted_conv = ConvLayer(
|
||
|
in_channels,
|
||
|
mid_channels,
|
||
|
1,
|
||
|
stride=1,
|
||
|
norm=norm[0],
|
||
|
act_func=act_func[0],
|
||
|
use_bias=use_bias[0],
|
||
|
)
|
||
|
self.depth_conv = ConvLayer(
|
||
|
mid_channels,
|
||
|
mid_channels,
|
||
|
kernel_size,
|
||
|
stride=stride,
|
||
|
groups=mid_channels,
|
||
|
norm=norm[1],
|
||
|
act_func=act_func[1],
|
||
|
use_bias=use_bias[1],
|
||
|
)
|
||
|
self.point_conv = ConvLayer(
|
||
|
mid_channels,
|
||
|
out_channels,
|
||
|
1,
|
||
|
norm=norm[2],
|
||
|
act_func=act_func[2],
|
||
|
use_bias=use_bias[2],
|
||
|
)
|
||
|
|
||
|
def forward(self, x):
|
||
|
x = self.inverted_conv(x)
|
||
|
x = self.depth_conv(x)
|
||
|
x = self.point_conv(x)
|
||
|
return x
|
||
|
|
||
|
|
||
|
class LiteMSA(nn.Module):
|
||
|
"""Lightweight multi-scale attention"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
in_channels: int,
|
||
|
out_channels: int,
|
||
|
heads: int or None = None,
|
||
|
heads_ratio: float = 1.0,
|
||
|
dim=8,
|
||
|
use_bias=False,
|
||
|
norm=(None, nn.BatchNorm2d),
|
||
|
act_func=(None, None),
|
||
|
kernel_func=nn.ReLU,
|
||
|
scales=(5,),
|
||
|
):
|
||
|
super(LiteMSA, self).__init__()
|
||
|
heads = heads or int(in_channels // dim * heads_ratio)
|
||
|
|
||
|
total_dim = heads * dim
|
||
|
|
||
|
use_bias = val2tuple(use_bias, 2)
|
||
|
norm = val2tuple(norm, 2)
|
||
|
act_func = val2tuple(act_func, 2)
|
||
|
|
||
|
self.dim = dim
|
||
|
self.qkv = ConvLayer(
|
||
|
in_channels,
|
||
|
3 * total_dim,
|
||
|
1,
|
||
|
use_bias=use_bias[0],
|
||
|
norm=norm[0],
|
||
|
act_func=act_func[0],
|
||
|
)
|
||
|
self.aggreg = nn.ModuleList(
|
||
|
[
|
||
|
nn.Sequential(
|
||
|
nn.Conv2d(
|
||
|
3 * total_dim,
|
||
|
3 * total_dim,
|
||
|
scale,
|
||
|
padding=get_same_padding(scale),
|
||
|
groups=3 * total_dim,
|
||
|
bias=use_bias[0],
|
||
|
),
|
||
|
nn.Conv2d(3 * total_dim, 3 * total_dim, 1, groups=3 * heads, bias=use_bias[0]),
|
||
|
)
|
||
|
for scale in scales
|
||
|
]
|
||
|
)
|
||
|
self.kernel_func = kernel_func(inplace=False)
|
||
|
|
||
|
self.proj = ConvLayer(
|
||
|
total_dim * (1 + len(scales)),
|
||
|
out_channels,
|
||
|
1,
|
||
|
use_bias=use_bias[1],
|
||
|
norm=norm[1],
|
||
|
act_func=act_func[1],
|
||
|
)
|
||
|
|
||
|
def forward(self, x):
|
||
|
B, _, H, W = list(x.size())
|
||
|
|
||
|
# generate multi-scale q, k, v
|
||
|
qkv = self.qkv(x)
|
||
|
multi_scale_qkv = [qkv]
|
||
|
for op in self.aggreg:
|
||
|
multi_scale_qkv.append(op(qkv))
|
||
|
multi_scale_qkv = torch.cat(multi_scale_qkv, dim=1)
|
||
|
|
||
|
multi_scale_qkv = torch.reshape(
|
||
|
multi_scale_qkv,
|
||
|
(
|
||
|
B,
|
||
|
-1,
|
||
|
3 * self.dim,
|
||
|
H * W,
|
||
|
),
|
||
|
)
|
||
|
multi_scale_qkv = torch.transpose(multi_scale_qkv, -1, -2)
|
||
|
q, k, v = (
|
||
|
multi_scale_qkv[..., 0 : self.dim],
|
||
|
multi_scale_qkv[..., self.dim : 2 * self.dim],
|
||
|
multi_scale_qkv[..., 2 * self.dim :],
|
||
|
)
|
||
|
|
||
|
# lightweight global attention
|
||
|
q = self.kernel_func(q)
|
||
|
k = self.kernel_func(k)
|
||
|
|
||
|
trans_k = k.transpose(-1, -2)
|
||
|
|
||
|
v = F.pad(v, (0, 1), mode="constant", value=1)
|
||
|
kv = torch.matmul(trans_k, v)
|
||
|
out = torch.matmul(q, kv)
|
||
|
out = out[..., :-1] / (out[..., -1:] + 1e-15)
|
||
|
|
||
|
# final projecttion
|
||
|
out = torch.transpose(out, -1, -2)
|
||
|
out = torch.reshape(out, (B, -1, H, W))
|
||
|
out = self.proj(out)
|
||
|
|
||
|
return out
|
||
|
|
||
|
class EfficientViTBlock(nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
in_channels,
|
||
|
heads_ratio=1.0,
|
||
|
dim=32,
|
||
|
expand_ratio=4,
|
||
|
norm=nn.BatchNorm2d,
|
||
|
act_func=nn.Hardswish,
|
||
|
):
|
||
|
super(EfficientViTBlock, self).__init__()
|
||
|
self.context_module = ResidualBlock(
|
||
|
LiteMSA(
|
||
|
in_channels=in_channels,
|
||
|
out_channels=in_channels,
|
||
|
heads_ratio=heads_ratio,
|
||
|
dim=dim,
|
||
|
norm=(None, norm),
|
||
|
),
|
||
|
nn.Identity(),
|
||
|
)
|
||
|
local_module = MBConv(
|
||
|
in_channels=in_channels,
|
||
|
out_channels=in_channels,
|
||
|
expand_ratio=expand_ratio,
|
||
|
use_bias=(True, True, False),
|
||
|
norm=(None, None, norm),
|
||
|
act_func=(act_func, act_func, None),
|
||
|
)
|
||
|
self.local_module = ResidualBlock(local_module, nn.Identity())
|
||
|
|
||
|
def forward(self, x):
|
||
|
x = self.context_module(x)
|
||
|
x = self.local_module(x)
|
||
|
return x
|
||
|
|
||
|
|
||
|
class ResidualBlock(nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
main: nn.Module or None,
|
||
|
shortcut: nn.Module or None,
|
||
|
post_act=None,
|
||
|
pre_norm: nn.Module or None = None,
|
||
|
):
|
||
|
super(ResidualBlock, self).__init__()
|
||
|
|
||
|
self.pre_norm = pre_norm
|
||
|
self.main = main
|
||
|
self.shortcut = shortcut
|
||
|
self.post_act = post_act(inplace=True) if post_act else nn.Identity()
|
||
|
|
||
|
def forward_main(self, x):
|
||
|
if self.pre_norm is None:
|
||
|
return self.main(x)
|
||
|
else:
|
||
|
return self.main(self.pre_norm(x))
|
||
|
|
||
|
def forward(self, x):
|
||
|
if self.main is None:
|
||
|
res = x
|
||
|
elif self.shortcut is None:
|
||
|
res = self.forward_main(x)
|
||
|
else:
|
||
|
res = self.forward_main(x) + self.shortcut(x)
|
||
|
if self.post_act:
|
||
|
res = self.post_act(res)
|
||
|
return res
|
||
|
|
||
|
class ClsHead(nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
in_channels,
|
||
|
width_list,
|
||
|
n_classes=1000,
|
||
|
dropout=0,
|
||
|
norm=nn.BatchNorm2d,
|
||
|
act_func=nn.Hardswish,
|
||
|
global_pool='avg',
|
||
|
):
|
||
|
super(ClsHead, self).__init__()
|
||
|
self.ops = nn.Sequential(
|
||
|
ConvLayer(in_channels, width_list[0], 1, norm=norm, act_func=act_func),
|
||
|
SelectAdaptivePool2d(pool_type=global_pool, flatten=True, input_fmt='NCHW'),
|
||
|
nn.Linear(width_list[0], width_list[1], bias=False),
|
||
|
nn.LayerNorm(width_list[1]),
|
||
|
act_func(inplace=True),
|
||
|
nn.Dropout(dropout, inplace=False) if dropout else nn.Identity(),
|
||
|
nn.Linear(width_list[1], n_classes, bias=True),
|
||
|
)
|
||
|
|
||
|
def forward(self, x):
|
||
|
x = self.ops(x)
|
||
|
return x
|
||
|
|
||
|
|
||
|
class EfficientViT(nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
in_chans=3,
|
||
|
width_list=[],
|
||
|
depth_list=[],
|
||
|
dim=32,
|
||
|
expand_ratio=4,
|
||
|
norm=nn.BatchNorm2d,
|
||
|
act_func=nn.Hardswish,
|
||
|
global_pool='avg',
|
||
|
head_width_list=[],
|
||
|
head_dropout=0.0,
|
||
|
num_classes=1000,
|
||
|
):
|
||
|
super(EfficientViT, self).__init__()
|
||
|
self.grad_checkpointing = False
|
||
|
self.global_pool = global_pool
|
||
|
# input stem
|
||
|
input_stem = [
|
||
|
('in_conv', ConvLayer(
|
||
|
in_channels=3,
|
||
|
out_channels=width_list[0],
|
||
|
kernel_size=3,
|
||
|
stride=2,
|
||
|
norm=norm,
|
||
|
act_func=act_func,
|
||
|
))
|
||
|
]
|
||
|
stem_block = 0
|
||
|
for _ in range(depth_list[0]):
|
||
|
block = self.build_local_block(
|
||
|
in_channels=width_list[0],
|
||
|
out_channels=width_list[0],
|
||
|
stride=1,
|
||
|
expand_ratio=1,
|
||
|
norm=norm,
|
||
|
act_func=act_func,
|
||
|
)
|
||
|
input_stem.append((f'res{stem_block}', ResidualBlock(block, nn.Identity())))
|
||
|
stem_block += 1
|
||
|
in_channels = width_list[0]
|
||
|
self.stem = nn.Sequential(OrderedDict(input_stem))
|
||
|
|
||
|
self.feature_info = []
|
||
|
stages = []
|
||
|
stage_idx = 0
|
||
|
for w, d in zip(width_list[1:3], depth_list[1:3]):
|
||
|
stage = []
|
||
|
for i in range(d):
|
||
|
stride = 2 if i == 0 else 1
|
||
|
block = self.build_local_block(
|
||
|
in_channels=in_channels,
|
||
|
out_channels=w,
|
||
|
stride=stride,
|
||
|
expand_ratio=expand_ratio,
|
||
|
norm=norm,
|
||
|
act_func=act_func,
|
||
|
)
|
||
|
block = ResidualBlock(block, nn.Identity() if stride == 1 else None)
|
||
|
stage.append(block)
|
||
|
in_channels = w
|
||
|
stages.append(nn.Sequential(*stage))
|
||
|
self.feature_info += [dict(num_chs=in_channels, reduction=stride, module=f'stages.{stage_idx}')]
|
||
|
stage_idx += 1
|
||
|
|
||
|
for w, d in zip(width_list[3:], depth_list[3:]):
|
||
|
stage = []
|
||
|
block = self.build_local_block(
|
||
|
in_channels=in_channels,
|
||
|
out_channels=w,
|
||
|
stride=2,
|
||
|
expand_ratio=expand_ratio,
|
||
|
norm=norm,
|
||
|
act_func=act_func,
|
||
|
fewer_norm=True,
|
||
|
)
|
||
|
stage.append(ResidualBlock(block, None))
|
||
|
in_channels = w
|
||
|
|
||
|
for _ in range(d):
|
||
|
stage.append(
|
||
|
EfficientViTBlock(
|
||
|
in_channels=in_channels,
|
||
|
dim=dim,
|
||
|
expand_ratio=expand_ratio,
|
||
|
norm=norm,
|
||
|
act_func=act_func,
|
||
|
)
|
||
|
)
|
||
|
stages.append(nn.Sequential(*stage))
|
||
|
self.feature_info += [dict(num_chs=in_channels, reduction=stride, module=f'stages.{stage_idx}')]
|
||
|
stage_idx += 1
|
||
|
|
||
|
self.stages = nn.Sequential(*stages)
|
||
|
self.num_features = in_channels
|
||
|
self.head_width_list = head_width_list
|
||
|
self.head_dropout = head_dropout
|
||
|
if num_classes > 0:
|
||
|
self.head = ClsHead(self.num_features, self.head_width_list, n_classes=num_classes, dropout=self.head_dropout, global_pool=self.global_pool)
|
||
|
else:
|
||
|
if global_pool is not None:
|
||
|
self.head = SelectAdaptivePool2d(pool_type=global_pool, flatten=True, input_fmt='NCHW')
|
||
|
else:
|
||
|
self.head = nn.Identity()
|
||
|
|
||
|
@staticmethod
|
||
|
def build_local_block(
|
||
|
in_channels: int,
|
||
|
out_channels: int,
|
||
|
stride: int,
|
||
|
expand_ratio: float,
|
||
|
norm: str,
|
||
|
act_func: str,
|
||
|
fewer_norm: bool = False,
|
||
|
):
|
||
|
if expand_ratio == 1:
|
||
|
block = DSConv(
|
||
|
in_channels=in_channels,
|
||
|
out_channels=out_channels,
|
||
|
stride=stride,
|
||
|
use_bias=(True, False) if fewer_norm else False,
|
||
|
norm=(None, norm) if fewer_norm else norm,
|
||
|
act_func=(act_func, None),
|
||
|
)
|
||
|
else:
|
||
|
block = MBConv(
|
||
|
in_channels=in_channels,
|
||
|
out_channels=out_channels,
|
||
|
stride=stride,
|
||
|
expand_ratio=expand_ratio,
|
||
|
use_bias=(True, True, False) if fewer_norm else False,
|
||
|
norm=(None, None, norm) if fewer_norm else norm,
|
||
|
act_func=(act_func, act_func, None),
|
||
|
)
|
||
|
return block
|
||
|
|
||
|
@torch.jit.ignore
|
||
|
def group_matcher(self, coarse=False):
|
||
|
matcher = dict(
|
||
|
stem=r'^stem', # stem and embed
|
||
|
blocks=[(r'^stages\.(\d+)', None)]
|
||
|
)
|
||
|
return matcher
|
||
|
|
||
|
@torch.jit.ignore
|
||
|
def set_grad_checkpointing(self, enable=True):
|
||
|
self.grad_checkpointing = enable
|
||
|
|
||
|
@torch.jit.ignore
|
||
|
def get_classifier(self):
|
||
|
return self.head
|
||
|
|
||
|
def reset_classifier(self, num_classes, global_pool=None, dropout=0):
|
||
|
self.num_classes = num_classes
|
||
|
if global_pool is not None:
|
||
|
self.global_pool = global_pool
|
||
|
if num_classes > 0:
|
||
|
self.head = ClsHead(self.num_features, self.head_width_list, n_classes=num_classes, dropout=self.head_dropout, global_pool=global_pool)
|
||
|
else:
|
||
|
if global_pool is not None:
|
||
|
self.head = SelectAdaptivePool2d(pool_type=global_pool, flatten=True, input_fmt='NCHW')
|
||
|
else:
|
||
|
self.head = nn.Identity()
|
||
|
|
||
|
def forward_features(self, x):
|
||
|
x = self.stem(x)
|
||
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||
|
x = checkpoint_seq(self.stages, x)
|
||
|
else:
|
||
|
x = self.stages(x)
|
||
|
return x
|
||
|
|
||
|
def forward_head(self, x, pre_logits: bool = False):
|
||
|
return x if pre_logits else self.head(x)
|
||
|
|
||
|
def forward(self, x):
|
||
|
x = self.forward_features(x)
|
||
|
x = self.forward_head(x)
|
||
|
return x
|
||
|
|
||
|
|
||
|
def checkpoint_filter_fn(state_dict, model):
|
||
|
target_keys = list(model.state_dict().keys())
|
||
|
if 'state_dict' in state_dict.keys():
|
||
|
state_dict = state_dict['state_dict']
|
||
|
out_dict = {}
|
||
|
for i, (k, v) in enumerate(state_dict.items()):
|
||
|
out_dict[target_keys[i]] = v
|
||
|
return out_dict
|
||
|
|
||
|
|
||
|
def _cfg(url='', **kwargs):
|
||
|
return {
|
||
|
'url': url,
|
||
|
'num_classes': 1000,
|
||
|
'mean': IMAGENET_DEFAULT_MEAN,
|
||
|
'std': IMAGENET_DEFAULT_STD,
|
||
|
'first_conv': 'stem.in_conv.conv',
|
||
|
'classifier': 'head',
|
||
|
**kwargs,
|
||
|
}
|
||
|
|
||
|
|
||
|
default_cfgs = generate_default_cfgs(
|
||
|
{
|
||
|
'efficientvit_b0.r224_in1k': _cfg(
|
||
|
# url='https://drive.google.com/file/d/1ganFBZmmvCTpgUwiLb8ePD6NBNxRyZDk/view?usp=drive_link'
|
||
|
),
|
||
|
'efficientvit_b1.r224_in1k': _cfg(
|
||
|
# url='https://drive.google.com/file/d/1hKN_hvLG4nmRzbfzKY7GlqwpR5uKpOOk/view?usp=share_link'
|
||
|
),
|
||
|
'efficientvit_b1.r256_in1k': _cfg(
|
||
|
# url='https://drive.google.com/file/d/1hXcG_jB0ODMOESsSkzVye-58B4F3Cahs/view?usp=share_link'
|
||
|
),
|
||
|
'efficientvit_b1.r288_in1k': _cfg(
|
||
|
# url='https://drive.google.com/file/d/1sE_Suz9gOOUO7o5r9eeAT4nKK8Hrbhsu/view?usp=share_link'
|
||
|
),
|
||
|
'efficientvit_b2.r224_in1k': _cfg(
|
||
|
# url='https://drive.google.com/file/d/1DiM-iqVGTrq4te8mefHl3e1c12u4qR7d/view?usp=share_link'
|
||
|
),
|
||
|
'efficientvit_b2.r256_in1k': _cfg(
|
||
|
# url='https://drive.google.com/file/d/192OOk4ISitwlyW979M-FSJ_fYMMW9HQz/view?usp=share_link'
|
||
|
),
|
||
|
'efficientvit_b2.r288_in1k': _cfg(
|
||
|
# url='https://drive.google.com/file/d/1aodcepOyne667hvBAGpf9nDwmd5g0NpU/view?usp=share_link'
|
||
|
),
|
||
|
'efficientvit_b3.r224_in1k': _cfg(
|
||
|
# url='https://drive.google.com/file/d/18RZDGLiY8KsyJ7LGic4mg1JHwd-a_ky6/view?usp=share_link'
|
||
|
),
|
||
|
'efficientvit_b3.r256_in1k': _cfg(
|
||
|
# url='https://drive.google.com/file/d/1y1rnir4I0XiId-oTCcHhs7jqnrHGFi-g/view?usp=share_link'
|
||
|
),
|
||
|
'efficientvit_b3.r288_in1k': _cfg(
|
||
|
# url='https://drive.google.com/file/d/1KfwbGtlyFgslNr4LIHERv6aCfkItEvRk/view?usp=share_link'
|
||
|
),
|
||
|
}
|
||
|
)
|
||
|
|
||
|
|
||
|
def _create_efficientvit(variant, pretrained=False, **kwargs):
|
||
|
model = build_model_with_cfg(
|
||
|
EfficientViT,
|
||
|
variant,
|
||
|
pretrained,
|
||
|
pretrained_filter_fn=checkpoint_filter_fn,
|
||
|
**kwargs
|
||
|
)
|
||
|
return model
|
||
|
|
||
|
|
||
|
@register_model
|
||
|
def efficientvit_b0(pretrained=False, **kwargs):
|
||
|
model_args = dict(width_list=[8, 16, 32, 64, 128], depth_list=[1, 2, 2, 2, 2], dim=16, head_width_list=[1024, 1280])
|
||
|
return _create_efficientvit('efficientvit_b0', pretrained=pretrained, **dict(model_args, **kwargs))
|
||
|
|
||
|
@register_model
|
||
|
def efficientvit_b1(pretrained=False, **kwargs):
|
||
|
model_args = dict(width_list=[16, 32, 64, 128, 256], depth_list=[1, 2, 3, 3, 4], dim=16, head_width_list=[1536, 1600])
|
||
|
return _create_efficientvit('efficientvit_b0', pretrained=pretrained, **dict(model_args, **kwargs))
|
||
|
|
||
|
@register_model
|
||
|
def efficientvit_b2(pretrained=False, **kwargs):
|
||
|
model_args = dict(width_list=[24, 48, 96, 192, 384], depth_list=[1, 3, 4, 4, 6], dim=32, head_width_list=[2304, 2560])
|
||
|
return _create_efficientvit('efficientvit_b0', pretrained=pretrained, **dict(model_args, **kwargs))
|
||
|
|
||
|
@register_model
|
||
|
def efficientvit_b3(pretrained=False, **kwargs):
|
||
|
model_args = dict(width_list=[32, 64, 128, 256, 512], depth_list=[1, 4, 6, 6, 9], dim=32, head_width_list=[2304, 2560])
|
||
|
return _create_efficientvit('efficientvit_b0', pretrained=pretrained, **dict(model_args, **kwargs))
|