mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
[Feature] Add evaluate
function for ConcatDataset. (#650)
* Add `evaluate` function for ConcatDataset * Remove newline in log. * Fix lint * Specify mmcv version in Windows CI
This commit is contained in:
parent
1214df083d
commit
1a28f9ace6
2
.github/workflows/build.yml
vendored
2
.github/workflows/build.yml
vendored
@ -172,7 +172,7 @@ jobs:
|
|||||||
- name: Install MMCV & OpenCV
|
- name: Install MMCV & OpenCV
|
||||||
run: |
|
run: |
|
||||||
pip install opencv-python
|
pip install opencv-python
|
||||||
pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cpu/torch${{matrix.torch_major}}/index.html
|
pip install mmcv-full==1.4.2 -f https://download.openmmlab.com/mmcv/dist/cpu/torch${{matrix.torch_major}}/index.html
|
||||||
python -c 'import mmcv; print(mmcv.__version__)'
|
python -c 'import mmcv; print(mmcv.__version__)'
|
||||||
- name: Install mmcls dependencies
|
- name: Install mmcls dependencies
|
||||||
run: |
|
run: |
|
||||||
|
@ -59,7 +59,7 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
|
|||||||
"""Get all ground-truth labels (categories).
|
"""Get all ground-truth labels (categories).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list[int]: categories for all images.
|
np.ndarray: categories for all images.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
gt_labels = np.array([data['gt_label'] for data in self.data_infos])
|
gt_labels = np.array([data['gt_label'] for data in self.data_infos])
|
||||||
|
@ -29,6 +29,10 @@ def build_dataset(cfg, default_args=None):
|
|||||||
KFoldDataset, RepeatDataset)
|
KFoldDataset, RepeatDataset)
|
||||||
if isinstance(cfg, (list, tuple)):
|
if isinstance(cfg, (list, tuple)):
|
||||||
dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
|
dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
|
||||||
|
elif cfg['type'] == 'ConcatDataset':
|
||||||
|
dataset = ConcatDataset(
|
||||||
|
[build_dataset(c, default_args) for c in cfg['datasets']],
|
||||||
|
separate_eval=cfg.get('separate_eval', True))
|
||||||
elif cfg['type'] == 'RepeatDataset':
|
elif cfg['type'] == 'RepeatDataset':
|
||||||
dataset = RepeatDataset(
|
dataset = RepeatDataset(
|
||||||
build_dataset(cfg['dataset'], default_args), cfg['times'])
|
build_dataset(cfg['dataset'], default_args), cfg['times'])
|
||||||
|
@ -4,6 +4,7 @@ import math
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from mmcv.utils import print_log
|
||||||
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
|
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
|
||||||
|
|
||||||
from .builder import DATASETS
|
from .builder import DATASETS
|
||||||
@ -18,12 +19,23 @@ class ConcatDataset(_ConcatDataset):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
datasets (list[:obj:`Dataset`]): A list of datasets.
|
datasets (list[:obj:`Dataset`]): A list of datasets.
|
||||||
|
separate_eval (bool): Whether to evaluate the results
|
||||||
|
separately if it is used as validation dataset.
|
||||||
|
Defaults to True.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, datasets):
|
def __init__(self, datasets, separate_eval=True):
|
||||||
super(ConcatDataset, self).__init__(datasets)
|
super(ConcatDataset, self).__init__(datasets)
|
||||||
|
self.separate_eval = separate_eval
|
||||||
|
|
||||||
self.CLASSES = datasets[0].CLASSES
|
self.CLASSES = datasets[0].CLASSES
|
||||||
|
|
||||||
|
if not separate_eval:
|
||||||
|
if len(set([type(ds) for ds in datasets])) != 1:
|
||||||
|
raise NotImplementedError(
|
||||||
|
'To evaluate a concat dataset non-separately, '
|
||||||
|
'all the datasets should have same types')
|
||||||
|
|
||||||
def get_cat_ids(self, idx):
|
def get_cat_ids(self, idx):
|
||||||
if idx < 0:
|
if idx < 0:
|
||||||
if -idx > len(self):
|
if -idx > len(self):
|
||||||
@ -37,6 +49,63 @@ class ConcatDataset(_ConcatDataset):
|
|||||||
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
||||||
return self.datasets[dataset_idx].get_cat_ids(sample_idx)
|
return self.datasets[dataset_idx].get_cat_ids(sample_idx)
|
||||||
|
|
||||||
|
def evaluate(self, results, *args, indices=None, logger=None, **kwargs):
|
||||||
|
"""Evaluate the results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results (list[list | tuple]): Testing results of the dataset.
|
||||||
|
indices (list, optional): The indices of samples corresponding to
|
||||||
|
the results. It's unavailable on ConcatDataset.
|
||||||
|
Defaults to None.
|
||||||
|
logger (logging.Logger | str, optional): Logger used for printing
|
||||||
|
related information during evaluation. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str: float]: AP results of the total dataset or each separate
|
||||||
|
dataset if `self.separate_eval=True`.
|
||||||
|
"""
|
||||||
|
if indices is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
'Use indices to evaluate speific samples in a ConcatDataset '
|
||||||
|
'is not supported by now.')
|
||||||
|
|
||||||
|
assert len(results) == len(self), \
|
||||||
|
('Dataset and results have different sizes: '
|
||||||
|
f'{len(self)} v.s. {len(results)}')
|
||||||
|
|
||||||
|
# Check whether all the datasets support evaluation
|
||||||
|
for dataset in self.datasets:
|
||||||
|
assert hasattr(dataset, 'evaluate'), \
|
||||||
|
f"{type(dataset)} haven't implemented the evaluate function."
|
||||||
|
|
||||||
|
if self.separate_eval:
|
||||||
|
total_eval_results = dict()
|
||||||
|
for dataset_idx, dataset in enumerate(self.datasets):
|
||||||
|
start_idx = 0 if dataset_idx == 0 else \
|
||||||
|
self.cumulative_sizes[dataset_idx-1]
|
||||||
|
end_idx = self.cumulative_sizes[dataset_idx]
|
||||||
|
|
||||||
|
results_per_dataset = results[start_idx:end_idx]
|
||||||
|
print_log(
|
||||||
|
f'Evaluateing dataset-{dataset_idx} with '
|
||||||
|
f'{len(results_per_dataset)} images now',
|
||||||
|
logger=logger)
|
||||||
|
|
||||||
|
eval_results_per_dataset = dataset.evaluate(
|
||||||
|
results_per_dataset, *args, logger=logger, **kwargs)
|
||||||
|
for k, v in eval_results_per_dataset.items():
|
||||||
|
total_eval_results.update({f'{dataset_idx}_{k}': v})
|
||||||
|
|
||||||
|
return total_eval_results
|
||||||
|
else:
|
||||||
|
original_data_infos = self.datasets[0].data_infos
|
||||||
|
self.datasets[0].data_infos = sum(
|
||||||
|
[dataset.data_infos for dataset in self.datasets], [])
|
||||||
|
eval_results = self.datasets[0].evaluate(
|
||||||
|
results, logger=logger, **kwargs)
|
||||||
|
self.datasets[0].data_infos = original_data_infos
|
||||||
|
return eval_results
|
||||||
|
|
||||||
|
|
||||||
@DATASETS.register_module()
|
@DATASETS.register_module()
|
||||||
class RepeatDataset(object):
|
class RepeatDataset(object):
|
||||||
@ -68,6 +137,20 @@ class RepeatDataset(object):
|
|||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.times * self._ori_len
|
return self.times * self._ori_len
|
||||||
|
|
||||||
|
def evaluate(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError(
|
||||||
|
'evaluate results on a repeated dataset is weird. '
|
||||||
|
'Please inference and evaluate on the original dataset.')
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
"""Print the number of instance number."""
|
||||||
|
dataset_type = 'Test' if self.test_mode else 'Train'
|
||||||
|
result = (
|
||||||
|
f'\n{self.__class__.__name__} ({self.dataset.__class__.__name__}) '
|
||||||
|
f'{dataset_type} dataset with total number of samples {len(self)}.'
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
# Modified from https://github.com/facebookresearch/detectron2/blob/41d475b75a230221e21d9cac5d69655e3415e3a4/detectron2/data/samplers/distributed_sampler.py#L57 # noqa
|
# Modified from https://github.com/facebookresearch/detectron2/blob/41d475b75a230221e21d9cac5d69655e3415e3a4/detectron2/data/samplers/distributed_sampler.py#L57 # noqa
|
||||||
@DATASETS.register_module()
|
@DATASETS.register_module()
|
||||||
@ -171,6 +254,20 @@ class ClassBalancedDataset(object):
|
|||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.repeat_indices)
|
return len(self.repeat_indices)
|
||||||
|
|
||||||
|
def evaluate(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError(
|
||||||
|
'evaluate results on a class-balanced dataset is weird. '
|
||||||
|
'Please inference and evaluate on the original dataset.')
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
"""Print the number of instance number."""
|
||||||
|
dataset_type = 'Test' if self.test_mode else 'Train'
|
||||||
|
result = (
|
||||||
|
f'\n{self.__class__.__name__} ({self.dataset.__class__.__name__}) '
|
||||||
|
f'{dataset_type} dataset with total number of samples {len(self)}.'
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
@DATASETS.register_module()
|
@DATASETS.register_module()
|
||||||
class KFoldDataset:
|
class KFoldDataset:
|
||||||
|
@ -15,7 +15,7 @@ from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
|
|||||||
from mmcls.apis import multi_gpu_test, single_gpu_test
|
from mmcls.apis import multi_gpu_test, single_gpu_test
|
||||||
from mmcls.datasets import build_dataloader, build_dataset
|
from mmcls.datasets import build_dataloader, build_dataset
|
||||||
from mmcls.models import build_classifier
|
from mmcls.models import build_classifier
|
||||||
from mmcls.utils import setup_multi_processes
|
from mmcls.utils import get_root_logger, setup_multi_processes
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
@ -118,7 +118,6 @@ def main():
|
|||||||
if cfg.get('cudnn_benchmark', False):
|
if cfg.get('cudnn_benchmark', False):
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
cfg.model.pretrained = None
|
cfg.model.pretrained = None
|
||||||
cfg.data.test.test_mode = True
|
|
||||||
|
|
||||||
if args.gpu_ids is not None:
|
if args.gpu_ids is not None:
|
||||||
cfg.gpu_ids = args.gpu_ids[0:1]
|
cfg.gpu_ids = args.gpu_ids[0:1]
|
||||||
@ -137,7 +136,7 @@ def main():
|
|||||||
init_dist(args.launcher, **cfg.dist_params)
|
init_dist(args.launcher, **cfg.dist_params)
|
||||||
|
|
||||||
# build the dataloader
|
# build the dataloader
|
||||||
dataset = build_dataset(cfg.data.test)
|
dataset = build_dataset(cfg.data.test, default_args=dict(test_mode=True))
|
||||||
# the extra round_up data will be removed during gpu/cpu collect
|
# the extra round_up data will be removed during gpu/cpu collect
|
||||||
data_loader = build_dataloader(
|
data_loader = build_dataloader(
|
||||||
dataset,
|
dataset,
|
||||||
@ -187,9 +186,13 @@ def main():
|
|||||||
rank, _ = get_dist_info()
|
rank, _ = get_dist_info()
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
results = {}
|
results = {}
|
||||||
|
logger = get_root_logger()
|
||||||
if args.metrics:
|
if args.metrics:
|
||||||
eval_results = dataset.evaluate(outputs, args.metrics,
|
eval_results = dataset.evaluate(
|
||||||
args.metric_options)
|
results=outputs,
|
||||||
|
metric=args.metrics,
|
||||||
|
metric_options=args.metric_options,
|
||||||
|
logger=logger)
|
||||||
results.update(eval_results)
|
results.update(eval_results)
|
||||||
for k, v in eval_results.items():
|
for k, v in eval_results.items():
|
||||||
if isinstance(v, np.ndarray):
|
if isinstance(v, np.ndarray):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user