285 lines
9.1 KiB
Python
285 lines
9.1 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
# Modified from https://github.com/bethgelab/model-vs-human
|
|
import argparse
|
|
import os
|
|
import os.path as osp
|
|
|
|
import matplotlib as mpl
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import pandas as pd
|
|
from mmengine.logging import MMLogger
|
|
from utils import FormatStrFormatter, ShapeBias
|
|
|
|
# global default boundary settings for thin gray transparent
|
|
# boundaries to avoid not being able to see the difference
|
|
# between two partially overlapping datapoints of the same color:
|
|
PLOTTING_EDGE_COLOR = (0.3, 0.3, 0.3, 0.3)
|
|
PLOTTING_EDGE_WIDTH = 0.02
|
|
ICONS_DIR = osp.join(
|
|
osp.dirname(__file__), '..', '..', 'resources', 'shape_bias_icons')
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--csv-dir', type=str, help='directory of csv files')
|
|
parser.add_argument(
|
|
'--result-dir', type=str, help='directory to save plotting results')
|
|
parser.add_argument('--model-names', nargs='+', default=[], help='model name')
|
|
parser.add_argument(
|
|
'--colors',
|
|
nargs='+',
|
|
type=float,
|
|
default=[],
|
|
help= # noqa
|
|
'the colors for the plots of each model, and they should be in the same order as model_names' # noqa: E501
|
|
)
|
|
parser.add_argument(
|
|
'--markers',
|
|
nargs='+',
|
|
type=str,
|
|
default=[],
|
|
help= # noqa
|
|
'the markers for the plots of each model, and they should be in the same order as model_names' # noqa: E501
|
|
)
|
|
parser.add_argument(
|
|
'--plotting-names',
|
|
nargs='+',
|
|
default=[],
|
|
help= # noqa
|
|
'the plotting names for the plots of each model, and they should be in the same order as model_names' # noqa: E501
|
|
)
|
|
parser.add_argument(
|
|
'--delete-icons',
|
|
action='store_true',
|
|
help='whether to delete the icons after plotting')
|
|
|
|
humans = [
|
|
'subject-01', 'subject-02', 'subject-03', 'subject-04', 'subject-05',
|
|
'subject-06', 'subject-07', 'subject-08', 'subject-09', 'subject-10'
|
|
]
|
|
|
|
icon_names = [
|
|
'airplane.png', 'response_icons_vertical_reverse.png', 'bottle.png',
|
|
'car.png', 'oven.png', 'elephant.png', 'dog.png', 'boat.png', 'clock.png',
|
|
'chair.png', 'keyboard.png', 'bird.png', 'bicycle.png',
|
|
'response_icons_horizontal.png', 'cat.png', 'bear.png', 'colorbar.pdf',
|
|
'knife.png', 'response_icons_vertical.png', 'truck.png'
|
|
]
|
|
|
|
|
|
def read_csvs(csv_dir: str) -> pd.DataFrame:
|
|
"""Reads all csv files in a directory and returns a single dataframe.
|
|
|
|
Args:
|
|
csv_dir (str): directory of csv files.
|
|
|
|
Returns:
|
|
pd.DataFrame: dataframe containing all csv files
|
|
"""
|
|
df = pd.DataFrame()
|
|
for csv in os.listdir(csv_dir):
|
|
if csv.endswith('.csv'):
|
|
cur_df = pd.read_csv(osp.join(csv_dir, csv))
|
|
cur_df.columns = [c.lower() for c in cur_df.columns]
|
|
df = df.append(cur_df)
|
|
df.condition = df.condition.astype(str)
|
|
return df
|
|
|
|
|
|
def plot_shape_bias_matrixplot(args, analysis=ShapeBias()) -> None:
|
|
"""Plots a matrixplot of shape bias.
|
|
|
|
Args:
|
|
args (argparse.Namespace): arguments.
|
|
analysis (ShapeBias): shape bias analysis. Defaults to ShapeBias().
|
|
"""
|
|
mpl.rcParams['font.family'] = ['serif']
|
|
mpl.rcParams['font.serif'] = ['Times New Roman']
|
|
|
|
plt.figure(figsize=(9, 7))
|
|
df = read_csvs(args.csv_dir)
|
|
|
|
fontsize = 15
|
|
ticklength = 10
|
|
markersize = 250
|
|
label_size = 20
|
|
|
|
classes = df['category'].unique()
|
|
num_classes = len(classes)
|
|
|
|
# plot setup
|
|
fig = plt.figure(1, figsize=(12, 12), dpi=300.)
|
|
ax = plt.gca()
|
|
|
|
ax.set_xlim([0, 1])
|
|
ax.set_ylim([-.5, num_classes - 0.5])
|
|
|
|
# secondary reversed x axis
|
|
ax_top = ax.secondary_xaxis(
|
|
'top', functions=(lambda x: 1 - x, lambda x: 1 - x))
|
|
|
|
# labels, ticks
|
|
plt.tick_params(
|
|
axis='y', which='both', left=False, right=False, labelleft=False)
|
|
ax.set_ylabel('Shape categories', labelpad=60, fontsize=label_size)
|
|
ax.set_xlabel(
|
|
"Fraction of 'texture' decisions", fontsize=label_size, labelpad=25)
|
|
ax_top.set_xlabel(
|
|
"Fraction of 'shape' decisions", fontsize=label_size, labelpad=25)
|
|
ax.xaxis.set_major_formatter(FormatStrFormatter('%g'))
|
|
ax_top.xaxis.set_major_formatter(FormatStrFormatter('%g'))
|
|
ax.get_xaxis().set_ticks(np.arange(0, 1.1, 0.1))
|
|
ax_top.set_ticks(np.arange(0, 1.1, 0.1))
|
|
ax.tick_params(
|
|
axis='both', which='major', labelsize=fontsize, length=ticklength)
|
|
ax_top.tick_params(
|
|
axis='both', which='major', labelsize=fontsize, length=ticklength)
|
|
|
|
# arrows on x axes
|
|
plt.arrow(
|
|
x=0,
|
|
y=-1.75,
|
|
dx=1,
|
|
dy=0,
|
|
fc='black',
|
|
head_width=0.4,
|
|
head_length=0.03,
|
|
clip_on=False,
|
|
length_includes_head=True,
|
|
overhang=0.5)
|
|
plt.arrow(
|
|
x=1,
|
|
y=num_classes + 0.75,
|
|
dx=-1,
|
|
dy=0,
|
|
fc='black',
|
|
head_width=0.4,
|
|
head_length=0.03,
|
|
clip_on=False,
|
|
length_includes_head=True,
|
|
overhang=0.5)
|
|
|
|
# icons besides y axis
|
|
# determine order of icons
|
|
df_selection = df.loc[(df['subj'].isin(humans))]
|
|
class_avgs = []
|
|
for cl in classes:
|
|
df_class_selection = df_selection.query("category == '{}'".format(cl))
|
|
class_avgs.append(1 - analysis.analysis(
|
|
df=df_class_selection)['shape-bias'])
|
|
sorted_indices = np.argsort(class_avgs)
|
|
classes = classes[sorted_indices]
|
|
|
|
# icon placement is calculated in axis coordinates
|
|
WIDTH = 1 / num_classes
|
|
# placement left of yaxis (-WIDTH) plus some spacing (-.25*WIDTH)
|
|
XPOS = -1.25 * WIDTH
|
|
YPOS = -0.5
|
|
HEIGHT = 1
|
|
MARGINX = 1 / 10 * WIDTH # vertical whitespace between icons
|
|
MARGINY = 1 / 10 * HEIGHT # horizontal whitespace between icons
|
|
|
|
left = XPOS + MARGINX
|
|
right = XPOS + WIDTH - MARGINX
|
|
|
|
for i in range(num_classes):
|
|
bottom = i + MARGINY + YPOS
|
|
top = (i + 1) - MARGINY + YPOS
|
|
iconpath = osp.join(ICONS_DIR, '{}.png'.format(classes[i]))
|
|
plt.imshow(
|
|
plt.imread(iconpath),
|
|
extent=[left, right, bottom, top],
|
|
aspect='auto',
|
|
clip_on=False)
|
|
|
|
# plot horizontal intersection lines
|
|
for i in range(num_classes - 1):
|
|
plt.plot([0, 1], [i + .5, i + .5],
|
|
c='gray',
|
|
linestyle='dotted',
|
|
alpha=0.4)
|
|
|
|
# plot average shapebias + scatter points
|
|
for i in range(len(args.model_names)):
|
|
df_selection = df.loc[(df['subj'].isin(args.model_names[i]))]
|
|
result_df = analysis.analysis(df=df_selection)
|
|
avg = 1 - result_df['shape-bias']
|
|
ax.plot([avg, avg], [-1, num_classes], color=args.colors[i])
|
|
class_avgs = []
|
|
for cl in classes:
|
|
df_class_selection = df_selection.query(
|
|
"category == '{}'".format(cl))
|
|
class_avgs.append(1 - analysis.analysis(
|
|
df=df_class_selection)['shape-bias'])
|
|
|
|
ax.scatter(
|
|
class_avgs,
|
|
classes,
|
|
color=args.colors[i],
|
|
marker=args.markers[i],
|
|
label=args.plotting_names[i],
|
|
s=markersize,
|
|
clip_on=False,
|
|
edgecolors=PLOTTING_EDGE_COLOR,
|
|
linewidths=PLOTTING_EDGE_WIDTH,
|
|
zorder=3)
|
|
plt.legend(frameon=True, labelspacing=1, loc=9)
|
|
|
|
figure_path = osp.join(args.result_dir,
|
|
'cue-conflict_shape-bias_matrixplot.pdf')
|
|
fig.savefig(figure_path, bbox_inches='tight')
|
|
plt.close()
|
|
|
|
|
|
def check_icons() -> bool:
|
|
"""Check if icons are present, if not download them."""
|
|
if not osp.exists(ICONS_DIR):
|
|
return False
|
|
for icon_name in icon_names:
|
|
if not osp.exists(osp.join(ICONS_DIR, icon_name)):
|
|
return False
|
|
return True
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
if not check_icons():
|
|
root_url = 'https://github.com/bethgelab/model-vs-human/raw/master/assets/icons' # noqa: E501
|
|
os.makedirs(ICONS_DIR, exist_ok=True)
|
|
MMLogger.get_current_instance().info(
|
|
f'Downloading icons to {ICONS_DIR}')
|
|
for icon_name in icon_names:
|
|
url = osp.join(root_url, icon_name)
|
|
os.system('wget -O {} {}'.format(
|
|
osp.join(ICONS_DIR, icon_name), url))
|
|
|
|
args = parser.parse_args()
|
|
assert len(args.model_names) * 3 == len(args.colors), 'Number of colors \
|
|
must be 3 times the number of models. Every three colors are the RGB \
|
|
values for one model.'
|
|
|
|
# preprocess colors
|
|
args.colors = [c / 255. for c in args.colors]
|
|
colors = []
|
|
for i in range(len(args.model_names)):
|
|
colors.append(args.colors[3 * i:3 * i + 3])
|
|
args.colors = colors
|
|
args.colors.append([165 / 255., 30 / 255., 55 / 255.]) # human color
|
|
|
|
# if plotting names are not specified, use model names
|
|
if len(args.plotting_names) == 0:
|
|
args.plotting_names = args.model_names
|
|
|
|
# preprocess markers
|
|
args.markers.append('D') # human marker
|
|
|
|
# preprocess model names
|
|
args.model_names = [[m] for m in args.model_names]
|
|
args.model_names.append(humans)
|
|
|
|
# preprocess plotting names
|
|
args.plotting_names.append('Humans')
|
|
|
|
plot_shape_bias_matrixplot(args)
|
|
if args.delete_icons:
|
|
os.system('rm -rf {}'.format(ICONS_DIR))
|