mmselfsup/tools/analysis_tools/visualize_shape_bias.py
Yuan Liu 304e81650a
[Feature]: Shape bias (#635)
* [Feature]: Add shape bias vis

* [Fix]: Fix lint

* [Feature]: Add shape bias metrics

* [Fix]: Fix lint

* [Fix]: Delete redundant code

* [Feature]: Add shape bias doc

* [Fix]: Fix lint

* [Feature]: add UT

* [Fix]: Fix lint

* [Fix]: Fix typo

* [Fix]: Fix typo

* [Fix]: Fix args param style

* [Feature]: Download pic automatically
2022-12-30 22:11:18 +08:00

265 lines
8.5 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
# Modified from https://github.com/bethgelab/model-vs-human
import argparse
import os
import os.path as osp
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from utils import FormatStrFormatter, ShapeBias
# global default boundary settings for thin gray transparent
# boundaries to avoid not being able to see the difference
# between two partially overlapping datapoints of the same color:
PLOTTING_EDGE_COLOR = (0.3, 0.3, 0.3, 0.3)
PLOTTING_EDGE_WIDTH = 0.02
ICONS_DIR = osp.join(
osp.dirname(__file__), '..', '..', 'resources', 'shape_bias_icons')
parser = argparse.ArgumentParser()
parser.add_argument('--csv-dir', type=str, help='directory of csv files')
parser.add_argument(
'--result-dir', type=str, help='directory to save plotting results')
parser.add_argument('--model-names', nargs='+', default=[], help='model name')
parser.add_argument(
'--colors',
nargs='+',
type=float,
default=[],
help= # noqa
'the colors for the plots of each model, and they should be in the same order as model_names' # noqa: E501
)
parser.add_argument(
'--markers',
nargs='+',
type=str,
default=[],
help= # noqa
'the markers for the plots of each model, and they should be in the same order as model_names' # noqa: E501
)
parser.add_argument(
'--plotting-names',
nargs='+',
default=[],
help= # noqa
'the plotting names for the plots of each model, and they should be in the same order as model_names' # noqa: E501
)
humans = [
'subject-01', 'subject-02', 'subject-03', 'subject-04', 'subject-05',
'subject-06', 'subject-07', 'subject-08', 'subject-09', 'subject-10'
]
def read_csvs(csv_dir: str) -> pd.DataFrame:
"""Reads all csv files in a directory and returns a single dataframe.
Args:
csv_dir (str): directory of csv files.
Returns:
pd.DataFrame: dataframe containing all csv files
"""
df = pd.DataFrame()
for csv in os.listdir(csv_dir):
if csv.endswith('.csv'):
cur_df = pd.read_csv(osp.join(csv_dir, csv))
cur_df.columns = [c.lower() for c in cur_df.columns]
df = df.append(cur_df)
df.condition = df.condition.astype(str)
return df
def plot_shape_bias_matrixplot(args, analysis=ShapeBias()) -> None:
"""Plots a matrixplot of shape bias.
Args:
args (argparse.Namespace): arguments.
analysis (ShapeBias): shape bias analysis. Defaults to ShapeBias().
"""
mpl.rcParams['font.family'] = ['serif']
mpl.rcParams['font.serif'] = ['Times New Roman']
plt.figure(figsize=(9, 7))
df = read_csvs(args.csv_dir)
fontsize = 15
ticklength = 10
markersize = 250
label_size = 20
classes = df['category'].unique()
num_classes = len(classes)
# plot setup
fig = plt.figure(1, figsize=(12, 12), dpi=300.)
ax = plt.gca()
ax.set_xlim([0, 1])
ax.set_ylim([-.5, num_classes - 0.5])
# secondary reversed x axis
ax_top = ax.secondary_xaxis(
'top', functions=(lambda x: 1 - x, lambda x: 1 - x))
# labels, ticks
plt.tick_params(
axis='y', which='both', left=False, right=False, labelleft=False)
ax.set_ylabel('Shape categories', labelpad=60, fontsize=label_size)
ax.set_xlabel(
"Fraction of 'texture' decisions", fontsize=label_size, labelpad=25)
ax_top.set_xlabel(
"Fraction of 'shape' decisions", fontsize=label_size, labelpad=25)
ax.xaxis.set_major_formatter(FormatStrFormatter('%g'))
ax_top.xaxis.set_major_formatter(FormatStrFormatter('%g'))
ax.get_xaxis().set_ticks(np.arange(0, 1.1, 0.1))
ax_top.set_ticks(np.arange(0, 1.1, 0.1))
ax.tick_params(
axis='both', which='major', labelsize=fontsize, length=ticklength)
ax_top.tick_params(
axis='both', which='major', labelsize=fontsize, length=ticklength)
# arrows on x axes
plt.arrow(
x=0,
y=-1.75,
dx=1,
dy=0,
fc='black',
head_width=0.4,
head_length=0.03,
clip_on=False,
length_includes_head=True,
overhang=0.5)
plt.arrow(
x=1,
y=num_classes + 0.75,
dx=-1,
dy=0,
fc='black',
head_width=0.4,
head_length=0.03,
clip_on=False,
length_includes_head=True,
overhang=0.5)
# icons besides y axis
# determine order of icons
df_selection = df.loc[(df['subj'].isin(humans))]
class_avgs = []
for cl in classes:
df_class_selection = df_selection.query("category == '{}'".format(cl))
class_avgs.append(1 - analysis.analysis(
df=df_class_selection)['shape-bias'])
sorted_indices = np.argsort(class_avgs)
classes = classes[sorted_indices]
# icon placement is calculated in axis coordinates
WIDTH = 1 / num_classes
# placement left of yaxis (-WIDTH) plus some spacing (-.25*WIDTH)
XPOS = -1.25 * WIDTH
YPOS = -0.5
HEIGHT = 1
MARGINX = 1 / 10 * WIDTH # vertical whitespace between icons
MARGINY = 1 / 10 * HEIGHT # horizontal whitespace between icons
left = XPOS + MARGINX
right = XPOS + WIDTH - MARGINX
for i in range(num_classes):
bottom = i + MARGINY + YPOS
top = (i + 1) - MARGINY + YPOS
iconpath = osp.join(ICONS_DIR, '{}.png'.format(classes[i]))
plt.imshow(
plt.imread(iconpath),
extent=[left, right, bottom, top],
aspect='auto',
clip_on=False)
# plot horizontal intersection lines
for i in range(num_classes - 1):
plt.plot([0, 1], [i + .5, i + .5],
c='gray',
linestyle='dotted',
alpha=0.4)
# plot average shapebias + scatter points
for i in range(len(args.model_names)):
df_selection = df.loc[(df['subj'].isin(args.model_names[i]))]
result_df = analysis.analysis(df=df_selection)
avg = 1 - result_df['shape-bias']
ax.plot([avg, avg], [-1, num_classes], color=args.colors[i])
class_avgs = []
for cl in classes:
df_class_selection = df_selection.query(
"category == '{}'".format(cl))
class_avgs.append(1 - analysis.analysis(
df=df_class_selection)['shape-bias'])
ax.scatter(
class_avgs,
classes,
color=args.colors[i],
marker=args.markers[i],
label=args.plotting_names[i],
s=markersize,
clip_on=False,
edgecolors=PLOTTING_EDGE_COLOR,
linewidths=PLOTTING_EDGE_WIDTH,
zorder=3)
plt.legend(frameon=True, labelspacing=1, loc=9)
figure_path = osp.join(args.result_dir,
'cue-conflict_shape-bias_matrixplot.pdf')
fig.savefig(figure_path, bbox_inches='tight')
plt.close()
if __name__ == '__main__':
icon_names = [
'airplane.png', 'response_icons_vertical_reverse.png', 'bottle.png',
'car.png', 'oven.png', 'elephant.png', 'dog.png', 'boat.png',
'clock.png', 'chair.png', 'keyboard.png', 'bird.png', 'bicycle.png',
'response_icons_horizontal.png', 'cat.png', 'bear.png', 'colorbar.pdf',
'knife.png', 'response_icons_vertical.png', 'truck.png'
]
root_url = 'https://github.com/bethgelab/model-vs-human/raw/master/assets/icons' # noqa: E501
os.makedirs(ICONS_DIR, exist_ok=True)
for icon_name in icon_names:
url = osp.join(root_url, icon_name)
os.system('wget -O {} {}'.format(osp.join(ICONS_DIR, icon_name), url))
args = parser.parse_args()
assert len(args.model_names) * 3 == len(args.colors), 'Number of colors \
must be 3 times the number of models. Every three colors are the RGB \
values for one model.'
# preprocess colors
args.colors = [c / 255. for c in args.colors]
colors = []
for i in range(len(args.model_names)):
colors.append(args.colors[3 * i:3 * i + 3])
args.colors = colors
args.colors.append([165 / 255., 30 / 255., 55 / 255.]) # human color
# if plotting names are not specified, use model names
if len(args.plotting_names) == 0:
args.plotting_names = args.model_names
# preprocess markers
args.markers.append('D') # human marker
# preprocess model names
args.model_names = [[m] for m in args.model_names]
args.model_names.append(humans)
# preprocess plotting names
args.plotting_names.append('Humans')
plot_shape_bias_matrixplot(args)
os.system('rm -rf {}'.format(ICONS_DIR))