[Enhance] Improve accuracy calculation performance. (#592)

* Imporve accuracy calculate performance.

* Add unit tests for accuracy

* Reuse state_inds
pull/597/head
Ma Zerun 2021-12-09 14:06:08 +08:00 committed by GitHub
parent 72b0da8bd7
commit 188aa6ed5d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 39 additions and 4 deletions

View File

@ -6,7 +6,7 @@ import torch
import torch.nn as nn
def accuracy_numpy(pred, target, topk=1, thrs=0.):
def accuracy_numpy(pred, target, topk=(1, ), thrs=0.):
if isinstance(thrs, Number):
thrs = (thrs, )
res_single = True
@ -19,8 +19,14 @@ def accuracy_numpy(pred, target, topk=1, thrs=0.):
res = []
maxk = max(topk)
num = pred.shape[0]
pred_label = pred.argsort(axis=1)[:, -maxk:][:, ::-1]
pred_score = np.sort(pred, axis=1)[:, -maxk:][:, ::-1]
static_inds = np.indices((num, maxk))[0]
pred_label = pred.argpartition(-maxk, axis=1)[:, -maxk:]
pred_score = pred[static_inds, pred_label]
sort_inds = np.argsort(pred_score, axis=1)[:, ::-1]
pred_label = pred_label[static_inds, sort_inds]
pred_score = pred_score[static_inds, sort_inds]
for k in topk:
correct_k = pred_label[:, :k] == target.reshape(-1, 1)
@ -37,7 +43,7 @@ def accuracy_numpy(pred, target, topk=1, thrs=0.):
return res
def accuracy_torch(pred, target, topk=1, thrs=0.):
def accuracy_torch(pred, target, topk=(1, ), thrs=0.):
if isinstance(thrs, Number):
thrs = (thrs, )
res_single = True

View File

@ -3,6 +3,7 @@ import pytest
import torch
from mmcls.core import average_performance, mAP
from mmcls.models.losses.accuracy import Accuracy
def test_mAP():
@ -55,3 +56,31 @@ def test_average_performance():
assert average_performance(
pred, target, k=2) == pytest.approx(
(43.75, 50.00, 46.67, 40.00, 57.14, 47.06), rel=1e-2)
def test_accuracy():
pred_tensor = torch.tensor([[0.1, 0.2, 0.4], [0.2, 0.5, 0.3],
[0.4, 0.3, 0.1], [0.8, 0.9, 0.0]])
target_tensor = torch.tensor([2, 0, 0, 0])
pred_array = pred_tensor.numpy()
target_array = target_tensor.numpy()
acc_top1 = 50.
acc_top2 = 75.
compute_acc = Accuracy(topk=1)
assert compute_acc(pred_tensor, target_tensor) == acc_top1
assert compute_acc(pred_array, target_array) == acc_top1
compute_acc = Accuracy(topk=(1, ))
assert compute_acc(pred_tensor, target_tensor)[0] == acc_top1
assert compute_acc(pred_array, target_array)[0] == acc_top1
compute_acc = Accuracy(topk=(1, 2))
assert compute_acc(pred_tensor, target_tensor)[0] == acc_top1
assert compute_acc(pred_tensor, target_tensor)[1] == acc_top2
assert compute_acc(pred_array, target_array)[0] == acc_top1
assert compute_acc(pred_array, target_array)[1] == acc_top2
with pytest.raises(TypeError):
compute_acc(pred_tensor, target_array)