mmpretrain/tests/test_evaluation/test_metrics/test_retrieval.py

121 lines
4.8 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import numpy as np
import torch
from mmcls.evaluation.metrics import RetrievalRecall
from mmcls.registry import METRICS
from mmcls.structures import ClsDataSample
class TestRetrievalRecall(TestCase):
def test_evaluate(self):
"""Test using the metric in the same way as Evalutor."""
pred = [
ClsDataSample().set_pred_score(i).set_gt_label(k).to_dict()
for i, k in zip([
torch.tensor([0.7, 0.0, 0.3]),
torch.tensor([0.5, 0.2, 0.3]),
torch.tensor([0.4, 0.5, 0.1]),
torch.tensor([0.0, 0.0, 1.0]),
torch.tensor([0.0, 0.0, 1.0]),
torch.tensor([0.0, 0.0, 1.0]),
], [[0], [0, 1], [1], [2], [1, 2], [0, 1]])
]
# Test with score (use score instead of label if score exists)
metric = METRICS.build(dict(type='RetrievalRecall', topk=1))
metric.process(None, pred)
recall = metric.evaluate(6)
self.assertIsInstance(recall, dict)
self.assertAlmostEqual(
recall['retrieval/Recall@1'], 5 / 6 * 100, places=4)
# Test with invalid topk
with self.assertRaisesRegex(RuntimeError, 'selected index k'):
metric = METRICS.build(dict(type='RetrievalRecall', topk=10))
metric.process(None, pred)
metric.evaluate(6)
with self.assertRaisesRegex(ValueError, '`topk` must be a'):
METRICS.build(dict(type='RetrievalRecall', topk=-1))
# Test initialization
metric = METRICS.build(dict(type='RetrievalRecall', topk=5))
self.assertEqual(metric.topk, (5, ))
# Test initialization
metric = METRICS.build(dict(type='RetrievalRecall', topk=(1, 2, 5)))
self.assertEqual(metric.topk, (1, 2, 5))
def test_calculate(self):
"""Test using the metric from static method."""
# seq of indices format
y_true = [[0, 2, 5, 8, 9], [1, 4, 6]]
y_pred = [np.arange(10)] * 2
# test with average is 'macro'
recall_score = RetrievalRecall.calculate(
y_pred, y_true, topk=1, pred_indices=True, target_indices=True)
expect_recall = 50.
self.assertEqual(recall_score[0].item(), expect_recall)
# test with tensor input
y_true = torch.Tensor([[1, 0, 1, 0, 0, 1, 0, 0, 1, 1],
[0, 1, 0, 0, 1, 0, 1, 0, 0, 0]])
y_pred = np.array([np.linspace(0.95, 0.05, 10)] * 2)
recall_score = RetrievalRecall.calculate(y_pred, y_true, topk=1)
expect_recall = 50.
self.assertEqual(recall_score[0].item(), expect_recall)
# test with topk is 5
y_pred = np.array([np.linspace(0.95, 0.05, 10)] * 2)
recall_score = RetrievalRecall.calculate(y_pred, y_true, topk=2)
expect_recall = 100.
self.assertEqual(recall_score[0].item(), expect_recall)
# test with topk is (1, 5)
y_pred = np.array([np.linspace(0.95, 0.05, 10)] * 2)
recall_score = RetrievalRecall.calculate(y_pred, y_true, topk=(1, 5))
expect_recalls = [50., 100.]
self.assertEqual(len(recall_score), len(expect_recalls))
for i in range(len(expect_recalls)):
self.assertEqual(recall_score[i].item(), expect_recalls[i])
# Test with invalid pred
y_pred = dict()
y_true = [[0, 2, 5, 8, 9], [1, 4, 6]]
with self.assertRaisesRegex(AssertionError, '`pred` must be Seq'):
RetrievalRecall.calculate(y_pred, y_true, True, True)
# Test with invalid target
y_true = dict()
y_pred = [np.arange(10)] * 2
with self.assertRaisesRegex(AssertionError, '`target` must be Seq'):
RetrievalRecall.calculate(
y_pred, y_true, topk=1, pred_indices=True, target_indices=True)
# Test with different length `pred` with `target`
y_true = [[0, 2, 5, 8, 9], [1, 4, 6]]
y_pred = [np.arange(10)] * 3
with self.assertRaisesRegex(AssertionError, 'Length of `pred`'):
RetrievalRecall.calculate(
y_pred, y_true, topk=1, pred_indices=True, target_indices=True)
# Test with invalid pred
y_true = [[0, 2, 5, 8, 9], dict()]
y_pred = [np.arange(10)] * 2
with self.assertRaisesRegex(AssertionError, '`target` should be'):
RetrievalRecall.calculate(
y_pred, y_true, topk=1, pred_indices=True, target_indices=True)
# Test with invalid target
y_true = [[0, 2, 5, 8, 9], [1, 4, 6]]
y_pred = [np.arange(10), dict()]
with self.assertRaisesRegex(AssertionError, '`pred` should be'):
RetrievalRecall.calculate(
y_pred, y_true, topk=1, pred_indices=True, target_indices=True)