mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
commit
dc0630fd94
@ -33,6 +33,7 @@ I've included a few of my favourite models, but this is not an exhaustive collec
|
||||
* DPN-68, DPN-68b, DPN-92, DPN-98, DPN-131, DPN-107
|
||||
* Generic EfficientNet (from my standalone [GenMobileNet](https://github.com/rwightman/genmobilenet-pytorch)) - A generic model that implements many of the efficient models that utilize similar DepthwiseSeparable and InvertedResidual blocks
|
||||
* EfficientNet (B0-B7) (https://arxiv.org/abs/1905.11946) -- validated, compat with TF weights
|
||||
* EfficientNet-EdgeTPU (S, M, L) (https://ai.googleblog.com/2019/08/efficientnet-edgetpu-creating.html) --validated w/ TF weights
|
||||
* MixNet (https://arxiv.org/abs/1907.09595) -- validated, compat with TF weights
|
||||
* MNASNet B1, A1 (Squeeze-Excite), and Small (https://arxiv.org/abs/1807.11626)
|
||||
* MobileNet-V1 (https://arxiv.org/abs/1704.04861)
|
||||
@ -71,6 +72,7 @@ I've leveraged the training scripts in this repository to train a few of the mod
|
||||
|
||||
|Model | Prec@1 (Err) | Prec@5 (Err) | Param # | Image Scaling | Image Size |
|
||||
|---|---|---|---|---|---|
|
||||
| mixnet_xl | 80.120 (19.880) | 95.022 (4.978) | 11.90M | bicubic | 224 |
|
||||
| efficientnet_b2 | 79.760 (20.240) | 94.714 (5.286) | 9.11M | bicubic | 260 |
|
||||
| resnext50d_32x4d | 79.674 (20.326) | 94.868 (5.132) | 25.1M | bicubic | 224 |
|
||||
| mixnet_l | 78.976 (21.024 | 94.184 (5.816) | 7.33M | bicubic | 224 |
|
||||
@ -111,6 +113,8 @@ I've leveraged the training scripts in this repository to train a few of the mod
|
||||
| gluon_seresnext101_32x4d | 80.902 (19.098) | 95.294 (4.706) | 48.96 | bicubic | 224 | |
|
||||
| gluon_seresnext101_64x4d | 80.890 (19.110) | 95.304 (4.696) | 88.23 | bicubic | 224 | |
|
||||
| gluon_resnext101_64x4d | 80.602 (19.398) | 94.994 (5.006) | 83.46 | bicubic | 224 | |
|
||||
| tf_efficientnet_el | 80.534 (19.466) | 95.190 (4.810) | 10.59 | bicubic | 300 | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/edgetpu) |
|
||||
| tf_efficientnet_el *tfp | 80.476 (19.524) | 95.200 (4.800) | 10.59 | bicubic | 300 | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/edgetpu) |
|
||||
| gluon_resnet152_v1d | 80.470 (19.530) | 95.206 (4.794) | 60.21 | bicubic | 224 | |
|
||||
| gluon_resnet101_v1d | 80.424 (19.576) | 95.020 (4.980) | 44.57 | bicubic | 224 | |
|
||||
| gluon_resnext101_32x4d | 80.334 (19.666) | 94.926 (5.074) | 44.18 | bicubic | 224 | |
|
||||
@ -126,15 +130,19 @@ I've leveraged the training scripts in this repository to train a few of the mod
|
||||
| gluon_resnet101_v1b | 79.304 (20.696) | 94.524 (5.476) | 44.55 | bicubic | 224 | |
|
||||
| tf_efficientnet_b1 *tfp | 79.172 (20.828) | 94.450 (5.550) | 7.79 | bicubic | 240 | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) |
|
||||
| gluon_resnet50_v1d | 79.074 (20.926) | 94.476 (5.524) | 25.58 | bicubic | 224 | |
|
||||
| tf_efficientnet_em *tfp | 78.958 (21.042) | 94.458 (5.542) | 6.90 | bicubic | 240 | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/edgetpu) |
|
||||
| tf_mixnet_l *tfp | 78.846 (21.154) | 94.212 (5.788) | 7.33 | bilinear | 224 | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet) |
|
||||
| tf_efficientnet_b1 | 78.826 (21.174) | 94.198 (5.802) | 7.79 | bicubic | 240 | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) |
|
||||
| gluon_inception_v3 | 78.804 (21.196) | 94.380 (5.620) | 27.16M | bicubic | 299 | [MxNet Gluon](https://gluon-cv.mxnet.io/model_zoo/classification.html) |
|
||||
| tf_mixnet_l | 78.770 (21.230) | 94.004 (5.996) | 7.33 | bicubic | 224 | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet) |
|
||||
| tf_efficientnet_em | 78.742 (21.258) | 94.332 (5.668) | 6.90 | bicubic | 240 | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/edgetpu) |
|
||||
| gluon_resnet50_v1s | 78.712 (21.288) | 94.242 (5.758) | 25.68 | bicubic | 224 | |
|
||||
| gluon_resnet50_v1c | 78.010 (21.990) | 93.988 (6.012) | 25.58 | bicubic | 224 | |
|
||||
| tf_inception_v3 | 77.856 (22.144) | 93.644 (6.356) | 27.16M | bicubic | 299 | [Tensorflow Slim](https://github.com/tensorflow/models/tree/master/research/slim) |
|
||||
| tf_efficientnet_es *tfp | 77.616 (22.384) | 93.750 (6.250) | 5.44 | bicubic | 224 | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/edgetpu) |
|
||||
| gluon_resnet50_v1b | 77.578 (22.422) | 93.718 (6.282) | 25.56 | bicubic | 224 | |
|
||||
| adv_inception_v3 | 77.576 (22.424) | 93.724 (6.276) | 27.16M | bicubic | 299 | [Tensorflow Adv models](https://github.com/tensorflow/models/tree/master/research/adv_imagenet_models) |
|
||||
| tf_efficientnet_es | 77.264 (22.736) | 93.600 (6.400) | 5.44 | bicubic | 224 | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/edgetpu) |
|
||||
| tf_efficientnet_b0 *tfp | 77.258 (22.742) | 93.478 (6.522) | 5.29 | bicubic | 224 | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) |
|
||||
| tf_mixnet_m *tfp | 77.072 (22.928) | 93.368 (6.632) | 5.01 | bilinear | 224 | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet) |
|
||||
| tf_mixnet_m | 76.950 (23.050) | 93.156 (6.844) | 5.01 | bicubic | 224 | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet) |
|
||||
|
@ -138,6 +138,9 @@ default_cfgs = {
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_m-4647fc68.pth'),
|
||||
'mixnet_l': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_l-5a9a2ed8.pth'),
|
||||
'mixnet_xl': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_xl-ac5fbe8d.pth'),
|
||||
'mixnet_xxl': _cfg(),
|
||||
'tf_mixnet_s': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_s-89d3354b.pth'),
|
||||
'tf_mixnet_m': _cfg(
|
||||
@ -312,21 +315,59 @@ def _decode_block_str(block_str, depth_multiplier=1.0):
|
||||
else:
|
||||
assert False, 'Unknown block type (%s)' % block_type
|
||||
|
||||
# return a list of block args expanded by num_repeat and
|
||||
# scaled by depth_multiplier
|
||||
num_repeat = int(math.ceil(num_repeat * depth_multiplier))
|
||||
return [deepcopy(block_args) for _ in range(num_repeat)]
|
||||
return block_args, num_repeat
|
||||
|
||||
|
||||
def _decode_arch_def(arch_def, depth_multiplier=1.0):
|
||||
def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'):
|
||||
""" Per-stage depth scaling
|
||||
Scales the block repeats in each stage. This depth scaling impl maintains
|
||||
compatibility with the EfficientNet scaling method, while allowing sensible
|
||||
scaling for other models that may have multiple block arg definitions in each stage.
|
||||
"""
|
||||
|
||||
# We scale the total repeat count for each stage, there may be multiple
|
||||
# block arg defs per stage so we need to sum.
|
||||
num_repeat = sum(repeats)
|
||||
if depth_trunc == 'round':
|
||||
# Truncating to int by rounding allows stages with few repeats to remain
|
||||
# proportionally smaller for longer. This is a good choice when stage definitions
|
||||
# include single repeat stages that we'd prefer to keep that way as long as possible
|
||||
num_repeat_scaled = max(1, round(num_repeat * depth_multiplier))
|
||||
else:
|
||||
# The default for EfficientNet truncates repeats to int via 'ceil'.
|
||||
# Any multiplier > 1.0 will result in an increased depth for every stage.
|
||||
num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier))
|
||||
|
||||
# Proportionally distribute repeat count scaling to each block definition in the stage.
|
||||
# Allocation is done in reverse as it results in the first block being less likely to be scaled.
|
||||
# The first block makes less sense to repeat in most of the arch definitions.
|
||||
repeats_scaled = []
|
||||
for r in repeats[::-1]:
|
||||
rs = max(1, round((r / num_repeat * num_repeat_scaled)))
|
||||
repeats_scaled.append(rs)
|
||||
num_repeat -= r
|
||||
num_repeat_scaled -= rs
|
||||
repeats_scaled = repeats_scaled[::-1]
|
||||
|
||||
# Apply the calculated scaling to each block arg in the stage
|
||||
sa_scaled = []
|
||||
for ba, rep in zip(stack_args, repeats_scaled):
|
||||
sa_scaled.extend([deepcopy(ba) for _ in range(rep)])
|
||||
return sa_scaled
|
||||
|
||||
|
||||
def _decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil'):
|
||||
arch_args = []
|
||||
for stack_idx, block_strings in enumerate(arch_def):
|
||||
assert isinstance(block_strings, list)
|
||||
stack_args = []
|
||||
repeats = []
|
||||
for block_str in block_strings:
|
||||
assert isinstance(block_str, str)
|
||||
stack_args.extend(_decode_block_str(block_str, depth_multiplier))
|
||||
arch_args.append(stack_args)
|
||||
ba, rep = _decode_block_str(block_str)
|
||||
stack_args.append(ba)
|
||||
repeats.append(rep)
|
||||
arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc))
|
||||
return arch_args
|
||||
|
||||
|
||||
@ -1261,7 +1302,7 @@ def _gen_mixnet_s(channel_multiplier=1.0, num_classes=1000, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
def _gen_mixnet_m(channel_multiplier=1.0, num_classes=1000, **kwargs):
|
||||
def _gen_mixnet_m(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=1000, **kwargs):
|
||||
"""Creates a MixNet Medium-Large model.
|
||||
|
||||
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet
|
||||
@ -1283,7 +1324,7 @@ def _gen_mixnet_m(channel_multiplier=1.0, num_classes=1000, **kwargs):
|
||||
# 7x7
|
||||
]
|
||||
model = GenEfficientNet(
|
||||
_decode_arch_def(arch_def),
|
||||
_decode_arch_def(arch_def, depth_multiplier=depth_multiplier, depth_trunc='round'),
|
||||
num_classes=num_classes,
|
||||
stem_size=24,
|
||||
num_features=1536,
|
||||
@ -1876,6 +1917,36 @@ def mixnet_l(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mixnet_xl(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Creates a MixNet Extra-Large model.
|
||||
Not a paper spec, experimental def by RW w/ depth scaling.
|
||||
"""
|
||||
default_cfg = default_cfgs['mixnet_xl']
|
||||
#kwargs['drop_connect_rate'] = 0.2
|
||||
model = _gen_mixnet_m(
|
||||
channel_multiplier=1.6, depth_multiplier=1.2, num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mixnet_xxl(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Creates a MixNet Double Extra Large model.
|
||||
Not a paper spec, experimental def by RW w/ depth scaling.
|
||||
"""
|
||||
default_cfg = default_cfgs['mixnet_xxl']
|
||||
# kwargs['drop_connect_rate'] = 0.2
|
||||
model = _gen_mixnet_m(
|
||||
channel_multiplier=2.4, depth_multiplier=1.3, num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_mixnet_s(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Creates a MixNet Small model. Tensorflow compatible variant
|
||||
|
Loading…
x
Reference in New Issue
Block a user