[Feature] Support Out-of-Distribution datasets like ImageNet-A,R,S,C. (#1342)

* [Feature]: Support ImageNet-A,R,S

* [Feature]: Add doc for ood eval

* [Feature]: Add example config

* [Feature]: Add mCE evaluator

* [Fix]: Fix key error

* [Feature]: Add mCE for ImageNet-C

* [Feature]: Add ImageNet-C example

* [Feature]: Add doc for ImageNet-C ft

* [Fix]: Fix bug

* [Fix]: Fix lint

* [Fix]: Fix suggestion

* [Fix]: Fix codespell

* [Fix]: Fix lint

* [Feature]: Add gen annotation

* [Fix]: Fix lint

* [Fix]: Fix index mask bug
pull/1426/head
Yuan Liu 2023-03-16 16:30:42 +08:00 committed by GitHub
parent 654a337d3e
commit ce3fa7b3fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 457 additions and 3 deletions

View File

@ -23,6 +23,7 @@ Single Label Metric
Accuracy
SingleLabelMetric
ConfusionMatrix
CorruptionError
Multi Label Metric
----------------------

View File

@ -228,3 +228,75 @@ It's because our training schedule is for a batch size of 128. If using 8 GPUs,
just use `batch_size=16` config in the base config file for every GPU, and the total batch
size will be 128. But if using one GPU, you need to change it to 128 manually to
match the training schedule.
## Evaluate the fine-tuned model on ImageNet variants
It's a common practice to evaluate the ImageNet-(1K, 21K) fine-tuned model on the ImageNet-1K validation set. This set
shares similar data distribution with the training set, but in real world, the inference data is more likely to share
different data distribution with the training set. To have a full evaluation of model's performance on
out-of-distribution datasets, research community introduces the ImageNet-variant datasets, which shares different data
distribution with that of ImageNet-(1K, 21K)., MMClassification supports evaluating the fine-tuned model on
[ImageNet-Adversarial (A)](https://arxiv.org/abs/1907.07174), [ImageNet-Rendition (R)](https://arxiv.org/abs/2006.16241),
[ImageNet-Corruption (C)](https://arxiv.org/abs/1903.12261), and [ImageNet-Sketch (S)](https://arxiv.org/abs/1905.13549).
You can follow these steps below to have a try:
### Prepare the datasets
You can download these datasets from [OpenDataLab](https://opendatalab.com/) and refactor these datasets under the
`data` folder in the following format:
```text
imagenet-a
├── meta
│ └── val.txt
├── val
imagenet-r
├── meta
│ └── val.txt
├── val/
imagenet-s
├── meta
│ └── val.txt
├── val/
imagenet-c
├── meta
│ └── val.txt
├── val/
```
`val.txt` is the annotation file, which should have the same style as that of ImageNet-1K. You can refer to
[prepare_dataset](https://mmclassification.readthedocs.io/en/1.x/user_guides/dataset_prepare.html) to generate the
annotation file or you can refer to this [script](https://github.com/open-mmlab/mmclassification/tree/dev-1.x/projects/example_project/ood_eval/generate_imagenet_variant_annotation.py).
### Configure the dataset and test evaluator
Once the dataset is ready, you need to configure the `dataset` and `test_evaluator`. You have two options to
write the default settings:
#### 1. Change the configuration file directly
There are few modifications to the config file, but change the `data_root` of the test dataloader and pass the
annotation file to the `test_evaluator`.
```python
# You should replace imagenet-x below with imagenet-c, imagenet-r, imagenet-a
# or imagenet-s
test_dataloader=dict(dataset=dict(data_root='data/imagenet-x'))
test_evaluator=dict(ann_file='data/imagenet-x/meta/val.txt')
```
#### 2. Overwrite the default settings from command line
For example, you can overwrite the default settings by passing `--cfg-options`:
```bash
--cfg-options test_dataloader.dataset.data_root='data/imagenet-x' \
test_evaluator.ann_file='data/imagenet-x/meta/val.txt'
```
### Start test
This step is the common test step, you can follow this [guide](https://mmclassification.readthedocs.io/en/1.x/user_guides/train_test.html)
to evaluate your fine-tuned model on out-of-distribution datasets.
To make it easier, we also provide an off-the-shelf config files, for [ImageNet-C](https://github.com/open-mmlab/mmclassification/tree/dev-1.x/projects/example_project/ood_eval/vit_ood-eval_toy-example.py) and [ImageNet-C](https://github.com/open-mmlab/mmclassification/tree/dev-1.x/projects/example_project/ood_eval/vit_ood-eval_toy-example_imagnet-c.py), and you can have a try.

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .corruption_error import CorruptionError
from .multi_label import AveragePrecision, MultiLabelMetric
from .multi_task import MultiTasksMetric
from .retrieval import RetrievalRecall
@ -8,5 +9,5 @@ from .voc_multi_label import VOCAveragePrecision, VOCMultiLabelMetric
__all__ = [
'Accuracy', 'SingleLabelMetric', 'MultiLabelMetric', 'AveragePrecision',
'MultiTasksMetric', 'VOCAveragePrecision', 'VOCMultiLabelMetric',
'ConfusionMatrix', 'RetrievalRecall'
'ConfusionMatrix', 'RetrievalRecall', 'CorruptionError'
]

View File

@ -0,0 +1,165 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Sequence, Union
import torch
from mmcls.registry import METRICS
from .single_label import Accuracy
def _get_ce_alexnet() -> dict:
"""Returns Corruption Error values for AlexNet."""
ce_alexnet = dict()
ce_alexnet['gaussian_noise'] = 0.886428
ce_alexnet['shot_noise'] = 0.894468
ce_alexnet['impulse_noise'] = 0.922640
ce_alexnet['defocus_blur'] = 0.819880
ce_alexnet['glass_blur'] = 0.826268
ce_alexnet['motion_blur'] = 0.785948
ce_alexnet['zoom_blur'] = 0.798360
ce_alexnet['snow'] = 0.866816
ce_alexnet['frost'] = 0.826572
ce_alexnet['fog'] = 0.819324
ce_alexnet['brightness'] = 0.564592
ce_alexnet['contrast'] = 0.853204
ce_alexnet['elastic_transform'] = 0.646056
ce_alexnet['pixelate'] = 0.717840
ce_alexnet['jpeg_compression'] = 0.606500
return ce_alexnet
@METRICS.register_module()
class CorruptionError(Accuracy):
"""Mean Corruption Error (mCE) metric.
The mCE metric is proposed in `Benchmarking Neural Network Robustness to
Common Corruptions and Perturbations
<https://arxiv.org/abs/1903.12261>`_.
Args:
topk (int | Sequence[int]): If the ground truth label matches one of
the best **k** predictions, the sample will be regard as a positive
prediction. If the parameter is a tuple, all of top-k accuracy will
be calculated and outputted together. Defaults to 1.
thrs (Sequence[float | None] | float | None): If a float, predictions
with score lower than the threshold will be regard as the negative
prediction. If None, not apply threshold. If the parameter is a
tuple, accuracy based on all thresholds will be calculated and
outputted together. Defaults to 0.
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
prefix (str, optional): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Defaults to None.
ano_file (str, optional): The path of the annotation file. This
file will be used in evaluating the fine-tuned model on OOD
dataset, e.g. ImageNet-A. Defaults to None.
"""
def __init__(
self,
topk: Union[int, Sequence[int]] = (1, ),
thrs: Union[float, Sequence[Union[float, None]], None] = 0.,
collect_device: str = 'cpu',
prefix: Optional[str] = None,
ann_file: Optional[str] = None,
) -> None:
super().__init__(
topk=topk,
thrs=thrs,
collect_device=collect_device,
prefix=prefix,
ann_file=ann_file)
self.ce_alexnet = _get_ce_alexnet()
def process(self, data_batch, data_samples: Sequence[dict]) -> None:
"""Process one batch of data samples.
The processed results should be stored in ``self.results``, which will
be used to computed the metrics when all batches have been processed.
The difference between this method and ``process`` in ``Accuracy`` is
that the ``img_path`` is extracted from the ``data_batch`` and stored
in the ``self.results``.
Args:
data_batch: A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from the model.
"""
for data_sample in data_samples:
result = dict()
pred_label = data_sample['pred_label']
gt_label = data_sample['gt_label']
result['img_path'] = data_sample['img_path']
if 'score' in pred_label:
result['pred_score'] = pred_label['score']
else:
result['pred_label'] = pred_label['label'].cpu()
result['gt_label'] = gt_label['label'].cpu()
# Save the result to `self.results`.
self.results.append(result)
def compute_metrics(self, results: List) -> dict:
"""Compute the metrics from processed results.
Args:
results (dict): The processed results of each batch.
Returns:
Dict: The computed metrics. The keys are the names of the metrics,
and the values are corresponding results.
"""
# NOTICE: don't access `self.results` from the method.
metrics = {}
# extract
category = [res['img_path'].split('/')[3] for res in results]
target = [res['gt_label'] for res in results]
pred = [res['pred_score'] for res in results]
# categorize
pred_each_category = {}
target_each_category = {}
for c, t, p in zip(category, target, pred):
if c not in pred_each_category.keys():
pred_each_category[c] = []
target_each_category[c] = []
pred_each_category[c].append(p)
target_each_category[c].append(t)
# concat
pred_each_category = {
key: torch.stack(pred_each_category[key])
for key in pred_each_category.keys()
}
target_each_category = {
key: torch.cat(target_each_category[key])
for key in target_each_category.keys()
}
# compute mCE
mce_for_each_category = []
for key in pred_each_category.keys():
if key not in self.ce_alexnet.keys():
continue
target_current_category = target_each_category[key]
pred_current_category = pred_each_category[key]
try:
acc = self.calculate(pred_current_category,
target_current_category, self.topk,
self.thrs)
error = (100 - acc[0][0].item()) / (100. *
self.ce_alexnet[key])
except ValueError as e:
# If the topk is invalid.
raise ValueError(
str(e) + ' Please check the `val_evaluator` and '
'`test_evaluator` fields in your config file.')
mce_for_each_category.append(error)
metrics['mCE'] = sum(mce_for_each_category) / len(
mce_for_each_category)
return metrics

View File

@ -60,6 +60,26 @@ def _precision_recall_f1_support(pred_positive, gt_positive, average):
return precision, recall, f1_score, support
def _generate_candidate_indices(ann_file: str = None) -> Optional[list]:
"""generate index candidates for ImageNet-A, ImageNet-R, ImageNet-S.
Args:
ann_file (str, optional): The path of the annotation file. This
file will be used in evaluating the fine-tuned model on OOD
dataset, e.g. ImageNet-A. Defaults to None.
Returns:
Optional[list]: index candidates for ImageNet-A, ImageNet-R, ImageNet-S
"""
if ann_file is not None:
with open(ann_file, 'r') as f:
labels = [int(item.strip().split()[-1]) for item in f.readlines()]
label_dict = {label: 1 for label in labels}
return list(label_dict.keys())
else:
return None
@METRICS.register_module()
class Accuracy(BaseMetric):
r"""Accuracy evaluation metric.
@ -88,6 +108,9 @@ class Accuracy(BaseMetric):
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Defaults to None.
ann_file (str, optional): The path of the annotation file. This
file will be used in evaluating the fine-tuned model on OOD
dataset, e.g. ImageNet-A. Defaults to None.
Examples:
>>> import torch
@ -124,7 +147,8 @@ class Accuracy(BaseMetric):
topk: Union[int, Sequence[int]] = (1, ),
thrs: Union[float, Sequence[Union[float, None]], None] = 0.,
collect_device: str = 'cpu',
prefix: Optional[str] = None) -> None:
prefix: Optional[str] = None,
ann_file: Optional[str] = None) -> None:
super().__init__(collect_device=collect_device, prefix=prefix)
if isinstance(topk, int):
@ -137,6 +161,9 @@ class Accuracy(BaseMetric):
else:
self.thrs = tuple(thrs)
# generate index candidates for ImageNet-A, ImageNet-R, ImageNet-S
self.index_candidates = _generate_candidate_indices(ann_file)
def process(self, data_batch, data_samples: Sequence[dict]):
"""Process one batch of data samples.
@ -153,7 +180,15 @@ class Accuracy(BaseMetric):
pred_label = data_sample['pred_label']
gt_label = data_sample['gt_label']
if 'score' in pred_label:
result['pred_score'] = pred_label['score'].cpu()
if self.index_candidates is not None:
pred_label['score'] = pred_label['score'].cpu()
# Since we only compute the topk across the candidate
# indices, we need to add 1 to the score of the candidates
# to ensure that the candidates are in the topk.
pred_label['score'][
..., self.index_candidates] = pred_label['score'][
..., self.index_candidates] + 1.0
result['pred_score'] = pred_label['score']
else:
result['pred_label'] = pred_label['label'].cpu()
result['gt_label'] = gt_label['label'].cpu()

View File

@ -0,0 +1,71 @@
## Evaluate the fine-tuned model on ImageNet variants
It's a common practice to evaluate the ImageNet-(1K, 21K) fine-tuned model on the ImageNet-1K validation set. This set
shares similar data distribution with the training set, but in real world, the inference data is more likely to share
different data distribution with the training set. To have a full evaluation of model's performance on
out-of-distribution datasets, research community introduces the ImageNet-variant datasets, which shares different data
distribution with that of ImageNet-(1K, 21K)., MMClassification supports evaluating the fine-tuned model on
[ImageNet-Adversarial (A)](https://arxiv.org/abs/1907.07174), [ImageNet-Rendition (R)](https://arxiv.org/abs/2006.16241),
[ImageNet-Corruption (C)](https://arxiv.org/abs/1903.12261), and [ImageNet-Sketch (S)](https://arxiv.org/abs/1905.13549).
You can follow these steps below to have a try:
### Prepare the datasets
You can download these datasets from [OpenDataLab](https://opendatalab.com/) and refactor these datasets under the
`data` folder in the following format:
```text
imagenet-a
├── meta
│ └── val.txt
├── val
imagenet-r
├── meta
│ └── val.txt
├── val/
imagenet-s
├── meta
│ └── val.txt
├── val/
imagenet-c
├── meta
│ └── val.txt
├── val/
```
`val.txt` is the annotation file, which should have the same style as that of ImageNet-1K. You can refer to
[prepare_dataset](https://mmclassification.readthedocs.io/en/1.x/user_guides/dataset_prepare.html) to generate the
annotation file or you can refer to this [script](https://github.com/open-mmlab/mmclassification/tree/dev-1.x/projects/example_project/ood_eval/generate_imagenet_variant_annotation.py).
### Configure the dataset and test evaluator
Once the dataset is ready, you need to configure the `dataset` and `test_evaluator`. You have two options to
write the default settings:
#### 1. Change the configuration file directly
There are few modifications to the config file, but change the `data_root` of the test dataloader and pass the
annotation file to the `test_evaluator`.
```python
# You should replace imagenet-x below with imagenet-c, imagenet-r, imagenet-a
# or imagenet-s
test_dataloader=dict(dataset=dict(data_root='data/imagenet-x'))
test_evaluator=dict(ann_file='data/imagenet-x/meta/val.txt')
```
#### 2. Overwrite the default settings from command line
For example, you can overwrite the default settings by passing `--cfg-options`:
```bash
--cfg-options test_dataloader.dataset.data_root='data/imagenet-x' \
test_evaluator.ann_file='data/imagenet-x/meta/val.txt'
```
### Start test
This step is the common test step, you can follow this [guide](https://mmclassification.readthedocs.io/en/1.x/user_guides/train_test.html)
to evaluate your fine-tuned model on out-of-distribution datasets.
To make it easier, we also provide an off-the-shelf config files, for [ImageNet-C](https://github.com/open-mmlab/mmclassification/tree/dev-1.x/projects/example_project/ood_eval/vit_ood-eval_toy-example.py) and [ImageNet-C](https://github.com/open-mmlab/mmclassification/tree/dev-1.x/projects/example_project/ood_eval/vit_ood-eval_toy-example_imagnet-c.py), and you can have a try.

View File

@ -0,0 +1,5 @@
_base_ = 'mmcls::resnet/resnetv1c50_8xb32_in1k.py' # can be your own config
# You can replace imagenet-r with imagenet-a or imagenet-s
test_dataloader = dict(dataset=dict(data_root='data/imagenet-r'))
test_evaluator = dict(ann_file='data/imagenet-r/meta/val.txt')

View File

@ -0,0 +1,4 @@
_base_ = 'mmcls::resnet/resnetv1c50_8xb32_in1k.py' # can be your own config
test_dataloader = dict(dataset=dict(data_root='data/imagenet-c'))
test_evaluator = dict(type='CorruptionError')

View File

@ -0,0 +1,66 @@
import argparse
import os
parser = argparse.ArgumentParser()
parser.add_argument(
'--imagenet1k-ann-file',
type=str,
help='path to the ImageNet1k annotation file')
parser.add_argument(
'--imagenet-variant-root',
type=str,
help='the root folder of ImageNet variant')
parser.add_argument(
'--imagenet-variant-name',
type=str,
help='the name of the ImageNet variant')
parser.add_argument(
'--output-file', type=str, help='path to the output annotation file')
if __name__ == '__main__':
args = parser.parse_args()
with open(args.imagenet1k_ann_file, 'r') as f:
imagenet1k_list = [line.strip().split() for line in f.readlines()]
imagenet1k_list = [[line[0].split('/')[0], line[1]]
for line in imagenet1k_list]
imagenet1k_label_map = {line[0]: line[1] for line in imagenet1k_list}
imagenet_variant_images = []
if args.imagenet_variant_name != 'c':
# ImageNet variant A, R, S
imagenet_variant_subfolders = os.listdir(args.imagenet_variant_root)
imagenet_variant_subfolders = [
subfolder for subfolder in imagenet_variant_subfolders
if not subfolder.endswith('.txt')
]
for subfolder in imagenet_variant_subfolders:
cur_label = imagenet1k_label_map[subfolder]
cur_subfolder = os.path.join(args.imagenet_variant_root, subfolder)
cur_subfolder_files = os.listdir(cur_subfolder)
cur_subfolder_files = [
os.path.join(subfolder, file) + ' ' + cur_label
for file in cur_subfolder_files
]
imagenet_variant_images.extend(cur_subfolder_files)
else:
# ImageNet variant C
curruption_categories = os.listdir(args.imagenet_variant_root)
for category in curruption_categories:
curruption_levels = os.listdir(
os.path.join(args.imagenet_variant_root, category))
for level in curruption_levels:
imagenet_variant_subfolders = os.listdir(
os.path.join(args.imagenet_variant_root, category, level))
for subfolder in imagenet_variant_subfolders:
cur_label = imagenet1k_label_map[subfolder]
cur_subfolder = os.path.join(args.imagenet_variant_root,
category, level, subfolder)
cur_subfolder_files = os.listdir(cur_subfolder)
cur_subfolder_files = [
os.path.join(category, level, subfolder, file) + ' ' +
cur_label for file in cur_subfolder_files
]
imagenet_variant_images.extend(cur_subfolder_files)
with open(args.output_file, 'w') as f:
f.write('\n'.join(imagenet_variant_images))

View File

@ -0,0 +1,34 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from mmcls.registry import METRICS
class TestCorruptionError(TestCase):
def test_compute_metrics(self):
mCE_metrics = METRICS.build(dict(type='CorruptionError'))
results = [{
'pred_score': torch.tensor([0.7, 0.0, 0.3]),
'gt_label': torch.tensor([0]),
'img_path': 'a/b/c/gaussian_noise'
} for i in range(10)]
metrics = mCE_metrics.compute_metrics(results)
assert metrics['mCE'] == 0.0
def test_process(self):
mCE_metrics = METRICS.build(dict(type='CorruptionError'))
results = [{
'pred_label': {
'label': torch.tensor([0]),
'score': torch.tensor([0.7, 0.0, 0.3])
},
'gt_label': {
'label': torch.tensor([0])
},
'img_path': 'a/b/c/gaussian_noise'
} for i in range(10)]
mCE_metrics.process(None, results)
assert len(mCE_metrics.results) == 10