[Feature] Transfer shape-bias tool from mmselfsup (#1658)

* Transfer shape-bias tool from mmselfsup

* run shape-bias successfully, add CN docs

* fix unit test bug

* add shape_bias to index.rst in docs

* modified mistakes in shape-bias docs
pull/1689/head
Wang Xiang 2023-07-03 11:39:23 +08:00 committed by GitHub
parent 00030e3f7d
commit feb0814b2f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 948 additions and 1 deletions

View File

@ -104,6 +104,7 @@ We always welcome *PRs* and *Issues* for the betterment of MMPretrain.
useful_tools/log_result_analysis.md
useful_tools/complexity_analysis.md
useful_tools/confusion_matrix.md
useful_tools/shape_bias.md
.. toctree::
:maxdepth: 1

View File

@ -0,0 +1,100 @@
## Shape Bias Tool Usage
Shape bias measures how a model relies the shapes, compared to texture, to sense the semantics in images. For more details,
we recommend interested readers to this [paper](https://arxiv.org/abs/2106.07411). MMPretrain provide an off-the-shelf toolbox to
obtain the shape bias of a classification model. You can following these steps below:
### Prepare the dataset
First you should download the [cue-conflict](https://github.com/bethgelab/model-vs-human/releases/download/v0.1/cue-conflict.tar.gz) to `data` folder,
and then unzip this dataset. After that, you `data` folder should have the following structure:
```text
data
├──cue-conflict
| |──airplane
| |──bear
| ...
| |── truck
```
### Modify the config for classification
We run the shape-bias tool on a ViT-base model with masked autoencoder pretraining. Its config file is `configs/mae/benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py`, and its checkpoint is downloaded from [this link](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-base-p16_8xb512-fp16-coslr-1600e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k_20220825-cf70aa21.pth). Replace the original test_pipeline, test_dataloader and test_evaluation with the following configurations:
```python
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeEdge',
scale=256,
edge='short',
backend='pillow'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackInputs')
]
test_dataloader = dict(
pin_memory=True,
collate_fn=dict(type='default_collate'),
batch_size=32,
num_workers=4,
dataset=dict(
type='CustomDataset',
data_root='data/cue-conflict',
pipeline=test_pipeline,
_delete_=True),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
drop_last=False)
test_evaluator = dict(
type='mmpretrain.ShapeBiasMetric',
_delete_=True,
csv_dir='work_dirs/shape_bias',
model_name='mae')
```
Please note you should make custom modifications to the `csv_dir` and `model_name` above. I renamed my modified sample config file as `vit-base-p16_8xb128-coslr-100e_in1k_shape-bias.py` in the folder `configs/mae/benchmarks/`.
### Inference your model with above modified config file
Then you should inferece your model on the `cue-conflict` dataset with the your modified config file.
```shell
# For PyTorch
bash tools/dist_test.sh $CONFIG $CHECKPOINT
```
**Description of all arguments**:
- `$CONFIG`: The path of your modified config file.
- `$CHECKPOINT`: The path or link of the checkpoint file.
```shell
# Example
bash tools/dist_test.sh configs/mae/benchmarks/vit-base-p16_8xb128-coslr-100e_in1k_shape-bias.py https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-base-p16_8xb512-fp16-coslr-1600e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k_20220825-cf70aa21.pth 1
```
After that, you should obtain a csv file in `csv_dir` folder, named `cue-conflict_model-name_session-1.csv`. Besides this file, you should also download these [csv files](https://github.com/bethgelab/model-vs-human/tree/master/raw-data/cue-conflict) to the
`csv_dir`.
### Plot shape bias
Then we can start to plot the shape bias:
```shell
python tools/analysis_tools/shape_bias.py --csv-dir $CSV_DIR --result-dir $RESULT_DIR --colors $RGB --markers o --plotting-names $YOUR_MODEL_NAME --model-names $YOUR_MODEL_NAME
```
**Description of all arguments**:
- `--csv-dir $CSV_DIR`, the same directory to save these csv files.
- `--result-dir $RESULT_DIR`, the directory to output the result named `cue-conflict_shape-bias_matrixplot.pdf`.
- `--colors $RGB`, should be the RGB values, formatted in R G B, e.g. 100 100 100, and can be multiple RGB values, if you want to plot the shape bias of several models.
- `--plotting-names $YOUR_MODEL_NAME`, the name of the legend in the shape bias figure, and you can set it as your model name. If you want to plot several models, plotting_names can be multiple values.
- `model-names $YOUR_MODEL_NAME`, should be the same name specified in your config, and can be multiple names if you want to plot the shape bias of several models.
Please note, every three values for `--colors` corresponds to one value for `--model-names`. After all of above steps, you are expected to obtain the following figure.
<div align="center">
<img src="https://github.com/open-mmlab/mmpretrain/assets/42371271/dc608d06-43eb-4860-bb70-486ed2a3f927" width="500" />
</div>

View File

@ -90,6 +90,7 @@ MMPretrain 上手路线
useful_tools/log_result_analysis.md
useful_tools/complexity_analysis.md
useful_tools/confusion_matrix.md
useful_tools/shape_bias.md
.. toctree::
:maxdepth: 1

View File

@ -0,0 +1,96 @@
## 形状偏差(Shape Bias)工具用法
形状偏差(shape bias)衡量模型与纹理相比,如何依赖形状来感知图像中的语义。关于更多细节,我们向感兴趣的读者推荐这篇[论文](https://arxiv.org/abs/2106.07411) 。MMPretrain提供现成的工具箱来获得分类模型的形状偏差。您可以按照以下步骤操作
### 准备数据集
首先你应该下载[cue-conflict](https://github.com/bethgelab/model-vs-human/releases/download/v0.1/cue-conflict.tar.gz) 到`data`文件夹,然后解压缩这个数据集。之后,你的`data`文件夹应具有一下结构:
```text
data
├──cue-conflict
| |──airplane
| |──bear
| ...
| |── truck
```
### 修改分类配置
我们在使用MAE预训练的ViT-base模型上运行形状偏移工具。它的配置文件为`configs/mae/benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py`,它的检查点可从[此链接](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-base-p16_8xb512-fp16-coslr-1600e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k_20220825-cf70aa21.pth) 下载。将原始配置中的test_pipeline, test_dataloader和test_evaluation替换为以下配置
```python
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeEdge',
scale=256,
edge='short',
backend='pillow'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackInputs')
]
test_dataloader = dict(
pin_memory=True,
collate_fn=dict(type='default_collate'),
batch_size=32,
num_workers=4,
dataset=dict(
type='CustomDataset',
data_root='data/cue-conflict',
pipeline=test_pipeline,
_delete_=True),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
drop_last=False)
test_evaluator = dict(
type='mmpretrain.ShapeBiasMetric',
_delete_=True,
csv_dir='work_dirs/shape_bias',
model_name='mae')
```
请注意,你应该对上面的`csv_dir`和`model_name`进行自定义修改。我把修改后的示例配置文件重命名为`configs/mae/benchmarks/`文件夹中的`vit-base-p16_8xb128-coslr-100e_in1k_shape-bias.py`文件。
### 用上面修改后的配置文件在你的模型上做推断
然后,你应该使用修改后的配置文件在`cue-conflict`数据集上推断你的模型。
```shell
# For PyTorch
bash tools/dist_test.sh $CONFIG $CHECKPOINT
```
**所有参数的说明**
- `$CONFIG`: 修改后的配置文件的路径。
- `$CHECKPOINT`: 检查点文件的路径或链接。
```shell
# Example
bash tools/dist_test.sh configs/mae/benchmarks/vit-base-p16_8xb128-coslr-100e_in1k_shape-bias.py https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-base-p16_8xb512-fp16-coslr-1600e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k_20220825-cf70aa21.pth 1
```
之后,你应该在`csv_dir`文件夹中获得一个名为`cue-conflict_model-name_session-1.csv`的csv文件。除了这个文件以外你还应该下载这些[csv文件](https://github.com/bethgelab/model-vs-human/tree/master/raw-data/cue-conflict) 到`csv_dir`。
### 绘制形状偏差图
然后我们可以开始绘制形状偏差图:
```shell
python tools/analysis_tools/shape_bias.py --csv-dir $CSV_DIR --result-dir $RESULT_DIR --colors $RGB --markers o --plotting-names $YOUR_MODEL_NAME --model-names $YOUR_MODEL_NAME
```
**所有参数的说明**:
- `--csv-dir $CSV_DIR`, 与保存这些csv文件的目录相同。
- `--result-dir $RESULT_DIR`, 输出名为`cue-conflict_shape-bias_matrixplot.pdf`的结果的目录。
- `--colors $RGB`, 应该是RGB值格式为R G B例如100 100 100如果你想绘制几个模型的形状偏差可以是多个RGB值。
- `--plotting-names $YOUR_MODEL_NAME`, 形状偏移图中图例的名称您可以将其设置为模型名称。如果要绘制多个模型plotting_names可以是多个值。
- `model-names $YOUR_MODEL_NAME`, 应与配置中指定的名称相同,如果要绘制多个模型的形状偏差,则可以是多个名称。
请注意,`--colors`的每三个值对应于`--model-names`的一个值。完成以上所有步骤后,你将获得下图。
<div align="center">
<img src="https://github.com/open-mmlab/mmpretrain/assets/42371271/dc608d06-43eb-4860-bb70-486ed2a3f927" width="500" />
</div>

View File

@ -6,6 +6,7 @@ from .multi_task import MultiTasksMetric
from .nocaps import NocapsSave
from .retrieval import RetrievalAveragePrecision, RetrievalRecall
from .scienceqa import ScienceQAMetric
from .shape_bias_label import ShapeBiasMetric
from .single_label import Accuracy, ConfusionMatrix, SingleLabelMetric
from .visual_grounding_eval import VisualGroundingMetric
from .voc_multi_label import VOCAveragePrecision, VOCMultiLabelMetric
@ -16,5 +17,5 @@ __all__ = [
'MultiTasksMetric', 'VOCAveragePrecision', 'VOCMultiLabelMetric',
'ConfusionMatrix', 'RetrievalRecall', 'VQAAcc', 'ReportVQA', 'COCOCaption',
'VisualGroundingMetric', 'ScienceQAMetric', 'GQAAcc', 'NocapsSave',
'RetrievalAveragePrecision'
'RetrievalAveragePrecision', 'ShapeBiasMetric'
]

View File

@ -0,0 +1,172 @@
# Copyright (c) OpenMMLab. All rights reserved.
import csv
import os
import os.path as osp
from typing import List, Sequence
import numpy as np
import torch
from mmengine.dist.utils import get_rank
from mmengine.evaluator import BaseMetric
from mmpretrain.registry import METRICS
@METRICS.register_module()
class ShapeBiasMetric(BaseMetric):
"""Evaluate the model on ``cue_conflict`` dataset.
This module will evaluate the model on an OOD dataset, cue_conflict, in
order to measure the shape bias of the model. In addition to compuate the
Top-1 accuracy, this module also generate a csv file to record the
detailed prediction results, such that this csv file can be used to
generate the shape bias curve.
Args:
csv_dir (str): The directory to save the csv file.
model_name (str): The name of the csv file. Please note that the
model name should be an unique identifier.
dataset_name (str): The name of the dataset. Default: 'cue_conflict'.
"""
# mapping several classes from ImageNet-1K to the same category
airplane_indices = [404]
bear_indices = [294, 295, 296, 297]
bicycle_indices = [444, 671]
bird_indices = [
8, 10, 11, 12, 13, 14, 15, 16, 18, 19, 20, 22, 23, 24, 80, 81, 82, 83,
87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 98, 99, 100, 127, 128, 129,
130, 131, 132, 133, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144,
145
]
boat_indices = [472, 554, 625, 814, 914]
bottle_indices = [440, 720, 737, 898, 899, 901, 907]
car_indices = [436, 511, 817]
cat_indices = [281, 282, 283, 284, 285, 286]
chair_indices = [423, 559, 765, 857]
clock_indices = [409, 530, 892]
dog_indices = [
152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165,
166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179,
180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 193, 194,
195, 196, 197, 198, 199, 200, 201, 202, 203, 205, 206, 207, 208, 209,
210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223,
224, 225, 226, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238,
239, 240, 241, 243, 244, 245, 246, 247, 248, 249, 250, 252, 253, 254,
255, 256, 257, 259, 261, 262, 263, 265, 266, 267, 268
]
elephant_indices = [385, 386]
keyboard_indices = [508, 878]
knife_indices = [499]
oven_indices = [766]
truck_indices = [555, 569, 656, 675, 717, 734, 864, 867]
def __init__(self,
csv_dir: str,
model_name: str,
dataset_name: str = 'cue_conflict',
**kwargs) -> None:
super().__init__(**kwargs)
self.categories = sorted([
'knife', 'keyboard', 'elephant', 'bicycle', 'airplane', 'clock',
'oven', 'chair', 'bear', 'boat', 'cat', 'bottle', 'truck', 'car',
'bird', 'dog'
])
self.csv_dir = csv_dir
self.model_name = model_name
self.dataset_name = dataset_name
if get_rank() == 0:
self.csv_path = self.create_csv()
def process(self, data_batch, data_samples: Sequence[dict]) -> None:
"""Process one batch of data samples.
The processed results should be stored in ``self.results``, which will
be used to computed the metrics when all batches have been processed.
Args:
data_batch: A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from the model.
"""
for data_sample in data_samples:
result = dict()
if 'pred_score' in data_sample:
result['pred_score'] = data_sample['pred_score'].cpu()
else:
result['pred_label'] = data_sample['pred_label'].cpu()
result['gt_label'] = data_sample['gt_label'].cpu()
result['gt_category'] = data_sample['img_path'].split('/')[-2]
result['img_name'] = data_sample['img_path'].split('/')[-1]
aggregated_category_probabilities = []
# get the prediction for each category of current instance
for category in self.categories:
category_indices = getattr(self, f'{category}_indices')
category_probabilities = torch.gather(
result['pred_score'], 0,
torch.tensor(category_indices)).mean()
aggregated_category_probabilities.append(
category_probabilities)
# sort the probabilities in descending order
pred_indices = torch.stack(aggregated_category_probabilities
).argsort(descending=True).numpy()
result['pred_category'] = np.take(self.categories, pred_indices)
# Save the result to `self.results`.
self.results.append(result)
def create_csv(self) -> str:
"""Create a csv file to store the results."""
session_name = 'session-1'
csv_path = osp.join(
self.csv_dir, self.dataset_name + '_' + self.model_name + '_' +
session_name + '.csv')
if osp.exists(csv_path):
os.remove(csv_path)
directory = osp.dirname(csv_path)
if not osp.exists(directory):
os.makedirs(directory, exist_ok=True)
with open(csv_path, 'w') as f:
writer = csv.writer(f)
writer.writerow([
'subj', 'session', 'trial', 'rt', 'object_response',
'category', 'condition', 'imagename'
])
return csv_path
def dump_results_to_csv(self, results: List[dict]) -> None:
"""Dump the results to a csv file.
Args:
results (List[dict]): A list of results.
"""
for i, result in enumerate(results):
img_name = result['img_name']
category = result['gt_category']
condition = 'NaN'
with open(self.csv_path, 'a') as f:
writer = csv.writer(f)
writer.writerow([
self.model_name, 1, i + 1, 'NaN',
result['pred_category'][0], category, condition, img_name
])
def compute_metrics(self, results: List[dict]) -> dict:
"""Compute the metrics from the results.
Args:
results (List[dict]): A list of results.
Returns:
dict: A dict of metrics.
"""
if get_rank() == 0:
self.dump_results_to_csv(results)
metrics = dict()
metrics['accuracy/top1'] = np.mean([
result['pred_category'][0] == result['gt_category']
for result in results
])
return metrics

View File

@ -0,0 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmpretrain.evaluation import ShapeBiasMetric
def test_shape_bias_metric():
data_sample = dict()
data_sample['pred_score'] = torch.rand(1000, )
data_sample['pred_label'] = torch.tensor(1)
data_sample['gt_label'] = torch.tensor(1)
data_sample['img_path'] = 'tests/airplane/test.JPEG'
evaluator = ShapeBiasMetric(
csv_dir='tests/data', dataset_name='test', model_name='test')
evaluator.process(None, [data_sample])

View File

@ -0,0 +1,284 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Modified from https://github.com/bethgelab/model-vs-human
import argparse
import os
import os.path as osp
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from mmengine.logging import MMLogger
from utils import FormatStrFormatter, ShapeBias
# global default boundary settings for thin gray transparent
# boundaries to avoid not being able to see the difference
# between two partially overlapping datapoints of the same color:
PLOTTING_EDGE_COLOR = (0.3, 0.3, 0.3, 0.3)
PLOTTING_EDGE_WIDTH = 0.02
ICONS_DIR = osp.join(
osp.dirname(__file__), '..', '..', 'resources', 'shape_bias_icons')
parser = argparse.ArgumentParser()
parser.add_argument('--csv-dir', type=str, help='directory of csv files')
parser.add_argument(
'--result-dir', type=str, help='directory to save plotting results')
parser.add_argument('--model-names', nargs='+', default=[], help='model name')
parser.add_argument(
'--colors',
nargs='+',
type=float,
default=[],
help= # noqa
'the colors for the plots of each model, and they should be in the same order as model_names' # noqa: E501
)
parser.add_argument(
'--markers',
nargs='+',
type=str,
default=[],
help= # noqa
'the markers for the plots of each model, and they should be in the same order as model_names' # noqa: E501
)
parser.add_argument(
'--plotting-names',
nargs='+',
default=[],
help= # noqa
'the plotting names for the plots of each model, and they should be in the same order as model_names' # noqa: E501
)
parser.add_argument(
'--delete-icons',
action='store_true',
help='whether to delete the icons after plotting')
humans = [
'subject-01', 'subject-02', 'subject-03', 'subject-04', 'subject-05',
'subject-06', 'subject-07', 'subject-08', 'subject-09', 'subject-10'
]
icon_names = [
'airplane.png', 'response_icons_vertical_reverse.png', 'bottle.png',
'car.png', 'oven.png', 'elephant.png', 'dog.png', 'boat.png', 'clock.png',
'chair.png', 'keyboard.png', 'bird.png', 'bicycle.png',
'response_icons_horizontal.png', 'cat.png', 'bear.png', 'colorbar.pdf',
'knife.png', 'response_icons_vertical.png', 'truck.png'
]
def read_csvs(csv_dir: str) -> pd.DataFrame:
"""Reads all csv files in a directory and returns a single dataframe.
Args:
csv_dir (str): directory of csv files.
Returns:
pd.DataFrame: dataframe containing all csv files
"""
df = pd.DataFrame()
for csv in os.listdir(csv_dir):
if csv.endswith('.csv'):
cur_df = pd.read_csv(osp.join(csv_dir, csv))
cur_df.columns = [c.lower() for c in cur_df.columns]
df = df.append(cur_df)
df.condition = df.condition.astype(str)
return df
def plot_shape_bias_matrixplot(args, analysis=ShapeBias()) -> None:
"""Plots a matrixplot of shape bias.
Args:
args (argparse.Namespace): arguments.
analysis (ShapeBias): shape bias analysis. Defaults to ShapeBias().
"""
mpl.rcParams['font.family'] = ['serif']
mpl.rcParams['font.serif'] = ['Times New Roman']
plt.figure(figsize=(9, 7))
df = read_csvs(args.csv_dir)
fontsize = 15
ticklength = 10
markersize = 250
label_size = 20
classes = df['category'].unique()
num_classes = len(classes)
# plot setup
fig = plt.figure(1, figsize=(12, 12), dpi=300.)
ax = plt.gca()
ax.set_xlim([0, 1])
ax.set_ylim([-.5, num_classes - 0.5])
# secondary reversed x axis
ax_top = ax.secondary_xaxis(
'top', functions=(lambda x: 1 - x, lambda x: 1 - x))
# labels, ticks
plt.tick_params(
axis='y', which='both', left=False, right=False, labelleft=False)
ax.set_ylabel('Shape categories', labelpad=60, fontsize=label_size)
ax.set_xlabel(
"Fraction of 'texture' decisions", fontsize=label_size, labelpad=25)
ax_top.set_xlabel(
"Fraction of 'shape' decisions", fontsize=label_size, labelpad=25)
ax.xaxis.set_major_formatter(FormatStrFormatter('%g'))
ax_top.xaxis.set_major_formatter(FormatStrFormatter('%g'))
ax.get_xaxis().set_ticks(np.arange(0, 1.1, 0.1))
ax_top.set_ticks(np.arange(0, 1.1, 0.1))
ax.tick_params(
axis='both', which='major', labelsize=fontsize, length=ticklength)
ax_top.tick_params(
axis='both', which='major', labelsize=fontsize, length=ticklength)
# arrows on x axes
plt.arrow(
x=0,
y=-1.75,
dx=1,
dy=0,
fc='black',
head_width=0.4,
head_length=0.03,
clip_on=False,
length_includes_head=True,
overhang=0.5)
plt.arrow(
x=1,
y=num_classes + 0.75,
dx=-1,
dy=0,
fc='black',
head_width=0.4,
head_length=0.03,
clip_on=False,
length_includes_head=True,
overhang=0.5)
# icons besides y axis
# determine order of icons
df_selection = df.loc[(df['subj'].isin(humans))]
class_avgs = []
for cl in classes:
df_class_selection = df_selection.query("category == '{}'".format(cl))
class_avgs.append(1 - analysis.analysis(
df=df_class_selection)['shape-bias'])
sorted_indices = np.argsort(class_avgs)
classes = classes[sorted_indices]
# icon placement is calculated in axis coordinates
WIDTH = 1 / num_classes
# placement left of yaxis (-WIDTH) plus some spacing (-.25*WIDTH)
XPOS = -1.25 * WIDTH
YPOS = -0.5
HEIGHT = 1
MARGINX = 1 / 10 * WIDTH # vertical whitespace between icons
MARGINY = 1 / 10 * HEIGHT # horizontal whitespace between icons
left = XPOS + MARGINX
right = XPOS + WIDTH - MARGINX
for i in range(num_classes):
bottom = i + MARGINY + YPOS
top = (i + 1) - MARGINY + YPOS
iconpath = osp.join(ICONS_DIR, '{}.png'.format(classes[i]))
plt.imshow(
plt.imread(iconpath),
extent=[left, right, bottom, top],
aspect='auto',
clip_on=False)
# plot horizontal intersection lines
for i in range(num_classes - 1):
plt.plot([0, 1], [i + .5, i + .5],
c='gray',
linestyle='dotted',
alpha=0.4)
# plot average shapebias + scatter points
for i in range(len(args.model_names)):
df_selection = df.loc[(df['subj'].isin(args.model_names[i]))]
result_df = analysis.analysis(df=df_selection)
avg = 1 - result_df['shape-bias']
ax.plot([avg, avg], [-1, num_classes], color=args.colors[i])
class_avgs = []
for cl in classes:
df_class_selection = df_selection.query(
"category == '{}'".format(cl))
class_avgs.append(1 - analysis.analysis(
df=df_class_selection)['shape-bias'])
ax.scatter(
class_avgs,
classes,
color=args.colors[i],
marker=args.markers[i],
label=args.plotting_names[i],
s=markersize,
clip_on=False,
edgecolors=PLOTTING_EDGE_COLOR,
linewidths=PLOTTING_EDGE_WIDTH,
zorder=3)
plt.legend(frameon=True, labelspacing=1, loc=9)
figure_path = osp.join(args.result_dir,
'cue-conflict_shape-bias_matrixplot.pdf')
fig.savefig(figure_path, bbox_inches='tight')
plt.close()
def check_icons() -> bool:
"""Check if icons are present, if not download them."""
if not osp.exists(ICONS_DIR):
return False
for icon_name in icon_names:
if not osp.exists(osp.join(ICONS_DIR, icon_name)):
return False
return True
if __name__ == '__main__':
if not check_icons():
root_url = 'https://github.com/bethgelab/model-vs-human/raw/master/assets/icons' # noqa: E501
os.makedirs(ICONS_DIR, exist_ok=True)
MMLogger.get_current_instance().info(
f'Downloading icons to {ICONS_DIR}')
for icon_name in icon_names:
url = osp.join(root_url, icon_name)
os.system('wget -O {} {}'.format(
osp.join(ICONS_DIR, icon_name), url))
args = parser.parse_args()
assert len(args.model_names) * 3 == len(args.colors), 'Number of colors \
must be 3 times the number of models. Every three colors are the RGB \
values for one model.'
# preprocess colors
args.colors = [c / 255. for c in args.colors]
colors = []
for i in range(len(args.model_names)):
colors.append(args.colors[3 * i:3 * i + 3])
args.colors = colors
args.colors.append([165 / 255., 30 / 255., 55 / 255.]) # human color
# if plotting names are not specified, use model names
if len(args.plotting_names) == 0:
args.plotting_names = args.model_names
# preprocess markers
args.markers.append('D') # human marker
# preprocess model names
args.model_names = [[m] for m in args.model_names]
args.model_names.append(humans)
# preprocess plotting names
args.plotting_names.append('Humans')
plot_shape_bias_matrixplot(args)
if args.delete_icons:
os.system('rm -rf {}'.format(ICONS_DIR))

View File

@ -0,0 +1,277 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Modified from https://github.com/bethgelab/model-vs-human
from typing import Any, Dict, List, Optional
import matplotlib as mpl
import pandas as pd
from matplotlib import _api
from matplotlib import transforms as mtransforms
class _DummyAxis:
"""Define the minimal interface for a dummy axis.
Args:
minpos (float): The minimum positive value for the axis. Defaults to 0.
"""
__name__ = 'dummy'
# Once the deprecation elapses, replace dataLim and viewLim by plain
# _view_interval and _data_interval private tuples.
dataLim = _api.deprecate_privatize_attribute(
'3.6', alternative='get_data_interval() and set_data_interval()')
viewLim = _api.deprecate_privatize_attribute(
'3.6', alternative='get_view_interval() and set_view_interval()')
def __init__(self, minpos: float = 0) -> None:
self._dataLim = mtransforms.Bbox.unit()
self._viewLim = mtransforms.Bbox.unit()
self._minpos = minpos
def get_view_interval(self) -> Dict:
"""Return the view interval as a tuple (*vmin*, *vmax*)."""
return self._viewLim.intervalx
def set_view_interval(self, vmin: float, vmax: float) -> None:
"""Set the view interval to (*vmin*, *vmax*)."""
self._viewLim.intervalx = vmin, vmax
def get_minpos(self) -> float:
"""Return the minimum positive value for the axis."""
return self._minpos
def get_data_interval(self) -> Dict:
"""Return the data interval as a tuple (*vmin*, *vmax*)."""
return self._dataLim.intervalx
def set_data_interval(self, vmin: float, vmax: float) -> None:
"""Set the data interval to (*vmin*, *vmax*)."""
self._dataLim.intervalx = vmin, vmax
def get_tick_space(self) -> int:
"""Return the number of ticks to use."""
# Just use the long-standing default of nbins==9
return 9
class TickHelper:
"""A helper class for ticks and tick labels."""
axis = None
def set_axis(self, axis: Any) -> None:
"""Set the axis instance."""
self.axis = axis
def create_dummy_axis(self, **kwargs) -> None:
"""Create a dummy axis if no axis is set."""
if self.axis is None:
self.axis = _DummyAxis(**kwargs)
@_api.deprecated('3.5', alternative='`.Axis.set_view_interval`')
def set_view_interval(self, vmin: float, vmax: float) -> None:
"""Set the view interval to (*vmin*, *vmax*)."""
self.axis.set_view_interval(vmin, vmax)
@_api.deprecated('3.5', alternative='`.Axis.set_data_interval`')
def set_data_interval(self, vmin: float, vmax: float) -> None:
"""Set the data interval to (*vmin*, *vmax*)."""
self.axis.set_data_interval(vmin, vmax)
@_api.deprecated(
'3.5',
alternative='`.Axis.set_view_interval` and `.Axis.set_data_interval`')
def set_bounds(self, vmin: float, vmax: float) -> None:
"""Set the view and data interval to (*vmin*, *vmax*)."""
self.set_view_interval(vmin, vmax)
self.set_data_interval(vmin, vmax)
class Formatter(TickHelper):
"""Create a string based on a tick value and location."""
# some classes want to see all the locs to help format
# individual ones
locs = []
def __call__(self, x: str, pos: Optional[Any] = None) -> str:
"""Return the format for tick value *x* at position pos.
``pos=None`` indicates an unspecified location.
This method must be overridden in the derived class.
Args:
x (str): The tick value.
pos (Optional[Any]): The tick position. Defaults to None.
"""
raise NotImplementedError('Derived must override')
def format_ticks(self, values: pd.Series) -> List[str]:
"""Return the tick labels for all the ticks at once.
Args:
values (pd.Series): The tick values.
Returns:
List[str]: The tick labels.
"""
self.set_locs(values)
return [self(value, i) for i, value in enumerate(values)]
def format_data(self, value: Any) -> str:
"""Return the full string representation of the value with the position
unspecified.
Args:
value (Any): The tick value.
Returns:
str: The full string representation of the value.
"""
return self.__call__(value)
def format_data_short(self, value: Any) -> str:
"""Return a short string version of the tick value.
Defaults to the position-independent long value.
Args:
value (Any): The tick value.
Returns:
str: The short string representation of the value.
"""
return self.format_data(value)
def get_offset(self) -> str:
"""Return the offset string."""
return ''
def set_locs(self, locs: List[Any]) -> None:
"""Set the locations of the ticks.
This method is called before computing the tick labels because some
formatters need to know all tick locations to do so.
"""
self.locs = locs
@staticmethod
def fix_minus(s: str) -> str:
"""Some classes may want to replace a hyphen for minus with the proper
Unicode symbol (U+2212) for typographical correctness.
This is a
helper method to perform such a replacement when it is enabled via
:rc:`axes.unicode_minus`.
Args:
s (str): The string to replace the hyphen with the Unicode symbol.
"""
return (s.replace('-', '\N{MINUS SIGN}')
if mpl.rcParams['axes.unicode_minus'] else s)
def _set_locator(self, locator: Any) -> None:
"""Subclasses may want to override this to set a locator."""
pass
class FormatStrFormatter(Formatter):
"""Use an old-style ('%' operator) format string to format the tick.
The format string should have a single variable format (%) in it.
It will be applied to the value (not the position) of the tick.
Negative numeric values will use a dash, not a Unicode minus; use mathtext
to get a Unicode minus by wrapping the format specifier with $ (e.g.
"$%g$").
Args:
fmt (str): Format string.
"""
def __init__(self, fmt: str) -> None:
self.fmt = fmt
def __call__(self, x: str, pos: Optional[Any]) -> str:
"""Return the formatted label string.
Only the value *x* is formatted. The position is ignored.
Args:
x (str): The value to format.
pos (Any): The position of the tick. Ignored.
"""
return self.fmt % x
class ShapeBias:
"""Compute the shape bias of a model.
Reference: `ImageNet-trained CNNs are biased towards texture;
increasing shape bias improves accuracy and robustness
<https://arxiv.org/abs/1811.12231>`_.
"""
num_input_models = 1
def __init__(self) -> None:
super().__init__()
self.plotting_name = 'shape-bias'
@staticmethod
def _check_dataframe(df: pd.DataFrame) -> None:
"""Check that the dataframe is valid."""
assert len(df) > 0, 'empty dataframe'
def analysis(self, df: pd.DataFrame) -> Dict[str, float]:
"""Compute the shape bias of a model.
Args:
df (pd.DataFrame): The dataframe containing the data.
Returns:
Dict[str, float]: The shape bias.
"""
self._check_dataframe(df)
df = df.copy()
df['correct_texture'] = df['imagename'].apply(
self.get_texture_category)
df['correct_shape'] = df['category']
# remove those rows where shape = texture, i.e. no cue conflict present
df2 = df.loc[df.correct_shape != df.correct_texture]
fraction_correct_shape = len(
df2.loc[df2.object_response == df2.correct_shape]) / len(df)
fraction_correct_texture = len(
df2.loc[df2.object_response == df2.correct_texture]) / len(df)
shape_bias = fraction_correct_shape / (
fraction_correct_shape + fraction_correct_texture)
result_dict = {
'fraction-correct-shape': fraction_correct_shape,
'fraction-correct-texture': fraction_correct_texture,
'shape-bias': shape_bias
}
return result_dict
def get_texture_category(self, imagename: str) -> str:
"""Return texture category from imagename.
e.g. 'XXX_dog10-bird2.png' -> 'bird '
Args:
imagename (str): Name of the image.
Returns:
str: Texture category.
"""
assert type(imagename) is str
# remove unnecessary words
a = imagename.split('_')[-1]
# remove .png etc.
b = a.split('.')[0]
# get texture category (last word)
c = b.split('-')[-1]
# remove number, e.g. 'bird2' -> 'bird'
d = ''.join([i for i in c if not i.isdigit()])
return d