diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index b266d9f48..062bd9295 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -172,7 +172,7 @@ jobs: - name: Install MMCV & OpenCV run: | pip install opencv-python - pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cpu/torch${{matrix.torch_major}}/index.html + pip install mmcv-full==1.4.2 -f https://download.openmmlab.com/mmcv/dist/cpu/torch${{matrix.torch_major}}/index.html python -c 'import mmcv; print(mmcv.__version__)' - name: Install mmcls dependencies run: | diff --git a/mmcls/datasets/base_dataset.py b/mmcls/datasets/base_dataset.py index 7a2f31092..7924b4065 100644 --- a/mmcls/datasets/base_dataset.py +++ b/mmcls/datasets/base_dataset.py @@ -59,7 +59,7 @@ class BaseDataset(Dataset, metaclass=ABCMeta): """Get all ground-truth labels (categories). Returns: - list[int]: categories for all images. + np.ndarray: categories for all images. """ gt_labels = np.array([data['gt_label'] for data in self.data_infos]) diff --git a/mmcls/datasets/builder.py b/mmcls/datasets/builder.py index bc07bb150..cf9345414 100644 --- a/mmcls/datasets/builder.py +++ b/mmcls/datasets/builder.py @@ -29,6 +29,10 @@ def build_dataset(cfg, default_args=None): KFoldDataset, RepeatDataset) if isinstance(cfg, (list, tuple)): dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg]) + elif cfg['type'] == 'ConcatDataset': + dataset = ConcatDataset( + [build_dataset(c, default_args) for c in cfg['datasets']], + separate_eval=cfg.get('separate_eval', True)) elif cfg['type'] == 'RepeatDataset': dataset = RepeatDataset( build_dataset(cfg['dataset'], default_args), cfg['times']) diff --git a/mmcls/datasets/dataset_wrappers.py b/mmcls/datasets/dataset_wrappers.py index 745c8f149..6aef65638 100644 --- a/mmcls/datasets/dataset_wrappers.py +++ b/mmcls/datasets/dataset_wrappers.py @@ -4,6 +4,7 @@ import math from collections import defaultdict import numpy as np +from mmcv.utils import print_log from torch.utils.data.dataset import ConcatDataset as _ConcatDataset from .builder import DATASETS @@ -18,12 +19,23 @@ class ConcatDataset(_ConcatDataset): Args: datasets (list[:obj:`Dataset`]): A list of datasets. + separate_eval (bool): Whether to evaluate the results + separately if it is used as validation dataset. + Defaults to True. """ - def __init__(self, datasets): + def __init__(self, datasets, separate_eval=True): super(ConcatDataset, self).__init__(datasets) + self.separate_eval = separate_eval + self.CLASSES = datasets[0].CLASSES + if not separate_eval: + if len(set([type(ds) for ds in datasets])) != 1: + raise NotImplementedError( + 'To evaluate a concat dataset non-separately, ' + 'all the datasets should have same types') + def get_cat_ids(self, idx): if idx < 0: if -idx > len(self): @@ -37,6 +49,63 @@ class ConcatDataset(_ConcatDataset): sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] return self.datasets[dataset_idx].get_cat_ids(sample_idx) + def evaluate(self, results, *args, indices=None, logger=None, **kwargs): + """Evaluate the results. + + Args: + results (list[list | tuple]): Testing results of the dataset. + indices (list, optional): The indices of samples corresponding to + the results. It's unavailable on ConcatDataset. + Defaults to None. + logger (logging.Logger | str, optional): Logger used for printing + related information during evaluation. Defaults to None. + + Returns: + dict[str: float]: AP results of the total dataset or each separate + dataset if `self.separate_eval=True`. + """ + if indices is not None: + raise NotImplementedError( + 'Use indices to evaluate speific samples in a ConcatDataset ' + 'is not supported by now.') + + assert len(results) == len(self), \ + ('Dataset and results have different sizes: ' + f'{len(self)} v.s. {len(results)}') + + # Check whether all the datasets support evaluation + for dataset in self.datasets: + assert hasattr(dataset, 'evaluate'), \ + f"{type(dataset)} haven't implemented the evaluate function." + + if self.separate_eval: + total_eval_results = dict() + for dataset_idx, dataset in enumerate(self.datasets): + start_idx = 0 if dataset_idx == 0 else \ + self.cumulative_sizes[dataset_idx-1] + end_idx = self.cumulative_sizes[dataset_idx] + + results_per_dataset = results[start_idx:end_idx] + print_log( + f'Evaluateing dataset-{dataset_idx} with ' + f'{len(results_per_dataset)} images now', + logger=logger) + + eval_results_per_dataset = dataset.evaluate( + results_per_dataset, *args, logger=logger, **kwargs) + for k, v in eval_results_per_dataset.items(): + total_eval_results.update({f'{dataset_idx}_{k}': v}) + + return total_eval_results + else: + original_data_infos = self.datasets[0].data_infos + self.datasets[0].data_infos = sum( + [dataset.data_infos for dataset in self.datasets], []) + eval_results = self.datasets[0].evaluate( + results, logger=logger, **kwargs) + self.datasets[0].data_infos = original_data_infos + return eval_results + @DATASETS.register_module() class RepeatDataset(object): @@ -68,6 +137,20 @@ class RepeatDataset(object): def __len__(self): return self.times * self._ori_len + def evaluate(self, *args, **kwargs): + raise NotImplementedError( + 'evaluate results on a repeated dataset is weird. ' + 'Please inference and evaluate on the original dataset.') + + def __repr__(self): + """Print the number of instance number.""" + dataset_type = 'Test' if self.test_mode else 'Train' + result = ( + f'\n{self.__class__.__name__} ({self.dataset.__class__.__name__}) ' + f'{dataset_type} dataset with total number of samples {len(self)}.' + ) + return result + # Modified from https://github.com/facebookresearch/detectron2/blob/41d475b75a230221e21d9cac5d69655e3415e3a4/detectron2/data/samplers/distributed_sampler.py#L57 # noqa @DATASETS.register_module() @@ -171,6 +254,20 @@ class ClassBalancedDataset(object): def __len__(self): return len(self.repeat_indices) + def evaluate(self, *args, **kwargs): + raise NotImplementedError( + 'evaluate results on a class-balanced dataset is weird. ' + 'Please inference and evaluate on the original dataset.') + + def __repr__(self): + """Print the number of instance number.""" + dataset_type = 'Test' if self.test_mode else 'Train' + result = ( + f'\n{self.__class__.__name__} ({self.dataset.__class__.__name__}) ' + f'{dataset_type} dataset with total number of samples {len(self)}.' + ) + return result + @DATASETS.register_module() class KFoldDataset: diff --git a/tools/test.py b/tools/test.py index da8c12f37..c294c5e56 100644 --- a/tools/test.py +++ b/tools/test.py @@ -15,7 +15,7 @@ from mmcv.runner import (get_dist_info, init_dist, load_checkpoint, from mmcls.apis import multi_gpu_test, single_gpu_test from mmcls.datasets import build_dataloader, build_dataset from mmcls.models import build_classifier -from mmcls.utils import setup_multi_processes +from mmcls.utils import get_root_logger, setup_multi_processes def parse_args(): @@ -118,7 +118,6 @@ def main(): if cfg.get('cudnn_benchmark', False): torch.backends.cudnn.benchmark = True cfg.model.pretrained = None - cfg.data.test.test_mode = True if args.gpu_ids is not None: cfg.gpu_ids = args.gpu_ids[0:1] @@ -137,7 +136,7 @@ def main(): init_dist(args.launcher, **cfg.dist_params) # build the dataloader - dataset = build_dataset(cfg.data.test) + dataset = build_dataset(cfg.data.test, default_args=dict(test_mode=True)) # the extra round_up data will be removed during gpu/cpu collect data_loader = build_dataloader( dataset, @@ -187,9 +186,13 @@ def main(): rank, _ = get_dist_info() if rank == 0: results = {} + logger = get_root_logger() if args.metrics: - eval_results = dataset.evaluate(outputs, args.metrics, - args.metric_options) + eval_results = dataset.evaluate( + results=outputs, + metric=args.metrics, + metric_options=args.metric_options, + logger=logger) results.update(eval_results) for k, v in eval_results.items(): if isinstance(v, np.ndarray):