mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature] Dump predictions to a pickle file for offline evaluation. (#293)
* [Feature] Dump predictions to pickle file for offline evaluation. * print_log
This commit is contained in:
parent
b7866021c4
commit
4cd91ffe15
@ -1,6 +1,6 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .evaluator import Evaluator
|
from .evaluator import Evaluator
|
||||||
from .metric import BaseMetric
|
from .metric import BaseMetric, DumpResults
|
||||||
from .utils import get_metric_value
|
from .utils import get_metric_value
|
||||||
|
|
||||||
__all__ = ['BaseMetric', 'Evaluator', 'get_metric_value']
|
__all__ = ['BaseMetric', 'Evaluator', 'get_metric_value', 'DumpResults']
|
||||||
|
@ -3,8 +3,14 @@ import warnings
|
|||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
from typing import Any, List, Optional, Sequence, Union
|
from typing import Any, List, Optional, Sequence, Union
|
||||||
|
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from mmengine.data import BaseDataElement
|
||||||
from mmengine.dist import (broadcast_object_list, collect_results,
|
from mmengine.dist import (broadcast_object_list, collect_results,
|
||||||
is_main_process)
|
is_main_process)
|
||||||
|
from mmengine.fileio import dump
|
||||||
|
from mmengine.logging import print_log
|
||||||
|
from mmengine.registry import METRICS
|
||||||
|
|
||||||
|
|
||||||
class BaseMetric(metaclass=ABCMeta):
|
class BaseMetric(metaclass=ABCMeta):
|
||||||
@ -116,3 +122,51 @@ class BaseMetric(metaclass=ABCMeta):
|
|||||||
# reset the results list
|
# reset the results list
|
||||||
self.results.clear()
|
self.results.clear()
|
||||||
return metrics[0]
|
return metrics[0]
|
||||||
|
|
||||||
|
|
||||||
|
@METRICS.register_module()
|
||||||
|
class DumpResults(BaseMetric):
|
||||||
|
"""Dump model predictions to a pickle file for offline evaluation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
out_file_path (str): Path of the dumped file. Must end with '.pkl'
|
||||||
|
or '.pickle'.
|
||||||
|
collect_device (str): Device name used for collecting results from
|
||||||
|
different ranks during distributed training. Must be 'cpu' or
|
||||||
|
'gpu'. Defaults to 'cpu'.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
out_file_path: str,
|
||||||
|
collect_device: str = 'cpu') -> None:
|
||||||
|
super().__init__(collect_device=collect_device)
|
||||||
|
if not out_file_path.endswith(('.pkl', '.pickle')):
|
||||||
|
raise ValueError('The output file must be a pkl file.')
|
||||||
|
self.out_file_path = out_file_path
|
||||||
|
|
||||||
|
def process(self, data_batch: Sequence[dict],
|
||||||
|
predictions: Sequence[dict]) -> None:
|
||||||
|
"""transfer tensors in predictions to CPU."""
|
||||||
|
self.results.extend(_to_cpu(predictions))
|
||||||
|
|
||||||
|
def compute_metrics(self, results: list) -> dict:
|
||||||
|
"""dump the prediction results to a pickle file."""
|
||||||
|
dump(results, self.out_file_path)
|
||||||
|
print_log(
|
||||||
|
f'Results has been saved to {self.out_file_path}.',
|
||||||
|
logger='current')
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def _to_cpu(data: Any) -> Any:
|
||||||
|
"""transfer all tensors and BaseDataElement to cpu."""
|
||||||
|
if isinstance(data, (Tensor, BaseDataElement)):
|
||||||
|
return data.to('cpu')
|
||||||
|
elif isinstance(data, list):
|
||||||
|
return [_to_cpu(d) for d in data]
|
||||||
|
elif isinstance(data, tuple):
|
||||||
|
return tuple(_to_cpu(d) for d in data)
|
||||||
|
elif isinstance(data, dict):
|
||||||
|
return {k: _to_cpu(v) for k, v in data.items()}
|
||||||
|
else:
|
||||||
|
return data
|
||||||
|
41
tests/test_evaluator/test_metric.py
Normal file
41
tests/test_evaluator/test_metric.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import os.path as osp
|
||||||
|
import tempfile
|
||||||
|
from unittest import TestCase
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from mmengine.evaluator import DumpResults
|
||||||
|
from mmengine.fileio import load
|
||||||
|
|
||||||
|
|
||||||
|
class TestDumpResults(TestCase):
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
'The output file must be a pkl file.'):
|
||||||
|
DumpResults(out_file_path='./results.json')
|
||||||
|
|
||||||
|
def test_process(self):
|
||||||
|
metric = DumpResults(out_file_path='./results.pkl')
|
||||||
|
predictions = [dict(data=(Tensor([1, 2, 3]), Tensor([4, 5, 6])))]
|
||||||
|
metric.process(None, predictions)
|
||||||
|
self.assertEqual(len(metric.results), 1)
|
||||||
|
self.assertEqual(metric.results[0]['data'][0].device,
|
||||||
|
torch.device('cpu'))
|
||||||
|
|
||||||
|
def test_compute_metrics(self):
|
||||||
|
temp_dir = tempfile.TemporaryDirectory()
|
||||||
|
path = osp.join(temp_dir.name, 'results.pkl')
|
||||||
|
metric = DumpResults(out_file_path=path)
|
||||||
|
predictions = [dict(data=(Tensor([1, 2, 3]), Tensor([4, 5, 6])))]
|
||||||
|
metric.process(None, predictions)
|
||||||
|
metric.compute_metrics(metric.results)
|
||||||
|
self.assertTrue(osp.isfile(path))
|
||||||
|
|
||||||
|
results = load(path)
|
||||||
|
self.assertEqual(len(results), 1)
|
||||||
|
self.assertEqual(results[0]['data'][0].device, torch.device('cpu'))
|
||||||
|
|
||||||
|
temp_dir.cleanup()
|
Loading…
x
Reference in New Issue
Block a user