[Fix]: Fix distributed setting for shape bias (#689)

* [Fix]: Fix dist bug

* [Fix]: Fix lint

* [Feature]: Add download log info

* [Fix]: Fix lint
This commit is contained in:
Yuan Liu 2023-02-08 14:23:59 +08:00 committed by GitHub
parent 0f634de85b
commit d3f57edf79
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 41 additions and 18 deletions

View File

@ -6,6 +6,7 @@ from typing import List, Sequence
import numpy as np import numpy as np
import torch import torch
from mmengine.dist.utils import get_rank
from mmengine.evaluator import BaseMetric from mmengine.evaluator import BaseMetric
from mmselfsup.registry import METRICS from mmselfsup.registry import METRICS
@ -75,7 +76,8 @@ class ShapeBiasMetric(BaseMetric):
self.csv_dir = csv_dir self.csv_dir = csv_dir
self.model_name = model_name self.model_name = model_name
self.dataset_name = dataset_name self.dataset_name = dataset_name
self.csv_path = self.create_csv() if get_rank() == 0:
self.csv_path = self.create_csv()
def process(self, data_batch, data_samples: Sequence[dict]) -> None: def process(self, data_batch, data_samples: Sequence[dict]) -> None:
"""Process one batch of data samples. """Process one batch of data samples.
@ -126,7 +128,7 @@ class ShapeBiasMetric(BaseMetric):
os.remove(csv_path) os.remove(csv_path)
directory = osp.dirname(csv_path) directory = osp.dirname(csv_path)
if not osp.exists(directory): if not osp.exists(directory):
os.makedirs(directory) os.makedirs(directory, exist_ok=True)
with open(csv_path, 'w') as f: with open(csv_path, 'w') as f:
writer = csv.writer(f) writer = csv.writer(f)
writer.writerow([ writer.writerow([
@ -161,7 +163,8 @@ class ShapeBiasMetric(BaseMetric):
Returns: Returns:
dict: A dict of metrics. dict: A dict of metrics.
""" """
self.dump_results_to_csv(results) if get_rank() == 0:
self.dump_results_to_csv(results)
metrics = dict() metrics = dict()
metrics['accuracy/top1'] = np.mean([ metrics['accuracy/top1'] = np.mean([
result['pred_category'][0] == result['gt_category'] result['pred_category'][0] == result['gt_category']

View File

@ -8,6 +8,7 @@ import matplotlib as mpl
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from mmengine.logging import MMLogger
from utils import FormatStrFormatter, ShapeBias from utils import FormatStrFormatter, ShapeBias
# global default boundary settings for thin gray transparent # global default boundary settings for thin gray transparent
@ -46,12 +47,24 @@ parser.add_argument(
help= # noqa help= # noqa
'the plotting names for the plots of each model, and they should be in the same order as model_names' # noqa: E501 'the plotting names for the plots of each model, and they should be in the same order as model_names' # noqa: E501
) )
parser.add_argument(
'--delete-icons',
action='store_true',
help='whether to delete the icons after plotting')
humans = [ humans = [
'subject-01', 'subject-02', 'subject-03', 'subject-04', 'subject-05', 'subject-01', 'subject-02', 'subject-03', 'subject-04', 'subject-05',
'subject-06', 'subject-07', 'subject-08', 'subject-09', 'subject-10' 'subject-06', 'subject-07', 'subject-08', 'subject-09', 'subject-10'
] ]
icon_names = [
'airplane.png', 'response_icons_vertical_reverse.png', 'bottle.png',
'car.png', 'oven.png', 'elephant.png', 'dog.png', 'boat.png', 'clock.png',
'chair.png', 'keyboard.png', 'bird.png', 'bicycle.png',
'response_icons_horizontal.png', 'cat.png', 'bear.png', 'colorbar.pdf',
'knife.png', 'response_icons_vertical.png', 'truck.png'
]
def read_csvs(csv_dir: str) -> pd.DataFrame: def read_csvs(csv_dir: str) -> pd.DataFrame:
"""Reads all csv files in a directory and returns a single dataframe. """Reads all csv files in a directory and returns a single dataframe.
@ -83,7 +96,6 @@ def plot_shape_bias_matrixplot(args, analysis=ShapeBias()) -> None:
mpl.rcParams['font.serif'] = ['Times New Roman'] mpl.rcParams['font.serif'] = ['Times New Roman']
plt.figure(figsize=(9, 7)) plt.figure(figsize=(9, 7))
df = read_csvs(args.csv_dir) df = read_csvs(args.csv_dir)
fontsize = 15 fontsize = 15
@ -218,19 +230,27 @@ def plot_shape_bias_matrixplot(args, analysis=ShapeBias()) -> None:
plt.close() plt.close()
if __name__ == '__main__': def check_icons() -> bool:
icon_names = [ """Check if icons are present, if not download them."""
'airplane.png', 'response_icons_vertical_reverse.png', 'bottle.png', if not osp.exists(ICONS_DIR):
'car.png', 'oven.png', 'elephant.png', 'dog.png', 'boat.png', return False
'clock.png', 'chair.png', 'keyboard.png', 'bird.png', 'bicycle.png',
'response_icons_horizontal.png', 'cat.png', 'bear.png', 'colorbar.pdf',
'knife.png', 'response_icons_vertical.png', 'truck.png'
]
root_url = 'https://github.com/bethgelab/model-vs-human/raw/master/assets/icons' # noqa: E501
os.makedirs(ICONS_DIR, exist_ok=True)
for icon_name in icon_names: for icon_name in icon_names:
url = osp.join(root_url, icon_name) if not osp.exists(osp.join(ICONS_DIR, icon_name)):
os.system('wget -O {} {}'.format(osp.join(ICONS_DIR, icon_name), url)) return False
return True
if __name__ == '__main__':
if not check_icons():
root_url = 'https://github.com/bethgelab/model-vs-human/raw/master/assets/icons' # noqa: E501
os.makedirs(ICONS_DIR, exist_ok=True)
MMLogger.get_current_instance().info(
f'Downloading icons to {ICONS_DIR}')
for icon_name in icon_names:
url = osp.join(root_url, icon_name)
os.system('wget -O {} {}'.format(
osp.join(ICONS_DIR, icon_name), url))
args = parser.parse_args() args = parser.parse_args()
assert len(args.model_names) * 3 == len(args.colors), 'Number of colors \ assert len(args.model_names) * 3 == len(args.colors), 'Number of colors \
@ -260,5 +280,5 @@ if __name__ == '__main__':
args.plotting_names.append('Humans') args.plotting_names.append('Humans')
plot_shape_bias_matrixplot(args) plot_shape_bias_matrixplot(args)
if args.delete_icons:
os.system('rm -rf {}'.format(ICONS_DIR)) os.system('rm -rf {}'.format(ICONS_DIR))