[Feature] Transfer shape-bias tool from mmselfsup (#1658)
* Transfer shape-bias tool from mmselfsup * run shape-bias successfully, add CN docs * fix unit test bug * add shape_bias to index.rst in docs * modified mistakes in shape-bias docspull/1689/head
parent
00030e3f7d
commit
feb0814b2f
|
@ -104,6 +104,7 @@ We always welcome *PRs* and *Issues* for the betterment of MMPretrain.
|
|||
useful_tools/log_result_analysis.md
|
||||
useful_tools/complexity_analysis.md
|
||||
useful_tools/confusion_matrix.md
|
||||
useful_tools/shape_bias.md
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
## Shape Bias Tool Usage
|
||||
|
||||
Shape bias measures how a model relies the shapes, compared to texture, to sense the semantics in images. For more details,
|
||||
we recommend interested readers to this [paper](https://arxiv.org/abs/2106.07411). MMPretrain provide an off-the-shelf toolbox to
|
||||
obtain the shape bias of a classification model. You can following these steps below:
|
||||
|
||||
### Prepare the dataset
|
||||
|
||||
First you should download the [cue-conflict](https://github.com/bethgelab/model-vs-human/releases/download/v0.1/cue-conflict.tar.gz) to `data` folder,
|
||||
and then unzip this dataset. After that, you `data` folder should have the following structure:
|
||||
|
||||
```text
|
||||
data
|
||||
├──cue-conflict
|
||||
| |──airplane
|
||||
| |──bear
|
||||
| ...
|
||||
| |── truck
|
||||
```
|
||||
|
||||
### Modify the config for classification
|
||||
|
||||
We run the shape-bias tool on a ViT-base model with masked autoencoder pretraining. Its config file is `configs/mae/benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py`, and its checkpoint is downloaded from [this link](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-base-p16_8xb512-fp16-coslr-1600e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k_20220825-cf70aa21.pth). Replace the original test_pipeline, test_dataloader and test_evaluation with the following configurations:
|
||||
|
||||
```python
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='ResizeEdge',
|
||||
scale=256,
|
||||
edge='short',
|
||||
backend='pillow'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='PackInputs')
|
||||
]
|
||||
test_dataloader = dict(
|
||||
pin_memory=True,
|
||||
collate_fn=dict(type='default_collate'),
|
||||
batch_size=32,
|
||||
num_workers=4,
|
||||
dataset=dict(
|
||||
type='CustomDataset',
|
||||
data_root='data/cue-conflict',
|
||||
pipeline=test_pipeline,
|
||||
_delete_=True),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
drop_last=False)
|
||||
test_evaluator = dict(
|
||||
type='mmpretrain.ShapeBiasMetric',
|
||||
_delete_=True,
|
||||
csv_dir='work_dirs/shape_bias',
|
||||
model_name='mae')
|
||||
```
|
||||
|
||||
Please note you should make custom modifications to the `csv_dir` and `model_name` above. I renamed my modified sample config file as `vit-base-p16_8xb128-coslr-100e_in1k_shape-bias.py` in the folder `configs/mae/benchmarks/`.
|
||||
|
||||
### Inference your model with above modified config file
|
||||
|
||||
Then you should inferece your model on the `cue-conflict` dataset with the your modified config file.
|
||||
|
||||
```shell
|
||||
# For PyTorch
|
||||
bash tools/dist_test.sh $CONFIG $CHECKPOINT
|
||||
```
|
||||
|
||||
**Description of all arguments**:
|
||||
|
||||
- `$CONFIG`: The path of your modified config file.
|
||||
- `$CHECKPOINT`: The path or link of the checkpoint file.
|
||||
|
||||
```shell
|
||||
# Example
|
||||
bash tools/dist_test.sh configs/mae/benchmarks/vit-base-p16_8xb128-coslr-100e_in1k_shape-bias.py https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-base-p16_8xb512-fp16-coslr-1600e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k_20220825-cf70aa21.pth 1
|
||||
```
|
||||
|
||||
After that, you should obtain a csv file in `csv_dir` folder, named `cue-conflict_model-name_session-1.csv`. Besides this file, you should also download these [csv files](https://github.com/bethgelab/model-vs-human/tree/master/raw-data/cue-conflict) to the
|
||||
`csv_dir`.
|
||||
|
||||
### Plot shape bias
|
||||
|
||||
Then we can start to plot the shape bias:
|
||||
|
||||
```shell
|
||||
python tools/analysis_tools/shape_bias.py --csv-dir $CSV_DIR --result-dir $RESULT_DIR --colors $RGB --markers o --plotting-names $YOUR_MODEL_NAME --model-names $YOUR_MODEL_NAME
|
||||
```
|
||||
|
||||
**Description of all arguments**:
|
||||
|
||||
- `--csv-dir $CSV_DIR`, the same directory to save these csv files.
|
||||
- `--result-dir $RESULT_DIR`, the directory to output the result named `cue-conflict_shape-bias_matrixplot.pdf`.
|
||||
- `--colors $RGB`, should be the RGB values, formatted in R G B, e.g. 100 100 100, and can be multiple RGB values, if you want to plot the shape bias of several models.
|
||||
- `--plotting-names $YOUR_MODEL_NAME`, the name of the legend in the shape bias figure, and you can set it as your model name. If you want to plot several models, plotting_names can be multiple values.
|
||||
- `model-names $YOUR_MODEL_NAME`, should be the same name specified in your config, and can be multiple names if you want to plot the shape bias of several models.
|
||||
|
||||
Please note, every three values for `--colors` corresponds to one value for `--model-names`. After all of above steps, you are expected to obtain the following figure.
|
||||
|
||||
<div align="center">
|
||||
<img src="https://github.com/open-mmlab/mmpretrain/assets/42371271/dc608d06-43eb-4860-bb70-486ed2a3f927" width="500" />
|
||||
</div>
|
|
@ -90,6 +90,7 @@ MMPretrain 上手路线
|
|||
useful_tools/log_result_analysis.md
|
||||
useful_tools/complexity_analysis.md
|
||||
useful_tools/confusion_matrix.md
|
||||
useful_tools/shape_bias.md
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
|
|
@ -0,0 +1,96 @@
|
|||
## 形状偏差(Shape Bias)工具用法
|
||||
|
||||
形状偏差(shape bias)衡量模型与纹理相比,如何依赖形状来感知图像中的语义。关于更多细节,我们向感兴趣的读者推荐这篇[论文](https://arxiv.org/abs/2106.07411) 。MMPretrain提供现成的工具箱来获得分类模型的形状偏差。您可以按照以下步骤操作:
|
||||
|
||||
### 准备数据集
|
||||
|
||||
首先你应该下载[cue-conflict](https://github.com/bethgelab/model-vs-human/releases/download/v0.1/cue-conflict.tar.gz) 到`data`文件夹,然后解压缩这个数据集。之后,你的`data`文件夹应具有一下结构:
|
||||
|
||||
```text
|
||||
data
|
||||
├──cue-conflict
|
||||
| |──airplane
|
||||
| |──bear
|
||||
| ...
|
||||
| |── truck
|
||||
```
|
||||
|
||||
### 修改分类配置
|
||||
|
||||
我们在使用MAE预训练的ViT-base模型上运行形状偏移工具。它的配置文件为`configs/mae/benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py`,它的检查点可从[此链接](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-base-p16_8xb512-fp16-coslr-1600e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k_20220825-cf70aa21.pth) 下载。将原始配置中的test_pipeline, test_dataloader和test_evaluation替换为以下配置:
|
||||
|
||||
```python
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='ResizeEdge',
|
||||
scale=256,
|
||||
edge='short',
|
||||
backend='pillow'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='PackInputs')
|
||||
]
|
||||
test_dataloader = dict(
|
||||
pin_memory=True,
|
||||
collate_fn=dict(type='default_collate'),
|
||||
batch_size=32,
|
||||
num_workers=4,
|
||||
dataset=dict(
|
||||
type='CustomDataset',
|
||||
data_root='data/cue-conflict',
|
||||
pipeline=test_pipeline,
|
||||
_delete_=True),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
drop_last=False)
|
||||
test_evaluator = dict(
|
||||
type='mmpretrain.ShapeBiasMetric',
|
||||
_delete_=True,
|
||||
csv_dir='work_dirs/shape_bias',
|
||||
model_name='mae')
|
||||
```
|
||||
|
||||
请注意,你应该对上面的`csv_dir`和`model_name`进行自定义修改。我把修改后的示例配置文件重命名为`configs/mae/benchmarks/`文件夹中的`vit-base-p16_8xb128-coslr-100e_in1k_shape-bias.py`文件。
|
||||
|
||||
### 用上面修改后的配置文件在你的模型上做推断
|
||||
|
||||
然后,你应该使用修改后的配置文件在`cue-conflict`数据集上推断你的模型。
|
||||
|
||||
```shell
|
||||
# For PyTorch
|
||||
bash tools/dist_test.sh $CONFIG $CHECKPOINT
|
||||
```
|
||||
|
||||
**所有参数的说明**:
|
||||
|
||||
- `$CONFIG`: 修改后的配置文件的路径。
|
||||
- `$CHECKPOINT`: 检查点文件的路径或链接。
|
||||
|
||||
```shell
|
||||
# Example
|
||||
bash tools/dist_test.sh configs/mae/benchmarks/vit-base-p16_8xb128-coslr-100e_in1k_shape-bias.py https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-base-p16_8xb512-fp16-coslr-1600e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k_20220825-cf70aa21.pth 1
|
||||
```
|
||||
|
||||
之后,你应该在`csv_dir`文件夹中获得一个名为`cue-conflict_model-name_session-1.csv`的csv文件。除了这个文件以外,你还应该下载这些[csv文件](https://github.com/bethgelab/model-vs-human/tree/master/raw-data/cue-conflict) 到`csv_dir`。
|
||||
|
||||
### 绘制形状偏差图
|
||||
|
||||
然后我们可以开始绘制形状偏差图:
|
||||
|
||||
```shell
|
||||
python tools/analysis_tools/shape_bias.py --csv-dir $CSV_DIR --result-dir $RESULT_DIR --colors $RGB --markers o --plotting-names $YOUR_MODEL_NAME --model-names $YOUR_MODEL_NAME
|
||||
```
|
||||
|
||||
**所有参数的说明**:
|
||||
|
||||
- `--csv-dir $CSV_DIR`, 与保存这些csv文件的目录相同。
|
||||
- `--result-dir $RESULT_DIR`, 输出名为`cue-conflict_shape-bias_matrixplot.pdf`的结果的目录。
|
||||
- `--colors $RGB`, 应该是RGB值,格式为R G B,例如100 100 100,如果你想绘制几个模型的形状偏差,可以是多个RGB值。
|
||||
- `--plotting-names $YOUR_MODEL_NAME`, 形状偏移图中图例的名称,您可以将其设置为模型名称。如果要绘制多个模型,plotting_names可以是多个值。
|
||||
- `model-names $YOUR_MODEL_NAME`, 应与配置中指定的名称相同,如果要绘制多个模型的形状偏差,则可以是多个名称。
|
||||
|
||||
请注意,`--colors`的每三个值对应于`--model-names`的一个值。完成以上所有步骤后,你将获得下图。
|
||||
|
||||
<div align="center">
|
||||
<img src="https://github.com/open-mmlab/mmpretrain/assets/42371271/dc608d06-43eb-4860-bb70-486ed2a3f927" width="500" />
|
||||
</div>
|
|
@ -6,6 +6,7 @@ from .multi_task import MultiTasksMetric
|
|||
from .nocaps import NocapsSave
|
||||
from .retrieval import RetrievalAveragePrecision, RetrievalRecall
|
||||
from .scienceqa import ScienceQAMetric
|
||||
from .shape_bias_label import ShapeBiasMetric
|
||||
from .single_label import Accuracy, ConfusionMatrix, SingleLabelMetric
|
||||
from .visual_grounding_eval import VisualGroundingMetric
|
||||
from .voc_multi_label import VOCAveragePrecision, VOCMultiLabelMetric
|
||||
|
@ -16,5 +17,5 @@ __all__ = [
|
|||
'MultiTasksMetric', 'VOCAveragePrecision', 'VOCMultiLabelMetric',
|
||||
'ConfusionMatrix', 'RetrievalRecall', 'VQAAcc', 'ReportVQA', 'COCOCaption',
|
||||
'VisualGroundingMetric', 'ScienceQAMetric', 'GQAAcc', 'NocapsSave',
|
||||
'RetrievalAveragePrecision'
|
||||
'RetrievalAveragePrecision', 'ShapeBiasMetric'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,172 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import csv
|
||||
import os
|
||||
import os.path as osp
|
||||
from typing import List, Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.dist.utils import get_rank
|
||||
from mmengine.evaluator import BaseMetric
|
||||
|
||||
from mmpretrain.registry import METRICS
|
||||
|
||||
|
||||
@METRICS.register_module()
|
||||
class ShapeBiasMetric(BaseMetric):
|
||||
"""Evaluate the model on ``cue_conflict`` dataset.
|
||||
|
||||
This module will evaluate the model on an OOD dataset, cue_conflict, in
|
||||
order to measure the shape bias of the model. In addition to compuate the
|
||||
Top-1 accuracy, this module also generate a csv file to record the
|
||||
detailed prediction results, such that this csv file can be used to
|
||||
generate the shape bias curve.
|
||||
|
||||
Args:
|
||||
csv_dir (str): The directory to save the csv file.
|
||||
model_name (str): The name of the csv file. Please note that the
|
||||
model name should be an unique identifier.
|
||||
dataset_name (str): The name of the dataset. Default: 'cue_conflict'.
|
||||
"""
|
||||
|
||||
# mapping several classes from ImageNet-1K to the same category
|
||||
airplane_indices = [404]
|
||||
bear_indices = [294, 295, 296, 297]
|
||||
bicycle_indices = [444, 671]
|
||||
bird_indices = [
|
||||
8, 10, 11, 12, 13, 14, 15, 16, 18, 19, 20, 22, 23, 24, 80, 81, 82, 83,
|
||||
87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 98, 99, 100, 127, 128, 129,
|
||||
130, 131, 132, 133, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144,
|
||||
145
|
||||
]
|
||||
boat_indices = [472, 554, 625, 814, 914]
|
||||
bottle_indices = [440, 720, 737, 898, 899, 901, 907]
|
||||
car_indices = [436, 511, 817]
|
||||
cat_indices = [281, 282, 283, 284, 285, 286]
|
||||
chair_indices = [423, 559, 765, 857]
|
||||
clock_indices = [409, 530, 892]
|
||||
dog_indices = [
|
||||
152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165,
|
||||
166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179,
|
||||
180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 193, 194,
|
||||
195, 196, 197, 198, 199, 200, 201, 202, 203, 205, 206, 207, 208, 209,
|
||||
210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223,
|
||||
224, 225, 226, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238,
|
||||
239, 240, 241, 243, 244, 245, 246, 247, 248, 249, 250, 252, 253, 254,
|
||||
255, 256, 257, 259, 261, 262, 263, 265, 266, 267, 268
|
||||
]
|
||||
elephant_indices = [385, 386]
|
||||
keyboard_indices = [508, 878]
|
||||
knife_indices = [499]
|
||||
oven_indices = [766]
|
||||
truck_indices = [555, 569, 656, 675, 717, 734, 864, 867]
|
||||
|
||||
def __init__(self,
|
||||
csv_dir: str,
|
||||
model_name: str,
|
||||
dataset_name: str = 'cue_conflict',
|
||||
**kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.categories = sorted([
|
||||
'knife', 'keyboard', 'elephant', 'bicycle', 'airplane', 'clock',
|
||||
'oven', 'chair', 'bear', 'boat', 'cat', 'bottle', 'truck', 'car',
|
||||
'bird', 'dog'
|
||||
])
|
||||
self.csv_dir = csv_dir
|
||||
self.model_name = model_name
|
||||
self.dataset_name = dataset_name
|
||||
if get_rank() == 0:
|
||||
self.csv_path = self.create_csv()
|
||||
|
||||
def process(self, data_batch, data_samples: Sequence[dict]) -> None:
|
||||
"""Process one batch of data samples.
|
||||
|
||||
The processed results should be stored in ``self.results``, which will
|
||||
be used to computed the metrics when all batches have been processed.
|
||||
|
||||
Args:
|
||||
data_batch: A batch of data from the dataloader.
|
||||
data_samples (Sequence[dict]): A batch of outputs from the model.
|
||||
"""
|
||||
for data_sample in data_samples:
|
||||
result = dict()
|
||||
if 'pred_score' in data_sample:
|
||||
result['pred_score'] = data_sample['pred_score'].cpu()
|
||||
else:
|
||||
result['pred_label'] = data_sample['pred_label'].cpu()
|
||||
result['gt_label'] = data_sample['gt_label'].cpu()
|
||||
result['gt_category'] = data_sample['img_path'].split('/')[-2]
|
||||
result['img_name'] = data_sample['img_path'].split('/')[-1]
|
||||
|
||||
aggregated_category_probabilities = []
|
||||
# get the prediction for each category of current instance
|
||||
for category in self.categories:
|
||||
category_indices = getattr(self, f'{category}_indices')
|
||||
category_probabilities = torch.gather(
|
||||
result['pred_score'], 0,
|
||||
torch.tensor(category_indices)).mean()
|
||||
aggregated_category_probabilities.append(
|
||||
category_probabilities)
|
||||
# sort the probabilities in descending order
|
||||
pred_indices = torch.stack(aggregated_category_probabilities
|
||||
).argsort(descending=True).numpy()
|
||||
result['pred_category'] = np.take(self.categories, pred_indices)
|
||||
|
||||
# Save the result to `self.results`.
|
||||
self.results.append(result)
|
||||
|
||||
def create_csv(self) -> str:
|
||||
"""Create a csv file to store the results."""
|
||||
session_name = 'session-1'
|
||||
csv_path = osp.join(
|
||||
self.csv_dir, self.dataset_name + '_' + self.model_name + '_' +
|
||||
session_name + '.csv')
|
||||
if osp.exists(csv_path):
|
||||
os.remove(csv_path)
|
||||
directory = osp.dirname(csv_path)
|
||||
if not osp.exists(directory):
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
with open(csv_path, 'w') as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow([
|
||||
'subj', 'session', 'trial', 'rt', 'object_response',
|
||||
'category', 'condition', 'imagename'
|
||||
])
|
||||
return csv_path
|
||||
|
||||
def dump_results_to_csv(self, results: List[dict]) -> None:
|
||||
"""Dump the results to a csv file.
|
||||
|
||||
Args:
|
||||
results (List[dict]): A list of results.
|
||||
"""
|
||||
for i, result in enumerate(results):
|
||||
img_name = result['img_name']
|
||||
category = result['gt_category']
|
||||
condition = 'NaN'
|
||||
with open(self.csv_path, 'a') as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow([
|
||||
self.model_name, 1, i + 1, 'NaN',
|
||||
result['pred_category'][0], category, condition, img_name
|
||||
])
|
||||
|
||||
def compute_metrics(self, results: List[dict]) -> dict:
|
||||
"""Compute the metrics from the results.
|
||||
|
||||
Args:
|
||||
results (List[dict]): A list of results.
|
||||
|
||||
Returns:
|
||||
dict: A dict of metrics.
|
||||
"""
|
||||
if get_rank() == 0:
|
||||
self.dump_results_to_csv(results)
|
||||
metrics = dict()
|
||||
metrics['accuracy/top1'] = np.mean([
|
||||
result['pred_category'][0] == result['gt_category']
|
||||
for result in results
|
||||
])
|
||||
|
||||
return metrics
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmpretrain.evaluation import ShapeBiasMetric
|
||||
|
||||
|
||||
def test_shape_bias_metric():
|
||||
data_sample = dict()
|
||||
data_sample['pred_score'] = torch.rand(1000, )
|
||||
data_sample['pred_label'] = torch.tensor(1)
|
||||
data_sample['gt_label'] = torch.tensor(1)
|
||||
data_sample['img_path'] = 'tests/airplane/test.JPEG'
|
||||
evaluator = ShapeBiasMetric(
|
||||
csv_dir='tests/data', dataset_name='test', model_name='test')
|
||||
evaluator.process(None, [data_sample])
|
|
@ -0,0 +1,284 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Modified from https://github.com/bethgelab/model-vs-human
|
||||
import argparse
|
||||
import os
|
||||
import os.path as osp
|
||||
|
||||
import matplotlib as mpl
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from mmengine.logging import MMLogger
|
||||
from utils import FormatStrFormatter, ShapeBias
|
||||
|
||||
# global default boundary settings for thin gray transparent
|
||||
# boundaries to avoid not being able to see the difference
|
||||
# between two partially overlapping datapoints of the same color:
|
||||
PLOTTING_EDGE_COLOR = (0.3, 0.3, 0.3, 0.3)
|
||||
PLOTTING_EDGE_WIDTH = 0.02
|
||||
ICONS_DIR = osp.join(
|
||||
osp.dirname(__file__), '..', '..', 'resources', 'shape_bias_icons')
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--csv-dir', type=str, help='directory of csv files')
|
||||
parser.add_argument(
|
||||
'--result-dir', type=str, help='directory to save plotting results')
|
||||
parser.add_argument('--model-names', nargs='+', default=[], help='model name')
|
||||
parser.add_argument(
|
||||
'--colors',
|
||||
nargs='+',
|
||||
type=float,
|
||||
default=[],
|
||||
help= # noqa
|
||||
'the colors for the plots of each model, and they should be in the same order as model_names' # noqa: E501
|
||||
)
|
||||
parser.add_argument(
|
||||
'--markers',
|
||||
nargs='+',
|
||||
type=str,
|
||||
default=[],
|
||||
help= # noqa
|
||||
'the markers for the plots of each model, and they should be in the same order as model_names' # noqa: E501
|
||||
)
|
||||
parser.add_argument(
|
||||
'--plotting-names',
|
||||
nargs='+',
|
||||
default=[],
|
||||
help= # noqa
|
||||
'the plotting names for the plots of each model, and they should be in the same order as model_names' # noqa: E501
|
||||
)
|
||||
parser.add_argument(
|
||||
'--delete-icons',
|
||||
action='store_true',
|
||||
help='whether to delete the icons after plotting')
|
||||
|
||||
humans = [
|
||||
'subject-01', 'subject-02', 'subject-03', 'subject-04', 'subject-05',
|
||||
'subject-06', 'subject-07', 'subject-08', 'subject-09', 'subject-10'
|
||||
]
|
||||
|
||||
icon_names = [
|
||||
'airplane.png', 'response_icons_vertical_reverse.png', 'bottle.png',
|
||||
'car.png', 'oven.png', 'elephant.png', 'dog.png', 'boat.png', 'clock.png',
|
||||
'chair.png', 'keyboard.png', 'bird.png', 'bicycle.png',
|
||||
'response_icons_horizontal.png', 'cat.png', 'bear.png', 'colorbar.pdf',
|
||||
'knife.png', 'response_icons_vertical.png', 'truck.png'
|
||||
]
|
||||
|
||||
|
||||
def read_csvs(csv_dir: str) -> pd.DataFrame:
|
||||
"""Reads all csv files in a directory and returns a single dataframe.
|
||||
|
||||
Args:
|
||||
csv_dir (str): directory of csv files.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: dataframe containing all csv files
|
||||
"""
|
||||
df = pd.DataFrame()
|
||||
for csv in os.listdir(csv_dir):
|
||||
if csv.endswith('.csv'):
|
||||
cur_df = pd.read_csv(osp.join(csv_dir, csv))
|
||||
cur_df.columns = [c.lower() for c in cur_df.columns]
|
||||
df = df.append(cur_df)
|
||||
df.condition = df.condition.astype(str)
|
||||
return df
|
||||
|
||||
|
||||
def plot_shape_bias_matrixplot(args, analysis=ShapeBias()) -> None:
|
||||
"""Plots a matrixplot of shape bias.
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): arguments.
|
||||
analysis (ShapeBias): shape bias analysis. Defaults to ShapeBias().
|
||||
"""
|
||||
mpl.rcParams['font.family'] = ['serif']
|
||||
mpl.rcParams['font.serif'] = ['Times New Roman']
|
||||
|
||||
plt.figure(figsize=(9, 7))
|
||||
df = read_csvs(args.csv_dir)
|
||||
|
||||
fontsize = 15
|
||||
ticklength = 10
|
||||
markersize = 250
|
||||
label_size = 20
|
||||
|
||||
classes = df['category'].unique()
|
||||
num_classes = len(classes)
|
||||
|
||||
# plot setup
|
||||
fig = plt.figure(1, figsize=(12, 12), dpi=300.)
|
||||
ax = plt.gca()
|
||||
|
||||
ax.set_xlim([0, 1])
|
||||
ax.set_ylim([-.5, num_classes - 0.5])
|
||||
|
||||
# secondary reversed x axis
|
||||
ax_top = ax.secondary_xaxis(
|
||||
'top', functions=(lambda x: 1 - x, lambda x: 1 - x))
|
||||
|
||||
# labels, ticks
|
||||
plt.tick_params(
|
||||
axis='y', which='both', left=False, right=False, labelleft=False)
|
||||
ax.set_ylabel('Shape categories', labelpad=60, fontsize=label_size)
|
||||
ax.set_xlabel(
|
||||
"Fraction of 'texture' decisions", fontsize=label_size, labelpad=25)
|
||||
ax_top.set_xlabel(
|
||||
"Fraction of 'shape' decisions", fontsize=label_size, labelpad=25)
|
||||
ax.xaxis.set_major_formatter(FormatStrFormatter('%g'))
|
||||
ax_top.xaxis.set_major_formatter(FormatStrFormatter('%g'))
|
||||
ax.get_xaxis().set_ticks(np.arange(0, 1.1, 0.1))
|
||||
ax_top.set_ticks(np.arange(0, 1.1, 0.1))
|
||||
ax.tick_params(
|
||||
axis='both', which='major', labelsize=fontsize, length=ticklength)
|
||||
ax_top.tick_params(
|
||||
axis='both', which='major', labelsize=fontsize, length=ticklength)
|
||||
|
||||
# arrows on x axes
|
||||
plt.arrow(
|
||||
x=0,
|
||||
y=-1.75,
|
||||
dx=1,
|
||||
dy=0,
|
||||
fc='black',
|
||||
head_width=0.4,
|
||||
head_length=0.03,
|
||||
clip_on=False,
|
||||
length_includes_head=True,
|
||||
overhang=0.5)
|
||||
plt.arrow(
|
||||
x=1,
|
||||
y=num_classes + 0.75,
|
||||
dx=-1,
|
||||
dy=0,
|
||||
fc='black',
|
||||
head_width=0.4,
|
||||
head_length=0.03,
|
||||
clip_on=False,
|
||||
length_includes_head=True,
|
||||
overhang=0.5)
|
||||
|
||||
# icons besides y axis
|
||||
# determine order of icons
|
||||
df_selection = df.loc[(df['subj'].isin(humans))]
|
||||
class_avgs = []
|
||||
for cl in classes:
|
||||
df_class_selection = df_selection.query("category == '{}'".format(cl))
|
||||
class_avgs.append(1 - analysis.analysis(
|
||||
df=df_class_selection)['shape-bias'])
|
||||
sorted_indices = np.argsort(class_avgs)
|
||||
classes = classes[sorted_indices]
|
||||
|
||||
# icon placement is calculated in axis coordinates
|
||||
WIDTH = 1 / num_classes
|
||||
# placement left of yaxis (-WIDTH) plus some spacing (-.25*WIDTH)
|
||||
XPOS = -1.25 * WIDTH
|
||||
YPOS = -0.5
|
||||
HEIGHT = 1
|
||||
MARGINX = 1 / 10 * WIDTH # vertical whitespace between icons
|
||||
MARGINY = 1 / 10 * HEIGHT # horizontal whitespace between icons
|
||||
|
||||
left = XPOS + MARGINX
|
||||
right = XPOS + WIDTH - MARGINX
|
||||
|
||||
for i in range(num_classes):
|
||||
bottom = i + MARGINY + YPOS
|
||||
top = (i + 1) - MARGINY + YPOS
|
||||
iconpath = osp.join(ICONS_DIR, '{}.png'.format(classes[i]))
|
||||
plt.imshow(
|
||||
plt.imread(iconpath),
|
||||
extent=[left, right, bottom, top],
|
||||
aspect='auto',
|
||||
clip_on=False)
|
||||
|
||||
# plot horizontal intersection lines
|
||||
for i in range(num_classes - 1):
|
||||
plt.plot([0, 1], [i + .5, i + .5],
|
||||
c='gray',
|
||||
linestyle='dotted',
|
||||
alpha=0.4)
|
||||
|
||||
# plot average shapebias + scatter points
|
||||
for i in range(len(args.model_names)):
|
||||
df_selection = df.loc[(df['subj'].isin(args.model_names[i]))]
|
||||
result_df = analysis.analysis(df=df_selection)
|
||||
avg = 1 - result_df['shape-bias']
|
||||
ax.plot([avg, avg], [-1, num_classes], color=args.colors[i])
|
||||
class_avgs = []
|
||||
for cl in classes:
|
||||
df_class_selection = df_selection.query(
|
||||
"category == '{}'".format(cl))
|
||||
class_avgs.append(1 - analysis.analysis(
|
||||
df=df_class_selection)['shape-bias'])
|
||||
|
||||
ax.scatter(
|
||||
class_avgs,
|
||||
classes,
|
||||
color=args.colors[i],
|
||||
marker=args.markers[i],
|
||||
label=args.plotting_names[i],
|
||||
s=markersize,
|
||||
clip_on=False,
|
||||
edgecolors=PLOTTING_EDGE_COLOR,
|
||||
linewidths=PLOTTING_EDGE_WIDTH,
|
||||
zorder=3)
|
||||
plt.legend(frameon=True, labelspacing=1, loc=9)
|
||||
|
||||
figure_path = osp.join(args.result_dir,
|
||||
'cue-conflict_shape-bias_matrixplot.pdf')
|
||||
fig.savefig(figure_path, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
|
||||
def check_icons() -> bool:
|
||||
"""Check if icons are present, if not download them."""
|
||||
if not osp.exists(ICONS_DIR):
|
||||
return False
|
||||
for icon_name in icon_names:
|
||||
if not osp.exists(osp.join(ICONS_DIR, icon_name)):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if not check_icons():
|
||||
root_url = 'https://github.com/bethgelab/model-vs-human/raw/master/assets/icons' # noqa: E501
|
||||
os.makedirs(ICONS_DIR, exist_ok=True)
|
||||
MMLogger.get_current_instance().info(
|
||||
f'Downloading icons to {ICONS_DIR}')
|
||||
for icon_name in icon_names:
|
||||
url = osp.join(root_url, icon_name)
|
||||
os.system('wget -O {} {}'.format(
|
||||
osp.join(ICONS_DIR, icon_name), url))
|
||||
|
||||
args = parser.parse_args()
|
||||
assert len(args.model_names) * 3 == len(args.colors), 'Number of colors \
|
||||
must be 3 times the number of models. Every three colors are the RGB \
|
||||
values for one model.'
|
||||
|
||||
# preprocess colors
|
||||
args.colors = [c / 255. for c in args.colors]
|
||||
colors = []
|
||||
for i in range(len(args.model_names)):
|
||||
colors.append(args.colors[3 * i:3 * i + 3])
|
||||
args.colors = colors
|
||||
args.colors.append([165 / 255., 30 / 255., 55 / 255.]) # human color
|
||||
|
||||
# if plotting names are not specified, use model names
|
||||
if len(args.plotting_names) == 0:
|
||||
args.plotting_names = args.model_names
|
||||
|
||||
# preprocess markers
|
||||
args.markers.append('D') # human marker
|
||||
|
||||
# preprocess model names
|
||||
args.model_names = [[m] for m in args.model_names]
|
||||
args.model_names.append(humans)
|
||||
|
||||
# preprocess plotting names
|
||||
args.plotting_names.append('Humans')
|
||||
|
||||
plot_shape_bias_matrixplot(args)
|
||||
if args.delete_icons:
|
||||
os.system('rm -rf {}'.format(ICONS_DIR))
|
|
@ -0,0 +1,277 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Modified from https://github.com/bethgelab/model-vs-human
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import matplotlib as mpl
|
||||
import pandas as pd
|
||||
from matplotlib import _api
|
||||
from matplotlib import transforms as mtransforms
|
||||
|
||||
|
||||
class _DummyAxis:
|
||||
"""Define the minimal interface for a dummy axis.
|
||||
|
||||
Args:
|
||||
minpos (float): The minimum positive value for the axis. Defaults to 0.
|
||||
"""
|
||||
__name__ = 'dummy'
|
||||
|
||||
# Once the deprecation elapses, replace dataLim and viewLim by plain
|
||||
# _view_interval and _data_interval private tuples.
|
||||
dataLim = _api.deprecate_privatize_attribute(
|
||||
'3.6', alternative='get_data_interval() and set_data_interval()')
|
||||
viewLim = _api.deprecate_privatize_attribute(
|
||||
'3.6', alternative='get_view_interval() and set_view_interval()')
|
||||
|
||||
def __init__(self, minpos: float = 0) -> None:
|
||||
self._dataLim = mtransforms.Bbox.unit()
|
||||
self._viewLim = mtransforms.Bbox.unit()
|
||||
self._minpos = minpos
|
||||
|
||||
def get_view_interval(self) -> Dict:
|
||||
"""Return the view interval as a tuple (*vmin*, *vmax*)."""
|
||||
return self._viewLim.intervalx
|
||||
|
||||
def set_view_interval(self, vmin: float, vmax: float) -> None:
|
||||
"""Set the view interval to (*vmin*, *vmax*)."""
|
||||
self._viewLim.intervalx = vmin, vmax
|
||||
|
||||
def get_minpos(self) -> float:
|
||||
"""Return the minimum positive value for the axis."""
|
||||
return self._minpos
|
||||
|
||||
def get_data_interval(self) -> Dict:
|
||||
"""Return the data interval as a tuple (*vmin*, *vmax*)."""
|
||||
return self._dataLim.intervalx
|
||||
|
||||
def set_data_interval(self, vmin: float, vmax: float) -> None:
|
||||
"""Set the data interval to (*vmin*, *vmax*)."""
|
||||
self._dataLim.intervalx = vmin, vmax
|
||||
|
||||
def get_tick_space(self) -> int:
|
||||
"""Return the number of ticks to use."""
|
||||
# Just use the long-standing default of nbins==9
|
||||
return 9
|
||||
|
||||
|
||||
class TickHelper:
|
||||
"""A helper class for ticks and tick labels."""
|
||||
axis = None
|
||||
|
||||
def set_axis(self, axis: Any) -> None:
|
||||
"""Set the axis instance."""
|
||||
self.axis = axis
|
||||
|
||||
def create_dummy_axis(self, **kwargs) -> None:
|
||||
"""Create a dummy axis if no axis is set."""
|
||||
if self.axis is None:
|
||||
self.axis = _DummyAxis(**kwargs)
|
||||
|
||||
@_api.deprecated('3.5', alternative='`.Axis.set_view_interval`')
|
||||
def set_view_interval(self, vmin: float, vmax: float) -> None:
|
||||
"""Set the view interval to (*vmin*, *vmax*)."""
|
||||
self.axis.set_view_interval(vmin, vmax)
|
||||
|
||||
@_api.deprecated('3.5', alternative='`.Axis.set_data_interval`')
|
||||
def set_data_interval(self, vmin: float, vmax: float) -> None:
|
||||
"""Set the data interval to (*vmin*, *vmax*)."""
|
||||
self.axis.set_data_interval(vmin, vmax)
|
||||
|
||||
@_api.deprecated(
|
||||
'3.5',
|
||||
alternative='`.Axis.set_view_interval` and `.Axis.set_data_interval`')
|
||||
def set_bounds(self, vmin: float, vmax: float) -> None:
|
||||
"""Set the view and data interval to (*vmin*, *vmax*)."""
|
||||
self.set_view_interval(vmin, vmax)
|
||||
self.set_data_interval(vmin, vmax)
|
||||
|
||||
|
||||
class Formatter(TickHelper):
|
||||
"""Create a string based on a tick value and location."""
|
||||
# some classes want to see all the locs to help format
|
||||
# individual ones
|
||||
locs = []
|
||||
|
||||
def __call__(self, x: str, pos: Optional[Any] = None) -> str:
|
||||
"""Return the format for tick value *x* at position pos.
|
||||
|
||||
``pos=None`` indicates an unspecified location.
|
||||
|
||||
This method must be overridden in the derived class.
|
||||
|
||||
Args:
|
||||
x (str): The tick value.
|
||||
pos (Optional[Any]): The tick position. Defaults to None.
|
||||
"""
|
||||
raise NotImplementedError('Derived must override')
|
||||
|
||||
def format_ticks(self, values: pd.Series) -> List[str]:
|
||||
"""Return the tick labels for all the ticks at once.
|
||||
|
||||
Args:
|
||||
values (pd.Series): The tick values.
|
||||
|
||||
Returns:
|
||||
List[str]: The tick labels.
|
||||
"""
|
||||
self.set_locs(values)
|
||||
return [self(value, i) for i, value in enumerate(values)]
|
||||
|
||||
def format_data(self, value: Any) -> str:
|
||||
"""Return the full string representation of the value with the position
|
||||
unspecified.
|
||||
|
||||
Args:
|
||||
value (Any): The tick value.
|
||||
|
||||
Returns:
|
||||
str: The full string representation of the value.
|
||||
"""
|
||||
return self.__call__(value)
|
||||
|
||||
def format_data_short(self, value: Any) -> str:
|
||||
"""Return a short string version of the tick value.
|
||||
|
||||
Defaults to the position-independent long value.
|
||||
|
||||
Args:
|
||||
value (Any): The tick value.
|
||||
|
||||
Returns:
|
||||
str: The short string representation of the value.
|
||||
"""
|
||||
return self.format_data(value)
|
||||
|
||||
def get_offset(self) -> str:
|
||||
"""Return the offset string."""
|
||||
return ''
|
||||
|
||||
def set_locs(self, locs: List[Any]) -> None:
|
||||
"""Set the locations of the ticks.
|
||||
|
||||
This method is called before computing the tick labels because some
|
||||
formatters need to know all tick locations to do so.
|
||||
"""
|
||||
self.locs = locs
|
||||
|
||||
@staticmethod
|
||||
def fix_minus(s: str) -> str:
|
||||
"""Some classes may want to replace a hyphen for minus with the proper
|
||||
Unicode symbol (U+2212) for typographical correctness.
|
||||
|
||||
This is a
|
||||
helper method to perform such a replacement when it is enabled via
|
||||
:rc:`axes.unicode_minus`.
|
||||
|
||||
Args:
|
||||
s (str): The string to replace the hyphen with the Unicode symbol.
|
||||
"""
|
||||
return (s.replace('-', '\N{MINUS SIGN}')
|
||||
if mpl.rcParams['axes.unicode_minus'] else s)
|
||||
|
||||
def _set_locator(self, locator: Any) -> None:
|
||||
"""Subclasses may want to override this to set a locator."""
|
||||
pass
|
||||
|
||||
|
||||
class FormatStrFormatter(Formatter):
|
||||
"""Use an old-style ('%' operator) format string to format the tick.
|
||||
|
||||
The format string should have a single variable format (%) in it.
|
||||
It will be applied to the value (not the position) of the tick.
|
||||
|
||||
Negative numeric values will use a dash, not a Unicode minus; use mathtext
|
||||
to get a Unicode minus by wrapping the format specifier with $ (e.g.
|
||||
"$%g$").
|
||||
|
||||
Args:
|
||||
fmt (str): Format string.
|
||||
"""
|
||||
|
||||
def __init__(self, fmt: str) -> None:
|
||||
self.fmt = fmt
|
||||
|
||||
def __call__(self, x: str, pos: Optional[Any]) -> str:
|
||||
"""Return the formatted label string.
|
||||
|
||||
Only the value *x* is formatted. The position is ignored.
|
||||
|
||||
Args:
|
||||
x (str): The value to format.
|
||||
pos (Any): The position of the tick. Ignored.
|
||||
"""
|
||||
return self.fmt % x
|
||||
|
||||
|
||||
class ShapeBias:
|
||||
"""Compute the shape bias of a model.
|
||||
|
||||
Reference: `ImageNet-trained CNNs are biased towards texture;
|
||||
increasing shape bias improves accuracy and robustness
|
||||
<https://arxiv.org/abs/1811.12231>`_.
|
||||
"""
|
||||
num_input_models = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.plotting_name = 'shape-bias'
|
||||
|
||||
@staticmethod
|
||||
def _check_dataframe(df: pd.DataFrame) -> None:
|
||||
"""Check that the dataframe is valid."""
|
||||
assert len(df) > 0, 'empty dataframe'
|
||||
|
||||
def analysis(self, df: pd.DataFrame) -> Dict[str, float]:
|
||||
"""Compute the shape bias of a model.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The dataframe containing the data.
|
||||
|
||||
Returns:
|
||||
Dict[str, float]: The shape bias.
|
||||
"""
|
||||
self._check_dataframe(df)
|
||||
|
||||
df = df.copy()
|
||||
df['correct_texture'] = df['imagename'].apply(
|
||||
self.get_texture_category)
|
||||
df['correct_shape'] = df['category']
|
||||
|
||||
# remove those rows where shape = texture, i.e. no cue conflict present
|
||||
df2 = df.loc[df.correct_shape != df.correct_texture]
|
||||
fraction_correct_shape = len(
|
||||
df2.loc[df2.object_response == df2.correct_shape]) / len(df)
|
||||
fraction_correct_texture = len(
|
||||
df2.loc[df2.object_response == df2.correct_texture]) / len(df)
|
||||
shape_bias = fraction_correct_shape / (
|
||||
fraction_correct_shape + fraction_correct_texture)
|
||||
|
||||
result_dict = {
|
||||
'fraction-correct-shape': fraction_correct_shape,
|
||||
'fraction-correct-texture': fraction_correct_texture,
|
||||
'shape-bias': shape_bias
|
||||
}
|
||||
return result_dict
|
||||
|
||||
def get_texture_category(self, imagename: str) -> str:
|
||||
"""Return texture category from imagename.
|
||||
|
||||
e.g. 'XXX_dog10-bird2.png' -> 'bird '
|
||||
|
||||
Args:
|
||||
imagename (str): Name of the image.
|
||||
|
||||
Returns:
|
||||
str: Texture category.
|
||||
"""
|
||||
assert type(imagename) is str
|
||||
|
||||
# remove unnecessary words
|
||||
a = imagename.split('_')[-1]
|
||||
# remove .png etc.
|
||||
b = a.split('.')[0]
|
||||
# get texture category (last word)
|
||||
c = b.split('-')[-1]
|
||||
# remove number, e.g. 'bird2' -> 'bird'
|
||||
d = ''.join([i for i in c if not i.isdigit()])
|
||||
return d
|
Loading…
Reference in New Issue