mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Remove poorly named metrics from torch imagenet example origins. Use top1/top5 in csv output for consistency with existing validation results files, acc elsewhere. Fixes #111
This commit is contained in:
parent
56608c9070
commit
13cf68850b
@ -170,10 +170,9 @@ class AverageMeter:
|
|||||||
|
|
||||||
|
|
||||||
def accuracy(output, target, topk=(1,)):
|
def accuracy(output, target, topk=(1,)):
|
||||||
"""Computes the precision@k for the specified values of k"""
|
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
||||||
maxk = max(topk)
|
maxk = max(topk)
|
||||||
batch_size = target.size(0)
|
batch_size = target.size(0)
|
||||||
|
|
||||||
_, pred = output.topk(maxk, 1, True, True)
|
_, pred = output.topk(maxk, 1, True, True)
|
||||||
pred = pred.t()
|
pred = pred.t()
|
||||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||||
|
29
train.py
29
train.py
@ -193,8 +193,8 @@ parser.add_argument('--no-prefetcher', action='store_true', default=False,
|
|||||||
help='disable fast prefetcher')
|
help='disable fast prefetcher')
|
||||||
parser.add_argument('--output', default='', type=str, metavar='PATH',
|
parser.add_argument('--output', default='', type=str, metavar='PATH',
|
||||||
help='path to output folder (default: none, current dir)')
|
help='path to output folder (default: none, current dir)')
|
||||||
parser.add_argument('--eval-metric', default='prec1', type=str, metavar='EVAL_METRIC',
|
parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',
|
||||||
help='Best metric (default: "prec1"')
|
help='Best metric (default: "top1"')
|
||||||
parser.add_argument('--tta', type=int, default=0, metavar='N',
|
parser.add_argument('--tta', type=int, default=0, metavar='N',
|
||||||
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
|
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
|
||||||
parser.add_argument("--local_rank", default=0, type=int)
|
parser.add_argument("--local_rank", default=0, type=int)
|
||||||
@ -596,8 +596,8 @@ def train_epoch(
|
|||||||
def validate(model, loader, loss_fn, args, log_suffix=''):
|
def validate(model, loader, loss_fn, args, log_suffix=''):
|
||||||
batch_time_m = AverageMeter()
|
batch_time_m = AverageMeter()
|
||||||
losses_m = AverageMeter()
|
losses_m = AverageMeter()
|
||||||
prec1_m = AverageMeter()
|
top1_m = AverageMeter()
|
||||||
prec5_m = AverageMeter()
|
top5_m = AverageMeter()
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
@ -621,20 +621,20 @@ def validate(model, loader, loss_fn, args, log_suffix=''):
|
|||||||
target = target[0:target.size(0):reduce_factor]
|
target = target[0:target.size(0):reduce_factor]
|
||||||
|
|
||||||
loss = loss_fn(output, target)
|
loss = loss_fn(output, target)
|
||||||
prec1, prec5 = accuracy(output, target, topk=(1, 5))
|
acc1, acc5 = accuracy(output, target, topk=(1, 5))
|
||||||
|
|
||||||
if args.distributed:
|
if args.distributed:
|
||||||
reduced_loss = reduce_tensor(loss.data, args.world_size)
|
reduced_loss = reduce_tensor(loss.data, args.world_size)
|
||||||
prec1 = reduce_tensor(prec1, args.world_size)
|
acc1 = reduce_tensor(acc1, args.world_size)
|
||||||
prec5 = reduce_tensor(prec5, args.world_size)
|
acc5 = reduce_tensor(acc5, args.world_size)
|
||||||
else:
|
else:
|
||||||
reduced_loss = loss.data
|
reduced_loss = loss.data
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
losses_m.update(reduced_loss.item(), input.size(0))
|
losses_m.update(reduced_loss.item(), input.size(0))
|
||||||
prec1_m.update(prec1.item(), output.size(0))
|
top1_m.update(acc1.item(), output.size(0))
|
||||||
prec5_m.update(prec5.item(), output.size(0))
|
top5_m.update(acc5.item(), output.size(0))
|
||||||
|
|
||||||
batch_time_m.update(time.time() - end)
|
batch_time_m.update(time.time() - end)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
@ -644,13 +644,12 @@ def validate(model, loader, loss_fn, args, log_suffix=''):
|
|||||||
'{0}: [{1:>4d}/{2}] '
|
'{0}: [{1:>4d}/{2}] '
|
||||||
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
|
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
|
||||||
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
|
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
|
||||||
'Prec@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
|
'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
|
||||||
'Prec@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
|
'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
|
||||||
log_name, batch_idx, last_idx,
|
log_name, batch_idx, last_idx, batch_time=batch_time_m,
|
||||||
batch_time=batch_time_m, loss=losses_m,
|
loss=losses_m, top1=top1_m, top5=top5_m))
|
||||||
top1=prec1_m, top5=prec5_m))
|
|
||||||
|
|
||||||
metrics = OrderedDict([('loss', losses_m.avg), ('prec1', prec1_m.avg), ('prec5', prec5_m.avg)])
|
metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])
|
||||||
|
|
||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
|
12
validate.py
12
validate.py
@ -150,10 +150,10 @@ def validate(args):
|
|||||||
loss = criterion(output, target)
|
loss = criterion(output, target)
|
||||||
|
|
||||||
# measure accuracy and record loss
|
# measure accuracy and record loss
|
||||||
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
|
acc1, acc5 = accuracy(output.data, target, topk=(1, 5))
|
||||||
losses.update(loss.item(), input.size(0))
|
losses.update(loss.item(), input.size(0))
|
||||||
top1.update(prec1.item(), input.size(0))
|
top1.update(acc1.item(), input.size(0))
|
||||||
top5.update(prec5.item(), input.size(0))
|
top5.update(acc5.item(), input.size(0))
|
||||||
|
|
||||||
# measure elapsed time
|
# measure elapsed time
|
||||||
batch_time.update(time.time() - end)
|
batch_time.update(time.time() - end)
|
||||||
@ -164,8 +164,8 @@ def validate(args):
|
|||||||
'Test: [{0:>4d}/{1}] '
|
'Test: [{0:>4d}/{1}] '
|
||||||
'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
|
'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
|
||||||
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
|
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
|
||||||
'Prec@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) '
|
'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) '
|
||||||
'Prec@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
|
'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
|
||||||
i, len(loader), batch_time=batch_time,
|
i, len(loader), batch_time=batch_time,
|
||||||
rate_avg=input.size(0) / batch_time.avg,
|
rate_avg=input.size(0) / batch_time.avg,
|
||||||
loss=losses, top1=top1, top5=top5))
|
loss=losses, top1=top1, top5=top5))
|
||||||
@ -178,7 +178,7 @@ def validate(args):
|
|||||||
cropt_pct=crop_pct,
|
cropt_pct=crop_pct,
|
||||||
interpolation=data_config['interpolation'])
|
interpolation=data_config['interpolation'])
|
||||||
|
|
||||||
logging.info(' * Prec@1 {:.3f} ({:.3f}) Prec@5 {:.3f} ({:.3f})'.format(
|
logging.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
|
||||||
results['top1'], results['top1_err'], results['top5'], results['top5_err']))
|
results['top1'], results['top1_err'], results['top5'], results['top5_err']))
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
Loading…
x
Reference in New Issue
Block a user