From 188aa6ed5dae184024818a6c3a72e5b0c733888d Mon Sep 17 00:00:00 2001 From: Ma Zerun Date: Thu, 9 Dec 2021 14:06:08 +0800 Subject: [PATCH] [Enhance] Improve accuracy calculation performance. (#592) * Imporve accuracy calculate performance. * Add unit tests for accuracy * Reuse state_inds --- mmcls/models/losses/accuracy.py | 14 ++++++++++---- tests/test_metrics/test_metrics.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/mmcls/models/losses/accuracy.py b/mmcls/models/losses/accuracy.py index 3ecdbbecf..873e579b8 100644 --- a/mmcls/models/losses/accuracy.py +++ b/mmcls/models/losses/accuracy.py @@ -6,7 +6,7 @@ import torch import torch.nn as nn -def accuracy_numpy(pred, target, topk=1, thrs=0.): +def accuracy_numpy(pred, target, topk=(1, ), thrs=0.): if isinstance(thrs, Number): thrs = (thrs, ) res_single = True @@ -19,8 +19,14 @@ def accuracy_numpy(pred, target, topk=1, thrs=0.): res = [] maxk = max(topk) num = pred.shape[0] - pred_label = pred.argsort(axis=1)[:, -maxk:][:, ::-1] - pred_score = np.sort(pred, axis=1)[:, -maxk:][:, ::-1] + + static_inds = np.indices((num, maxk))[0] + pred_label = pred.argpartition(-maxk, axis=1)[:, -maxk:] + pred_score = pred[static_inds, pred_label] + + sort_inds = np.argsort(pred_score, axis=1)[:, ::-1] + pred_label = pred_label[static_inds, sort_inds] + pred_score = pred_score[static_inds, sort_inds] for k in topk: correct_k = pred_label[:, :k] == target.reshape(-1, 1) @@ -37,7 +43,7 @@ def accuracy_numpy(pred, target, topk=1, thrs=0.): return res -def accuracy_torch(pred, target, topk=1, thrs=0.): +def accuracy_torch(pred, target, topk=(1, ), thrs=0.): if isinstance(thrs, Number): thrs = (thrs, ) res_single = True diff --git a/tests/test_metrics/test_metrics.py b/tests/test_metrics/test_metrics.py index 8906af8a9..df06e2596 100644 --- a/tests/test_metrics/test_metrics.py +++ b/tests/test_metrics/test_metrics.py @@ -3,6 +3,7 @@ import pytest import torch from mmcls.core import average_performance, mAP +from mmcls.models.losses.accuracy import Accuracy def test_mAP(): @@ -55,3 +56,31 @@ def test_average_performance(): assert average_performance( pred, target, k=2) == pytest.approx( (43.75, 50.00, 46.67, 40.00, 57.14, 47.06), rel=1e-2) + + +def test_accuracy(): + pred_tensor = torch.tensor([[0.1, 0.2, 0.4], [0.2, 0.5, 0.3], + [0.4, 0.3, 0.1], [0.8, 0.9, 0.0]]) + target_tensor = torch.tensor([2, 0, 0, 0]) + pred_array = pred_tensor.numpy() + target_array = target_tensor.numpy() + + acc_top1 = 50. + acc_top2 = 75. + + compute_acc = Accuracy(topk=1) + assert compute_acc(pred_tensor, target_tensor) == acc_top1 + assert compute_acc(pred_array, target_array) == acc_top1 + + compute_acc = Accuracy(topk=(1, )) + assert compute_acc(pred_tensor, target_tensor)[0] == acc_top1 + assert compute_acc(pred_array, target_array)[0] == acc_top1 + + compute_acc = Accuracy(topk=(1, 2)) + assert compute_acc(pred_tensor, target_tensor)[0] == acc_top1 + assert compute_acc(pred_tensor, target_tensor)[1] == acc_top2 + assert compute_acc(pred_array, target_array)[0] == acc_top1 + assert compute_acc(pred_array, target_array)[1] == acc_top2 + + with pytest.raises(TypeError): + compute_acc(pred_tensor, target_array)