mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
EfficientNet and related cleanup
* remove folded_bn support and corresponding untrainable tflite ported weights * combine bn args into dict * add inplace support to activations and use where possible for reduced mem on large models
This commit is contained in:
parent
c11973602d
commit
d6ac5bbc48
@ -50,18 +50,12 @@ default_cfgs = {
|
||||
'mnasnet_100': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_b1-74cb7081.pth',
|
||||
interpolation='bicubic'),
|
||||
'tflite_mnasnet_100': _cfg(
|
||||
url='https://www.dropbox.com/s/q55ir3tx8mpeyol/tflite_mnasnet_100-31639cdc.pth?dl=1',
|
||||
interpolation='bicubic'),
|
||||
'mnasnet_140': _cfg(url=''),
|
||||
'semnasnet_050': _cfg(url=''),
|
||||
'semnasnet_075': _cfg(url=''),
|
||||
'semnasnet_100': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_a1-d9418771.pth',
|
||||
interpolation='bicubic'),
|
||||
'tflite_semnasnet_100': _cfg(
|
||||
url='https://www.dropbox.com/s/yiori47sr9dydev/tflite_semnasnet_100-7c780429.pth?dl=1',
|
||||
interpolation='bicubic'),
|
||||
'semnasnet_140': _cfg(url=''),
|
||||
'mnasnet_small': _cfg(url=''),
|
||||
'mobilenetv1_100': _cfg(url=''),
|
||||
@ -118,6 +112,7 @@ _DEBUG = False
|
||||
# Default args for PyTorch BN impl
|
||||
_BN_MOMENTUM_PT_DEFAULT = 0.1
|
||||
_BN_EPS_PT_DEFAULT = 1e-5
|
||||
_BN_ARGS_PT = dict(momentum=_BN_MOMENTUM_PT_DEFAULT, eps=_BN_EPS_PT_DEFAULT)
|
||||
|
||||
# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
|
||||
# papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay)
|
||||
@ -126,23 +121,18 @@ _BN_EPS_PT_DEFAULT = 1e-5
|
||||
# .9997 (/w .999 in search space) for paper
|
||||
_BN_MOMENTUM_TF_DEFAULT = 1 - 0.99
|
||||
_BN_EPS_TF_DEFAULT = 1e-3
|
||||
_BN_ARGS_TF = dict(momentum=_BN_MOMENTUM_TF_DEFAULT, eps=_BN_EPS_TF_DEFAULT)
|
||||
|
||||
|
||||
def _resolve_bn_params(kwargs):
|
||||
# NOTE kwargs passed as dict intentionally
|
||||
bn_momentum_default = _BN_MOMENTUM_PT_DEFAULT
|
||||
bn_eps_default = _BN_EPS_PT_DEFAULT
|
||||
bn_tf = kwargs.pop('bn_tf', False)
|
||||
if bn_tf:
|
||||
bn_momentum_default = _BN_MOMENTUM_TF_DEFAULT
|
||||
bn_eps_default = _BN_EPS_TF_DEFAULT
|
||||
def _resolve_bn_args(kwargs):
|
||||
bn_args = _BN_ARGS_TF.copy() if kwargs.pop('bn_tf', False) else _BN_ARGS_PT.copy()
|
||||
bn_momentum = kwargs.pop('bn_momentum', None)
|
||||
if bn_momentum is not None:
|
||||
bn_args['momentum'] = bn_momentum
|
||||
bn_eps = kwargs.pop('bn_eps', None)
|
||||
if bn_momentum is None:
|
||||
bn_momentum = bn_momentum_default
|
||||
if bn_eps is None:
|
||||
bn_eps = bn_eps_default
|
||||
return bn_momentum, bn_eps
|
||||
if bn_eps is not None:
|
||||
bn_args['eps'] = bn_eps
|
||||
return bn_args
|
||||
|
||||
|
||||
def _round_channels(channels, multiplier=1.0, divisor=8, channel_min=None):
|
||||
@ -292,6 +282,31 @@ def _decode_arch_def(arch_def, depth_multiplier=1.0):
|
||||
return arch_args
|
||||
|
||||
|
||||
def swish(x, inplace=False):
|
||||
if inplace:
|
||||
return x.mul_(x.sigmoid())
|
||||
else:
|
||||
return x * x.sigmoid()
|
||||
|
||||
|
||||
def sigmoid(x, inplace=False):
|
||||
return x.sigmoid_() if inplace else x.sigmoid()
|
||||
|
||||
|
||||
def hard_swish(x, inplace=False):
|
||||
if inplace:
|
||||
return x.mul_(F.relu6(x + 3.) / 6.)
|
||||
else:
|
||||
return x * F.relu6(x + 3.) / 6.
|
||||
|
||||
|
||||
def hard_sigmoid(x, inplace=False):
|
||||
if inplace:
|
||||
return x.add_(3.).clamp_(0., 6.).div_(6.)
|
||||
else:
|
||||
return F.relu6(x + 3.) / 6.
|
||||
|
||||
|
||||
class _BlockBuilder:
|
||||
""" Build Trunk Blocks
|
||||
|
||||
@ -303,9 +318,9 @@ class _BlockBuilder:
|
||||
"""
|
||||
|
||||
def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
|
||||
drop_connect_rate=0., act_fn=None, se_gate_fn=torch.sigmoid, se_reduce_mid=False,
|
||||
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
|
||||
folded_bn=False, padding_same=False, verbose=False):
|
||||
drop_connect_rate=0., act_fn=None, se_gate_fn=sigmoid, se_reduce_mid=False,
|
||||
bn_args=_BN_ARGS_PT, padding_same=False,
|
||||
verbose=False):
|
||||
self.channel_multiplier = channel_multiplier
|
||||
self.channel_divisor = channel_divisor
|
||||
self.channel_min = channel_min
|
||||
@ -313,9 +328,7 @@ class _BlockBuilder:
|
||||
self.act_fn = act_fn
|
||||
self.se_gate_fn = se_gate_fn
|
||||
self.se_reduce_mid = se_reduce_mid
|
||||
self.bn_momentum = bn_momentum
|
||||
self.bn_eps = bn_eps
|
||||
self.folded_bn = folded_bn
|
||||
self.bn_args = bn_args
|
||||
self.padding_same = padding_same
|
||||
self.verbose = verbose
|
||||
|
||||
@ -331,9 +344,7 @@ class _BlockBuilder:
|
||||
bt = ba.pop('block_type')
|
||||
ba['in_chs'] = self.in_chs
|
||||
ba['out_chs'] = self._round_channels(ba['out_chs'])
|
||||
ba['bn_momentum'] = self.bn_momentum
|
||||
ba['bn_eps'] = self.bn_eps
|
||||
ba['folded_bn'] = self.folded_bn
|
||||
ba['bn_args'] = self.bn_args
|
||||
ba['padding_same'] = self.padding_same
|
||||
# block act fn overrides the model default
|
||||
ba['act_fn'] = ba['act_fn'] if ba['act_fn'] is not None else self.act_fn
|
||||
@ -427,18 +438,6 @@ def _initialize_weight_default(m):
|
||||
nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear')
|
||||
|
||||
|
||||
def swish(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
def hard_swish(x):
|
||||
return x * F.relu6(x + 3.) / 6.
|
||||
|
||||
|
||||
def hard_sigmoid(x):
|
||||
return F.relu6(x + 3.) / 6.
|
||||
|
||||
|
||||
def drop_connect(inputs, training=False, drop_connect_rate=0.):
|
||||
"""Apply drop connect."""
|
||||
if not training:
|
||||
@ -474,7 +473,7 @@ class ChannelShuffle(nn.Module):
|
||||
|
||||
|
||||
class SqueezeExcite(nn.Module):
|
||||
def __init__(self, in_chs, reduce_chs=None, act_fn=F.relu, gate_fn=torch.sigmoid):
|
||||
def __init__(self, in_chs, reduce_chs=None, act_fn=F.relu, gate_fn=sigmoid):
|
||||
super(SqueezeExcite, self).__init__()
|
||||
self.act_fn = act_fn
|
||||
self.gate_fn = gate_fn
|
||||
@ -486,17 +485,16 @@ class SqueezeExcite(nn.Module):
|
||||
# NOTE adaptiveavgpool can be used here, but seems to cause issues with NVIDIA AMP performance
|
||||
x_se = x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1)
|
||||
x_se = self.conv_reduce(x_se)
|
||||
x_se = self.act_fn(x_se)
|
||||
x_se = self.act_fn(x_se, inplace=True)
|
||||
x_se = self.conv_expand(x_se)
|
||||
x = self.gate_fn(x_se) * x
|
||||
x = x * self.gate_fn(x_se)
|
||||
return x
|
||||
|
||||
|
||||
class ConvBnAct(nn.Module):
|
||||
def __init__(self, in_chs, out_chs, kernel_size,
|
||||
stride=1, act_fn=F.relu,
|
||||
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
|
||||
folded_bn=False, padding_same=False):
|
||||
bn_args=_BN_ARGS_PT, padding_same=False):
|
||||
super(ConvBnAct, self).__init__()
|
||||
assert stride in [1, 2]
|
||||
self.act_fn = act_fn
|
||||
@ -504,14 +502,13 @@ class ConvBnAct(nn.Module):
|
||||
|
||||
self.conv = sconv2d(
|
||||
in_chs, out_chs, kernel_size,
|
||||
stride=stride, padding=padding, bias=folded_bn)
|
||||
self.bn1 = None if folded_bn else nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps)
|
||||
stride=stride, padding=padding, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(out_chs, **bn_args)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
if self.bn1 is not None:
|
||||
x = self.bn1(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.bn1(x)
|
||||
x = self.act_fn(x, inplace=True)
|
||||
return x
|
||||
|
||||
|
||||
@ -522,9 +519,8 @@ class DepthwiseSeparableConv(nn.Module):
|
||||
"""
|
||||
def __init__(self, in_chs, out_chs, kernel_size,
|
||||
stride=1, act_fn=F.relu, noskip=False, pw_act=False,
|
||||
se_ratio=0., se_gate_fn=torch.sigmoid,
|
||||
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
|
||||
folded_bn=False, padding_same=False, drop_connect_rate=0.):
|
||||
se_ratio=0., se_gate_fn=sigmoid,
|
||||
bn_args=_BN_ARGS_PT, padding_same=False, drop_connect_rate=0.):
|
||||
super(DepthwiseSeparableConv, self).__init__()
|
||||
assert stride in [1, 2]
|
||||
self.has_se = se_ratio is not None and se_ratio > 0.
|
||||
@ -537,33 +533,31 @@ class DepthwiseSeparableConv(nn.Module):
|
||||
|
||||
self.conv_dw = sconv2d(
|
||||
in_chs, in_chs, kernel_size,
|
||||
stride=stride, padding=dw_padding, groups=in_chs, bias=folded_bn)
|
||||
self.bn1 = None if folded_bn else nn.BatchNorm2d(in_chs, momentum=bn_momentum, eps=bn_eps)
|
||||
stride=stride, padding=dw_padding, groups=in_chs, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(in_chs, **bn_args)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
if self.has_se:
|
||||
self.se = SqueezeExcite(
|
||||
in_chs, reduce_chs=max(1, int(in_chs * se_ratio)), act_fn=act_fn, gate_fn=se_gate_fn)
|
||||
|
||||
self.conv_pw = sconv2d(in_chs, out_chs, 1, padding=pw_padding, bias=folded_bn)
|
||||
self.bn2 = None if folded_bn else nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps)
|
||||
self.conv_pw = sconv2d(in_chs, out_chs, 1, padding=pw_padding, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(out_chs, **bn_args)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
x = self.conv_dw(x)
|
||||
if self.bn1 is not None:
|
||||
x = self.bn1(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.bn1(x)
|
||||
x = self.act_fn(x, inplace=True)
|
||||
|
||||
if self.has_se:
|
||||
x = self.se(x)
|
||||
|
||||
x = self.conv_pw(x)
|
||||
if self.bn2 is not None:
|
||||
x = self.bn2(x)
|
||||
x = self.bn2(x)
|
||||
if self.has_pw_act:
|
||||
x = self.act_fn(x)
|
||||
x = self.act_fn(x, inplace=True)
|
||||
|
||||
if self.has_residual:
|
||||
if self.drop_connect_rate > 0.:
|
||||
@ -577,10 +571,9 @@ class InvertedResidual(nn.Module):
|
||||
|
||||
def __init__(self, in_chs, out_chs, kernel_size,
|
||||
stride=1, act_fn=F.relu, exp_ratio=1.0, noskip=False,
|
||||
se_ratio=0., se_reduce_mid=False, se_gate_fn=torch.sigmoid,
|
||||
se_ratio=0., se_reduce_mid=False, se_gate_fn=sigmoid,
|
||||
shuffle_type=None, pw_group=1,
|
||||
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
|
||||
folded_bn=False, padding_same=False, drop_connect_rate=0.):
|
||||
bn_args=_BN_ARGS_PT, padding_same=False, drop_connect_rate=0.):
|
||||
super(InvertedResidual, self).__init__()
|
||||
mid_chs = int(in_chs * exp_ratio)
|
||||
self.has_se = se_ratio is not None and se_ratio > 0.
|
||||
@ -591,8 +584,8 @@ class InvertedResidual(nn.Module):
|
||||
pw_padding = _padding_arg(0, padding_same)
|
||||
|
||||
# Point-wise expansion
|
||||
self.conv_pw = sconv2d(in_chs, mid_chs, 1, padding=pw_padding, groups=pw_group, bias=folded_bn)
|
||||
self.bn1 = None if folded_bn else nn.BatchNorm2d(mid_chs, momentum=bn_momentum, eps=bn_eps)
|
||||
self.conv_pw = sconv2d(in_chs, mid_chs, 1, padding=pw_padding, groups=pw_group, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(mid_chs, **bn_args)
|
||||
|
||||
self.shuffle_type = shuffle_type
|
||||
if shuffle_type is not None:
|
||||
@ -600,8 +593,8 @@ class InvertedResidual(nn.Module):
|
||||
|
||||
# Depth-wise convolution
|
||||
self.conv_dw = sconv2d(
|
||||
mid_chs, mid_chs, kernel_size, padding=dw_padding, stride=stride, groups=mid_chs, bias=folded_bn)
|
||||
self.bn2 = None if folded_bn else nn.BatchNorm2d(mid_chs, momentum=bn_momentum, eps=bn_eps)
|
||||
mid_chs, mid_chs, kernel_size, padding=dw_padding, stride=stride, groups=mid_chs, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(mid_chs, **bn_args)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
if self.has_se:
|
||||
@ -610,17 +603,16 @@ class InvertedResidual(nn.Module):
|
||||
mid_chs, reduce_chs=max(1, int(se_base_chs * se_ratio)), act_fn=act_fn, gate_fn=se_gate_fn)
|
||||
|
||||
# Point-wise linear projection
|
||||
self.conv_pwl = sconv2d(mid_chs, out_chs, 1, padding=pw_padding, groups=pw_group, bias=folded_bn)
|
||||
self.bn3 = None if folded_bn else nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps)
|
||||
self.conv_pwl = sconv2d(mid_chs, out_chs, 1, padding=pw_padding, groups=pw_group, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(out_chs, **bn_args)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
# Point-wise expansion
|
||||
x = self.conv_pw(x)
|
||||
if self.bn1 is not None:
|
||||
x = self.bn1(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.bn1(x)
|
||||
x = self.act_fn(x, inplace=True)
|
||||
|
||||
# FIXME haven't tried this yet
|
||||
# for channel shuffle when using groups with pointwise convs as per FBNet variants
|
||||
@ -629,9 +621,8 @@ class InvertedResidual(nn.Module):
|
||||
|
||||
# Depth-wise convolution
|
||||
x = self.conv_dw(x)
|
||||
if self.bn2 is not None:
|
||||
x = self.bn2(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.bn2(x)
|
||||
x = self.act_fn(x, inplace=True)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
if self.has_se:
|
||||
@ -639,8 +630,7 @@ class InvertedResidual(nn.Module):
|
||||
|
||||
# Point-wise linear projection
|
||||
x = self.conv_pwl(x)
|
||||
if self.bn3 is not None:
|
||||
x = self.bn3(x)
|
||||
x = self.bn3(x)
|
||||
|
||||
if self.has_residual:
|
||||
if self.drop_connect_rate > 0.:
|
||||
@ -668,11 +658,9 @@ class GenEfficientNet(nn.Module):
|
||||
|
||||
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_features=1280,
|
||||
channel_multiplier=1.0, channel_divisor=8, channel_min=None,
|
||||
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
|
||||
drop_rate=0., drop_connect_rate=0., act_fn=F.relu,
|
||||
se_gate_fn=torch.sigmoid, se_reduce_mid=False,
|
||||
global_pool='avg', head_conv='default', weight_init='goog',
|
||||
folded_bn=False, padding_same=False,):
|
||||
se_gate_fn=sigmoid, se_reduce_mid=False, bn_args=_BN_ARGS_PT,
|
||||
global_pool='avg', head_conv='default', weight_init='goog', padding_same=False):
|
||||
super(GenEfficientNet, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.drop_rate = drop_rate
|
||||
@ -682,14 +670,14 @@ class GenEfficientNet(nn.Module):
|
||||
stem_size = _round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
|
||||
self.conv_stem = sconv2d(
|
||||
in_chans, stem_size, 3,
|
||||
padding=_padding_arg(1, padding_same), stride=2, bias=folded_bn)
|
||||
self.bn1 = None if folded_bn else nn.BatchNorm2d(stem_size, momentum=bn_momentum, eps=bn_eps)
|
||||
padding=_padding_arg(1, padding_same), stride=2, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(stem_size, **bn_args)
|
||||
in_chs = stem_size
|
||||
|
||||
builder = _BlockBuilder(
|
||||
channel_multiplier, channel_divisor, channel_min,
|
||||
drop_connect_rate, act_fn, se_gate_fn, se_reduce_mid,
|
||||
bn_momentum, bn_eps, folded_bn, padding_same, verbose=_DEBUG)
|
||||
bn_args, padding_same, verbose=_DEBUG)
|
||||
self.blocks = nn.Sequential(*builder(in_chs, block_args))
|
||||
in_chs = builder.in_chs
|
||||
|
||||
@ -701,9 +689,8 @@ class GenEfficientNet(nn.Module):
|
||||
self.efficient_head = head_conv == 'efficient'
|
||||
self.conv_head = sconv2d(
|
||||
in_chs, self.num_features, 1,
|
||||
padding=_padding_arg(0, padding_same), bias=folded_bn and not self.efficient_head)
|
||||
self.bn2 = None if (folded_bn or self.efficient_head) else \
|
||||
nn.BatchNorm2d(self.num_features, momentum=bn_momentum, eps=bn_eps)
|
||||
padding=_padding_arg(0, padding_same), bias=False)
|
||||
self.bn2 = None if self.efficient_head else nn.BatchNorm2d(self.num_features, **bn_args)
|
||||
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.classifier = nn.Linear(self.num_features * self.global_pool.feat_mult(), self.num_classes)
|
||||
@ -729,25 +716,23 @@ class GenEfficientNet(nn.Module):
|
||||
|
||||
def forward_features(self, x, pool=True):
|
||||
x = self.conv_stem(x)
|
||||
if self.bn1 is not None:
|
||||
x = self.bn1(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.bn1(x)
|
||||
x = self.act_fn(x, inplace=True)
|
||||
x = self.blocks(x)
|
||||
if self.efficient_head:
|
||||
# efficient head, currently only mobilenet-v3 performs pool before last 1x1 conv
|
||||
x = self.global_pool(x) # always need to pool here regardless of flag
|
||||
x = self.conv_head(x)
|
||||
# no BN
|
||||
x = self.act_fn(x)
|
||||
x = self.act_fn(x, inplace=True)
|
||||
if pool:
|
||||
# expect flattened output if pool is true, otherwise keep dim
|
||||
x = x.view(x.size(0), -1)
|
||||
else:
|
||||
if self.conv_head is not None:
|
||||
x = self.conv_head(x)
|
||||
if self.bn2 is not None:
|
||||
x = self.bn2(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.bn2(x)
|
||||
x = self.act_fn(x, inplace=True)
|
||||
if pool:
|
||||
x = self.global_pool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
@ -785,7 +770,6 @@ def _gen_mnasnet_a1(channel_multiplier, num_classes=1000, **kwargs):
|
||||
# stage 6, 7x7 in
|
||||
['ir_r1_k3_s1_e6_c320'],
|
||||
]
|
||||
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
||||
model = GenEfficientNet(
|
||||
_decode_arch_def(arch_def),
|
||||
num_classes=num_classes,
|
||||
@ -793,8 +777,7 @@ def _gen_mnasnet_a1(channel_multiplier, num_classes=1000, **kwargs):
|
||||
channel_multiplier=channel_multiplier,
|
||||
channel_divisor=8,
|
||||
channel_min=None,
|
||||
bn_momentum=bn_momentum,
|
||||
bn_eps=bn_eps,
|
||||
bn_args=_resolve_bn_args(kwargs),
|
||||
**kwargs
|
||||
)
|
||||
return model
|
||||
@ -825,7 +808,6 @@ def _gen_mnasnet_b1(channel_multiplier, num_classes=1000, **kwargs):
|
||||
# stage 6, 7x7 in
|
||||
['ir_r1_k3_s1_e6_c320_noskip']
|
||||
]
|
||||
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
||||
model = GenEfficientNet(
|
||||
_decode_arch_def(arch_def),
|
||||
num_classes=num_classes,
|
||||
@ -833,8 +815,7 @@ def _gen_mnasnet_b1(channel_multiplier, num_classes=1000, **kwargs):
|
||||
channel_multiplier=channel_multiplier,
|
||||
channel_divisor=8,
|
||||
channel_min=None,
|
||||
bn_momentum=bn_momentum,
|
||||
bn_eps=bn_eps,
|
||||
bn_args=_resolve_bn_args(kwargs),
|
||||
**kwargs
|
||||
)
|
||||
return model
|
||||
@ -858,7 +839,6 @@ def _gen_mnasnet_small(channel_multiplier, num_classes=1000, **kwargs):
|
||||
['ir_r3_k5_s2_e6_c88_se0.25'],
|
||||
['ir_r1_k3_s1_e6_c144']
|
||||
]
|
||||
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
||||
model = GenEfficientNet(
|
||||
_decode_arch_def(arch_def),
|
||||
num_classes=num_classes,
|
||||
@ -866,8 +846,7 @@ def _gen_mnasnet_small(channel_multiplier, num_classes=1000, **kwargs):
|
||||
channel_multiplier=channel_multiplier,
|
||||
channel_divisor=8,
|
||||
channel_min=None,
|
||||
bn_momentum=bn_momentum,
|
||||
bn_eps=bn_eps,
|
||||
bn_args=_resolve_bn_args(kwargs),
|
||||
**kwargs
|
||||
)
|
||||
return model
|
||||
@ -885,7 +864,6 @@ def _gen_mobilenet_v1(channel_multiplier, num_classes=1000, **kwargs):
|
||||
['dsa_r6_k3_s2_c512'],
|
||||
['dsa_r2_k3_s2_c1024'],
|
||||
]
|
||||
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
||||
model = GenEfficientNet(
|
||||
_decode_arch_def(arch_def),
|
||||
num_classes=num_classes,
|
||||
@ -894,8 +872,7 @@ def _gen_mobilenet_v1(channel_multiplier, num_classes=1000, **kwargs):
|
||||
channel_multiplier=channel_multiplier,
|
||||
channel_divisor=8,
|
||||
channel_min=None,
|
||||
bn_momentum=bn_momentum,
|
||||
bn_eps=bn_eps,
|
||||
bn_args=_resolve_bn_args(kwargs),
|
||||
act_fn=F.relu6,
|
||||
head_conv='none',
|
||||
**kwargs
|
||||
@ -917,7 +894,6 @@ def _gen_mobilenet_v2(channel_multiplier, num_classes=1000, **kwargs):
|
||||
['ir_r3_k3_s2_e6_c160'],
|
||||
['ir_r1_k3_s1_e6_c320'],
|
||||
]
|
||||
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
||||
model = GenEfficientNet(
|
||||
_decode_arch_def(arch_def),
|
||||
num_classes=num_classes,
|
||||
@ -925,8 +901,7 @@ def _gen_mobilenet_v2(channel_multiplier, num_classes=1000, **kwargs):
|
||||
channel_multiplier=channel_multiplier,
|
||||
channel_divisor=8,
|
||||
channel_min=None,
|
||||
bn_momentum=bn_momentum,
|
||||
bn_eps=bn_eps,
|
||||
bn_args=_resolve_bn_args(kwargs),
|
||||
act_fn=F.relu6,
|
||||
**kwargs
|
||||
)
|
||||
@ -958,7 +933,6 @@ def _gen_mobilenet_v3(channel_multiplier, num_classes=1000, **kwargs):
|
||||
# stage 6, 7x7 in
|
||||
['cn_r1_k1_s1_c960'], # hard-swish
|
||||
]
|
||||
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
||||
model = GenEfficientNet(
|
||||
_decode_arch_def(arch_def),
|
||||
num_classes=num_classes,
|
||||
@ -966,8 +940,7 @@ def _gen_mobilenet_v3(channel_multiplier, num_classes=1000, **kwargs):
|
||||
channel_multiplier=channel_multiplier,
|
||||
channel_divisor=8,
|
||||
channel_min=None,
|
||||
bn_momentum=bn_momentum,
|
||||
bn_eps=bn_eps,
|
||||
bn_args=_resolve_bn_args(kwargs),
|
||||
act_fn=hard_swish,
|
||||
se_gate_fn=hard_sigmoid,
|
||||
se_reduce_mid=True,
|
||||
@ -994,7 +967,6 @@ def _gen_chamnet_v1(channel_multiplier, num_classes=1000, **kwargs):
|
||||
['ir_r4_k3_s2_e7_c152'],
|
||||
['ir_r1_k3_s1_e10_c104'],
|
||||
]
|
||||
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
||||
model = GenEfficientNet(
|
||||
_decode_arch_def(arch_def),
|
||||
num_classes=num_classes,
|
||||
@ -1003,8 +975,7 @@ def _gen_chamnet_v1(channel_multiplier, num_classes=1000, **kwargs):
|
||||
channel_multiplier=channel_multiplier,
|
||||
channel_divisor=8,
|
||||
channel_min=None,
|
||||
bn_momentum=bn_momentum,
|
||||
bn_eps=bn_eps,
|
||||
bn_args=_resolve_bn_args(kwargs),
|
||||
**kwargs
|
||||
)
|
||||
return model
|
||||
@ -1027,7 +998,6 @@ def _gen_chamnet_v2(channel_multiplier, num_classes=1000, **kwargs):
|
||||
['ir_r6_k3_s2_e2_c152'],
|
||||
['ir_r1_k3_s1_e6_c112'],
|
||||
]
|
||||
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
||||
model = GenEfficientNet(
|
||||
_decode_arch_def(arch_def),
|
||||
num_classes=num_classes,
|
||||
@ -1036,8 +1006,7 @@ def _gen_chamnet_v2(channel_multiplier, num_classes=1000, **kwargs):
|
||||
channel_multiplier=channel_multiplier,
|
||||
channel_divisor=8,
|
||||
channel_min=None,
|
||||
bn_momentum=bn_momentum,
|
||||
bn_eps=bn_eps,
|
||||
bn_args=_resolve_bn_args(kwargs),
|
||||
**kwargs
|
||||
)
|
||||
return model
|
||||
@ -1061,7 +1030,6 @@ def _gen_fbnetc(channel_multiplier, num_classes=1000, **kwargs):
|
||||
['ir_r4_k5_s2_e6_c184'],
|
||||
['ir_r1_k3_s1_e6_c352'],
|
||||
]
|
||||
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
||||
model = GenEfficientNet(
|
||||
_decode_arch_def(arch_def),
|
||||
num_classes=num_classes,
|
||||
@ -1070,8 +1038,7 @@ def _gen_fbnetc(channel_multiplier, num_classes=1000, **kwargs):
|
||||
channel_multiplier=channel_multiplier,
|
||||
channel_divisor=8,
|
||||
channel_min=None,
|
||||
bn_momentum=bn_momentum,
|
||||
bn_eps=bn_eps,
|
||||
bn_args=_resolve_bn_args(kwargs),
|
||||
**kwargs
|
||||
)
|
||||
return model
|
||||
@ -1101,7 +1068,6 @@ def _gen_spnasnet(channel_multiplier, num_classes=1000, **kwargs):
|
||||
# stage 6, 7x7 in
|
||||
['ir_r1_k3_s1_e6_c320_noskip']
|
||||
]
|
||||
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
||||
model = GenEfficientNet(
|
||||
_decode_arch_def(arch_def),
|
||||
num_classes=num_classes,
|
||||
@ -1109,8 +1075,7 @@ def _gen_spnasnet(channel_multiplier, num_classes=1000, **kwargs):
|
||||
channel_multiplier=channel_multiplier,
|
||||
channel_divisor=8,
|
||||
channel_min=None,
|
||||
bn_momentum=bn_momentum,
|
||||
bn_eps=bn_eps,
|
||||
bn_args=_resolve_bn_args(kwargs),
|
||||
**kwargs
|
||||
)
|
||||
return model
|
||||
@ -1147,7 +1112,6 @@ def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=
|
||||
['ir_r4_k5_s2_e6_c192_se0.25'],
|
||||
['ir_r1_k3_s1_e6_c320_se0.25'],
|
||||
]
|
||||
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
||||
# NOTE: other models in the family didn't scale the feature count
|
||||
num_features = _round_channels(1280, channel_multiplier, 8, None)
|
||||
model = GenEfficientNet(
|
||||
@ -1158,8 +1122,7 @@ def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=
|
||||
channel_divisor=8,
|
||||
channel_min=None,
|
||||
num_features=num_features,
|
||||
bn_momentum=bn_momentum,
|
||||
bn_eps=bn_eps,
|
||||
bn_args=_resolve_bn_args(kwargs),
|
||||
act_fn=swish,
|
||||
**kwargs
|
||||
)
|
||||
@ -1205,20 +1168,6 @@ def mnasnet_b1(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return mnasnet_100(pretrained, num_classes, in_chans, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def tflite_mnasnet_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" MNASNet B1, depth multiplier of 1.0. """
|
||||
default_cfg = default_cfgs['tflite_mnasnet_100']
|
||||
# these two args are for compat with tflite pretrained weights
|
||||
kwargs['folded_bn'] = True
|
||||
kwargs['padding_same'] = True
|
||||
model = _gen_mnasnet_b1(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mnasnet_140(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" MNASNet B1, depth multiplier of 1.4 """
|
||||
@ -1269,20 +1218,6 @@ def mnasnet_a1(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return semnasnet_100(pretrained, num_classes, in_chans, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def tflite_semnasnet_100(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" MNASNet A1, depth multiplier of 1.0. """
|
||||
default_cfg = default_cfgs['tflite_semnasnet_100']
|
||||
# these two args are for compat with tflite pretrained weights
|
||||
kwargs['folded_bn'] = True
|
||||
kwargs['padding_same'] = True
|
||||
model = _gen_mnasnet_a1(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def semnasnet_140(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" MNASNet A1 (w/ SE), depth multiplier of 1.4. """
|
||||
|
Loading…
x
Reference in New Issue
Block a user