# Copyright (c) OpenMMLab. All rights reserved. # Modified from https://github.com/bethgelab/model-vs-human import argparse import os import os.path as osp import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np import pandas as pd from mmengine.logging import MMLogger from utils import FormatStrFormatter, ShapeBias # global default boundary settings for thin gray transparent # boundaries to avoid not being able to see the difference # between two partially overlapping datapoints of the same color: PLOTTING_EDGE_COLOR = (0.3, 0.3, 0.3, 0.3) PLOTTING_EDGE_WIDTH = 0.02 ICONS_DIR = osp.join( osp.dirname(__file__), '..', '..', 'resources', 'shape_bias_icons') parser = argparse.ArgumentParser() parser.add_argument('--csv-dir', type=str, help='directory of csv files') parser.add_argument( '--result-dir', type=str, help='directory to save plotting results') parser.add_argument('--model-names', nargs='+', default=[], help='model name') parser.add_argument( '--colors', nargs='+', type=float, default=[], help= # noqa 'the colors for the plots of each model, and they should be in the same order as model_names' # noqa: E501 ) parser.add_argument( '--markers', nargs='+', type=str, default=[], help= # noqa 'the markers for the plots of each model, and they should be in the same order as model_names' # noqa: E501 ) parser.add_argument( '--plotting-names', nargs='+', default=[], help= # noqa 'the plotting names for the plots of each model, and they should be in the same order as model_names' # noqa: E501 ) parser.add_argument( '--delete-icons', action='store_true', help='whether to delete the icons after plotting') humans = [ 'subject-01', 'subject-02', 'subject-03', 'subject-04', 'subject-05', 'subject-06', 'subject-07', 'subject-08', 'subject-09', 'subject-10' ] icon_names = [ 'airplane.png', 'response_icons_vertical_reverse.png', 'bottle.png', 'car.png', 'oven.png', 'elephant.png', 'dog.png', 'boat.png', 'clock.png', 'chair.png', 'keyboard.png', 'bird.png', 'bicycle.png', 'response_icons_horizontal.png', 'cat.png', 'bear.png', 'colorbar.pdf', 'knife.png', 'response_icons_vertical.png', 'truck.png' ] def read_csvs(csv_dir: str) -> pd.DataFrame: """Reads all csv files in a directory and returns a single dataframe. Args: csv_dir (str): directory of csv files. Returns: pd.DataFrame: dataframe containing all csv files """ df = pd.DataFrame() for csv in os.listdir(csv_dir): if csv.endswith('.csv'): cur_df = pd.read_csv(osp.join(csv_dir, csv)) cur_df.columns = [c.lower() for c in cur_df.columns] df = df.append(cur_df) df.condition = df.condition.astype(str) return df def plot_shape_bias_matrixplot(args, analysis=ShapeBias()) -> None: """Plots a matrixplot of shape bias. Args: args (argparse.Namespace): arguments. analysis (ShapeBias): shape bias analysis. Defaults to ShapeBias(). """ mpl.rcParams['font.family'] = ['serif'] mpl.rcParams['font.serif'] = ['Times New Roman'] plt.figure(figsize=(9, 7)) df = read_csvs(args.csv_dir) fontsize = 15 ticklength = 10 markersize = 250 label_size = 20 classes = df['category'].unique() num_classes = len(classes) # plot setup fig = plt.figure(1, figsize=(12, 12), dpi=300.) ax = plt.gca() ax.set_xlim([0, 1]) ax.set_ylim([-.5, num_classes - 0.5]) # secondary reversed x axis ax_top = ax.secondary_xaxis( 'top', functions=(lambda x: 1 - x, lambda x: 1 - x)) # labels, ticks plt.tick_params( axis='y', which='both', left=False, right=False, labelleft=False) ax.set_ylabel('Shape categories', labelpad=60, fontsize=label_size) ax.set_xlabel( "Fraction of 'texture' decisions", fontsize=label_size, labelpad=25) ax_top.set_xlabel( "Fraction of 'shape' decisions", fontsize=label_size, labelpad=25) ax.xaxis.set_major_formatter(FormatStrFormatter('%g')) ax_top.xaxis.set_major_formatter(FormatStrFormatter('%g')) ax.get_xaxis().set_ticks(np.arange(0, 1.1, 0.1)) ax_top.set_ticks(np.arange(0, 1.1, 0.1)) ax.tick_params( axis='both', which='major', labelsize=fontsize, length=ticklength) ax_top.tick_params( axis='both', which='major', labelsize=fontsize, length=ticklength) # arrows on x axes plt.arrow( x=0, y=-1.75, dx=1, dy=0, fc='black', head_width=0.4, head_length=0.03, clip_on=False, length_includes_head=True, overhang=0.5) plt.arrow( x=1, y=num_classes + 0.75, dx=-1, dy=0, fc='black', head_width=0.4, head_length=0.03, clip_on=False, length_includes_head=True, overhang=0.5) # icons besides y axis # determine order of icons df_selection = df.loc[(df['subj'].isin(humans))] class_avgs = [] for cl in classes: df_class_selection = df_selection.query("category == '{}'".format(cl)) class_avgs.append(1 - analysis.analysis( df=df_class_selection)['shape-bias']) sorted_indices = np.argsort(class_avgs) classes = classes[sorted_indices] # icon placement is calculated in axis coordinates WIDTH = 1 / num_classes # placement left of yaxis (-WIDTH) plus some spacing (-.25*WIDTH) XPOS = -1.25 * WIDTH YPOS = -0.5 HEIGHT = 1 MARGINX = 1 / 10 * WIDTH # vertical whitespace between icons MARGINY = 1 / 10 * HEIGHT # horizontal whitespace between icons left = XPOS + MARGINX right = XPOS + WIDTH - MARGINX for i in range(num_classes): bottom = i + MARGINY + YPOS top = (i + 1) - MARGINY + YPOS iconpath = osp.join(ICONS_DIR, '{}.png'.format(classes[i])) plt.imshow( plt.imread(iconpath), extent=[left, right, bottom, top], aspect='auto', clip_on=False) # plot horizontal intersection lines for i in range(num_classes - 1): plt.plot([0, 1], [i + .5, i + .5], c='gray', linestyle='dotted', alpha=0.4) # plot average shapebias + scatter points for i in range(len(args.model_names)): df_selection = df.loc[(df['subj'].isin(args.model_names[i]))] result_df = analysis.analysis(df=df_selection) avg = 1 - result_df['shape-bias'] ax.plot([avg, avg], [-1, num_classes], color=args.colors[i]) class_avgs = [] for cl in classes: df_class_selection = df_selection.query( "category == '{}'".format(cl)) class_avgs.append(1 - analysis.analysis( df=df_class_selection)['shape-bias']) ax.scatter( class_avgs, classes, color=args.colors[i], marker=args.markers[i], label=args.plotting_names[i], s=markersize, clip_on=False, edgecolors=PLOTTING_EDGE_COLOR, linewidths=PLOTTING_EDGE_WIDTH, zorder=3) plt.legend(frameon=True, labelspacing=1, loc=9) figure_path = osp.join(args.result_dir, 'cue-conflict_shape-bias_matrixplot.pdf') fig.savefig(figure_path, bbox_inches='tight') plt.close() def check_icons() -> bool: """Check if icons are present, if not download them.""" if not osp.exists(ICONS_DIR): return False for icon_name in icon_names: if not osp.exists(osp.join(ICONS_DIR, icon_name)): return False return True if __name__ == '__main__': if not check_icons(): root_url = 'https://github.com/bethgelab/model-vs-human/raw/master/assets/icons' # noqa: E501 os.makedirs(ICONS_DIR, exist_ok=True) MMLogger.get_current_instance().info( f'Downloading icons to {ICONS_DIR}') for icon_name in icon_names: url = osp.join(root_url, icon_name) os.system('wget -O {} {}'.format( osp.join(ICONS_DIR, icon_name), url)) args = parser.parse_args() assert len(args.model_names) * 3 == len(args.colors), 'Number of colors \ must be 3 times the number of models. Every three colors are the RGB \ values for one model.' # preprocess colors args.colors = [c / 255. for c in args.colors] colors = [] for i in range(len(args.model_names)): colors.append(args.colors[3 * i:3 * i + 3]) args.colors = colors args.colors.append([165 / 255., 30 / 255., 55 / 255.]) # human color # if plotting names are not specified, use model names if len(args.plotting_names) == 0: args.plotting_names = args.model_names # preprocess markers args.markers.append('D') # human marker # preprocess model names args.model_names = [[m] for m in args.model_names] args.model_names.append(humans) # preprocess plotting names args.plotting_names.append('Humans') plot_shape_bias_matrixplot(args) if args.delete_icons: os.system('rm -rf {}'.format(ICONS_DIR))