[Feature] Add `evaluate` function for ConcatDataset. (#650)
* Add `evaluate` function for ConcatDataset * Remove newline in log. * Fix lint * Specify mmcv version in Windows CIpull/717/head
parent
1214df083d
commit
1a28f9ace6
|
@ -172,7 +172,7 @@ jobs:
|
|||
- name: Install MMCV & OpenCV
|
||||
run: |
|
||||
pip install opencv-python
|
||||
pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cpu/torch${{matrix.torch_major}}/index.html
|
||||
pip install mmcv-full==1.4.2 -f https://download.openmmlab.com/mmcv/dist/cpu/torch${{matrix.torch_major}}/index.html
|
||||
python -c 'import mmcv; print(mmcv.__version__)'
|
||||
- name: Install mmcls dependencies
|
||||
run: |
|
||||
|
|
|
@ -59,7 +59,7 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
|
|||
"""Get all ground-truth labels (categories).
|
||||
|
||||
Returns:
|
||||
list[int]: categories for all images.
|
||||
np.ndarray: categories for all images.
|
||||
"""
|
||||
|
||||
gt_labels = np.array([data['gt_label'] for data in self.data_infos])
|
||||
|
|
|
@ -29,6 +29,10 @@ def build_dataset(cfg, default_args=None):
|
|||
KFoldDataset, RepeatDataset)
|
||||
if isinstance(cfg, (list, tuple)):
|
||||
dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
|
||||
elif cfg['type'] == 'ConcatDataset':
|
||||
dataset = ConcatDataset(
|
||||
[build_dataset(c, default_args) for c in cfg['datasets']],
|
||||
separate_eval=cfg.get('separate_eval', True))
|
||||
elif cfg['type'] == 'RepeatDataset':
|
||||
dataset = RepeatDataset(
|
||||
build_dataset(cfg['dataset'], default_args), cfg['times'])
|
||||
|
|
|
@ -4,6 +4,7 @@ import math
|
|||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
from mmcv.utils import print_log
|
||||
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
|
||||
|
||||
from .builder import DATASETS
|
||||
|
@ -18,12 +19,23 @@ class ConcatDataset(_ConcatDataset):
|
|||
|
||||
Args:
|
||||
datasets (list[:obj:`Dataset`]): A list of datasets.
|
||||
separate_eval (bool): Whether to evaluate the results
|
||||
separately if it is used as validation dataset.
|
||||
Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self, datasets):
|
||||
def __init__(self, datasets, separate_eval=True):
|
||||
super(ConcatDataset, self).__init__(datasets)
|
||||
self.separate_eval = separate_eval
|
||||
|
||||
self.CLASSES = datasets[0].CLASSES
|
||||
|
||||
if not separate_eval:
|
||||
if len(set([type(ds) for ds in datasets])) != 1:
|
||||
raise NotImplementedError(
|
||||
'To evaluate a concat dataset non-separately, '
|
||||
'all the datasets should have same types')
|
||||
|
||||
def get_cat_ids(self, idx):
|
||||
if idx < 0:
|
||||
if -idx > len(self):
|
||||
|
@ -37,6 +49,63 @@ class ConcatDataset(_ConcatDataset):
|
|||
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
||||
return self.datasets[dataset_idx].get_cat_ids(sample_idx)
|
||||
|
||||
def evaluate(self, results, *args, indices=None, logger=None, **kwargs):
|
||||
"""Evaluate the results.
|
||||
|
||||
Args:
|
||||
results (list[list | tuple]): Testing results of the dataset.
|
||||
indices (list, optional): The indices of samples corresponding to
|
||||
the results. It's unavailable on ConcatDataset.
|
||||
Defaults to None.
|
||||
logger (logging.Logger | str, optional): Logger used for printing
|
||||
related information during evaluation. Defaults to None.
|
||||
|
||||
Returns:
|
||||
dict[str: float]: AP results of the total dataset or each separate
|
||||
dataset if `self.separate_eval=True`.
|
||||
"""
|
||||
if indices is not None:
|
||||
raise NotImplementedError(
|
||||
'Use indices to evaluate speific samples in a ConcatDataset '
|
||||
'is not supported by now.')
|
||||
|
||||
assert len(results) == len(self), \
|
||||
('Dataset and results have different sizes: '
|
||||
f'{len(self)} v.s. {len(results)}')
|
||||
|
||||
# Check whether all the datasets support evaluation
|
||||
for dataset in self.datasets:
|
||||
assert hasattr(dataset, 'evaluate'), \
|
||||
f"{type(dataset)} haven't implemented the evaluate function."
|
||||
|
||||
if self.separate_eval:
|
||||
total_eval_results = dict()
|
||||
for dataset_idx, dataset in enumerate(self.datasets):
|
||||
start_idx = 0 if dataset_idx == 0 else \
|
||||
self.cumulative_sizes[dataset_idx-1]
|
||||
end_idx = self.cumulative_sizes[dataset_idx]
|
||||
|
||||
results_per_dataset = results[start_idx:end_idx]
|
||||
print_log(
|
||||
f'Evaluateing dataset-{dataset_idx} with '
|
||||
f'{len(results_per_dataset)} images now',
|
||||
logger=logger)
|
||||
|
||||
eval_results_per_dataset = dataset.evaluate(
|
||||
results_per_dataset, *args, logger=logger, **kwargs)
|
||||
for k, v in eval_results_per_dataset.items():
|
||||
total_eval_results.update({f'{dataset_idx}_{k}': v})
|
||||
|
||||
return total_eval_results
|
||||
else:
|
||||
original_data_infos = self.datasets[0].data_infos
|
||||
self.datasets[0].data_infos = sum(
|
||||
[dataset.data_infos for dataset in self.datasets], [])
|
||||
eval_results = self.datasets[0].evaluate(
|
||||
results, logger=logger, **kwargs)
|
||||
self.datasets[0].data_infos = original_data_infos
|
||||
return eval_results
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class RepeatDataset(object):
|
||||
|
@ -68,6 +137,20 @@ class RepeatDataset(object):
|
|||
def __len__(self):
|
||||
return self.times * self._ori_len
|
||||
|
||||
def evaluate(self, *args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
'evaluate results on a repeated dataset is weird. '
|
||||
'Please inference and evaluate on the original dataset.')
|
||||
|
||||
def __repr__(self):
|
||||
"""Print the number of instance number."""
|
||||
dataset_type = 'Test' if self.test_mode else 'Train'
|
||||
result = (
|
||||
f'\n{self.__class__.__name__} ({self.dataset.__class__.__name__}) '
|
||||
f'{dataset_type} dataset with total number of samples {len(self)}.'
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
# Modified from https://github.com/facebookresearch/detectron2/blob/41d475b75a230221e21d9cac5d69655e3415e3a4/detectron2/data/samplers/distributed_sampler.py#L57 # noqa
|
||||
@DATASETS.register_module()
|
||||
|
@ -171,6 +254,20 @@ class ClassBalancedDataset(object):
|
|||
def __len__(self):
|
||||
return len(self.repeat_indices)
|
||||
|
||||
def evaluate(self, *args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
'evaluate results on a class-balanced dataset is weird. '
|
||||
'Please inference and evaluate on the original dataset.')
|
||||
|
||||
def __repr__(self):
|
||||
"""Print the number of instance number."""
|
||||
dataset_type = 'Test' if self.test_mode else 'Train'
|
||||
result = (
|
||||
f'\n{self.__class__.__name__} ({self.dataset.__class__.__name__}) '
|
||||
f'{dataset_type} dataset with total number of samples {len(self)}.'
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class KFoldDataset:
|
||||
|
|
|
@ -15,7 +15,7 @@ from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
|
|||
from mmcls.apis import multi_gpu_test, single_gpu_test
|
||||
from mmcls.datasets import build_dataloader, build_dataset
|
||||
from mmcls.models import build_classifier
|
||||
from mmcls.utils import setup_multi_processes
|
||||
from mmcls.utils import get_root_logger, setup_multi_processes
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
@ -118,7 +118,6 @@ def main():
|
|||
if cfg.get('cudnn_benchmark', False):
|
||||
torch.backends.cudnn.benchmark = True
|
||||
cfg.model.pretrained = None
|
||||
cfg.data.test.test_mode = True
|
||||
|
||||
if args.gpu_ids is not None:
|
||||
cfg.gpu_ids = args.gpu_ids[0:1]
|
||||
|
@ -137,7 +136,7 @@ def main():
|
|||
init_dist(args.launcher, **cfg.dist_params)
|
||||
|
||||
# build the dataloader
|
||||
dataset = build_dataset(cfg.data.test)
|
||||
dataset = build_dataset(cfg.data.test, default_args=dict(test_mode=True))
|
||||
# the extra round_up data will be removed during gpu/cpu collect
|
||||
data_loader = build_dataloader(
|
||||
dataset,
|
||||
|
@ -187,9 +186,13 @@ def main():
|
|||
rank, _ = get_dist_info()
|
||||
if rank == 0:
|
||||
results = {}
|
||||
logger = get_root_logger()
|
||||
if args.metrics:
|
||||
eval_results = dataset.evaluate(outputs, args.metrics,
|
||||
args.metric_options)
|
||||
eval_results = dataset.evaluate(
|
||||
results=outputs,
|
||||
metric=args.metrics,
|
||||
metric_options=args.metric_options,
|
||||
logger=logger)
|
||||
results.update(eval_results)
|
||||
for k, v in eval_results.items():
|
||||
if isinstance(v, np.ndarray):
|
||||
|
|
Loading…
Reference in New Issue