Update accuracy.py (#104)
Co-authored-by: Ülkü Tuncer Küçüktaş <UlkuTuncerKucuktas@users.noreply.github.com>pull/108/head
parent
5de2d2c0e9
commit
6f7698cb7c
|
@ -25,7 +25,7 @@ def accuracy_torch(pred, target, topk=1):
|
|||
correct = pred_label.eq(target.view(1, -1).expand_as(pred_label))
|
||||
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
||||
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
|
||||
res.append(correct_k.mul_(100. / num))
|
||||
return res
|
||||
|
||||
|
|
Loading…
Reference in New Issue