[Enhancement] Automatically report mean scores when applicable (#995)

This commit is contained in:
Tong Gao 2022-05-05 12:57:19 +08:00 committed by GitHub
parent 28c9e460d5
commit 6b180db93d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 28 additions and 9 deletions

View File

@ -19,10 +19,12 @@ class UniformConcatDataset(ConcatDataset):
separate_eval (bool): Whether to evaluate the results
separately if it is used as validation dataset.
Defaults to True.
show_mean_scores (bool): Whether to compute the mean evaluation
results, only applicable when ``separate_eval=True``. If ``True``,
mean results will be added to the result dictionary with keys in
the form of ``mean_{metric_name}``.
show_mean_scores (str | bool): Whether to compute the mean evaluation
results, only applicable when ``separate_eval=True``. Options are
[True, False, ``auto``]. If ``True``, mean results will be added to
the result dictionary with keys in the form of
``mean_{metric_name}``. If 'auto', mean results will be shown only
when more than 1 dataset is wrapped.
pipeline (None | list[dict] | list[list[dict]]): If ``None``,
each dataset in datasets use its own pipeline;
If ``list[dict]``, it will be assigned to the dataset whose
@ -37,7 +39,7 @@ class UniformConcatDataset(ConcatDataset):
def __init__(self,
datasets,
separate_eval=True,
show_mean_scores=False,
show_mean_scores='auto',
pipeline=None,
force_apply=False,
**kwargs):
@ -70,8 +72,12 @@ class UniformConcatDataset(ConcatDataset):
'Evaluating datasets as a whole is not'
' supported yet. Please use "separate_eval=True"')
assert isinstance(show_mean_scores, bool) or show_mean_scores == 'auto'
if show_mean_scores == 'auto':
show_mean_scores = len(self.datasets) > 1
self.show_mean_scores = show_mean_scores
if show_mean_scores:
if show_mean_scores is True or show_mean_scores == 'auto' and len(
self.datasets) > 1:
if len(set([type(ds) for ds in self.datasets])) != 1:
raise NotImplementedError(
'To compute mean evaluation scores, all datasets'

View File

@ -78,19 +78,32 @@ def test_uniform_concat_dataset_eval():
def evaluate(self, res, logger, **kwargs):
return dict(n=res[0])
# Test 'auto'
fake_inputs = [10]
datasets = [dict(type='DummyDataset')]
tmp_dataset = UniformConcatDataset(datasets)
results = tmp_dataset.evaluate(fake_inputs)
assert results['0_n'] == 10
assert 'mean_n' not in results
tmp_dataset = UniformConcatDataset(datasets, show_mean_scores=True)
results = tmp_dataset.evaluate(fake_inputs)
assert results['mean_n'] == 10
fake_inputs = [10, 20]
datasets = [dict(type='DummyDataset'), dict(type='DummyDataset')]
tmp_dataset = UniformConcatDataset(datasets)
tmp_dataset = UniformConcatDataset(datasets)
results = tmp_dataset.evaluate(fake_inputs)
assert results['0_n'] == 10
assert results['1_n'] == 20
assert results['mean_n'] == 15
tmp_dataset = UniformConcatDataset(datasets, show_mean_scores=True)
tmp_dataset = UniformConcatDataset(datasets, show_mean_scores=False)
results = tmp_dataset.evaluate(fake_inputs)
assert results['0_n'] == 10
assert results['1_n'] == 20
assert results['mean_n'] == 15
assert 'mean_n' not in results
with pytest.raises(NotImplementedError):
ds = UniformConcatDataset(datasets, separate_eval=False)