diff --git a/mmocr/datasets/uniform_concat_dataset.py b/mmocr/datasets/uniform_concat_dataset.py index 92a88752..9fdc6063 100644 --- a/mmocr/datasets/uniform_concat_dataset.py +++ b/mmocr/datasets/uniform_concat_dataset.py @@ -19,10 +19,12 @@ class UniformConcatDataset(ConcatDataset): separate_eval (bool): Whether to evaluate the results separately if it is used as validation dataset. Defaults to True. - show_mean_scores (bool): Whether to compute the mean evaluation - results, only applicable when ``separate_eval=True``. If ``True``, - mean results will be added to the result dictionary with keys in - the form of ``mean_{metric_name}``. + show_mean_scores (str | bool): Whether to compute the mean evaluation + results, only applicable when ``separate_eval=True``. Options are + [True, False, ``auto``]. If ``True``, mean results will be added to + the result dictionary with keys in the form of + ``mean_{metric_name}``. If 'auto', mean results will be shown only + when more than 1 dataset is wrapped. pipeline (None | list[dict] | list[list[dict]]): If ``None``, each dataset in datasets use its own pipeline; If ``list[dict]``, it will be assigned to the dataset whose @@ -37,7 +39,7 @@ class UniformConcatDataset(ConcatDataset): def __init__(self, datasets, separate_eval=True, - show_mean_scores=False, + show_mean_scores='auto', pipeline=None, force_apply=False, **kwargs): @@ -70,8 +72,12 @@ class UniformConcatDataset(ConcatDataset): 'Evaluating datasets as a whole is not' ' supported yet. Please use "separate_eval=True"') + assert isinstance(show_mean_scores, bool) or show_mean_scores == 'auto' + if show_mean_scores == 'auto': + show_mean_scores = len(self.datasets) > 1 self.show_mean_scores = show_mean_scores - if show_mean_scores: + if show_mean_scores is True or show_mean_scores == 'auto' and len( + self.datasets) > 1: if len(set([type(ds) for ds in self.datasets])) != 1: raise NotImplementedError( 'To compute mean evaluation scores, all datasets' diff --git a/tests/test_dataset/test_uniform_concat_dataset.py b/tests/test_dataset/test_uniform_concat_dataset.py index f45f44ea..7f001295 100644 --- a/tests/test_dataset/test_uniform_concat_dataset.py +++ b/tests/test_dataset/test_uniform_concat_dataset.py @@ -78,19 +78,32 @@ def test_uniform_concat_dataset_eval(): def evaluate(self, res, logger, **kwargs): return dict(n=res[0]) + # Test 'auto' + fake_inputs = [10] + datasets = [dict(type='DummyDataset')] + tmp_dataset = UniformConcatDataset(datasets) + results = tmp_dataset.evaluate(fake_inputs) + assert results['0_n'] == 10 + assert 'mean_n' not in results + + tmp_dataset = UniformConcatDataset(datasets, show_mean_scores=True) + results = tmp_dataset.evaluate(fake_inputs) + assert results['mean_n'] == 10 + fake_inputs = [10, 20] datasets = [dict(type='DummyDataset'), dict(type='DummyDataset')] - + tmp_dataset = UniformConcatDataset(datasets) tmp_dataset = UniformConcatDataset(datasets) results = tmp_dataset.evaluate(fake_inputs) assert results['0_n'] == 10 assert results['1_n'] == 20 + assert results['mean_n'] == 15 - tmp_dataset = UniformConcatDataset(datasets, show_mean_scores=True) + tmp_dataset = UniformConcatDataset(datasets, show_mean_scores=False) results = tmp_dataset.evaluate(fake_inputs) assert results['0_n'] == 10 assert results['1_n'] == 20 - assert results['mean_n'] == 15 + assert 'mean_n' not in results with pytest.raises(NotImplementedError): ds = UniformConcatDataset(datasets, separate_eval=False)