[Feature] Add `evaluate` function for ConcatDataset. ()

* Add `evaluate` function for ConcatDataset

* Remove newline in log.

* Fix lint

* Specify mmcv version in Windows CI
pull/717/head
Ma Zerun 2022-02-28 12:46:17 +08:00 committed by GitHub
parent 1214df083d
commit 1a28f9ace6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 112 additions and 8 deletions

View File

@ -172,7 +172,7 @@ jobs:
- name: Install MMCV & OpenCV
run: |
pip install opencv-python
pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cpu/torch${{matrix.torch_major}}/index.html
pip install mmcv-full==1.4.2 -f https://download.openmmlab.com/mmcv/dist/cpu/torch${{matrix.torch_major}}/index.html
python -c 'import mmcv; print(mmcv.__version__)'
- name: Install mmcls dependencies
run: |

View File

@ -59,7 +59,7 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
"""Get all ground-truth labels (categories).
Returns:
list[int]: categories for all images.
np.ndarray: categories for all images.
"""
gt_labels = np.array([data['gt_label'] for data in self.data_infos])

View File

@ -29,6 +29,10 @@ def build_dataset(cfg, default_args=None):
KFoldDataset, RepeatDataset)
if isinstance(cfg, (list, tuple)):
dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
elif cfg['type'] == 'ConcatDataset':
dataset = ConcatDataset(
[build_dataset(c, default_args) for c in cfg['datasets']],
separate_eval=cfg.get('separate_eval', True))
elif cfg['type'] == 'RepeatDataset':
dataset = RepeatDataset(
build_dataset(cfg['dataset'], default_args), cfg['times'])

View File

@ -4,6 +4,7 @@ import math
from collections import defaultdict
import numpy as np
from mmcv.utils import print_log
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
from .builder import DATASETS
@ -18,12 +19,23 @@ class ConcatDataset(_ConcatDataset):
Args:
datasets (list[:obj:`Dataset`]): A list of datasets.
separate_eval (bool): Whether to evaluate the results
separately if it is used as validation dataset.
Defaults to True.
"""
def __init__(self, datasets):
def __init__(self, datasets, separate_eval=True):
super(ConcatDataset, self).__init__(datasets)
self.separate_eval = separate_eval
self.CLASSES = datasets[0].CLASSES
if not separate_eval:
if len(set([type(ds) for ds in datasets])) != 1:
raise NotImplementedError(
'To evaluate a concat dataset non-separately, '
'all the datasets should have same types')
def get_cat_ids(self, idx):
if idx < 0:
if -idx > len(self):
@ -37,6 +49,63 @@ class ConcatDataset(_ConcatDataset):
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return self.datasets[dataset_idx].get_cat_ids(sample_idx)
def evaluate(self, results, *args, indices=None, logger=None, **kwargs):
"""Evaluate the results.
Args:
results (list[list | tuple]): Testing results of the dataset.
indices (list, optional): The indices of samples corresponding to
the results. It's unavailable on ConcatDataset.
Defaults to None.
logger (logging.Logger | str, optional): Logger used for printing
related information during evaluation. Defaults to None.
Returns:
dict[str: float]: AP results of the total dataset or each separate
dataset if `self.separate_eval=True`.
"""
if indices is not None:
raise NotImplementedError(
'Use indices to evaluate speific samples in a ConcatDataset '
'is not supported by now.')
assert len(results) == len(self), \
('Dataset and results have different sizes: '
f'{len(self)} v.s. {len(results)}')
# Check whether all the datasets support evaluation
for dataset in self.datasets:
assert hasattr(dataset, 'evaluate'), \
f"{type(dataset)} haven't implemented the evaluate function."
if self.separate_eval:
total_eval_results = dict()
for dataset_idx, dataset in enumerate(self.datasets):
start_idx = 0 if dataset_idx == 0 else \
self.cumulative_sizes[dataset_idx-1]
end_idx = self.cumulative_sizes[dataset_idx]
results_per_dataset = results[start_idx:end_idx]
print_log(
f'Evaluateing dataset-{dataset_idx} with '
f'{len(results_per_dataset)} images now',
logger=logger)
eval_results_per_dataset = dataset.evaluate(
results_per_dataset, *args, logger=logger, **kwargs)
for k, v in eval_results_per_dataset.items():
total_eval_results.update({f'{dataset_idx}_{k}': v})
return total_eval_results
else:
original_data_infos = self.datasets[0].data_infos
self.datasets[0].data_infos = sum(
[dataset.data_infos for dataset in self.datasets], [])
eval_results = self.datasets[0].evaluate(
results, logger=logger, **kwargs)
self.datasets[0].data_infos = original_data_infos
return eval_results
@DATASETS.register_module()
class RepeatDataset(object):
@ -68,6 +137,20 @@ class RepeatDataset(object):
def __len__(self):
return self.times * self._ori_len
def evaluate(self, *args, **kwargs):
raise NotImplementedError(
'evaluate results on a repeated dataset is weird. '
'Please inference and evaluate on the original dataset.')
def __repr__(self):
"""Print the number of instance number."""
dataset_type = 'Test' if self.test_mode else 'Train'
result = (
f'\n{self.__class__.__name__} ({self.dataset.__class__.__name__}) '
f'{dataset_type} dataset with total number of samples {len(self)}.'
)
return result
# Modified from https://github.com/facebookresearch/detectron2/blob/41d475b75a230221e21d9cac5d69655e3415e3a4/detectron2/data/samplers/distributed_sampler.py#L57 # noqa
@DATASETS.register_module()
@ -171,6 +254,20 @@ class ClassBalancedDataset(object):
def __len__(self):
return len(self.repeat_indices)
def evaluate(self, *args, **kwargs):
raise NotImplementedError(
'evaluate results on a class-balanced dataset is weird. '
'Please inference and evaluate on the original dataset.')
def __repr__(self):
"""Print the number of instance number."""
dataset_type = 'Test' if self.test_mode else 'Train'
result = (
f'\n{self.__class__.__name__} ({self.dataset.__class__.__name__}) '
f'{dataset_type} dataset with total number of samples {len(self)}.'
)
return result
@DATASETS.register_module()
class KFoldDataset:

View File

@ -15,7 +15,7 @@ from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
from mmcls.apis import multi_gpu_test, single_gpu_test
from mmcls.datasets import build_dataloader, build_dataset
from mmcls.models import build_classifier
from mmcls.utils import setup_multi_processes
from mmcls.utils import get_root_logger, setup_multi_processes
def parse_args():
@ -118,7 +118,6 @@ def main():
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
cfg.model.pretrained = None
cfg.data.test.test_mode = True
if args.gpu_ids is not None:
cfg.gpu_ids = args.gpu_ids[0:1]
@ -137,7 +136,7 @@ def main():
init_dist(args.launcher, **cfg.dist_params)
# build the dataloader
dataset = build_dataset(cfg.data.test)
dataset = build_dataset(cfg.data.test, default_args=dict(test_mode=True))
# the extra round_up data will be removed during gpu/cpu collect
data_loader = build_dataloader(
dataset,
@ -187,9 +186,13 @@ def main():
rank, _ = get_dist_info()
if rank == 0:
results = {}
logger = get_root_logger()
if args.metrics:
eval_results = dataset.evaluate(outputs, args.metrics,
args.metric_options)
eval_results = dataset.evaluate(
results=outputs,
metric=args.metrics,
metric_options=args.metric_options,
logger=logger)
results.update(eval_results)
for k, v in eval_results.items():
if isinstance(v, np.ndarray):