mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
[Enhancement] Automatically report mean scores when applicable (#995)
This commit is contained in:
parent
28c9e460d5
commit
6b180db93d
@ -19,10 +19,12 @@ class UniformConcatDataset(ConcatDataset):
|
||||
separate_eval (bool): Whether to evaluate the results
|
||||
separately if it is used as validation dataset.
|
||||
Defaults to True.
|
||||
show_mean_scores (bool): Whether to compute the mean evaluation
|
||||
results, only applicable when ``separate_eval=True``. If ``True``,
|
||||
mean results will be added to the result dictionary with keys in
|
||||
the form of ``mean_{metric_name}``.
|
||||
show_mean_scores (str | bool): Whether to compute the mean evaluation
|
||||
results, only applicable when ``separate_eval=True``. Options are
|
||||
[True, False, ``auto``]. If ``True``, mean results will be added to
|
||||
the result dictionary with keys in the form of
|
||||
``mean_{metric_name}``. If 'auto', mean results will be shown only
|
||||
when more than 1 dataset is wrapped.
|
||||
pipeline (None | list[dict] | list[list[dict]]): If ``None``,
|
||||
each dataset in datasets use its own pipeline;
|
||||
If ``list[dict]``, it will be assigned to the dataset whose
|
||||
@ -37,7 +39,7 @@ class UniformConcatDataset(ConcatDataset):
|
||||
def __init__(self,
|
||||
datasets,
|
||||
separate_eval=True,
|
||||
show_mean_scores=False,
|
||||
show_mean_scores='auto',
|
||||
pipeline=None,
|
||||
force_apply=False,
|
||||
**kwargs):
|
||||
@ -70,8 +72,12 @@ class UniformConcatDataset(ConcatDataset):
|
||||
'Evaluating datasets as a whole is not'
|
||||
' supported yet. Please use "separate_eval=True"')
|
||||
|
||||
assert isinstance(show_mean_scores, bool) or show_mean_scores == 'auto'
|
||||
if show_mean_scores == 'auto':
|
||||
show_mean_scores = len(self.datasets) > 1
|
||||
self.show_mean_scores = show_mean_scores
|
||||
if show_mean_scores:
|
||||
if show_mean_scores is True or show_mean_scores == 'auto' and len(
|
||||
self.datasets) > 1:
|
||||
if len(set([type(ds) for ds in self.datasets])) != 1:
|
||||
raise NotImplementedError(
|
||||
'To compute mean evaluation scores, all datasets'
|
||||
|
@ -78,19 +78,32 @@ def test_uniform_concat_dataset_eval():
|
||||
def evaluate(self, res, logger, **kwargs):
|
||||
return dict(n=res[0])
|
||||
|
||||
# Test 'auto'
|
||||
fake_inputs = [10]
|
||||
datasets = [dict(type='DummyDataset')]
|
||||
tmp_dataset = UniformConcatDataset(datasets)
|
||||
results = tmp_dataset.evaluate(fake_inputs)
|
||||
assert results['0_n'] == 10
|
||||
assert 'mean_n' not in results
|
||||
|
||||
tmp_dataset = UniformConcatDataset(datasets, show_mean_scores=True)
|
||||
results = tmp_dataset.evaluate(fake_inputs)
|
||||
assert results['mean_n'] == 10
|
||||
|
||||
fake_inputs = [10, 20]
|
||||
datasets = [dict(type='DummyDataset'), dict(type='DummyDataset')]
|
||||
|
||||
tmp_dataset = UniformConcatDataset(datasets)
|
||||
tmp_dataset = UniformConcatDataset(datasets)
|
||||
results = tmp_dataset.evaluate(fake_inputs)
|
||||
assert results['0_n'] == 10
|
||||
assert results['1_n'] == 20
|
||||
assert results['mean_n'] == 15
|
||||
|
||||
tmp_dataset = UniformConcatDataset(datasets, show_mean_scores=True)
|
||||
tmp_dataset = UniformConcatDataset(datasets, show_mean_scores=False)
|
||||
results = tmp_dataset.evaluate(fake_inputs)
|
||||
assert results['0_n'] == 10
|
||||
assert results['1_n'] == 20
|
||||
assert results['mean_n'] == 15
|
||||
assert 'mean_n' not in results
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
ds = UniformConcatDataset(datasets, separate_eval=False)
|
||||
|
Loading…
x
Reference in New Issue
Block a user