# Copyright (c) OpenMMLab. All rights reserved. # Modified from https://github.com/bethgelab/model-vs-human import argparse import os import os.path as osp import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np import pandas as pd from utils import FormatStrFormatter, ShapeBias # global default boundary settings for thin gray transparent # boundaries to avoid not being able to see the difference # between two partially overlapping datapoints of the same color: PLOTTING_EDGE_COLOR = (0.3, 0.3, 0.3, 0.3) PLOTTING_EDGE_WIDTH = 0.02 ICONS_DIR = osp.join( osp.dirname(__file__), '..', '..', 'resources', 'shape_bias_icons') parser = argparse.ArgumentParser() parser.add_argument('--csv-dir', type=str, help='directory of csv files') parser.add_argument( '--result-dir', type=str, help='directory to save plotting results') parser.add_argument('--model-names', nargs='+', default=[], help='model name') parser.add_argument( '--colors', nargs='+', type=float, default=[], help= # noqa 'the colors for the plots of each model, and they should be in the same order as model_names' # noqa: E501 ) parser.add_argument( '--markers', nargs='+', type=str, default=[], help= # noqa 'the markers for the plots of each model, and they should be in the same order as model_names' # noqa: E501 ) parser.add_argument( '--plotting-names', nargs='+', default=[], help= # noqa 'the plotting names for the plots of each model, and they should be in the same order as model_names' # noqa: E501 ) humans = [ 'subject-01', 'subject-02', 'subject-03', 'subject-04', 'subject-05', 'subject-06', 'subject-07', 'subject-08', 'subject-09', 'subject-10' ] def read_csvs(csv_dir: str) -> pd.DataFrame: """Reads all csv files in a directory and returns a single dataframe. Args: csv_dir (str): directory of csv files. Returns: pd.DataFrame: dataframe containing all csv files """ df = pd.DataFrame() for csv in os.listdir(csv_dir): if csv.endswith('.csv'): cur_df = pd.read_csv(osp.join(csv_dir, csv)) cur_df.columns = [c.lower() for c in cur_df.columns] df = df.append(cur_df) df.condition = df.condition.astype(str) return df def plot_shape_bias_matrixplot(args, analysis=ShapeBias()) -> None: """Plots a matrixplot of shape bias. Args: args (argparse.Namespace): arguments. analysis (ShapeBias): shape bias analysis. Defaults to ShapeBias(). """ mpl.rcParams['font.family'] = ['serif'] mpl.rcParams['font.serif'] = ['Times New Roman'] plt.figure(figsize=(9, 7)) df = read_csvs(args.csv_dir) fontsize = 15 ticklength = 10 markersize = 250 label_size = 20 classes = df['category'].unique() num_classes = len(classes) # plot setup fig = plt.figure(1, figsize=(12, 12), dpi=300.) ax = plt.gca() ax.set_xlim([0, 1]) ax.set_ylim([-.5, num_classes - 0.5]) # secondary reversed x axis ax_top = ax.secondary_xaxis( 'top', functions=(lambda x: 1 - x, lambda x: 1 - x)) # labels, ticks plt.tick_params( axis='y', which='both', left=False, right=False, labelleft=False) ax.set_ylabel('Shape categories', labelpad=60, fontsize=label_size) ax.set_xlabel( "Fraction of 'texture' decisions", fontsize=label_size, labelpad=25) ax_top.set_xlabel( "Fraction of 'shape' decisions", fontsize=label_size, labelpad=25) ax.xaxis.set_major_formatter(FormatStrFormatter('%g')) ax_top.xaxis.set_major_formatter(FormatStrFormatter('%g')) ax.get_xaxis().set_ticks(np.arange(0, 1.1, 0.1)) ax_top.set_ticks(np.arange(0, 1.1, 0.1)) ax.tick_params( axis='both', which='major', labelsize=fontsize, length=ticklength) ax_top.tick_params( axis='both', which='major', labelsize=fontsize, length=ticklength) # arrows on x axes plt.arrow( x=0, y=-1.75, dx=1, dy=0, fc='black', head_width=0.4, head_length=0.03, clip_on=False, length_includes_head=True, overhang=0.5) plt.arrow( x=1, y=num_classes + 0.75, dx=-1, dy=0, fc='black', head_width=0.4, head_length=0.03, clip_on=False, length_includes_head=True, overhang=0.5) # icons besides y axis # determine order of icons df_selection = df.loc[(df['subj'].isin(humans))] class_avgs = [] for cl in classes: df_class_selection = df_selection.query("category == '{}'".format(cl)) class_avgs.append(1 - analysis.analysis( df=df_class_selection)['shape-bias']) sorted_indices = np.argsort(class_avgs) classes = classes[sorted_indices] # icon placement is calculated in axis coordinates WIDTH = 1 / num_classes # placement left of yaxis (-WIDTH) plus some spacing (-.25*WIDTH) XPOS = -1.25 * WIDTH YPOS = -0.5 HEIGHT = 1 MARGINX = 1 / 10 * WIDTH # vertical whitespace between icons MARGINY = 1 / 10 * HEIGHT # horizontal whitespace between icons left = XPOS + MARGINX right = XPOS + WIDTH - MARGINX for i in range(num_classes): bottom = i + MARGINY + YPOS top = (i + 1) - MARGINY + YPOS iconpath = osp.join(ICONS_DIR, '{}.png'.format(classes[i])) plt.imshow( plt.imread(iconpath), extent=[left, right, bottom, top], aspect='auto', clip_on=False) # plot horizontal intersection lines for i in range(num_classes - 1): plt.plot([0, 1], [i + .5, i + .5], c='gray', linestyle='dotted', alpha=0.4) # plot average shapebias + scatter points for i in range(len(args.model_names)): df_selection = df.loc[(df['subj'].isin(args.model_names[i]))] result_df = analysis.analysis(df=df_selection) avg = 1 - result_df['shape-bias'] ax.plot([avg, avg], [-1, num_classes], color=args.colors[i]) class_avgs = [] for cl in classes: df_class_selection = df_selection.query( "category == '{}'".format(cl)) class_avgs.append(1 - analysis.analysis( df=df_class_selection)['shape-bias']) ax.scatter( class_avgs, classes, color=args.colors[i], marker=args.markers[i], label=args.plotting_names[i], s=markersize, clip_on=False, edgecolors=PLOTTING_EDGE_COLOR, linewidths=PLOTTING_EDGE_WIDTH, zorder=3) plt.legend(frameon=True, labelspacing=1, loc=9) figure_path = osp.join(args.result_dir, 'cue-conflict_shape-bias_matrixplot.pdf') fig.savefig(figure_path, bbox_inches='tight') plt.close() if __name__ == '__main__': icon_names = [ 'airplane.png', 'response_icons_vertical_reverse.png', 'bottle.png', 'car.png', 'oven.png', 'elephant.png', 'dog.png', 'boat.png', 'clock.png', 'chair.png', 'keyboard.png', 'bird.png', 'bicycle.png', 'response_icons_horizontal.png', 'cat.png', 'bear.png', 'colorbar.pdf', 'knife.png', 'response_icons_vertical.png', 'truck.png' ] root_url = 'https://github.com/bethgelab/model-vs-human/raw/master/assets/icons' # noqa: E501 os.makedirs(ICONS_DIR, exist_ok=True) for icon_name in icon_names: url = osp.join(root_url, icon_name) os.system('wget -O {} {}'.format(osp.join(ICONS_DIR, icon_name), url)) args = parser.parse_args() assert len(args.model_names) * 3 == len(args.colors), 'Number of colors \ must be 3 times the number of models. Every three colors are the RGB \ values for one model.' # preprocess colors args.colors = [c / 255. for c in args.colors] colors = [] for i in range(len(args.model_names)): colors.append(args.colors[3 * i:3 * i + 3]) args.colors = colors args.colors.append([165 / 255., 30 / 255., 55 / 255.]) # human color # if plotting names are not specified, use model names if len(args.plotting_names) == 0: args.plotting_names = args.model_names # preprocess markers args.markers.append('D') # human marker # preprocess model names args.model_names = [[m] for m in args.model_names] args.model_names.append(humans) # preprocess plotting names args.plotting_names.append('Humans') plot_shape_bias_matrixplot(args) os.system('rm -rf {}'.format(ICONS_DIR))