mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
EfficientNet-Lite model added w/ converted checkpoints, validation in progress...
This commit is contained in:
parent
7deacf5477
commit
bd05258f7b
@ -52,12 +52,14 @@ default_cfgs = {
|
||||
'mnasnet_100': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_b1-74cb7081.pth'),
|
||||
'mnasnet_140': _cfg(url=''),
|
||||
|
||||
'semnasnet_050': _cfg(url=''),
|
||||
'semnasnet_075': _cfg(url=''),
|
||||
'semnasnet_100': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_a1-d9418771.pth'),
|
||||
'semnasnet_140': _cfg(url=''),
|
||||
'mnasnet_small': _cfg(url=''),
|
||||
|
||||
'mobilenetv2_100': _cfg(url=''),
|
||||
'fbnetc_100': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetc_100-c345b898.pth',
|
||||
@ -65,6 +67,7 @@ default_cfgs = {
|
||||
'spnasnet_100': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/spnasnet_100-048bc3f4.pth',
|
||||
interpolation='bilinear'),
|
||||
|
||||
'efficientnet_b0': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0_ra-3dd342df.pth'),
|
||||
'efficientnet_b1': _cfg(
|
||||
@ -94,15 +97,32 @@ default_cfgs = {
|
||||
url='', input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954),
|
||||
'efficientnet_l2': _cfg(
|
||||
url='', input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.961),
|
||||
|
||||
'efficientnet_es': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_es_ra-f111e99c.pth'),
|
||||
'efficientnet_em': _cfg(
|
||||
url='', input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
|
||||
'efficientnet_el': _cfg(
|
||||
url='', input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
|
||||
|
||||
'efficientnet_cc_b0_4e': _cfg(url=''),
|
||||
'efficientnet_cc_b0_8e': _cfg(url=''),
|
||||
'efficientnet_cc_b1_8e': _cfg(url='', input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
|
||||
|
||||
'efficientnet_lite0': _cfg(
|
||||
url=''),
|
||||
'efficientnet_lite1': _cfg(
|
||||
url='',
|
||||
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
|
||||
'efficientnet_lite2': _cfg(
|
||||
url='',
|
||||
input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890),
|
||||
'efficientnet_lite3': _cfg(
|
||||
url='',
|
||||
input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
|
||||
'efficientnet_lite4': _cfg(
|
||||
url='', input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
|
||||
|
||||
'tf_efficientnet_b0': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth',
|
||||
input_size=(3, 224, 224)),
|
||||
@ -130,6 +150,7 @@ default_cfgs = {
|
||||
'tf_efficientnet_b8': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ra-572d5dd9.pth',
|
||||
input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954),
|
||||
|
||||
'tf_efficientnet_b0_ap': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ap-f262efe1.pth',
|
||||
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, input_size=(3, 224, 224)),
|
||||
@ -165,6 +186,7 @@ default_cfgs = {
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ap-00e169fa.pth',
|
||||
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
|
||||
input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954),
|
||||
|
||||
'tf_efficientnet_b0_ns': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ns-c0e6a31c.pth',
|
||||
input_size=(3, 224, 224)),
|
||||
@ -195,6 +217,7 @@ default_cfgs = {
|
||||
'tf_efficientnet_l2_ns': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns-df73bb44.pth',
|
||||
input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.96),
|
||||
|
||||
'tf_efficientnet_es': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||
@ -207,6 +230,7 @@ default_cfgs = {
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_el-5143854e.pth',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||
input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
|
||||
|
||||
'tf_efficientnet_cc_b0_4e': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_4e-4362b6b2.pth',
|
||||
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
|
||||
@ -217,6 +241,33 @@ default_cfgs = {
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b1_8e-f7c79ae1.pth',
|
||||
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
|
||||
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
|
||||
|
||||
'tf_efficientnet_lite0': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite0-0aa007d2.pth',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||
interpolation='bicubic', # should be bilinear but bicubic better match for TF bilinear at low res
|
||||
),
|
||||
'tf_efficientnet_lite1': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite1-bde8b488.pth',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882,
|
||||
interpolation='bicubic', # should be bilinear but bicubic better match for TF bilinear at low res
|
||||
),
|
||||
'tf_efficientnet_lite2': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite2-dcccb7df.pth',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||
input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890,
|
||||
interpolation='bicubic', # should be bilinear but bicubic better match for TF bilinear at low res
|
||||
),
|
||||
'tf_efficientnet_lite3': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite3-b733e338.pth',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||
input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904, interpolation='bilinear'),
|
||||
'tf_efficientnet_lite4': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite4-741542c3.pth',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||
input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922, interpolation='bilinear'),
|
||||
|
||||
'mixnet_s': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_s-a907afbc.pth'),
|
||||
'mixnet_m': _cfg(
|
||||
@ -226,6 +277,7 @@ default_cfgs = {
|
||||
'mixnet_xl': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_xl_ra-aac3c00c.pth'),
|
||||
'mixnet_xxl': _cfg(),
|
||||
|
||||
'tf_mixnet_s': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_s-89d3354b.pth'),
|
||||
'tf_mixnet_m': _cfg(
|
||||
@ -253,7 +305,7 @@ class EfficientNet(nn.Module):
|
||||
|
||||
def __init__(self, block_args, num_classes=1000, num_features=1280, in_chans=3, stem_size=32,
|
||||
channel_multiplier=1.0, channel_divisor=8, channel_min=None,
|
||||
output_stride=32, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0.,
|
||||
output_stride=32, pad_type='', fix_stem=False, act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0.,
|
||||
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, global_pool='avg'):
|
||||
super(EfficientNet, self).__init__()
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
@ -264,7 +316,8 @@ class EfficientNet(nn.Module):
|
||||
self._in_chs = in_chans
|
||||
|
||||
# Stem
|
||||
stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
|
||||
if not fix_stem:
|
||||
stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
|
||||
self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
|
||||
self.bn1 = norm_layer(stem_size, **norm_kwargs)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
@ -333,7 +386,7 @@ class EfficientNetFeatures(nn.Module):
|
||||
|
||||
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pre_pwl',
|
||||
in_chans=3, stem_size=32, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
|
||||
output_stride=32, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0.,
|
||||
output_stride=32, pad_type='', fix_stem=False, act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0.,
|
||||
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None):
|
||||
super(EfficientNetFeatures, self).__init__()
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
@ -346,7 +399,8 @@ class EfficientNetFeatures(nn.Module):
|
||||
self._in_chs = in_chans
|
||||
|
||||
# Stem
|
||||
stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
|
||||
if not fix_stem:
|
||||
stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
|
||||
self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
|
||||
self.bn1 = norm_layer(stem_size, **norm_kwargs)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
@ -707,6 +761,47 @@ def _gen_efficientnet_condconv(
|
||||
return model
|
||||
|
||||
|
||||
def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
|
||||
"""Creates an EfficientNet-Lite model.
|
||||
|
||||
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite
|
||||
Paper: https://arxiv.org/abs/1905.11946
|
||||
|
||||
EfficientNet params
|
||||
name: (channel_multiplier, depth_multiplier, resolution, dropout_rate)
|
||||
'efficientnet-lite0': (1.0, 1.0, 224, 0.2),
|
||||
'efficientnet-lite1': (1.0, 1.1, 240, 0.2),
|
||||
'efficientnet-lite2': (1.1, 1.2, 260, 0.3),
|
||||
'efficientnet-lite3': (1.2, 1.4, 280, 0.3),
|
||||
'efficientnet-lite4': (1.4, 1.8, 300, 0.3),
|
||||
|
||||
Args:
|
||||
channel_multiplier: multiplier to number of channels per layer
|
||||
depth_multiplier: multiplier to number of repeats per stage
|
||||
"""
|
||||
arch_def = [
|
||||
['ds_r1_k3_s1_e1_c16'],
|
||||
['ir_r2_k3_s2_e6_c24'],
|
||||
['ir_r2_k5_s2_e6_c40'],
|
||||
['ir_r3_k3_s2_e6_c80'],
|
||||
['ir_r3_k5_s1_e6_c112'],
|
||||
['ir_r4_k5_s2_e6_c192'],
|
||||
['ir_r1_k3_s1_e6_c320'],
|
||||
]
|
||||
model_kwargs = dict(
|
||||
block_args=decode_arch_def(arch_def, depth_multiplier, fix_first_last=True),
|
||||
num_features=1280,
|
||||
stem_size=32,
|
||||
fix_stem=True,
|
||||
channel_multiplier=channel_multiplier,
|
||||
act_layer=nn.ReLU6,
|
||||
norm_kwargs=resolve_bn_args(kwargs),
|
||||
**kwargs,
|
||||
)
|
||||
model = _create_model(model_kwargs, default_cfgs[variant], pretrained)
|
||||
return model
|
||||
|
||||
|
||||
def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
||||
"""Creates a MixNet Small model.
|
||||
|
||||
@ -1032,6 +1127,51 @@ def efficientnet_cc_b1_8e(pretrained=False, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientnet_lite0(pretrained=False, **kwargs):
|
||||
""" EfficientNet-Lite0 """
|
||||
# NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
|
||||
model = _gen_efficientnet_lite(
|
||||
'efficientnet_lite0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientnet_lite1(pretrained=False, **kwargs):
|
||||
""" EfficientNet-Lite1 """
|
||||
# NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
|
||||
model = _gen_efficientnet_lite(
|
||||
'efficientnet_lite1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientnet_lite2(pretrained=False, **kwargs):
|
||||
""" EfficientNet-Lite2 """
|
||||
# NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2
|
||||
model = _gen_efficientnet_lite(
|
||||
'efficientnet_lite2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientnet_lite3(pretrained=False, **kwargs):
|
||||
""" EfficientNet-Lite3 """
|
||||
# NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2
|
||||
model = _gen_efficientnet_lite(
|
||||
'efficientnet_lite3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientnet_lite4(pretrained=False, **kwargs):
|
||||
""" EfficientNet-Lite4 """
|
||||
# NOTE for train, drop_rate should be 0.4, drop_path_rate should be 0.2
|
||||
model = _gen_efficientnet_lite(
|
||||
'efficientnet_lite4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_efficientnet_b0(pretrained=False, **kwargs):
|
||||
""" EfficientNet-B0. Tensorflow compatible variant """
|
||||
@ -1386,6 +1526,61 @@ def tf_efficientnet_cc_b1_8e(pretrained=False, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_efficientnet_lite0(pretrained=False, **kwargs):
|
||||
""" EfficientNet-Lite0 """
|
||||
# NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
model = _gen_efficientnet_lite(
|
||||
'tf_efficientnet_lite0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_efficientnet_lite1(pretrained=False, **kwargs):
|
||||
""" EfficientNet-Lite1 """
|
||||
# NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
model = _gen_efficientnet_lite(
|
||||
'tf_efficientnet_lite1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_efficientnet_lite2(pretrained=False, **kwargs):
|
||||
""" EfficientNet-Lite2 """
|
||||
# NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
model = _gen_efficientnet_lite(
|
||||
'tf_efficientnet_lite2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_efficientnet_lite3(pretrained=False, **kwargs):
|
||||
""" EfficientNet-Lite3 """
|
||||
# NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
model = _gen_efficientnet_lite(
|
||||
'tf_efficientnet_lite3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_efficientnet_lite4(pretrained=False, **kwargs):
|
||||
""" EfficientNet-Lite4 """
|
||||
# NOTE for train, drop_rate should be 0.4, drop_path_rate should be 0.2
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
model = _gen_efficientnet_lite(
|
||||
'tf_efficientnet_lite4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mixnet_s(pretrained=False, **kwargs):
|
||||
"""Creates a MixNet Small model.
|
||||
|
@ -174,7 +174,7 @@ def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='c
|
||||
return sa_scaled
|
||||
|
||||
|
||||
def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_multiplier=1):
|
||||
def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_multiplier=1, fix_first_last=False):
|
||||
arch_args = []
|
||||
for stack_idx, block_strings in enumerate(arch_def):
|
||||
assert isinstance(block_strings, list)
|
||||
@ -187,7 +187,10 @@ def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_
|
||||
ba['num_experts'] *= experts_multiplier
|
||||
stack_args.append(ba)
|
||||
repeats.append(rep)
|
||||
arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc))
|
||||
if fix_first_last and (stack_idx == 0 or stack_idx == len(arch_def) - 1):
|
||||
arch_args.append(_scale_stage_depth(stack_args, repeats, 1.0, depth_trunc))
|
||||
else:
|
||||
arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc))
|
||||
return arch_args
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user