mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Adding support to EfficientNet / MobileNetV3 to for different attention layers in .se position
This commit is contained in:
parent
27b3680d49
commit
67e759f710
@ -94,6 +94,14 @@ default_cfgs = {
|
||||
url='', input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954),
|
||||
'efficientnet_l2': _cfg(
|
||||
url='', input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.961),
|
||||
'efficientnet_eca_b0': _cfg(
|
||||
url=''),
|
||||
'efficientnet_eca_b1': _cfg(
|
||||
url='',
|
||||
input_size=(3, 240, 240), pool_size=(8, 8)),
|
||||
'efficientnet_eca_b2': _cfg(
|
||||
url='',
|
||||
input_size=(3, 260, 260), pool_size=(9, 9)),
|
||||
'efficientnet_es': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_es_ra-f111e99c.pth'),
|
||||
'efficientnet_em': _cfg(
|
||||
@ -254,7 +262,7 @@ class EfficientNet(nn.Module):
|
||||
def __init__(self, block_args, num_classes=1000, num_features=1280, in_chans=3, stem_size=32,
|
||||
channel_multiplier=1.0, channel_divisor=8, channel_min=None,
|
||||
output_stride=32, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0.,
|
||||
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, global_pool='avg'):
|
||||
attn_layer=None, attn_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, global_pool='avg'):
|
||||
super(EfficientNet, self).__init__()
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
|
||||
@ -272,8 +280,8 @@ class EfficientNet(nn.Module):
|
||||
|
||||
# Middle stages (IR/ER/DS Blocks)
|
||||
builder = EfficientNetBuilder(
|
||||
channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, se_kwargs,
|
||||
norm_layer, norm_kwargs, drop_path_rate, verbose=_DEBUG)
|
||||
channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer,
|
||||
attn_layer, attn_kwargs, norm_layer, norm_kwargs, drop_path_rate, verbose=_DEBUG)
|
||||
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
|
||||
self.feature_info = builder.features
|
||||
self._in_chs = builder.in_chs
|
||||
@ -334,7 +342,7 @@ class EfficientNetFeatures(nn.Module):
|
||||
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pre_pwl',
|
||||
in_chans=3, stem_size=32, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
|
||||
output_stride=32, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0.,
|
||||
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None):
|
||||
attn_layer=None, attn_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None):
|
||||
super(EfficientNetFeatures, self).__init__()
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
|
||||
@ -354,8 +362,8 @@ class EfficientNetFeatures(nn.Module):
|
||||
|
||||
# Middle stages (IR/ER/DS Blocks)
|
||||
builder = EfficientNetBuilder(
|
||||
channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, se_kwargs,
|
||||
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
|
||||
channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, attn_layer,
|
||||
attn_kwargs, norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
|
||||
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
|
||||
self.feature_info = builder.features # builder provides info about feature channels for each block
|
||||
self._in_chs = builder.in_chs
|
||||
@ -627,13 +635,13 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre
|
||||
|
||||
"""
|
||||
arch_def = [
|
||||
['ds_r1_k3_s1_e1_c16_se0.25'],
|
||||
['ir_r2_k3_s2_e6_c24_se0.25'],
|
||||
['ir_r2_k5_s2_e6_c40_se0.25'],
|
||||
['ir_r3_k3_s2_e6_c80_se0.25'],
|
||||
['ir_r3_k5_s1_e6_c112_se0.25'],
|
||||
['ir_r4_k5_s2_e6_c192_se0.25'],
|
||||
['ir_r1_k3_s1_e6_c320_se0.25'],
|
||||
['ds_r1_k3_s1_e1_c16'],
|
||||
['ir_r2_k3_s2_e6_c24'],
|
||||
['ir_r2_k5_s2_e6_c40'],
|
||||
['ir_r3_k3_s2_e6_c80'],
|
||||
['ir_r3_k5_s1_e6_c112'],
|
||||
['ir_r4_k5_s2_e6_c192'],
|
||||
['ir_r1_k3_s1_e6_c320'],
|
||||
]
|
||||
model_kwargs = dict(
|
||||
block_args=decode_arch_def(arch_def, depth_multiplier),
|
||||
@ -641,6 +649,8 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre
|
||||
stem_size=32,
|
||||
channel_multiplier=channel_multiplier,
|
||||
act_layer=Swish,
|
||||
attn_layer='sev2',
|
||||
attn_kwargs=dict(se_ratio=0.25),
|
||||
norm_kwargs=resolve_bn_args(kwargs),
|
||||
**kwargs,
|
||||
)
|
||||
@ -707,6 +717,53 @@ def _gen_efficientnet_condconv(
|
||||
return model
|
||||
|
||||
|
||||
def _gen_efficientnet_eca(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
|
||||
"""Creates an EfficientNet model w/ ECA attention instead of SE.
|
||||
|
||||
Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
|
||||
Paper: https://arxiv.org/abs/1905.11946
|
||||
|
||||
EfficientNet params
|
||||
name: (channel_multiplier, depth_multiplier, resolution, dropout_rate)
|
||||
'efficientnet-b0': (1.0, 1.0, 224, 0.2),
|
||||
'efficientnet-b1': (1.0, 1.1, 240, 0.2),
|
||||
'efficientnet-b2': (1.1, 1.2, 260, 0.3),
|
||||
'efficientnet-b3': (1.2, 1.4, 300, 0.3),
|
||||
'efficientnet-b4': (1.4, 1.8, 380, 0.4),
|
||||
'efficientnet-b5': (1.6, 2.2, 456, 0.4),
|
||||
'efficientnet-b6': (1.8, 2.6, 528, 0.5),
|
||||
'efficientnet-b7': (2.0, 3.1, 600, 0.5),
|
||||
'efficientnet-b8': (2.2, 3.6, 672, 0.5),
|
||||
'efficientnet-l2': (4.3, 5.3, 800, 0.5),
|
||||
|
||||
Args:
|
||||
channel_multiplier: multiplier to number of channels per layer
|
||||
depth_multiplier: multiplier to number of repeats per stage
|
||||
|
||||
"""
|
||||
arch_def = [
|
||||
['ds_r1_k3_s1_e1_c16'],
|
||||
['ir_r2_k3_s2_e6_c24'],
|
||||
['ir_r2_k5_s2_e6_c40'],
|
||||
['ir_r3_k3_s2_e6_c80'],
|
||||
['ir_r3_k5_s1_e6_c112'],
|
||||
['ir_r4_k5_s2_e6_c192'],
|
||||
['ir_r1_k3_s1_e6_c320'],
|
||||
]
|
||||
model_kwargs = dict(
|
||||
block_args=decode_arch_def(arch_def, depth_multiplier),
|
||||
num_features=round_channels(1280, channel_multiplier, 8, None),
|
||||
stem_size=32,
|
||||
channel_multiplier=channel_multiplier,
|
||||
act_layer=Swish,
|
||||
attn_layer='eca',
|
||||
norm_kwargs=resolve_bn_args(kwargs),
|
||||
**kwargs,
|
||||
)
|
||||
model = _create_model(model_kwargs, default_cfgs[variant], pretrained)
|
||||
return model
|
||||
|
||||
|
||||
def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
||||
"""Creates a MixNet Small model.
|
||||
|
||||
@ -980,6 +1037,33 @@ def efficientnet_l2(pretrained=False, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientnet_eca_b0(pretrained=False, **kwargs):
|
||||
""" EfficientNet-ECA-B0 """
|
||||
# NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
|
||||
model = _gen_efficientnet_eca(
|
||||
'efficientnet_eca_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientnet_eca_b1(pretrained=False, **kwargs):
|
||||
""" EfficientNet-ECA-B1 """
|
||||
# NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
|
||||
model = _gen_efficientnet_eca(
|
||||
'efficientnet_eca_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientnet_eca_b2(pretrained=False, **kwargs):
|
||||
""" EfficientNet-ECA-B2 """
|
||||
# NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2
|
||||
model = _gen_efficientnet_eca(
|
||||
'efficientnet_eca_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientnet_es(pretrained=False, **kwargs):
|
||||
""" EfficientNet-Edge Small. """
|
||||
|
@ -1,8 +1,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from .layers.activations import sigmoid
|
||||
from .layers import create_conv2d, drop_path
|
||||
from .layers import create_conv2d, create_attn, drop_path
|
||||
|
||||
|
||||
# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
|
||||
@ -30,26 +29,21 @@ def resolve_bn_args(kwargs):
|
||||
return bn_args
|
||||
|
||||
|
||||
_SE_ARGS_DEFAULT = dict(
|
||||
gate_fn=sigmoid,
|
||||
act_layer=None,
|
||||
reduce_mid=False,
|
||||
divisor=1)
|
||||
|
||||
|
||||
def resolve_se_args(kwargs, in_chs, act_layer=None):
|
||||
se_kwargs = kwargs.copy() if kwargs is not None else {}
|
||||
# fill in args that aren't specified with the defaults
|
||||
for k, v in _SE_ARGS_DEFAULT.items():
|
||||
se_kwargs.setdefault(k, v)
|
||||
# some models, like MobilNetV3, calculate SE reduction chs from the containing block's mid_ch instead of in_ch
|
||||
if not se_kwargs.pop('reduce_mid'):
|
||||
se_kwargs['reduced_base_chs'] = in_chs
|
||||
# act_layer override, if it remains None, the containing block's act_layer will be used
|
||||
if se_kwargs['act_layer'] is None:
|
||||
assert act_layer is not None
|
||||
se_kwargs['act_layer'] = act_layer
|
||||
return se_kwargs
|
||||
def resolve_attn_args(layer, kwargs, in_chs, act_layer=None):
|
||||
attn_kwargs = kwargs.copy() if kwargs is not None else {}
|
||||
if isinstance(layer, nn.Module):
|
||||
is_se = 'SqueezeExciteV2' in layer.__name__
|
||||
else:
|
||||
is_se = layer == 'sev2'
|
||||
if is_se:
|
||||
# some models, like MobilNetV3, calculate SE reduction chs from the containing block's mid_ch instead of in_ch
|
||||
if not attn_kwargs.pop('reduce_mid', False):
|
||||
attn_kwargs['reduced_base_chs'] = in_chs
|
||||
# if act_layer it is not defined by attn kwargs, the containing block's act_layer will be used for attn
|
||||
if attn_kwargs.get('act_layer', None) is None:
|
||||
assert act_layer is not None
|
||||
attn_kwargs['act_layer'] = act_layer
|
||||
return attn_kwargs
|
||||
|
||||
|
||||
def make_divisible(v, divisor=8, min_value=None):
|
||||
@ -90,26 +84,6 @@ class ChannelShuffle(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
class SqueezeExcite(nn.Module):
|
||||
def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None,
|
||||
act_layer=nn.ReLU, gate_fn=sigmoid, divisor=1, **_):
|
||||
super(SqueezeExcite, self).__init__()
|
||||
self.gate_fn = gate_fn
|
||||
reduced_chs = make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor)
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
|
||||
|
||||
def forward(self, x):
|
||||
x_se = self.avg_pool(x)
|
||||
x_se = self.conv_reduce(x_se)
|
||||
x_se = self.act1(x_se)
|
||||
x_se = self.conv_expand(x_se)
|
||||
x = x * self.gate_fn(x_se)
|
||||
return x
|
||||
|
||||
|
||||
class ConvBnAct(nn.Module):
|
||||
def __init__(self, in_chs, out_chs, kernel_size,
|
||||
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU,
|
||||
@ -140,11 +114,10 @@ class DepthwiseSeparableConv(nn.Module):
|
||||
"""
|
||||
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
|
||||
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
|
||||
pw_kernel_size=1, pw_act=False, se_ratio=0., se_kwargs=None,
|
||||
pw_kernel_size=1, pw_act=False, attn_layer=None, attn_kwargs=None,
|
||||
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_path_rate=0.):
|
||||
super(DepthwiseSeparableConv, self).__init__()
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
has_se = se_ratio is not None and se_ratio > 0.
|
||||
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
|
||||
self.has_pw_act = pw_act # activation after point-wise conv
|
||||
self.drop_path_rate = drop_path_rate
|
||||
@ -154,10 +127,10 @@ class DepthwiseSeparableConv(nn.Module):
|
||||
self.bn1 = norm_layer(in_chs, **norm_kwargs)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
if has_se:
|
||||
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
|
||||
self.se = SqueezeExcite(in_chs, se_ratio=se_ratio, **se_kwargs)
|
||||
# Attention block (Squeeze-Excitation, ECA, etc)
|
||||
if attn_layer is not None:
|
||||
attn_kwargs = resolve_attn_args(attn_layer, attn_kwargs, in_chs, act_layer)
|
||||
self.se = create_attn(attn_layer, in_chs, **attn_kwargs)
|
||||
else:
|
||||
self.se = None
|
||||
|
||||
@ -199,13 +172,12 @@ class InvertedResidual(nn.Module):
|
||||
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
|
||||
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
|
||||
exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1,
|
||||
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
|
||||
attn_layer=None, attn_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
|
||||
conv_kwargs=None, drop_path_rate=0.):
|
||||
super(InvertedResidual, self).__init__()
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
conv_kwargs = conv_kwargs or {}
|
||||
mid_chs = make_divisible(in_chs * exp_ratio)
|
||||
has_se = se_ratio is not None and se_ratio > 0.
|
||||
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
|
||||
self.drop_path_rate = drop_path_rate
|
||||
|
||||
@ -221,10 +193,10 @@ class InvertedResidual(nn.Module):
|
||||
self.bn2 = norm_layer(mid_chs, **norm_kwargs)
|
||||
self.act2 = act_layer(inplace=True)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
if has_se:
|
||||
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
|
||||
self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs)
|
||||
# Attention block (Squeeze-Excitation, ECA, etc)
|
||||
if attn_layer is not None:
|
||||
attn_kwargs = resolve_attn_args(attn_layer, attn_kwargs, in_chs, act_layer)
|
||||
self.se = create_attn(attn_layer, mid_chs, **attn_kwargs)
|
||||
else:
|
||||
self.se = None
|
||||
|
||||
@ -256,7 +228,7 @@ class InvertedResidual(nn.Module):
|
||||
x = self.bn2(x)
|
||||
x = self.act2(x)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
# Attention
|
||||
if self.se is not None:
|
||||
x = self.se(x)
|
||||
|
||||
@ -278,7 +250,7 @@ class CondConvResidual(InvertedResidual):
|
||||
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
|
||||
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
|
||||
exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1,
|
||||
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
|
||||
attn_layer=None, attn_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
|
||||
num_experts=0, drop_path_rate=0.):
|
||||
|
||||
self.num_experts = num_experts
|
||||
@ -287,7 +259,7 @@ class CondConvResidual(InvertedResidual):
|
||||
super(CondConvResidual, self).__init__(
|
||||
in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, dilation=dilation, pad_type=pad_type,
|
||||
act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size,
|
||||
pw_kernel_size=pw_kernel_size, se_ratio=se_ratio, se_kwargs=se_kwargs,
|
||||
pw_kernel_size=pw_kernel_size, attn_layer=attn_layer, attn_kwargs=attn_kwargs,
|
||||
norm_layer=norm_layer, norm_kwargs=norm_kwargs, conv_kwargs=conv_kwargs,
|
||||
drop_path_rate=drop_path_rate)
|
||||
|
||||
@ -310,7 +282,7 @@ class CondConvResidual(InvertedResidual):
|
||||
x = self.bn2(x)
|
||||
x = self.act2(x)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
# Attention
|
||||
if self.se is not None:
|
||||
x = self.se(x)
|
||||
|
||||
@ -330,7 +302,7 @@ class EdgeResidual(nn.Module):
|
||||
|
||||
def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_chs=0,
|
||||
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, pw_kernel_size=1,
|
||||
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
|
||||
attn_layer=None, attn_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
|
||||
drop_path_rate=0.):
|
||||
super(EdgeResidual, self).__init__()
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
@ -338,7 +310,6 @@ class EdgeResidual(nn.Module):
|
||||
mid_chs = make_divisible(fake_in_chs * exp_ratio)
|
||||
else:
|
||||
mid_chs = make_divisible(in_chs * exp_ratio)
|
||||
has_se = se_ratio is not None and se_ratio > 0.
|
||||
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
|
||||
self.drop_path_rate = drop_path_rate
|
||||
|
||||
@ -347,10 +318,10 @@ class EdgeResidual(nn.Module):
|
||||
self.bn1 = norm_layer(mid_chs, **norm_kwargs)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
if has_se:
|
||||
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
|
||||
self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs)
|
||||
# Attention block (Squeeze-Excitation, ECA, etc)
|
||||
if attn_layer is not None:
|
||||
attn_kwargs = resolve_attn_args(attn_layer, attn_kwargs, in_chs, act_layer)
|
||||
self.se = create_attn(attn_layer, mid_chs, **attn_kwargs)
|
||||
else:
|
||||
self.se = None
|
||||
|
||||
@ -378,7 +349,7 @@ class EdgeResidual(nn.Module):
|
||||
x = self.bn1(x)
|
||||
x = self.act1(x)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
# Attention
|
||||
if self.se is not None:
|
||||
x = self.se(x)
|
||||
|
||||
|
@ -79,6 +79,13 @@ def _decode_block_str(block_str):
|
||||
exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1
|
||||
pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1
|
||||
fake_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def
|
||||
attn_layer = None
|
||||
attn_kwargs = None
|
||||
if 'se' in options:
|
||||
attn_layer = 'sev2'
|
||||
attn_kwargs = dict(se_ratio=float(options['se']))
|
||||
elif 'eca' in options:
|
||||
attn_layer = 'eca'
|
||||
|
||||
num_repeat = int(options['r'])
|
||||
# each type of block has different valid arguments, fill accordingly
|
||||
@ -90,7 +97,8 @@ def _decode_block_str(block_str):
|
||||
pw_kernel_size=pw_kernel_size,
|
||||
out_chs=int(options['c']),
|
||||
exp_ratio=float(options['e']),
|
||||
se_ratio=float(options['se']) if 'se' in options else None,
|
||||
attn_layer=attn_layer,
|
||||
attn_kwargs=attn_kwargs,
|
||||
stride=int(options['s']),
|
||||
act_layer=act_layer,
|
||||
noskip=noskip,
|
||||
@ -103,7 +111,8 @@ def _decode_block_str(block_str):
|
||||
dw_kernel_size=_parse_ksize(options['k']),
|
||||
pw_kernel_size=pw_kernel_size,
|
||||
out_chs=int(options['c']),
|
||||
se_ratio=float(options['se']) if 'se' in options else None,
|
||||
attn_layer=attn_layer,
|
||||
attn_kwargs=attn_kwargs,
|
||||
stride=int(options['s']),
|
||||
act_layer=act_layer,
|
||||
pw_act=block_type == 'dsa',
|
||||
@ -117,7 +126,8 @@ def _decode_block_str(block_str):
|
||||
out_chs=int(options['c']),
|
||||
exp_ratio=float(options['e']),
|
||||
fake_in_chs=fake_in_chs,
|
||||
se_ratio=float(options['se']) if 'se' in options else None,
|
||||
attn_layer=attn_layer,
|
||||
attn_kwargs=attn_kwargs,
|
||||
stride=int(options['s']),
|
||||
act_layer=act_layer,
|
||||
noskip=noskip,
|
||||
@ -201,7 +211,7 @@ class EfficientNetBuilder:
|
||||
|
||||
"""
|
||||
def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
|
||||
output_stride=32, pad_type='', act_layer=None, se_kwargs=None,
|
||||
output_stride=32, pad_type='', act_layer=None, attn_layer=None, attn_kwargs=None,
|
||||
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_path_rate=0., feature_location='',
|
||||
verbose=False):
|
||||
self.channel_multiplier = channel_multiplier
|
||||
@ -210,7 +220,8 @@ class EfficientNetBuilder:
|
||||
self.output_stride = output_stride
|
||||
self.pad_type = pad_type
|
||||
self.act_layer = act_layer
|
||||
self.se_kwargs = se_kwargs
|
||||
self.attn_layer = attn_layer
|
||||
self.attn_kwargs = attn_kwargs
|
||||
self.norm_layer = norm_layer
|
||||
self.norm_kwargs = norm_kwargs
|
||||
self.drop_path_rate = drop_path_rate
|
||||
@ -239,9 +250,19 @@ class EfficientNetBuilder:
|
||||
# block act fn overrides the model default
|
||||
ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer
|
||||
assert ba['act_layer'] is not None
|
||||
if 'attn_layer' in ba:
|
||||
assert'attn_kwargs' in ba # block args should have both or neither
|
||||
# per-block attn layer overrides model default
|
||||
ba['attn_layer'] = ba['attn_layer'] if ba['attn_layer'] is not None else self.attn_layer
|
||||
if self.attn_kwargs is not None:
|
||||
# merge per-block attn kwargs with model if both exist
|
||||
if ba['attn_kwargs'] is None:
|
||||
ba['attn_kwargs'] = self.attn_kwargs
|
||||
else:
|
||||
ba['attn_kwargs'].update(self.attn_kwargs)
|
||||
|
||||
if bt == 'ir':
|
||||
ba['drop_path_rate'] = drop_path_rate
|
||||
ba['se_kwargs'] = self.se_kwargs
|
||||
if self.verbose:
|
||||
logging.info(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)))
|
||||
if ba.get('num_experts', 0) > 0:
|
||||
@ -250,13 +271,11 @@ class EfficientNetBuilder:
|
||||
block = InvertedResidual(**ba)
|
||||
elif bt == 'ds' or bt == 'dsa':
|
||||
ba['drop_path_rate'] = drop_path_rate
|
||||
ba['se_kwargs'] = self.se_kwargs
|
||||
if self.verbose:
|
||||
logging.info(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba)))
|
||||
block = DepthwiseSeparableConv(**ba)
|
||||
elif bt == 'er':
|
||||
ba['drop_path_rate'] = drop_path_rate
|
||||
ba['se_kwargs'] = self.se_kwargs
|
||||
if self.verbose:
|
||||
logging.info(' EdgeResidual {}, Args: {}'.format(block_idx, str(ba)))
|
||||
block = EdgeResidual(**ba)
|
||||
|
@ -11,7 +11,7 @@ import torch.nn.functional as F
|
||||
|
||||
from .registry import register_model
|
||||
from .helpers import load_pretrained
|
||||
from .layers import SEModule
|
||||
from .layers import SqueezeExcite
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
|
||||
from .resnet import ResNet, Bottleneck, BasicBlock
|
||||
@ -321,7 +321,7 @@ def gluon_seresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kw
|
||||
default_cfg = default_cfgs['gluon_seresnext50_32x4d']
|
||||
model = ResNet(
|
||||
Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4,
|
||||
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer=SEModule), **kwargs)
|
||||
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer=SqueezeExcite), **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
@ -335,7 +335,7 @@ def gluon_seresnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **k
|
||||
default_cfg = default_cfgs['gluon_seresnext101_32x4d']
|
||||
model = ResNet(
|
||||
Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=4,
|
||||
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer=SEModule), **kwargs)
|
||||
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer=SqueezeExcite), **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
@ -347,7 +347,7 @@ def gluon_seresnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **k
|
||||
"""Constructs a SEResNeXt-101-64x4d model.
|
||||
"""
|
||||
default_cfg = default_cfgs['gluon_seresnext101_64x4d']
|
||||
block_args = dict(attn_layer=SEModule)
|
||||
block_args = dict(attn_layer=SqueezeExcite)
|
||||
model = ResNet(
|
||||
Bottleneck, [3, 4, 23, 3], cardinality=64, base_width=4,
|
||||
num_classes=num_classes, in_chans=in_chans, block_args=block_args, **kwargs)
|
||||
@ -362,7 +362,7 @@ def gluon_senet154(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs an SENet-154 model.
|
||||
"""
|
||||
default_cfg = default_cfgs['gluon_senet154']
|
||||
block_args = dict(attn_layer=SEModule)
|
||||
block_args = dict(attn_layer=SqueezeExcite)
|
||||
model = ResNet(
|
||||
Bottleneck, [3, 8, 36, 3], cardinality=64, base_width=4, stem_type='deep', down_kernel_size=3,
|
||||
block_reduce_first=2, num_classes=num_classes, in_chans=in_chans, block_args=block_args, **kwargs)
|
||||
|
@ -7,8 +7,8 @@ from .cond_conv2d import CondConv2d, get_condconv_initializer
|
||||
from .create_conv2d import create_conv2d
|
||||
from .create_attn import create_attn
|
||||
from .selective_kernel import SelectiveKernelConv
|
||||
from .se import SEModule
|
||||
from .eca import EcaModule, CecaModule
|
||||
from .se import SqueezeExcite, SqueezeExciteV2
|
||||
from .eca import EfficientChannelAttn, CircularEfficientChannelAttn
|
||||
from .activations import *
|
||||
from .adaptive_avgmax_pool import \
|
||||
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
|
||||
|
@ -75,9 +75,9 @@ class LightSpatialAttn(nn.Module):
|
||||
return x * x_attn.sigmoid()
|
||||
|
||||
|
||||
class CbamModule(nn.Module):
|
||||
class ConvBlockAttn(nn.Module):
|
||||
def __init__(self, channels, spatial_kernel_size=7):
|
||||
super(CbamModule, self).__init__()
|
||||
super(ConvBlockAttn, self).__init__()
|
||||
self.channel = ChannelAttn(channels)
|
||||
self.spatial = SpatialAttn(spatial_kernel_size)
|
||||
|
||||
@ -87,9 +87,9 @@ class CbamModule(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class LightCbamModule(nn.Module):
|
||||
class LightConvBlockAttn(nn.Module):
|
||||
def __init__(self, channels, spatial_kernel_size=7):
|
||||
super(LightCbamModule, self).__init__()
|
||||
super(LightConvBlockAttn, self).__init__()
|
||||
self.channel = LightChannelAttn(channels)
|
||||
self.spatial = LightSpatialAttn(spatial_kernel_size)
|
||||
|
||||
|
@ -3,9 +3,9 @@
|
||||
Hacked together by Ross Wightman
|
||||
"""
|
||||
import torch
|
||||
from .se import SEModule
|
||||
from .eca import EcaModule, CecaModule
|
||||
from .cbam import CbamModule, LightCbamModule
|
||||
from .se import SqueezeExcite, SqueezeExciteV2
|
||||
from .eca import EfficientChannelAttn, CircularEfficientChannelAttn
|
||||
from .cbam import ConvBlockAttn, LightConvBlockAttn
|
||||
|
||||
|
||||
def create_attn(attn_type, channels, **kwargs):
|
||||
@ -14,20 +14,19 @@ def create_attn(attn_type, channels, **kwargs):
|
||||
if isinstance(attn_type, str):
|
||||
attn_type = attn_type.lower()
|
||||
if attn_type == 'se':
|
||||
module_cls = SEModule
|
||||
module_cls = SqueezeExcite
|
||||
elif attn_type == 'sev2':
|
||||
module_cls = SqueezeExciteV2
|
||||
elif attn_type == 'eca':
|
||||
module_cls = EcaModule
|
||||
elif attn_type == 'eca':
|
||||
module_cls = CecaModule
|
||||
module_cls = EfficientChannelAttn
|
||||
elif attn_type == 'ceca':
|
||||
module_cls = CircularEfficientChannelAttn
|
||||
elif attn_type == 'cbam':
|
||||
module_cls = CbamModule
|
||||
module_cls = ConvBlockAttn
|
||||
elif attn_type == 'lcbam':
|
||||
module_cls = LightCbamModule
|
||||
module_cls = LightConvBlockAttn
|
||||
else:
|
||||
assert False, "Invalid attn module (%s)" % attn_type
|
||||
elif isinstance(attn_type, bool):
|
||||
if attn_type:
|
||||
module_cls = SEModule
|
||||
else:
|
||||
module_cls = attn_type
|
||||
if module_cls is not None:
|
||||
|
@ -38,7 +38,7 @@ from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class EcaModule(nn.Module):
|
||||
class EfficientChannelAttn(nn.Module):
|
||||
"""Constructs an ECA module.
|
||||
|
||||
Args:
|
||||
@ -49,8 +49,8 @@ class EcaModule(nn.Module):
|
||||
(default=None. if channel size not given, use k_size given for kernel size.)
|
||||
kernel_size: Adaptive selection of kernel size (default=3)
|
||||
"""
|
||||
def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1):
|
||||
super(EcaModule, self).__init__()
|
||||
def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1, gate_fn=None):
|
||||
super(EfficientChannelAttn, self).__init__()
|
||||
assert kernel_size % 2 == 1
|
||||
|
||||
if channels is not None:
|
||||
@ -59,20 +59,18 @@ class EcaModule(nn.Module):
|
||||
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
|
||||
self.gate_fn = gate_fn
|
||||
|
||||
def forward(self, x):
|
||||
# Feature descriptor on the global spatial information
|
||||
y = self.avg_pool(x)
|
||||
# Reshape for convolution
|
||||
y = y.view(x.shape[0], 1, -1)
|
||||
# Two different branches of ECA module
|
||||
y = self.avg_pool(x) # Feature descriptor on the global spatial information
|
||||
y = y.view(x.shape[0], 1, -1) # Reshape for convolution
|
||||
y = self.conv(y)
|
||||
# Multi-scale information fusion
|
||||
y = y.view(x.shape[0], -1, 1, 1).sigmoid()
|
||||
y = y.view(x.shape[0], -1, 1, 1)
|
||||
y = y.sigmoid() if self.gate_fn is None else self.gate_fn(y)
|
||||
return x * y.expand_as(x)
|
||||
|
||||
|
||||
class CecaModule(nn.Module):
|
||||
class CircularEfficientChannelAttn(nn.Module):
|
||||
"""Constructs a circular ECA module.
|
||||
|
||||
ECA module where the conv uses circular padding rather than zero padding.
|
||||
@ -92,13 +90,14 @@ class CecaModule(nn.Module):
|
||||
kernel_size: Adaptive selection of kernel size (default=3)
|
||||
"""
|
||||
|
||||
def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1):
|
||||
super(CecaModule, self).__init__()
|
||||
def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1, gate_fn=None):
|
||||
super(CircularEfficientChannelAttn, self).__init__()
|
||||
assert kernel_size % 2 == 1
|
||||
|
||||
if channels is not None:
|
||||
t = int(abs(math.log(channels, 2) + beta) / gamma)
|
||||
kernel_size = max(t if t % 2 else t + 1, 3)
|
||||
self.padding = (kernel_size - 1) // 2
|
||||
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
#pytorch circular padding mode is buggy as of pytorch 1.4
|
||||
@ -106,19 +105,13 @@ class CecaModule(nn.Module):
|
||||
|
||||
#implement manual circular padding
|
||||
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=False)
|
||||
self.padding = (kernel_size - 1) // 2
|
||||
self.gate_fn = gate_fn
|
||||
|
||||
def forward(self, x):
|
||||
# Feature descriptor on the global spatial information
|
||||
y = self.avg_pool(x)
|
||||
|
||||
y = self.avg_pool(x) # Feature descriptor on the global spatial information
|
||||
# Manually implement circular padding, F.pad does not seemed to be bugged
|
||||
y = F.pad(y.view(x.shape[0], 1, -1), (self.padding, self.padding), mode='circular')
|
||||
|
||||
# Two different branches of ECA module
|
||||
y = self.conv(y)
|
||||
|
||||
# Multi-scale information fusion
|
||||
y = y.view(x.shape[0], -1, 1, 1).sigmoid()
|
||||
|
||||
y = y.view(x.shape[0], -1, 1, 1)
|
||||
y = y.sigmoid() if self.gate_fn is None else self.gate_fn(y)
|
||||
return x * y.expand_as(x)
|
||||
|
@ -21,7 +21,13 @@ tup_triple = _ntuple(3)
|
||||
tup_quadruple = _ntuple(4)
|
||||
|
||||
|
||||
|
||||
def make_divisible(v, divisor=8, min_value=None):
|
||||
min_value = min_value or divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
# Make sure that round down does not go down by more than 10%.
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
|
||||
|
@ -1,12 +1,22 @@
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
|
||||
from .helpers import make_divisible
|
||||
|
||||
class SEModule(nn.Module):
|
||||
|
||||
def __init__(self, channels, reduction=16, act_layer=nn.ReLU):
|
||||
super(SEModule, self).__init__()
|
||||
class SqueezeExcite(nn.Module):
|
||||
""" Squeeze-and-Excitation module as used in Pytorch SENet, SE-ResNeXt implementations
|
||||
|
||||
Args:
|
||||
channels (int): number of input and output channels
|
||||
reduction (int, float): divisor for attention (squeezed) channels
|
||||
act_layer (nn.Module): override the default ReLU activation
|
||||
"""
|
||||
|
||||
def __init__(self, channels, reduction=16, act_layer=nn.ReLU, divisible_by=1):
|
||||
super(SqueezeExcite, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
reduction_channels = max(channels // reduction, 8)
|
||||
reduction_channels = make_divisible(channels // reduction, divisible_by)
|
||||
self.fc1 = nn.Conv2d(
|
||||
channels, reduction_channels, kernel_size=1, padding=0, bias=True)
|
||||
self.act = act_layer(inplace=True)
|
||||
@ -19,3 +29,38 @@ class SEModule(nn.Module):
|
||||
x_se = self.act(x_se)
|
||||
x_se = self.fc2(x_se)
|
||||
return x * x_se.sigmoid()
|
||||
|
||||
|
||||
class SqueezeExciteV2(nn.Module):
|
||||
""" Squeeze-and-Excitation module as used in EfficientNet, MobileNetV3, related models
|
||||
|
||||
Differs from the original SqueezeExcite impl in that:
|
||||
* reduction is specified as a float multiplier instead of divisor (se_ratio)
|
||||
* gate function is changeable from sigmoid to alternate (ie hard_sigmoid)
|
||||
* layer names match those in weights for the EfficientNet/MobileNetV3 families
|
||||
|
||||
Args:
|
||||
channels (int): number of input and output channels
|
||||
se_ratio (float): multiplier for attention (squeezed) channels
|
||||
reduced_base_chs (int): specify alternate channel count to base the reduction channels on
|
||||
act_layer (nn.Module): override the default ReLU activation
|
||||
gate_fn (callable): override the default gate function
|
||||
"""
|
||||
|
||||
def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None,
|
||||
act_layer=nn.ReLU, gate_fn=torch.sigmoid, divisible_by=1, **_):
|
||||
super(SqueezeExciteV2, self).__init__()
|
||||
self.gate_fn = gate_fn
|
||||
reduced_chs = make_divisible((reduced_base_chs or in_chs) * se_ratio, divisible_by)
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
|
||||
|
||||
def forward(self, x):
|
||||
x_se = self.avg_pool(x)
|
||||
x_se = self.conv_reduce(x_se)
|
||||
x_se = self.act1(x_se)
|
||||
x_se = self.conv_expand(x_se)
|
||||
x = x * self.gate_fn(x_se)
|
||||
return x
|
||||
|
@ -30,10 +30,11 @@ def _cfg(url='', **kwargs):
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
'mobilenetv3_large_075': _cfg(url=''),
|
||||
'mobilenetv3_large_100': _cfg(url=''),
|
||||
'mobilenetv3_small_075': _cfg(url=''),
|
||||
'mobilenetv3_small_100': _cfg(url=''),
|
||||
'mobilenetv3_large_075': _cfg(url='', interoplation='bicubic'),
|
||||
'mobilenetv3_large_100': _cfg(url='', interoplation='bicubic'),
|
||||
'mobilenetv3_small_075': _cfg(url='', interoplation='bicubic'),
|
||||
'mobilenetv3_small_100': _cfg(url='', interoplation='bicubic'),
|
||||
'mobilenetv3_eca_large': _cfg(url='', interoplation='bicubic'),
|
||||
'mobilenetv3_rw': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth',
|
||||
interpolation='bicubic'),
|
||||
@ -72,7 +73,7 @@ class MobileNetV3(nn.Module):
|
||||
|
||||
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True,
|
||||
channel_multiplier=1.0, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0.,
|
||||
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, global_pool='avg'):
|
||||
attn_layer=None, attn_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, global_pool='avg'):
|
||||
super(MobileNetV3, self).__init__()
|
||||
|
||||
self.num_classes = num_classes
|
||||
@ -89,7 +90,7 @@ class MobileNetV3(nn.Module):
|
||||
|
||||
# Middle stages (IR/ER/DS Blocks)
|
||||
builder = EfficientNetBuilder(
|
||||
channel_multiplier, 8, None, 32, pad_type, act_layer, se_kwargs,
|
||||
channel_multiplier, 8, None, 32, pad_type, act_layer, attn_layer, attn_kwargs,
|
||||
norm_layer, norm_kwargs, drop_path_rate, verbose=_DEBUG)
|
||||
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
|
||||
self.feature_info = builder.features
|
||||
@ -148,7 +149,7 @@ class MobileNetV3Features(nn.Module):
|
||||
|
||||
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pre_pwl',
|
||||
in_chans=3, stem_size=16, channel_multiplier=1.0, output_stride=32, pad_type='',
|
||||
act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., se_kwargs=None,
|
||||
act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., attn_layer=None, attn_kwargs=None,
|
||||
norm_layer=nn.BatchNorm2d, norm_kwargs=None):
|
||||
super(MobileNetV3Features, self).__init__()
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
@ -169,7 +170,7 @@ class MobileNetV3Features(nn.Module):
|
||||
|
||||
# Middle stages (IR/ER/DS Blocks)
|
||||
builder = EfficientNetBuilder(
|
||||
channel_multiplier, 8, None, output_stride, pad_type, act_layer, se_kwargs,
|
||||
channel_multiplier, 8, None, output_stride, pad_type, act_layer, attn_layer, attn_kwargs,
|
||||
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
|
||||
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
|
||||
self.feature_info = builder.features # builder provides info about feature channels for each block
|
||||
@ -256,7 +257,7 @@ def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kw
|
||||
channel_multiplier=channel_multiplier,
|
||||
norm_kwargs=resolve_bn_args(kwargs),
|
||||
act_layer=HardSwish,
|
||||
se_kwargs=dict(gate_fn=hard_sigmoid, reduce_mid=True, divisor=1),
|
||||
attn_kwargs=dict(gate_fn=hard_sigmoid, reduce_mid=True, divisor=1),
|
||||
**kwargs,
|
||||
)
|
||||
model = _create_model(model_kwargs, default_cfgs[variant], pretrained)
|
||||
@ -352,7 +353,68 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg
|
||||
channel_multiplier=channel_multiplier,
|
||||
norm_kwargs=resolve_bn_args(kwargs),
|
||||
act_layer=act_layer,
|
||||
se_kwargs=dict(act_layer=nn.ReLU, gate_fn=hard_sigmoid, reduce_mid=True, divisor=8),
|
||||
attn_kwargs=dict(act_layer=nn.ReLU, gate_fn=hard_sigmoid, reduce_mid=True, divisible_by=8),
|
||||
**kwargs,
|
||||
)
|
||||
model = _create_model(model_kwargs, default_cfgs[variant], pretrained)
|
||||
return model
|
||||
|
||||
|
||||
def _gen_mobilenet_v3_eca(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
||||
"""Creates a MobileNet-V3 model.
|
||||
|
||||
Ref impl: ?
|
||||
Paper: https://arxiv.org/abs/1905.02244
|
||||
|
||||
Args:
|
||||
channel_multiplier: multiplier to number of channels per layer.
|
||||
"""
|
||||
if 'small' in variant:
|
||||
num_features = 1024
|
||||
act_layer = HardSwish
|
||||
arch_def = [
|
||||
# stage 0, 112x112 in
|
||||
['ds_r1_k3_s2_e1_c16_nre'], # relu
|
||||
# stage 1, 56x56 in
|
||||
['ir_r1_k3_s2_e4.5_c24_nre', 'ir_r1_k3_s1_e3.67_c24_nre'], # relu
|
||||
# stage 2, 28x28 in
|
||||
['ir_r1_k5_s2_e4_c40', 'ir_r2_k5_s1_e6_c40'], # hard-swish
|
||||
# stage 3, 14x14 in
|
||||
['ir_r2_k5_s1_e3_c48'], # hard-swish
|
||||
# stage 4, 14x14in
|
||||
['ir_r3_k5_s2_e6_c96'], # hard-swish
|
||||
# stage 6, 7x7 in
|
||||
['cn_r1_k1_s1_c576'], # hard-swish
|
||||
]
|
||||
else:
|
||||
num_features = 1280
|
||||
act_layer = HardSwish
|
||||
arch_def = [
|
||||
# stage 0, 112x112 in
|
||||
['ds_r1_k3_s1_e1_c16_nre'], # relu
|
||||
# stage 1, 112x112 in
|
||||
['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu
|
||||
# stage 2, 56x56 in
|
||||
['ir_r3_k5_s2_e3_c40_nre'], # relu
|
||||
# stage 3, 28x28 in
|
||||
['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish
|
||||
# stage 4, 14x14in
|
||||
['ir_r2_k3_s1_e6_c112'], # hard-swish
|
||||
# stage 5, 14x14in
|
||||
['ir_r3_k5_s2_e6_c160'], # hard-swish
|
||||
# stage 6, 7x7 in
|
||||
['cn_r1_k1_s1_c960'], # hard-swish
|
||||
]
|
||||
|
||||
model_kwargs = dict(
|
||||
block_args=decode_arch_def(arch_def),
|
||||
num_features=num_features,
|
||||
stem_size=16,
|
||||
channel_multiplier=channel_multiplier,
|
||||
norm_kwargs=resolve_bn_args(kwargs),
|
||||
act_layer=act_layer,
|
||||
attn_layer='eca',
|
||||
attn_kwargs=dict(gate_fn=hard_sigmoid),
|
||||
**kwargs,
|
||||
)
|
||||
model = _create_model(model_kwargs, default_cfgs[variant], pretrained)
|
||||
@ -382,12 +444,18 @@ def mobilenetv3_small_075(pretrained=False, **kwargs):
|
||||
|
||||
@register_model
|
||||
def mobilenetv3_small_100(pretrained=False, **kwargs):
|
||||
print(kwargs)
|
||||
""" MobileNet V3 """
|
||||
model = _gen_mobilenet_v3('mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv3_eca_large(pretrained=False, **kwargs):
|
||||
""" MobileNet V3 """
|
||||
model = _gen_mobilenet_v3_eca('mobilenetv3_eca_large', 1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv3_rw(pretrained=False, **kwargs):
|
||||
""" MobileNet V3 """
|
||||
|
@ -11,7 +11,7 @@ import torch.nn.functional as F
|
||||
from .resnet import ResNet
|
||||
from .registry import register_model
|
||||
from .helpers import load_pretrained
|
||||
from .layers import SEModule
|
||||
from .layers import SqueezeExcite
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
|
||||
__all__ = []
|
||||
|
Loading…
x
Reference in New Issue
Block a user