Initial MobileNetV4 pass
parent
23f09af08e
commit
6a8bb03330
|
@ -1,6 +1,7 @@
|
|||
from .activations import *
|
||||
from .adaptive_avgmax_pool import \
|
||||
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
|
||||
from .attention2d import MultiQueryAttention2d, Attention2d, MultiQueryAttentionV2
|
||||
from .attention_pool import AttentionPoolLatent
|
||||
from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
|
||||
from .blur_pool import BlurPool2d
|
||||
|
|
|
@ -2,15 +2,19 @@
|
|||
|
||||
Hacked together by / Copyright 2019, Ross Wightman
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from timm.layers import create_conv2d, DropPath, make_divisible, create_act_layer, get_norm_act_layer
|
||||
from timm.layers import create_conv2d, DropPath, make_divisible, create_act_layer, to_2tuple,\
|
||||
get_norm_act_layer, MultiQueryAttention2d, MultiQueryAttentionV2, Attention2d
|
||||
|
||||
__all__ = [
|
||||
'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual']
|
||||
'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual',
|
||||
'UniversalInvertedResidual', 'MobileAttention'
|
||||
]
|
||||
|
||||
|
||||
def num_groups(group_size, channels):
|
||||
|
@ -85,7 +89,8 @@ class ConvBnAct(nn.Module):
|
|||
self.has_skip = skip and stride == 1 and in_chs == out_chs
|
||||
|
||||
self.conv = create_conv2d(
|
||||
in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, groups=groups, padding=pad_type)
|
||||
in_chs, out_chs, kernel_size,
|
||||
stride=stride, dilation=dilation, groups=groups, padding=pad_type)
|
||||
self.bn1 = norm_act_layer(out_chs, inplace=True)
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity()
|
||||
|
||||
|
@ -105,7 +110,7 @@ class ConvBnAct(nn.Module):
|
|||
|
||||
|
||||
class DepthwiseSeparableConv(nn.Module):
|
||||
""" DepthwiseSeparable block
|
||||
""" Depthwise-separable block
|
||||
Used for DS convs in MobileNet-V1 and in the place of IR blocks that have no expansion
|
||||
(factor of 1.0). This is an alternative to having a IR with an optional first pw conv.
|
||||
"""
|
||||
|
@ -139,16 +144,19 @@ class DepthwiseSeparableConv(nn.Module):
|
|||
self.conv_s2d = create_conv2d(
|
||||
in_chs, sd_chs, kernel_size=2, stride=2, padding=0) #'same')
|
||||
self.bn_s2d = norm_act_layer(sd_chs, sd_chs)
|
||||
dw_kernel_size = (dw_kernel_size + 1) // 2
|
||||
dw_pad_type = 'same' if dw_kernel_size == 2 else pad_type
|
||||
in_chs = sd_chs
|
||||
else:
|
||||
self.conv_s2d = None
|
||||
self.bn_s2d = None
|
||||
dw_pad_type = pad_type
|
||||
|
||||
groups = num_groups(group_size, in_chs)
|
||||
|
||||
dw_pad_type = 'same' if dw_kernel_size == 2 else pad_type
|
||||
self.conv_dw = create_conv2d(
|
||||
in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=dw_pad_type, groups=groups)
|
||||
in_chs, in_chs, dw_kernel_size,
|
||||
stride=stride, dilation=dilation, padding=dw_pad_type, groups=groups)
|
||||
self.bn1 = norm_act_layer(in_chs, inplace=True)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
|
@ -222,10 +230,13 @@ class InvertedResidual(nn.Module):
|
|||
sd_chs = int(in_chs * 4)
|
||||
self.conv_s2d = create_conv2d(in_chs, sd_chs, kernel_size=2, stride=2, padding=pad_type)
|
||||
self.bn_s2d = norm_act_layer(sd_chs, sd_chs)
|
||||
dw_kernel_size = (dw_kernel_size + 1) // 2
|
||||
dw_pad_type = 'same' if dw_kernel_size == 2 else pad_type
|
||||
in_chs = sd_chs
|
||||
else:
|
||||
self.conv_s2d = None
|
||||
self.bn_s2d = None
|
||||
dw_pad_type = pad_type
|
||||
|
||||
mid_chs = make_divisible(in_chs * exp_ratio)
|
||||
groups = num_groups(group_size, mid_chs)
|
||||
|
@ -236,8 +247,8 @@ class InvertedResidual(nn.Module):
|
|||
|
||||
# Depth-wise convolution
|
||||
self.conv_dw = create_conv2d(
|
||||
mid_chs, mid_chs, dw_kernel_size, stride=stride, dilation=dilation,
|
||||
groups=groups, padding=pad_type, **conv_kwargs)
|
||||
mid_chs, mid_chs, dw_kernel_size,
|
||||
stride=stride, dilation=dilation, groups=groups, padding=dw_pad_type, **conv_kwargs)
|
||||
self.bn2 = norm_act_layer(mid_chs, inplace=True)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
|
@ -271,6 +282,267 @@ class InvertedResidual(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class LayerScale2d(nn.Module):
|
||||
def __init__(self, dim, init_values=1e-5, inplace=False):
|
||||
super().__init__()
|
||||
self.inplace = inplace
|
||||
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
||||
|
||||
def forward(self, x):
|
||||
gamma = self.gamma.view(1, -1, 1, 1)
|
||||
return x.mul_(gamma) if self.inplace else x * gamma
|
||||
|
||||
|
||||
class UniversalInvertedResidual(nn.Module):
|
||||
""" Universal Inverted Residual Block
|
||||
|
||||
For MobileNetV4 - https://arxiv.org/abs/
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_chs,
|
||||
out_chs,
|
||||
dw_kernel_size_start: int = 0,
|
||||
dw_kernel_size_mid: int = 3,
|
||||
dw_kernel_size_end: int = 0,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
group_size=1,
|
||||
pad_type='',
|
||||
noskip=False,
|
||||
exp_ratio=1.0,
|
||||
act_layer=nn.ReLU,
|
||||
dw_act_layer=None,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
se_layer=None,
|
||||
conv_kwargs=None,
|
||||
drop_path_rate=0.,
|
||||
layer_scale_init_value: Optional[float] = 1e-5,
|
||||
):
|
||||
super(UniversalInvertedResidual, self).__init__()
|
||||
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
|
||||
dw_act_layer = dw_act_layer or act_layer
|
||||
dw_norm_act_layer = get_norm_act_layer(norm_layer, dw_act_layer)
|
||||
conv_kwargs = conv_kwargs or {}
|
||||
self.has_skip = (in_chs == out_chs and stride == 1) and not noskip
|
||||
|
||||
# FIXME dilation isn't right w/ extra ks > 1 convs
|
||||
if dw_kernel_size_start:
|
||||
self.conv_dw_start = create_conv2d(
|
||||
in_chs, in_chs, dw_kernel_size_start,
|
||||
dilation=dilation, # FIXME
|
||||
depthwise=True,
|
||||
padding=pad_type,
|
||||
**conv_kwargs,
|
||||
)
|
||||
self.norm_dw_start = dw_norm_act_layer(in_chs, apply_act=False)
|
||||
else:
|
||||
self.conv_dw_start = nn.Identity()
|
||||
self.norm_dw_start = nn.Identity()
|
||||
|
||||
# Point-wise expansion
|
||||
mid_chs = make_divisible(in_chs * exp_ratio)
|
||||
self.conv_pw = create_conv2d(in_chs, mid_chs, 1, padding=pad_type, **conv_kwargs)
|
||||
self.norm_pw = norm_act_layer(mid_chs, inplace=True)
|
||||
|
||||
# Depth-wise convolution
|
||||
if dw_kernel_size_mid:
|
||||
groups = num_groups(group_size, mid_chs)
|
||||
self.conv_dw_mid = create_conv2d(
|
||||
mid_chs, mid_chs, dw_kernel_size_mid,
|
||||
stride=stride,
|
||||
dilation=dilation, # FIXME
|
||||
groups=groups,
|
||||
padding=pad_type,
|
||||
**conv_kwargs,
|
||||
)
|
||||
self.norm_dw_mid = dw_norm_act_layer(mid_chs, inplace=True)
|
||||
else:
|
||||
self.conv_dw_mid = nn.Identity()
|
||||
self.norm_dw_mid = nn.Identity()
|
||||
|
||||
# Squeeze-and-excitation
|
||||
self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity()
|
||||
|
||||
# Point-wise linear projection
|
||||
self.conv_pwl = create_conv2d(mid_chs, out_chs, 1, padding=pad_type, **conv_kwargs)
|
||||
self.norm_pwl = norm_act_layer(out_chs, apply_act=False)
|
||||
|
||||
if dw_kernel_size_end:
|
||||
self.conv_dw_end = create_conv2d(
|
||||
out_chs, out_chs, dw_kernel_size_end,
|
||||
dilation=dilation,
|
||||
depthwise=True,
|
||||
padding=pad_type,
|
||||
**conv_kwargs,
|
||||
)
|
||||
self.norm_dw_end = dw_norm_act_layer(out_chs, apply_act=False)
|
||||
else:
|
||||
# dw_end rarely used so keeping it out of repr by not using None instead of nn.Identitty()
|
||||
self.conv_dw_end = None
|
||||
self.norm_dw_end = None
|
||||
|
||||
if layer_scale_init_value is not None:
|
||||
self.layer_scale = LayerScale2d(out_chs, layer_scale_init_value)
|
||||
else:
|
||||
self.layer_scale = nn.Identity()
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity()
|
||||
|
||||
def feature_info(self, location):
|
||||
if location == 'expansion': # after SE, input to PWL
|
||||
return dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels)
|
||||
else: # location == 'bottleneck', block output
|
||||
return dict(module='', num_chs=self.conv_pwl.out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
shortcut = x
|
||||
x = self.conv_dw_start(x)
|
||||
x = self.norm_dw_start(x)
|
||||
x = self.conv_pw(x)
|
||||
x = self.norm_pw(x)
|
||||
x = self.conv_dw_mid(x)
|
||||
x = self.norm_dw_mid(x)
|
||||
x = self.se(x)
|
||||
x = self.conv_pwl(x)
|
||||
x = self.norm_pwl(x)
|
||||
if self.conv_dw_end is not None:
|
||||
x = self.conv_dw_end(x)
|
||||
x = self.norm_dw_end(x)
|
||||
x = self.layer_scale(x)
|
||||
if self.has_skip:
|
||||
x = self.drop_path(x) + shortcut
|
||||
return x
|
||||
|
||||
|
||||
class MobileAttention(nn.Module):
|
||||
""" Mobile Attention Block
|
||||
|
||||
For MobileNetV4 - https://arxiv.org/abs/
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
in_chs,
|
||||
out_chs,
|
||||
stride=1,
|
||||
dw_kernel_size=3,
|
||||
dilation=1,
|
||||
group_size=1,
|
||||
pad_type='',
|
||||
num_heads: int = 8,
|
||||
key_dim: int = 64,
|
||||
value_dim: int = 64,
|
||||
use_multi_query: bool = False,
|
||||
query_strides: int = (1, 1),
|
||||
kv_stride: int = 1,
|
||||
cpe_dw_kernel_size=3,
|
||||
noskip=False,
|
||||
act_layer=nn.ReLU,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
drop_path_rate=0.,
|
||||
attn_drop=0.0,
|
||||
proj_drop=0.0,
|
||||
layer_scale_init_value: Optional[float] = 1e-5,
|
||||
use_bias=False,
|
||||
use_cpe=False,
|
||||
):
|
||||
super(MobileAttention, self).__init__()
|
||||
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
|
||||
self.has_skip = (stride == 1 and in_chs == out_chs) and not noskip
|
||||
self.query_strides = to_2tuple(query_strides)
|
||||
self.kv_stride = kv_stride
|
||||
self.has_query_stride = any([s > 1 for s in self.query_strides])
|
||||
|
||||
# This CPE is different than the one suggested in the original paper.
|
||||
# https://arxiv.org/abs/2102.10882
|
||||
# 1. Rather than adding one CPE before the attention blocks, we add a CPE
|
||||
# into every attention block.
|
||||
# 2. We replace the expensive Conv2D by a Seperable DW Conv.
|
||||
if use_cpe:
|
||||
self.conv_cpe_dw = create_conv2d(
|
||||
in_chs, in_chs,
|
||||
kernel_size=cpe_dw_kernel_size,
|
||||
dilation=dilation,
|
||||
depthwise=True,
|
||||
bias=True,
|
||||
)
|
||||
else:
|
||||
self.conv_cpe_dw = None
|
||||
|
||||
self.norm = norm_act_layer(in_chs, apply_act=False)
|
||||
|
||||
if num_heads is None:
|
||||
assert in_chs % key_dim == 0
|
||||
num_heads = in_chs // key_dim
|
||||
|
||||
if use_multi_query:
|
||||
#if self.has_query_stride or self.kv_stride > 1:
|
||||
self.attn = (
|
||||
MultiQueryAttention2d(
|
||||
in_chs,
|
||||
dim_out=out_chs,
|
||||
num_heads=num_heads,
|
||||
key_dim=key_dim,
|
||||
value_dim=value_dim,
|
||||
query_strides=query_strides,
|
||||
kv_stride=kv_stride,
|
||||
dilation=dilation,
|
||||
padding=pad_type,
|
||||
dw_kernel_size=dw_kernel_size,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=proj_drop,
|
||||
#bias=use_bias, # why not here if used w/ mhsa?
|
||||
)
|
||||
)
|
||||
# else:
|
||||
# self.attn = MultiQueryAttentionV2(
|
||||
# in_chs,
|
||||
# dim_out=out_chs,
|
||||
# num_heads=num_heads,
|
||||
# key_dim=key_dim,
|
||||
# value_dim=value_dim,
|
||||
# attn_drop=attn_drop,
|
||||
# proj_drop=proj_drop,
|
||||
# )
|
||||
else:
|
||||
self.attn = Attention2d(
|
||||
in_chs,
|
||||
dim_out=out_chs,
|
||||
num_heads=num_heads,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=proj_drop,
|
||||
bias=use_bias,
|
||||
)
|
||||
|
||||
if layer_scale_init_value is not None:
|
||||
self.layer_scale = LayerScale2d(out_chs, layer_scale_init_value)
|
||||
else:
|
||||
self.layer_scale = nn.Identity()
|
||||
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity()
|
||||
|
||||
|
||||
def feature_info(self, location):
|
||||
if location == 'expansion': # after SE, input to PW
|
||||
return dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels)
|
||||
else: # location == 'bottleneck', block output
|
||||
return dict(module='', num_chs=self.conv_pw.out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
if self.conv_cpe_dw is not None:
|
||||
x_cpe = self.conv_cpe_dw(x)
|
||||
x = x + x_cpe
|
||||
|
||||
shortcut = x
|
||||
x = self.norm(x)
|
||||
x = self.attn(x)
|
||||
x = self.layer_scale(x)
|
||||
if self.has_skip:
|
||||
x = self.drop_path(x) + shortcut
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class CondConvResidual(InvertedResidual):
|
||||
""" Inverted residual block w/ CondConv routing"""
|
||||
|
||||
|
@ -296,13 +568,24 @@ class CondConvResidual(InvertedResidual):
|
|||
|
||||
self.num_experts = num_experts
|
||||
conv_kwargs = dict(num_experts=self.num_experts)
|
||||
|
||||
super(CondConvResidual, self).__init__(
|
||||
in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, dilation=dilation, group_size=group_size,
|
||||
pad_type=pad_type, act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size,
|
||||
pw_kernel_size=pw_kernel_size, se_layer=se_layer, norm_layer=norm_layer, conv_kwargs=conv_kwargs,
|
||||
drop_path_rate=drop_path_rate)
|
||||
|
||||
in_chs,
|
||||
out_chs,
|
||||
dw_kernel_size=dw_kernel_size,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
group_size=group_size,
|
||||
pad_type=pad_type,
|
||||
act_layer=act_layer,
|
||||
noskip=noskip,
|
||||
exp_ratio=exp_ratio,
|
||||
exp_kernel_size=exp_kernel_size,
|
||||
pw_kernel_size=pw_kernel_size,
|
||||
se_layer=se_layer,
|
||||
norm_layer=norm_layer,
|
||||
conv_kwargs=conv_kwargs,
|
||||
drop_path_rate=drop_path_rate,
|
||||
)
|
||||
self.routing_fn = nn.Linear(in_chs, self.num_experts)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -362,7 +645,8 @@ class EdgeResidual(nn.Module):
|
|||
|
||||
# Expansion convolution
|
||||
self.conv_exp = create_conv2d(
|
||||
in_chs, mid_chs, exp_kernel_size, stride=stride, dilation=dilation, groups=groups, padding=pad_type)
|
||||
in_chs, mid_chs, exp_kernel_size,
|
||||
stride=stride, dilation=dilation, groups=groups, padding=pad_type)
|
||||
self.bn1 = norm_act_layer(mid_chs, inplace=True)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
|
|
|
@ -139,11 +139,10 @@ def _decode_block_str(block_str):
|
|||
|
||||
# if act_layer is None, the model default (passed to model init) will be used
|
||||
act_layer = options['n'] if 'n' in options else None
|
||||
exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1
|
||||
pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1
|
||||
start_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1
|
||||
end_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1
|
||||
force_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def
|
||||
num_repeat = int(options['r'])
|
||||
s2d = int(options['d']) if 'd' in options else 0
|
||||
|
||||
# each type of block has different valid arguments, fill accordingly
|
||||
block_args = dict(
|
||||
|
@ -155,31 +154,31 @@ def _decode_block_str(block_str):
|
|||
if block_type == 'ir':
|
||||
block_args.update(dict(
|
||||
dw_kernel_size=_parse_ksize(options['k']),
|
||||
exp_kernel_size=exp_kernel_size,
|
||||
pw_kernel_size=pw_kernel_size,
|
||||
exp_kernel_size=start_kernel_size,
|
||||
pw_kernel_size=end_kernel_size,
|
||||
exp_ratio=float(options['e']),
|
||||
se_ratio=float(options['se']) if 'se' in options else 0.,
|
||||
se_ratio=float(options.get('se', 0.)),
|
||||
noskip=skip is False,
|
||||
s2d=s2d > 0,
|
||||
s2d=int(options.get('d', 0)) > 0,
|
||||
))
|
||||
if 'cc' in options:
|
||||
block_args['num_experts'] = int(options['cc'])
|
||||
elif block_type == 'ds' or block_type == 'dsa':
|
||||
block_args.update(dict(
|
||||
dw_kernel_size=_parse_ksize(options['k']),
|
||||
pw_kernel_size=pw_kernel_size,
|
||||
se_ratio=float(options['se']) if 'se' in options else 0.,
|
||||
pw_kernel_size=end_kernel_size,
|
||||
se_ratio=float(options.get('se', 0.)),
|
||||
pw_act=block_type == 'dsa',
|
||||
noskip=block_type == 'dsa' or skip is False,
|
||||
s2d=s2d > 0,
|
||||
s2d=int(options.get('d', 0)) > 0,
|
||||
))
|
||||
elif block_type == 'er':
|
||||
block_args.update(dict(
|
||||
exp_kernel_size=_parse_ksize(options['k']),
|
||||
pw_kernel_size=pw_kernel_size,
|
||||
pw_kernel_size=end_kernel_size,
|
||||
exp_ratio=float(options['e']),
|
||||
force_in_chs=force_in_chs,
|
||||
se_ratio=float(options['se']) if 'se' in options else 0.,
|
||||
se_ratio=float(options.get('se', 0.)),
|
||||
noskip=skip is False,
|
||||
))
|
||||
elif block_type == 'cn':
|
||||
|
@ -187,6 +186,38 @@ def _decode_block_str(block_str):
|
|||
kernel_size=int(options['k']),
|
||||
skip=skip is True,
|
||||
))
|
||||
elif block_type == 'uir':
|
||||
# override exp / proj kernels for start/end in uir block
|
||||
start_kernel_size = _parse_ksize(options['a']) if 'a' in options else 0
|
||||
end_kernel_size = _parse_ksize(options['p']) if 'p' in options else 0
|
||||
block_args.update(dict(
|
||||
dw_kernel_size_start=start_kernel_size, # overload exp ks arg for dw start
|
||||
dw_kernel_size_mid=_parse_ksize(options['k']),
|
||||
dw_kernel_size_end=end_kernel_size, # overload pw ks arg for dw end
|
||||
exp_ratio=float(options['e']),
|
||||
se_ratio=float(options.get('se', 0.)),
|
||||
noskip=skip is False,
|
||||
))
|
||||
elif block_type == 'mha':
|
||||
kv_dim = int(options['d'])
|
||||
block_args.update(dict(
|
||||
dw_kernel_size=_parse_ksize(options['k']),
|
||||
num_heads=int(options['h']),
|
||||
key_dim=kv_dim,
|
||||
value_dim=kv_dim,
|
||||
kv_stride=int(options.get('v', 1)),
|
||||
noskip=skip is False,
|
||||
))
|
||||
elif block_type == 'mqa':
|
||||
kv_dim = int(options['d'])
|
||||
block_args.update(dict(
|
||||
dw_kernel_size=_parse_ksize(options['k']),
|
||||
num_heads=int(options['h']),
|
||||
key_dim=kv_dim,
|
||||
value_dim=kv_dim,
|
||||
kv_stride=int(options.get('v', 1)),
|
||||
noskip=skip is False,
|
||||
))
|
||||
else:
|
||||
assert False, 'Unknown block type (%s)' % block_type
|
||||
if 'gs' in options:
|
||||
|
@ -331,10 +362,9 @@ class EfficientNetBuilder:
|
|||
ba['in_chs'] = self.in_chs
|
||||
ba['out_chs'] = self.round_chs_fn(ba['out_chs'])
|
||||
s2d = ba.get('s2d', 0)
|
||||
if s2d:
|
||||
if s2d > 0:
|
||||
# adjust while space2depth active
|
||||
ba['out_chs'] *= 4
|
||||
if s2d == 1:
|
||||
ba['dw_kernel_size'] = (ba['dw_kernel_size'] + 1) // 2
|
||||
if 'force_in_chs' in ba and ba['force_in_chs']:
|
||||
# NOTE this is a hack to work around mismatch in TF EdgeEffNet impl
|
||||
ba['force_in_chs'] = self.round_chs_fn(ba['force_in_chs'])
|
||||
|
@ -344,19 +374,19 @@ class EfficientNetBuilder:
|
|||
assert ba['act_layer'] is not None
|
||||
ba['norm_layer'] = self.norm_layer
|
||||
ba['drop_path_rate'] = drop_path_rate
|
||||
if bt != 'cn':
|
||||
se_ratio = ba.pop('se_ratio')
|
||||
if se_ratio and self.se_layer is not None:
|
||||
if not self.se_from_exp:
|
||||
# adjust se_ratio by expansion ratio if calculating se channels from block input
|
||||
se_ratio /= ba.get('exp_ratio', 1.0)
|
||||
# adjust space2depth
|
||||
if s2d == 1:
|
||||
se_ratio /= 4
|
||||
if self.se_has_ratio:
|
||||
ba['se_layer'] = partial(self.se_layer, rd_ratio=se_ratio)
|
||||
else:
|
||||
ba['se_layer'] = self.se_layer
|
||||
|
||||
se_ratio = ba.pop('se_ratio', None)
|
||||
if se_ratio and self.se_layer is not None:
|
||||
if not self.se_from_exp:
|
||||
# adjust se_ratio by expansion ratio if calculating se channels from block input
|
||||
se_ratio /= ba.get('exp_ratio', 1.0)
|
||||
if s2d == 1:
|
||||
# adjust for start of space2depth
|
||||
se_ratio /= 4
|
||||
if self.se_has_ratio:
|
||||
ba['se_layer'] = partial(self.se_layer, rd_ratio=se_ratio)
|
||||
else:
|
||||
ba['se_layer'] = self.se_layer
|
||||
|
||||
if bt == 'ir':
|
||||
_log_info_if(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
|
||||
|
@ -370,8 +400,17 @@ class EfficientNetBuilder:
|
|||
elif bt == 'cn':
|
||||
_log_info_if(' ConvBnAct {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
|
||||
block = ConvBnAct(**ba)
|
||||
elif bt == 'uir':
|
||||
_log_info_if(' UniversalInvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
|
||||
block = UniversalInvertedResidual(**ba)
|
||||
elif bt == 'mqa':
|
||||
_log_info_if(' MobileMultiQueryAttention {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
|
||||
block = MobileAttention(**ba, use_multi_query=True)
|
||||
elif bt == 'mha':
|
||||
_log_info_if(' MobileMultiHeadAttention {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
|
||||
block = MobileAttention(**ba)
|
||||
else:
|
||||
assert False, 'Uknkown block type (%s) while building model.' % bt
|
||||
assert False, 'Unknown block type (%s) while building model.' % bt
|
||||
|
||||
self.in_chs = ba['out_chs'] # update in_chs for arg of next block
|
||||
return block
|
||||
|
@ -420,12 +459,10 @@ class EfficientNetBuilder:
|
|||
|
||||
if space2depth > 0:
|
||||
if space2depth == 2 and block_args['stride'] == 2:
|
||||
space2depth = 0
|
||||
block_args['stride'] = 1
|
||||
# to end s2d region, need to correct expansion and se ratio relative to input
|
||||
# FIXME unify with _make_block logic? this is rather meh
|
||||
block_args['exp_ratio'] /= 4
|
||||
#block_args['se_ratio'] /= 4
|
||||
space2depth = 0
|
||||
else:
|
||||
block_args['s2d'] = space2depth
|
||||
|
||||
|
|
|
@ -622,38 +622,215 @@ def _gen_lcnet(variant: str, channel_multiplier: float = 1.0, pretrained: bool =
|
|||
return model
|
||||
|
||||
|
||||
def _gen_lcnet(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs):
|
||||
""" LCNet
|
||||
Essentially a MobileNet-V3 crossed with a MobileNet-V1
|
||||
|
||||
Paper: `PP-LCNet: A Lightweight CPU Convolutional Neural Network` - https://arxiv.org/abs/2109.15099
|
||||
|
||||
def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
"""Creates a MobileNet-V4 model.
|
||||
|
||||
Ref impl: ?
|
||||
Paper: https://arxiv.org/abs/1905.02244
|
||||
|
||||
Args:
|
||||
channel_multiplier: multiplier to number of channels per layer.
|
||||
"""
|
||||
arch_def = [
|
||||
# stage 0, 112x112 in
|
||||
['dsa_r1_k3_s1_c32'],
|
||||
# stage 1, 112x112 in
|
||||
['dsa_r2_k3_s2_c64'],
|
||||
# stage 2, 56x56 in
|
||||
['dsa_r2_k3_s2_c128'],
|
||||
# stage 3, 28x28 in
|
||||
['dsa_r1_k3_s2_c256', 'dsa_r1_k5_s1_c256'],
|
||||
# stage 4, 14x14in
|
||||
['dsa_r4_k5_s1_c256'],
|
||||
# stage 5, 14x14in
|
||||
['dsa_r2_k5_s2_c512_se0.25'],
|
||||
# 7x7
|
||||
]
|
||||
if 'hybrid' in variant:
|
||||
if 'medium' in variant:
|
||||
stem_size = 32
|
||||
num_features = 1280
|
||||
act_layer = resolve_act_layer(kwargs, 'relu')
|
||||
arch_def = [
|
||||
# stage 0, 112x112 in
|
||||
['er_r1_k3_s2_e4_c48'],
|
||||
# stage 1, 56x56 in
|
||||
['uir_r1_a3_k5_s2_e4_c80', 'uir_r1_a3_k3_s1_e2_c80'],
|
||||
# stage 2, 28x28 in
|
||||
[
|
||||
'uir_r1_a3_k5_s2_e6_c160',
|
||||
'uir_r1_a0_k0_s1_e2_c160',
|
||||
'uir_r1_a3_k3_s1_e4_c160',
|
||||
'uir_r1_a3_k5_s1_e4_c160',
|
||||
'mqa_r1_k3_h4_s1_v2_d64_c160',
|
||||
'uir_r1_a3_k3_s1_e4_c160',
|
||||
'mqa_r1_k3_h4_s1_v2_d64_c160',
|
||||
'uir_r1_a3_k0_s1_e4_c160', # convnext
|
||||
'mqa_r1_k3_h4_s1_v2_d64_c160',
|
||||
'uir_r1_a3_k3_s1_e4_c160',
|
||||
'mqa_r1_k3_h4_s1_v2_d64_c160',
|
||||
'uir_r1_a3_k0_s1_e4_c160', # convnext
|
||||
],
|
||||
# stage 3, 14x14in
|
||||
[
|
||||
'uir_r1_a5_k5_s2_e6_c256',
|
||||
'uir_r1_a5_k5_s1_e4_c256',
|
||||
'uir_r2_a3_k5_s1_e4_c256',
|
||||
'uir_r1_a0_k0_s1_e2_c256',
|
||||
'uir_r1_a3_k5_s1_e2_c256',
|
||||
'uir_r1_a0_k0_s1_e2_c256',
|
||||
'uir_r1_a0_k0_s1_e4_c256',
|
||||
'mqa_r1_k3_h4_s1_d64_c256',
|
||||
'uir_r1_a3_k0_s1_e4_c256', # convnext
|
||||
'mqa_r1_k3_h4_s1_d64_c256',
|
||||
'uir_r1_a5_k5_s1_e4_c256',
|
||||
'mqa_r1_k3_h4_s1_d64_c256',
|
||||
'uir_r1_a5_k0_s1_e4_c256', # convnext4
|
||||
'mqa_r1_k3_h4_s1_d64_c256',
|
||||
'uir_r1_a5_k0_s1_e4_c256', # convnext4
|
||||
],
|
||||
# stage 4, 7x7 in
|
||||
['cn_r1_k1_s1_c960'],
|
||||
]
|
||||
elif 'large' in variant:
|
||||
stem_size = 24
|
||||
num_features = 1280
|
||||
act_layer = resolve_act_layer(kwargs, 'gelu')
|
||||
arch_def = [
|
||||
# stage 0, 112x112 in
|
||||
['er_r1_k3_s2_e4_c48'],
|
||||
# stage 1, 56x56 in
|
||||
['uir_r1_a3_k5_s2_e4_c96', 'uir_r1_a3_k3_s1_e4_c96'],
|
||||
# stage 2, 28x28 in
|
||||
[
|
||||
'uir_r1_a3_k5_s2_e4_c192',
|
||||
'uir_r3_a3_k3_s1_e4_c192',
|
||||
'uir_r1_a3_k5_s1_e4_c192',
|
||||
'uir_r2_a5_k3_s1_e4_c192',
|
||||
'mqa_r1_k3_h8_s1_v2_d48_c192',
|
||||
'uir_r1_a5_k3_s1_e4_c192',
|
||||
'mqa_r1_k3_h8_s1_v2_d48_c192',
|
||||
'uir_r1_a5_k3_s1_e4_c192',
|
||||
'mqa_r1_k3_h8_s1_v2_d48_c192',
|
||||
'uir_r1_a5_k3_s1_e4_c192',
|
||||
'mqa_r1_k3_h8_s1_v2_d48_c192',
|
||||
'uir_r1_a3_k0_s1_e4_c192', # convnext
|
||||
],
|
||||
# stage 3, 14x14in
|
||||
[
|
||||
'uir_r4_a5_k5_s2_e4_c512',
|
||||
'uir_r1_a5_k0_s1_e4_c512', # convnext
|
||||
'uir_r1_a5_k3_s1_e4_c512',
|
||||
'uir_r2_a5_k0_s1_e4_c512', # convnext
|
||||
'uir_r1_a5_k3_s1_e4_c512',
|
||||
'uir_r1_a5_k5_s1_e4_c512',
|
||||
'mqa_r1_k3_h8_s1_d64_c512',
|
||||
'uir_r3_a5_k0_s1_e4_c512', # convnext
|
||||
'mqa_r1_k3_h8_s1_d64_c512',
|
||||
'uir_r3_a5_k0_s1_e4_c512', # convnext
|
||||
'mqa_r1_k3_h8_s1_d64_c512',
|
||||
'uir_r3_a5_k0_s1_e4_c512', # convnext
|
||||
'mqa_r1_k3_h8_s1_d64_c512',
|
||||
'uir_r3_a5_k0_s1_e4_c512', # convnext
|
||||
],
|
||||
# stage 4, 7x7 in
|
||||
['cn_r1_k1_s1_c960'],
|
||||
]
|
||||
else:
|
||||
assert False, f'Unknown variant {variant}.'
|
||||
else:
|
||||
if 'small' in variant:
|
||||
stem_size = 32
|
||||
num_features = 1280
|
||||
act_layer = resolve_act_layer(kwargs, 'relu')
|
||||
arch_def = [
|
||||
# stage 0, 112x112 in
|
||||
['cn_r1_k3_s2_e1_c32', 'cn_r1_k1_s1_e1_c32'],
|
||||
# stage 1, 56x56 in
|
||||
['cn_r1_k3_s2_e1_c96', 'cn_r1_k1_s1_e1_c64'],
|
||||
# stage 2, 28x28 in
|
||||
[
|
||||
'uir_r1_a5_k5_s2_e3_c96', # start dw
|
||||
'uir_r4_a0_k3_s1_e2_c96', # ir
|
||||
'uir_r1_a3_k0_s1_e4_c96', # convnext
|
||||
],
|
||||
# stage 3, 14x14 in
|
||||
[
|
||||
'uir_r1_a3_k3_s2_e6_c128', # start dw
|
||||
'uir_r1_a5_k5_s1_e4_c128', # start dw
|
||||
'uir_r1_a0_k5_s1_e4_c128', # ir
|
||||
'uir_r1_a0_k5_s1_e3_c128', # ir
|
||||
'uir_r2_a0_k5_s1_e4_c128', # ir
|
||||
],
|
||||
# stage 4, 7x7 in
|
||||
['cn_r1_k1_s1_c960'], # hard-swish
|
||||
]
|
||||
elif 'medium' in variant:
|
||||
stem_size = 32
|
||||
num_features = 1280
|
||||
act_layer = resolve_act_layer(kwargs, 'relu')
|
||||
arch_def = [
|
||||
# stage 0, 112x112 in
|
||||
['er_r1_k3_s2_e4_c48'],
|
||||
# stage 1, 56x56 in
|
||||
['uir_r1_a3_k5_s2_e4_c80', 'uir_r1_a3_k3_s1_e2_c80'],
|
||||
# stage 2, 28x28 in
|
||||
[
|
||||
'uir_r1_a5_k3_s2_e6_c160',
|
||||
'uir_r2_a3_k3_s1_e4_c160',
|
||||
'uir_r1_a3_k3_s1_e4_c160',
|
||||
'uir_r1_a3_k3_s1_e4_c160',
|
||||
'uir_r1_a3_k0_s1_e4_c160', # convnext
|
||||
'uir_r2_a0_k0_s1_e2_c160',
|
||||
'uir_r1_a3_k0_s1_e4_c160', # convnext
|
||||
],
|
||||
# stage 3, 14x14in
|
||||
[
|
||||
'uir_r1_a5_k5_s2_e6_c256',
|
||||
'uir_r1_a5_k5_s1_e4_c256',
|
||||
'uir_r2_a3_k5_s1_e4_c256',
|
||||
'uir_r1_a0_k0_s1_e4_c256',
|
||||
'uir_r1_a3_k0_s1_e4_c256', # convnext
|
||||
'uir_r1_a3_k0_s1_e4_c256', # convnext
|
||||
'uir_r1_a3_k5_s1_e2_c256',
|
||||
'uir_r1_a5_k5_s1_e4_c256',
|
||||
'uir_r2_a0_k0_s1_e4_c256',
|
||||
'uir_r1_a5_k0_s1_e2_c256', # convnext
|
||||
],
|
||||
# stage 4, 7x7 in
|
||||
['cn_r1_k1_s1_c960'],
|
||||
]
|
||||
elif 'large' in variant:
|
||||
stem_size = 24
|
||||
num_features = 1280
|
||||
act_layer = resolve_act_layer(kwargs, 'relu')
|
||||
arch_def = [
|
||||
# stage 0, 112x112 in
|
||||
['er_r1_k3_s2_e4_c48'],
|
||||
# stage 1, 56x56 in
|
||||
['uir_r1_a3_k5_s2_e4_c96', 'uir_r1_a3_k3_s1_e4_c96'],
|
||||
# stage 2, 28x28 in
|
||||
[
|
||||
'uir_r1_a3_k5_s2_e4_c192',
|
||||
'uir_r3_a3_k3_s1_e4_c192',
|
||||
'uir_r1_a3_k5_s1_e4_c192',
|
||||
'uir_r5_a5_k3_s1_e4_c192',
|
||||
'uir_r1_a3_k0_s1_e4_c192', # convnext
|
||||
],
|
||||
# stage 3, 14x14in
|
||||
[
|
||||
'uir_r4_a5_k5_s2_e4_c512',
|
||||
'uir_r1_a5_k0_s1_e4_c512', # convnext
|
||||
'uir_r1_a5_k3_s1_e4_c512',
|
||||
'uir_r2_a5_k0_s1_e4_c512', # convnext
|
||||
'uir_r1_a5_k3_s1_e4_c512',
|
||||
'uir_r1_a5_k5_s1_e4_c512',
|
||||
'uir_r3_a5_k0_s1_e4_c512', # convnext
|
||||
|
||||
],
|
||||
# stage 4, 7x7 in
|
||||
['cn_r1_k1_s1_c960'],
|
||||
]
|
||||
else:
|
||||
assert False, f'Unknown variant {variant}.'
|
||||
|
||||
se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU, rd_round_fn=round_channels)
|
||||
model_kwargs = dict(
|
||||
block_args=decode_arch_def(arch_def),
|
||||
stem_size=16,
|
||||
num_features=num_features,
|
||||
stem_size=stem_size,
|
||||
fix_stem=channel_multiplier < 0.75,
|
||||
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
|
||||
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
|
||||
act_layer=resolve_act_layer(kwargs, 'hard_swish'),
|
||||
se_layer=partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU),
|
||||
num_features=1280,
|
||||
act_layer=act_layer,
|
||||
se_layer=se_layer,
|
||||
**kwargs,
|
||||
)
|
||||
model = _create_mnv3(variant, pretrained, **model_kwargs)
|
||||
|
@ -688,6 +865,9 @@ default_cfgs = generate_default_cfgs({
|
|||
origin_url='https://github.com/Alibaba-MIIL/ImageNet21K',
|
||||
paper_ids='arXiv:2104.10972v4',
|
||||
interpolation='bilinear', mean=(0., 0., 0.), std=(1., 1., 1.), num_classes=11221),
|
||||
'mobilenetv3_large_150.untrained': _cfg(
|
||||
interpolation='bicubic'),
|
||||
|
||||
|
||||
'mobilenetv3_small_050.lamb_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_050_lambc-4b7bbe87.pth',
|
||||
|
@ -762,6 +942,32 @@ default_cfgs = generate_default_cfgs({
|
|||
interpolation='bicubic',
|
||||
),
|
||||
"lcnet_150.untrained": _cfg(),
|
||||
|
||||
'mobilenetv4_conv_small': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
interpolation='bicubic'),
|
||||
'mobilenetv4_conv_medium': _cfg(
|
||||
#hf_hub_id='timm/',
|
||||
interpolation='bicubic'),
|
||||
'mobilenetv4_conv_large': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
interpolation='bicubic'),
|
||||
|
||||
'mobilenetv4_hybrid_small': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
interpolation='bicubic'),
|
||||
'mobilenetv4_hybrid_medium': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
interpolation='bicubic'),
|
||||
'mobilenetv4_hybrid_large': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
interpolation='bicubic'),
|
||||
'mobilenetv4_hybrid_medium_075': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
interpolation='bicubic'),
|
||||
'mobilenetv4_hybrid_medium_150': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
interpolation='bicubic'),
|
||||
})
|
||||
|
||||
|
||||
|
@ -779,6 +985,13 @@ def mobilenetv3_large_100(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
|||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv3_large_150(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" MobileNet V3 """
|
||||
model = _gen_mobilenet_v3('mobilenetv3_large_100', 1.5, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv3_small_050(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" MobileNet V3 """
|
||||
|
@ -918,6 +1131,56 @@ def lcnet_150(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
|||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv4_conv_small(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" MobileNet V4 """
|
||||
model = _gen_mobilenet_v4('mobilenetv4_conv_small', 1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv4_conv_medium(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" MobileNet V4 """
|
||||
model = _gen_mobilenet_v4('mobilenetv4_conv_medium', 1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv4_conv_large(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" MobileNet V4 """
|
||||
model = _gen_mobilenet_v4('mobilenetv4_conv_large', 1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv4_hybrid_medium_075(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" MobileNet V4 Hybrid """
|
||||
model = _gen_mobilenet_v4('mobilenetv4_hybrid_medium_075', 0.75, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv4_hybrid_medium(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" MobileNet V4 Hybrid """
|
||||
model = _gen_mobilenet_v4('mobilenetv4_hybrid_medium', 1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv4_hybrid_medium_150(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" MobileNet V4 Hybrid """
|
||||
model = _gen_mobilenet_v4('mobilenetv4_hybrid_medium_150', 1.5, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv4_hybrid_large(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" MobileNet V4 Hybrid"""
|
||||
model = _gen_mobilenet_v4('mobilenetv4_hybrid_large', 1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
|
||||
register_model_deprecations(__name__, {
|
||||
'mobilenetv3_large_100_miil': 'mobilenetv3_large_100.miil_in21k_ft_in1k',
|
||||
'mobilenetv3_large_100_miil_in21k': 'mobilenetv3_large_100.miil_in21k',
|
||||
|
|
Loading…
Reference in New Issue