Add crop_pct arg to validate, extra fields to csv output, 'all' filters pretrained
parent
949b7a81c4
commit
edb425ea48
23
validate.py
23
validate.py
|
@ -30,6 +30,8 @@ parser.add_argument('-b', '--batch-size', default=256, type=int,
|
||||||
metavar='N', help='mini-batch size (default: 256)')
|
metavar='N', help='mini-batch size (default: 256)')
|
||||||
parser.add_argument('--img-size', default=None, type=int,
|
parser.add_argument('--img-size', default=None, type=int,
|
||||||
metavar='N', help='Input image dimension, uses model default if empty')
|
metavar='N', help='Input image dimension, uses model default if empty')
|
||||||
|
parser.add_argument('--crop-pct', default=None, type=float,
|
||||||
|
metavar='N', help='Input image center crop pct')
|
||||||
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
|
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
|
||||||
help='Override mean pixel value of dataset')
|
help='Override mean pixel value of dataset')
|
||||||
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
|
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
|
||||||
|
@ -81,6 +83,7 @@ def validate(args):
|
||||||
|
|
||||||
criterion = nn.CrossEntropyLoss().cuda()
|
criterion = nn.CrossEntropyLoss().cuda()
|
||||||
|
|
||||||
|
crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
|
||||||
loader = create_loader(
|
loader = create_loader(
|
||||||
Dataset(args.data, load_bytes=args.tf_preprocessing),
|
Dataset(args.data, load_bytes=args.tf_preprocessing),
|
||||||
input_size=data_config['input_size'],
|
input_size=data_config['input_size'],
|
||||||
|
@ -90,7 +93,7 @@ def validate(args):
|
||||||
mean=data_config['mean'],
|
mean=data_config['mean'],
|
||||||
std=data_config['std'],
|
std=data_config['std'],
|
||||||
num_workers=args.workers,
|
num_workers=args.workers,
|
||||||
crop_pct=1.0 if test_time_pool else data_config['crop_pct'],
|
crop_pct=crop_pct,
|
||||||
tf_preprocessing=args.tf_preprocessing)
|
tf_preprocessing=args.tf_preprocessing)
|
||||||
|
|
||||||
batch_time = AverageMeter()
|
batch_time = AverageMeter()
|
||||||
|
@ -124,16 +127,19 @@ def validate(args):
|
||||||
'Test: [{0:>4d}/{1}] '
|
'Test: [{0:>4d}/{1}] '
|
||||||
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
|
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
|
||||||
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
|
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
|
||||||
'Prec@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
|
'Prec@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) '
|
||||||
'Prec@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
|
'Prec@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
|
||||||
i, len(loader), batch_time=batch_time,
|
i, len(loader), batch_time=batch_time,
|
||||||
rate_avg=input.size(0) / batch_time.avg,
|
rate_avg=input.size(0) / batch_time.avg,
|
||||||
loss=losses, top1=top1, top5=top5))
|
loss=losses, top1=top1, top5=top5))
|
||||||
|
|
||||||
results = OrderedDict(
|
results = OrderedDict(
|
||||||
top1=round(top1.avg, 3), top1_err=round(100 - top1.avg, 3),
|
top1=round(top1.avg, 4), top1_err=round(100 - top1.avg, 4),
|
||||||
top5=round(top5.avg, 3), top5_err=round(100 - top5.avg, 3),
|
top5=round(top5.avg, 4), top5_err=round(100 - top5.avg, 4),
|
||||||
param_count=round(param_count / 1e6, 2))
|
param_count=round(param_count / 1e6, 2),
|
||||||
|
img_size=data_config['input_size'][-1],
|
||||||
|
cropt_pct=crop_pct,
|
||||||
|
interpolation=data_config['interpolation'])
|
||||||
|
|
||||||
logging.info(' * Prec@1 {:.3f} ({:.3f}) Prec@5 {:.3f} ({:.3f})'.format(
|
logging.info(' * Prec@1 {:.3f} ({:.3f}) Prec@5 {:.3f} ({:.3f})'.format(
|
||||||
results['top1'], results['top1_err'], results['top5'], results['top5_err']))
|
results['top1'], results['top1_err'], results['top5'], results['top5_err']))
|
||||||
|
@ -155,7 +161,7 @@ def main():
|
||||||
if args.model == 'all':
|
if args.model == 'all':
|
||||||
# validate all models in a list of names with pretrained checkpoints
|
# validate all models in a list of names with pretrained checkpoints
|
||||||
args.pretrained = True
|
args.pretrained = True
|
||||||
model_names = list_models()
|
model_names = list_models(pretrained=True)
|
||||||
model_cfgs = [(n, '') for n in model_names]
|
model_cfgs = [(n, '') for n in model_names]
|
||||||
elif not is_model(args.model):
|
elif not is_model(args.model):
|
||||||
# model name doesn't exist, try as wildcard filter
|
# model name doesn't exist, try as wildcard filter
|
||||||
|
@ -170,7 +176,8 @@ def main():
|
||||||
args.model = m
|
args.model = m
|
||||||
args.checkpoint = c
|
args.checkpoint = c
|
||||||
result = OrderedDict(model=args.model)
|
result = OrderedDict(model=args.model)
|
||||||
result.update(validate(args))
|
r = validate(args)
|
||||||
|
result.update(r)
|
||||||
if args.checkpoint:
|
if args.checkpoint:
|
||||||
result['checkpoint'] = args.checkpoint
|
result['checkpoint'] = args.checkpoint
|
||||||
dw = csv.DictWriter(cf, fieldnames=result.keys())
|
dw = csv.DictWriter(cf, fieldnames=result.keys())
|
||||||
|
|
Loading…
Reference in New Issue