mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Remove deprecated bn-tf train arg and create_model handler. Add evos/evob models back into fx test filter until norm_norm_norm branch merged.
This commit is contained in:
parent
b9a715c86a
commit
5ccf682a8f
@ -3,12 +3,12 @@
|
||||
## EfficientNet-B2 with RandAugment - 80.4 top-1, 95.1 top-5
|
||||
These params are for dual Titan RTX cards with NVIDIA Apex installed:
|
||||
|
||||
`./distributed_train.sh 2 /imagenet/ --model efficientnet_b2 -b 128 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .97 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.3 --drop-connect 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .016`
|
||||
`./distributed_train.sh 2 /imagenet/ --model efficientnet_b2 -b 128 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .97 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.3 --drop-path 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .016`
|
||||
|
||||
## MixNet-XL with RandAugment - 80.5 top-1, 94.9 top-5
|
||||
This params are for dual Titan RTX cards with NVIDIA Apex installed:
|
||||
|
||||
`./distributed_train.sh 2 /imagenet/ --model mixnet_xl -b 128 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .969 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.3 --drop-connect 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.3 --amp --lr .016 --dist-bn reduce`
|
||||
`./distributed_train.sh 2 /imagenet/ --model mixnet_xl -b 128 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .969 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.3 --drop-path 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.3 --amp --lr .016 --dist-bn reduce`
|
||||
|
||||
## SE-ResNeXt-26-D and SE-ResNeXt-26-T
|
||||
These hparams (or similar) work well for a wide range of ResNet architecture, generally a good idea to increase the epoch # as the model size increases... ie approx 180-200 for ResNe(X)t50, and 220+ for larger. Increase batch size and LR proportionally for better GPUs or with AMP enabled. These params were for 2 1080Ti cards:
|
||||
@ -21,7 +21,7 @@ The training of this model started with the same command line as EfficientNet-B2
|
||||
## EfficientNet-B0 with RandAugment - 77.7 top-1, 95.3 top-5
|
||||
[Michael Klachko](https://github.com/michaelklachko) achieved these results with the command line for B2 adapted for larger batch size, with the recommended B0 dropout rate of 0.2.
|
||||
|
||||
`./distributed_train.sh 2 /imagenet/ --model efficientnet_b0 -b 384 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .97 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-connect 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .048`
|
||||
`./distributed_train.sh 2 /imagenet/ --model efficientnet_b0 -b 384 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .97 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-path 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .048`
|
||||
|
||||
## ResNet50 with JSD loss and RandAugment (clean + 2x RA augs) - 79.04 top-1, 94.39 top-5
|
||||
|
||||
@ -32,11 +32,11 @@ Trained on two older 1080Ti cards, this took a while. Only slightly, non statist
|
||||
## EfficientNet-ES (EdgeTPU-Small) with RandAugment - 78.066 top-1, 93.926 top-5
|
||||
Trained by [Andrew Lavin](https://github.com/andravin) with 8 V100 cards. Model EMA was not used, final checkpoint is the average of 8 best checkpoints during training.
|
||||
|
||||
`./distributed_train.sh 8 /imagenet --model efficientnet_es -b 128 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .97 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-connect 0.2 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .064`
|
||||
`./distributed_train.sh 8 /imagenet --model efficientnet_es -b 128 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .97 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-path 0.2 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .064`
|
||||
|
||||
## MobileNetV3-Large-100 - 75.766 top-1, 92,542 top-5
|
||||
|
||||
`./distributed_train.sh 2 /imagenet/ --model mobilenetv3_large_100 -b 512 --sched step --epochs 600 --decay-epochs 2.4 --decay-rate .973 --opt rmsproptf --opt-eps .001 -j 7 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-connect 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .064 --lr-noise 0.42 0.9`
|
||||
`./distributed_train.sh 2 /imagenet/ --model mobilenetv3_large_100 -b 512 --sched step --epochs 600 --decay-epochs 2.4 --decay-rate .973 --opt rmsproptf --opt-eps .001 -j 7 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-path 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .064 --lr-noise 0.42 0.9`
|
||||
|
||||
|
||||
## ResNeXt-50 32x4d w/ RandAugment - 79.762 top-1, 94.60 top-5
|
||||
|
@ -427,6 +427,7 @@ if 'GITHUB_ACTIONS' not in os.environ:
|
||||
'deit_*_distilled_patch16_224',
|
||||
'levit*',
|
||||
'pit_*_distilled_224',
|
||||
'*evob', '*evos', # until norm_norm_norm branch is merged
|
||||
] + EXCLUDE_FX_FILTERS
|
||||
|
||||
|
||||
|
@ -40,7 +40,7 @@ def get_bn_args_tf():
|
||||
|
||||
|
||||
def resolve_bn_args(kwargs):
|
||||
bn_args = get_bn_args_tf() if kwargs.pop('bn_tf', False) else {}
|
||||
bn_args = {}
|
||||
bn_momentum = kwargs.pop('bn_momentum', None)
|
||||
if bn_momentum is not None:
|
||||
bn_args['momentum'] = bn_momentum
|
||||
|
@ -47,13 +47,6 @@ def create_model(
|
||||
"""
|
||||
source_name, model_name = split_model_name(model_name)
|
||||
|
||||
# Only EfficientNet and MobileNetV3 models have support for batchnorm params or drop_connect_rate passed as args
|
||||
is_efficientnet = is_model_in_modules(model_name, ['efficientnet', 'mobilenetv3'])
|
||||
if not is_efficientnet:
|
||||
kwargs.pop('bn_tf', None)
|
||||
kwargs.pop('bn_momentum', None)
|
||||
kwargs.pop('bn_eps', None)
|
||||
|
||||
# handle backwards compat with drop_connect -> drop_path change
|
||||
drop_connect_rate = kwargs.pop('drop_connect_rate', None)
|
||||
if drop_connect_rate is not None and kwargs.get('drop_path_rate', None) is None:
|
||||
|
3
train.py
3
train.py
@ -234,8 +234,6 @@ parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
|
||||
help='Drop block rate (default: None)')
|
||||
|
||||
# Batch norm parameters (only works with gen_efficientnet based models currently)
|
||||
parser.add_argument('--bn-tf', action='store_true', default=False,
|
||||
help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')
|
||||
parser.add_argument('--bn-momentum', type=float, default=None,
|
||||
help='BatchNorm momentum override (if not None)')
|
||||
parser.add_argument('--bn-eps', type=float, default=None,
|
||||
@ -375,7 +373,6 @@ def main():
|
||||
drop_path_rate=args.drop_path,
|
||||
drop_block_rate=args.drop_block,
|
||||
global_pool=args.gp,
|
||||
bn_tf=args.bn_tf,
|
||||
bn_momentum=args.bn_momentum,
|
||||
bn_eps=args.bn_eps,
|
||||
scriptable=args.torchscript,
|
||||
|
Loading…
x
Reference in New Issue
Block a user