From 304e81650acca248ca1129b27cf3cb1546e3bc1e Mon Sep 17 00:00:00 2001 From: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com> Date: Fri, 30 Dec 2022 22:11:18 +0800 Subject: [PATCH] [Feature]: Shape bias (#635) * [Feature]: Add shape bias vis * [Fix]: Fix lint * [Feature]: Add shape bias metrics * [Fix]: Fix lint * [Fix]: Delete redundant code * [Feature]: Add shape bias doc * [Fix]: Fix lint * [Feature]: add UT * [Fix]: Fix lint * [Fix]: Fix typo * [Fix]: Fix typo * [Fix]: Fix args param style * [Feature]: Download pic automatically --- docs/en/user_guides/visualization.md | 86 ++++++ mmselfsup/evaluation/__init__.py | 1 + mmselfsup/evaluation/metrics/__init__.py | 4 + .../evaluation/metrics/shape_bias_label.py | 171 +++++++++++ requirements/optional.txt | 1 + tests/data/test_test_session-1.csv | 1 + .../test_metrics/test_shape_bias_metric.py | 15 + tools/analysis_tools/utils.py | 277 ++++++++++++++++++ tools/analysis_tools/visualize_shape_bias.py | 264 +++++++++++++++++ 9 files changed, 820 insertions(+) create mode 100644 mmselfsup/evaluation/metrics/__init__.py create mode 100644 mmselfsup/evaluation/metrics/shape_bias_label.py create mode 100644 tests/data/test_test_session-1.csv create mode 100644 tests/test_evaluation/test_metrics/test_shape_bias_metric.py create mode 100644 tools/analysis_tools/utils.py create mode 100644 tools/analysis_tools/visualize_shape_bias.py diff --git a/docs/en/user_guides/visualization.md b/docs/en/user_guides/visualization.md index 3f6e710a..f6163635 100644 --- a/docs/en/user_guides/visualization.md +++ b/docs/en/user_guides/visualization.md @@ -12,6 +12,11 @@ Visualization can give an intuitive interpretation of the performance of the mod - [Visualize Datasets](#visualize-datasets) - [Visualize t-SNE](#visualize-t-sne) - [Visualize Low-level Feature Reconstruction](#visualize-low-level-feature-reconstruction) + - [Visualize Shape Bias](#visualize-shape-bias) + - [Prepare the dataset](#prepare-the-dataset) + - [Modify the config for classification](#modify-the-config-for-classification) + - [Inference your model with above modified config file](#inference-your-model-with-above-modified-config-file) + - [Plot shape bias](#plot-shape-bias) @@ -205,3 +210,84 @@ Results of MaskFeat:
+ +## Visualize Shape Bias + +Shape bias measures how a model relies the shapes, compared to texture, to sense the semantics in images. For more details, +we recommend interested readers to this [paper](https://arxiv.org/abs/2106.07411). MMSelfSup provide an off-the-shelf toolbox to +obtain the shape bias of a classification model. You can following these steps below: + +### Prepare the dataset + +First you should download the [cue-conflict](https://github.com/bethgelab/model-vs-human/releases/download/v0.1/cue-conflict.tar.gz) to `data` folder, +and then unzip this dataset. After that, you `data` folder should have the following structure: + +```text +data +├──cue-conflict +| |──airplane +| |──bear +| ... +| |── truck +``` + +### Modify the config for classification + +Replace the original test_dataloader and test_evaluation with following configurations + +```python +test_dataloader = dict( + dataset=dict( + type='CustomDataset', + data_root='data/cue-conflict', + _delete_=True), + drop_last=False) +test_evaluator = dict( + type='mmselfsup.ShapeBiasMetric', + _delete_=True, + csv_dir='directory/to/save/the/csv/file', + model_name='your_model_name') +``` + +Please note you should make custom modifications to the `csv_dir` and `model_name`. + +### Inference your model with above modified config file + +Then you should inferece your model on the `cue-conflict` dataset with the your modified config files. + +```shell +# For Slurm +GPUS_PER_NODE=1 GPUS=1 bash tools/benchmarks/classification/mim_slurm_test.sh $partition $config $checkpoint +``` + +```shell +# For PyTorch +GPUS=1 bash tools/benchmarks/classification/mim_dist_test.sh $config $checkpoint +``` + +After that, you should obtain a csv file, named `cue-conflict_model-name_session-1.csv`. Besides this file, you should +also download these [csv files](https://github.com/bethgelab/model-vs-human/tree/master/raw-data/cue-conflict) to the +`csv_dir`. + +### Plot shape bias + +Then we can start to plot the shape bias + +```shell +python tools/analysis_tools/visualize_shape_bias.py --csv-dir $CVS_DIR --result-dir $CSV_DIR --colors $RGB --markers o --plotting-names $YOU_MODEL_NAME --model-names $YOU_MODEL_NAME +``` + +- csv-dir, the same directory to save these csv files +- colors, should be the RGB values, formatted in R G B, e.g. 100 100 100, and can be multiple RGB values, if you want + to plot the shape bias of several models +- plotting-names, the name of the legend in the shape bias figure, and you can set it as your model name. If you want + to plot several models, plotting_names can be multiple values +- model-names, should be the same name specified in your config, and can be multiple names if you want to plot the + shape bias of several models + +Please note, every three values for `--colors` corresponds to one value for `--model-names`. After all of above steps, you +are expected to obtain the following figure. + +
+ +
diff --git a/mmselfsup/evaluation/__init__.py b/mmselfsup/evaluation/__init__.py index 228f58c2..f70dc226 100644 --- a/mmselfsup/evaluation/__init__.py +++ b/mmselfsup/evaluation/__init__.py @@ -1,2 +1,3 @@ # Copyright (c) OpenMMLab. All rights reserved. from .functional import * # noqa: F401,F403 +from .metrics import * # noqa: F401,F403 diff --git a/mmselfsup/evaluation/metrics/__init__.py b/mmselfsup/evaluation/metrics/__init__.py new file mode 100644 index 00000000..5b18dcd8 --- /dev/null +++ b/mmselfsup/evaluation/metrics/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .shape_bias_label import ShapeBiasMetric + +__all__ = ['ShapeBiasMetric'] diff --git a/mmselfsup/evaluation/metrics/shape_bias_label.py b/mmselfsup/evaluation/metrics/shape_bias_label.py new file mode 100644 index 00000000..c136e04a --- /dev/null +++ b/mmselfsup/evaluation/metrics/shape_bias_label.py @@ -0,0 +1,171 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import csv +import os +import os.path as osp +from typing import List, Sequence + +import numpy as np +import torch +from mmengine.evaluator import BaseMetric + +from mmselfsup.registry import METRICS + + +@METRICS.register_module() +class ShapeBiasMetric(BaseMetric): + """Evaluate the model on ``cue_conflict`` dataset. + + This module will evaluate the model on an OOD dataset, cue_conflict, in + order to measure the shape bias of the model. In addition to compuate the + Top-1 accuracy, this module also generate a csv file to record the + detailed prediction results, such that this csv file can be used to + generate the shape bias curve. + + Args: + csv_dir (str): The directory to save the csv file. + model_name (str): The name of the csv file. Please note that the + model name should be an unique identifier. + dataset_name (str): The name of the dataset. Default: 'cue_conflict'. + """ + + # mapping several classes from ImageNet-1K to the same category + airplane_indices = [404] + bear_indices = [294, 295, 296, 297] + bicycle_indices = [444, 671] + bird_indices = [ + 8, 10, 11, 12, 13, 14, 15, 16, 18, 19, 20, 22, 23, 24, 80, 81, 82, 83, + 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 98, 99, 100, 127, 128, 129, + 130, 131, 132, 133, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, + 145 + ] + boat_indices = [472, 554, 625, 814, 914] + bottle_indices = [440, 720, 737, 898, 899, 901, 907] + car_indices = [436, 511, 817] + cat_indices = [281, 282, 283, 284, 285, 286] + chair_indices = [423, 559, 765, 857] + clock_indices = [409, 530, 892] + dog_indices = [ + 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, + 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, + 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 193, 194, + 195, 196, 197, 198, 199, 200, 201, 202, 203, 205, 206, 207, 208, 209, + 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, + 224, 225, 226, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, + 239, 240, 241, 243, 244, 245, 246, 247, 248, 249, 250, 252, 253, 254, + 255, 256, 257, 259, 261, 262, 263, 265, 266, 267, 268 + ] + elephant_indices = [385, 386] + keyboard_indices = [508, 878] + knife_indices = [499] + oven_indices = [766] + truck_indices = [555, 569, 656, 675, 717, 734, 864, 867] + + def __init__(self, + csv_dir: str, + model_name: str, + dataset_name: str = 'cue_conflict', + **kwargs) -> None: + super().__init__(**kwargs) + + self.categories = sorted([ + 'knife', 'keyboard', 'elephant', 'bicycle', 'airplane', 'clock', + 'oven', 'chair', 'bear', 'boat', 'cat', 'bottle', 'truck', 'car', + 'bird', 'dog' + ]) + self.csv_dir = csv_dir + self.model_name = model_name + self.dataset_name = dataset_name + self.csv_path = self.create_csv() + + def process(self, data_batch, data_samples: Sequence[dict]) -> None: + """Process one batch of data samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + for data_sample in data_samples: + result = dict() + pred_label = data_sample['pred_label'] + gt_label = data_sample['gt_label'] + if 'score' in pred_label: + result['pred_score'] = pred_label['score'].cpu() + else: + result['pred_label'] = pred_label['label'].cpu() + result['gt_label'] = gt_label['label'].cpu() + result['gt_category'] = data_sample['img_path'].split('/')[-2] + result['img_name'] = data_sample['img_path'].split('/')[-1] + + aggregated_category_probabilities = [] + # get the prediction for each category of current instance + for category in self.categories: + category_indices = getattr(self, f'{category}_indices') + category_probabilities = torch.gather( + result['pred_score'], 0, + torch.tensor(category_indices)).mean() + aggregated_category_probabilities.append( + category_probabilities) + # sort the probabilities in descending order + pred_indices = torch.stack(aggregated_category_probabilities + ).argsort(descending=True).numpy() + result['pred_category'] = np.take(self.categories, pred_indices) + + # Save the result to `self.results`. + self.results.append(result) + + def create_csv(self) -> str: + """Create a csv file to store the results.""" + session_name = 'session-1' + csv_path = osp.join( + self.csv_dir, self.dataset_name + '_' + self.model_name + '_' + + session_name + '.csv') + if osp.exists(csv_path): + os.remove(csv_path) + directory = osp.dirname(csv_path) + if not osp.exists(directory): + os.makedirs(directory) + with open(csv_path, 'w') as f: + writer = csv.writer(f) + writer.writerow([ + 'subj', 'session', 'trial', 'rt', 'object_response', + 'category', 'condition', 'imagename' + ]) + return csv_path + + def dump_results_to_csv(self, results: List[dict]) -> None: + """Dump the results to a csv file. + + Args: + results (List[dict]): A list of results. + """ + for i, result in enumerate(results): + img_name = result['img_name'] + category = result['gt_category'] + condition = 'NaN' + with open(self.csv_path, 'a') as f: + writer = csv.writer(f) + writer.writerow([ + self.model_name, 1, i + 1, 'NaN', + result['pred_category'][0], category, condition, img_name + ]) + + def compute_metrics(self, results: List[dict]) -> dict: + """Compute the metrics from the results. + + Args: + results (List[dict]): A list of results. + + Returns: + dict: A dict of metrics. + """ + self.dump_results_to_csv(results) + metrics = dict() + metrics['accuracy/top1'] = np.mean([ + result['pred_category'][0] == result['gt_category'] + for result in results + ]) + + return metrics diff --git a/requirements/optional.txt b/requirements/optional.txt index f749cfd9..d9690686 100644 --- a/requirements/optional.txt +++ b/requirements/optional.txt @@ -1 +1,2 @@ faiss-gpu==1.7.2 +pandas diff --git a/tests/data/test_test_session-1.csv b/tests/data/test_test_session-1.csv new file mode 100644 index 00000000..2c6a5158 --- /dev/null +++ b/tests/data/test_test_session-1.csv @@ -0,0 +1 @@ +subj,session,trial,rt,object_response,category,condition,imagename diff --git a/tests/test_evaluation/test_metrics/test_shape_bias_metric.py b/tests/test_evaluation/test_metrics/test_shape_bias_metric.py new file mode 100644 index 00000000..403d3222 --- /dev/null +++ b/tests/test_evaluation/test_metrics/test_shape_bias_metric.py @@ -0,0 +1,15 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmselfsup.evaluation import ShapeBiasMetric + + +def test_shape_bias_metric(): + data_sample = dict() + data_sample['pred_label'] = dict( + score=torch.rand(1000, ), label=torch.tensor(1)) + data_sample['gt_label'] = dict(label=torch.tensor(1)) + data_sample['img_path'] = 'tests/airplane/test.JPEG' + evaluator = ShapeBiasMetric( + csv_dir='tests/data', dataset_name='test', model_name='test') + evaluator.process(None, [data_sample]) diff --git a/tools/analysis_tools/utils.py b/tools/analysis_tools/utils.py new file mode 100644 index 00000000..184cb32a --- /dev/null +++ b/tools/analysis_tools/utils.py @@ -0,0 +1,277 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modified from https://github.com/bethgelab/model-vs-human +from typing import Any, Dict, List, Optional + +import matplotlib as mpl +import pandas as pd +from matplotlib import _api +from matplotlib import transforms as mtransforms + + +class _DummyAxis: + """Define the minimal interface for a dummy axis. + + Args: + minpos (float): The minimum positive value for the axis. Defaults to 0. + """ + __name__ = 'dummy' + + # Once the deprecation elapses, replace dataLim and viewLim by plain + # _view_interval and _data_interval private tuples. + dataLim = _api.deprecate_privatize_attribute( + '3.6', alternative='get_data_interval() and set_data_interval()') + viewLim = _api.deprecate_privatize_attribute( + '3.6', alternative='get_view_interval() and set_view_interval()') + + def __init__(self, minpos: float = 0) -> None: + self._dataLim = mtransforms.Bbox.unit() + self._viewLim = mtransforms.Bbox.unit() + self._minpos = minpos + + def get_view_interval(self) -> Dict: + """Return the view interval as a tuple (*vmin*, *vmax*).""" + return self._viewLim.intervalx + + def set_view_interval(self, vmin: float, vmax: float) -> None: + """Set the view interval to (*vmin*, *vmax*).""" + self._viewLim.intervalx = vmin, vmax + + def get_minpos(self) -> float: + """Return the minimum positive value for the axis.""" + return self._minpos + + def get_data_interval(self) -> Dict: + """Return the data interval as a tuple (*vmin*, *vmax*).""" + return self._dataLim.intervalx + + def set_data_interval(self, vmin: float, vmax: float) -> None: + """Set the data interval to (*vmin*, *vmax*).""" + self._dataLim.intervalx = vmin, vmax + + def get_tick_space(self) -> int: + """Return the number of ticks to use.""" + # Just use the long-standing default of nbins==9 + return 9 + + +class TickHelper: + """A helper class for ticks and tick labels.""" + axis = None + + def set_axis(self, axis: Any) -> None: + """Set the axis instance.""" + self.axis = axis + + def create_dummy_axis(self, **kwargs) -> None: + """Create a dummy axis if no axis is set.""" + if self.axis is None: + self.axis = _DummyAxis(**kwargs) + + @_api.deprecated('3.5', alternative='`.Axis.set_view_interval`') + def set_view_interval(self, vmin: float, vmax: float) -> None: + """Set the view interval to (*vmin*, *vmax*).""" + self.axis.set_view_interval(vmin, vmax) + + @_api.deprecated('3.5', alternative='`.Axis.set_data_interval`') + def set_data_interval(self, vmin: float, vmax: float) -> None: + """Set the data interval to (*vmin*, *vmax*).""" + self.axis.set_data_interval(vmin, vmax) + + @_api.deprecated( + '3.5', + alternative='`.Axis.set_view_interval` and `.Axis.set_data_interval`') + def set_bounds(self, vmin: float, vmax: float) -> None: + """Set the view and data interval to (*vmin*, *vmax*).""" + self.set_view_interval(vmin, vmax) + self.set_data_interval(vmin, vmax) + + +class Formatter(TickHelper): + """Create a string based on a tick value and location.""" + # some classes want to see all the locs to help format + # individual ones + locs = [] + + def __call__(self, x: str, pos: Optional[Any] = None) -> str: + """Return the format for tick value *x* at position pos. + + ``pos=None`` indicates an unspecified location. + + This method must be overridden in the derived class. + + Args: + x (str): The tick value. + pos (Optional[Any]): The tick position. Defaults to None. + """ + raise NotImplementedError('Derived must override') + + def format_ticks(self, values: pd.Series) -> List[str]: + """Return the tick labels for all the ticks at once. + + Args: + values (pd.Series): The tick values. + + Returns: + List[str]: The tick labels. + """ + self.set_locs(values) + return [self(value, i) for i, value in enumerate(values)] + + def format_data(self, value: Any) -> str: + """Return the full string representation of the value with the position + unspecified. + + Args: + value (Any): The tick value. + + Returns: + str: The full string representation of the value. + """ + return self.__call__(value) + + def format_data_short(self, value: Any) -> str: + """Return a short string version of the tick value. + + Defaults to the position-independent long value. + + Args: + value (Any): The tick value. + + Returns: + str: The short string representation of the value. + """ + return self.format_data(value) + + def get_offset(self) -> str: + """Return the offset string.""" + return '' + + def set_locs(self, locs: List[Any]) -> None: + """Set the locations of the ticks. + + This method is called before computing the tick labels because some + formatters need to know all tick locations to do so. + """ + self.locs = locs + + @staticmethod + def fix_minus(s: str) -> str: + """Some classes may want to replace a hyphen for minus with the proper + Unicode symbol (U+2212) for typographical correctness. + + This is a + helper method to perform such a replacement when it is enabled via + :rc:`axes.unicode_minus`. + + Args: + s (str): The string to replace the hyphen with the Unicode symbol. + """ + return (s.replace('-', '\N{MINUS SIGN}') + if mpl.rcParams['axes.unicode_minus'] else s) + + def _set_locator(self, locator: Any) -> None: + """Subclasses may want to override this to set a locator.""" + pass + + +class FormatStrFormatter(Formatter): + """Use an old-style ('%' operator) format string to format the tick. + + The format string should have a single variable format (%) in it. + It will be applied to the value (not the position) of the tick. + + Negative numeric values will use a dash, not a Unicode minus; use mathtext + to get a Unicode minus by wrapping the format specifier with $ (e.g. + "$%g$"). + + Args: + fmt (str): Format string. + """ + + def __init__(self, fmt: str) -> None: + self.fmt = fmt + + def __call__(self, x: str, pos: Optional[Any]) -> str: + """Return the formatted label string. + + Only the value *x* is formatted. The position is ignored. + + Args: + x (str): The value to format. + pos (Any): The position of the tick. Ignored. + """ + return self.fmt % x + + +class ShapeBias: + """Compute the shape bias of a model. + + Reference: `ImageNet-trained CNNs are biased towards texture; + increasing shape bias improves accuracy and robustness + `_. + """ + num_input_models = 1 + + def __init__(self) -> None: + super().__init__() + self.plotting_name = 'shape-bias' + + @staticmethod + def _check_dataframe(df: pd.DataFrame) -> None: + """Check that the dataframe is valid.""" + assert len(df) > 0, 'empty dataframe' + + def analysis(self, df: pd.DataFrame) -> Dict[str, float]: + """Compute the shape bias of a model. + + Args: + df (pd.DataFrame): The dataframe containing the data. + + Returns: + Dict[str, float]: The shape bias. + """ + self._check_dataframe(df) + + df = df.copy() + df['correct_texture'] = df['imagename'].apply( + self.get_texture_category) + df['correct_shape'] = df['category'] + + # remove those rows where shape = texture, i.e. no cue conflict present + df2 = df.loc[df.correct_shape != df.correct_texture] + fraction_correct_shape = len( + df2.loc[df2.object_response == df2.correct_shape]) / len(df) + fraction_correct_texture = len( + df2.loc[df2.object_response == df2.correct_texture]) / len(df) + shape_bias = fraction_correct_shape / ( + fraction_correct_shape + fraction_correct_texture) + + result_dict = { + 'fraction-correct-shape': fraction_correct_shape, + 'fraction-correct-texture': fraction_correct_texture, + 'shape-bias': shape_bias + } + return result_dict + + def get_texture_category(self, imagename: str) -> str: + """Return texture category from imagename. + + e.g. 'XXX_dog10-bird2.png' -> 'bird ' + + Args: + imagename (str): Name of the image. + + Returns: + str: Texture category. + """ + assert type(imagename) is str + + # remove unnecessary words + a = imagename.split('_')[-1] + # remove .png etc. + b = a.split('.')[0] + # get texture category (last word) + c = b.split('-')[-1] + # remove number, e.g. 'bird2' -> 'bird' + d = ''.join([i for i in c if not i.isdigit()]) + return d diff --git a/tools/analysis_tools/visualize_shape_bias.py b/tools/analysis_tools/visualize_shape_bias.py new file mode 100644 index 00000000..43187e6f --- /dev/null +++ b/tools/analysis_tools/visualize_shape_bias.py @@ -0,0 +1,264 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modified from https://github.com/bethgelab/model-vs-human +import argparse +import os +import os.path as osp + +import matplotlib as mpl +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from utils import FormatStrFormatter, ShapeBias + +# global default boundary settings for thin gray transparent +# boundaries to avoid not being able to see the difference +# between two partially overlapping datapoints of the same color: +PLOTTING_EDGE_COLOR = (0.3, 0.3, 0.3, 0.3) +PLOTTING_EDGE_WIDTH = 0.02 +ICONS_DIR = osp.join( + osp.dirname(__file__), '..', '..', 'resources', 'shape_bias_icons') + +parser = argparse.ArgumentParser() +parser.add_argument('--csv-dir', type=str, help='directory of csv files') +parser.add_argument( + '--result-dir', type=str, help='directory to save plotting results') +parser.add_argument('--model-names', nargs='+', default=[], help='model name') +parser.add_argument( + '--colors', + nargs='+', + type=float, + default=[], + help= # noqa + 'the colors for the plots of each model, and they should be in the same order as model_names' # noqa: E501 +) +parser.add_argument( + '--markers', + nargs='+', + type=str, + default=[], + help= # noqa + 'the markers for the plots of each model, and they should be in the same order as model_names' # noqa: E501 +) +parser.add_argument( + '--plotting-names', + nargs='+', + default=[], + help= # noqa + 'the plotting names for the plots of each model, and they should be in the same order as model_names' # noqa: E501 +) + +humans = [ + 'subject-01', 'subject-02', 'subject-03', 'subject-04', 'subject-05', + 'subject-06', 'subject-07', 'subject-08', 'subject-09', 'subject-10' +] + + +def read_csvs(csv_dir: str) -> pd.DataFrame: + """Reads all csv files in a directory and returns a single dataframe. + + Args: + csv_dir (str): directory of csv files. + + Returns: + pd.DataFrame: dataframe containing all csv files + """ + df = pd.DataFrame() + for csv in os.listdir(csv_dir): + if csv.endswith('.csv'): + cur_df = pd.read_csv(osp.join(csv_dir, csv)) + cur_df.columns = [c.lower() for c in cur_df.columns] + df = df.append(cur_df) + df.condition = df.condition.astype(str) + return df + + +def plot_shape_bias_matrixplot(args, analysis=ShapeBias()) -> None: + """Plots a matrixplot of shape bias. + + Args: + args (argparse.Namespace): arguments. + analysis (ShapeBias): shape bias analysis. Defaults to ShapeBias(). + """ + mpl.rcParams['font.family'] = ['serif'] + mpl.rcParams['font.serif'] = ['Times New Roman'] + + plt.figure(figsize=(9, 7)) + + df = read_csvs(args.csv_dir) + + fontsize = 15 + ticklength = 10 + markersize = 250 + label_size = 20 + + classes = df['category'].unique() + num_classes = len(classes) + + # plot setup + fig = plt.figure(1, figsize=(12, 12), dpi=300.) + ax = plt.gca() + + ax.set_xlim([0, 1]) + ax.set_ylim([-.5, num_classes - 0.5]) + + # secondary reversed x axis + ax_top = ax.secondary_xaxis( + 'top', functions=(lambda x: 1 - x, lambda x: 1 - x)) + + # labels, ticks + plt.tick_params( + axis='y', which='both', left=False, right=False, labelleft=False) + ax.set_ylabel('Shape categories', labelpad=60, fontsize=label_size) + ax.set_xlabel( + "Fraction of 'texture' decisions", fontsize=label_size, labelpad=25) + ax_top.set_xlabel( + "Fraction of 'shape' decisions", fontsize=label_size, labelpad=25) + ax.xaxis.set_major_formatter(FormatStrFormatter('%g')) + ax_top.xaxis.set_major_formatter(FormatStrFormatter('%g')) + ax.get_xaxis().set_ticks(np.arange(0, 1.1, 0.1)) + ax_top.set_ticks(np.arange(0, 1.1, 0.1)) + ax.tick_params( + axis='both', which='major', labelsize=fontsize, length=ticklength) + ax_top.tick_params( + axis='both', which='major', labelsize=fontsize, length=ticklength) + + # arrows on x axes + plt.arrow( + x=0, + y=-1.75, + dx=1, + dy=0, + fc='black', + head_width=0.4, + head_length=0.03, + clip_on=False, + length_includes_head=True, + overhang=0.5) + plt.arrow( + x=1, + y=num_classes + 0.75, + dx=-1, + dy=0, + fc='black', + head_width=0.4, + head_length=0.03, + clip_on=False, + length_includes_head=True, + overhang=0.5) + + # icons besides y axis + # determine order of icons + df_selection = df.loc[(df['subj'].isin(humans))] + class_avgs = [] + for cl in classes: + df_class_selection = df_selection.query("category == '{}'".format(cl)) + class_avgs.append(1 - analysis.analysis( + df=df_class_selection)['shape-bias']) + sorted_indices = np.argsort(class_avgs) + classes = classes[sorted_indices] + + # icon placement is calculated in axis coordinates + WIDTH = 1 / num_classes + # placement left of yaxis (-WIDTH) plus some spacing (-.25*WIDTH) + XPOS = -1.25 * WIDTH + YPOS = -0.5 + HEIGHT = 1 + MARGINX = 1 / 10 * WIDTH # vertical whitespace between icons + MARGINY = 1 / 10 * HEIGHT # horizontal whitespace between icons + + left = XPOS + MARGINX + right = XPOS + WIDTH - MARGINX + + for i in range(num_classes): + bottom = i + MARGINY + YPOS + top = (i + 1) - MARGINY + YPOS + iconpath = osp.join(ICONS_DIR, '{}.png'.format(classes[i])) + plt.imshow( + plt.imread(iconpath), + extent=[left, right, bottom, top], + aspect='auto', + clip_on=False) + + # plot horizontal intersection lines + for i in range(num_classes - 1): + plt.plot([0, 1], [i + .5, i + .5], + c='gray', + linestyle='dotted', + alpha=0.4) + + # plot average shapebias + scatter points + for i in range(len(args.model_names)): + df_selection = df.loc[(df['subj'].isin(args.model_names[i]))] + result_df = analysis.analysis(df=df_selection) + avg = 1 - result_df['shape-bias'] + ax.plot([avg, avg], [-1, num_classes], color=args.colors[i]) + class_avgs = [] + for cl in classes: + df_class_selection = df_selection.query( + "category == '{}'".format(cl)) + class_avgs.append(1 - analysis.analysis( + df=df_class_selection)['shape-bias']) + + ax.scatter( + class_avgs, + classes, + color=args.colors[i], + marker=args.markers[i], + label=args.plotting_names[i], + s=markersize, + clip_on=False, + edgecolors=PLOTTING_EDGE_COLOR, + linewidths=PLOTTING_EDGE_WIDTH, + zorder=3) + plt.legend(frameon=True, labelspacing=1, loc=9) + + figure_path = osp.join(args.result_dir, + 'cue-conflict_shape-bias_matrixplot.pdf') + fig.savefig(figure_path, bbox_inches='tight') + plt.close() + + +if __name__ == '__main__': + icon_names = [ + 'airplane.png', 'response_icons_vertical_reverse.png', 'bottle.png', + 'car.png', 'oven.png', 'elephant.png', 'dog.png', 'boat.png', + 'clock.png', 'chair.png', 'keyboard.png', 'bird.png', 'bicycle.png', + 'response_icons_horizontal.png', 'cat.png', 'bear.png', 'colorbar.pdf', + 'knife.png', 'response_icons_vertical.png', 'truck.png' + ] + root_url = 'https://github.com/bethgelab/model-vs-human/raw/master/assets/icons' # noqa: E501 + os.makedirs(ICONS_DIR, exist_ok=True) + for icon_name in icon_names: + url = osp.join(root_url, icon_name) + os.system('wget -O {} {}'.format(osp.join(ICONS_DIR, icon_name), url)) + + args = parser.parse_args() + assert len(args.model_names) * 3 == len(args.colors), 'Number of colors \ + must be 3 times the number of models. Every three colors are the RGB \ + values for one model.' + + # preprocess colors + args.colors = [c / 255. for c in args.colors] + colors = [] + for i in range(len(args.model_names)): + colors.append(args.colors[3 * i:3 * i + 3]) + args.colors = colors + args.colors.append([165 / 255., 30 / 255., 55 / 255.]) # human color + + # if plotting names are not specified, use model names + if len(args.plotting_names) == 0: + args.plotting_names = args.model_names + + # preprocess markers + args.markers.append('D') # human marker + + # preprocess model names + args.model_names = [[m] for m in args.model_names] + args.model_names.append(humans) + + # preprocess plotting names + args.plotting_names.append('Humans') + + plot_shape_bias_matrixplot(args) + + os.system('rm -rf {}'.format(ICONS_DIR))