From 304e81650acca248ca1129b27cf3cb1546e3bc1e Mon Sep 17 00:00:00 2001
From: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com>
Date: Fri, 30 Dec 2022 22:11:18 +0800
Subject: [PATCH] [Feature]: Shape bias (#635)
* [Feature]: Add shape bias vis
* [Fix]: Fix lint
* [Feature]: Add shape bias metrics
* [Fix]: Fix lint
* [Fix]: Delete redundant code
* [Feature]: Add shape bias doc
* [Fix]: Fix lint
* [Feature]: add UT
* [Fix]: Fix lint
* [Fix]: Fix typo
* [Fix]: Fix typo
* [Fix]: Fix args param style
* [Feature]: Download pic automatically
---
docs/en/user_guides/visualization.md | 86 ++++++
mmselfsup/evaluation/__init__.py | 1 +
mmselfsup/evaluation/metrics/__init__.py | 4 +
.../evaluation/metrics/shape_bias_label.py | 171 +++++++++++
requirements/optional.txt | 1 +
tests/data/test_test_session-1.csv | 1 +
.../test_metrics/test_shape_bias_metric.py | 15 +
tools/analysis_tools/utils.py | 277 ++++++++++++++++++
tools/analysis_tools/visualize_shape_bias.py | 264 +++++++++++++++++
9 files changed, 820 insertions(+)
create mode 100644 mmselfsup/evaluation/metrics/__init__.py
create mode 100644 mmselfsup/evaluation/metrics/shape_bias_label.py
create mode 100644 tests/data/test_test_session-1.csv
create mode 100644 tests/test_evaluation/test_metrics/test_shape_bias_metric.py
create mode 100644 tools/analysis_tools/utils.py
create mode 100644 tools/analysis_tools/visualize_shape_bias.py
diff --git a/docs/en/user_guides/visualization.md b/docs/en/user_guides/visualization.md
index 3f6e710a..f6163635 100644
--- a/docs/en/user_guides/visualization.md
+++ b/docs/en/user_guides/visualization.md
@@ -12,6 +12,11 @@ Visualization can give an intuitive interpretation of the performance of the mod
- [Visualize Datasets](#visualize-datasets)
- [Visualize t-SNE](#visualize-t-sne)
- [Visualize Low-level Feature Reconstruction](#visualize-low-level-feature-reconstruction)
+ - [Visualize Shape Bias](#visualize-shape-bias)
+ - [Prepare the dataset](#prepare-the-dataset)
+ - [Modify the config for classification](#modify-the-config-for-classification)
+ - [Inference your model with above modified config file](#inference-your-model-with-above-modified-config-file)
+ - [Plot shape bias](#plot-shape-bias)
@@ -205,3 +210,84 @@ Results of MaskFeat:
+
+## Visualize Shape Bias
+
+Shape bias measures how a model relies the shapes, compared to texture, to sense the semantics in images. For more details,
+we recommend interested readers to this [paper](https://arxiv.org/abs/2106.07411). MMSelfSup provide an off-the-shelf toolbox to
+obtain the shape bias of a classification model. You can following these steps below:
+
+### Prepare the dataset
+
+First you should download the [cue-conflict](https://github.com/bethgelab/model-vs-human/releases/download/v0.1/cue-conflict.tar.gz) to `data` folder,
+and then unzip this dataset. After that, you `data` folder should have the following structure:
+
+```text
+data
+├──cue-conflict
+| |──airplane
+| |──bear
+| ...
+| |── truck
+```
+
+### Modify the config for classification
+
+Replace the original test_dataloader and test_evaluation with following configurations
+
+```python
+test_dataloader = dict(
+ dataset=dict(
+ type='CustomDataset',
+ data_root='data/cue-conflict',
+ _delete_=True),
+ drop_last=False)
+test_evaluator = dict(
+ type='mmselfsup.ShapeBiasMetric',
+ _delete_=True,
+ csv_dir='directory/to/save/the/csv/file',
+ model_name='your_model_name')
+```
+
+Please note you should make custom modifications to the `csv_dir` and `model_name`.
+
+### Inference your model with above modified config file
+
+Then you should inferece your model on the `cue-conflict` dataset with the your modified config files.
+
+```shell
+# For Slurm
+GPUS_PER_NODE=1 GPUS=1 bash tools/benchmarks/classification/mim_slurm_test.sh $partition $config $checkpoint
+```
+
+```shell
+# For PyTorch
+GPUS=1 bash tools/benchmarks/classification/mim_dist_test.sh $config $checkpoint
+```
+
+After that, you should obtain a csv file, named `cue-conflict_model-name_session-1.csv`. Besides this file, you should
+also download these [csv files](https://github.com/bethgelab/model-vs-human/tree/master/raw-data/cue-conflict) to the
+`csv_dir`.
+
+### Plot shape bias
+
+Then we can start to plot the shape bias
+
+```shell
+python tools/analysis_tools/visualize_shape_bias.py --csv-dir $CVS_DIR --result-dir $CSV_DIR --colors $RGB --markers o --plotting-names $YOU_MODEL_NAME --model-names $YOU_MODEL_NAME
+```
+
+- csv-dir, the same directory to save these csv files
+- colors, should be the RGB values, formatted in R G B, e.g. 100 100 100, and can be multiple RGB values, if you want
+ to plot the shape bias of several models
+- plotting-names, the name of the legend in the shape bias figure, and you can set it as your model name. If you want
+ to plot several models, plotting_names can be multiple values
+- model-names, should be the same name specified in your config, and can be multiple names if you want to plot the
+ shape bias of several models
+
+Please note, every three values for `--colors` corresponds to one value for `--model-names`. After all of above steps, you
+are expected to obtain the following figure.
+
+
+

+
diff --git a/mmselfsup/evaluation/__init__.py b/mmselfsup/evaluation/__init__.py
index 228f58c2..f70dc226 100644
--- a/mmselfsup/evaluation/__init__.py
+++ b/mmselfsup/evaluation/__init__.py
@@ -1,2 +1,3 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .functional import * # noqa: F401,F403
+from .metrics import * # noqa: F401,F403
diff --git a/mmselfsup/evaluation/metrics/__init__.py b/mmselfsup/evaluation/metrics/__init__.py
new file mode 100644
index 00000000..5b18dcd8
--- /dev/null
+++ b/mmselfsup/evaluation/metrics/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .shape_bias_label import ShapeBiasMetric
+
+__all__ = ['ShapeBiasMetric']
diff --git a/mmselfsup/evaluation/metrics/shape_bias_label.py b/mmselfsup/evaluation/metrics/shape_bias_label.py
new file mode 100644
index 00000000..c136e04a
--- /dev/null
+++ b/mmselfsup/evaluation/metrics/shape_bias_label.py
@@ -0,0 +1,171 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import csv
+import os
+import os.path as osp
+from typing import List, Sequence
+
+import numpy as np
+import torch
+from mmengine.evaluator import BaseMetric
+
+from mmselfsup.registry import METRICS
+
+
+@METRICS.register_module()
+class ShapeBiasMetric(BaseMetric):
+ """Evaluate the model on ``cue_conflict`` dataset.
+
+ This module will evaluate the model on an OOD dataset, cue_conflict, in
+ order to measure the shape bias of the model. In addition to compuate the
+ Top-1 accuracy, this module also generate a csv file to record the
+ detailed prediction results, such that this csv file can be used to
+ generate the shape bias curve.
+
+ Args:
+ csv_dir (str): The directory to save the csv file.
+ model_name (str): The name of the csv file. Please note that the
+ model name should be an unique identifier.
+ dataset_name (str): The name of the dataset. Default: 'cue_conflict'.
+ """
+
+ # mapping several classes from ImageNet-1K to the same category
+ airplane_indices = [404]
+ bear_indices = [294, 295, 296, 297]
+ bicycle_indices = [444, 671]
+ bird_indices = [
+ 8, 10, 11, 12, 13, 14, 15, 16, 18, 19, 20, 22, 23, 24, 80, 81, 82, 83,
+ 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 98, 99, 100, 127, 128, 129,
+ 130, 131, 132, 133, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144,
+ 145
+ ]
+ boat_indices = [472, 554, 625, 814, 914]
+ bottle_indices = [440, 720, 737, 898, 899, 901, 907]
+ car_indices = [436, 511, 817]
+ cat_indices = [281, 282, 283, 284, 285, 286]
+ chair_indices = [423, 559, 765, 857]
+ clock_indices = [409, 530, 892]
+ dog_indices = [
+ 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165,
+ 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179,
+ 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 193, 194,
+ 195, 196, 197, 198, 199, 200, 201, 202, 203, 205, 206, 207, 208, 209,
+ 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223,
+ 224, 225, 226, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238,
+ 239, 240, 241, 243, 244, 245, 246, 247, 248, 249, 250, 252, 253, 254,
+ 255, 256, 257, 259, 261, 262, 263, 265, 266, 267, 268
+ ]
+ elephant_indices = [385, 386]
+ keyboard_indices = [508, 878]
+ knife_indices = [499]
+ oven_indices = [766]
+ truck_indices = [555, 569, 656, 675, 717, 734, 864, 867]
+
+ def __init__(self,
+ csv_dir: str,
+ model_name: str,
+ dataset_name: str = 'cue_conflict',
+ **kwargs) -> None:
+ super().__init__(**kwargs)
+
+ self.categories = sorted([
+ 'knife', 'keyboard', 'elephant', 'bicycle', 'airplane', 'clock',
+ 'oven', 'chair', 'bear', 'boat', 'cat', 'bottle', 'truck', 'car',
+ 'bird', 'dog'
+ ])
+ self.csv_dir = csv_dir
+ self.model_name = model_name
+ self.dataset_name = dataset_name
+ self.csv_path = self.create_csv()
+
+ def process(self, data_batch, data_samples: Sequence[dict]) -> None:
+ """Process one batch of data samples.
+
+ The processed results should be stored in ``self.results``, which will
+ be used to computed the metrics when all batches have been processed.
+
+ Args:
+ data_batch: A batch of data from the dataloader.
+ data_samples (Sequence[dict]): A batch of outputs from the model.
+ """
+ for data_sample in data_samples:
+ result = dict()
+ pred_label = data_sample['pred_label']
+ gt_label = data_sample['gt_label']
+ if 'score' in pred_label:
+ result['pred_score'] = pred_label['score'].cpu()
+ else:
+ result['pred_label'] = pred_label['label'].cpu()
+ result['gt_label'] = gt_label['label'].cpu()
+ result['gt_category'] = data_sample['img_path'].split('/')[-2]
+ result['img_name'] = data_sample['img_path'].split('/')[-1]
+
+ aggregated_category_probabilities = []
+ # get the prediction for each category of current instance
+ for category in self.categories:
+ category_indices = getattr(self, f'{category}_indices')
+ category_probabilities = torch.gather(
+ result['pred_score'], 0,
+ torch.tensor(category_indices)).mean()
+ aggregated_category_probabilities.append(
+ category_probabilities)
+ # sort the probabilities in descending order
+ pred_indices = torch.stack(aggregated_category_probabilities
+ ).argsort(descending=True).numpy()
+ result['pred_category'] = np.take(self.categories, pred_indices)
+
+ # Save the result to `self.results`.
+ self.results.append(result)
+
+ def create_csv(self) -> str:
+ """Create a csv file to store the results."""
+ session_name = 'session-1'
+ csv_path = osp.join(
+ self.csv_dir, self.dataset_name + '_' + self.model_name + '_' +
+ session_name + '.csv')
+ if osp.exists(csv_path):
+ os.remove(csv_path)
+ directory = osp.dirname(csv_path)
+ if not osp.exists(directory):
+ os.makedirs(directory)
+ with open(csv_path, 'w') as f:
+ writer = csv.writer(f)
+ writer.writerow([
+ 'subj', 'session', 'trial', 'rt', 'object_response',
+ 'category', 'condition', 'imagename'
+ ])
+ return csv_path
+
+ def dump_results_to_csv(self, results: List[dict]) -> None:
+ """Dump the results to a csv file.
+
+ Args:
+ results (List[dict]): A list of results.
+ """
+ for i, result in enumerate(results):
+ img_name = result['img_name']
+ category = result['gt_category']
+ condition = 'NaN'
+ with open(self.csv_path, 'a') as f:
+ writer = csv.writer(f)
+ writer.writerow([
+ self.model_name, 1, i + 1, 'NaN',
+ result['pred_category'][0], category, condition, img_name
+ ])
+
+ def compute_metrics(self, results: List[dict]) -> dict:
+ """Compute the metrics from the results.
+
+ Args:
+ results (List[dict]): A list of results.
+
+ Returns:
+ dict: A dict of metrics.
+ """
+ self.dump_results_to_csv(results)
+ metrics = dict()
+ metrics['accuracy/top1'] = np.mean([
+ result['pred_category'][0] == result['gt_category']
+ for result in results
+ ])
+
+ return metrics
diff --git a/requirements/optional.txt b/requirements/optional.txt
index f749cfd9..d9690686 100644
--- a/requirements/optional.txt
+++ b/requirements/optional.txt
@@ -1 +1,2 @@
faiss-gpu==1.7.2
+pandas
diff --git a/tests/data/test_test_session-1.csv b/tests/data/test_test_session-1.csv
new file mode 100644
index 00000000..2c6a5158
--- /dev/null
+++ b/tests/data/test_test_session-1.csv
@@ -0,0 +1 @@
+subj,session,trial,rt,object_response,category,condition,imagename
diff --git a/tests/test_evaluation/test_metrics/test_shape_bias_metric.py b/tests/test_evaluation/test_metrics/test_shape_bias_metric.py
new file mode 100644
index 00000000..403d3222
--- /dev/null
+++ b/tests/test_evaluation/test_metrics/test_shape_bias_metric.py
@@ -0,0 +1,15 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from mmselfsup.evaluation import ShapeBiasMetric
+
+
+def test_shape_bias_metric():
+ data_sample = dict()
+ data_sample['pred_label'] = dict(
+ score=torch.rand(1000, ), label=torch.tensor(1))
+ data_sample['gt_label'] = dict(label=torch.tensor(1))
+ data_sample['img_path'] = 'tests/airplane/test.JPEG'
+ evaluator = ShapeBiasMetric(
+ csv_dir='tests/data', dataset_name='test', model_name='test')
+ evaluator.process(None, [data_sample])
diff --git a/tools/analysis_tools/utils.py b/tools/analysis_tools/utils.py
new file mode 100644
index 00000000..184cb32a
--- /dev/null
+++ b/tools/analysis_tools/utils.py
@@ -0,0 +1,277 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# Modified from https://github.com/bethgelab/model-vs-human
+from typing import Any, Dict, List, Optional
+
+import matplotlib as mpl
+import pandas as pd
+from matplotlib import _api
+from matplotlib import transforms as mtransforms
+
+
+class _DummyAxis:
+ """Define the minimal interface for a dummy axis.
+
+ Args:
+ minpos (float): The minimum positive value for the axis. Defaults to 0.
+ """
+ __name__ = 'dummy'
+
+ # Once the deprecation elapses, replace dataLim and viewLim by plain
+ # _view_interval and _data_interval private tuples.
+ dataLim = _api.deprecate_privatize_attribute(
+ '3.6', alternative='get_data_interval() and set_data_interval()')
+ viewLim = _api.deprecate_privatize_attribute(
+ '3.6', alternative='get_view_interval() and set_view_interval()')
+
+ def __init__(self, minpos: float = 0) -> None:
+ self._dataLim = mtransforms.Bbox.unit()
+ self._viewLim = mtransforms.Bbox.unit()
+ self._minpos = minpos
+
+ def get_view_interval(self) -> Dict:
+ """Return the view interval as a tuple (*vmin*, *vmax*)."""
+ return self._viewLim.intervalx
+
+ def set_view_interval(self, vmin: float, vmax: float) -> None:
+ """Set the view interval to (*vmin*, *vmax*)."""
+ self._viewLim.intervalx = vmin, vmax
+
+ def get_minpos(self) -> float:
+ """Return the minimum positive value for the axis."""
+ return self._minpos
+
+ def get_data_interval(self) -> Dict:
+ """Return the data interval as a tuple (*vmin*, *vmax*)."""
+ return self._dataLim.intervalx
+
+ def set_data_interval(self, vmin: float, vmax: float) -> None:
+ """Set the data interval to (*vmin*, *vmax*)."""
+ self._dataLim.intervalx = vmin, vmax
+
+ def get_tick_space(self) -> int:
+ """Return the number of ticks to use."""
+ # Just use the long-standing default of nbins==9
+ return 9
+
+
+class TickHelper:
+ """A helper class for ticks and tick labels."""
+ axis = None
+
+ def set_axis(self, axis: Any) -> None:
+ """Set the axis instance."""
+ self.axis = axis
+
+ def create_dummy_axis(self, **kwargs) -> None:
+ """Create a dummy axis if no axis is set."""
+ if self.axis is None:
+ self.axis = _DummyAxis(**kwargs)
+
+ @_api.deprecated('3.5', alternative='`.Axis.set_view_interval`')
+ def set_view_interval(self, vmin: float, vmax: float) -> None:
+ """Set the view interval to (*vmin*, *vmax*)."""
+ self.axis.set_view_interval(vmin, vmax)
+
+ @_api.deprecated('3.5', alternative='`.Axis.set_data_interval`')
+ def set_data_interval(self, vmin: float, vmax: float) -> None:
+ """Set the data interval to (*vmin*, *vmax*)."""
+ self.axis.set_data_interval(vmin, vmax)
+
+ @_api.deprecated(
+ '3.5',
+ alternative='`.Axis.set_view_interval` and `.Axis.set_data_interval`')
+ def set_bounds(self, vmin: float, vmax: float) -> None:
+ """Set the view and data interval to (*vmin*, *vmax*)."""
+ self.set_view_interval(vmin, vmax)
+ self.set_data_interval(vmin, vmax)
+
+
+class Formatter(TickHelper):
+ """Create a string based on a tick value and location."""
+ # some classes want to see all the locs to help format
+ # individual ones
+ locs = []
+
+ def __call__(self, x: str, pos: Optional[Any] = None) -> str:
+ """Return the format for tick value *x* at position pos.
+
+ ``pos=None`` indicates an unspecified location.
+
+ This method must be overridden in the derived class.
+
+ Args:
+ x (str): The tick value.
+ pos (Optional[Any]): The tick position. Defaults to None.
+ """
+ raise NotImplementedError('Derived must override')
+
+ def format_ticks(self, values: pd.Series) -> List[str]:
+ """Return the tick labels for all the ticks at once.
+
+ Args:
+ values (pd.Series): The tick values.
+
+ Returns:
+ List[str]: The tick labels.
+ """
+ self.set_locs(values)
+ return [self(value, i) for i, value in enumerate(values)]
+
+ def format_data(self, value: Any) -> str:
+ """Return the full string representation of the value with the position
+ unspecified.
+
+ Args:
+ value (Any): The tick value.
+
+ Returns:
+ str: The full string representation of the value.
+ """
+ return self.__call__(value)
+
+ def format_data_short(self, value: Any) -> str:
+ """Return a short string version of the tick value.
+
+ Defaults to the position-independent long value.
+
+ Args:
+ value (Any): The tick value.
+
+ Returns:
+ str: The short string representation of the value.
+ """
+ return self.format_data(value)
+
+ def get_offset(self) -> str:
+ """Return the offset string."""
+ return ''
+
+ def set_locs(self, locs: List[Any]) -> None:
+ """Set the locations of the ticks.
+
+ This method is called before computing the tick labels because some
+ formatters need to know all tick locations to do so.
+ """
+ self.locs = locs
+
+ @staticmethod
+ def fix_minus(s: str) -> str:
+ """Some classes may want to replace a hyphen for minus with the proper
+ Unicode symbol (U+2212) for typographical correctness.
+
+ This is a
+ helper method to perform such a replacement when it is enabled via
+ :rc:`axes.unicode_minus`.
+
+ Args:
+ s (str): The string to replace the hyphen with the Unicode symbol.
+ """
+ return (s.replace('-', '\N{MINUS SIGN}')
+ if mpl.rcParams['axes.unicode_minus'] else s)
+
+ def _set_locator(self, locator: Any) -> None:
+ """Subclasses may want to override this to set a locator."""
+ pass
+
+
+class FormatStrFormatter(Formatter):
+ """Use an old-style ('%' operator) format string to format the tick.
+
+ The format string should have a single variable format (%) in it.
+ It will be applied to the value (not the position) of the tick.
+
+ Negative numeric values will use a dash, not a Unicode minus; use mathtext
+ to get a Unicode minus by wrapping the format specifier with $ (e.g.
+ "$%g$").
+
+ Args:
+ fmt (str): Format string.
+ """
+
+ def __init__(self, fmt: str) -> None:
+ self.fmt = fmt
+
+ def __call__(self, x: str, pos: Optional[Any]) -> str:
+ """Return the formatted label string.
+
+ Only the value *x* is formatted. The position is ignored.
+
+ Args:
+ x (str): The value to format.
+ pos (Any): The position of the tick. Ignored.
+ """
+ return self.fmt % x
+
+
+class ShapeBias:
+ """Compute the shape bias of a model.
+
+ Reference: `ImageNet-trained CNNs are biased towards texture;
+ increasing shape bias improves accuracy and robustness
+ `_.
+ """
+ num_input_models = 1
+
+ def __init__(self) -> None:
+ super().__init__()
+ self.plotting_name = 'shape-bias'
+
+ @staticmethod
+ def _check_dataframe(df: pd.DataFrame) -> None:
+ """Check that the dataframe is valid."""
+ assert len(df) > 0, 'empty dataframe'
+
+ def analysis(self, df: pd.DataFrame) -> Dict[str, float]:
+ """Compute the shape bias of a model.
+
+ Args:
+ df (pd.DataFrame): The dataframe containing the data.
+
+ Returns:
+ Dict[str, float]: The shape bias.
+ """
+ self._check_dataframe(df)
+
+ df = df.copy()
+ df['correct_texture'] = df['imagename'].apply(
+ self.get_texture_category)
+ df['correct_shape'] = df['category']
+
+ # remove those rows where shape = texture, i.e. no cue conflict present
+ df2 = df.loc[df.correct_shape != df.correct_texture]
+ fraction_correct_shape = len(
+ df2.loc[df2.object_response == df2.correct_shape]) / len(df)
+ fraction_correct_texture = len(
+ df2.loc[df2.object_response == df2.correct_texture]) / len(df)
+ shape_bias = fraction_correct_shape / (
+ fraction_correct_shape + fraction_correct_texture)
+
+ result_dict = {
+ 'fraction-correct-shape': fraction_correct_shape,
+ 'fraction-correct-texture': fraction_correct_texture,
+ 'shape-bias': shape_bias
+ }
+ return result_dict
+
+ def get_texture_category(self, imagename: str) -> str:
+ """Return texture category from imagename.
+
+ e.g. 'XXX_dog10-bird2.png' -> 'bird '
+
+ Args:
+ imagename (str): Name of the image.
+
+ Returns:
+ str: Texture category.
+ """
+ assert type(imagename) is str
+
+ # remove unnecessary words
+ a = imagename.split('_')[-1]
+ # remove .png etc.
+ b = a.split('.')[0]
+ # get texture category (last word)
+ c = b.split('-')[-1]
+ # remove number, e.g. 'bird2' -> 'bird'
+ d = ''.join([i for i in c if not i.isdigit()])
+ return d
diff --git a/tools/analysis_tools/visualize_shape_bias.py b/tools/analysis_tools/visualize_shape_bias.py
new file mode 100644
index 00000000..43187e6f
--- /dev/null
+++ b/tools/analysis_tools/visualize_shape_bias.py
@@ -0,0 +1,264 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# Modified from https://github.com/bethgelab/model-vs-human
+import argparse
+import os
+import os.path as osp
+
+import matplotlib as mpl
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+from utils import FormatStrFormatter, ShapeBias
+
+# global default boundary settings for thin gray transparent
+# boundaries to avoid not being able to see the difference
+# between two partially overlapping datapoints of the same color:
+PLOTTING_EDGE_COLOR = (0.3, 0.3, 0.3, 0.3)
+PLOTTING_EDGE_WIDTH = 0.02
+ICONS_DIR = osp.join(
+ osp.dirname(__file__), '..', '..', 'resources', 'shape_bias_icons')
+
+parser = argparse.ArgumentParser()
+parser.add_argument('--csv-dir', type=str, help='directory of csv files')
+parser.add_argument(
+ '--result-dir', type=str, help='directory to save plotting results')
+parser.add_argument('--model-names', nargs='+', default=[], help='model name')
+parser.add_argument(
+ '--colors',
+ nargs='+',
+ type=float,
+ default=[],
+ help= # noqa
+ 'the colors for the plots of each model, and they should be in the same order as model_names' # noqa: E501
+)
+parser.add_argument(
+ '--markers',
+ nargs='+',
+ type=str,
+ default=[],
+ help= # noqa
+ 'the markers for the plots of each model, and they should be in the same order as model_names' # noqa: E501
+)
+parser.add_argument(
+ '--plotting-names',
+ nargs='+',
+ default=[],
+ help= # noqa
+ 'the plotting names for the plots of each model, and they should be in the same order as model_names' # noqa: E501
+)
+
+humans = [
+ 'subject-01', 'subject-02', 'subject-03', 'subject-04', 'subject-05',
+ 'subject-06', 'subject-07', 'subject-08', 'subject-09', 'subject-10'
+]
+
+
+def read_csvs(csv_dir: str) -> pd.DataFrame:
+ """Reads all csv files in a directory and returns a single dataframe.
+
+ Args:
+ csv_dir (str): directory of csv files.
+
+ Returns:
+ pd.DataFrame: dataframe containing all csv files
+ """
+ df = pd.DataFrame()
+ for csv in os.listdir(csv_dir):
+ if csv.endswith('.csv'):
+ cur_df = pd.read_csv(osp.join(csv_dir, csv))
+ cur_df.columns = [c.lower() for c in cur_df.columns]
+ df = df.append(cur_df)
+ df.condition = df.condition.astype(str)
+ return df
+
+
+def plot_shape_bias_matrixplot(args, analysis=ShapeBias()) -> None:
+ """Plots a matrixplot of shape bias.
+
+ Args:
+ args (argparse.Namespace): arguments.
+ analysis (ShapeBias): shape bias analysis. Defaults to ShapeBias().
+ """
+ mpl.rcParams['font.family'] = ['serif']
+ mpl.rcParams['font.serif'] = ['Times New Roman']
+
+ plt.figure(figsize=(9, 7))
+
+ df = read_csvs(args.csv_dir)
+
+ fontsize = 15
+ ticklength = 10
+ markersize = 250
+ label_size = 20
+
+ classes = df['category'].unique()
+ num_classes = len(classes)
+
+ # plot setup
+ fig = plt.figure(1, figsize=(12, 12), dpi=300.)
+ ax = plt.gca()
+
+ ax.set_xlim([0, 1])
+ ax.set_ylim([-.5, num_classes - 0.5])
+
+ # secondary reversed x axis
+ ax_top = ax.secondary_xaxis(
+ 'top', functions=(lambda x: 1 - x, lambda x: 1 - x))
+
+ # labels, ticks
+ plt.tick_params(
+ axis='y', which='both', left=False, right=False, labelleft=False)
+ ax.set_ylabel('Shape categories', labelpad=60, fontsize=label_size)
+ ax.set_xlabel(
+ "Fraction of 'texture' decisions", fontsize=label_size, labelpad=25)
+ ax_top.set_xlabel(
+ "Fraction of 'shape' decisions", fontsize=label_size, labelpad=25)
+ ax.xaxis.set_major_formatter(FormatStrFormatter('%g'))
+ ax_top.xaxis.set_major_formatter(FormatStrFormatter('%g'))
+ ax.get_xaxis().set_ticks(np.arange(0, 1.1, 0.1))
+ ax_top.set_ticks(np.arange(0, 1.1, 0.1))
+ ax.tick_params(
+ axis='both', which='major', labelsize=fontsize, length=ticklength)
+ ax_top.tick_params(
+ axis='both', which='major', labelsize=fontsize, length=ticklength)
+
+ # arrows on x axes
+ plt.arrow(
+ x=0,
+ y=-1.75,
+ dx=1,
+ dy=0,
+ fc='black',
+ head_width=0.4,
+ head_length=0.03,
+ clip_on=False,
+ length_includes_head=True,
+ overhang=0.5)
+ plt.arrow(
+ x=1,
+ y=num_classes + 0.75,
+ dx=-1,
+ dy=0,
+ fc='black',
+ head_width=0.4,
+ head_length=0.03,
+ clip_on=False,
+ length_includes_head=True,
+ overhang=0.5)
+
+ # icons besides y axis
+ # determine order of icons
+ df_selection = df.loc[(df['subj'].isin(humans))]
+ class_avgs = []
+ for cl in classes:
+ df_class_selection = df_selection.query("category == '{}'".format(cl))
+ class_avgs.append(1 - analysis.analysis(
+ df=df_class_selection)['shape-bias'])
+ sorted_indices = np.argsort(class_avgs)
+ classes = classes[sorted_indices]
+
+ # icon placement is calculated in axis coordinates
+ WIDTH = 1 / num_classes
+ # placement left of yaxis (-WIDTH) plus some spacing (-.25*WIDTH)
+ XPOS = -1.25 * WIDTH
+ YPOS = -0.5
+ HEIGHT = 1
+ MARGINX = 1 / 10 * WIDTH # vertical whitespace between icons
+ MARGINY = 1 / 10 * HEIGHT # horizontal whitespace between icons
+
+ left = XPOS + MARGINX
+ right = XPOS + WIDTH - MARGINX
+
+ for i in range(num_classes):
+ bottom = i + MARGINY + YPOS
+ top = (i + 1) - MARGINY + YPOS
+ iconpath = osp.join(ICONS_DIR, '{}.png'.format(classes[i]))
+ plt.imshow(
+ plt.imread(iconpath),
+ extent=[left, right, bottom, top],
+ aspect='auto',
+ clip_on=False)
+
+ # plot horizontal intersection lines
+ for i in range(num_classes - 1):
+ plt.plot([0, 1], [i + .5, i + .5],
+ c='gray',
+ linestyle='dotted',
+ alpha=0.4)
+
+ # plot average shapebias + scatter points
+ for i in range(len(args.model_names)):
+ df_selection = df.loc[(df['subj'].isin(args.model_names[i]))]
+ result_df = analysis.analysis(df=df_selection)
+ avg = 1 - result_df['shape-bias']
+ ax.plot([avg, avg], [-1, num_classes], color=args.colors[i])
+ class_avgs = []
+ for cl in classes:
+ df_class_selection = df_selection.query(
+ "category == '{}'".format(cl))
+ class_avgs.append(1 - analysis.analysis(
+ df=df_class_selection)['shape-bias'])
+
+ ax.scatter(
+ class_avgs,
+ classes,
+ color=args.colors[i],
+ marker=args.markers[i],
+ label=args.plotting_names[i],
+ s=markersize,
+ clip_on=False,
+ edgecolors=PLOTTING_EDGE_COLOR,
+ linewidths=PLOTTING_EDGE_WIDTH,
+ zorder=3)
+ plt.legend(frameon=True, labelspacing=1, loc=9)
+
+ figure_path = osp.join(args.result_dir,
+ 'cue-conflict_shape-bias_matrixplot.pdf')
+ fig.savefig(figure_path, bbox_inches='tight')
+ plt.close()
+
+
+if __name__ == '__main__':
+ icon_names = [
+ 'airplane.png', 'response_icons_vertical_reverse.png', 'bottle.png',
+ 'car.png', 'oven.png', 'elephant.png', 'dog.png', 'boat.png',
+ 'clock.png', 'chair.png', 'keyboard.png', 'bird.png', 'bicycle.png',
+ 'response_icons_horizontal.png', 'cat.png', 'bear.png', 'colorbar.pdf',
+ 'knife.png', 'response_icons_vertical.png', 'truck.png'
+ ]
+ root_url = 'https://github.com/bethgelab/model-vs-human/raw/master/assets/icons' # noqa: E501
+ os.makedirs(ICONS_DIR, exist_ok=True)
+ for icon_name in icon_names:
+ url = osp.join(root_url, icon_name)
+ os.system('wget -O {} {}'.format(osp.join(ICONS_DIR, icon_name), url))
+
+ args = parser.parse_args()
+ assert len(args.model_names) * 3 == len(args.colors), 'Number of colors \
+ must be 3 times the number of models. Every three colors are the RGB \
+ values for one model.'
+
+ # preprocess colors
+ args.colors = [c / 255. for c in args.colors]
+ colors = []
+ for i in range(len(args.model_names)):
+ colors.append(args.colors[3 * i:3 * i + 3])
+ args.colors = colors
+ args.colors.append([165 / 255., 30 / 255., 55 / 255.]) # human color
+
+ # if plotting names are not specified, use model names
+ if len(args.plotting_names) == 0:
+ args.plotting_names = args.model_names
+
+ # preprocess markers
+ args.markers.append('D') # human marker
+
+ # preprocess model names
+ args.model_names = [[m] for m in args.model_names]
+ args.model_names.append(humans)
+
+ # preprocess plotting names
+ args.plotting_names.append('Humans')
+
+ plot_shape_bias_matrixplot(args)
+
+ os.system('rm -rf {}'.format(ICONS_DIR))