# Copyright (c) OpenMMLab. All rights reserved. import os.path as osp import tempfile from unittest import TestCase import torch from torch import Tensor from mmengine.evaluator import DumpResults from mmengine.fileio import load class TestDumpResults(TestCase): def test_init(self): with self.assertRaisesRegex(ValueError, 'The output file must be a pkl file.'): DumpResults(out_file_path='./results.json') def test_process(self): metric = DumpResults(out_file_path='./results.pkl') predictions = [dict(data=(Tensor([1, 2, 3]), Tensor([4, 5, 6])))] metric.process(None, predictions) self.assertEqual(len(metric.results), 1) self.assertEqual(metric.results[0]['data'][0].device, torch.device('cpu')) def test_compute_metrics(self): temp_dir = tempfile.TemporaryDirectory() path = osp.join(temp_dir.name, 'results.pkl') metric = DumpResults(out_file_path=path) predictions = [dict(data=(Tensor([1, 2, 3]), Tensor([4, 5, 6])))] metric.process(None, predictions) metric.compute_metrics(metric.results) self.assertTrue(osp.isfile(path)) results = load(path) self.assertEqual(len(results), 1) self.assertEqual(results[0]['data'][0].device, torch.device('cpu')) temp_dir.cleanup()