mmengine/tests/test_evaluator/test_metric.py

42 lines
1.4 KiB
Python
Raw Normal View History

# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
from unittest import TestCase
import torch
from torch import Tensor
from mmengine.evaluator import DumpResults
from mmengine.fileio import load
class TestDumpResults(TestCase):
def test_init(self):
with self.assertRaisesRegex(ValueError,
'The output file must be a pkl file.'):
DumpResults(out_file_path='./results.json')
def test_process(self):
metric = DumpResults(out_file_path='./results.pkl')
[Refactor] Refactor data flow to make the interface more natural (#468) * [Refactor]: modify interface of Visualizer.add_datasample (#365) * [Refactor] Refactor data flow: refine `data_preprocessor`. (#359) * refine data_preprocessor * remove unused BATCH_DATA alias * Fix type hints * rename move_data to cast_data * [Refactor] Refactor data flow: collate data in `collate_fn` of `DataLoader` (#323) * acollate data in dataloader * fix docstring * refine comment * fix as comment * refactor default collate and psedo collate * foramt test file * fix docstring * fix as comment * rename elem to data_item * minor fix * fix as comment * [Refactor] Refactor data flow: `data_batch` argument of `Evaluator.process is a `dict` (#360) * refine evaluator and metric * compatible with new default collate * replace default collate with pseudo * Handle data_batch in metric * fix unit test * fix unit test * fix unit test * minor refine * make data_batch optional make data_batch optional * rename outputs to predictions * fix ut * rename predictions to outputs * fix docstring * fix docstring * fix unit test * make outputs and data_batch to kwargs * fix unit test * keep signature of metric * fix ut * rename pred_sample arguments to data_sample(Visualizer) * fix loop and ut * [refactor]: Refactor model dataflow (#398) * [Refactor] Refactor data flow: refine `data_preprocessor`. (#359) * refine data_preprocessor * remove unused BATCH_DATA alias * Fix type hints * rename move_data to cast_data * refactor model data flow tmp_commt tmp commit * make val_cfg and test_cfg optional * roll back runner * pass test mmdet * fix as comment fix as comment fix ci in DataPreprocessor * fix ut * fix ut * fix rebase main * [Fix]: Fix test val ddp (#462) * [Fix] Fix docstring and type hint of data flow (#463) * Fix docstring of data flow * change signature of hook * fix unit test * resolve conflicts * fix lint
2022-08-24 22:04:55 +08:00
data_samples = [dict(data=(Tensor([1, 2, 3]), Tensor([4, 5, 6])))]
metric.process(None, data_samples)
self.assertEqual(len(metric.results), 1)
self.assertEqual(metric.results[0]['data'][0].device,
torch.device('cpu'))
def test_compute_metrics(self):
temp_dir = tempfile.TemporaryDirectory()
path = osp.join(temp_dir.name, 'results.pkl')
metric = DumpResults(out_file_path=path)
[Refactor] Refactor data flow to make the interface more natural (#468) * [Refactor]: modify interface of Visualizer.add_datasample (#365) * [Refactor] Refactor data flow: refine `data_preprocessor`. (#359) * refine data_preprocessor * remove unused BATCH_DATA alias * Fix type hints * rename move_data to cast_data * [Refactor] Refactor data flow: collate data in `collate_fn` of `DataLoader` (#323) * acollate data in dataloader * fix docstring * refine comment * fix as comment * refactor default collate and psedo collate * foramt test file * fix docstring * fix as comment * rename elem to data_item * minor fix * fix as comment * [Refactor] Refactor data flow: `data_batch` argument of `Evaluator.process is a `dict` (#360) * refine evaluator and metric * compatible with new default collate * replace default collate with pseudo * Handle data_batch in metric * fix unit test * fix unit test * fix unit test * minor refine * make data_batch optional make data_batch optional * rename outputs to predictions * fix ut * rename predictions to outputs * fix docstring * fix docstring * fix unit test * make outputs and data_batch to kwargs * fix unit test * keep signature of metric * fix ut * rename pred_sample arguments to data_sample(Visualizer) * fix loop and ut * [refactor]: Refactor model dataflow (#398) * [Refactor] Refactor data flow: refine `data_preprocessor`. (#359) * refine data_preprocessor * remove unused BATCH_DATA alias * Fix type hints * rename move_data to cast_data * refactor model data flow tmp_commt tmp commit * make val_cfg and test_cfg optional * roll back runner * pass test mmdet * fix as comment fix as comment fix ci in DataPreprocessor * fix ut * fix ut * fix rebase main * [Fix]: Fix test val ddp (#462) * [Fix] Fix docstring and type hint of data flow (#463) * Fix docstring of data flow * change signature of hook * fix unit test * resolve conflicts * fix lint
2022-08-24 22:04:55 +08:00
data_samples = [dict(data=(Tensor([1, 2, 3]), Tensor([4, 5, 6])))]
metric.process(None, data_samples)
metric.compute_metrics(metric.results)
self.assertTrue(osp.isfile(path))
results = load(path)
self.assertEqual(len(results), 1)
self.assertEqual(results[0]['data'][0].device, torch.device('cpu'))
temp_dir.cleanup()