mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add new auto-augmentation Tensorflow EfficientNet weights, incl B6 and B7 models. Validation scores still pending but looking good.
This commit is contained in:
parent
857f33015a
commit
77e2e0c4e3
@ -84,24 +84,34 @@ default_cfgs = {
|
||||
url='', input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
|
||||
'efficientnet_b5': _cfg(
|
||||
url='', input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934),
|
||||
'efficientnet_b6': _cfg(
|
||||
url='', input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942),
|
||||
'efficientnet_b7': _cfg(
|
||||
url='', input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949),
|
||||
'tf_efficientnet_b0': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0-0af12548.pth',
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth',
|
||||
input_size=(3, 224, 224)),
|
||||
'tf_efficientnet_b1': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1-5c1377c4.pth',
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_aa-ea7a6ee0.pth',
|
||||
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
|
||||
'tf_efficientnet_b2': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2-e393ef04.pth',
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_aa-60c94f97.pth',
|
||||
input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890),
|
||||
'tf_efficientnet_b3': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3-e3bd6955.pth',
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_aa-84b4657e.pth',
|
||||
input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
|
||||
'tf_efficientnet_b4': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4-74ee3bed.pth',
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_aa-818f208c.pth',
|
||||
input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
|
||||
'tf_efficientnet_b5': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5-c6949ce9.pth',
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_aa-99018a74.pth',
|
||||
input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934),
|
||||
'tf_efficientnet_b6': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_aa-80ba17e4.pth',
|
||||
input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942),
|
||||
'tf_efficientnet_b7': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_aa-076e3472.pth',
|
||||
input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949),
|
||||
'mixnet_s': _cfg(url=''),
|
||||
'mixnet_m': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_m-4647fc68.pth'),
|
||||
@ -763,8 +773,6 @@ def _gen_mnasnet_a1(channel_multiplier, num_classes=1000, **kwargs):
|
||||
num_classes=num_classes,
|
||||
stem_size=32,
|
||||
channel_multiplier=channel_multiplier,
|
||||
channel_divisor=8,
|
||||
channel_min=None,
|
||||
bn_args=_resolve_bn_args(kwargs),
|
||||
**kwargs
|
||||
)
|
||||
@ -801,8 +809,6 @@ def _gen_mnasnet_b1(channel_multiplier, num_classes=1000, **kwargs):
|
||||
num_classes=num_classes,
|
||||
stem_size=32,
|
||||
channel_multiplier=channel_multiplier,
|
||||
channel_divisor=8,
|
||||
channel_min=None,
|
||||
bn_args=_resolve_bn_args(kwargs),
|
||||
**kwargs
|
||||
)
|
||||
@ -832,8 +838,6 @@ def _gen_mnasnet_small(channel_multiplier, num_classes=1000, **kwargs):
|
||||
num_classes=num_classes,
|
||||
stem_size=8,
|
||||
channel_multiplier=channel_multiplier,
|
||||
channel_divisor=8,
|
||||
channel_min=None,
|
||||
bn_args=_resolve_bn_args(kwargs),
|
||||
**kwargs
|
||||
)
|
||||
@ -858,8 +862,6 @@ def _gen_mobilenet_v1(channel_multiplier, num_classes=1000, **kwargs):
|
||||
stem_size=32,
|
||||
num_features=1024,
|
||||
channel_multiplier=channel_multiplier,
|
||||
channel_divisor=8,
|
||||
channel_min=None,
|
||||
bn_args=_resolve_bn_args(kwargs),
|
||||
act_fn=F.relu6,
|
||||
head_conv='none',
|
||||
@ -887,8 +889,6 @@ def _gen_mobilenet_v2(channel_multiplier, num_classes=1000, **kwargs):
|
||||
num_classes=num_classes,
|
||||
stem_size=32,
|
||||
channel_multiplier=channel_multiplier,
|
||||
channel_divisor=8,
|
||||
channel_min=None,
|
||||
bn_args=_resolve_bn_args(kwargs),
|
||||
act_fn=F.relu6,
|
||||
**kwargs
|
||||
@ -926,8 +926,6 @@ def _gen_mobilenet_v3(channel_multiplier, num_classes=1000, **kwargs):
|
||||
num_classes=num_classes,
|
||||
stem_size=16,
|
||||
channel_multiplier=channel_multiplier,
|
||||
channel_divisor=8,
|
||||
channel_min=None,
|
||||
bn_args=_resolve_bn_args(kwargs),
|
||||
act_fn=hard_swish,
|
||||
se_gate_fn=hard_sigmoid,
|
||||
@ -961,8 +959,6 @@ def _gen_chamnet_v1(channel_multiplier, num_classes=1000, **kwargs):
|
||||
stem_size=32,
|
||||
num_features=1280, # no idea what this is? try mobile/mnasnet default?
|
||||
channel_multiplier=channel_multiplier,
|
||||
channel_divisor=8,
|
||||
channel_min=None,
|
||||
bn_args=_resolve_bn_args(kwargs),
|
||||
**kwargs
|
||||
)
|
||||
@ -992,8 +988,6 @@ def _gen_chamnet_v2(channel_multiplier, num_classes=1000, **kwargs):
|
||||
stem_size=32,
|
||||
num_features=1280, # no idea what this is? try mobile/mnasnet default?
|
||||
channel_multiplier=channel_multiplier,
|
||||
channel_divisor=8,
|
||||
channel_min=None,
|
||||
bn_args=_resolve_bn_args(kwargs),
|
||||
**kwargs
|
||||
)
|
||||
@ -1024,8 +1018,6 @@ def _gen_fbnetc(channel_multiplier, num_classes=1000, **kwargs):
|
||||
stem_size=16,
|
||||
num_features=1984, # paper suggests this, but is not 100% clear
|
||||
channel_multiplier=channel_multiplier,
|
||||
channel_divisor=8,
|
||||
channel_min=None,
|
||||
bn_args=_resolve_bn_args(kwargs),
|
||||
**kwargs
|
||||
)
|
||||
@ -1061,8 +1053,6 @@ def _gen_spnasnet(channel_multiplier, num_classes=1000, **kwargs):
|
||||
num_classes=num_classes,
|
||||
stem_size=32,
|
||||
channel_multiplier=channel_multiplier,
|
||||
channel_divisor=8,
|
||||
channel_min=None,
|
||||
bn_args=_resolve_bn_args(kwargs),
|
||||
**kwargs
|
||||
)
|
||||
@ -1107,8 +1097,6 @@ def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=
|
||||
num_classes=num_classes,
|
||||
stem_size=32,
|
||||
channel_multiplier=channel_multiplier,
|
||||
channel_divisor=8,
|
||||
channel_min=None,
|
||||
num_features=num_features,
|
||||
bn_args=_resolve_bn_args(kwargs),
|
||||
act_fn=swish,
|
||||
@ -1144,8 +1132,6 @@ def _gen_mixnet_s(channel_multiplier=1.0, num_classes=1000, **kwargs):
|
||||
stem_size=16,
|
||||
num_features=1536,
|
||||
channel_multiplier=channel_multiplier,
|
||||
channel_divisor=8,
|
||||
channel_min=None,
|
||||
bn_args=_resolve_bn_args(kwargs),
|
||||
act_fn=F.relu,
|
||||
**kwargs
|
||||
@ -1180,8 +1166,6 @@ def _gen_mixnet_m(channel_multiplier=1.0, num_classes=1000, **kwargs):
|
||||
stem_size=24,
|
||||
num_features=1536,
|
||||
channel_multiplier=channel_multiplier,
|
||||
channel_divisor=8,
|
||||
channel_min=None,
|
||||
bn_args=_resolve_bn_args(kwargs),
|
||||
act_fn=F.relu,
|
||||
**kwargs
|
||||
@ -1495,6 +1479,37 @@ def efficientnet_b5(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientnet_b6(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" EfficientNet-B6 """
|
||||
# NOTE for train, drop_rate should be 0.5
|
||||
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
|
||||
default_cfg = default_cfgs['efficientnet_b6']
|
||||
model = _gen_efficientnet(
|
||||
channel_multiplier=1.8, depth_multiplier=2.6,
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientnet_b7(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" EfficientNet-B7 """
|
||||
# NOTE for train, drop_rate should be 0.5
|
||||
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
|
||||
default_cfg = default_cfgs['efficientnet_b7']
|
||||
model = _gen_efficientnet(
|
||||
channel_multiplier=2.0, depth_multiplier=3.1,
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_efficientnet_b0(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" EfficientNet-B0. Tensorflow compatible variant """
|
||||
@ -1585,6 +1600,38 @@ def tf_efficientnet_b5(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_efficientnet_b6(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" EfficientNet-B6. Tensorflow compatible variant """
|
||||
# NOTE for train, drop_rate should be 0.5
|
||||
default_cfg = default_cfgs['tf_efficientnet_b6']
|
||||
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
model = _gen_efficientnet(
|
||||
channel_multiplier=1.8, depth_multiplier=2.6,
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_efficientnet_b7(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" EfficientNet-B7. Tensorflow compatible variant """
|
||||
# NOTE for train, drop_rate should be 0.5
|
||||
default_cfg = default_cfgs['tf_efficientnet_b7']
|
||||
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
model = _gen_efficientnet(
|
||||
channel_multiplier=2.0, depth_multiplier=3.1,
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mixnet_s(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Creates a MixNet Small model.
|
||||
|
Loading…
x
Reference in New Issue
Block a user