[Feature] Dump predictions to a pickle file for offline evaluation. (#293)
* [Feature] Dump predictions to pickle file for offline evaluation. * print_logpull/306/head
parent
b7866021c4
commit
4cd91ffe15
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .evaluator import Evaluator
|
||||
from .metric import BaseMetric
|
||||
from .metric import BaseMetric, DumpResults
|
||||
from .utils import get_metric_value
|
||||
|
||||
__all__ = ['BaseMetric', 'Evaluator', 'get_metric_value']
|
||||
__all__ = ['BaseMetric', 'Evaluator', 'get_metric_value', 'DumpResults']
|
||||
|
|
|
@ -3,8 +3,14 @@ import warnings
|
|||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Any, List, Optional, Sequence, Union
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
from mmengine.data import BaseDataElement
|
||||
from mmengine.dist import (broadcast_object_list, collect_results,
|
||||
is_main_process)
|
||||
from mmengine.fileio import dump
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.registry import METRICS
|
||||
|
||||
|
||||
class BaseMetric(metaclass=ABCMeta):
|
||||
|
@ -116,3 +122,51 @@ class BaseMetric(metaclass=ABCMeta):
|
|||
# reset the results list
|
||||
self.results.clear()
|
||||
return metrics[0]
|
||||
|
||||
|
||||
@METRICS.register_module()
|
||||
class DumpResults(BaseMetric):
|
||||
"""Dump model predictions to a pickle file for offline evaluation.
|
||||
|
||||
Args:
|
||||
out_file_path (str): Path of the dumped file. Must end with '.pkl'
|
||||
or '.pickle'.
|
||||
collect_device (str): Device name used for collecting results from
|
||||
different ranks during distributed training. Must be 'cpu' or
|
||||
'gpu'. Defaults to 'cpu'.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
out_file_path: str,
|
||||
collect_device: str = 'cpu') -> None:
|
||||
super().__init__(collect_device=collect_device)
|
||||
if not out_file_path.endswith(('.pkl', '.pickle')):
|
||||
raise ValueError('The output file must be a pkl file.')
|
||||
self.out_file_path = out_file_path
|
||||
|
||||
def process(self, data_batch: Sequence[dict],
|
||||
predictions: Sequence[dict]) -> None:
|
||||
"""transfer tensors in predictions to CPU."""
|
||||
self.results.extend(_to_cpu(predictions))
|
||||
|
||||
def compute_metrics(self, results: list) -> dict:
|
||||
"""dump the prediction results to a pickle file."""
|
||||
dump(results, self.out_file_path)
|
||||
print_log(
|
||||
f'Results has been saved to {self.out_file_path}.',
|
||||
logger='current')
|
||||
return {}
|
||||
|
||||
|
||||
def _to_cpu(data: Any) -> Any:
|
||||
"""transfer all tensors and BaseDataElement to cpu."""
|
||||
if isinstance(data, (Tensor, BaseDataElement)):
|
||||
return data.to('cpu')
|
||||
elif isinstance(data, list):
|
||||
return [_to_cpu(d) for d in data]
|
||||
elif isinstance(data, tuple):
|
||||
return tuple(_to_cpu(d) for d in data)
|
||||
elif isinstance(data, dict):
|
||||
return {k: _to_cpu(v) for k, v in data.items()}
|
||||
else:
|
||||
return data
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from mmengine.evaluator import DumpResults
|
||||
from mmengine.fileio import load
|
||||
|
||||
|
||||
class TestDumpResults(TestCase):
|
||||
|
||||
def test_init(self):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
'The output file must be a pkl file.'):
|
||||
DumpResults(out_file_path='./results.json')
|
||||
|
||||
def test_process(self):
|
||||
metric = DumpResults(out_file_path='./results.pkl')
|
||||
predictions = [dict(data=(Tensor([1, 2, 3]), Tensor([4, 5, 6])))]
|
||||
metric.process(None, predictions)
|
||||
self.assertEqual(len(metric.results), 1)
|
||||
self.assertEqual(metric.results[0]['data'][0].device,
|
||||
torch.device('cpu'))
|
||||
|
||||
def test_compute_metrics(self):
|
||||
temp_dir = tempfile.TemporaryDirectory()
|
||||
path = osp.join(temp_dir.name, 'results.pkl')
|
||||
metric = DumpResults(out_file_path=path)
|
||||
predictions = [dict(data=(Tensor([1, 2, 3]), Tensor([4, 5, 6])))]
|
||||
metric.process(None, predictions)
|
||||
metric.compute_metrics(metric.results)
|
||||
self.assertTrue(osp.isfile(path))
|
||||
|
||||
results = load(path)
|
||||
self.assertEqual(len(results), 1)
|
||||
self.assertEqual(results[0]['data'][0].device, torch.device('cpu'))
|
||||
|
||||
temp_dir.cleanup()
|
Loading…
Reference in New Issue