From feb0814b2f5cf75880b616bb69427f59cfa84c55 Mon Sep 17 00:00:00 2001 From: Wang Xiang Date: Mon, 3 Jul 2023 11:39:23 +0800 Subject: [PATCH] [Feature] Transfer shape-bias tool from mmselfsup (#1658) * Transfer shape-bias tool from mmselfsup * run shape-bias successfully, add CN docs * fix unit test bug * add shape_bias to index.rst in docs * modified mistakes in shape-bias docs --- docs/en/index.rst | 1 + docs/en/useful_tools/shape_bias.md | 100 ++++++ docs/zh_CN/index.rst | 1 + docs/zh_CN/useful_tools/shape_bias.md | 96 ++++++ mmpretrain/evaluation/metrics/__init__.py | 3 +- .../evaluation/metrics/shape_bias_label.py | 172 +++++++++++ .../test_metrics/test_shape_bias_metric.py | 15 + tools/analysis_tools/shape_bias.py | 284 ++++++++++++++++++ tools/analysis_tools/utils.py | 277 +++++++++++++++++ 9 files changed, 948 insertions(+), 1 deletion(-) create mode 100644 docs/en/useful_tools/shape_bias.md create mode 100644 docs/zh_CN/useful_tools/shape_bias.md create mode 100644 mmpretrain/evaluation/metrics/shape_bias_label.py create mode 100644 tests/test_evaluation/test_metrics/test_shape_bias_metric.py create mode 100644 tools/analysis_tools/shape_bias.py create mode 100644 tools/analysis_tools/utils.py diff --git a/docs/en/index.rst b/docs/en/index.rst index 57f68ad6..d16a32d6 100644 --- a/docs/en/index.rst +++ b/docs/en/index.rst @@ -104,6 +104,7 @@ We always welcome *PRs* and *Issues* for the betterment of MMPretrain. useful_tools/log_result_analysis.md useful_tools/complexity_analysis.md useful_tools/confusion_matrix.md + useful_tools/shape_bias.md .. toctree:: :maxdepth: 1 diff --git a/docs/en/useful_tools/shape_bias.md b/docs/en/useful_tools/shape_bias.md new file mode 100644 index 00000000..ea4f96c4 --- /dev/null +++ b/docs/en/useful_tools/shape_bias.md @@ -0,0 +1,100 @@ +## Shape Bias Tool Usage + +Shape bias measures how a model relies the shapes, compared to texture, to sense the semantics in images. For more details, +we recommend interested readers to this [paper](https://arxiv.org/abs/2106.07411). MMPretrain provide an off-the-shelf toolbox to +obtain the shape bias of a classification model. You can following these steps below: + +### Prepare the dataset + +First you should download the [cue-conflict](https://github.com/bethgelab/model-vs-human/releases/download/v0.1/cue-conflict.tar.gz) to `data` folder, +and then unzip this dataset. After that, you `data` folder should have the following structure: + +```text +data +├──cue-conflict +| |──airplane +| |──bear +| ... +| |── truck +``` + +### Modify the config for classification + +We run the shape-bias tool on a ViT-base model with masked autoencoder pretraining. Its config file is `configs/mae/benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py`, and its checkpoint is downloaded from [this link](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-base-p16_8xb512-fp16-coslr-1600e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k_20220825-cf70aa21.pth). Replace the original test_pipeline, test_dataloader and test_evaluation with the following configurations: + +```python +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='ResizeEdge', + scale=256, + edge='short', + backend='pillow'), + dict(type='CenterCrop', crop_size=224), + dict(type='PackInputs') +] +test_dataloader = dict( + pin_memory=True, + collate_fn=dict(type='default_collate'), + batch_size=32, + num_workers=4, + dataset=dict( + type='CustomDataset', + data_root='data/cue-conflict', + pipeline=test_pipeline, + _delete_=True), + sampler=dict(type='DefaultSampler', shuffle=False), + persistent_workers=True, + drop_last=False) +test_evaluator = dict( + type='mmpretrain.ShapeBiasMetric', + _delete_=True, + csv_dir='work_dirs/shape_bias', + model_name='mae') +``` + +Please note you should make custom modifications to the `csv_dir` and `model_name` above. I renamed my modified sample config file as `vit-base-p16_8xb128-coslr-100e_in1k_shape-bias.py` in the folder `configs/mae/benchmarks/`. + +### Inference your model with above modified config file + +Then you should inferece your model on the `cue-conflict` dataset with the your modified config file. + +```shell +# For PyTorch +bash tools/dist_test.sh $CONFIG $CHECKPOINT +``` + +**Description of all arguments**: + +- `$CONFIG`: The path of your modified config file. +- `$CHECKPOINT`: The path or link of the checkpoint file. + +```shell +# Example +bash tools/dist_test.sh configs/mae/benchmarks/vit-base-p16_8xb128-coslr-100e_in1k_shape-bias.py https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-base-p16_8xb512-fp16-coslr-1600e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k_20220825-cf70aa21.pth 1 +``` + +After that, you should obtain a csv file in `csv_dir` folder, named `cue-conflict_model-name_session-1.csv`. Besides this file, you should also download these [csv files](https://github.com/bethgelab/model-vs-human/tree/master/raw-data/cue-conflict) to the +`csv_dir`. + +### Plot shape bias + +Then we can start to plot the shape bias: + +```shell +python tools/analysis_tools/shape_bias.py --csv-dir $CSV_DIR --result-dir $RESULT_DIR --colors $RGB --markers o --plotting-names $YOUR_MODEL_NAME --model-names $YOUR_MODEL_NAME +``` + +**Description of all arguments**: + +- `--csv-dir $CSV_DIR`, the same directory to save these csv files. +- `--result-dir $RESULT_DIR`, the directory to output the result named `cue-conflict_shape-bias_matrixplot.pdf`. +- `--colors $RGB`, should be the RGB values, formatted in R G B, e.g. 100 100 100, and can be multiple RGB values, if you want to plot the shape bias of several models. +- `--plotting-names $YOUR_MODEL_NAME`, the name of the legend in the shape bias figure, and you can set it as your model name. If you want to plot several models, plotting_names can be multiple values. +- `model-names $YOUR_MODEL_NAME`, should be the same name specified in your config, and can be multiple names if you want to plot the shape bias of several models. + +Please note, every three values for `--colors` corresponds to one value for `--model-names`. After all of above steps, you are expected to obtain the following figure. + +
+ +
diff --git a/docs/zh_CN/index.rst b/docs/zh_CN/index.rst index 59ffad8e..ca57faac 100644 --- a/docs/zh_CN/index.rst +++ b/docs/zh_CN/index.rst @@ -90,6 +90,7 @@ MMPretrain 上手路线 useful_tools/log_result_analysis.md useful_tools/complexity_analysis.md useful_tools/confusion_matrix.md + useful_tools/shape_bias.md .. toctree:: :maxdepth: 1 diff --git a/docs/zh_CN/useful_tools/shape_bias.md b/docs/zh_CN/useful_tools/shape_bias.md new file mode 100644 index 00000000..f557197d --- /dev/null +++ b/docs/zh_CN/useful_tools/shape_bias.md @@ -0,0 +1,96 @@ +## 形状偏差(Shape Bias)工具用法 + +形状偏差(shape bias)衡量模型与纹理相比,如何依赖形状来感知图像中的语义。关于更多细节,我们向感兴趣的读者推荐这篇[论文](https://arxiv.org/abs/2106.07411) 。MMPretrain提供现成的工具箱来获得分类模型的形状偏差。您可以按照以下步骤操作: + +### 准备数据集 + +首先你应该下载[cue-conflict](https://github.com/bethgelab/model-vs-human/releases/download/v0.1/cue-conflict.tar.gz) 到`data`文件夹,然后解压缩这个数据集。之后,你的`data`文件夹应具有一下结构: + +```text +data +├──cue-conflict +| |──airplane +| |──bear +| ... +| |── truck +``` + +### 修改分类配置 + +我们在使用MAE预训练的ViT-base模型上运行形状偏移工具。它的配置文件为`configs/mae/benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py`,它的检查点可从[此链接](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-base-p16_8xb512-fp16-coslr-1600e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k_20220825-cf70aa21.pth) 下载。将原始配置中的test_pipeline, test_dataloader和test_evaluation替换为以下配置: + +```python +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='ResizeEdge', + scale=256, + edge='short', + backend='pillow'), + dict(type='CenterCrop', crop_size=224), + dict(type='PackInputs') +] +test_dataloader = dict( + pin_memory=True, + collate_fn=dict(type='default_collate'), + batch_size=32, + num_workers=4, + dataset=dict( + type='CustomDataset', + data_root='data/cue-conflict', + pipeline=test_pipeline, + _delete_=True), + sampler=dict(type='DefaultSampler', shuffle=False), + persistent_workers=True, + drop_last=False) +test_evaluator = dict( + type='mmpretrain.ShapeBiasMetric', + _delete_=True, + csv_dir='work_dirs/shape_bias', + model_name='mae') +``` + +请注意,你应该对上面的`csv_dir`和`model_name`进行自定义修改。我把修改后的示例配置文件重命名为`configs/mae/benchmarks/`文件夹中的`vit-base-p16_8xb128-coslr-100e_in1k_shape-bias.py`文件。 + +### 用上面修改后的配置文件在你的模型上做推断 + +然后,你应该使用修改后的配置文件在`cue-conflict`数据集上推断你的模型。 + +```shell +# For PyTorch +bash tools/dist_test.sh $CONFIG $CHECKPOINT +``` + +**所有参数的说明**: + +- `$CONFIG`: 修改后的配置文件的路径。 +- `$CHECKPOINT`: 检查点文件的路径或链接。 + +```shell +# Example +bash tools/dist_test.sh configs/mae/benchmarks/vit-base-p16_8xb128-coslr-100e_in1k_shape-bias.py https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-base-p16_8xb512-fp16-coslr-1600e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k_20220825-cf70aa21.pth 1 +``` + +之后,你应该在`csv_dir`文件夹中获得一个名为`cue-conflict_model-name_session-1.csv`的csv文件。除了这个文件以外,你还应该下载这些[csv文件](https://github.com/bethgelab/model-vs-human/tree/master/raw-data/cue-conflict) 到`csv_dir`。 + +### 绘制形状偏差图 + +然后我们可以开始绘制形状偏差图: + +```shell +python tools/analysis_tools/shape_bias.py --csv-dir $CSV_DIR --result-dir $RESULT_DIR --colors $RGB --markers o --plotting-names $YOUR_MODEL_NAME --model-names $YOUR_MODEL_NAME +``` + +**所有参数的说明**: + +- `--csv-dir $CSV_DIR`, 与保存这些csv文件的目录相同。 +- `--result-dir $RESULT_DIR`, 输出名为`cue-conflict_shape-bias_matrixplot.pdf`的结果的目录。 +- `--colors $RGB`, 应该是RGB值,格式为R G B,例如100 100 100,如果你想绘制几个模型的形状偏差,可以是多个RGB值。 +- `--plotting-names $YOUR_MODEL_NAME`, 形状偏移图中图例的名称,您可以将其设置为模型名称。如果要绘制多个模型,plotting_names可以是多个值。 +- `model-names $YOUR_MODEL_NAME`, 应与配置中指定的名称相同,如果要绘制多个模型的形状偏差,则可以是多个名称。 + +请注意,`--colors`的每三个值对应于`--model-names`的一个值。完成以上所有步骤后,你将获得下图。 + +
+ +
diff --git a/mmpretrain/evaluation/metrics/__init__.py b/mmpretrain/evaluation/metrics/__init__.py index 7f5a4f36..fd46de78 100644 --- a/mmpretrain/evaluation/metrics/__init__.py +++ b/mmpretrain/evaluation/metrics/__init__.py @@ -6,6 +6,7 @@ from .multi_task import MultiTasksMetric from .nocaps import NocapsSave from .retrieval import RetrievalAveragePrecision, RetrievalRecall from .scienceqa import ScienceQAMetric +from .shape_bias_label import ShapeBiasMetric from .single_label import Accuracy, ConfusionMatrix, SingleLabelMetric from .visual_grounding_eval import VisualGroundingMetric from .voc_multi_label import VOCAveragePrecision, VOCMultiLabelMetric @@ -16,5 +17,5 @@ __all__ = [ 'MultiTasksMetric', 'VOCAveragePrecision', 'VOCMultiLabelMetric', 'ConfusionMatrix', 'RetrievalRecall', 'VQAAcc', 'ReportVQA', 'COCOCaption', 'VisualGroundingMetric', 'ScienceQAMetric', 'GQAAcc', 'NocapsSave', - 'RetrievalAveragePrecision' + 'RetrievalAveragePrecision', 'ShapeBiasMetric' ] diff --git a/mmpretrain/evaluation/metrics/shape_bias_label.py b/mmpretrain/evaluation/metrics/shape_bias_label.py new file mode 100644 index 00000000..27c80a36 --- /dev/null +++ b/mmpretrain/evaluation/metrics/shape_bias_label.py @@ -0,0 +1,172 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import csv +import os +import os.path as osp +from typing import List, Sequence + +import numpy as np +import torch +from mmengine.dist.utils import get_rank +from mmengine.evaluator import BaseMetric + +from mmpretrain.registry import METRICS + + +@METRICS.register_module() +class ShapeBiasMetric(BaseMetric): + """Evaluate the model on ``cue_conflict`` dataset. + + This module will evaluate the model on an OOD dataset, cue_conflict, in + order to measure the shape bias of the model. In addition to compuate the + Top-1 accuracy, this module also generate a csv file to record the + detailed prediction results, such that this csv file can be used to + generate the shape bias curve. + + Args: + csv_dir (str): The directory to save the csv file. + model_name (str): The name of the csv file. Please note that the + model name should be an unique identifier. + dataset_name (str): The name of the dataset. Default: 'cue_conflict'. + """ + + # mapping several classes from ImageNet-1K to the same category + airplane_indices = [404] + bear_indices = [294, 295, 296, 297] + bicycle_indices = [444, 671] + bird_indices = [ + 8, 10, 11, 12, 13, 14, 15, 16, 18, 19, 20, 22, 23, 24, 80, 81, 82, 83, + 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 98, 99, 100, 127, 128, 129, + 130, 131, 132, 133, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, + 145 + ] + boat_indices = [472, 554, 625, 814, 914] + bottle_indices = [440, 720, 737, 898, 899, 901, 907] + car_indices = [436, 511, 817] + cat_indices = [281, 282, 283, 284, 285, 286] + chair_indices = [423, 559, 765, 857] + clock_indices = [409, 530, 892] + dog_indices = [ + 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, + 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, + 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 193, 194, + 195, 196, 197, 198, 199, 200, 201, 202, 203, 205, 206, 207, 208, 209, + 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, + 224, 225, 226, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, + 239, 240, 241, 243, 244, 245, 246, 247, 248, 249, 250, 252, 253, 254, + 255, 256, 257, 259, 261, 262, 263, 265, 266, 267, 268 + ] + elephant_indices = [385, 386] + keyboard_indices = [508, 878] + knife_indices = [499] + oven_indices = [766] + truck_indices = [555, 569, 656, 675, 717, 734, 864, 867] + + def __init__(self, + csv_dir: str, + model_name: str, + dataset_name: str = 'cue_conflict', + **kwargs) -> None: + super().__init__(**kwargs) + + self.categories = sorted([ + 'knife', 'keyboard', 'elephant', 'bicycle', 'airplane', 'clock', + 'oven', 'chair', 'bear', 'boat', 'cat', 'bottle', 'truck', 'car', + 'bird', 'dog' + ]) + self.csv_dir = csv_dir + self.model_name = model_name + self.dataset_name = dataset_name + if get_rank() == 0: + self.csv_path = self.create_csv() + + def process(self, data_batch, data_samples: Sequence[dict]) -> None: + """Process one batch of data samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + for data_sample in data_samples: + result = dict() + if 'pred_score' in data_sample: + result['pred_score'] = data_sample['pred_score'].cpu() + else: + result['pred_label'] = data_sample['pred_label'].cpu() + result['gt_label'] = data_sample['gt_label'].cpu() + result['gt_category'] = data_sample['img_path'].split('/')[-2] + result['img_name'] = data_sample['img_path'].split('/')[-1] + + aggregated_category_probabilities = [] + # get the prediction for each category of current instance + for category in self.categories: + category_indices = getattr(self, f'{category}_indices') + category_probabilities = torch.gather( + result['pred_score'], 0, + torch.tensor(category_indices)).mean() + aggregated_category_probabilities.append( + category_probabilities) + # sort the probabilities in descending order + pred_indices = torch.stack(aggregated_category_probabilities + ).argsort(descending=True).numpy() + result['pred_category'] = np.take(self.categories, pred_indices) + + # Save the result to `self.results`. + self.results.append(result) + + def create_csv(self) -> str: + """Create a csv file to store the results.""" + session_name = 'session-1' + csv_path = osp.join( + self.csv_dir, self.dataset_name + '_' + self.model_name + '_' + + session_name + '.csv') + if osp.exists(csv_path): + os.remove(csv_path) + directory = osp.dirname(csv_path) + if not osp.exists(directory): + os.makedirs(directory, exist_ok=True) + with open(csv_path, 'w') as f: + writer = csv.writer(f) + writer.writerow([ + 'subj', 'session', 'trial', 'rt', 'object_response', + 'category', 'condition', 'imagename' + ]) + return csv_path + + def dump_results_to_csv(self, results: List[dict]) -> None: + """Dump the results to a csv file. + + Args: + results (List[dict]): A list of results. + """ + for i, result in enumerate(results): + img_name = result['img_name'] + category = result['gt_category'] + condition = 'NaN' + with open(self.csv_path, 'a') as f: + writer = csv.writer(f) + writer.writerow([ + self.model_name, 1, i + 1, 'NaN', + result['pred_category'][0], category, condition, img_name + ]) + + def compute_metrics(self, results: List[dict]) -> dict: + """Compute the metrics from the results. + + Args: + results (List[dict]): A list of results. + + Returns: + dict: A dict of metrics. + """ + if get_rank() == 0: + self.dump_results_to_csv(results) + metrics = dict() + metrics['accuracy/top1'] = np.mean([ + result['pred_category'][0] == result['gt_category'] + for result in results + ]) + + return metrics diff --git a/tests/test_evaluation/test_metrics/test_shape_bias_metric.py b/tests/test_evaluation/test_metrics/test_shape_bias_metric.py new file mode 100644 index 00000000..d57ace89 --- /dev/null +++ b/tests/test_evaluation/test_metrics/test_shape_bias_metric.py @@ -0,0 +1,15 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmpretrain.evaluation import ShapeBiasMetric + + +def test_shape_bias_metric(): + data_sample = dict() + data_sample['pred_score'] = torch.rand(1000, ) + data_sample['pred_label'] = torch.tensor(1) + data_sample['gt_label'] = torch.tensor(1) + data_sample['img_path'] = 'tests/airplane/test.JPEG' + evaluator = ShapeBiasMetric( + csv_dir='tests/data', dataset_name='test', model_name='test') + evaluator.process(None, [data_sample]) diff --git a/tools/analysis_tools/shape_bias.py b/tools/analysis_tools/shape_bias.py new file mode 100644 index 00000000..52e9fe69 --- /dev/null +++ b/tools/analysis_tools/shape_bias.py @@ -0,0 +1,284 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modified from https://github.com/bethgelab/model-vs-human +import argparse +import os +import os.path as osp + +import matplotlib as mpl +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from mmengine.logging import MMLogger +from utils import FormatStrFormatter, ShapeBias + +# global default boundary settings for thin gray transparent +# boundaries to avoid not being able to see the difference +# between two partially overlapping datapoints of the same color: +PLOTTING_EDGE_COLOR = (0.3, 0.3, 0.3, 0.3) +PLOTTING_EDGE_WIDTH = 0.02 +ICONS_DIR = osp.join( + osp.dirname(__file__), '..', '..', 'resources', 'shape_bias_icons') + +parser = argparse.ArgumentParser() +parser.add_argument('--csv-dir', type=str, help='directory of csv files') +parser.add_argument( + '--result-dir', type=str, help='directory to save plotting results') +parser.add_argument('--model-names', nargs='+', default=[], help='model name') +parser.add_argument( + '--colors', + nargs='+', + type=float, + default=[], + help= # noqa + 'the colors for the plots of each model, and they should be in the same order as model_names' # noqa: E501 +) +parser.add_argument( + '--markers', + nargs='+', + type=str, + default=[], + help= # noqa + 'the markers for the plots of each model, and they should be in the same order as model_names' # noqa: E501 +) +parser.add_argument( + '--plotting-names', + nargs='+', + default=[], + help= # noqa + 'the plotting names for the plots of each model, and they should be in the same order as model_names' # noqa: E501 +) +parser.add_argument( + '--delete-icons', + action='store_true', + help='whether to delete the icons after plotting') + +humans = [ + 'subject-01', 'subject-02', 'subject-03', 'subject-04', 'subject-05', + 'subject-06', 'subject-07', 'subject-08', 'subject-09', 'subject-10' +] + +icon_names = [ + 'airplane.png', 'response_icons_vertical_reverse.png', 'bottle.png', + 'car.png', 'oven.png', 'elephant.png', 'dog.png', 'boat.png', 'clock.png', + 'chair.png', 'keyboard.png', 'bird.png', 'bicycle.png', + 'response_icons_horizontal.png', 'cat.png', 'bear.png', 'colorbar.pdf', + 'knife.png', 'response_icons_vertical.png', 'truck.png' +] + + +def read_csvs(csv_dir: str) -> pd.DataFrame: + """Reads all csv files in a directory and returns a single dataframe. + + Args: + csv_dir (str): directory of csv files. + + Returns: + pd.DataFrame: dataframe containing all csv files + """ + df = pd.DataFrame() + for csv in os.listdir(csv_dir): + if csv.endswith('.csv'): + cur_df = pd.read_csv(osp.join(csv_dir, csv)) + cur_df.columns = [c.lower() for c in cur_df.columns] + df = df.append(cur_df) + df.condition = df.condition.astype(str) + return df + + +def plot_shape_bias_matrixplot(args, analysis=ShapeBias()) -> None: + """Plots a matrixplot of shape bias. + + Args: + args (argparse.Namespace): arguments. + analysis (ShapeBias): shape bias analysis. Defaults to ShapeBias(). + """ + mpl.rcParams['font.family'] = ['serif'] + mpl.rcParams['font.serif'] = ['Times New Roman'] + + plt.figure(figsize=(9, 7)) + df = read_csvs(args.csv_dir) + + fontsize = 15 + ticklength = 10 + markersize = 250 + label_size = 20 + + classes = df['category'].unique() + num_classes = len(classes) + + # plot setup + fig = plt.figure(1, figsize=(12, 12), dpi=300.) + ax = plt.gca() + + ax.set_xlim([0, 1]) + ax.set_ylim([-.5, num_classes - 0.5]) + + # secondary reversed x axis + ax_top = ax.secondary_xaxis( + 'top', functions=(lambda x: 1 - x, lambda x: 1 - x)) + + # labels, ticks + plt.tick_params( + axis='y', which='both', left=False, right=False, labelleft=False) + ax.set_ylabel('Shape categories', labelpad=60, fontsize=label_size) + ax.set_xlabel( + "Fraction of 'texture' decisions", fontsize=label_size, labelpad=25) + ax_top.set_xlabel( + "Fraction of 'shape' decisions", fontsize=label_size, labelpad=25) + ax.xaxis.set_major_formatter(FormatStrFormatter('%g')) + ax_top.xaxis.set_major_formatter(FormatStrFormatter('%g')) + ax.get_xaxis().set_ticks(np.arange(0, 1.1, 0.1)) + ax_top.set_ticks(np.arange(0, 1.1, 0.1)) + ax.tick_params( + axis='both', which='major', labelsize=fontsize, length=ticklength) + ax_top.tick_params( + axis='both', which='major', labelsize=fontsize, length=ticklength) + + # arrows on x axes + plt.arrow( + x=0, + y=-1.75, + dx=1, + dy=0, + fc='black', + head_width=0.4, + head_length=0.03, + clip_on=False, + length_includes_head=True, + overhang=0.5) + plt.arrow( + x=1, + y=num_classes + 0.75, + dx=-1, + dy=0, + fc='black', + head_width=0.4, + head_length=0.03, + clip_on=False, + length_includes_head=True, + overhang=0.5) + + # icons besides y axis + # determine order of icons + df_selection = df.loc[(df['subj'].isin(humans))] + class_avgs = [] + for cl in classes: + df_class_selection = df_selection.query("category == '{}'".format(cl)) + class_avgs.append(1 - analysis.analysis( + df=df_class_selection)['shape-bias']) + sorted_indices = np.argsort(class_avgs) + classes = classes[sorted_indices] + + # icon placement is calculated in axis coordinates + WIDTH = 1 / num_classes + # placement left of yaxis (-WIDTH) plus some spacing (-.25*WIDTH) + XPOS = -1.25 * WIDTH + YPOS = -0.5 + HEIGHT = 1 + MARGINX = 1 / 10 * WIDTH # vertical whitespace between icons + MARGINY = 1 / 10 * HEIGHT # horizontal whitespace between icons + + left = XPOS + MARGINX + right = XPOS + WIDTH - MARGINX + + for i in range(num_classes): + bottom = i + MARGINY + YPOS + top = (i + 1) - MARGINY + YPOS + iconpath = osp.join(ICONS_DIR, '{}.png'.format(classes[i])) + plt.imshow( + plt.imread(iconpath), + extent=[left, right, bottom, top], + aspect='auto', + clip_on=False) + + # plot horizontal intersection lines + for i in range(num_classes - 1): + plt.plot([0, 1], [i + .5, i + .5], + c='gray', + linestyle='dotted', + alpha=0.4) + + # plot average shapebias + scatter points + for i in range(len(args.model_names)): + df_selection = df.loc[(df['subj'].isin(args.model_names[i]))] + result_df = analysis.analysis(df=df_selection) + avg = 1 - result_df['shape-bias'] + ax.plot([avg, avg], [-1, num_classes], color=args.colors[i]) + class_avgs = [] + for cl in classes: + df_class_selection = df_selection.query( + "category == '{}'".format(cl)) + class_avgs.append(1 - analysis.analysis( + df=df_class_selection)['shape-bias']) + + ax.scatter( + class_avgs, + classes, + color=args.colors[i], + marker=args.markers[i], + label=args.plotting_names[i], + s=markersize, + clip_on=False, + edgecolors=PLOTTING_EDGE_COLOR, + linewidths=PLOTTING_EDGE_WIDTH, + zorder=3) + plt.legend(frameon=True, labelspacing=1, loc=9) + + figure_path = osp.join(args.result_dir, + 'cue-conflict_shape-bias_matrixplot.pdf') + fig.savefig(figure_path, bbox_inches='tight') + plt.close() + + +def check_icons() -> bool: + """Check if icons are present, if not download them.""" + if not osp.exists(ICONS_DIR): + return False + for icon_name in icon_names: + if not osp.exists(osp.join(ICONS_DIR, icon_name)): + return False + return True + + +if __name__ == '__main__': + + if not check_icons(): + root_url = 'https://github.com/bethgelab/model-vs-human/raw/master/assets/icons' # noqa: E501 + os.makedirs(ICONS_DIR, exist_ok=True) + MMLogger.get_current_instance().info( + f'Downloading icons to {ICONS_DIR}') + for icon_name in icon_names: + url = osp.join(root_url, icon_name) + os.system('wget -O {} {}'.format( + osp.join(ICONS_DIR, icon_name), url)) + + args = parser.parse_args() + assert len(args.model_names) * 3 == len(args.colors), 'Number of colors \ + must be 3 times the number of models. Every three colors are the RGB \ + values for one model.' + + # preprocess colors + args.colors = [c / 255. for c in args.colors] + colors = [] + for i in range(len(args.model_names)): + colors.append(args.colors[3 * i:3 * i + 3]) + args.colors = colors + args.colors.append([165 / 255., 30 / 255., 55 / 255.]) # human color + + # if plotting names are not specified, use model names + if len(args.plotting_names) == 0: + args.plotting_names = args.model_names + + # preprocess markers + args.markers.append('D') # human marker + + # preprocess model names + args.model_names = [[m] for m in args.model_names] + args.model_names.append(humans) + + # preprocess plotting names + args.plotting_names.append('Humans') + + plot_shape_bias_matrixplot(args) + if args.delete_icons: + os.system('rm -rf {}'.format(ICONS_DIR)) diff --git a/tools/analysis_tools/utils.py b/tools/analysis_tools/utils.py new file mode 100644 index 00000000..184cb32a --- /dev/null +++ b/tools/analysis_tools/utils.py @@ -0,0 +1,277 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modified from https://github.com/bethgelab/model-vs-human +from typing import Any, Dict, List, Optional + +import matplotlib as mpl +import pandas as pd +from matplotlib import _api +from matplotlib import transforms as mtransforms + + +class _DummyAxis: + """Define the minimal interface for a dummy axis. + + Args: + minpos (float): The minimum positive value for the axis. Defaults to 0. + """ + __name__ = 'dummy' + + # Once the deprecation elapses, replace dataLim and viewLim by plain + # _view_interval and _data_interval private tuples. + dataLim = _api.deprecate_privatize_attribute( + '3.6', alternative='get_data_interval() and set_data_interval()') + viewLim = _api.deprecate_privatize_attribute( + '3.6', alternative='get_view_interval() and set_view_interval()') + + def __init__(self, minpos: float = 0) -> None: + self._dataLim = mtransforms.Bbox.unit() + self._viewLim = mtransforms.Bbox.unit() + self._minpos = minpos + + def get_view_interval(self) -> Dict: + """Return the view interval as a tuple (*vmin*, *vmax*).""" + return self._viewLim.intervalx + + def set_view_interval(self, vmin: float, vmax: float) -> None: + """Set the view interval to (*vmin*, *vmax*).""" + self._viewLim.intervalx = vmin, vmax + + def get_minpos(self) -> float: + """Return the minimum positive value for the axis.""" + return self._minpos + + def get_data_interval(self) -> Dict: + """Return the data interval as a tuple (*vmin*, *vmax*).""" + return self._dataLim.intervalx + + def set_data_interval(self, vmin: float, vmax: float) -> None: + """Set the data interval to (*vmin*, *vmax*).""" + self._dataLim.intervalx = vmin, vmax + + def get_tick_space(self) -> int: + """Return the number of ticks to use.""" + # Just use the long-standing default of nbins==9 + return 9 + + +class TickHelper: + """A helper class for ticks and tick labels.""" + axis = None + + def set_axis(self, axis: Any) -> None: + """Set the axis instance.""" + self.axis = axis + + def create_dummy_axis(self, **kwargs) -> None: + """Create a dummy axis if no axis is set.""" + if self.axis is None: + self.axis = _DummyAxis(**kwargs) + + @_api.deprecated('3.5', alternative='`.Axis.set_view_interval`') + def set_view_interval(self, vmin: float, vmax: float) -> None: + """Set the view interval to (*vmin*, *vmax*).""" + self.axis.set_view_interval(vmin, vmax) + + @_api.deprecated('3.5', alternative='`.Axis.set_data_interval`') + def set_data_interval(self, vmin: float, vmax: float) -> None: + """Set the data interval to (*vmin*, *vmax*).""" + self.axis.set_data_interval(vmin, vmax) + + @_api.deprecated( + '3.5', + alternative='`.Axis.set_view_interval` and `.Axis.set_data_interval`') + def set_bounds(self, vmin: float, vmax: float) -> None: + """Set the view and data interval to (*vmin*, *vmax*).""" + self.set_view_interval(vmin, vmax) + self.set_data_interval(vmin, vmax) + + +class Formatter(TickHelper): + """Create a string based on a tick value and location.""" + # some classes want to see all the locs to help format + # individual ones + locs = [] + + def __call__(self, x: str, pos: Optional[Any] = None) -> str: + """Return the format for tick value *x* at position pos. + + ``pos=None`` indicates an unspecified location. + + This method must be overridden in the derived class. + + Args: + x (str): The tick value. + pos (Optional[Any]): The tick position. Defaults to None. + """ + raise NotImplementedError('Derived must override') + + def format_ticks(self, values: pd.Series) -> List[str]: + """Return the tick labels for all the ticks at once. + + Args: + values (pd.Series): The tick values. + + Returns: + List[str]: The tick labels. + """ + self.set_locs(values) + return [self(value, i) for i, value in enumerate(values)] + + def format_data(self, value: Any) -> str: + """Return the full string representation of the value with the position + unspecified. + + Args: + value (Any): The tick value. + + Returns: + str: The full string representation of the value. + """ + return self.__call__(value) + + def format_data_short(self, value: Any) -> str: + """Return a short string version of the tick value. + + Defaults to the position-independent long value. + + Args: + value (Any): The tick value. + + Returns: + str: The short string representation of the value. + """ + return self.format_data(value) + + def get_offset(self) -> str: + """Return the offset string.""" + return '' + + def set_locs(self, locs: List[Any]) -> None: + """Set the locations of the ticks. + + This method is called before computing the tick labels because some + formatters need to know all tick locations to do so. + """ + self.locs = locs + + @staticmethod + def fix_minus(s: str) -> str: + """Some classes may want to replace a hyphen for minus with the proper + Unicode symbol (U+2212) for typographical correctness. + + This is a + helper method to perform such a replacement when it is enabled via + :rc:`axes.unicode_minus`. + + Args: + s (str): The string to replace the hyphen with the Unicode symbol. + """ + return (s.replace('-', '\N{MINUS SIGN}') + if mpl.rcParams['axes.unicode_minus'] else s) + + def _set_locator(self, locator: Any) -> None: + """Subclasses may want to override this to set a locator.""" + pass + + +class FormatStrFormatter(Formatter): + """Use an old-style ('%' operator) format string to format the tick. + + The format string should have a single variable format (%) in it. + It will be applied to the value (not the position) of the tick. + + Negative numeric values will use a dash, not a Unicode minus; use mathtext + to get a Unicode minus by wrapping the format specifier with $ (e.g. + "$%g$"). + + Args: + fmt (str): Format string. + """ + + def __init__(self, fmt: str) -> None: + self.fmt = fmt + + def __call__(self, x: str, pos: Optional[Any]) -> str: + """Return the formatted label string. + + Only the value *x* is formatted. The position is ignored. + + Args: + x (str): The value to format. + pos (Any): The position of the tick. Ignored. + """ + return self.fmt % x + + +class ShapeBias: + """Compute the shape bias of a model. + + Reference: `ImageNet-trained CNNs are biased towards texture; + increasing shape bias improves accuracy and robustness + `_. + """ + num_input_models = 1 + + def __init__(self) -> None: + super().__init__() + self.plotting_name = 'shape-bias' + + @staticmethod + def _check_dataframe(df: pd.DataFrame) -> None: + """Check that the dataframe is valid.""" + assert len(df) > 0, 'empty dataframe' + + def analysis(self, df: pd.DataFrame) -> Dict[str, float]: + """Compute the shape bias of a model. + + Args: + df (pd.DataFrame): The dataframe containing the data. + + Returns: + Dict[str, float]: The shape bias. + """ + self._check_dataframe(df) + + df = df.copy() + df['correct_texture'] = df['imagename'].apply( + self.get_texture_category) + df['correct_shape'] = df['category'] + + # remove those rows where shape = texture, i.e. no cue conflict present + df2 = df.loc[df.correct_shape != df.correct_texture] + fraction_correct_shape = len( + df2.loc[df2.object_response == df2.correct_shape]) / len(df) + fraction_correct_texture = len( + df2.loc[df2.object_response == df2.correct_texture]) / len(df) + shape_bias = fraction_correct_shape / ( + fraction_correct_shape + fraction_correct_texture) + + result_dict = { + 'fraction-correct-shape': fraction_correct_shape, + 'fraction-correct-texture': fraction_correct_texture, + 'shape-bias': shape_bias + } + return result_dict + + def get_texture_category(self, imagename: str) -> str: + """Return texture category from imagename. + + e.g. 'XXX_dog10-bird2.png' -> 'bird ' + + Args: + imagename (str): Name of the image. + + Returns: + str: Texture category. + """ + assert type(imagename) is str + + # remove unnecessary words + a = imagename.split('_')[-1] + # remove .png etc. + b = a.split('.')[0] + # get texture category (last word) + c = b.split('-')[-1] + # remove number, e.g. 'bird2' -> 'bird' + d = ''.join([i for i in c if not i.isdigit()]) + return d