42 lines
1.4 KiB
Python
42 lines
1.4 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import os.path as osp
|
|
import tempfile
|
|
from unittest import TestCase
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
|
|
from mmengine.evaluator import DumpResults
|
|
from mmengine.fileio import load
|
|
|
|
|
|
class TestDumpResults(TestCase):
|
|
|
|
def test_init(self):
|
|
with self.assertRaisesRegex(ValueError,
|
|
'The output file must be a pkl file.'):
|
|
DumpResults(out_file_path='./results.json')
|
|
|
|
def test_process(self):
|
|
metric = DumpResults(out_file_path='./results.pkl')
|
|
predictions = [dict(data=(Tensor([1, 2, 3]), Tensor([4, 5, 6])))]
|
|
metric.process(None, predictions)
|
|
self.assertEqual(len(metric.results), 1)
|
|
self.assertEqual(metric.results[0]['data'][0].device,
|
|
torch.device('cpu'))
|
|
|
|
def test_compute_metrics(self):
|
|
temp_dir = tempfile.TemporaryDirectory()
|
|
path = osp.join(temp_dir.name, 'results.pkl')
|
|
metric = DumpResults(out_file_path=path)
|
|
predictions = [dict(data=(Tensor([1, 2, 3]), Tensor([4, 5, 6])))]
|
|
metric.process(None, predictions)
|
|
metric.compute_metrics(metric.results)
|
|
self.assertTrue(osp.isfile(path))
|
|
|
|
results = load(path)
|
|
self.assertEqual(len(results), 1)
|
|
self.assertEqual(results[0]['data'][0].device, torch.device('cpu'))
|
|
|
|
temp_dir.cleanup()
|