mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
[Enhancement] Automatically report mean scores when applicable (#995)
This commit is contained in:
parent
28c9e460d5
commit
6b180db93d
@ -19,10 +19,12 @@ class UniformConcatDataset(ConcatDataset):
|
|||||||
separate_eval (bool): Whether to evaluate the results
|
separate_eval (bool): Whether to evaluate the results
|
||||||
separately if it is used as validation dataset.
|
separately if it is used as validation dataset.
|
||||||
Defaults to True.
|
Defaults to True.
|
||||||
show_mean_scores (bool): Whether to compute the mean evaluation
|
show_mean_scores (str | bool): Whether to compute the mean evaluation
|
||||||
results, only applicable when ``separate_eval=True``. If ``True``,
|
results, only applicable when ``separate_eval=True``. Options are
|
||||||
mean results will be added to the result dictionary with keys in
|
[True, False, ``auto``]. If ``True``, mean results will be added to
|
||||||
the form of ``mean_{metric_name}``.
|
the result dictionary with keys in the form of
|
||||||
|
``mean_{metric_name}``. If 'auto', mean results will be shown only
|
||||||
|
when more than 1 dataset is wrapped.
|
||||||
pipeline (None | list[dict] | list[list[dict]]): If ``None``,
|
pipeline (None | list[dict] | list[list[dict]]): If ``None``,
|
||||||
each dataset in datasets use its own pipeline;
|
each dataset in datasets use its own pipeline;
|
||||||
If ``list[dict]``, it will be assigned to the dataset whose
|
If ``list[dict]``, it will be assigned to the dataset whose
|
||||||
@ -37,7 +39,7 @@ class UniformConcatDataset(ConcatDataset):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
datasets,
|
datasets,
|
||||||
separate_eval=True,
|
separate_eval=True,
|
||||||
show_mean_scores=False,
|
show_mean_scores='auto',
|
||||||
pipeline=None,
|
pipeline=None,
|
||||||
force_apply=False,
|
force_apply=False,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
@ -70,8 +72,12 @@ class UniformConcatDataset(ConcatDataset):
|
|||||||
'Evaluating datasets as a whole is not'
|
'Evaluating datasets as a whole is not'
|
||||||
' supported yet. Please use "separate_eval=True"')
|
' supported yet. Please use "separate_eval=True"')
|
||||||
|
|
||||||
|
assert isinstance(show_mean_scores, bool) or show_mean_scores == 'auto'
|
||||||
|
if show_mean_scores == 'auto':
|
||||||
|
show_mean_scores = len(self.datasets) > 1
|
||||||
self.show_mean_scores = show_mean_scores
|
self.show_mean_scores = show_mean_scores
|
||||||
if show_mean_scores:
|
if show_mean_scores is True or show_mean_scores == 'auto' and len(
|
||||||
|
self.datasets) > 1:
|
||||||
if len(set([type(ds) for ds in self.datasets])) != 1:
|
if len(set([type(ds) for ds in self.datasets])) != 1:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
'To compute mean evaluation scores, all datasets'
|
'To compute mean evaluation scores, all datasets'
|
||||||
|
@ -78,19 +78,32 @@ def test_uniform_concat_dataset_eval():
|
|||||||
def evaluate(self, res, logger, **kwargs):
|
def evaluate(self, res, logger, **kwargs):
|
||||||
return dict(n=res[0])
|
return dict(n=res[0])
|
||||||
|
|
||||||
|
# Test 'auto'
|
||||||
|
fake_inputs = [10]
|
||||||
|
datasets = [dict(type='DummyDataset')]
|
||||||
|
tmp_dataset = UniformConcatDataset(datasets)
|
||||||
|
results = tmp_dataset.evaluate(fake_inputs)
|
||||||
|
assert results['0_n'] == 10
|
||||||
|
assert 'mean_n' not in results
|
||||||
|
|
||||||
|
tmp_dataset = UniformConcatDataset(datasets, show_mean_scores=True)
|
||||||
|
results = tmp_dataset.evaluate(fake_inputs)
|
||||||
|
assert results['mean_n'] == 10
|
||||||
|
|
||||||
fake_inputs = [10, 20]
|
fake_inputs = [10, 20]
|
||||||
datasets = [dict(type='DummyDataset'), dict(type='DummyDataset')]
|
datasets = [dict(type='DummyDataset'), dict(type='DummyDataset')]
|
||||||
|
tmp_dataset = UniformConcatDataset(datasets)
|
||||||
tmp_dataset = UniformConcatDataset(datasets)
|
tmp_dataset = UniformConcatDataset(datasets)
|
||||||
results = tmp_dataset.evaluate(fake_inputs)
|
results = tmp_dataset.evaluate(fake_inputs)
|
||||||
assert results['0_n'] == 10
|
assert results['0_n'] == 10
|
||||||
assert results['1_n'] == 20
|
assert results['1_n'] == 20
|
||||||
|
assert results['mean_n'] == 15
|
||||||
|
|
||||||
tmp_dataset = UniformConcatDataset(datasets, show_mean_scores=True)
|
tmp_dataset = UniformConcatDataset(datasets, show_mean_scores=False)
|
||||||
results = tmp_dataset.evaluate(fake_inputs)
|
results = tmp_dataset.evaluate(fake_inputs)
|
||||||
assert results['0_n'] == 10
|
assert results['0_n'] == 10
|
||||||
assert results['1_n'] == 20
|
assert results['1_n'] == 20
|
||||||
assert results['mean_n'] == 15
|
assert 'mean_n' not in results
|
||||||
|
|
||||||
with pytest.raises(NotImplementedError):
|
with pytest.raises(NotImplementedError):
|
||||||
ds = UniformConcatDataset(datasets, separate_eval=False)
|
ds = UniformConcatDataset(datasets, separate_eval=False)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user