mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #115 from rwightman/mobilenetv2-experiment
MobileNet-V2 experiments
This commit is contained in:
commit
c99a5abed4
11
README.md
11
README.md
@ -2,6 +2,13 @@
|
|||||||
|
|
||||||
## What's New
|
## What's New
|
||||||
|
|
||||||
|
### April 5, 2020
|
||||||
|
* Add some newly trained MobileNet-V2 models trained with latest h-params, rand augment. They compare quite favourably to EfficientNet-Lite
|
||||||
|
* 3.5M param MobileNet-V2 100 @ 73%
|
||||||
|
* 4.5M param MobileNet-V2 110d @ 75%
|
||||||
|
* 6.1M param MobileNet-V2 140 @ 76.5%
|
||||||
|
* 5.8M param MobileNet-V2 120d @ 77.3%
|
||||||
|
|
||||||
### March 18, 2020
|
### March 18, 2020
|
||||||
* Add EfficientNet-Lite models w/ weights ported from [Tensorflow TPU](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite)
|
* Add EfficientNet-Lite models w/ weights ported from [Tensorflow TPU](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite)
|
||||||
* Add RandAugment trained ResNeXt-50 32x4d weights with 79.8 top-1. Trained by [Andrew Lavin](https://github.com/andravin) (see Training section for hparams)
|
* Add RandAugment trained ResNeXt-50 32x4d weights with 79.8 top-1. Trained by [Andrew Lavin](https://github.com/andravin) (see Training section for hparams)
|
||||||
@ -194,10 +201,12 @@ I've leveraged the training scripts in this repository to train a few of the mod
|
|||||||
| seresnext26tn_32x4d | 77.986 (22.014) | 93.746 (6.254) | 16.8M | bicubic | 224 |
|
| seresnext26tn_32x4d | 77.986 (22.014) | 93.746 (6.254) | 16.8M | bicubic | 224 |
|
||||||
| efficientnet_b0 | 77.698 (22.302) | 93.532 (6.468) | 5.29M | bicubic | 224 |
|
| efficientnet_b0 | 77.698 (22.302) | 93.532 (6.468) | 5.29M | bicubic | 224 |
|
||||||
| seresnext26d_32x4d | 77.602 (22.398) | 93.608 (6.392) | 16.8M | bicubic | 224 |
|
| seresnext26d_32x4d | 77.602 (22.398) | 93.608 (6.392) | 16.8M | bicubic | 224 |
|
||||||
|
| mobilenetv2_120d | 77.294 (22.706 | 93.502 (6.498) | 5.8M | bicubic | 224 |
|
||||||
| mixnet_m | 77.256 (22.744) | 93.418 (6.582) | 5.01M | bicubic | 224 |
|
| mixnet_m | 77.256 (22.744) | 93.418 (6.582) | 5.01M | bicubic | 224 |
|
||||||
| seresnext26_32x4d | 77.104 (22.896) | 93.316 (6.684) | 16.8M | bicubic | 224 |
|
| seresnext26_32x4d | 77.104 (22.896) | 93.316 (6.684) | 16.8M | bicubic | 224 |
|
||||||
| skresnet34 | 76.912 (23.088) | 93.322 (6.678) | 22.2M | bicubic | 224 |
|
| skresnet34 | 76.912 (23.088) | 93.322 (6.678) | 22.2M | bicubic | 224 |
|
||||||
| resnet26d | 76.68 (23.32) | 93.166 (6.834) | 16M | bicubic | 224 |
|
| resnet26d | 76.68 (23.32) | 93.166 (6.834) | 16M | bicubic | 224 |
|
||||||
|
| mobilenetv2_140 | 76.524 (23.476) | 92.990 (7.010) | 6.1M | bicubic | 224 |
|
||||||
| mixnet_s | 75.988 (24.012) | 92.794 (7.206) | 4.13M | bicubic | 224 |
|
| mixnet_s | 75.988 (24.012) | 92.794 (7.206) | 4.13M | bicubic | 224 |
|
||||||
| mobilenetv3_large_100 | 75.766 (24.234) | 92.542 (7.458) | 5.5M | bicubic | 224 |
|
| mobilenetv3_large_100 | 75.766 (24.234) | 92.542 (7.458) | 5.5M | bicubic | 224 |
|
||||||
| mobilenetv3_rw | 75.634 (24.366) | 92.708 (7.292) | 5.5M | bicubic | 224 |
|
| mobilenetv3_rw | 75.634 (24.366) | 92.708 (7.292) | 5.5M | bicubic | 224 |
|
||||||
@ -205,10 +214,12 @@ I've leveraged the training scripts in this repository to train a few of the mod
|
|||||||
| resnet26 | 75.292 (24.708) | 92.57 (7.43) | 16M | bicubic | 224 |
|
| resnet26 | 75.292 (24.708) | 92.57 (7.43) | 16M | bicubic | 224 |
|
||||||
| fbnetc_100 | 75.124 (24.876) | 92.386 (7.614) | 5.6M | bilinear | 224 |
|
| fbnetc_100 | 75.124 (24.876) | 92.386 (7.614) | 5.6M | bilinear | 224 |
|
||||||
| resnet34 | 75.110 (24.890) | 92.284 (7.716) | 22M | bilinear | 224 |
|
| resnet34 | 75.110 (24.890) | 92.284 (7.716) | 22M | bilinear | 224 |
|
||||||
|
| mobilenetv2_110d | 75.052 (24.948) | 92.180 (7.820) | 4.5M | bicubic | 224 |
|
||||||
| seresnet34 | 74.808 (25.192) | 92.124 (7.876) | 22M | bilinear | 224 |
|
| seresnet34 | 74.808 (25.192) | 92.124 (7.876) | 22M | bilinear | 224 |
|
||||||
| mnasnet_b1 | 74.658 (25.342) | 92.114 (7.886) | 4.38M | bicubic | 224 |
|
| mnasnet_b1 | 74.658 (25.342) | 92.114 (7.886) | 4.38M | bicubic | 224 |
|
||||||
| spnasnet_100 | 74.084 (25.916) | 91.818 (8.182) | 4.42M | bilinear | 224 |
|
| spnasnet_100 | 74.084 (25.916) | 91.818 (8.182) | 4.42M | bilinear | 224 |
|
||||||
| skresnet18 | 73.038 (26.962) | 91.168 (8.832) | 11.9M | bicubic | 224 |
|
| skresnet18 | 73.038 (26.962) | 91.168 (8.832) | 11.9M | bicubic | 224 |
|
||||||
|
| mobilenetv2_100 | 72.978 (27.022) | 91.016 (8.984) | 3.5M | bicubic | 224 |
|
||||||
| seresnet18 | 71.742 (28.258) | 90.334 (9.666) | 11.8M | bicubic | 224 |
|
| seresnet18 | 71.742 (28.258) | 90.334 (9.666) | 11.8M | bicubic | 224 |
|
||||||
|
|
||||||
### Ported Weights
|
### Ported Weights
|
||||||
|
@ -60,7 +60,15 @@ default_cfgs = {
|
|||||||
'semnasnet_140': _cfg(url=''),
|
'semnasnet_140': _cfg(url=''),
|
||||||
'mnasnet_small': _cfg(url=''),
|
'mnasnet_small': _cfg(url=''),
|
||||||
|
|
||||||
'mobilenetv2_100': _cfg(url=''),
|
'mobilenetv2_100': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_100_ra-b33bc2c4.pth'),
|
||||||
|
'mobilenetv2_110d': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_110d_ra-77090ade.pth'),
|
||||||
|
'mobilenetv2_120d': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_120d_ra-5987e2ed.pth'),
|
||||||
|
'mobilenetv2_140': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_140_ra-21a4e913.pth'),
|
||||||
|
|
||||||
'fbnetc_100': _cfg(
|
'fbnetc_100': _cfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetc_100-c345b898.pth',
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetc_100-c345b898.pth',
|
||||||
interpolation='bilinear'),
|
interpolation='bilinear'),
|
||||||
@ -318,6 +326,7 @@ class EfficientNet(nn.Module):
|
|||||||
# Stem
|
# Stem
|
||||||
if not fix_stem:
|
if not fix_stem:
|
||||||
stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
|
stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
|
||||||
|
print(stem_size)
|
||||||
self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
|
self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
|
||||||
self.bn1 = norm_layer(stem_size, **norm_kwargs)
|
self.bn1 = norm_layer(stem_size, **norm_kwargs)
|
||||||
self.act1 = act_layer(inplace=True)
|
self.act1 = act_layer(inplace=True)
|
||||||
@ -565,7 +574,8 @@ def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwar
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def _gen_mobilenet_v2(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
def _gen_mobilenet_v2(
|
||||||
|
variant, channel_multiplier=1.0, depth_multiplier=1.0, fix_stem_head=False, pretrained=False, **kwargs):
|
||||||
""" Generate MobileNet-V2 network
|
""" Generate MobileNet-V2 network
|
||||||
Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py
|
Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py
|
||||||
Paper: https://arxiv.org/abs/1801.04381
|
Paper: https://arxiv.org/abs/1801.04381
|
||||||
@ -580,8 +590,10 @@ def _gen_mobilenet_v2(variant, channel_multiplier=1.0, pretrained=False, **kwarg
|
|||||||
['ir_r1_k3_s1_e6_c320'],
|
['ir_r1_k3_s1_e6_c320'],
|
||||||
]
|
]
|
||||||
model_kwargs = dict(
|
model_kwargs = dict(
|
||||||
block_args=decode_arch_def(arch_def),
|
block_args=decode_arch_def(arch_def, depth_multiplier=depth_multiplier, fix_first_last=fix_stem_head),
|
||||||
|
num_features=1280 if fix_stem_head else round_channels(1280, channel_multiplier, 8, None),
|
||||||
stem_size=32,
|
stem_size=32,
|
||||||
|
fix_stem=fix_stem_head,
|
||||||
channel_multiplier=channel_multiplier,
|
channel_multiplier=channel_multiplier,
|
||||||
norm_kwargs=resolve_bn_args(kwargs),
|
norm_kwargs=resolve_bn_args(kwargs),
|
||||||
act_layer=nn.ReLU6,
|
act_layer=nn.ReLU6,
|
||||||
@ -945,11 +957,34 @@ def mnasnet_small(pretrained=False, **kwargs):
|
|||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def mobilenetv2_100(pretrained=False, **kwargs):
|
def mobilenetv2_100(pretrained=False, **kwargs):
|
||||||
""" MobileNet V2 """
|
""" MobileNet V2 w/ 1.0 channel multiplier """
|
||||||
model = _gen_mobilenet_v2('mobilenetv2_100', 1.0, pretrained=pretrained, **kwargs)
|
model = _gen_mobilenet_v2('mobilenetv2_100', 1.0, pretrained=pretrained, **kwargs)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def mobilenetv2_140(pretrained=False, **kwargs):
|
||||||
|
""" MobileNet V2 w/ 1.4 channel multiplier """
|
||||||
|
model = _gen_mobilenet_v2('mobilenetv2_140', 1.4, pretrained=pretrained, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def mobilenetv2_110d(pretrained=False, **kwargs):
|
||||||
|
""" MobileNet V2 w/ 1.1 channel, 1.2 depth multipliers"""
|
||||||
|
model = _gen_mobilenet_v2(
|
||||||
|
'mobilenetv2_110d', 1.1, depth_multiplier=1.2, fix_stem_head=True, pretrained=pretrained, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def mobilenetv2_120d(pretrained=False, **kwargs):
|
||||||
|
""" MobileNet V2 w/ 1.2 channel, 1.4 depth multipliers """
|
||||||
|
model = _gen_mobilenet_v2(
|
||||||
|
'mobilenetv2_120d', 1.2, depth_multiplier=1.4, fix_stem_head=True, pretrained=pretrained, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def fbnetc_100(pretrained=False, **kwargs):
|
def fbnetc_100(pretrained=False, **kwargs):
|
||||||
""" FBNet-C """
|
""" FBNet-C """
|
||||||
|
Loading…
x
Reference in New Issue
Block a user