mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Default to native PyTorch AMP instead of APEX amp. Too many APEX issues cropping up lately.
This commit is contained in:
parent
b4e216e377
commit
0356e773f5
@ -177,7 +177,7 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non
|
||||
if cfg is None:
|
||||
cfg = getattr(model, 'default_cfg')
|
||||
if cfg is None or 'url' not in cfg or not cfg['url']:
|
||||
_logger.warning("Pretrained model URL does not exist, using random initialization.")
|
||||
_logger.warning("No pretrained weights exist for this model. Using random initialization.")
|
||||
return
|
||||
|
||||
state_dict = load_state_dict_from_url(cfg['url'], progress=progress, map_location='cpu')
|
||||
|
8
train.py
8
train.py
@ -310,11 +310,11 @@ def main():
|
||||
# resolve AMP arguments based on PyTorch / Apex availability
|
||||
use_amp = None
|
||||
if args.amp:
|
||||
# for backwards compat, `--amp` arg tries apex before native amp
|
||||
if has_apex:
|
||||
args.apex_amp = True
|
||||
elif has_native_amp:
|
||||
# `--amp` chooses native amp before apex (APEX ver not actively maintained)
|
||||
if has_native_amp:
|
||||
args.native_amp = True
|
||||
elif has_apex:
|
||||
args.apex_amp = True
|
||||
if args.apex_amp and has_apex:
|
||||
use_amp = 'apex'
|
||||
elif args.native_amp and has_native_amp:
|
||||
|
13
validate.py
13
validate.py
@ -116,15 +116,20 @@ def validate(args):
|
||||
args.prefetcher = not args.no_prefetcher
|
||||
amp_autocast = suppress # do nothing
|
||||
if args.amp:
|
||||
if has_apex:
|
||||
args.apex_amp = True
|
||||
elif has_native_amp:
|
||||
if has_native_amp:
|
||||
args.native_amp = True
|
||||
elif has_apex:
|
||||
args.apex_amp = True
|
||||
else:
|
||||
_logger.warning("Neither APEX or Native Torch AMP is available, using FP32.")
|
||||
_logger.warning("Neither APEX or Native Torch AMP is available.")
|
||||
assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set."
|
||||
if args.native_amp:
|
||||
amp_autocast = torch.cuda.amp.autocast
|
||||
_logger.info('Validating in mixed precision with native PyTorch AMP.')
|
||||
elif args.apex_amp:
|
||||
_logger.info('Validating in mixed precision with NVIDIA APEX AMP.')
|
||||
else:
|
||||
_logger.info('Validating in float32. AMP not enabled.')
|
||||
|
||||
if args.legacy_jit:
|
||||
set_jit_legacy()
|
||||
|
Loading…
x
Reference in New Issue
Block a user