mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Replace fp16 with amp support for validate.py script
This commit is contained in:
parent
e6f24e5578
commit
02a30411ad
19
validate.py
19
validate.py
@ -18,6 +18,12 @@ import torch.nn as nn
|
||||
import torch.nn.parallel
|
||||
from collections import OrderedDict
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
has_apex = True
|
||||
except ImportError:
|
||||
has_apex = False
|
||||
|
||||
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
|
||||
from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config
|
||||
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging
|
||||
@ -61,8 +67,8 @@ parser.add_argument('--no-prefetcher', action='store_true', default=False,
|
||||
help='disable fast prefetcher')
|
||||
parser.add_argument('--pin-mem', action='store_true', default=False,
|
||||
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
||||
parser.add_argument('--fp16', action='store_true', default=False,
|
||||
help='Use half precision (fp16)')
|
||||
parser.add_argument('--amp', action='store_true', default=False,
|
||||
help='Use AMP mixed precision')
|
||||
parser.add_argument('--tf-preprocessing', action='store_true', default=False,
|
||||
help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
|
||||
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
|
||||
@ -98,13 +104,13 @@ def validate(args):
|
||||
torch.jit.optimized_execution(True)
|
||||
model = torch.jit.script(model)
|
||||
|
||||
if args.num_gpu > 1:
|
||||
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
|
||||
if args.amp:
|
||||
model = amp.initialize(model.cuda(), opt_level='O1')
|
||||
else:
|
||||
model = model.cuda()
|
||||
|
||||
if args.fp16:
|
||||
model = model.half()
|
||||
if args.num_gpu > 1:
|
||||
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))
|
||||
|
||||
criterion = nn.CrossEntropyLoss().cuda()
|
||||
|
||||
@ -127,7 +133,6 @@ def validate(args):
|
||||
num_workers=args.workers,
|
||||
crop_pct=crop_pct,
|
||||
pin_memory=args.pin_mem,
|
||||
fp16=args.fp16,
|
||||
tf_preprocessing=args.tf_preprocessing)
|
||||
|
||||
batch_time = AverageMeter()
|
||||
|
Loading…
x
Reference in New Issue
Block a user