diff --git a/docs/en/api/evaluation.rst b/docs/en/api/evaluation.rst index 80f9c006..0c0d1feb 100644 --- a/docs/en/api/evaluation.rst +++ b/docs/en/api/evaluation.rst @@ -23,6 +23,7 @@ Single Label Metric Accuracy SingleLabelMetric ConfusionMatrix + CorruptionError Multi Label Metric ---------------------- diff --git a/docs/en/user_guides/finetune.md b/docs/en/user_guides/finetune.md index 967f7938..97d44aff 100644 --- a/docs/en/user_guides/finetune.md +++ b/docs/en/user_guides/finetune.md @@ -228,3 +228,75 @@ It's because our training schedule is for a batch size of 128. If using 8 GPUs, just use `batch_size=16` config in the base config file for every GPU, and the total batch size will be 128. But if using one GPU, you need to change it to 128 manually to match the training schedule. + +## Evaluate the fine-tuned model on ImageNet variants + +It's a common practice to evaluate the ImageNet-(1K, 21K) fine-tuned model on the ImageNet-1K validation set. This set +shares similar data distribution with the training set, but in real world, the inference data is more likely to share +different data distribution with the training set. To have a full evaluation of model's performance on +out-of-distribution datasets, research community introduces the ImageNet-variant datasets, which shares different data +distribution with that of ImageNet-(1K, 21K)., MMClassification supports evaluating the fine-tuned model on +[ImageNet-Adversarial (A)](https://arxiv.org/abs/1907.07174), [ImageNet-Rendition (R)](https://arxiv.org/abs/2006.16241), +[ImageNet-Corruption (C)](https://arxiv.org/abs/1903.12261), and [ImageNet-Sketch (S)](https://arxiv.org/abs/1905.13549). +You can follow these steps below to have a try: + +### Prepare the datasets + +You can download these datasets from [OpenDataLab](https://opendatalab.com/) and refactor these datasets under the +`data` folder in the following format: + +```text + imagenet-a + ├── meta + │ └── val.txt + ├── val + imagenet-r + ├── meta + │ └── val.txt + ├── val/ + imagenet-s + ├── meta + │ └── val.txt + ├── val/ + imagenet-c + ├── meta + │ └── val.txt + ├── val/ +``` + +`val.txt` is the annotation file, which should have the same style as that of ImageNet-1K. You can refer to +[prepare_dataset](https://mmclassification.readthedocs.io/en/1.x/user_guides/dataset_prepare.html) to generate the +annotation file or you can refer to this [script](https://github.com/open-mmlab/mmclassification/tree/dev-1.x/projects/example_project/ood_eval/generate_imagenet_variant_annotation.py). + +### Configure the dataset and test evaluator + +Once the dataset is ready, you need to configure the `dataset` and `test_evaluator`. You have two options to +write the default settings: + +#### 1. Change the configuration file directly + +There are few modifications to the config file, but change the `data_root` of the test dataloader and pass the +annotation file to the `test_evaluator`. + +```python +# You should replace imagenet-x below with imagenet-c, imagenet-r, imagenet-a +# or imagenet-s +test_dataloader=dict(dataset=dict(data_root='data/imagenet-x')) +test_evaluator=dict(ann_file='data/imagenet-x/meta/val.txt') +``` + +#### 2. Overwrite the default settings from command line + +For example, you can overwrite the default settings by passing `--cfg-options`: + +```bash +--cfg-options test_dataloader.dataset.data_root='data/imagenet-x' \ + test_evaluator.ann_file='data/imagenet-x/meta/val.txt' +``` + +### Start test + +This step is the common test step, you can follow this [guide](https://mmclassification.readthedocs.io/en/1.x/user_guides/train_test.html) +to evaluate your fine-tuned model on out-of-distribution datasets. + +To make it easier, we also provide an off-the-shelf config files, for [ImageNet-C](https://github.com/open-mmlab/mmclassification/tree/dev-1.x/projects/example_project/ood_eval/vit_ood-eval_toy-example.py) and [ImageNet-C](https://github.com/open-mmlab/mmclassification/tree/dev-1.x/projects/example_project/ood_eval/vit_ood-eval_toy-example_imagnet-c.py), and you can have a try. diff --git a/mmcls/evaluation/metrics/__init__.py b/mmcls/evaluation/metrics/__init__.py index 25fed724..9c6fcd96 100644 --- a/mmcls/evaluation/metrics/__init__.py +++ b/mmcls/evaluation/metrics/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .corruption_error import CorruptionError from .multi_label import AveragePrecision, MultiLabelMetric from .multi_task import MultiTasksMetric from .retrieval import RetrievalRecall @@ -8,5 +9,5 @@ from .voc_multi_label import VOCAveragePrecision, VOCMultiLabelMetric __all__ = [ 'Accuracy', 'SingleLabelMetric', 'MultiLabelMetric', 'AveragePrecision', 'MultiTasksMetric', 'VOCAveragePrecision', 'VOCMultiLabelMetric', - 'ConfusionMatrix', 'RetrievalRecall' + 'ConfusionMatrix', 'RetrievalRecall', 'CorruptionError' ] diff --git a/mmcls/evaluation/metrics/corruption_error.py b/mmcls/evaluation/metrics/corruption_error.py new file mode 100644 index 00000000..58516133 --- /dev/null +++ b/mmcls/evaluation/metrics/corruption_error.py @@ -0,0 +1,165 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Sequence, Union + +import torch + +from mmcls.registry import METRICS +from .single_label import Accuracy + + +def _get_ce_alexnet() -> dict: + """Returns Corruption Error values for AlexNet.""" + ce_alexnet = dict() + ce_alexnet['gaussian_noise'] = 0.886428 + ce_alexnet['shot_noise'] = 0.894468 + ce_alexnet['impulse_noise'] = 0.922640 + ce_alexnet['defocus_blur'] = 0.819880 + ce_alexnet['glass_blur'] = 0.826268 + ce_alexnet['motion_blur'] = 0.785948 + ce_alexnet['zoom_blur'] = 0.798360 + ce_alexnet['snow'] = 0.866816 + ce_alexnet['frost'] = 0.826572 + ce_alexnet['fog'] = 0.819324 + ce_alexnet['brightness'] = 0.564592 + ce_alexnet['contrast'] = 0.853204 + ce_alexnet['elastic_transform'] = 0.646056 + ce_alexnet['pixelate'] = 0.717840 + ce_alexnet['jpeg_compression'] = 0.606500 + + return ce_alexnet + + +@METRICS.register_module() +class CorruptionError(Accuracy): + """Mean Corruption Error (mCE) metric. + + The mCE metric is proposed in `Benchmarking Neural Network Robustness to + Common Corruptions and Perturbations + `_. + + Args: + topk (int | Sequence[int]): If the ground truth label matches one of + the best **k** predictions, the sample will be regard as a positive + prediction. If the parameter is a tuple, all of top-k accuracy will + be calculated and outputted together. Defaults to 1. + thrs (Sequence[float | None] | float | None): If a float, predictions + with score lower than the threshold will be regard as the negative + prediction. If None, not apply threshold. If the parameter is a + tuple, accuracy based on all thresholds will be calculated and + outputted together. Defaults to 0. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + ano_file (str, optional): The path of the annotation file. This + file will be used in evaluating the fine-tuned model on OOD + dataset, e.g. ImageNet-A. Defaults to None. + """ + + def __init__( + self, + topk: Union[int, Sequence[int]] = (1, ), + thrs: Union[float, Sequence[Union[float, None]], None] = 0., + collect_device: str = 'cpu', + prefix: Optional[str] = None, + ann_file: Optional[str] = None, + ) -> None: + super().__init__( + topk=topk, + thrs=thrs, + collect_device=collect_device, + prefix=prefix, + ann_file=ann_file) + self.ce_alexnet = _get_ce_alexnet() + + def process(self, data_batch, data_samples: Sequence[dict]) -> None: + """Process one batch of data samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + The difference between this method and ``process`` in ``Accuracy`` is + that the ``img_path`` is extracted from the ``data_batch`` and stored + in the ``self.results``. + + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + for data_sample in data_samples: + result = dict() + pred_label = data_sample['pred_label'] + gt_label = data_sample['gt_label'] + result['img_path'] = data_sample['img_path'] + if 'score' in pred_label: + result['pred_score'] = pred_label['score'] + else: + result['pred_label'] = pred_label['label'].cpu() + result['gt_label'] = gt_label['label'].cpu() + # Save the result to `self.results`. + self.results.append(result) + + def compute_metrics(self, results: List) -> dict: + """Compute the metrics from processed results. + + Args: + results (dict): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + # NOTICE: don't access `self.results` from the method. + metrics = {} + + # extract + category = [res['img_path'].split('/')[3] for res in results] + target = [res['gt_label'] for res in results] + pred = [res['pred_score'] for res in results] + + # categorize + pred_each_category = {} + target_each_category = {} + for c, t, p in zip(category, target, pred): + if c not in pred_each_category.keys(): + pred_each_category[c] = [] + target_each_category[c] = [] + pred_each_category[c].append(p) + target_each_category[c].append(t) + + # concat + pred_each_category = { + key: torch.stack(pred_each_category[key]) + for key in pred_each_category.keys() + } + target_each_category = { + key: torch.cat(target_each_category[key]) + for key in target_each_category.keys() + } + + # compute mCE + mce_for_each_category = [] + for key in pred_each_category.keys(): + if key not in self.ce_alexnet.keys(): + continue + target_current_category = target_each_category[key] + pred_current_category = pred_each_category[key] + try: + acc = self.calculate(pred_current_category, + target_current_category, self.topk, + self.thrs) + error = (100 - acc[0][0].item()) / (100. * + self.ce_alexnet[key]) + except ValueError as e: + # If the topk is invalid. + raise ValueError( + str(e) + ' Please check the `val_evaluator` and ' + '`test_evaluator` fields in your config file.') + mce_for_each_category.append(error) + + metrics['mCE'] = sum(mce_for_each_category) / len( + mce_for_each_category) + + return metrics diff --git a/mmcls/evaluation/metrics/single_label.py b/mmcls/evaluation/metrics/single_label.py index eda31eb4..2611eed4 100644 --- a/mmcls/evaluation/metrics/single_label.py +++ b/mmcls/evaluation/metrics/single_label.py @@ -60,6 +60,26 @@ def _precision_recall_f1_support(pred_positive, gt_positive, average): return precision, recall, f1_score, support +def _generate_candidate_indices(ann_file: str = None) -> Optional[list]: + """generate index candidates for ImageNet-A, ImageNet-R, ImageNet-S. + + Args: + ann_file (str, optional): The path of the annotation file. This + file will be used in evaluating the fine-tuned model on OOD + dataset, e.g. ImageNet-A. Defaults to None. + + Returns: + Optional[list]: index candidates for ImageNet-A, ImageNet-R, ImageNet-S + """ + if ann_file is not None: + with open(ann_file, 'r') as f: + labels = [int(item.strip().split()[-1]) for item in f.readlines()] + label_dict = {label: 1 for label in labels} + return list(label_dict.keys()) + else: + return None + + @METRICS.register_module() class Accuracy(BaseMetric): r"""Accuracy evaluation metric. @@ -88,6 +108,9 @@ class Accuracy(BaseMetric): names to disambiguate homonymous metrics of different evaluators. If prefix is not provided in the argument, self.default_prefix will be used instead. Defaults to None. + ann_file (str, optional): The path of the annotation file. This + file will be used in evaluating the fine-tuned model on OOD + dataset, e.g. ImageNet-A. Defaults to None. Examples: >>> import torch @@ -124,7 +147,8 @@ class Accuracy(BaseMetric): topk: Union[int, Sequence[int]] = (1, ), thrs: Union[float, Sequence[Union[float, None]], None] = 0., collect_device: str = 'cpu', - prefix: Optional[str] = None) -> None: + prefix: Optional[str] = None, + ann_file: Optional[str] = None) -> None: super().__init__(collect_device=collect_device, prefix=prefix) if isinstance(topk, int): @@ -137,6 +161,9 @@ class Accuracy(BaseMetric): else: self.thrs = tuple(thrs) + # generate index candidates for ImageNet-A, ImageNet-R, ImageNet-S + self.index_candidates = _generate_candidate_indices(ann_file) + def process(self, data_batch, data_samples: Sequence[dict]): """Process one batch of data samples. @@ -153,7 +180,15 @@ class Accuracy(BaseMetric): pred_label = data_sample['pred_label'] gt_label = data_sample['gt_label'] if 'score' in pred_label: - result['pred_score'] = pred_label['score'].cpu() + if self.index_candidates is not None: + pred_label['score'] = pred_label['score'].cpu() + # Since we only compute the topk across the candidate + # indices, we need to add 1 to the score of the candidates + # to ensure that the candidates are in the topk. + pred_label['score'][ + ..., self.index_candidates] = pred_label['score'][ + ..., self.index_candidates] + 1.0 + result['pred_score'] = pred_label['score'] else: result['pred_label'] = pred_label['label'].cpu() result['gt_label'] = gt_label['label'].cpu() diff --git a/projects/ood_eval/README.md b/projects/ood_eval/README.md new file mode 100644 index 00000000..01d5792f --- /dev/null +++ b/projects/ood_eval/README.md @@ -0,0 +1,71 @@ +## Evaluate the fine-tuned model on ImageNet variants + +It's a common practice to evaluate the ImageNet-(1K, 21K) fine-tuned model on the ImageNet-1K validation set. This set +shares similar data distribution with the training set, but in real world, the inference data is more likely to share +different data distribution with the training set. To have a full evaluation of model's performance on +out-of-distribution datasets, research community introduces the ImageNet-variant datasets, which shares different data +distribution with that of ImageNet-(1K, 21K)., MMClassification supports evaluating the fine-tuned model on +[ImageNet-Adversarial (A)](https://arxiv.org/abs/1907.07174), [ImageNet-Rendition (R)](https://arxiv.org/abs/2006.16241), +[ImageNet-Corruption (C)](https://arxiv.org/abs/1903.12261), and [ImageNet-Sketch (S)](https://arxiv.org/abs/1905.13549). +You can follow these steps below to have a try: + +### Prepare the datasets + +You can download these datasets from [OpenDataLab](https://opendatalab.com/) and refactor these datasets under the +`data` folder in the following format: + +```text + imagenet-a + ├── meta + │ └── val.txt + ├── val + imagenet-r + ├── meta + │ └── val.txt + ├── val/ + imagenet-s + ├── meta + │ └── val.txt + ├── val/ + imagenet-c + ├── meta + │ └── val.txt + ├── val/ +``` + +`val.txt` is the annotation file, which should have the same style as that of ImageNet-1K. You can refer to +[prepare_dataset](https://mmclassification.readthedocs.io/en/1.x/user_guides/dataset_prepare.html) to generate the +annotation file or you can refer to this [script](https://github.com/open-mmlab/mmclassification/tree/dev-1.x/projects/example_project/ood_eval/generate_imagenet_variant_annotation.py). + +### Configure the dataset and test evaluator + +Once the dataset is ready, you need to configure the `dataset` and `test_evaluator`. You have two options to +write the default settings: + +#### 1. Change the configuration file directly + +There are few modifications to the config file, but change the `data_root` of the test dataloader and pass the +annotation file to the `test_evaluator`. + +```python +# You should replace imagenet-x below with imagenet-c, imagenet-r, imagenet-a +# or imagenet-s +test_dataloader=dict(dataset=dict(data_root='data/imagenet-x')) +test_evaluator=dict(ann_file='data/imagenet-x/meta/val.txt') +``` + +#### 2. Overwrite the default settings from command line + +For example, you can overwrite the default settings by passing `--cfg-options`: + +```bash +--cfg-options test_dataloader.dataset.data_root='data/imagenet-x' \ + test_evaluator.ann_file='data/imagenet-x/meta/val.txt' +``` + +### Start test + +This step is the common test step, you can follow this [guide](https://mmclassification.readthedocs.io/en/1.x/user_guides/train_test.html) +to evaluate your fine-tuned model on out-of-distribution datasets. + +To make it easier, we also provide an off-the-shelf config files, for [ImageNet-C](https://github.com/open-mmlab/mmclassification/tree/dev-1.x/projects/example_project/ood_eval/vit_ood-eval_toy-example.py) and [ImageNet-C](https://github.com/open-mmlab/mmclassification/tree/dev-1.x/projects/example_project/ood_eval/vit_ood-eval_toy-example_imagnet-c.py), and you can have a try. diff --git a/projects/ood_eval/config/vit_ood-eval_toy-example.py b/projects/ood_eval/config/vit_ood-eval_toy-example.py new file mode 100644 index 00000000..dc928316 --- /dev/null +++ b/projects/ood_eval/config/vit_ood-eval_toy-example.py @@ -0,0 +1,5 @@ +_base_ = 'mmcls::resnet/resnetv1c50_8xb32_in1k.py' # can be your own config + +# You can replace imagenet-r with imagenet-a or imagenet-s +test_dataloader = dict(dataset=dict(data_root='data/imagenet-r')) +test_evaluator = dict(ann_file='data/imagenet-r/meta/val.txt') diff --git a/projects/ood_eval/config/vit_ood-eval_toy-example_imagenet-c.py b/projects/ood_eval/config/vit_ood-eval_toy-example_imagenet-c.py new file mode 100644 index 00000000..9f829eae --- /dev/null +++ b/projects/ood_eval/config/vit_ood-eval_toy-example_imagenet-c.py @@ -0,0 +1,4 @@ +_base_ = 'mmcls::resnet/resnetv1c50_8xb32_in1k.py' # can be your own config + +test_dataloader = dict(dataset=dict(data_root='data/imagenet-c')) +test_evaluator = dict(type='CorruptionError') diff --git a/projects/ood_eval/generate_imagenet_variant_annotation.py b/projects/ood_eval/generate_imagenet_variant_annotation.py new file mode 100644 index 00000000..f48bf679 --- /dev/null +++ b/projects/ood_eval/generate_imagenet_variant_annotation.py @@ -0,0 +1,66 @@ +import argparse +import os + +parser = argparse.ArgumentParser() +parser.add_argument( + '--imagenet1k-ann-file', + type=str, + help='path to the ImageNet1k annotation file') +parser.add_argument( + '--imagenet-variant-root', + type=str, + help='the root folder of ImageNet variant') +parser.add_argument( + '--imagenet-variant-name', + type=str, + help='the name of the ImageNet variant') +parser.add_argument( + '--output-file', type=str, help='path to the output annotation file') + +if __name__ == '__main__': + args = parser.parse_args() + with open(args.imagenet1k_ann_file, 'r') as f: + imagenet1k_list = [line.strip().split() for line in f.readlines()] + imagenet1k_list = [[line[0].split('/')[0], line[1]] + for line in imagenet1k_list] + imagenet1k_label_map = {line[0]: line[1] for line in imagenet1k_list} + + imagenet_variant_images = [] + if args.imagenet_variant_name != 'c': + # ImageNet variant A, R, S + imagenet_variant_subfolders = os.listdir(args.imagenet_variant_root) + imagenet_variant_subfolders = [ + subfolder for subfolder in imagenet_variant_subfolders + if not subfolder.endswith('.txt') + ] + for subfolder in imagenet_variant_subfolders: + cur_label = imagenet1k_label_map[subfolder] + cur_subfolder = os.path.join(args.imagenet_variant_root, subfolder) + cur_subfolder_files = os.listdir(cur_subfolder) + cur_subfolder_files = [ + os.path.join(subfolder, file) + ' ' + cur_label + for file in cur_subfolder_files + ] + imagenet_variant_images.extend(cur_subfolder_files) + else: + # ImageNet variant C + curruption_categories = os.listdir(args.imagenet_variant_root) + for category in curruption_categories: + curruption_levels = os.listdir( + os.path.join(args.imagenet_variant_root, category)) + for level in curruption_levels: + imagenet_variant_subfolders = os.listdir( + os.path.join(args.imagenet_variant_root, category, level)) + for subfolder in imagenet_variant_subfolders: + cur_label = imagenet1k_label_map[subfolder] + cur_subfolder = os.path.join(args.imagenet_variant_root, + category, level, subfolder) + cur_subfolder_files = os.listdir(cur_subfolder) + cur_subfolder_files = [ + os.path.join(category, level, subfolder, file) + ' ' + + cur_label for file in cur_subfolder_files + ] + imagenet_variant_images.extend(cur_subfolder_files) + + with open(args.output_file, 'w') as f: + f.write('\n'.join(imagenet_variant_images)) diff --git a/tests/test_evaluation/test_metrics/test_corruption_error.py b/tests/test_evaluation/test_metrics/test_corruption_error.py new file mode 100644 index 00000000..00fedc77 --- /dev/null +++ b/tests/test_evaluation/test_metrics/test_corruption_error.py @@ -0,0 +1,34 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch + +from mmcls.registry import METRICS + + +class TestCorruptionError(TestCase): + + def test_compute_metrics(self): + mCE_metrics = METRICS.build(dict(type='CorruptionError')) + results = [{ + 'pred_score': torch.tensor([0.7, 0.0, 0.3]), + 'gt_label': torch.tensor([0]), + 'img_path': 'a/b/c/gaussian_noise' + } for i in range(10)] + metrics = mCE_metrics.compute_metrics(results) + assert metrics['mCE'] == 0.0 + + def test_process(self): + mCE_metrics = METRICS.build(dict(type='CorruptionError')) + results = [{ + 'pred_label': { + 'label': torch.tensor([0]), + 'score': torch.tensor([0.7, 0.0, 0.3]) + }, + 'gt_label': { + 'label': torch.tensor([0]) + }, + 'img_path': 'a/b/c/gaussian_noise' + } for i in range(10)] + mCE_metrics.process(None, results) + assert len(mCE_metrics.results) == 10