Merge remote-tracking branch 'upstream/master'
commit
b0d45fd09c
19
README.md
19
README.md
|
@ -2,6 +2,11 @@
|
|||
|
||||
## What's New
|
||||
|
||||
### Feb 1/2, 2020
|
||||
* Port new EfficientNet-B8 (RandAugment) weights, these are different than the B8 AdvProp, different input normalization.
|
||||
* Update results csv files on all models for ImageNet validation and three other test sets
|
||||
* Push PyPi package update
|
||||
|
||||
### Jan 31, 2020
|
||||
* Update ResNet50 weights with a new 79.038 result from further JSD / AugMix experiments. Full command line for reproduction in training section below.
|
||||
|
||||
|
@ -87,9 +92,9 @@ Included models:
|
|||
* Original variant from [Cadene](https://github.com/Cadene/pretrained-models.pytorch)
|
||||
* MXNet Gluon 'modified aligned' Xception-65 and 71 models from [Gluon ModelZoo](https://github.com/dmlc/gluon-cv/tree/master/gluoncv/model_zoo)
|
||||
* PNasNet & NASNet-A (from [Cadene](https://github.com/Cadene/pretrained-models.pytorch))
|
||||
* DPN (from [me](https://github.com/rwightman/pytorch-dpn-pretrained), weights hosted by Cadene)
|
||||
* DPN (from [myself](https://github.com/rwightman/pytorch-dpn-pretrained))
|
||||
* DPN-68, DPN-68b, DPN-92, DPN-98, DPN-131, DPN-107
|
||||
* EfficientNet (from my standalone [GenMobileNet](https://github.com/rwightman/genmobilenet-pytorch)) - A generic model that implements many of the efficient models that utilize similar DepthwiseSeparable and InvertedResidual blocks
|
||||
* EfficientNet (from my standalone [GenEfficientNet](https://github.com/rwightman/gen-efficientnet-pytorch)) - A generic model that implements many of the efficient models that utilize similar DepthwiseSeparable and InvertedResidual blocks
|
||||
* EfficientNet AdvProp (B0-B8) (https://arxiv.org/abs/1911.09665) -- TF weights ported
|
||||
* EfficientNet (B0-B7) (https://arxiv.org/abs/1905.11946) -- TF weights ported, B0-B2 finetuned PyTorch
|
||||
* EfficientNet-EdgeTPU (S, M, L) (https://ai.googleblog.com/2019/08/efficientnet-edgetpu-creating.html) --TF weights ported
|
||||
|
@ -136,8 +141,8 @@ I've leveraged the training scripts in this repository to train a few of the mod
|
|||
|
||||
|Model | Prec@1 (Err) | Prec@5 (Err) | Param # | Image Scaling | Image Size |
|
||||
|---|---|---|---|---|---|
|
||||
| efficientnet_b3a | 81.874 (18.126) | 95.840 (4.160) | 9.11M | bicubic | 320 (1.0 crop) |
|
||||
| efficientnet_b3 | 81.498 (18.502) | 95.718 (4.282) | 9.11M | bicubic | 300 |
|
||||
| efficientnet_b3a | 81.874 (18.126) | 95.840 (4.160) | 12.23M | bicubic | 320 (1.0 crop) |
|
||||
| efficientnet_b3 | 81.498 (18.502) | 95.718 (4.282) | 12.23M | bicubic | 300 |
|
||||
| efficientnet_b2a | 80.608 (19.392) | 95.310 (4.690) | 9.11M | bicubic | 288 (1.0 crop) |
|
||||
| mixnet_xl | 80.478 (19.522) | 94.932 (5.068) | 11.90M | bicubic | 224 |
|
||||
| efficientnet_b2 | 80.402 (19.598) | 95.076 (4.924) | 9.11M | bicubic | 260 |
|
||||
|
@ -170,6 +175,8 @@ For the models below, the model code and weight porting from Tensorflow or MXNet
|
|||
| Model | Prec@1 (Err) | Prec@5 (Err) | Param # | Image Scaling | Image Size |
|
||||
|---|---|---|---|---|---|
|
||||
| tf_efficientnet_b8_ap *tfp | 85.436 (14.564) | 97.272 (2.728) | 87.4 | bicubic | 672 |
|
||||
| tf_efficientnet_b8 *tfp | 85.384 (14.616) | 97.394 (2.606) | 87.4 | bicubic | 672 |
|
||||
| tf_efficientnet_b8 | 85.37 (14.63) | 97.39 (2.61) | 87.4 | bicubic | 672 |
|
||||
| tf_efficientnet_b8_ap | 85.368 (14.632) | 97.294 (2.706) | 87.4 | bicubic | 672 |
|
||||
| tf_efficientnet_b7_ap *tfp | 85.154 (14.846) | 97.244 (2.756) | 66.35 | bicubic | 600 |
|
||||
| tf_efficientnet_b7_ap | 85.118 (14.882) | 97.252 (2.748) | 66.35 | bicubic | 600 |
|
||||
|
@ -309,13 +316,13 @@ Trained on two older 1080Ti cards, this took a while. Only slightly, non statist
|
|||
|
||||
All development and testing has been done in Conda Python 3 environments on Linux x86-64 systems, specifically Python 3.6.x and 3.7.x. Little to no care has been taken to be Python 2.x friendly and I don't plan to support it. If you run into any challenges running on Windows, or other OS, I'm definitely open to looking into those issues so long as it's in a reproducible (read Conda) environment.
|
||||
|
||||
PyTorch versions 1.2 and 1.3.1 have been tested with this code.
|
||||
PyTorch versions 1.2, 1.3.1, and 1.4 have been tested with this code.
|
||||
|
||||
I've tried to keep the dependencies minimal, the setup is as per the PyTorch default install instructions for Conda:
|
||||
```
|
||||
conda create -n torch-env
|
||||
conda activate torch-env
|
||||
conda install -c pytorch pytorch torchvision cudatoolkit=10
|
||||
conda install -c pytorch pytorch torchvision cudatoolkit=10.1
|
||||
conda install pyyaml
|
||||
```
|
||||
|
||||
|
|
|
@ -69,7 +69,7 @@ def transforms_imagenet_train(
|
|||
else:
|
||||
# if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
|
||||
color_jitter = (float(color_jitter),) * 3
|
||||
secondary_tfl += [transforms.ColorJitter(*color_jitter)]
|
||||
secondary_tfl += [transforms.ColorJitter(*color_jitter)]
|
||||
|
||||
final_tfl = []
|
||||
if use_prefetcher:
|
||||
|
|
|
@ -1 +1 @@
|
|||
__version__ = '0.1.14'
|
||||
__version__ = '0.1.16'
|
||||
|
|
4
train.py
4
train.py
|
@ -79,6 +79,8 @@ parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
|
|||
help='Image resize interpolation type (overrides model)')
|
||||
parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N',
|
||||
help='input batch size for training (default: 32)')
|
||||
parser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N',
|
||||
help='ratio of validation batch size to training batch size (default: 1)')
|
||||
parser.add_argument('--drop', type=float, default=0.0, metavar='DROP',
|
||||
help='Dropout rate (default: 0.)')
|
||||
parser.add_argument('--drop-connect', type=float, default=0.0, metavar='DROP',
|
||||
|
@ -388,7 +390,7 @@ def main():
|
|||
loader_eval = create_loader(
|
||||
dataset_eval,
|
||||
input_size=data_config['input_size'],
|
||||
batch_size=4 * args.batch_size,
|
||||
batch_size=args.validation_batch_size_multiplier * args.batch_size,
|
||||
is_training=False,
|
||||
use_prefetcher=args.prefetcher,
|
||||
interpolation=data_config['interpolation'],
|
||||
|
|
Loading…
Reference in New Issue