[Feature]: Shape bias (#635)

* [Feature]: Add shape bias vis

* [Fix]: Fix lint

* [Feature]: Add shape bias metrics

* [Fix]: Fix lint

* [Fix]: Delete redundant code

* [Feature]: Add shape bias doc

* [Fix]: Fix lint

* [Feature]: add UT

* [Fix]: Fix lint

* [Fix]: Fix typo

* [Fix]: Fix typo

* [Fix]: Fix args param style

* [Feature]: Download pic automatically
This commit is contained in:
Yuan Liu 2022-12-30 22:11:18 +08:00 committed by GitHub
parent 1fd3509f7b
commit 304e81650a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 820 additions and 0 deletions

View File

@ -12,6 +12,11 @@ Visualization can give an intuitive interpretation of the performance of the mod
- [Visualize Datasets](#visualize-datasets)
- [Visualize t-SNE](#visualize-t-sne)
- [Visualize Low-level Feature Reconstruction](#visualize-low-level-feature-reconstruction)
- [Visualize Shape Bias](#visualize-shape-bias)
- [Prepare the dataset](#prepare-the-dataset)
- [Modify the config for classification](#modify-the-config-for-classification)
- [Inference your model with above modified config file](#inference-your-model-with-above-modified-config-file)
- [Plot shape bias](#plot-shape-bias)
<!-- /TOC -->
@ -205,3 +210,84 @@ Results of MaskFeat:
<div align="center">
<img src="https://user-images.githubusercontent.com/36138628/200465876-7e7dcb6f-5e8d-4d80-b300-9e1847cb975f.jpg" width="800" />
</div>
## Visualize Shape Bias
Shape bias measures how a model relies the shapes, compared to texture, to sense the semantics in images. For more details,
we recommend interested readers to this [paper](https://arxiv.org/abs/2106.07411). MMSelfSup provide an off-the-shelf toolbox to
obtain the shape bias of a classification model. You can following these steps below:
### Prepare the dataset
First you should download the [cue-conflict](https://github.com/bethgelab/model-vs-human/releases/download/v0.1/cue-conflict.tar.gz) to `data` folder,
and then unzip this dataset. After that, you `data` folder should have the following structure:
```text
data
├──cue-conflict
| |──airplane
| |──bear
| ...
| |── truck
```
### Modify the config for classification
Replace the original test_dataloader and test_evaluation with following configurations
```python
test_dataloader = dict(
dataset=dict(
type='CustomDataset',
data_root='data/cue-conflict',
_delete_=True),
drop_last=False)
test_evaluator = dict(
type='mmselfsup.ShapeBiasMetric',
_delete_=True,
csv_dir='directory/to/save/the/csv/file',
model_name='your_model_name')
```
Please note you should make custom modifications to the `csv_dir` and `model_name`.
### Inference your model with above modified config file
Then you should inferece your model on the `cue-conflict` dataset with the your modified config files.
```shell
# For Slurm
GPUS_PER_NODE=1 GPUS=1 bash tools/benchmarks/classification/mim_slurm_test.sh $partition $config $checkpoint
```
```shell
# For PyTorch
GPUS=1 bash tools/benchmarks/classification/mim_dist_test.sh $config $checkpoint
```
After that, you should obtain a csv file, named `cue-conflict_model-name_session-1.csv`. Besides this file, you should
also download these [csv files](https://github.com/bethgelab/model-vs-human/tree/master/raw-data/cue-conflict) to the
`csv_dir`.
### Plot shape bias
Then we can start to plot the shape bias
```shell
python tools/analysis_tools/visualize_shape_bias.py --csv-dir $CVS_DIR --result-dir $CSV_DIR --colors $RGB --markers o --plotting-names $YOU_MODEL_NAME --model-names $YOU_MODEL_NAME
```
- csv-dir, the same directory to save these csv files
- colors, should be the RGB values, formatted in R G B, e.g. 100 100 100, and can be multiple RGB values, if you want
to plot the shape bias of several models
- plotting-names, the name of the legend in the shape bias figure, and you can set it as your model name. If you want
to plot several models, plotting_names can be multiple values
- model-names, should be the same name specified in your config, and can be multiple names if you want to plot the
shape bias of several models
Please note, every three values for `--colors` corresponds to one value for `--model-names`. After all of above steps, you
are expected to obtain the following figure.
<div align="center">
<img src="https://user-images.githubusercontent.com/30762564/208357938-c744d3c3-7e08-468e-82b7-fc5f1804da59.png" width="400" />
</div>

View File

@ -1,2 +1,3 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .functional import * # noqa: F401,F403
from .metrics import * # noqa: F401,F403

View File

@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .shape_bias_label import ShapeBiasMetric
__all__ = ['ShapeBiasMetric']

View File

@ -0,0 +1,171 @@
# Copyright (c) OpenMMLab. All rights reserved.
import csv
import os
import os.path as osp
from typing import List, Sequence
import numpy as np
import torch
from mmengine.evaluator import BaseMetric
from mmselfsup.registry import METRICS
@METRICS.register_module()
class ShapeBiasMetric(BaseMetric):
"""Evaluate the model on ``cue_conflict`` dataset.
This module will evaluate the model on an OOD dataset, cue_conflict, in
order to measure the shape bias of the model. In addition to compuate the
Top-1 accuracy, this module also generate a csv file to record the
detailed prediction results, such that this csv file can be used to
generate the shape bias curve.
Args:
csv_dir (str): The directory to save the csv file.
model_name (str): The name of the csv file. Please note that the
model name should be an unique identifier.
dataset_name (str): The name of the dataset. Default: 'cue_conflict'.
"""
# mapping several classes from ImageNet-1K to the same category
airplane_indices = [404]
bear_indices = [294, 295, 296, 297]
bicycle_indices = [444, 671]
bird_indices = [
8, 10, 11, 12, 13, 14, 15, 16, 18, 19, 20, 22, 23, 24, 80, 81, 82, 83,
87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 98, 99, 100, 127, 128, 129,
130, 131, 132, 133, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144,
145
]
boat_indices = [472, 554, 625, 814, 914]
bottle_indices = [440, 720, 737, 898, 899, 901, 907]
car_indices = [436, 511, 817]
cat_indices = [281, 282, 283, 284, 285, 286]
chair_indices = [423, 559, 765, 857]
clock_indices = [409, 530, 892]
dog_indices = [
152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165,
166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179,
180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 193, 194,
195, 196, 197, 198, 199, 200, 201, 202, 203, 205, 206, 207, 208, 209,
210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223,
224, 225, 226, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238,
239, 240, 241, 243, 244, 245, 246, 247, 248, 249, 250, 252, 253, 254,
255, 256, 257, 259, 261, 262, 263, 265, 266, 267, 268
]
elephant_indices = [385, 386]
keyboard_indices = [508, 878]
knife_indices = [499]
oven_indices = [766]
truck_indices = [555, 569, 656, 675, 717, 734, 864, 867]
def __init__(self,
csv_dir: str,
model_name: str,
dataset_name: str = 'cue_conflict',
**kwargs) -> None:
super().__init__(**kwargs)
self.categories = sorted([
'knife', 'keyboard', 'elephant', 'bicycle', 'airplane', 'clock',
'oven', 'chair', 'bear', 'boat', 'cat', 'bottle', 'truck', 'car',
'bird', 'dog'
])
self.csv_dir = csv_dir
self.model_name = model_name
self.dataset_name = dataset_name
self.csv_path = self.create_csv()
def process(self, data_batch, data_samples: Sequence[dict]) -> None:
"""Process one batch of data samples.
The processed results should be stored in ``self.results``, which will
be used to computed the metrics when all batches have been processed.
Args:
data_batch: A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from the model.
"""
for data_sample in data_samples:
result = dict()
pred_label = data_sample['pred_label']
gt_label = data_sample['gt_label']
if 'score' in pred_label:
result['pred_score'] = pred_label['score'].cpu()
else:
result['pred_label'] = pred_label['label'].cpu()
result['gt_label'] = gt_label['label'].cpu()
result['gt_category'] = data_sample['img_path'].split('/')[-2]
result['img_name'] = data_sample['img_path'].split('/')[-1]
aggregated_category_probabilities = []
# get the prediction for each category of current instance
for category in self.categories:
category_indices = getattr(self, f'{category}_indices')
category_probabilities = torch.gather(
result['pred_score'], 0,
torch.tensor(category_indices)).mean()
aggregated_category_probabilities.append(
category_probabilities)
# sort the probabilities in descending order
pred_indices = torch.stack(aggregated_category_probabilities
).argsort(descending=True).numpy()
result['pred_category'] = np.take(self.categories, pred_indices)
# Save the result to `self.results`.
self.results.append(result)
def create_csv(self) -> str:
"""Create a csv file to store the results."""
session_name = 'session-1'
csv_path = osp.join(
self.csv_dir, self.dataset_name + '_' + self.model_name + '_' +
session_name + '.csv')
if osp.exists(csv_path):
os.remove(csv_path)
directory = osp.dirname(csv_path)
if not osp.exists(directory):
os.makedirs(directory)
with open(csv_path, 'w') as f:
writer = csv.writer(f)
writer.writerow([
'subj', 'session', 'trial', 'rt', 'object_response',
'category', 'condition', 'imagename'
])
return csv_path
def dump_results_to_csv(self, results: List[dict]) -> None:
"""Dump the results to a csv file.
Args:
results (List[dict]): A list of results.
"""
for i, result in enumerate(results):
img_name = result['img_name']
category = result['gt_category']
condition = 'NaN'
with open(self.csv_path, 'a') as f:
writer = csv.writer(f)
writer.writerow([
self.model_name, 1, i + 1, 'NaN',
result['pred_category'][0], category, condition, img_name
])
def compute_metrics(self, results: List[dict]) -> dict:
"""Compute the metrics from the results.
Args:
results (List[dict]): A list of results.
Returns:
dict: A dict of metrics.
"""
self.dump_results_to_csv(results)
metrics = dict()
metrics['accuracy/top1'] = np.mean([
result['pred_category'][0] == result['gt_category']
for result in results
])
return metrics

View File

@ -1 +1,2 @@
faiss-gpu==1.7.2
pandas

View File

@ -0,0 +1 @@
subj,session,trial,rt,object_response,category,condition,imagename
1 subj session trial rt object_response category condition imagename

View File

@ -0,0 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmselfsup.evaluation import ShapeBiasMetric
def test_shape_bias_metric():
data_sample = dict()
data_sample['pred_label'] = dict(
score=torch.rand(1000, ), label=torch.tensor(1))
data_sample['gt_label'] = dict(label=torch.tensor(1))
data_sample['img_path'] = 'tests/airplane/test.JPEG'
evaluator = ShapeBiasMetric(
csv_dir='tests/data', dataset_name='test', model_name='test')
evaluator.process(None, [data_sample])

View File

@ -0,0 +1,277 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Modified from https://github.com/bethgelab/model-vs-human
from typing import Any, Dict, List, Optional
import matplotlib as mpl
import pandas as pd
from matplotlib import _api
from matplotlib import transforms as mtransforms
class _DummyAxis:
"""Define the minimal interface for a dummy axis.
Args:
minpos (float): The minimum positive value for the axis. Defaults to 0.
"""
__name__ = 'dummy'
# Once the deprecation elapses, replace dataLim and viewLim by plain
# _view_interval and _data_interval private tuples.
dataLim = _api.deprecate_privatize_attribute(
'3.6', alternative='get_data_interval() and set_data_interval()')
viewLim = _api.deprecate_privatize_attribute(
'3.6', alternative='get_view_interval() and set_view_interval()')
def __init__(self, minpos: float = 0) -> None:
self._dataLim = mtransforms.Bbox.unit()
self._viewLim = mtransforms.Bbox.unit()
self._minpos = minpos
def get_view_interval(self) -> Dict:
"""Return the view interval as a tuple (*vmin*, *vmax*)."""
return self._viewLim.intervalx
def set_view_interval(self, vmin: float, vmax: float) -> None:
"""Set the view interval to (*vmin*, *vmax*)."""
self._viewLim.intervalx = vmin, vmax
def get_minpos(self) -> float:
"""Return the minimum positive value for the axis."""
return self._minpos
def get_data_interval(self) -> Dict:
"""Return the data interval as a tuple (*vmin*, *vmax*)."""
return self._dataLim.intervalx
def set_data_interval(self, vmin: float, vmax: float) -> None:
"""Set the data interval to (*vmin*, *vmax*)."""
self._dataLim.intervalx = vmin, vmax
def get_tick_space(self) -> int:
"""Return the number of ticks to use."""
# Just use the long-standing default of nbins==9
return 9
class TickHelper:
"""A helper class for ticks and tick labels."""
axis = None
def set_axis(self, axis: Any) -> None:
"""Set the axis instance."""
self.axis = axis
def create_dummy_axis(self, **kwargs) -> None:
"""Create a dummy axis if no axis is set."""
if self.axis is None:
self.axis = _DummyAxis(**kwargs)
@_api.deprecated('3.5', alternative='`.Axis.set_view_interval`')
def set_view_interval(self, vmin: float, vmax: float) -> None:
"""Set the view interval to (*vmin*, *vmax*)."""
self.axis.set_view_interval(vmin, vmax)
@_api.deprecated('3.5', alternative='`.Axis.set_data_interval`')
def set_data_interval(self, vmin: float, vmax: float) -> None:
"""Set the data interval to (*vmin*, *vmax*)."""
self.axis.set_data_interval(vmin, vmax)
@_api.deprecated(
'3.5',
alternative='`.Axis.set_view_interval` and `.Axis.set_data_interval`')
def set_bounds(self, vmin: float, vmax: float) -> None:
"""Set the view and data interval to (*vmin*, *vmax*)."""
self.set_view_interval(vmin, vmax)
self.set_data_interval(vmin, vmax)
class Formatter(TickHelper):
"""Create a string based on a tick value and location."""
# some classes want to see all the locs to help format
# individual ones
locs = []
def __call__(self, x: str, pos: Optional[Any] = None) -> str:
"""Return the format for tick value *x* at position pos.
``pos=None`` indicates an unspecified location.
This method must be overridden in the derived class.
Args:
x (str): The tick value.
pos (Optional[Any]): The tick position. Defaults to None.
"""
raise NotImplementedError('Derived must override')
def format_ticks(self, values: pd.Series) -> List[str]:
"""Return the tick labels for all the ticks at once.
Args:
values (pd.Series): The tick values.
Returns:
List[str]: The tick labels.
"""
self.set_locs(values)
return [self(value, i) for i, value in enumerate(values)]
def format_data(self, value: Any) -> str:
"""Return the full string representation of the value with the position
unspecified.
Args:
value (Any): The tick value.
Returns:
str: The full string representation of the value.
"""
return self.__call__(value)
def format_data_short(self, value: Any) -> str:
"""Return a short string version of the tick value.
Defaults to the position-independent long value.
Args:
value (Any): The tick value.
Returns:
str: The short string representation of the value.
"""
return self.format_data(value)
def get_offset(self) -> str:
"""Return the offset string."""
return ''
def set_locs(self, locs: List[Any]) -> None:
"""Set the locations of the ticks.
This method is called before computing the tick labels because some
formatters need to know all tick locations to do so.
"""
self.locs = locs
@staticmethod
def fix_minus(s: str) -> str:
"""Some classes may want to replace a hyphen for minus with the proper
Unicode symbol (U+2212) for typographical correctness.
This is a
helper method to perform such a replacement when it is enabled via
:rc:`axes.unicode_minus`.
Args:
s (str): The string to replace the hyphen with the Unicode symbol.
"""
return (s.replace('-', '\N{MINUS SIGN}')
if mpl.rcParams['axes.unicode_minus'] else s)
def _set_locator(self, locator: Any) -> None:
"""Subclasses may want to override this to set a locator."""
pass
class FormatStrFormatter(Formatter):
"""Use an old-style ('%' operator) format string to format the tick.
The format string should have a single variable format (%) in it.
It will be applied to the value (not the position) of the tick.
Negative numeric values will use a dash, not a Unicode minus; use mathtext
to get a Unicode minus by wrapping the format specifier with $ (e.g.
"$%g$").
Args:
fmt (str): Format string.
"""
def __init__(self, fmt: str) -> None:
self.fmt = fmt
def __call__(self, x: str, pos: Optional[Any]) -> str:
"""Return the formatted label string.
Only the value *x* is formatted. The position is ignored.
Args:
x (str): The value to format.
pos (Any): The position of the tick. Ignored.
"""
return self.fmt % x
class ShapeBias:
"""Compute the shape bias of a model.
Reference: `ImageNet-trained CNNs are biased towards texture;
increasing shape bias improves accuracy and robustness
<https://arxiv.org/abs/1811.12231>`_.
"""
num_input_models = 1
def __init__(self) -> None:
super().__init__()
self.plotting_name = 'shape-bias'
@staticmethod
def _check_dataframe(df: pd.DataFrame) -> None:
"""Check that the dataframe is valid."""
assert len(df) > 0, 'empty dataframe'
def analysis(self, df: pd.DataFrame) -> Dict[str, float]:
"""Compute the shape bias of a model.
Args:
df (pd.DataFrame): The dataframe containing the data.
Returns:
Dict[str, float]: The shape bias.
"""
self._check_dataframe(df)
df = df.copy()
df['correct_texture'] = df['imagename'].apply(
self.get_texture_category)
df['correct_shape'] = df['category']
# remove those rows where shape = texture, i.e. no cue conflict present
df2 = df.loc[df.correct_shape != df.correct_texture]
fraction_correct_shape = len(
df2.loc[df2.object_response == df2.correct_shape]) / len(df)
fraction_correct_texture = len(
df2.loc[df2.object_response == df2.correct_texture]) / len(df)
shape_bias = fraction_correct_shape / (
fraction_correct_shape + fraction_correct_texture)
result_dict = {
'fraction-correct-shape': fraction_correct_shape,
'fraction-correct-texture': fraction_correct_texture,
'shape-bias': shape_bias
}
return result_dict
def get_texture_category(self, imagename: str) -> str:
"""Return texture category from imagename.
e.g. 'XXX_dog10-bird2.png' -> 'bird '
Args:
imagename (str): Name of the image.
Returns:
str: Texture category.
"""
assert type(imagename) is str
# remove unnecessary words
a = imagename.split('_')[-1]
# remove .png etc.
b = a.split('.')[0]
# get texture category (last word)
c = b.split('-')[-1]
# remove number, e.g. 'bird2' -> 'bird'
d = ''.join([i for i in c if not i.isdigit()])
return d

View File

@ -0,0 +1,264 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Modified from https://github.com/bethgelab/model-vs-human
import argparse
import os
import os.path as osp
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from utils import FormatStrFormatter, ShapeBias
# global default boundary settings for thin gray transparent
# boundaries to avoid not being able to see the difference
# between two partially overlapping datapoints of the same color:
PLOTTING_EDGE_COLOR = (0.3, 0.3, 0.3, 0.3)
PLOTTING_EDGE_WIDTH = 0.02
ICONS_DIR = osp.join(
osp.dirname(__file__), '..', '..', 'resources', 'shape_bias_icons')
parser = argparse.ArgumentParser()
parser.add_argument('--csv-dir', type=str, help='directory of csv files')
parser.add_argument(
'--result-dir', type=str, help='directory to save plotting results')
parser.add_argument('--model-names', nargs='+', default=[], help='model name')
parser.add_argument(
'--colors',
nargs='+',
type=float,
default=[],
help= # noqa
'the colors for the plots of each model, and they should be in the same order as model_names' # noqa: E501
)
parser.add_argument(
'--markers',
nargs='+',
type=str,
default=[],
help= # noqa
'the markers for the plots of each model, and they should be in the same order as model_names' # noqa: E501
)
parser.add_argument(
'--plotting-names',
nargs='+',
default=[],
help= # noqa
'the plotting names for the plots of each model, and they should be in the same order as model_names' # noqa: E501
)
humans = [
'subject-01', 'subject-02', 'subject-03', 'subject-04', 'subject-05',
'subject-06', 'subject-07', 'subject-08', 'subject-09', 'subject-10'
]
def read_csvs(csv_dir: str) -> pd.DataFrame:
"""Reads all csv files in a directory and returns a single dataframe.
Args:
csv_dir (str): directory of csv files.
Returns:
pd.DataFrame: dataframe containing all csv files
"""
df = pd.DataFrame()
for csv in os.listdir(csv_dir):
if csv.endswith('.csv'):
cur_df = pd.read_csv(osp.join(csv_dir, csv))
cur_df.columns = [c.lower() for c in cur_df.columns]
df = df.append(cur_df)
df.condition = df.condition.astype(str)
return df
def plot_shape_bias_matrixplot(args, analysis=ShapeBias()) -> None:
"""Plots a matrixplot of shape bias.
Args:
args (argparse.Namespace): arguments.
analysis (ShapeBias): shape bias analysis. Defaults to ShapeBias().
"""
mpl.rcParams['font.family'] = ['serif']
mpl.rcParams['font.serif'] = ['Times New Roman']
plt.figure(figsize=(9, 7))
df = read_csvs(args.csv_dir)
fontsize = 15
ticklength = 10
markersize = 250
label_size = 20
classes = df['category'].unique()
num_classes = len(classes)
# plot setup
fig = plt.figure(1, figsize=(12, 12), dpi=300.)
ax = plt.gca()
ax.set_xlim([0, 1])
ax.set_ylim([-.5, num_classes - 0.5])
# secondary reversed x axis
ax_top = ax.secondary_xaxis(
'top', functions=(lambda x: 1 - x, lambda x: 1 - x))
# labels, ticks
plt.tick_params(
axis='y', which='both', left=False, right=False, labelleft=False)
ax.set_ylabel('Shape categories', labelpad=60, fontsize=label_size)
ax.set_xlabel(
"Fraction of 'texture' decisions", fontsize=label_size, labelpad=25)
ax_top.set_xlabel(
"Fraction of 'shape' decisions", fontsize=label_size, labelpad=25)
ax.xaxis.set_major_formatter(FormatStrFormatter('%g'))
ax_top.xaxis.set_major_formatter(FormatStrFormatter('%g'))
ax.get_xaxis().set_ticks(np.arange(0, 1.1, 0.1))
ax_top.set_ticks(np.arange(0, 1.1, 0.1))
ax.tick_params(
axis='both', which='major', labelsize=fontsize, length=ticklength)
ax_top.tick_params(
axis='both', which='major', labelsize=fontsize, length=ticklength)
# arrows on x axes
plt.arrow(
x=0,
y=-1.75,
dx=1,
dy=0,
fc='black',
head_width=0.4,
head_length=0.03,
clip_on=False,
length_includes_head=True,
overhang=0.5)
plt.arrow(
x=1,
y=num_classes + 0.75,
dx=-1,
dy=0,
fc='black',
head_width=0.4,
head_length=0.03,
clip_on=False,
length_includes_head=True,
overhang=0.5)
# icons besides y axis
# determine order of icons
df_selection = df.loc[(df['subj'].isin(humans))]
class_avgs = []
for cl in classes:
df_class_selection = df_selection.query("category == '{}'".format(cl))
class_avgs.append(1 - analysis.analysis(
df=df_class_selection)['shape-bias'])
sorted_indices = np.argsort(class_avgs)
classes = classes[sorted_indices]
# icon placement is calculated in axis coordinates
WIDTH = 1 / num_classes
# placement left of yaxis (-WIDTH) plus some spacing (-.25*WIDTH)
XPOS = -1.25 * WIDTH
YPOS = -0.5
HEIGHT = 1
MARGINX = 1 / 10 * WIDTH # vertical whitespace between icons
MARGINY = 1 / 10 * HEIGHT # horizontal whitespace between icons
left = XPOS + MARGINX
right = XPOS + WIDTH - MARGINX
for i in range(num_classes):
bottom = i + MARGINY + YPOS
top = (i + 1) - MARGINY + YPOS
iconpath = osp.join(ICONS_DIR, '{}.png'.format(classes[i]))
plt.imshow(
plt.imread(iconpath),
extent=[left, right, bottom, top],
aspect='auto',
clip_on=False)
# plot horizontal intersection lines
for i in range(num_classes - 1):
plt.plot([0, 1], [i + .5, i + .5],
c='gray',
linestyle='dotted',
alpha=0.4)
# plot average shapebias + scatter points
for i in range(len(args.model_names)):
df_selection = df.loc[(df['subj'].isin(args.model_names[i]))]
result_df = analysis.analysis(df=df_selection)
avg = 1 - result_df['shape-bias']
ax.plot([avg, avg], [-1, num_classes], color=args.colors[i])
class_avgs = []
for cl in classes:
df_class_selection = df_selection.query(
"category == '{}'".format(cl))
class_avgs.append(1 - analysis.analysis(
df=df_class_selection)['shape-bias'])
ax.scatter(
class_avgs,
classes,
color=args.colors[i],
marker=args.markers[i],
label=args.plotting_names[i],
s=markersize,
clip_on=False,
edgecolors=PLOTTING_EDGE_COLOR,
linewidths=PLOTTING_EDGE_WIDTH,
zorder=3)
plt.legend(frameon=True, labelspacing=1, loc=9)
figure_path = osp.join(args.result_dir,
'cue-conflict_shape-bias_matrixplot.pdf')
fig.savefig(figure_path, bbox_inches='tight')
plt.close()
if __name__ == '__main__':
icon_names = [
'airplane.png', 'response_icons_vertical_reverse.png', 'bottle.png',
'car.png', 'oven.png', 'elephant.png', 'dog.png', 'boat.png',
'clock.png', 'chair.png', 'keyboard.png', 'bird.png', 'bicycle.png',
'response_icons_horizontal.png', 'cat.png', 'bear.png', 'colorbar.pdf',
'knife.png', 'response_icons_vertical.png', 'truck.png'
]
root_url = 'https://github.com/bethgelab/model-vs-human/raw/master/assets/icons' # noqa: E501
os.makedirs(ICONS_DIR, exist_ok=True)
for icon_name in icon_names:
url = osp.join(root_url, icon_name)
os.system('wget -O {} {}'.format(osp.join(ICONS_DIR, icon_name), url))
args = parser.parse_args()
assert len(args.model_names) * 3 == len(args.colors), 'Number of colors \
must be 3 times the number of models. Every three colors are the RGB \
values for one model.'
# preprocess colors
args.colors = [c / 255. for c in args.colors]
colors = []
for i in range(len(args.model_names)):
colors.append(args.colors[3 * i:3 * i + 3])
args.colors = colors
args.colors.append([165 / 255., 30 / 255., 55 / 255.]) # human color
# if plotting names are not specified, use model names
if len(args.plotting_names) == 0:
args.plotting_names = args.model_names
# preprocess markers
args.markers.append('D') # human marker
# preprocess model names
args.model_names = [[m] for m in args.model_names]
args.model_names.append(humans)
# preprocess plotting names
args.plotting_names.append('Humans')
plot_shape_bias_matrixplot(args)
os.system('rm -rf {}'.format(ICONS_DIR))