mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Replace fp16 with amp support for validate.py script
This commit is contained in:
parent
e6f24e5578
commit
02a30411ad
19
validate.py
19
validate.py
@ -18,6 +18,12 @@ import torch.nn as nn
|
|||||||
import torch.nn.parallel
|
import torch.nn.parallel
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
try:
|
||||||
|
from apex import amp
|
||||||
|
has_apex = True
|
||||||
|
except ImportError:
|
||||||
|
has_apex = False
|
||||||
|
|
||||||
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
|
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
|
||||||
from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config
|
from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config
|
||||||
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging
|
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging
|
||||||
@ -61,8 +67,8 @@ parser.add_argument('--no-prefetcher', action='store_true', default=False,
|
|||||||
help='disable fast prefetcher')
|
help='disable fast prefetcher')
|
||||||
parser.add_argument('--pin-mem', action='store_true', default=False,
|
parser.add_argument('--pin-mem', action='store_true', default=False,
|
||||||
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
||||||
parser.add_argument('--fp16', action='store_true', default=False,
|
parser.add_argument('--amp', action='store_true', default=False,
|
||||||
help='Use half precision (fp16)')
|
help='Use AMP mixed precision')
|
||||||
parser.add_argument('--tf-preprocessing', action='store_true', default=False,
|
parser.add_argument('--tf-preprocessing', action='store_true', default=False,
|
||||||
help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
|
help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
|
||||||
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
|
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
|
||||||
@ -98,13 +104,13 @@ def validate(args):
|
|||||||
torch.jit.optimized_execution(True)
|
torch.jit.optimized_execution(True)
|
||||||
model = torch.jit.script(model)
|
model = torch.jit.script(model)
|
||||||
|
|
||||||
if args.num_gpu > 1:
|
if args.amp:
|
||||||
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
|
model = amp.initialize(model.cuda(), opt_level='O1')
|
||||||
else:
|
else:
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
|
|
||||||
if args.fp16:
|
if args.num_gpu > 1:
|
||||||
model = model.half()
|
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))
|
||||||
|
|
||||||
criterion = nn.CrossEntropyLoss().cuda()
|
criterion = nn.CrossEntropyLoss().cuda()
|
||||||
|
|
||||||
@ -127,7 +133,6 @@ def validate(args):
|
|||||||
num_workers=args.workers,
|
num_workers=args.workers,
|
||||||
crop_pct=crop_pct,
|
crop_pct=crop_pct,
|
||||||
pin_memory=args.pin_mem,
|
pin_memory=args.pin_mem,
|
||||||
fp16=args.fp16,
|
|
||||||
tf_preprocessing=args.tf_preprocessing)
|
tf_preprocessing=args.tf_preprocessing)
|
||||||
|
|
||||||
batch_time = AverageMeter()
|
batch_time = AverageMeter()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user